From f91cb672725c58a98d01a963b900748a1b7c4f85 Mon Sep 17 00:00:00 2001
From: Frank Sauerburger <f.sauerburger@cern.ch>
Date: Fri, 30 Apr 2021 17:23:14 +0200
Subject: [PATCH] Prepare cuts for PySpark Dataframes

---
 .gitlab-ci.yml                   |  2 ++
 dev-requirements.txt             |  1 +
 freeforestml/cut.py              |  6 ++++
 freeforestml/tests/test_spark.py | 56 ++++++++++++++++++++++++++++++++
 4 files changed, 65 insertions(+)
 create mode 100644 dev-requirements.txt
 create mode 100644 freeforestml/tests/test_spark.py

diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 08c39a0..7d6055b 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -9,6 +9,7 @@ doctest:
   image: python:3.7
   script:
     - pip install -r requirements.txt
+    - pip install -r dev-requirements.txt
     - "python -m doctest -v $(ls freeforestml/*.py | grep -v '__init__.py')"
 
 unittest:
@@ -16,6 +17,7 @@ unittest:
   image: python:3.7
   script:
     - pip install -r requirements.txt
+    - pip install -r dev-requirements.txt
     - pip install pytest
     - pytest
 
diff --git a/dev-requirements.txt b/dev-requirements.txt
new file mode 100644
index 0000000..e214769
--- /dev/null
+++ b/dev-requirements.txt
@@ -0,0 +1 @@
+pyspark
diff --git a/freeforestml/cut.py b/freeforestml/cut.py
index cf4e1e4..b34114d 100644
--- a/freeforestml/cut.py
+++ b/freeforestml/cut.py
@@ -86,6 +86,12 @@ class Cut:
         Applies the internally stored cut to the given dataframe and returns a
         new dataframe containing only entries passing the event selection.
         """
+        if self.func is None:
+            return dataframe
+
+        if hasattr(dataframe, "_sc"):
+            return dataframe.filter(self.func(dataframe))
+
         return dataframe[self.idx_array(dataframe)]
 
     def idx_array(self, dataframe):
diff --git a/freeforestml/tests/test_spark.py b/freeforestml/tests/test_spark.py
new file mode 100644
index 0000000..cfe9e90
--- /dev/null
+++ b/freeforestml/tests/test_spark.py
@@ -0,0 +1,56 @@
+
+import unittest
+from pyspark.sql import SparkSession
+import freeforestml as ff
+
+class TestparkCut(unittest.TestCase):
+    """Check that cut object can operate on PySpark dataframes"""
+    def setUp(self):
+        """Create or get spark session"""
+        self.spark = SparkSession.builder.master("local").getOrCreate()
+        self.sc = self.spark.sparkContext
+
+    def test_func(self):
+        """Check that lambda-cut can be applied to a spark dataframe"""
+        rdd = self.sc.parallelize([(x, x**2) for x in range(100)])
+        df = rdd.toDF(["x", "y"])
+
+        cut = ff.Cut(lambda d: d.y < 100)
+        selected = cut(df).count()
+        self.assertEqual(selected, 10)
+
+    def test_and(self):
+        """Check that derived cuts can be used on a pyspark dataframe"""
+        rdd = self.sc.parallelize([(x, x**2) for x in range(100)])
+        df = rdd.toDF(["x", "y"])
+
+        cut_y = ff.Cut(lambda d: d.y < 1000)
+        cut_x = ff.Cut(lambda d: d.x > 2)
+        cut = cut_x & cut_y
+
+        selected = cut(df).count()
+        self.assertEqual(selected, 29)
+
+    def test_or(self):
+        """Check that derived cuts can be used on a pyspark dataframe"""
+        rdd = self.sc.parallelize([(x, x**2) for x in range(100)])
+        df = rdd.toDF(["x", "y"])
+
+        cut_y = ff.Cut(lambda d: d.y < 1000)
+        cut_x = ff.Cut(lambda d: d.x > 2)
+        cut = cut_x | cut_y
+
+        selected = cut(df).count()
+        self.assertEqual(selected, 29)
+
+    def test_or(self):
+        """Check that derived cuts can be used on a pyspark dataframe"""
+        rdd = self.sc.parallelize([(x, x**2) for x in range(100)])
+        df = rdd.toDF(["x", "y"])
+
+        cut_y = ff.Cut(lambda d: d.y < 1000)
+        cut_x = ff.Cut(lambda d: d.x > 2)
+        cut = cut_x | cut_y
+
+        selected = cut(df).count()
+        self.assertEqual(selected, 29)
-- 
GitLab