From 3953a4e0c95a42b2d2662174d775c1dabbbae803 Mon Sep 17 00:00:00 2001
From: Frank Sauerburger <f.sauerburger@cern.ch>
Date: Fri, 23 Oct 2020 19:14:58 +0200
Subject: [PATCH] Add new implementation draft

---
 nnfwtbn/plot.py | 198 +++++++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 196 insertions(+), 2 deletions(-)

diff --git a/nnfwtbn/plot.py b/nnfwtbn/plot.py
index 86be94f..7d4a9e0 100644
--- a/nnfwtbn/plot.py
+++ b/nnfwtbn/plot.py
@@ -9,6 +9,7 @@ import dask.dataframe as dd
 import seaborn as sns
 
 from atlasify import atlasify
+import uhepp
 
 from nnfwtbn.stack import Stack
 from nnfwtbn.process import Process
@@ -60,7 +61,7 @@ def uhepplot(dataframe, variable, bins, stacks, selection=None,
          weight=None, y_log=False, y_min=None, vlines=[],
          denominator=0, numerator=-1, ratio_label=None, diff=False,
          ratio_range=None, atlas=None, info=None, enlarge=1.6,
-         density=False, include_outside=False, **kwds):
+         density=False, include_outside=False, return_uhepp=False, **kwds):
     """
     Creates a histogram of stacked processes. The first argument is the
     dataframe to operate on. The 'variable' argument defines the x-axis. The
@@ -119,8 +120,201 @@ def uhepplot(dataframe, variable, bins, stacks, selection=None,
     If the density argument is True, the area of each stack is normalized to
     unity.
 
-    The method returns a Universal HEP plot object.
+    If return_uhepp is True, the method return a UHepPlot object.
     """
+    uhepp_obj = uhepp.UHepPlot()
+    uhepp_obj.producer = "nnfwtbn"
+
+    if (atlas is not False) or (info is not False):
+        uhepp_obj.atlas = True
+
+        atlas, info = fill_labels(atlas, info)
+        uhepp_obj.label = atlas
+        uhepp_obj.subtext = info
+
+    # Wrap column string by variable
+    if isinstance(variable, str):
+        variable = Variable(variable, variable)
+
+    uhepp_obj.symbol = variable.name
+    if variable.unit is not None:
+        uhepp_obj.unit = variable.unit
+
+    uhepp_obj.log = y_log
+
+    if weight is None:
+        weight = Variable("unity", lambda d: variable(d) * 0 + 1)
+    elif isinstance(weight, str):
+        weight = Variable(weight, weight)
+
+    squared_weight = Variable("squared weight", lambda d: weight(d)**2)
+
+    draw_ratio = (denominator is not None) and (numerator is not None)
+
+    # Disable ratio plots if there is only one stack 
+    if len(stacks) == 1 and \
+            isinstance(denominator, int) and \
+            isinstance(numerator, int):
+        draw_ratio = False
+
+    # Handle selection
+    if selection is None:
+        selection = Cut(lambda d: variable(d) * 0 == 0)
+    elif not isinstance(selection, Cut):
+        selection = Cut(selection)
+
+    # Handle range/bins
+    if range is not None:
+        # Build bins
+        if not isinstance(bins, int):
+            raise err.InvalidBins("When range is given, bins must be int.")
+        if not isinstance(range, tuple)  or len(range) != 2:
+            raise err.InvalidProcessSelection("Range argument must be a "
+                                              "tuple of two numbers.")
+        bins = np.linspace(range[0], range[1], bins + 1)
+
+    uhepp_obj.bin_edges = [float(x) for x in bins]
+    uhepp_obj.include_overflow = include_outside
+    uhepp_obj.include_overflow = include_outside
+
+    # Create separate under- and overflow bin
+    def pad(edges):
+        padded = [edges[0] - 1] + list(edges) + [edges[-1] + 1]
+        return np.asarray(padded)
+    bins = pad(bins)
+
+    if blind is None:
+        blind = []
+    elif isinstance(blind, Stack):
+        # Wrap scalar blind stack
+        blind = [blind]
+
+    # Handle stack
+    yields = {}
+    for i_stack, stack in enumerate(stacks):
+        uhepp_stack = uhepp.Stack()
+        uhepp_obj.stacks.append(uhepp_stack)
+        uhepp_stack.type = stack.get_histtype(0)
+
+        if stack in blind and variable.blinding is not None:
+            c_blind = variable.blinding(variable, bins)
+        else:
+            c_blind = lambda d: d
+
+        density_norm = 1.0
+        if density:
+            density_norm = stack.get_total(dataframe,
+                                           [float('-inf'), float('inf')],
+                                           variable, weight, include_outside).sum()
+             
+        for i_process, process in enumerate(stack.processes):
+            histogram = stack.get_hist(c_blind(dataframe), i_process, bins,
+                                       variable, weight, include_outside)
+
+            histogram = histogram / density_norm
+
+            uncertainty = stack.get_total_uncertainty(c_blind(dataframe),
+                                                      bins, variable,
+                                                      weight,
+                                                      include_outside)
+            uncertainty = uncertainty / density_norm
+
+            yield_ = {}
+            yield_["base"] = histogram
+            yield_["stat"] = uncertainty
+
+            process_id = f"s{i_stack}_p{i_process}"
+            yields[process_id] = yield_
+
+            uhepp_sitem = uhepp.StackItem()
+            uhepp_sitem.label = process.label
+            uhepp_sitem.style = stack.get_aux(i_process)
+            uhepp_sitem.yield_ = [process_id]
+            uhepp_stack.content.append(uhepp_sitem)
+
+        # Resolve numerator/denominator indices
+        if isinstance(numerator, int) and stacks[numerator] == stack:
+            numerator = stack
+        if isinstance(denominator, int) and stacks[denominator] == stack:
+            denominator= stack
+
+
+    if draw_ratio:
+        ratio = []
+
+        # Get denominator hist
+        if denominator in blind and variable.blinding is not None:
+            c_blind = variable.blinding(variable, bins)
+        else:
+            c_blind = lambda d: d
+
+        yield_ = {}
+        yield_["base"] = list(denominator.get_total(c_blind(dataframe), bins,
+                                               variable, weight,
+                                               include_outside))
+        yield_["stat"] = list(denominator.get_total_uncertainty(c_blind(dataframe),
+                                                           bins, variable,
+                                                           weight,
+                                                           include_outside))
+        yields["den"] = yield_
+
+        # Process numerators
+        numerators = numerator
+        if not isinstance(numerators, (list, tuple)):
+            numerators = [numerators]
+
+        numerators_data = []
+        for i_numerator, numerator in enumerate(numerators):
+            if numerator in blind and variable.blinding is not None:
+                c_blind = variable.blinding(variable, bins)
+            else:
+                c_blind = lambda d: d
+
+            yield_ = {}
+            process_id = f"num_{i_numerator}"
+            yield_["base"] = numerator.get_total(c_blind(dataframe), bins,
+                                                    variable, weight,
+                                                    include_outside)
+            yield_["stat"] = numerator.get_total_uncertainty(c_blind(dataframe),
+                                                                bins,
+                                                                variable, weight,
+                                                                include_outside)
+            yields[process_id] = yield_
+
+
+            uhepp_ritem = uhepp.RatioItem()
+            uhepp_ritem.numerator = [process_id]
+            uhepp_ritem.denominator = ["den"]
+            uhepp_ritem.style = {'markersize': 4, 'fmt': 'o'}
+            uhepp_ritem.style.update(numerator.get_aux(0))
+
+            histtype = numerator.get_histtype(0)
+            uhepp_ritem.type = "points" if histtype == "points" else "steps"
+
+            ratio.append(uhepp_ritem)
+
+        uhepp_obj.ratio = ratio
+
+    # Compute delayed dask 
+    yields, = dask.compute(yields)
+    sterile_yields = {}
+    for key, item in yields.items():
+
+        sterile_item = {}
+        for item_key, sub_item in item.items():
+            sterile_item[item_key] = [float(x) for x in sub_item]
+
+        sterile_yields[key] = uhepp.Yield(sterile_item)
+
+    # TODO: Reorder processes if y_log
+
+    uhepp_obj.yields = sterile_yields
+
+    # TODO: vertical and horizontal lines
+
+    if return_uhepp:
+        return uhepp_obj
+
 
 
 def hist(dataframe, variable, bins, stacks, selection=None,
-- 
GitLab