Skip to content
Snippets Groups Projects
Unverified Commit e83a812c authored by Frank Sauerburger's avatar Frank Sauerburger
Browse files

Merge branch '18-rotate-confusion-matrix'

parents e698251f c1541045
No related branches found
No related tags found
1 merge request!27Resolve "Rotate confusion matrix"
Pipeline #12498 passed
This diff is collapsed.
......@@ -493,7 +493,8 @@ def hist(dataframe, variable, bins, stacks, selection=None,
return figure, axes
def confusion_matrix(df, x_processes, y_processes, x_label, y_label,
weight=None, axes=None, figure=None, **kwds):
weight=None, axes=None, figure=None,
normalize_rows=False, **kwds):
"""
Creates a confusion matrix.
"""
......@@ -510,6 +511,11 @@ def confusion_matrix(df, x_processes, y_processes, x_label, y_label,
axes = figure.subplots()
y_processes.reverse()
if normalize_rows:
# Swap processes
y_processes, x_processes = x_processes, y_processes
data = np.zeros((len(y_processes), len(x_processes)))
for i_x, x_process in enumerate(x_processes):
x_df = x_process(df)
......@@ -518,14 +524,26 @@ def confusion_matrix(df, x_processes, y_processes, x_label, y_label,
x_y_df = y_process(x_df)
data[i_y][i_x] = weight(x_y_df).sum() / total_weight
if normalize_rows:
# Swap processes back
y_processes, x_processes = x_processes, y_processes
data = data.T
data = pd.DataFrame(data,
columns=[p.label for p in x_processes],
index=[p.label for p in y_processes])
if normalize_rows:
# Swap labels
y_label, x_label = x_label, y_label
sns.heatmap(data, **dict(vmin=0, vmax=1, cmap="Greens", ax=axes,
cbar_kws={
'label': "$P($%s$|$%s$)$" % (y_label, x_label)
}
), **kwds)
if normalize_rows:
# Swap labels back
y_label, x_label = x_label, y_label
axes.set_xlabel(x_label)
axes.set_ylabel(y_label)
axes.set_ylim(len(y_processes), 0)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment