From 24550d781a270c41b5d18e914167b7e6e1361b9e Mon Sep 17 00:00:00 2001
From: Frank Sauerburger <f.sauerburger@cern.ch>
Date: Wed, 3 Jul 2019 16:50:13 +0200
Subject: [PATCH] Add Cut copy constructor

---
 nnfwtbn/cut.py            |  8 ++++++--
 nnfwtbn/tests/test_cut.py | 27 +++++++++++++++++++++++++++
 2 files changed, 33 insertions(+), 2 deletions(-)

diff --git a/nnfwtbn/cut.py b/nnfwtbn/cut.py
index 21fb501..cf4e1e4 100644
--- a/nnfwtbn/cut.py
+++ b/nnfwtbn/cut.py
@@ -74,8 +74,12 @@ class Cut:
         the optional function is omitted, Every row in the dataframe is
         accepted by this cut.
         """
-        self.func = func
-        self.label = label
+        if isinstance(func, Cut):
+            self.func = func.func
+            self.label = label or func.label
+        else:
+            self.func = func
+            self.label = label
 
     def __call__(self, dataframe):
         """
diff --git a/nnfwtbn/tests/test_cut.py b/nnfwtbn/tests/test_cut.py
index f9759bc..4051542 100644
--- a/nnfwtbn/tests/test_cut.py
+++ b/nnfwtbn/tests/test_cut.py
@@ -261,3 +261,30 @@ class CutTestCase(unittest.TestCase):
         high_sale = Cut(lambda df: df.sale > 4, label="High sales volume")
         self.assertEqual(high_sale.label, "High sales volume")
 
+    def test_init_cut(self):
+        """
+        Check that a cut can be passed to the constructor.
+        """
+        high_sale = Cut(lambda df: df.sale > 4)
+        high_sale2 = Cut(high_sale)
+
+        self.assertEqual(len(high_sale2(self.df)), 4)
+        self.assertEqual(len(high_sale2.idx_array(self.df)), 8)
+        
+    def test_init_cut_name_inherit(self):
+        """
+        Check that the name of a cut passed to the constructor is inherited.
+        """
+        high_sale = Cut(lambda df: df.sale > 4, label="High sales volume")
+        high_sale2 = Cut(high_sale)
+
+        self.assertEqual(high_sale2.label, "High sales volume")
+        
+    def test_init_cut_name_inherit_precedence(self):
+        """
+        Check that the name argument has precedence over the given cut.
+        """
+        high_sale = Cut(lambda df: df.sale > 4, label="High sales volume")
+        high_sale2 = Cut(high_sale, label="Other label")
+
+        self.assertEqual(high_sale2.label, "Other label")
-- 
GitLab