From a464fccecdfd8c948af9da6098a094eacdf9a401 Mon Sep 17 00:00:00 2001 From: Frank Sauerburger <f.sauerburger@cern.ch> Date: Wed, 3 Jul 2019 16:07:12 +0200 Subject: [PATCH] Add name attribute to cuts --- nnfwtbn/cut.py | 11 ++++++++++- nnfwtbn/tests/test_cut.py | 8 ++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/nnfwtbn/cut.py b/nnfwtbn/cut.py index 74227e4..22d0704 100644 --- a/nnfwtbn/cut.py +++ b/nnfwtbn/cut.py @@ -57,9 +57,17 @@ class Cut: >>> sel_pos_even_lambda(df) value 4 4 + + Cuts might be named by passing the 'name' argument to the constructor. + Cut names can be used during plotting as labels to specify the plotted + region. + + >>> sel_sr = Cut(lambda df: df.is_sr == 1, name="Signal Region") + >>> sel_sr.name + 'Signal Region' """ - def __init__(self, func=None): + def __init__(self, func=None, name=None): """ Creates a new cut. The optional func argument is called with the dataframe upon evaluation. The function must return an index array. If @@ -67,6 +75,7 @@ class Cut: accepted by this cut. """ self.func = func + self.name = name def __call__(self, dataframe): """ diff --git a/nnfwtbn/tests/test_cut.py b/nnfwtbn/tests/test_cut.py index 2c22f2f..3c8fbdb 100644 --- a/nnfwtbn/tests/test_cut.py +++ b/nnfwtbn/tests/test_cut.py @@ -253,3 +253,11 @@ class CutTestCase(unittest.TestCase): self.assertEqual(list(high_sale_years.sale), [4.7, 5.6, 7.5, 4.2]) self.assertEqual(list(high_sale_years.year), [2012, 2013, 2014, 2017]) + def test_name(self): + """ + Check that names specified during construction are available via the + 'name' attribute. + """ + high_sale = Cut(lambda df: df.sale > 4, name="High sales volume") + self.assertEqual(high_sale.name, "High sales volume") + -- GitLab