Skip to content
Snippets Groups Projects
Unverified Commit f91cb672 authored by Frank Sauerburger's avatar Frank Sauerburger
Browse files

Prepare cuts for PySpark Dataframes

parent a51669be
No related branches found
No related tags found
1 merge request!66Draft: Resolve "Integrate with pyspark"
Pipeline #12690 failed
......@@ -9,6 +9,7 @@ doctest:
image: python:3.7
script:
- pip install -r requirements.txt
- pip install -r dev-requirements.txt
- "python -m doctest -v $(ls freeforestml/*.py | grep -v '__init__.py')"
unittest:
......@@ -16,6 +17,7 @@ unittest:
image: python:3.7
script:
- pip install -r requirements.txt
- pip install -r dev-requirements.txt
- pip install pytest
- pytest
......
......@@ -86,6 +86,12 @@ class Cut:
Applies the internally stored cut to the given dataframe and returns a
new dataframe containing only entries passing the event selection.
"""
if self.func is None:
return dataframe
if hasattr(dataframe, "_sc"):
return dataframe.filter(self.func(dataframe))
return dataframe[self.idx_array(dataframe)]
def idx_array(self, dataframe):
......
import unittest
from pyspark.sql import SparkSession
import freeforestml as ff
class TestparkCut(unittest.TestCase):
"""Check that cut object can operate on PySpark dataframes"""
def setUp(self):
"""Create or get spark session"""
self.spark = SparkSession.builder.master("local").getOrCreate()
self.sc = self.spark.sparkContext
def test_func(self):
"""Check that lambda-cut can be applied to a spark dataframe"""
rdd = self.sc.parallelize([(x, x**2) for x in range(100)])
df = rdd.toDF(["x", "y"])
cut = ff.Cut(lambda d: d.y < 100)
selected = cut(df).count()
self.assertEqual(selected, 10)
def test_and(self):
"""Check that derived cuts can be used on a pyspark dataframe"""
rdd = self.sc.parallelize([(x, x**2) for x in range(100)])
df = rdd.toDF(["x", "y"])
cut_y = ff.Cut(lambda d: d.y < 1000)
cut_x = ff.Cut(lambda d: d.x > 2)
cut = cut_x & cut_y
selected = cut(df).count()
self.assertEqual(selected, 29)
def test_or(self):
"""Check that derived cuts can be used on a pyspark dataframe"""
rdd = self.sc.parallelize([(x, x**2) for x in range(100)])
df = rdd.toDF(["x", "y"])
cut_y = ff.Cut(lambda d: d.y < 1000)
cut_x = ff.Cut(lambda d: d.x > 2)
cut = cut_x | cut_y
selected = cut(df).count()
self.assertEqual(selected, 29)
def test_or(self):
"""Check that derived cuts can be used on a pyspark dataframe"""
rdd = self.sc.parallelize([(x, x**2) for x in range(100)])
df = rdd.toDF(["x", "y"])
cut_y = ff.Cut(lambda d: d.y < 1000)
cut_x = ff.Cut(lambda d: d.x > 2)
cut = cut_x | cut_y
selected = cut(df).count()
self.assertEqual(selected, 29)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment