diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 5e7292ff531f5675e897dd228387363fb61b02e2..e903defad5fddbb5783298abea89778b084ddbbd 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -10,7 +10,7 @@ doctest: script: - pip install -r requirements.txt - pip install -r dev-requirements.txt - - "python -m doctest -v $(ls freeforestml/*.py | grep -v '__init__.py')" + - ci/doctest.sh unittest: stage: test diff --git a/ci/doctest.sh b/ci/doctest.sh new file mode 100755 index 0000000000000000000000000000000000000000..69542152e76c038ee2a5765221cefa8ad4e2b0a4 --- /dev/null +++ b/ci/doctest.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +python3 -m doctest -v $(ls freeforestml/*.py | grep -v '__init__.py') diff --git a/freeforestml/cut.py b/freeforestml/cut.py index b34114d823ba62bd42ca1a28a3883f8433be11ed..982c69c9b40cf1e1798b752e93ec4d45fc2cfe22 100644 --- a/freeforestml/cut.py +++ b/freeforestml/cut.py @@ -8,7 +8,7 @@ class Cut: quantities. Cuts store the condition to be applied to a dataframe. New cut objects - accept all event by default. The selection can be limited by passing a + accept all events by default. The selection can be limited by passing a lambda to the constructor. >>> sel_all = Cut() @@ -65,9 +65,20 @@ class Cut: >>> sel_sr = Cut(lambda df: df.is_sr == 1, label="Signal Region") >>> sel_sr.label 'Signal Region' + + If the application of a cut requires to change the event weights by a so + called scale factors, you can pass additional optional keyword arguments + that specify how the new weight should be computed. + + >>> sel_sample = Cut(lambda df: df.value % 2 == 0, \ + weight=lambda df: df.weight * 2) + + The argument name 'weight' in this example is arbitrary. It is even + possible to add new columns to the returned dataframe in this way, + however, this is not recommended. """ - def __init__(self, func=None, label=None): + def __init__(self, func=None, label=None, **columns): """ Creates a new cut. The optional func argument is called with the dataframe upon evaluation. The function must return an index array. If @@ -77,9 +88,16 @@ class Cut: if isinstance(func, Cut): self.func = func.func self.label = label or func.label + self.columns = columns or func.columns else: self.func = func self.label = label + self.columns = columns + + @staticmethod + def _is_pyspark(dataframe): + """Return true if assumed to be a PySpark Dataframe""" + return hasattr(dataframe, "_sc") def __call__(self, dataframe): """ @@ -89,10 +107,26 @@ class Cut: if self.func is None: return dataframe - if hasattr(dataframe, "_sc"): - return dataframe.filter(self.func(dataframe)) - - return dataframe[self.idx_array(dataframe)] + if self._is_pyspark(dataframe): + # Is PySpark + new_df = dataframe.filter(self.func(dataframe)) + if self.columns: + add_columns = [ + col(new_df).alias(name) + for name, col in self.columns.items() + ] + + existing = [ + n for n in new_df.columns if n not in self.columns + ] + new_df = new_df.select(*existing, *add_columns) + + return new_df + + new_df = dataframe[self.idx_array(dataframe)] + if self.columns: + new_df = new_df.assign(**self.columns) + return new_df def idx_array(self, dataframe): """ diff --git a/freeforestml/tests/test_cut.py b/freeforestml/tests/test_cut.py index bff8f785a0ea865ba222db397d29da94ad733b57..1d0f3c7eae3426b4f24e97f107cb0466c12190bc 100644 --- a/freeforestml/tests/test_cut.py +++ b/freeforestml/tests/test_cut.py @@ -311,3 +311,26 @@ class CutTestCase(unittest.TestCase): high_sale = Cut(lambda df: df.sale > 10) self.assertEqual(list(high_sale(self.df).year), []) + + def test_assign_columns(self): + """ + Check that passing a keyword argument overwrites an existing column. + """ + alternate = Cut(lambda df: df.year % 2 == 0, + sale=lambda df: df.sale * 2) + + df_alt = alternate(self.df) + self.assertEqual(list(df_alt.year), [2010, 2012, 2014, 2016]) + self.assertEqual(list(df_alt.sale), [7.8, 9.4, 15.0, 4.6]) + + + def test_assign_new_columns(self): + """ + Check that passing a keyword argument creates a new columns + """ + alternate = Cut(lambda df: df.year % 2 == 0, + weight=lambda df: df.year * 0 + 2) + df_alt = alternate(self.df) + self.assertEqual(list(df_alt.year), [2010, 2012, 2014, 2016]) + self.assertEqual(list(df_alt.sale), [3.9, 4.7, 7.5, 2.3]) + self.assertEqual(list(df_alt.weight), [2, 2, 2, 2]) diff --git a/freeforestml/tests/test_spark.py b/freeforestml/tests/test_spark.py index 69463289de39dc4c0716ed39f029c9aaf0f6c833..61ff2552347812be0d469d45f493175b5b6f6bfa 100644 --- a/freeforestml/tests/test_spark.py +++ b/freeforestml/tests/test_spark.py @@ -54,3 +54,21 @@ class TestparkCut(unittest.TestCase): selected = cut(df).count() self.assertEqual(selected, 100) + + def test_assign(self): + """Check that a cut can assign a columns""" + rdd = self.sc.parallelize([(x, x**2) for x in range(100)]) + df = rdd.toDF(["x", "y"]) + + cut_new = ff.Cut(lambda d: d.x < 5, weight=lambda d: d.x * 2) + sum_new = cut_new(df).agg({'weight': 'sum'}).collect() + self.assertEqual(sum_new[0][0], 2 + 4 + 6 + 8) + + def test_overwrite(self): + """Check that a cut can overwrite a columns""" + rdd = self.sc.parallelize([(x, x**2) for x in range(100)]) + df = rdd.toDF(["x", "y"]) + + cut_overwrite = ff.Cut(lambda d: d.x < 5, x=lambda d: d.x * 2) + sum_overwrite = cut_overwrite(df).agg({'x': 'sum'}).collect() + self.assertEqual(sum_overwrite[0][0], 2 + 4 + 6 + 8)