Skip to content
Snippets Groups Projects
Unverified Commit ad25a3c7 authored by Frank Sauerburger's avatar Frank Sauerburger
Browse files

Update classification example for keras API

parent 21ea2cbd
Branches 44-data-content-interface
No related tags found
1 merge request!71Draft: Resolve "Warnings about missing 'weighted_metrics' in tensorflow evaluate()"
Pipeline #12732 passed
%% Cell type:markdown id: tags:
# Classification
%% Cell type:code id: tags:
``` python
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.optimizers import SGD
from freeforestml import Variable, Process, Cut, \
HepNet, ClassicalCV, EstimatorNormalizer, \
HistogramFactory, confusion_matrix, atlasify, \
McStack
from freeforestml import toydata, example_style
example_style()
```
%% Cell type:code id: tags:
``` python
df = toydata.get()
```
%% Cell type:code id: tags:
``` python
p_ztt = Process(r"$Z\rightarrow\tau\tau$", range=(0, 0))
p_sig = Process(r"Signal", range=(1, 1))
s_all = McStack(p_ztt, p_sig)
```
%% Cell type:code id: tags:
``` python
hist_factory = HistogramFactory(df, stacks=[s_all], weight="weight")
```
%% Cell type:markdown id: tags:
## Cut-based
%% Cell type:markdown id: tags:
First, we set up a cut-based event selection as a benchmark.
%% Cell type:code id: tags:
``` python
hist_factory(Variable("$\Delta \eta^{jj}$",
lambda d: (d.jet_1_eta - d.jet_2_eta).abs()),
bins=20, range=(0, 8))
hist_factory(Variable("$m^{jj}$", "m_jj"),
bins=20, range=(0, 1500))
None
```
%% Cell type:code id: tags:
``` python
c_sr = Cut(lambda d: d.m_jj > 400) & \
Cut(lambda d: d.jet_2_pt >= 30) & \
Cut(lambda d: d.jet_1_eta * d.jet_2_eta < 0) & \
Cut(lambda d: (d.jet_2_eta - d.jet_1_eta).abs() > 3)
c_sr.label = "Signal"
c_rest = (~c_sr)
c_rest.label = "Rest"
```
%% Cell type:code id: tags:
``` python
confusion_matrix(df, [p_sig, p_ztt], [c_sr, c_rest],
x_label="Signal", y_label="Region", annot=True, weight="weight")
confusion_matrix(df, [p_sig, p_ztt], [c_sr, c_rest], normalize_rows=True,
x_label="Signal", y_label="Region", annot=True, weight="weight")
None
```
%% Cell type:markdown id: tags:
## Neural Network
%% Cell type:code id: tags:
``` python
df['dijet_deta'] = (df.jet_1_eta - df.jet_2_eta).abs()
df['dijet_prod_eta'] = (df.jet_1_eta * df.jet_2_eta)
input_var = ['dijet_prod_eta', 'm_jj', 'dijet_deta', 'higgs_pt', 'jet_2_pt', 'jet_1_eta', 'jet_2_eta', 'tau_eta']
output_var = ['is_sig', 'is_ztt']
```
%% Cell type:code id: tags:
``` python
df["is_sig"] = p_sig.selection.idx_array(df)
df["is_ztt"] = p_ztt.selection.idx_array(df)
```
%% Cell type:code id: tags:
``` python
sample_df = df.sample(frac=1000 / len(df)).compute()
sns.pairplot(sample_df, vars=input_var, hue="is_sig")
None
```
%% Cell type:code id: tags:
``` python
def model():
m = Sequential()
m.add(Dense(units=15, activation='relu', input_dim=len(input_var)))
m.add(Dense(units=5, activation='relu'))
m.add(Dense(units=2, activation='softmax'))
m.compile(loss='categorical_crossentropy',
optimizer=SGD(lr=0.1),
metrics=['categorical_accuracy'])
weighted_metrics=['categorical_accuracy'])
return m
cv = ClassicalCV(5, frac_var='random')
net = HepNet(model, cv, EstimatorNormalizer, input_var, output_var)
```
%% Cell type:code id: tags:
``` python
sig_wf = len(p_sig.selection(df).weight) / p_sig.selection(df).weight.sum()
ztt_wf = len(p_ztt.selection(df).weight) / p_ztt.selection(df).weight.sum()
```
%% Cell type:code id: tags:
``` python
net.fit(df.compute(), epochs=150, verbose=0, batch_size=2048,
weight=Variable("weight", lambda d: d.weight * (d.is_sig * sig_wf + d.is_ztt * ztt_wf)))
```
%% Cell type:code id: tags:
``` python
sns.lineplot(x='epoch', y='loss', data=net.history, label="Training")
sns.lineplot(x='epoch', y='val_loss', data=net.history, label="Validation")
plt.ylabel("loss")
atlasify(False, "FreeForestML Example")
None
```
%% Cell type:markdown id: tags:
### Accuracy
%% Cell type:code id: tags:
``` python
sns.lineplot(x='epoch', y='categorical_accuracy', data=net.history, label="Training")
sns.lineplot(x='epoch', y='val_categorical_accuracy', data=net.history, label="Validation")
plt.ylabel("Accuracy")
atlasify(False, "FreeForestML Example")
None
```
%% Cell type:code id: tags:
``` python
sns.lineplot(x='epoch', y='val_categorical_accuracy', data=net.history, hue="fold")
plt.legend(loc=4)
atlasify(False, "FreeForestML Example")
None
```
%% Cell type:code id: tags:
``` python
out = net.predict(df.compute(), cv='test')
out['pred_sig'] = out.pred_is_sig >= 0.5
```
%% Cell type:code id: tags:
``` python
c_pred_sig = Process("Signal", lambda d: d.pred_is_sig >= 0.5)
c_pred_ztt = Process(r"$Z\rightarrow\tau\tau$", lambda d: d.pred_is_sig < 0.5)
confusion_matrix(out, [p_sig, p_ztt], [c_pred_sig, c_pred_ztt],
x_label="Truth", y_label="Classification", annot=True, weight="weight")
confusion_matrix(out, [p_sig, p_ztt], [c_pred_sig, c_pred_ztt], normalize_rows=True,
x_label="Truth", y_label="Classification", annot=True, weight="weight")
None
```
%% Cell type:markdown id: tags:
### Export to lwtnn
%% Cell type:markdown id: tags:
In order to use the network in lwtnn, we need to export the neural network with the `export()` method. This export one network per fold. It is the reposibility of the use to implement the cross validation in the analysis framework.
%% Cell type:code id: tags:
``` python
net.export("lwtnn")
```
%% Cell type:code id: tags:
``` python
!ls lwtnn*
```
%% Cell type:markdown id: tags:
The final, manuel step is to run the lwtnn's converter using the shortcut script `test.sh`.
%% Cell type:code id: tags:
``` python
```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment