Skip to content
Snippets Groups Projects
Unverified Commit 4b9f72b7 authored by Frank Sauerburger's avatar Frank Sauerburger
Browse files

Draft HepNet class

parent 3c949c92
Branches 59-add-systematic-band-stack
No related tags found
1 merge request!5Resolve "Implement Meta Model"
......@@ -32,8 +32,6 @@ class CrossValidator(ABC):
# Handle variable
if isinstance(self.variable, str):
self.variable = Variable(self.variable, self.variable)
@abstractmethod
def select_slice(self, df, slice_id):
......@@ -202,33 +200,28 @@ class HepNet:
Creates a new HEP model. The keras model parameter is a callable that
returns a new instance of the model (The HEP net needs to able to create
multiple models, one for each cross validation fold.)
"""
def set(self, key, value):
"""
Stores properties such as normalization moments.
"""
The cross_validator must be a CrossValidator object.
def get(self, key, value):
"""
Returns properties such as normalization moments.
The normalizor must be a callable that returns a normalizor. Each
cross_validation fold uses a separate normalizor with independent
normalization weights.
"""
self.model = keras_model
self.cv = cross_validator
self.norm = normalizor
def store(self, file):
"""
Write properties and training weights to the given file.
"""
def restore(self, file):
"""
Loads properties and training weights from the given file.
"""
def fit(self, df, selection):
"""
Calls fit() on the Keras model
Calls fit() on all folds.
"""
### Loop over folds:
# seed normalizors
# fit folds
def predict(self, df):
"""
Calls predict() on the Keras model
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment