From ba465c0918aaa9076084601f7fe1418573846d0a Mon Sep 17 00:00:00 2001
From: Frank Sauerburger <f.sauerburger@cern.ch>
Date: Tue, 25 Jun 2019 12:49:48 +0200
Subject: [PATCH] Implement and test Variable

---
 nnfwtbn/tests/test_variable.py | 83 ++++++++++++++++++++++++++++++++++
 nnfwtbn/variable.py            | 39 ++++++++++++----
 2 files changed, 113 insertions(+), 9 deletions(-)
 create mode 100644 nnfwtbn/tests/test_variable.py

diff --git a/nnfwtbn/tests/test_variable.py b/nnfwtbn/tests/test_variable.py
new file mode 100644
index 0000000..148b59a
--- /dev/null
+++ b/nnfwtbn/tests/test_variable.py
@@ -0,0 +1,83 @@
+
+import unittest
+import numpy as np
+import pandas as pd
+
+from nnfwtbn.variable import Variable, RangeBlinding
+
+class VariableTestCase(unittest.TestCase):
+    """
+    Test the implementation of the variable class.
+    """
+
+    def test_init_store(self):
+        """
+        Check that all arguments are stored in the object.
+        """
+        blinding = RangeBlinding(100, 125)
+        variable = Variable("MMC", "ditau_mmc_mlm_m", "GeV", blinding)
+
+        self.assertEqual(variable.name, "MMC")
+        self.assertIsNotNone(variable.definition)
+        self.assertEqual(variable.unit, "GeV")
+        self.assertEqual(variable.blinding, blinding)
+
+    def test_init_definition_string(self):
+        """
+        Check that a string used as the variable definition is wrapped into a
+        lambda.
+        """
+        variable = Variable("MMC", "ditau_mmc_mlm_m", "GeV")
+        self.assertTrue(callable(variable.definition))
+
+    def test_init_blinding_type(self):
+        """
+        Check that an error is thrown if the blinding object is not an
+        instance of the abstract blinding class.
+        """
+        self.assertRaises(TypeError, "MMC", "ditau_mmc_mlm_m", "GeV", "blind")
+
+    def test_repr(self):
+        """
+        Check that the string representation contains the name of the
+        variable.
+        """
+        variable = Variable("MMC", "ditau_mmc_mlm_m", "GeV")
+        self.assertEqual(repr(variable), "<Variable 'MMC' [GeV]>")
+
+        variable = Variable(r"$\Delta \eta$",
+                            lambda df: df.jet_0_eta - df.jet_1_eta)
+        self.assertEqual(repr(variable), r"<Variable '$\\Delta \\eta$'>")
+
+
+    def generate_df(self):
+        """
+        Generate a toy dataframe.
+        """
+        return pd.DataFrame({
+            "x": np.arange(5),
+            "y": np.arange(5)**2
+        })
+            
+
+    def test_call_column(self):
+        """
+        Check that calling the variable extracts the given column name.
+        """
+        df = self.generate_df()
+
+        variable = Variable("$y$", "y")
+        y_col = variable(df)
+
+        self.assertListEqual(list(y_col), [0, 1, 4, 9, 16])
+
+    def test_call_lambda(self):
+        """
+        Check that calling the variable called the given lambda.
+        """
+        df = self.generate_df()
+
+        variable = Variable("$x + y$", lambda d: d.x + d.y)
+        sum = variable(df)
+
+        self.assertListEqual(list(sum), [0, 2, 6, 12, 20])
diff --git a/nnfwtbn/variable.py b/nnfwtbn/variable.py
index 96311c8..476760a 100644
--- a/nnfwtbn/variable.py
+++ b/nnfwtbn/variable.py
@@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
 
 class Blinding(ABC):
     """
-    The blinding class represents a blinding strategies. This is an abstract
+    The blinding class represents a blinding strategy. This is an abstract
     base class. Sub-classes must implement the __call__ method.
     """
 
@@ -22,7 +22,7 @@ class Blinding(ABC):
         
 class RangeBlinding(Blinding):
     """
-    Concrete blinding strategy to which remove all events between a certain
+    Concrete blinding strategy which removes all events between a certain
     x-axis range. The range might be extended to match the bin borders.
     """
 
@@ -33,6 +33,9 @@ class RangeBlinding(Blinding):
         end. The range might be extended to match bin borders.
         """
 
+    def __call__(self, dataframe, variable, bins, selection, range=None):
+        pass
+
 class Variable:
     """
     Representation of a quantity derived from the columns of a dataframe. The
@@ -43,11 +46,11 @@ class Variable:
     labeling of axes.
 
     >>> Variable("MMC", "ditau_mmc_mlm_m", "GeV")
-    <Variable: MMC [GeV]>
+    <Variable 'MMC' [GeV]>
     """
 
     def __init__(self, name, definition, unit=None, blinding=None):
-        """
+        r"""
         Returns a new variable object. The first argument is a human-readable
         name (potentially using latex). The second argument defines the value
         of the variable. This can be a string naming the column of the
@@ -55,10 +58,10 @@ class Variable:
         passed to it.
 
         >>> Variable("MMC", "ditau_mmc_mlm_m", "GeV")
-        <Variable: MMC [GeV]>
+        <Variable 'MMC' [GeV]>
 
-        >>> Variable(r"$\Delta \eta$", lambda df: df.jet_0_eta - df.jet_1_eta)
-        <Variable: $\Delta \eta$ >
+        >>> Variable("$\\Delta \\eta$", lambda df: df.jet_0_eta - df.jet_1_eta)
+        <Variable '$\\Delta \\eta$'>
 
         The optional argument unit defines the unit of the variable. This
         information is used for plotting, especially for labeling axes.
@@ -67,14 +70,32 @@ class Variable:
         the blinding strategy.
         """ 
 
+        if isinstance(definition, str):
+            # Wrap column string by lambda
+            self.definition = lambda d: d[definition]
+        else:
+            self.definition = definition
+        self.name = name
+        self.unit = unit
+
+        if blinding is not None and not isinstance(blinding, Blinding):
+            raise InvalidBlinding("Blinding object must inherit from "
+                                  "Blinding class.")
+        self.blinding = blinding
+
+
     def __call__(self, dataframe):
         """
         Returns an array or series of variable computed from the given
-        dataframe.
+        dataframe. This method does not apply the blinding!
         """
-
+        return self.definition(dataframe)
 
     def __repr__(self):
         """
         Returns a string representation.
         """
+        if self.unit is None:
+            return "<Variable %s>" % repr(self.name)
+        else:
+            return "<Variable %s [%s]>" % (repr(self.name), self.unit)
-- 
GitLab