diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 08c39a0f4ce68c896d2b2ffe26924b0645beebeb..b77ffe53655e953c3c3fbf77dc242aa8900683c8 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -9,7 +9,7 @@ doctest: image: python:3.7 script: - pip install -r requirements.txt - - "python -m doctest -v $(ls freeforestml/*.py | grep -v '__init__.py')" + - ci/doctest.sh unittest: stage: test diff --git a/ci/doctest.sh b/ci/doctest.sh new file mode 100755 index 0000000000000000000000000000000000000000..69542152e76c038ee2a5765221cefa8ad4e2b0a4 --- /dev/null +++ b/ci/doctest.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +python3 -m doctest -v $(ls freeforestml/*.py | grep -v '__init__.py') diff --git a/freeforestml/cut.py b/freeforestml/cut.py index cf4e1e45c271627a878ceba8acc1594e01e360fd..bb53dd39b121a85377129a82c7d45a63f19499ad 100644 --- a/freeforestml/cut.py +++ b/freeforestml/cut.py @@ -8,7 +8,7 @@ class Cut: quantities. Cuts store the condition to be applied to a dataframe. New cut objects - accept all event by default. The selection can be limited by passing a + accept all events by default. The selection can be limited by passing a lambda to the constructor. >>> sel_all = Cut() @@ -65,9 +65,20 @@ class Cut: >>> sel_sr = Cut(lambda df: df.is_sr == 1, label="Signal Region") >>> sel_sr.label 'Signal Region' + + If the application of a cut requires to change the event weights by a so + called scale factors, you can pass additional optional keyword arguments + that specify how the new weight should be computed. + + >>> sel_sample = Cut(lambda df: df.value % 2 == 0, \ + weight=lambda df: df.weight * 2) + + The argument name 'weight' in this example is arbitrary. It is even + possible to add new columns to the returned dataframe in this way, + however, this is not recommended. """ - def __init__(self, func=None, label=None): + def __init__(self, func=None, label=None, **columns): """ Creates a new cut. The optional func argument is called with the dataframe upon evaluation. The function must return an index array. If @@ -77,16 +88,21 @@ class Cut: if isinstance(func, Cut): self.func = func.func self.label = label or func.label + self.columns = columns or func.columns else: self.func = func self.label = label + self.columns = columns def __call__(self, dataframe): """ Applies the internally stored cut to the given dataframe and returns a new dataframe containing only entries passing the event selection. """ - return dataframe[self.idx_array(dataframe)] + new_df = dataframe[self.idx_array(dataframe)] + if self.columns: + new_df = new_df.assign(**self.columns) + return new_df def idx_array(self, dataframe): """ diff --git a/freeforestml/tests/test_cut.py b/freeforestml/tests/test_cut.py index bff8f785a0ea865ba222db397d29da94ad733b57..1d0f3c7eae3426b4f24e97f107cb0466c12190bc 100644 --- a/freeforestml/tests/test_cut.py +++ b/freeforestml/tests/test_cut.py @@ -311,3 +311,26 @@ class CutTestCase(unittest.TestCase): high_sale = Cut(lambda df: df.sale > 10) self.assertEqual(list(high_sale(self.df).year), []) + + def test_assign_columns(self): + """ + Check that passing a keyword argument overwrites an existing column. + """ + alternate = Cut(lambda df: df.year % 2 == 0, + sale=lambda df: df.sale * 2) + + df_alt = alternate(self.df) + self.assertEqual(list(df_alt.year), [2010, 2012, 2014, 2016]) + self.assertEqual(list(df_alt.sale), [7.8, 9.4, 15.0, 4.6]) + + + def test_assign_new_columns(self): + """ + Check that passing a keyword argument creates a new columns + """ + alternate = Cut(lambda df: df.year % 2 == 0, + weight=lambda df: df.year * 0 + 2) + df_alt = alternate(self.df) + self.assertEqual(list(df_alt.year), [2010, 2012, 2014, 2016]) + self.assertEqual(list(df_alt.sale), [3.9, 4.7, 7.5, 2.3]) + self.assertEqual(list(df_alt.weight), [2, 2, 2, 2])