From e4131c4f2adbe3030f86cd0b3588cfb69c54303f Mon Sep 17 00:00:00 2001 From: Frank Sauerburger <f.sauerburger@cern.ch> Date: Mon, 10 May 2021 15:41:28 +0200 Subject: [PATCH] Add optional column creators to Cut --- .gitlab-ci.yml | 2 +- ci/doctest.sh | 3 +++ freeforestml/cut.py | 22 +++++++++++++++++++--- freeforestml/tests/test_cut.py | 23 +++++++++++++++++++++++ 4 files changed, 46 insertions(+), 4 deletions(-) create mode 100755 ci/doctest.sh diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 08c39a0..b77ffe5 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -9,7 +9,7 @@ doctest: image: python:3.7 script: - pip install -r requirements.txt - - "python -m doctest -v $(ls freeforestml/*.py | grep -v '__init__.py')" + - ci/doctest.sh unittest: stage: test diff --git a/ci/doctest.sh b/ci/doctest.sh new file mode 100755 index 0000000..6954215 --- /dev/null +++ b/ci/doctest.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +python3 -m doctest -v $(ls freeforestml/*.py | grep -v '__init__.py') diff --git a/freeforestml/cut.py b/freeforestml/cut.py index cf4e1e4..bb53dd3 100644 --- a/freeforestml/cut.py +++ b/freeforestml/cut.py @@ -8,7 +8,7 @@ class Cut: quantities. Cuts store the condition to be applied to a dataframe. New cut objects - accept all event by default. The selection can be limited by passing a + accept all events by default. The selection can be limited by passing a lambda to the constructor. >>> sel_all = Cut() @@ -65,9 +65,20 @@ class Cut: >>> sel_sr = Cut(lambda df: df.is_sr == 1, label="Signal Region") >>> sel_sr.label 'Signal Region' + + If the application of a cut requires to change the event weights by a so + called scale factors, you can pass additional optional keyword arguments + that specify how the new weight should be computed. + + >>> sel_sample = Cut(lambda df: df.value % 2 == 0, \ + weight=lambda df: df.weight * 2) + + The argument name 'weight' in this example is arbitrary. It is even + possible to add new columns to the returned dataframe in this way, + however, this is not recommended. """ - def __init__(self, func=None, label=None): + def __init__(self, func=None, label=None, **columns): """ Creates a new cut. The optional func argument is called with the dataframe upon evaluation. The function must return an index array. If @@ -77,16 +88,21 @@ class Cut: if isinstance(func, Cut): self.func = func.func self.label = label or func.label + self.columns = columns or func.columns else: self.func = func self.label = label + self.columns = columns def __call__(self, dataframe): """ Applies the internally stored cut to the given dataframe and returns a new dataframe containing only entries passing the event selection. """ - return dataframe[self.idx_array(dataframe)] + new_df = dataframe[self.idx_array(dataframe)] + if self.columns: + new_df = new_df.assign(**self.columns) + return new_df def idx_array(self, dataframe): """ diff --git a/freeforestml/tests/test_cut.py b/freeforestml/tests/test_cut.py index bff8f78..1d0f3c7 100644 --- a/freeforestml/tests/test_cut.py +++ b/freeforestml/tests/test_cut.py @@ -311,3 +311,26 @@ class CutTestCase(unittest.TestCase): high_sale = Cut(lambda df: df.sale > 10) self.assertEqual(list(high_sale(self.df).year), []) + + def test_assign_columns(self): + """ + Check that passing a keyword argument overwrites an existing column. + """ + alternate = Cut(lambda df: df.year % 2 == 0, + sale=lambda df: df.sale * 2) + + df_alt = alternate(self.df) + self.assertEqual(list(df_alt.year), [2010, 2012, 2014, 2016]) + self.assertEqual(list(df_alt.sale), [7.8, 9.4, 15.0, 4.6]) + + + def test_assign_new_columns(self): + """ + Check that passing a keyword argument creates a new columns + """ + alternate = Cut(lambda df: df.year % 2 == 0, + weight=lambda df: df.year * 0 + 2) + df_alt = alternate(self.df) + self.assertEqual(list(df_alt.year), [2010, 2012, 2014, 2016]) + self.assertEqual(list(df_alt.sale), [3.9, 4.7, 7.5, 2.3]) + self.assertEqual(list(df_alt.weight), [2, 2, 2, 2]) -- GitLab