Skip to content
Snippets Groups Projects
Unverified Commit 87dc8df5 authored by Frank Sauerburger's avatar Frank Sauerburger
Browse files

Implement and test RangeBlinding

parent 562f197e
No related branches found
No related tags found
No related merge requests found
Pipeline #12312 passed
__version__ = "0.0.0"
from nnfwtbn.variable import Variable, RangeBlinding
from nnfwtbn.variable import Variable, RangeBlindingStrategy
from nnfwtbn.process import Process
from nnfwtbn.cut import Cut
from nnfwtbn.plot import HistogramFactory, hist
......@@ -2,3 +2,4 @@
class InvalidProcessSelection(ValueError): pass
class InvalidProcessType(ValueError): pass
class InvalidBlinding(TypeError): pass
class InvalidBins(TypeError): pass
......@@ -14,7 +14,7 @@ class HistogramFactory:
def __call__(self, *args, **kwds):
"""
Proxy for method to hist(). The positional argument passed to hist()
are the positional argument given to the constructor concatinated with
are the positional argument given to the constructor concatenated with
the positional argument given to this method. The keyword argument for
hist() is the union of the keyword arguments passed to the constructor
and this method. The argument passed to this method have precedence.
......
......@@ -3,7 +3,7 @@ import unittest
import numpy as np
import pandas as pd
from nnfwtbn.variable import Variable, RangeBlinding
from nnfwtbn.variable import Variable, RangeBlindingStrategy
class VariableTestCase(unittest.TestCase):
"""
......@@ -14,7 +14,7 @@ class VariableTestCase(unittest.TestCase):
"""
Check that all arguments are stored in the object.
"""
blinding = RangeBlinding(100, 125)
blinding = RangeBlindingStrategy(100, 125)
variable = Variable("MMC", "ditau_mmc_mlm_m", "GeV", blinding)
self.assertEqual(variable.name, "MMC")
......@@ -81,3 +81,117 @@ class VariableTestCase(unittest.TestCase):
sum = variable(df)
self.assertListEqual(list(sum), [0, 2, 6, 12, 20])
class RangeBlindingTestCase(unittest.TestCase):
"""
Test the implementation of the RangeBlinding class.
"""
def generate_df(self):
"""
Returns a toy dataframe.
"""
return pd.DataFrame({
"ditau_mmc_mlm_m": np.linspace(0, 400, 400),
"x": np.linspace(0, 1, 400),
})
def test_init_store(self):
"""
Check that the constructor stores all arguments.
"""
blinding = RangeBlindingStrategy(100, 125)
self.assertEquals(blinding.start, 100)
self.assertEquals(blinding.end, 125)
def test_event_blinding(self):
"""
Check that events in the given region are removed.
"""
blinding_strategy = RangeBlindingStrategy(100, 125)
variable = Variable("MMC", "ditau_mmc_mlm_m")
df = self.generate_df()
blinding = blinding_strategy(variable, bins=30, range=(50, 200))
blinded_df = df[blinding(df)]
# All events outside
self.assertTrue((
(blinded_df.ditau_mmc_mlm_m < 100) |
(blinded_df.ditau_mmc_mlm_m > 125)).all())
# No events inside
self.assertFalse((
(blinded_df.ditau_mmc_mlm_m > 100) &
(blinded_df.ditau_mmc_mlm_m < 125)).any())
# Boundary not enlarged
self.assertTrue((
(blinded_df.ditau_mmc_mlm_m > 100) &
(blinded_df.ditau_mmc_mlm_m < 130)).any())
def test_bin_border(self):
"""
Check that the blind range is extended to match the bin borders.
"""
blinding_strategy = RangeBlindingStrategy(100, 125)
variable = Variable("MMC", "ditau_mmc_mlm_m")
df = self.generate_df()
blinding = blinding_strategy(variable, bins=15, range=(50, 200))
blinded_df = df[blinding(df)]
# All events outside
self.assertTrue((
(blinded_df.ditau_mmc_mlm_m < 100) |
(blinded_df.ditau_mmc_mlm_m > 130)).all())
# No events inside
self.assertFalse((
(blinded_df.ditau_mmc_mlm_m > 100) &
(blinded_df.ditau_mmc_mlm_m < 130)).any())
def test_bin_border_left(self):
"""
Check that the blinding does not break if the blinding is left of the
first bin.
"""
blinding_strategy = RangeBlindingStrategy(10, 125)
variable = Variable("MMC", "ditau_mmc_mlm_m")
df = self.generate_df()
blinding = blinding_strategy(variable, bins=15, range=(50, 200))
blinded_df = df[blinding(df)]
# All events outside
self.assertTrue((
(blinded_df.ditau_mmc_mlm_m < 10) |
(blinded_df.ditau_mmc_mlm_m > 130)).all())
# No events inside
self.assertFalse((
(blinded_df.ditau_mmc_mlm_m > 10) &
(blinded_df.ditau_mmc_mlm_m < 130)).any())
def test_bin_border_right(self):
"""
Check that the blinding does not break if the blinding is left of the
first bin.
"""
blinding_strategy = RangeBlindingStrategy(100, 225)
variable = Variable("MMC", "ditau_mmc_mlm_m")
df = self.generate_df()
blinding = blinding_strategy(variable, bins=15, range=(50, 200))
blinded_df = df[blinding(df)]
# All events outside
self.assertTrue((
(blinded_df.ditau_mmc_mlm_m < 100) |
(blinded_df.ditau_mmc_mlm_m > 225)).all())
# No events inside
self.assertFalse((
(blinded_df.ditau_mmc_mlm_m > 100) &
(blinded_df.ditau_mmc_mlm_m < 225)).any())
from abc import ABC, abstractmethod
class Blinding(ABC):
import numpy as np
from nnfwtbn.cut import Cut
class BlindingStrategy(ABC):
"""
The blinding class represents a blinding strategy. This is an abstract
base class. Sub-classes must implement the __call__ method.
The BlindingStrategy class represents a blinding strategy. This is an
abstract base class. Sub-classes must implement the __call__ method.
"""
@abstractmethod
def __call__(self, dataframe, variable, bins, selection, range=None):
def __call__(self, dataframe, variable, bins, range=None):
"""
Returns the final selection to apply in order to blind a process. The
Returns the additional selection in order to blind a process. The
first argument is the dataframe to operate on. The second argument is
the variable whose histogram should be blinded. The arguments bins and
range are identical to the ones for the hist method. They might be
used in sub-classes to align the blinding cuts to bin borders.
The method returns the logical AND of the selection argument and the
blinding cut.
used in sub-classes to align the blinding cuts to bin borders.
"""
raise NotImplementedError()
class RangeBlinding(Blinding):
class RangeBlindingStrategy(BlindingStrategy):
"""
Concrete blinding strategy which removes all events between a certain
x-axis range. The range might be extended to match the bin borders.
......@@ -28,13 +29,43 @@ class RangeBlinding(Blinding):
def __init__(self, start, end):
"""
Returns a new RangeBlinding object. When the object is called, it
returns a selection removing all events that lay between start and
Returns a new RangeBlindingStrategy object. When the object is called,
it returns a selection removing all events that lay between start and
end. The range might be extended to match bin borders.
"""
self.start = start
self.end = end
def __call__(self, dataframe, variable, bins, selection, range=None):
pass
def __call__(self, variable, bins, range=None):
"""
See base class. Returns the additional selection.
"""
if range is not None:
# Build bins
if not isinstance(bins, int):
raise err.InvalidBins("When range is given, bins must be int.")
if not isinstance(range, tuple) or len(range) != 2:
raise err.InvalidProcessSelection("Range argument must be a "
"tuple of two numbers.")
bins = np.linspace(range[0], range[1], bins + 1)
start = self.start
if bins.min() < start and start < bins.max():
# Align to bin border
diff = bins - start
diff[diff > 0] -= float('inf')
start = bins[diff.argmax()]
end = self.end
if bins.min() < end and end < bins.max():
# Align to bin border
diff = bins - end
diff[diff < 0] += float('inf')
end = bins[diff.argmin()]
return Cut(lambda d: (variable(d) < start)
| (variable(d) > end))
class Variable:
"""
......@@ -78,7 +109,7 @@ class Variable:
self.name = name
self.unit = unit
if blinding is not None and not isinstance(blinding, Blinding):
if blinding is not None and not isinstance(blinding, BlindingStrategy):
raise InvalidBlinding("Blinding object must inherit from "
"Blinding class.")
self.blinding = blinding
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment