From 6e6ead6b47c9ccf61334029d5553d83289a94507 Mon Sep 17 00:00:00 2001
From: Frank Sauerburger <f.sauerburger@cern.ch>
Date: Wed, 26 Jun 2019 22:17:07 +0200
Subject: [PATCH] Implement histtype treatment

---
 histogram.ipynb | 25 ++++++++++++++----------
 nnfwtbn/plot.py | 51 +++++++++++++++++++++++++++++++------------------
 2 files changed, 47 insertions(+), 29 deletions(-)

diff --git a/histogram.ipynb b/histogram.ipynb
index 2537a3a..bc459f8 100644
--- a/histogram.ipynb
+++ b/histogram.ipynb
@@ -11,19 +11,19 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 10,
+   "execution_count": 14,
    "metadata": {},
    "outputs": [],
    "source": [
     "import pandas as pd\n",
     "import seaborn as sns\n",
-    "df = pd.read_hdf(\"test.h5\")\n",
-    "#df = pd.read_hdf(\"demo/mva.h5\")"
+    "#df = pd.read_hdf(\"test.h5\")\n",
+    "df = pd.read_hdf(\"demo/mva.h5\")"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 3,
+   "execution_count": 11,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -32,7 +32,11 @@
     "p_zll = Process(r\"$Z\\rightarrow\\ell\\ell$\", range=(-599, -500))\n",
     "p_fake = Process(r\"Fake\", range=(-199, -100))\n",
     "\n",
-    "p_sig = Process(r\"Signal\", range=(1, 1000))"
+    "p_sig = Process(r\"Signal\", range=(1, 1000))\n",
+    "\n",
+    "p_data = Process(r\"Data\", range=(0, 0))\n",
+    "p_asimov = Process(r\"Asimov\", selection=(\n",
+    "        p_top.selection | p_ztt.selection | p_zll.selection | p_fake.selection | p_sig.selection ))"
    ]
   },
   {
@@ -59,14 +63,14 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 14,
+   "execution_count": 19,
    "metadata": {},
    "outputs": [
     {
      "data": {
-      "image/png": "\n",
+      "image/png": "\n",
       "text/plain": [
-       "<matplotlib.figure.Figure at 0x7f855ba95630>"
+       "<matplotlib.figure.Figure at 0x7fc1553c4e10>"
       ]
      },
      "metadata": {},
@@ -74,8 +78,9 @@
     }
    ],
    "source": [
-    "hist(df, v_mmc, 40, [p_fake, p_top, p_zll, p_ztt], [p_sig], range=(0, 200), selection=c_vbf,\n",
-    "     weight=\"weight\", color=[sns.color_palette(\"Blues\"), sns.color_palette()[1:]])\n",
+    "hist(df, v_mmc, 20, [p_fake, p_top, p_zll, p_ztt, p_sig], [p_asimov], range=(0, 200), selection=c_vbf,\n",
+    "     weight=\"weight\", color=[sns.color_palette(\"Blues\")[:4] + sns.color_palette()[1:], ['black']],\n",
+    "     histtype=[['stepfilled']*4 + ['step'], 'points'])\n",
     "None"
    ]
   },
diff --git a/nnfwtbn/plot.py b/nnfwtbn/plot.py
index 9167e6c..31e2f11 100644
--- a/nnfwtbn/plot.py
+++ b/nnfwtbn/plot.py
@@ -33,19 +33,6 @@ class HistogramFactory:
         The method returns the return value of hist.
         """
 
-
-def _type_to_histtype(type):
-    """
-    Returns the matplotlib histogram type for a given process plotting type.
-
-    >>> _type_to_histtype("fill")
-    'stepfilled'
-    >>> _type_to_histtype("line")
-    'step'
-    """
-    type_map = {"fill": "stepfilled", "line": "step"} 
-    return type_map[type]
-
 def hist(dataframe, variable, bins, *stacks, selection=None,
          range=None, blind=None, axes=None, figure=None,
          weight=None, **kwds):
@@ -155,12 +142,38 @@ def hist(dataframe, variable, bins, *stacks, selection=None,
                 stack_props = props[i_stack]
                 process_kwds[kwd] = stack_props[i_process % len(stack_props)]
 
-            n, _, _ = axes.hist(variable(dataframe[sel(dataframe)]),
-                                bins=bins, range=range,
-                                bottom=bottom,
-                                label=process.label,
-                                weights=weight(dataframe[sel(dataframe)]),
-                                **process_kwds)
+            if "histtype" in process_kwds:
+                if process_kwds["histtype"] == "points":
+                    del process_kwds['histtype']
+
+                    defaults = {
+                        'markersize': 4,
+                        'fmt': 'o'
+                    }
+
+                    defaults.update(process_kwds)
+                    process_kwds = defaults
+
+                    n, _ = np.histogram(
+                        variable(dataframe[sel(dataframe)]),
+                        bins=bins, range=range,
+                        weights=weight(dataframe[sel(dataframe)]))
+
+                    bin_centers = (bins[1:] + bins[:-1]) / 2
+                    bin_widths = bins[1:] - bins[:-1]
+
+                    axes.errorbar(bin_centers, bottom + n, np.sqrt(n), bin_widths / 2,
+                        label=process.label,
+                        **process_kwds)
+
+                else:
+                    n, _, _ = axes.hist(
+                        variable(dataframe[sel(dataframe)]),
+                        bins=bins, range=range,
+                        bottom=bottom,
+                        label=process.label,
+                        weights=weight(dataframe[sel(dataframe)]),
+                        **process_kwds)
             bottom += n 
 
     axes.set_xlim((bins.min(), bins.max()))
-- 
GitLab