diff --git a/nnfwtbn/model.py b/nnfwtbn/model.py index a725b6bdbd7197d95dd04462a1d91e02c4322d6c..711d203284688dc653c79a4780387e657e89290c 100644 --- a/nnfwtbn/model.py +++ b/nnfwtbn/model.py @@ -89,6 +89,33 @@ class CrossValidator(ABC): given fold. """ + def select_cv_set(self, df, cv, fold_i): + """ + Returns the index array to select all events from the cross validator + set specified with cv ('train', 'val', 'test') for the given fold. + """ + if cv not in ['train', 'val', 'test']: + raise ValueError("Argument 'cv' must be one of 'train', 'val', " + "'test', 'all'; but was %s." % repr(cv)) + if cv == "train": + selected = self.select_training(df, fold_i) + elif cv == "val": + selected = self.select_validation(df, fold_i) + else: + selected = self.select_test(df, fold_i) + return selected + + def retrieve_fold_info(self, df, cv): + """ + Returns and array of integers to specify which event was used + for train/val/test in which fold + """ + fold_info = np.zeros(len(df), dtype='bool') - 1 + for fold_i in range(self.k): + selected = self.select_cv_set(df, cv, fold_i) + fold_info[selected] = fold_i + return fold_info + def save_to_h5(self, path, key, overwrite=False): """ Save cross validator definition to a hdf5 file. @@ -184,6 +211,7 @@ class ClassicalCV(CrossValidator): selected = selected | self.select_slice(df, slice_i) return selected + class BinaryCV(CrossValidator): """ @@ -607,7 +635,6 @@ class HepNet: "'test', 'all'; but was %s." % repr(cv)) out = np.zeros((len(df), len(self.output_list))) - if retrieve_fold_info: fold = np.zeros(len(df), dtype='int') - 1 test_set = np.zeros(len(df), dtype='bool') for fold_i in range(self.cv.k): @@ -615,25 +642,19 @@ class HepNet: norm = self.norms[fold_i] # identify fold - if cv == "train": - selected = self.cv.select_training(df, fold_i) - elif cv == "val": - selected = self.cv.select_validation(df, fold_i) - else: - selected = self.cv.select_test(df, fold_i) + selected = self.cv.select_cv_set(df, cv, fold_i) test_set |= selected out[selected] = model.predict(norm(df[selected][self.input_list]), **kwds) - if retrieve_fold_info: fold[selected] = fold_i - + test_df = df[test_set] out = out[test_set].transpose() out = dict(zip(["pred_" + s for s in self.output_list], out)) test_df = test_df.assign(**out) if retrieve_fold_info: - fold = {cv + "_fold" : fold} + fold = {cv + "_fold" : self.cv.retrieve_fold_info(df, cv)} test_df = test_df.assign(**fold) return test_df