diff --git a/nnfwtbn/model.py b/nnfwtbn/model.py
index a725b6bdbd7197d95dd04462a1d91e02c4322d6c..711d203284688dc653c79a4780387e657e89290c 100644
--- a/nnfwtbn/model.py
+++ b/nnfwtbn/model.py
@@ -89,6 +89,33 @@ class CrossValidator(ABC):
         given fold.
         """
 
+    def select_cv_set(self, df, cv, fold_i):
+        """
+        Returns the index array to select all events from the cross validator
+        set specified with cv ('train', 'val', 'test') for the given fold.
+        """
+        if cv not in ['train', 'val', 'test']:
+            raise ValueError("Argument 'cv' must be one of 'train', 'val', "
+                             "'test', 'all'; but was %s." % repr(cv))
+        if cv == "train":
+            selected = self.select_training(df, fold_i)
+        elif cv == "val":
+            selected = self.select_validation(df, fold_i)
+        else:
+            selected = self.select_test(df, fold_i)
+        return selected
+    
+    def retrieve_fold_info(self, df, cv):
+        """
+        Returns and array of integers to specify which event was used 
+        for train/val/test in which fold
+        """
+        fold_info = np.zeros(len(df), dtype='bool') - 1
+        for fold_i in range(self.k):        
+            selected = self.select_cv_set(df, cv, fold_i)
+            fold_info[selected] = fold_i
+        return fold_info
+
     def save_to_h5(self, path, key, overwrite=False):
         """
         Save cross validator definition to a hdf5 file.
@@ -184,6 +211,7 @@ class ClassicalCV(CrossValidator):
             selected = selected | self.select_slice(df, slice_i)
 
         return selected
+        
 
 class BinaryCV(CrossValidator):
     """
@@ -607,7 +635,6 @@ class HepNet:
                              "'test', 'all'; but was %s." % repr(cv))
 
         out = np.zeros((len(df), len(self.output_list)))
-        if retrieve_fold_info: fold = np.zeros(len(df), dtype='int') - 1
         test_set = np.zeros(len(df), dtype='bool')
 
         for fold_i in range(self.cv.k):
@@ -615,25 +642,19 @@ class HepNet:
             norm = self.norms[fold_i]
 
             # identify fold
-            if cv == "train":
-                selected = self.cv.select_training(df, fold_i)
-            elif cv == "val":
-                selected = self.cv.select_validation(df, fold_i)
-            else:
-                selected = self.cv.select_test(df, fold_i)
+            selected = self.cv.select_cv_set(df, cv, fold_i)
 
             test_set |= selected
             out[selected] = model.predict(norm(df[selected][self.input_list]),
                                           **kwds)
-            if retrieve_fold_info: fold[selected] = fold_i
-        
+            
         test_df = df[test_set]
         out = out[test_set].transpose()
         out = dict(zip(["pred_" + s for s in self.output_list], out))
         test_df = test_df.assign(**out)
         
         if retrieve_fold_info:
-            fold = {cv + "_fold" :  fold}
+            fold = {cv + "_fold" :  self.cv.retrieve_fold_info(df, cv)}
             test_df = test_df.assign(**fold)
             
         return test_df