diff --git a/nnfwtbn/model.py b/nnfwtbn/model.py index a7fdb909012ef80e0ff8e80c1d90c63e7b408ede..e9ede9a1172a9bcd63db62b15ae54aa6c8d1f537 100644 --- a/nnfwtbn/model.py +++ b/nnfwtbn/model.py @@ -69,7 +69,7 @@ class CrossValidator(ABC): """ @abstractmethod - def select_training(self, df, fold_i): + def select_training(self, df, fold_i, for_predicting = False): """ Returns the index array to select all training events from the dataset for the given fold.