Skip to content
Snippets Groups Projects
Unverified Commit a72a4005 authored by Frank Sauerburger's avatar Frank Sauerburger
Browse files

Add lwtnn export() method to HepModel

parent 6d2779e7
No related branches found
No related tags found
1 merge request!39Resolve "LWTNN model export"
Pipeline #12554 passed
%% Cell type:code id: tags:
``` python
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.optimizers import SGD
from nnfwtbn import Variable, Process, Cut, \
HepNet, ClassicalCV, EstimatorNormalizer, \
HistogramFactory, confusion_matrix, atlasify, \
McStack
from nnfwtbn import toydata
```
%% Cell type:code id: tags:
``` python
df = toydata.get()
```
%% Cell type:code id: tags:
``` python
p_ztt = Process(r"$Z\rightarrow\tau\tau$", range=(0, 0))
p_sig = Process(r"Signal", range=(1, 1))
s_all = McStack(p_ztt, p_sig)
```
%% Cell type:code id: tags:
``` python
hist_factory = HistogramFactory(df, stacks=[s_all], weight="weight")
```
%% Cell type:markdown id: tags:
# Cut-based
%% Cell type:code id: tags:
``` python
hist_factory(Variable("$\Delta \eta^{jj}$",
lambda d: (d.jet_1_eta - d.jet_2_eta).abs()),
bins=20, range=(0, 8))
hist_factory(Variable("$m^{jj}$", "m_jj"),
bins=20, range=(0, 1500))
None
```
%% Cell type:code id: tags:
``` python
c_sr = Cut(lambda d: d.m_jj > 400) & \
Cut(lambda d: d.jet_2_pt >= 30) & \
Cut(lambda d: d.jet_1_eta * d.jet_2_eta < 0) & \
Cut(lambda d: (d.jet_2_eta - d.jet_1_eta).abs() > 3)
c_sr.label = "Signal"
c_rest = (~c_sr)
c_rest.label = "Rest"
```
%% Cell type:code id: tags:
``` python
confusion_matrix(df, [p_sig, p_ztt], [c_sr, c_rest], info=False,
x_label="Signal", y_label="Region", annot=True, weight="weight")
confusion_matrix(df, [p_sig, p_ztt], [c_sr, c_rest], normalize_rows=True, info=False,
x_label="Signal", y_label="Region", annot=True, weight="weight")
None
```
%% Cell type:markdown id: tags:
# Neural Network
%% Cell type:code id: tags:
``` python
df['dijet_deta'] = (df.jet_1_eta - df.jet_2_eta).abs()
df['dijet_prod_eta'] = (df.jet_1_eta * df.jet_2_eta)
input_var = ['dijet_prod_eta', 'm_jj', 'dijet_deta', 'higgs_pt', 'jet_2_pt', 'jet_1_eta', 'jet_2_eta', 'tau_eta']
output_var = ['is_sig', 'is_ztt']
```
%% Cell type:code id: tags:
``` python
df["is_sig"] = p_sig.selection.idx_array(df)
df["is_ztt"] = p_ztt.selection.idx_array(df)
```
%% Cell type:code id: tags:
``` python
sns.pairplot(df.sample(n=1000), vars=input_var, hue="is_sig")
None
```
%% Cell type:code id: tags:
``` python
def model():
m = Sequential()
m.add(Dense(units=15, activation='relu', input_dim=len(input_var)))
m.add(Dense(units=5, activation='relu'))
m.add(Dense(units=2, activation='softmax'))
m.compile(loss='categorical_crossentropy',
optimizer=SGD(lr=0.1),
metrics=['categorical_accuracy'])
return m
df['random'] = toydata.rng.random(size=len(df))
cv = ClassicalCV(5, frac_var='random')
net = HepNet(model, cv, EstimatorNormalizer, input_var, output_var)
```
%% Cell type:code id: tags:
``` python
sig_wf = len(p_sig.selection(df).weight) / p_sig.selection(df).weight.sum()
ztt_wf = len(p_ztt.selection(df).weight) / p_ztt.selection(df).weight.sum()
```
%% Cell type:code id: tags:
``` python
net.fit(df, epochs=150, verbose=0, batch_size=2048,
weight=Variable("weight", lambda d: d.weight * (d.is_sig * sig_wf + d.is_ztt * ztt_wf)))
```
%% Cell type:code id: tags:
``` python
sns.lineplot(x='epoch', y='loss', data=net.history, label="Training")
sns.lineplot(x='epoch', y='val_loss', data=net.history, label="Validation")
plt.ylabel("loss")
atlasify("Internal")
None
```
%% Cell type:markdown id: tags:
## Accuracy
%% Cell type:code id: tags:
``` python
sns.lineplot(x='epoch', y='categorical_accuracy', data=net.history, label="Training")
sns.lineplot(x='epoch', y='val_categorical_accuracy', data=net.history, label="Validation")
plt.ylabel("Accuracy")
atlasify("Internal")
None
```
%% Cell type:code id: tags:
``` python
sns.lineplot(x='epoch', y='val_categorical_accuracy', data=net.history, hue="fold")
atlasify("Internal", enlarge=1.6)
None
```
%% Cell type:code id: tags:
``` python
out = net.predict(df, cv='test')
out['pred_sig'] = out.pred_is_sig >= 0.5
```
%% Cell type:code id: tags:
``` python
c_pred_sig = Process("Signal", lambda d: d.pred_is_sig >= 0.5)
c_pred_ztt = Process(r"$Z\rightarrow\tau\tau$", lambda d: d.pred_is_sig < 0.5)
confusion_matrix(out, [p_sig, p_ztt], [c_pred_sig, c_pred_ztt], info=False,
x_label="Truth", y_label="Classification", annot=True, weight="weight")
confusion_matrix(out, [p_sig, p_ztt], [c_pred_sig, c_pred_ztt], normalize_rows=True, info=False,
x_label="Truth", y_label="Classification", annot=True, weight="weight")
None
```
%% Cell type:markdown id: tags:
## Export to lwtnn
In order to use the network in lwtnn, we need to export the neural network with the `export()` method. This export one network per fold. It is the reposibility of the use to implement the cross validation in the analysis framework.
%% Cell type:code id: tags:
``` python
net.export("lwtnn")
```
%% Cell type:code id: tags:
``` python
!ls lwtnn*
```
%% Cell type:markdown id: tags:
The final, manuel step is to run the lwtnn's converter using the shortcut script `test.sh`.
......
......@@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
import os
import sys
import h5py
import json
import numpy as np
import pandas as pd
......@@ -269,6 +270,24 @@ class Normalizer(ABC):
Check if two normalizers are the same.
"""
@property
@abstractmethod
def scales(self):
"""
Every normalizor must reduce to a simple (offset + scale * x)
normalization to be used with lwtnn. This property returns the scale
parameters for all variables.
"""
@property
@abstractmethod
def offsets(self):
"""
Every normalizor must reduce to a simple (offset + scale * x)
normalization to be used with lwtnn. This property returns the offset
parameters for all variables.
"""
def save_to_h5(self, path, key, overwrite=False):
"""
Save normalizer definition to a hdf5 file.
......@@ -377,6 +396,14 @@ class EstimatorNormalizer(Normalizer):
width = pd.read_hdf(path, os.path.join(key, "width"))
return cls(None, center=center, width=width)
@property
def scales(self):
return 1 / self.width
@property
def offsets(self):
return -self.center / self. width
def normalize_category_weights(df, categories, weight='weight'):
"""
The categorical weight normalizer acts on the weight variable only. The
......@@ -615,3 +642,43 @@ class HepNet:
instance.norms.append(norm)
return instance
def export(self, path_base, command="converters/keras2json.py"):
"""
Exports the network such that it can be converted to lwtnn's json
format. The method generate a set of files for each cross validation
fold. For every fold, the archtecture, the weights, the input
variables and their normalization is exported. To simplify the
conversion to lwtnn's json format, the method also creates a bash
script which converts all folds.
The path_base argument should be a path or a name of the network. The
names of the gerneated files are created by appending to path_base.
"""
for fold_i in range(self.cv.k):
# get the architecture as a json string
arch = self.models[fold_i].to_json()
# save the architecture string to a file somehow, the below will work
with open('%s_arch_%d.json' % (path_base, fold_i), 'w') as arch_file:
arch_file.write(arch)
# now save the weights as an HDF5 file
self.models[fold_i].save_weights('%s_wght_%d.h5' % (path_base, fold_i))
with open("%s_vars_%d.json" % (path_base, fold_i), "w") \
as variable_file:
scales = self.norms[fold_i].scales
offsets = self.norms[fold_i].offsets
inputs = [dict(name=v, offset=o, scale=s)
for v, o, s in zip(self.input_list, offsets, scales)]
json.dump(dict(inputs=inputs, class_labels=self.output_list),
variable_file)
mode = "w" if fold_i == 0 else "a"
with open("%s.sh" % path_base, mode) as script_file:
print(f"{command} {path_base}_arch_{fold_i}.json "
f"{path_base}_vars_{fold_i}.json "
f"{path_base}_wght_{fold_i}.h5 "
f"> {path_base}_{fold_i}.json", file=script_file)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment