diff --git a/freeforestml/model.py b/freeforestml/model.py
index 62c70e5de80489c47eb384ade40165a8fab29c60..0629fde03d8203377762415daf309d8f89920715 100644
--- a/freeforestml/model.py
+++ b/freeforestml/model.py
@@ -552,6 +552,60 @@ class EstimatorNormalizer(Normalizer):
     def offsets(self):
         return -self.center / self. width
 
+class IdentityNormalizer(Normalizer):
+    '''
+    Has no effect on the df. Basically, UnitNormalizer(df) returns df
+    '''
+    def __init__(self, df, input_list=None, center=None, width=None):
+        if center is not None and width is not None:
+            self.center = center
+            self.width = width
+        else:
+            if input_list is not None:
+                if isinstance(input_list[0],list):
+                    input_list = sum(input_list, [])
+                    input_list = sorted(set(input_list), key=input_list.index)
+                    df = df[input_list]
+                else:
+                    df = df[input_list]
+                    
+            keys = df.keys()
+            self.center = pd.Series({name: 0.0 for name in keys})
+            self.width = pd.Series({name: 1.0 for name in keys})
+
+    def __call__(self, df):
+        return df
+    
+    def __eq__(self, other):
+        if not isinstance(other, self.__class__):
+            return False
+
+        if not self.center.equals(other.center):
+            return False
+
+        if not self.width.equals(other.width):
+            return False
+
+        return True
+
+    def _save_to_h5(self, path, key):
+        self.center.to_hdf(path, key=os.path.join(key, "center"))
+        self.width.to_hdf(path, key=os.path.join(key, "width"))
+
+    @classmethod
+    def _load_from_h5(cls, path, key):
+        center = pd.read_hdf(path, os.path.join(key, "center"))
+        width = pd.read_hdf(path, os.path.join(key, "width"))
+        return cls(None, center=center, width=width)
+
+    @property
+    def scales(self):
+        return 1.0
+
+    @property
+    def offsets(self):
+        return -0.0
+
 def normalize_category_weights(df, categories, weight='weight'):
     """
     The categorical weight normalizer acts on the weight variable only. The
@@ -600,7 +654,7 @@ class HepNet:
         """
         self.model_cls = keras_model
         self.cv = cross_validator
-        self.norm_cls = normalizer
+        self.norm_cls = IdentityNormalizer if normalizer==None else normalizer
         self.input_list = input_list
         self.output_list = output_list
         self.norms = []
@@ -710,8 +764,6 @@ class HepNet:
             history['epoch'] = np.arange(len(history['loss']))
             self.history = pd.concat([self.history, pd.DataFrame(history)])
 
-            tf.keras.backend.clear_session()
-
     def predict(self, df, cv='val', retrieve_fold_info = False, **kwds):
         """
         Calls predict() on the Keras model. The argument cv specifies the
@@ -822,7 +874,7 @@ class HepNet:
         self.history.to_hdf(path, "history")
 
     @classmethod
-    def load(cls, path):
+    def load(cls, path, **kwds):
         """
         Restore a model from a hdf5 file.
         """
@@ -868,7 +920,7 @@ class HepNet:
         with h5py.File(path, "r") as input_file:
             for fold_i in range(cv.k):
                 path_token = instance._get_model_path(path, fold_i)
-                model = tf.keras.models.load_model(path_token)
+                model = tf.keras.models.load_model(path_token, **kwds)
                 instance.models.append(model)
 
         # load normalizer
diff --git a/freeforestml/tests/test_model.py b/freeforestml/tests/test_model.py
index a4589d2d9fbacaa085da52aaf7498f5b18ee8747..b01fdba3bb2e7c7405b7c06ab657c374f043e326 100644
--- a/freeforestml/tests/test_model.py
+++ b/freeforestml/tests/test_model.py
@@ -11,7 +11,7 @@ from tensorflow.keras.layers import Dense, Dropout, Input, Concatenate
 from tensorflow.keras.optimizers import SGD
 
 from freeforestml.model import CrossValidator, ClassicalCV, MixedCV, \
-                          Normalizer, EstimatorNormalizer, \
+                          Normalizer, EstimatorNormalizer, IdentityNormalizer, \
                           normalize_category_weights, BinaryCV, HepNet, \
                           NoTestCV
 from freeforestml.variable import Variable
@@ -820,6 +820,109 @@ class EstimatorNormalizerTestCase(unittest.TestCase):
         self.assertTrue(norm1 == norm2)
 
 
+class IdentityNormalizerTestCase(unittest.TestCase):
+    """
+    Test the implementation of IdentityNormalizer.
+    """
+
+    def generate_df(self):
+        """
+        Generate toy dataframe.
+        """
+        return pd.DataFrame({
+            "x": [9, 10, 10, 12, 12, 13],
+            "y": [0, 1, 1, 1, 1, 2],
+            "z": [0, 0, 0, 0, 0, 0],  # Column with zero width
+        })
+
+    def generate_test_df(self):
+        """
+        Generate toy dataframe used to test the normalization.
+        """
+        return pd.DataFrame({
+            "x": [6, 11, 16],
+            "y": [-1, 1, 3],
+            "z": [-1, 0, 1],
+        })
+
+    def test_init(self):
+        """
+        Check that the constructor returns (center,width) of (0,1) for
+        all columns.
+        """
+        df = self.generate_df()
+        norm = IdentityNormalizer(df)
+
+        self.assertEqual(len(norm.center), 3)
+        self.assertAlmostEqual(norm.center["x"], 0)
+        self.assertAlmostEqual(norm.center["y"], 0)
+        self.assertAlmostEqual(norm.center["z"], 0)
+
+        self.assertEqual(len(norm.width), 3)
+        self.assertAlmostEqual(norm.width["x"], 1)
+        self.assertAlmostEqual(norm.width["y"], 1)
+        self.assertAlmostEqual(norm.width["z"], 1)
+
+    def test_init_input_list(self):
+        """
+        Check that the constructor assigns (center,width)=(0,1) only to the
+        columns listed in the input_list.
+        """
+        df = self.generate_df()
+        norm = IdentityNormalizer(df, input_list=["x", "z"])
+
+        self.assertEqual(list(norm.center.index), ["x", "z"])
+        self.assertEqual(list(norm.width.index), ["x", "z"])
+
+    def test_call(self):
+        """
+        Check that the normalizer is in fact an identity map.
+        """
+        df   = self.generate_df()
+        norm = IdentityNormalizer(df)
+
+        df_test = self.generate_test_df()
+        normed  = norm(self.generate_test_df())
+
+        self.assertEqual(list(normed.index), list(df_test.index))
+        self.assertEqual(list(normed.keys()), list(df_test.keys()))
+        self.assertEqual(normed.values.tolist(), df_test.values.tolist())
+
+    def test_call_other_vars(self):
+        """
+        Check that columns in the dataframe exist, if moments are
+        missing.
+        """
+        df = self.generate_df()
+        norm = IdentityNormalizer(df, input_list=["x", "z"])
+
+        normed = norm(self.generate_test_df())
+        self.assertEqual(list(normed.x), [6, 11, 16])
+        self.assertEqual(list(normed.y), [-1, 1, 3])
+        self.assertEqual(list(normed.z), [-1, 0, 1])
+
+    def test_equal_same_values(self):
+        df = self.generate_df()
+        norm1 = IdentityNormalizer(df)
+        norm2 = IdentityNormalizer(df)
+        self.assertTrue(norm1 == norm2)
+
+    def test_saving_and_loading(self):
+        """
+        Test that saving and loading the normalizer doesn't change its configuration.
+        """
+        df = self.generate_df()
+        norm1 = IdentityNormalizer(df)
+        fd, path = tempfile.mkstemp()
+        try:
+            norm1.save_to_h5(path, "norm")
+            norm2 = Normalizer.load_from_h5(path, "norm")
+        finally:
+            # close file descriptor and delete file
+            os.close(fd)
+            os.remove(path)
+        self.assertTrue(norm1 == norm2)
+
 
 class CategoricalWeightNormalizerTestCase(unittest.TestCase):
     """