From a464fccecdfd8c948af9da6098a094eacdf9a401 Mon Sep 17 00:00:00 2001
From: Frank Sauerburger <f.sauerburger@cern.ch>
Date: Wed, 3 Jul 2019 16:07:12 +0200
Subject: [PATCH] Add name attribute to cuts

---
 nnfwtbn/cut.py            | 11 ++++++++++-
 nnfwtbn/tests/test_cut.py |  8 ++++++++
 2 files changed, 18 insertions(+), 1 deletion(-)

diff --git a/nnfwtbn/cut.py b/nnfwtbn/cut.py
index 74227e4..22d0704 100644
--- a/nnfwtbn/cut.py
+++ b/nnfwtbn/cut.py
@@ -57,9 +57,17 @@ class Cut:
     >>> sel_pos_even_lambda(df)
        value
     4      4
+
+    Cuts might be named by passing the 'name' argument to the constructor.
+    Cut names can be used during plotting as labels to specify the plotted
+    region.
+
+    >>> sel_sr = Cut(lambda df: df.is_sr == 1, name="Signal Region")
+    >>> sel_sr.name
+    'Signal Region'
     """
 
-    def __init__(self, func=None):
+    def __init__(self, func=None, name=None):
         """
         Creates a new cut. The optional func argument is called with the
         dataframe upon evaluation. The function must return an index array. If
@@ -67,6 +75,7 @@ class Cut:
         accepted by this cut.
         """
         self.func = func
+        self.name = name
 
     def __call__(self, dataframe):
         """
diff --git a/nnfwtbn/tests/test_cut.py b/nnfwtbn/tests/test_cut.py
index 2c22f2f..3c8fbdb 100644
--- a/nnfwtbn/tests/test_cut.py
+++ b/nnfwtbn/tests/test_cut.py
@@ -253,3 +253,11 @@ class CutTestCase(unittest.TestCase):
         self.assertEqual(list(high_sale_years.sale), [4.7, 5.6, 7.5, 4.2])
         self.assertEqual(list(high_sale_years.year), [2012, 2013, 2014, 2017])
 
+    def test_name(self):
+        """
+        Check that names specified during construction are available via the
+        'name' attribute.
+        """
+        high_sale = Cut(lambda df: df.sale > 4, name="High sales volume")
+        self.assertEqual(high_sale.name, "High sales volume")
+
-- 
GitLab