Skip to content
Snippets Groups Projects

Add option to retrieve fold info

Merged Frank Sauerburger requested to merge add-option-to-retrieve-fold-info into master
1 file
+ 31
10
Compare changes
  • Side-by-side
  • Inline
+ 31
10
@@ -89,6 +89,33 @@ class CrossValidator(ABC):
given fold.
"""
def select_cv_set(self, df, cv, fold_i):
"""
Returns the index array to select all events from the cross validator
set specified with cv ('train', 'val', 'test') for the given fold.
"""
if cv not in ['train', 'val', 'test']:
raise ValueError("Argument 'cv' must be one of 'train', 'val', "
"'test', 'all'; but was %s." % repr(cv))
if cv == "train":
selected = self.select_training(df, fold_i)
elif cv == "val":
selected = self.select_validation(df, fold_i)
else:
selected = self.select_test(df, fold_i)
return selected
def retrieve_fold_info(self, df, cv):
"""
Returns and array of integers to specify which event was used
for train/val/test in which fold
"""
fold_info = np.zeros(len(df), dtype='bool') - 1
for fold_i in range(self.k):
selected = self.select_cv_set(df, cv, fold_i)
fold_info[selected] = fold_i
return fold_info
def save_to_h5(self, path, key, overwrite=False):
"""
Save cross validator definition to a hdf5 file.
@@ -184,6 +211,7 @@ class ClassicalCV(CrossValidator):
selected = selected | self.select_slice(df, slice_i)
return selected
class BinaryCV(CrossValidator):
"""
@@ -607,7 +635,6 @@ class HepNet:
"'test', 'all'; but was %s." % repr(cv))
out = np.zeros((len(df), len(self.output_list)))
if retrieve_fold_info: fold = np.zeros(len(df), dtype='int') - 1
test_set = np.zeros(len(df), dtype='bool')
for fold_i in range(self.cv.k):
@@ -615,25 +642,19 @@ class HepNet:
norm = self.norms[fold_i]
# identify fold
if cv == "train":
selected = self.cv.select_training(df, fold_i)
elif cv == "val":
selected = self.cv.select_validation(df, fold_i)
else:
selected = self.cv.select_test(df, fold_i)
selected = self.cv.select_cv_set(df, cv, fold_i)
test_set |= selected
out[selected] = model.predict(norm(df[selected][self.input_list]),
**kwds)
if retrieve_fold_info: fold[selected] = fold_i
test_df = df[test_set]
out = out[test_set].transpose()
out = dict(zip(["pred_" + s for s in self.output_list], out))
test_df = test_df.assign(**out)
if retrieve_fold_info:
fold = {cv + "_fold" : fold}
fold = {cv + "_fold" : self.cv.retrieve_fold_info(df, cv)}
test_df = test_df.assign(**fold)
return test_df
Loading