diff --git a/nnfwtbn/cut.py b/nnfwtbn/cut.py index 74227e4a6464fa7ad16422b45326c825c7467044..22d070431748d8261f129cd90dea16355826122b 100644 --- a/nnfwtbn/cut.py +++ b/nnfwtbn/cut.py @@ -57,9 +57,17 @@ class Cut: >>> sel_pos_even_lambda(df) value 4 4 + + Cuts might be named by passing the 'name' argument to the constructor. + Cut names can be used during plotting as labels to specify the plotted + region. + + >>> sel_sr = Cut(lambda df: df.is_sr == 1, name="Signal Region") + >>> sel_sr.name + 'Signal Region' """ - def __init__(self, func=None): + def __init__(self, func=None, name=None): """ Creates a new cut. The optional func argument is called with the dataframe upon evaluation. The function must return an index array. If @@ -67,6 +75,7 @@ class Cut: accepted by this cut. """ self.func = func + self.name = name def __call__(self, dataframe): """ diff --git a/nnfwtbn/tests/test_cut.py b/nnfwtbn/tests/test_cut.py index 2c22f2fcdf4df5d898543d48e830369e3bf5fc24..3c8fbdb3c643d226943d755ef8b6d7c1e95c5be5 100644 --- a/nnfwtbn/tests/test_cut.py +++ b/nnfwtbn/tests/test_cut.py @@ -253,3 +253,11 @@ class CutTestCase(unittest.TestCase): self.assertEqual(list(high_sale_years.sale), [4.7, 5.6, 7.5, 4.2]) self.assertEqual(list(high_sale_years.year), [2012, 2013, 2014, 2017]) + def test_name(self): + """ + Check that names specified during construction are available via the + 'name' attribute. + """ + high_sale = Cut(lambda df: df.sale > 4, name="High sales volume") + self.assertEqual(high_sale.name, "High sales volume") +