Skip to content
Snippets Groups Projects
Commit 590e33bb authored by Ahmed Markhoos's avatar Ahmed Markhoos
Browse files

path-token-fix

parent 0bd99ca6
Branches 44-data-content-interface
No related tags found
1 merge request!70path-token-fix
Pipeline #12724 failed
......@@ -727,14 +727,21 @@ class HepNet:
# save model architecture and weights (only if already trained)
if len(self.models) == self.cv.k:
for fold_i in range(self.cv.k):
path_token = path.rsplit(".", 1)
path_token = path.rsplit("/", 1)
file_token = path_token[-1].rsplit(".", 1)
if len(file_token) == 1:
file_token.append(f"fold_{fold_i}")
else:
file_token.insert(-1, f"fold_{fold_i}")
if len(path_token) == 1:
path_token.append(f"fold_{fold_i}")
path_token = [".".join(file_token)]
else:
path_token.insert(-1, f"fold_{fold_i}")
path_token = [path_token[0]] + [".".join(file_token)]
# this is the built-in save function from keras
self.models[fold_i].save(".".join(path_token))
self.models[fold_i].save("/".join(path_token))
with h5py.File(path, "w") as output_file:
# save default model class
......@@ -790,13 +797,20 @@ class HepNet:
# load trained models (if existing)
with h5py.File(path, "r") as input_file:
for fold_i in range(cv.k):
path_token = path.rsplit(".", 1)
path_token = path.rsplit("/", 1)
file_token = path_token[-1].rsplit(".", 1)
if len(file_token) == 1:
file_token.append(f"fold_{fold_i}")
else:
file_token.insert(-1, f"fold_{fold_i}")
if len(path_token) == 1:
path_token.append(f"fold_{fold_i}")
path_token = [".".join(file_token)]
else:
path_token.insert(-1, f"fold_{fold_i}")
path_token = [path_token[0]] + [".".join(file_token)]
model = tensorflow.keras.models.load_model(".".join(path_token))
model = tensorflow.keras.models.load_model("/".join(path_token))
instance.models.append(model)
# load normalizer
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment