From a8f5fca14c99618940cef69d5d9f74d8387bc173 Mon Sep 17 00:00:00 2001
From: Frank Sauerburger <f.sauerburger@cern.ch>
Date: Wed, 3 Jul 2019 15:36:26 +0200
Subject: [PATCH] Make Cut.call return the selected dataframe

This change is not compatible with previous versions and breaks existing code.
When a cut is called with a dataframe, the selected events are returned instead
of an index array. To get the index array, use Cut.idx_array().
---
 nnfwtbn/cut.py                 | 28 +++++++++++++++++-----------
 nnfwtbn/plot.py                | 22 +++++++++++-----------
 nnfwtbn/tests/test_cut.py      |  2 +-
 nnfwtbn/tests/test_variable.py |  8 ++++----
 4 files changed, 33 insertions(+), 27 deletions(-)

diff --git a/nnfwtbn/cut.py b/nnfwtbn/cut.py
index cc29e72..0a410f0 100644
--- a/nnfwtbn/cut.py
+++ b/nnfwtbn/cut.py
@@ -72,7 +72,7 @@ class Cut:
         Applies the internally stored cut to the given dataframe and returns a
         new dataframe containing only entries passing the event selection.
         """
-        return self.idx_array(dataframe)
+        return dataframe[self.idx_array(dataframe)]
 
     def idx_array(self, dataframe):
         """
@@ -89,36 +89,42 @@ class Cut:
         Returns a new cut implementing the logical AND of this cut and the
         other cut. The other cat be a Cut or any callable.
         """
-        if callable(other):
-            return Cut(lambda df: self(df) & other(df))
+        if isinstance(other, Cut):
+            return Cut(lambda df: self.idx_array(df) & other.idx_array(df))
+        elif callable(other):
+            return Cut(lambda df: self.idx_array(df) & other(df))
         else:
-            return Cut(lambda df: self(df) & other)
+            return Cut(lambda df: self.idx_array(df) & other)
 
     def __or__(self, other):
         """
         Returns a new cut implementing the logical OR of this cut and the
         other cut. The other cat be a Cut or any callable.
         """
-        if callable(other):
-            return Cut(lambda df: self(df) | other(df))
+        if isinstance(other, Cut):
+            return Cut(lambda df: self.idx_array(df) | other.idx_array(df))
+        elif callable(other):
+            return Cut(lambda df: self.idx_array(df) | other(df))
         else:
-            return Cut(lambda df: self(df) | other)
+            return Cut(lambda df: self.idx_array(df) | other)
 
     def __xor__(self, other):
         """
         Returns a new cut implementing the logical XOR of this cut and the
         other cut. The other can be a callable.
         """
-        if callable(other):
-            return Cut(lambda df: self(df) ^ other(df))
+        if isinstance(other, Cut):
+            return Cut(lambda df: self.idx_array(df) ^ other.idx_array(df))
+        elif callable(other):
+            return Cut(lambda df: self.idx_array(df) ^ other(df))
         else:
-            return Cut(lambda df: self(df) ^ other)
+            return Cut(lambda df: self.idx_array(df) ^ other)
 
     def __invert__(self):
         """
         Returns a new cut implementing the logical NOT of this cut.
         """
-        return Cut(lambda df: ~self(df))
+        return Cut(lambda df: ~self.idx_array(df))
 
     def __rand__(self, other):
         return self & other
diff --git a/nnfwtbn/plot.py b/nnfwtbn/plot.py
index 0491314..1105a41 100644
--- a/nnfwtbn/plot.py
+++ b/nnfwtbn/plot.py
@@ -160,9 +160,9 @@ def hist(dataframe, variable, bins, stacks, selection=None,
                 process_kwds = defaults
 
                 n, _ = np.histogram(
-                    variable(dataframe[sel(dataframe)]),
+                    variable(sel(dataframe)),
                     bins=bins, range=range,
-                    weights=weight(dataframe[sel(dataframe)]))
+                    weights=weight(sel(dataframe)))
 
                 bin_centers = (bins[1:] + bins[:-1]) / 2
                 bin_widths = bins[1:] - bins[:-1]
@@ -173,11 +173,11 @@ def hist(dataframe, variable, bins, stacks, selection=None,
 
             else:
                 n, _, _ = axes.hist(
-                    variable(dataframe[sel(dataframe)]),
+                    variable(sel(dataframe)),
                     bins=bins, range=range,
                     bottom=bottom,
                     label=process.label,
-                    weights=weight(dataframe[sel(dataframe)]),
+                    weights=weight(sel(dataframe)),
                     **process_kwds)
             bottom += n 
 
@@ -243,10 +243,10 @@ def confusion_matrix(df, x_processes, y_processes, x_label, y_label,
     y_processes.reverse()
     data = np.zeros((len(y_processes), len(x_processes)))
     for i_x, x_process in enumerate(x_processes):
-        x_df = df[x_process.selection(df)]
+        x_df = x_process.selection(df)
         total_weight = weight(x_df).sum()
         for i_y, y_process in enumerate(y_processes):
-            x_y_df = x_df[y_process.selection(x_df)]
+            x_y_df = y_process.selection(x_df)
             data[i_y][i_x] = weight(x_y_df).sum() / total_weight
 
     data = pd.DataFrame(data,
@@ -302,12 +302,12 @@ def roc(df, signal_process, background_process, discriminant, steps=100,
     if max is None:
         max = discriminant(df).max()
 
-    df = df[selection(df)]
+    df = selection(df)
 
     signal_effs = []
     background_ieffs = []
-    n_total_sig = weight(df[signal_process.selection(df)]).sum()
-    n_total_bkg = weight(df[background_process.selection(df)]).sum()
+    n_total_sig = weight(signal_process.selection(df)).sum()
+    n_total_bkg = weight(background_process.selection(df)).sum()
     for cut_value in np.linspace(min, max, steps):
         residual_df = df[discriminant(df) >= cut_value]
 
@@ -315,8 +315,8 @@ def roc(df, signal_process, background_process, discriminant, steps=100,
         if n_total == 0:
             continue
 
-        signal_df = residual_df[signal_process.selection(residual_df)]
-        background_df = residual_df[background_process.selection(residual_df)]
+        signal_df = signal_process.selection(residual_df)
+        background_df = background_process.selection(residual_df)
 
         n_signal = weight(signal_df).sum()
         n_background = weight(background_df).sum()
diff --git a/nnfwtbn/tests/test_cut.py b/nnfwtbn/tests/test_cut.py
index 68f1dbb..2c22f2f 100644
--- a/nnfwtbn/tests/test_cut.py
+++ b/nnfwtbn/tests/test_cut.py
@@ -202,7 +202,7 @@ class CutTestCase(unittest.TestCase):
         high_sale = Cut(lambda df: df.sale > 4)
 
         combined = (lambda df: df.year < 2015) ^ high_sale
-        selected = combined(self.df)
+        selected = combined.idx_array(self.df)
 
         self.assertEqual(list(selected),
                          [True, True, False, False, False, False, False, True])
diff --git a/nnfwtbn/tests/test_variable.py b/nnfwtbn/tests/test_variable.py
index 78ffe3c..c2edc75 100644
--- a/nnfwtbn/tests/test_variable.py
+++ b/nnfwtbn/tests/test_variable.py
@@ -114,7 +114,7 @@ class RangeBlindingTestCase(unittest.TestCase):
 
         df = self.generate_df()
         blinding = blinding_strategy(variable, bins=30, range=(50, 200))
-        blinded_df = df[blinding(df)]
+        blinded_df = blinding(df)
 
         # All events outside
         self.assertTrue((
@@ -140,7 +140,7 @@ class RangeBlindingTestCase(unittest.TestCase):
 
         df = self.generate_df()
         blinding = blinding_strategy(variable, bins=15, range=(50, 200))
-        blinded_df = df[blinding(df)]
+        blinded_df = blinding(df)
 
         # All events outside
         self.assertTrue((
@@ -162,7 +162,7 @@ class RangeBlindingTestCase(unittest.TestCase):
 
         df = self.generate_df()
         blinding = blinding_strategy(variable, bins=15, range=(50, 200))
-        blinded_df = df[blinding(df)]
+        blinded_df = blinding(df)
 
         # All events outside
         self.assertTrue((
@@ -184,7 +184,7 @@ class RangeBlindingTestCase(unittest.TestCase):
 
         df = self.generate_df()
         blinding = blinding_strategy(variable, bins=15, range=(50, 200))
-        blinded_df = df[blinding(df)]
+        blinded_df = blinding(df)
 
         # All events outside
         self.assertTrue((
-- 
GitLab