From 8cf2a8e86e4e0cc5ef5640eeead42931b6ab26e9 Mon Sep 17 00:00:00 2001 From: Frank Sauerburger <f.sauerburger@cern.ch> Date: Thu, 22 Oct 2020 10:23:40 +0200 Subject: [PATCH] Fix net naming scheme for unit tests --- nnfwtbn/model.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/nnfwtbn/model.py b/nnfwtbn/model.py index 711d203..79a1d8f 100644 --- a/nnfwtbn/model.py +++ b/nnfwtbn/model.py @@ -668,7 +668,10 @@ class HepNet: if len(self.models) == self.cv.k: for fold_i in range(self.cv.k): path_token = path.rsplit(".", 1) - path_token.insert(-1, f"fold_{fold_i}") + if len(path_token) == 1: + path_token.append(f"fold_{fold_i}") + else: + path_token.insert(-1, f"fold_{fold_i}") # this is the built-in save function from keras self.models[fold_i].save(".".join(path_token)) @@ -728,7 +731,10 @@ class HepNet: with h5py.File(path, "r") as input_file: for fold_i in range(cv.k): path_token = path.rsplit(".", 1) - path_token.insert(-1, f"fold_{fold_i}") + if len(path_token) == 1: + path_token.append(f"fold_{fold_i}") + else: + path_token.insert(-1, f"fold_{fold_i}") model = keras.models.load_model(".".join(path_token)) instance.models.append(model) -- GitLab