diff --git a/nnfwtbn/cut.py b/nnfwtbn/cut.py index cc29e7264e241c8b7558dc5c796bca914430b1b4..0a410f0e103a50fa3ba7e78e14a9bceb68c65a17 100644 --- a/nnfwtbn/cut.py +++ b/nnfwtbn/cut.py @@ -72,7 +72,7 @@ class Cut: Applies the internally stored cut to the given dataframe and returns a new dataframe containing only entries passing the event selection. """ - return self.idx_array(dataframe) + return dataframe[self.idx_array(dataframe)] def idx_array(self, dataframe): """ @@ -89,36 +89,42 @@ class Cut: Returns a new cut implementing the logical AND of this cut and the other cut. The other cat be a Cut or any callable. """ - if callable(other): - return Cut(lambda df: self(df) & other(df)) + if isinstance(other, Cut): + return Cut(lambda df: self.idx_array(df) & other.idx_array(df)) + elif callable(other): + return Cut(lambda df: self.idx_array(df) & other(df)) else: - return Cut(lambda df: self(df) & other) + return Cut(lambda df: self.idx_array(df) & other) def __or__(self, other): """ Returns a new cut implementing the logical OR of this cut and the other cut. The other cat be a Cut or any callable. """ - if callable(other): - return Cut(lambda df: self(df) | other(df)) + if isinstance(other, Cut): + return Cut(lambda df: self.idx_array(df) | other.idx_array(df)) + elif callable(other): + return Cut(lambda df: self.idx_array(df) | other(df)) else: - return Cut(lambda df: self(df) | other) + return Cut(lambda df: self.idx_array(df) | other) def __xor__(self, other): """ Returns a new cut implementing the logical XOR of this cut and the other cut. The other can be a callable. """ - if callable(other): - return Cut(lambda df: self(df) ^ other(df)) + if isinstance(other, Cut): + return Cut(lambda df: self.idx_array(df) ^ other.idx_array(df)) + elif callable(other): + return Cut(lambda df: self.idx_array(df) ^ other(df)) else: - return Cut(lambda df: self(df) ^ other) + return Cut(lambda df: self.idx_array(df) ^ other) def __invert__(self): """ Returns a new cut implementing the logical NOT of this cut. """ - return Cut(lambda df: ~self(df)) + return Cut(lambda df: ~self.idx_array(df)) def __rand__(self, other): return self & other diff --git a/nnfwtbn/plot.py b/nnfwtbn/plot.py index 04913140cce0884082084a853750410c0803f534..1105a41e5c4ab7b6c9429969c8eaab95bd007499 100644 --- a/nnfwtbn/plot.py +++ b/nnfwtbn/plot.py @@ -160,9 +160,9 @@ def hist(dataframe, variable, bins, stacks, selection=None, process_kwds = defaults n, _ = np.histogram( - variable(dataframe[sel(dataframe)]), + variable(sel(dataframe)), bins=bins, range=range, - weights=weight(dataframe[sel(dataframe)])) + weights=weight(sel(dataframe))) bin_centers = (bins[1:] + bins[:-1]) / 2 bin_widths = bins[1:] - bins[:-1] @@ -173,11 +173,11 @@ def hist(dataframe, variable, bins, stacks, selection=None, else: n, _, _ = axes.hist( - variable(dataframe[sel(dataframe)]), + variable(sel(dataframe)), bins=bins, range=range, bottom=bottom, label=process.label, - weights=weight(dataframe[sel(dataframe)]), + weights=weight(sel(dataframe)), **process_kwds) bottom += n @@ -243,10 +243,10 @@ def confusion_matrix(df, x_processes, y_processes, x_label, y_label, y_processes.reverse() data = np.zeros((len(y_processes), len(x_processes))) for i_x, x_process in enumerate(x_processes): - x_df = df[x_process.selection(df)] + x_df = x_process.selection(df) total_weight = weight(x_df).sum() for i_y, y_process in enumerate(y_processes): - x_y_df = x_df[y_process.selection(x_df)] + x_y_df = y_process.selection(x_df) data[i_y][i_x] = weight(x_y_df).sum() / total_weight data = pd.DataFrame(data, @@ -302,12 +302,12 @@ def roc(df, signal_process, background_process, discriminant, steps=100, if max is None: max = discriminant(df).max() - df = df[selection(df)] + df = selection(df) signal_effs = [] background_ieffs = [] - n_total_sig = weight(df[signal_process.selection(df)]).sum() - n_total_bkg = weight(df[background_process.selection(df)]).sum() + n_total_sig = weight(signal_process.selection(df)).sum() + n_total_bkg = weight(background_process.selection(df)).sum() for cut_value in np.linspace(min, max, steps): residual_df = df[discriminant(df) >= cut_value] @@ -315,8 +315,8 @@ def roc(df, signal_process, background_process, discriminant, steps=100, if n_total == 0: continue - signal_df = residual_df[signal_process.selection(residual_df)] - background_df = residual_df[background_process.selection(residual_df)] + signal_df = signal_process.selection(residual_df) + background_df = background_process.selection(residual_df) n_signal = weight(signal_df).sum() n_background = weight(background_df).sum() diff --git a/nnfwtbn/tests/test_cut.py b/nnfwtbn/tests/test_cut.py index 68f1dbbaf87a60b967d89020403bdeb5468b93cf..2c22f2fcdf4df5d898543d48e830369e3bf5fc24 100644 --- a/nnfwtbn/tests/test_cut.py +++ b/nnfwtbn/tests/test_cut.py @@ -202,7 +202,7 @@ class CutTestCase(unittest.TestCase): high_sale = Cut(lambda df: df.sale > 4) combined = (lambda df: df.year < 2015) ^ high_sale - selected = combined(self.df) + selected = combined.idx_array(self.df) self.assertEqual(list(selected), [True, True, False, False, False, False, False, True]) diff --git a/nnfwtbn/tests/test_variable.py b/nnfwtbn/tests/test_variable.py index 78ffe3ccbc398baf4c70b2b3ee91b55186362056..c2edc75bd6173052ea44eabbae78574140ffa475 100644 --- a/nnfwtbn/tests/test_variable.py +++ b/nnfwtbn/tests/test_variable.py @@ -114,7 +114,7 @@ class RangeBlindingTestCase(unittest.TestCase): df = self.generate_df() blinding = blinding_strategy(variable, bins=30, range=(50, 200)) - blinded_df = df[blinding(df)] + blinded_df = blinding(df) # All events outside self.assertTrue(( @@ -140,7 +140,7 @@ class RangeBlindingTestCase(unittest.TestCase): df = self.generate_df() blinding = blinding_strategy(variable, bins=15, range=(50, 200)) - blinded_df = df[blinding(df)] + blinded_df = blinding(df) # All events outside self.assertTrue(( @@ -162,7 +162,7 @@ class RangeBlindingTestCase(unittest.TestCase): df = self.generate_df() blinding = blinding_strategy(variable, bins=15, range=(50, 200)) - blinded_df = df[blinding(df)] + blinded_df = blinding(df) # All events outside self.assertTrue(( @@ -184,7 +184,7 @@ class RangeBlindingTestCase(unittest.TestCase): df = self.generate_df() blinding = blinding_strategy(variable, bins=15, range=(50, 200)) - blinded_df = df[blinding(df)] + blinded_df = blinding(df) # All events outside self.assertTrue((