diff --git a/nnfwtbn/model.py b/nnfwtbn/model.py index c4f7309cee36e47b92231b4960d2749bce188ebf..9057edcaf0f816686c28a15c9e44f2c76ad3a376 100644 --- a/nnfwtbn/model.py +++ b/nnfwtbn/model.py @@ -636,18 +636,21 @@ class HepNet: Save the model and all associated components to a hdf5 file. """ + # save model architecture and weights (only if already trained) + if len(self.models) == self.cv.k: + for fold_i in range(self.cv.k): + path_token = path.rsplit(".", 1) + path_token.insert(-1, f"fold_{fold_i}") + + # this is the built-in save function from keras + self.models[fold_i].save(".".join(path_token)) + with h5py.File(path, "w") as output_file: # save default model class # since this is a arbitrary piece of python code we need to use the python_to_str function group = output_file.create_group("models/default") group.attrs["model_cls"] = np.string_(python_to_str(self.model_cls)) - # save model architecture and weights (only if already trained) - if len(self.models) == self.cv.k: - for fold_i in range(self.cv.k): - group = output_file.create_group("models/fold_{}".format(fold_i)) - # this is the built-in save function from keras - self.models[fold_i].save(group) # save class name of default normalizer as string group = output_file.create_group("normalizers/default") @@ -696,10 +699,11 @@ class HepNet: # load trained models (if existing) with h5py.File(path, "r") as input_file: for fold_i in range(cv.k): - key = "models/fold_{}".format(fold_i) - if key in input_file: - model = keras.models.load_model(input_file[key]) - instance.models.append(model) + path_token = path.rsplit(".", 1) + path_token.insert(-1, f"fold_{fold_i}") + + model = keras.models.load_model(".".join(path_token)) + instance.models.append(model) # load normalizer for fold_i in range(cv.k):