From 49ce9eb2544c0671024966a78f082776a2f1d863 Mon Sep 17 00:00:00 2001 From: Ahmed Markhoos <ahmed.markhoos@cern.ch> Date: Thu, 22 Sep 2022 11:35:16 +0200 Subject: [PATCH] add _get_model_path() and its test function --- freeforestml/model.py | 41 ++++++++++---------------------- freeforestml/tests/test_model.py | 22 +++++++++++++++++ 2 files changed, 35 insertions(+), 28 deletions(-) diff --git a/freeforestml/model.py b/freeforestml/model.py index 4703f7c..0e056d9 100644 --- a/freeforestml/model.py +++ b/freeforestml/model.py @@ -633,6 +633,15 @@ class HepNet: return True + def _get_model_path(self, path, fold_i): + """ + Returns the path of the fold_i model + """ + path_token = list( os.path.splitext(path) ) + path_token.insert(1, f".fold_{fold_i}") + + return "".join(path_token) + def fit(self, df, weight=None, **kwds): """ Calls fit() on all folds. All kwds are passed to fit(). @@ -727,21 +736,9 @@ class HepNet: # save model architecture and weights (only if already trained) if len(self.models) == self.cv.k: for fold_i in range(self.cv.k): - path_token = path.rsplit("/", 1) - file_token = path_token[-1].rsplit(".", 1) - - if len(file_token) == 1: - file_token.append(f"fold_{fold_i}") - else: - file_token.insert(-1, f"fold_{fold_i}") - - if len(path_token) == 1: - path_token = [".".join(file_token)] - else: - path_token = [path_token[0]] + [".".join(file_token)] - + path_token = self._get_model_path(path, fold_i) # this is the built-in save function from keras - self.models[fold_i].save("/".join(path_token)) + self.models[fold_i].save(path_token) with h5py.File(path, "w") as output_file: # save default model class @@ -797,20 +794,8 @@ class HepNet: # load trained models (if existing) with h5py.File(path, "r") as input_file: for fold_i in range(cv.k): - path_token = path.rsplit("/", 1) - file_token = path_token[-1].rsplit(".", 1) - - if len(file_token) == 1: - file_token.append(f"fold_{fold_i}") - else: - file_token.insert(-1, f"fold_{fold_i}") - - if len(path_token) == 1: - path_token = [".".join(file_token)] - else: - path_token = [path_token[0]] + [".".join(file_token)] - - model = tensorflow.keras.models.load_model("/".join(path_token)) + path_token = instance._get_model_path(path, fold_i) + model = tensorflow.keras.models.load_model(path_token) instance.models.append(model) # load normalizer diff --git a/freeforestml/tests/test_model.py b/freeforestml/tests/test_model.py index d7c9b80..6e0eb56 100644 --- a/freeforestml/tests/test_model.py +++ b/freeforestml/tests/test_model.py @@ -879,6 +879,28 @@ class CategoricalWeightNormalizerTestCase(unittest.TestCase): class HepNetTestCase(unittest.TestCase): + def test_get_model_path(self): + """ + Test retrieving the fold_i model path + """ + net = HepNet(None,None,None,None,None) + + self.assertEqual("/system.directory/path/to/model.fold_0", + net._get_model_path("/system.directory/path/to/model",0)) + + self.assertEqual("/system.directory/path/to/model.fold_0.h5", + net._get_model_path("/system.directory/path/to/model.h5",0)) + + self.assertEqual("/system_directory/path/to/model.fold_0", + net._get_model_path("/system_directory/path/to/model",0)) + + self.assertEqual("/system_directory/path/to/model.fold_0.h5", + net._get_model_path("/system_directory/path/to/model.h5",0)) + + self.assertEqual("model.fold_0", net._get_model_path("model",0)) + self.assertEqual("model.fold_0.h5", net._get_model_path("model.h5",0)) + + def test_saving_and_loading(self): """ Test that saving and loading a neural network doesn't change its configuration. -- GitLab