diff --git a/nnfwtbn/tests/test_variable.py b/nnfwtbn/tests/test_variable.py new file mode 100644 index 0000000000000000000000000000000000000000..148b59a6f6b5aaf8e6a42f0537e99363ad96bac8 --- /dev/null +++ b/nnfwtbn/tests/test_variable.py @@ -0,0 +1,83 @@ + +import unittest +import numpy as np +import pandas as pd + +from nnfwtbn.variable import Variable, RangeBlinding + +class VariableTestCase(unittest.TestCase): + """ + Test the implementation of the variable class. + """ + + def test_init_store(self): + """ + Check that all arguments are stored in the object. + """ + blinding = RangeBlinding(100, 125) + variable = Variable("MMC", "ditau_mmc_mlm_m", "GeV", blinding) + + self.assertEqual(variable.name, "MMC") + self.assertIsNotNone(variable.definition) + self.assertEqual(variable.unit, "GeV") + self.assertEqual(variable.blinding, blinding) + + def test_init_definition_string(self): + """ + Check that a string used as the variable definition is wrapped into a + lambda. + """ + variable = Variable("MMC", "ditau_mmc_mlm_m", "GeV") + self.assertTrue(callable(variable.definition)) + + def test_init_blinding_type(self): + """ + Check that an error is thrown if the blinding object is not an + instance of the abstract blinding class. + """ + self.assertRaises(TypeError, "MMC", "ditau_mmc_mlm_m", "GeV", "blind") + + def test_repr(self): + """ + Check that the string representation contains the name of the + variable. + """ + variable = Variable("MMC", "ditau_mmc_mlm_m", "GeV") + self.assertEqual(repr(variable), "<Variable 'MMC' [GeV]>") + + variable = Variable(r"$\Delta \eta$", + lambda df: df.jet_0_eta - df.jet_1_eta) + self.assertEqual(repr(variable), r"<Variable '$\\Delta \\eta$'>") + + + def generate_df(self): + """ + Generate a toy dataframe. + """ + return pd.DataFrame({ + "x": np.arange(5), + "y": np.arange(5)**2 + }) + + + def test_call_column(self): + """ + Check that calling the variable extracts the given column name. + """ + df = self.generate_df() + + variable = Variable("$y$", "y") + y_col = variable(df) + + self.assertListEqual(list(y_col), [0, 1, 4, 9, 16]) + + def test_call_lambda(self): + """ + Check that calling the variable called the given lambda. + """ + df = self.generate_df() + + variable = Variable("$x + y$", lambda d: d.x + d.y) + sum = variable(df) + + self.assertListEqual(list(sum), [0, 2, 6, 12, 20]) diff --git a/nnfwtbn/variable.py b/nnfwtbn/variable.py index 96311c88104ebb0763ac69ad44c51782f4b48264..476760a3d0c3210be0d15ade7afe003e4cc641a4 100644 --- a/nnfwtbn/variable.py +++ b/nnfwtbn/variable.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod class Blinding(ABC): """ - The blinding class represents a blinding strategies. This is an abstract + The blinding class represents a blinding strategy. This is an abstract base class. Sub-classes must implement the __call__ method. """ @@ -22,7 +22,7 @@ class Blinding(ABC): class RangeBlinding(Blinding): """ - Concrete blinding strategy to which remove all events between a certain + Concrete blinding strategy which removes all events between a certain x-axis range. The range might be extended to match the bin borders. """ @@ -33,6 +33,9 @@ class RangeBlinding(Blinding): end. The range might be extended to match bin borders. """ + def __call__(self, dataframe, variable, bins, selection, range=None): + pass + class Variable: """ Representation of a quantity derived from the columns of a dataframe. The @@ -43,11 +46,11 @@ class Variable: labeling of axes. >>> Variable("MMC", "ditau_mmc_mlm_m", "GeV") - <Variable: MMC [GeV]> + <Variable 'MMC' [GeV]> """ def __init__(self, name, definition, unit=None, blinding=None): - """ + r""" Returns a new variable object. The first argument is a human-readable name (potentially using latex). The second argument defines the value of the variable. This can be a string naming the column of the @@ -55,10 +58,10 @@ class Variable: passed to it. >>> Variable("MMC", "ditau_mmc_mlm_m", "GeV") - <Variable: MMC [GeV]> + <Variable 'MMC' [GeV]> - >>> Variable(r"$\Delta \eta$", lambda df: df.jet_0_eta - df.jet_1_eta) - <Variable: $\Delta \eta$ > + >>> Variable("$\\Delta \\eta$", lambda df: df.jet_0_eta - df.jet_1_eta) + <Variable '$\\Delta \\eta$'> The optional argument unit defines the unit of the variable. This information is used for plotting, especially for labeling axes. @@ -67,14 +70,32 @@ class Variable: the blinding strategy. """ + if isinstance(definition, str): + # Wrap column string by lambda + self.definition = lambda d: d[definition] + else: + self.definition = definition + self.name = name + self.unit = unit + + if blinding is not None and not isinstance(blinding, Blinding): + raise InvalidBlinding("Blinding object must inherit from " + "Blinding class.") + self.blinding = blinding + + def __call__(self, dataframe): """ Returns an array or series of variable computed from the given - dataframe. + dataframe. This method does not apply the blinding! """ - + return self.definition(dataframe) def __repr__(self): """ Returns a string representation. """ + if self.unit is None: + return "<Variable %s>" % repr(self.name) + else: + return "<Variable %s [%s]>" % (repr(self.name), self.unit)