From a8f5fca14c99618940cef69d5d9f74d8387bc173 Mon Sep 17 00:00:00 2001 From: Frank Sauerburger <f.sauerburger@cern.ch> Date: Wed, 3 Jul 2019 15:36:26 +0200 Subject: [PATCH] Make Cut.call return the selected dataframe This change is not compatible with previous versions and breaks existing code. When a cut is called with a dataframe, the selected events are returned instead of an index array. To get the index array, use Cut.idx_array(). --- nnfwtbn/cut.py | 28 +++++++++++++++++----------- nnfwtbn/plot.py | 22 +++++++++++----------- nnfwtbn/tests/test_cut.py | 2 +- nnfwtbn/tests/test_variable.py | 8 ++++---- 4 files changed, 33 insertions(+), 27 deletions(-) diff --git a/nnfwtbn/cut.py b/nnfwtbn/cut.py index cc29e72..0a410f0 100644 --- a/nnfwtbn/cut.py +++ b/nnfwtbn/cut.py @@ -72,7 +72,7 @@ class Cut: Applies the internally stored cut to the given dataframe and returns a new dataframe containing only entries passing the event selection. """ - return self.idx_array(dataframe) + return dataframe[self.idx_array(dataframe)] def idx_array(self, dataframe): """ @@ -89,36 +89,42 @@ class Cut: Returns a new cut implementing the logical AND of this cut and the other cut. The other cat be a Cut or any callable. """ - if callable(other): - return Cut(lambda df: self(df) & other(df)) + if isinstance(other, Cut): + return Cut(lambda df: self.idx_array(df) & other.idx_array(df)) + elif callable(other): + return Cut(lambda df: self.idx_array(df) & other(df)) else: - return Cut(lambda df: self(df) & other) + return Cut(lambda df: self.idx_array(df) & other) def __or__(self, other): """ Returns a new cut implementing the logical OR of this cut and the other cut. The other cat be a Cut or any callable. """ - if callable(other): - return Cut(lambda df: self(df) | other(df)) + if isinstance(other, Cut): + return Cut(lambda df: self.idx_array(df) | other.idx_array(df)) + elif callable(other): + return Cut(lambda df: self.idx_array(df) | other(df)) else: - return Cut(lambda df: self(df) | other) + return Cut(lambda df: self.idx_array(df) | other) def __xor__(self, other): """ Returns a new cut implementing the logical XOR of this cut and the other cut. The other can be a callable. """ - if callable(other): - return Cut(lambda df: self(df) ^ other(df)) + if isinstance(other, Cut): + return Cut(lambda df: self.idx_array(df) ^ other.idx_array(df)) + elif callable(other): + return Cut(lambda df: self.idx_array(df) ^ other(df)) else: - return Cut(lambda df: self(df) ^ other) + return Cut(lambda df: self.idx_array(df) ^ other) def __invert__(self): """ Returns a new cut implementing the logical NOT of this cut. """ - return Cut(lambda df: ~self(df)) + return Cut(lambda df: ~self.idx_array(df)) def __rand__(self, other): return self & other diff --git a/nnfwtbn/plot.py b/nnfwtbn/plot.py index 0491314..1105a41 100644 --- a/nnfwtbn/plot.py +++ b/nnfwtbn/plot.py @@ -160,9 +160,9 @@ def hist(dataframe, variable, bins, stacks, selection=None, process_kwds = defaults n, _ = np.histogram( - variable(dataframe[sel(dataframe)]), + variable(sel(dataframe)), bins=bins, range=range, - weights=weight(dataframe[sel(dataframe)])) + weights=weight(sel(dataframe))) bin_centers = (bins[1:] + bins[:-1]) / 2 bin_widths = bins[1:] - bins[:-1] @@ -173,11 +173,11 @@ def hist(dataframe, variable, bins, stacks, selection=None, else: n, _, _ = axes.hist( - variable(dataframe[sel(dataframe)]), + variable(sel(dataframe)), bins=bins, range=range, bottom=bottom, label=process.label, - weights=weight(dataframe[sel(dataframe)]), + weights=weight(sel(dataframe)), **process_kwds) bottom += n @@ -243,10 +243,10 @@ def confusion_matrix(df, x_processes, y_processes, x_label, y_label, y_processes.reverse() data = np.zeros((len(y_processes), len(x_processes))) for i_x, x_process in enumerate(x_processes): - x_df = df[x_process.selection(df)] + x_df = x_process.selection(df) total_weight = weight(x_df).sum() for i_y, y_process in enumerate(y_processes): - x_y_df = x_df[y_process.selection(x_df)] + x_y_df = y_process.selection(x_df) data[i_y][i_x] = weight(x_y_df).sum() / total_weight data = pd.DataFrame(data, @@ -302,12 +302,12 @@ def roc(df, signal_process, background_process, discriminant, steps=100, if max is None: max = discriminant(df).max() - df = df[selection(df)] + df = selection(df) signal_effs = [] background_ieffs = [] - n_total_sig = weight(df[signal_process.selection(df)]).sum() - n_total_bkg = weight(df[background_process.selection(df)]).sum() + n_total_sig = weight(signal_process.selection(df)).sum() + n_total_bkg = weight(background_process.selection(df)).sum() for cut_value in np.linspace(min, max, steps): residual_df = df[discriminant(df) >= cut_value] @@ -315,8 +315,8 @@ def roc(df, signal_process, background_process, discriminant, steps=100, if n_total == 0: continue - signal_df = residual_df[signal_process.selection(residual_df)] - background_df = residual_df[background_process.selection(residual_df)] + signal_df = signal_process.selection(residual_df) + background_df = background_process.selection(residual_df) n_signal = weight(signal_df).sum() n_background = weight(background_df).sum() diff --git a/nnfwtbn/tests/test_cut.py b/nnfwtbn/tests/test_cut.py index 68f1dbb..2c22f2f 100644 --- a/nnfwtbn/tests/test_cut.py +++ b/nnfwtbn/tests/test_cut.py @@ -202,7 +202,7 @@ class CutTestCase(unittest.TestCase): high_sale = Cut(lambda df: df.sale > 4) combined = (lambda df: df.year < 2015) ^ high_sale - selected = combined(self.df) + selected = combined.idx_array(self.df) self.assertEqual(list(selected), [True, True, False, False, False, False, False, True]) diff --git a/nnfwtbn/tests/test_variable.py b/nnfwtbn/tests/test_variable.py index 78ffe3c..c2edc75 100644 --- a/nnfwtbn/tests/test_variable.py +++ b/nnfwtbn/tests/test_variable.py @@ -114,7 +114,7 @@ class RangeBlindingTestCase(unittest.TestCase): df = self.generate_df() blinding = blinding_strategy(variable, bins=30, range=(50, 200)) - blinded_df = df[blinding(df)] + blinded_df = blinding(df) # All events outside self.assertTrue(( @@ -140,7 +140,7 @@ class RangeBlindingTestCase(unittest.TestCase): df = self.generate_df() blinding = blinding_strategy(variable, bins=15, range=(50, 200)) - blinded_df = df[blinding(df)] + blinded_df = blinding(df) # All events outside self.assertTrue(( @@ -162,7 +162,7 @@ class RangeBlindingTestCase(unittest.TestCase): df = self.generate_df() blinding = blinding_strategy(variable, bins=15, range=(50, 200)) - blinded_df = df[blinding(df)] + blinded_df = blinding(df) # All events outside self.assertTrue(( @@ -184,7 +184,7 @@ class RangeBlindingTestCase(unittest.TestCase): df = self.generate_df() blinding = blinding_strategy(variable, bins=15, range=(50, 200)) - blinded_df = df[blinding(df)] + blinded_df = blinding(df) # All events outside self.assertTrue(( -- GitLab