From 8cf2a8e86e4e0cc5ef5640eeead42931b6ab26e9 Mon Sep 17 00:00:00 2001
From: Frank Sauerburger <f.sauerburger@cern.ch>
Date: Thu, 22 Oct 2020 10:23:40 +0200
Subject: [PATCH] Fix net naming scheme for unit tests

---
 nnfwtbn/model.py | 10 ++++++++--
 1 file changed, 8 insertions(+), 2 deletions(-)

diff --git a/nnfwtbn/model.py b/nnfwtbn/model.py
index 711d203..79a1d8f 100644
--- a/nnfwtbn/model.py
+++ b/nnfwtbn/model.py
@@ -668,7 +668,10 @@ class HepNet:
         if len(self.models) == self.cv.k:
             for fold_i in range(self.cv.k):
                 path_token = path.rsplit(".", 1)
-                path_token.insert(-1, f"fold_{fold_i}")
+                if len(path_token) == 1:
+                    path_token.append(f"fold_{fold_i}")
+                else:
+                    path_token.insert(-1, f"fold_{fold_i}")
 
                 # this is the built-in save function from keras
                 self.models[fold_i].save(".".join(path_token))
@@ -728,7 +731,10 @@ class HepNet:
         with h5py.File(path, "r") as input_file:
             for fold_i in range(cv.k):
                 path_token = path.rsplit(".", 1)
-                path_token.insert(-1, f"fold_{fold_i}")
+                if len(path_token) == 1:
+                    path_token.append(f"fold_{fold_i}")
+                else:
+                    path_token.insert(-1, f"fold_{fold_i}")
 
                 model = keras.models.load_model(".".join(path_token))
                 instance.models.append(model)
-- 
GitLab