diff --git a/nnfwtbn/cut.py b/nnfwtbn/cut.py index 21fb501def084c7d32cb79a94c0c7521706fab7a..cf4e1e45c271627a878ceba8acc1594e01e360fd 100644 --- a/nnfwtbn/cut.py +++ b/nnfwtbn/cut.py @@ -74,8 +74,12 @@ class Cut: the optional function is omitted, Every row in the dataframe is accepted by this cut. """ - self.func = func - self.label = label + if isinstance(func, Cut): + self.func = func.func + self.label = label or func.label + else: + self.func = func + self.label = label def __call__(self, dataframe): """ diff --git a/nnfwtbn/tests/test_cut.py b/nnfwtbn/tests/test_cut.py index f9759bcd270324ef39d2a773746a4904b2a02063..4051542652931bb112c502bcc1adcaa7930e71ba 100644 --- a/nnfwtbn/tests/test_cut.py +++ b/nnfwtbn/tests/test_cut.py @@ -261,3 +261,30 @@ class CutTestCase(unittest.TestCase): high_sale = Cut(lambda df: df.sale > 4, label="High sales volume") self.assertEqual(high_sale.label, "High sales volume") + def test_init_cut(self): + """ + Check that a cut can be passed to the constructor. + """ + high_sale = Cut(lambda df: df.sale > 4) + high_sale2 = Cut(high_sale) + + self.assertEqual(len(high_sale2(self.df)), 4) + self.assertEqual(len(high_sale2.idx_array(self.df)), 8) + + def test_init_cut_name_inherit(self): + """ + Check that the name of a cut passed to the constructor is inherited. + """ + high_sale = Cut(lambda df: df.sale > 4, label="High sales volume") + high_sale2 = Cut(high_sale) + + self.assertEqual(high_sale2.label, "High sales volume") + + def test_init_cut_name_inherit_precedence(self): + """ + Check that the name argument has precedence over the given cut. + """ + high_sale = Cut(lambda df: df.sale > 4, label="High sales volume") + high_sale2 = Cut(high_sale, label="Other label") + + self.assertEqual(high_sale2.label, "Other label")