diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 08c39a0f4ce68c896d2b2ffe26924b0645beebeb..7d6055bd004ba38487875b108edeed349bafb3bf 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -9,6 +9,7 @@ doctest: image: python:3.7 script: - pip install -r requirements.txt + - pip install -r dev-requirements.txt - "python -m doctest -v $(ls freeforestml/*.py | grep -v '__init__.py')" unittest: @@ -16,6 +17,7 @@ unittest: image: python:3.7 script: - pip install -r requirements.txt + - pip install -r dev-requirements.txt - pip install pytest - pytest diff --git a/dev-requirements.txt b/dev-requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e214769a654c0d5b900d76a14e8910929d9dc52f --- /dev/null +++ b/dev-requirements.txt @@ -0,0 +1 @@ +pyspark diff --git a/freeforestml/cut.py b/freeforestml/cut.py index cf4e1e45c271627a878ceba8acc1594e01e360fd..b34114d823ba62bd42ca1a28a3883f8433be11ed 100644 --- a/freeforestml/cut.py +++ b/freeforestml/cut.py @@ -86,6 +86,12 @@ class Cut: Applies the internally stored cut to the given dataframe and returns a new dataframe containing only entries passing the event selection. """ + if self.func is None: + return dataframe + + if hasattr(dataframe, "_sc"): + return dataframe.filter(self.func(dataframe)) + return dataframe[self.idx_array(dataframe)] def idx_array(self, dataframe): diff --git a/freeforestml/tests/test_spark.py b/freeforestml/tests/test_spark.py new file mode 100644 index 0000000000000000000000000000000000000000..cfe9e90029ae9549e3fdb4fd17823e411ae846c8 --- /dev/null +++ b/freeforestml/tests/test_spark.py @@ -0,0 +1,56 @@ + +import unittest +from pyspark.sql import SparkSession +import freeforestml as ff + +class TestparkCut(unittest.TestCase): + """Check that cut object can operate on PySpark dataframes""" + def setUp(self): + """Create or get spark session""" + self.spark = SparkSession.builder.master("local").getOrCreate() + self.sc = self.spark.sparkContext + + def test_func(self): + """Check that lambda-cut can be applied to a spark dataframe""" + rdd = self.sc.parallelize([(x, x**2) for x in range(100)]) + df = rdd.toDF(["x", "y"]) + + cut = ff.Cut(lambda d: d.y < 100) + selected = cut(df).count() + self.assertEqual(selected, 10) + + def test_and(self): + """Check that derived cuts can be used on a pyspark dataframe""" + rdd = self.sc.parallelize([(x, x**2) for x in range(100)]) + df = rdd.toDF(["x", "y"]) + + cut_y = ff.Cut(lambda d: d.y < 1000) + cut_x = ff.Cut(lambda d: d.x > 2) + cut = cut_x & cut_y + + selected = cut(df).count() + self.assertEqual(selected, 29) + + def test_or(self): + """Check that derived cuts can be used on a pyspark dataframe""" + rdd = self.sc.parallelize([(x, x**2) for x in range(100)]) + df = rdd.toDF(["x", "y"]) + + cut_y = ff.Cut(lambda d: d.y < 1000) + cut_x = ff.Cut(lambda d: d.x > 2) + cut = cut_x | cut_y + + selected = cut(df).count() + self.assertEqual(selected, 29) + + def test_or(self): + """Check that derived cuts can be used on a pyspark dataframe""" + rdd = self.sc.parallelize([(x, x**2) for x in range(100)]) + df = rdd.toDF(["x", "y"]) + + cut_y = ff.Cut(lambda d: d.y < 1000) + cut_x = ff.Cut(lambda d: d.x > 2) + cut = cut_x | cut_y + + selected = cut(df).count() + self.assertEqual(selected, 29)