diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml
index 5e7292ff531f5675e897dd228387363fb61b02e2..e903defad5fddbb5783298abea89778b084ddbbd 100644
--- a/.gitlab-ci.yml
+++ b/.gitlab-ci.yml
@@ -10,7 +10,7 @@ doctest:
   script:
     - pip install -r requirements.txt
     - pip install -r dev-requirements.txt
-    - "python -m doctest -v $(ls freeforestml/*.py | grep -v '__init__.py')"
+    - ci/doctest.sh
 
 unittest:
   stage: test
diff --git a/ci/doctest.sh b/ci/doctest.sh
new file mode 100755
index 0000000000000000000000000000000000000000..69542152e76c038ee2a5765221cefa8ad4e2b0a4
--- /dev/null
+++ b/ci/doctest.sh
@@ -0,0 +1,3 @@
+#!/bin/bash
+
+python3 -m doctest -v $(ls freeforestml/*.py | grep -v '__init__.py')
diff --git a/freeforestml/cut.py b/freeforestml/cut.py
index b34114d823ba62bd42ca1a28a3883f8433be11ed..982c69c9b40cf1e1798b752e93ec4d45fc2cfe22 100644
--- a/freeforestml/cut.py
+++ b/freeforestml/cut.py
@@ -8,7 +8,7 @@ class Cut:
     quantities.
 
     Cuts store the condition to be applied to a dataframe. New cut objects
-    accept all event by default. The selection can be limited by passing a
+    accept all events by default. The selection can be limited by passing a
     lambda to the constructor. 
 
     >>> sel_all = Cut()
@@ -65,9 +65,20 @@ class Cut:
     >>> sel_sr = Cut(lambda df: df.is_sr == 1, label="Signal Region")
     >>> sel_sr.label
     'Signal Region'
+
+    If the application of a cut requires to change the event weights by a so
+    called scale factors, you can pass additional optional keyword arguments
+    that specify how the new weight should be computed.
+
+    >>> sel_sample = Cut(lambda df: df.value % 2 == 0, \
+                         weight=lambda df: df.weight * 2)
+
+    The argument name 'weight' in this example is arbitrary. It is even
+    possible to add new columns to the returned dataframe in this way,
+    however, this is not recommended.
     """
 
-    def __init__(self, func=None, label=None):
+    def __init__(self, func=None, label=None, **columns):
         """
         Creates a new cut. The optional func argument is called with the
         dataframe upon evaluation. The function must return an index array. If
@@ -77,9 +88,16 @@ class Cut:
         if isinstance(func, Cut):
             self.func = func.func
             self.label = label or func.label
+            self.columns = columns or func.columns
         else:
             self.func = func
             self.label = label
+            self.columns = columns
+
+    @staticmethod
+    def _is_pyspark(dataframe):
+        """Return true if assumed to be a PySpark Dataframe"""
+        return hasattr(dataframe, "_sc")
 
     def __call__(self, dataframe):
         """
@@ -89,10 +107,26 @@ class Cut:
         if self.func is None:
             return dataframe
 
-        if hasattr(dataframe, "_sc"):
-            return dataframe.filter(self.func(dataframe))
-
-        return dataframe[self.idx_array(dataframe)]
+        if self._is_pyspark(dataframe):
+            # Is PySpark
+            new_df = dataframe.filter(self.func(dataframe))
+            if self.columns:
+                add_columns = [
+                    col(new_df).alias(name)
+                    for name, col in self.columns.items()
+                ]
+
+                existing = [    
+                    n for n in new_df.columns if n not in self.columns
+                ]
+                new_df = new_df.select(*existing, *add_columns)
+
+            return new_df
+
+        new_df = dataframe[self.idx_array(dataframe)]
+        if self.columns:
+            new_df = new_df.assign(**self.columns)
+        return new_df
 
     def idx_array(self, dataframe):
         """
diff --git a/freeforestml/tests/test_cut.py b/freeforestml/tests/test_cut.py
index bff8f785a0ea865ba222db397d29da94ad733b57..1d0f3c7eae3426b4f24e97f107cb0466c12190bc 100644
--- a/freeforestml/tests/test_cut.py
+++ b/freeforestml/tests/test_cut.py
@@ -311,3 +311,26 @@ class CutTestCase(unittest.TestCase):
         high_sale = Cut(lambda df: df.sale > 10)
 
         self.assertEqual(list(high_sale(self.df).year), [])
+
+    def test_assign_columns(self):
+        """
+        Check that passing a keyword argument overwrites an existing column.
+        """
+        alternate = Cut(lambda df: df.year % 2 == 0,
+                        sale=lambda df: df.sale * 2)
+
+        df_alt = alternate(self.df)
+        self.assertEqual(list(df_alt.year), [2010, 2012, 2014, 2016])
+        self.assertEqual(list(df_alt.sale), [7.8, 9.4, 15.0, 4.6])
+
+
+    def test_assign_new_columns(self):
+        """
+        Check that passing a keyword argument creates a new columns
+        """
+        alternate = Cut(lambda df: df.year % 2 == 0,
+                        weight=lambda df: df.year * 0 + 2)
+        df_alt = alternate(self.df)
+        self.assertEqual(list(df_alt.year), [2010, 2012, 2014, 2016])
+        self.assertEqual(list(df_alt.sale), [3.9, 4.7, 7.5, 2.3])
+        self.assertEqual(list(df_alt.weight), [2, 2, 2, 2])
diff --git a/freeforestml/tests/test_spark.py b/freeforestml/tests/test_spark.py
index 69463289de39dc4c0716ed39f029c9aaf0f6c833..61ff2552347812be0d469d45f493175b5b6f6bfa 100644
--- a/freeforestml/tests/test_spark.py
+++ b/freeforestml/tests/test_spark.py
@@ -54,3 +54,21 @@ class TestparkCut(unittest.TestCase):
 
         selected = cut(df).count()
         self.assertEqual(selected, 100)
+
+    def test_assign(self):
+        """Check that a cut can assign a columns"""
+        rdd = self.sc.parallelize([(x, x**2) for x in range(100)])
+        df = rdd.toDF(["x", "y"])
+
+        cut_new = ff.Cut(lambda d: d.x < 5, weight=lambda d: d.x * 2)
+        sum_new = cut_new(df).agg({'weight': 'sum'}).collect()
+        self.assertEqual(sum_new[0][0], 2 + 4 + 6 + 8)
+
+    def test_overwrite(self):
+        """Check that a cut can overwrite a columns"""
+        rdd = self.sc.parallelize([(x, x**2) for x in range(100)])
+        df = rdd.toDF(["x", "y"])
+
+        cut_overwrite = ff.Cut(lambda d: d.x < 5, x=lambda d: d.x * 2)
+        sum_overwrite = cut_overwrite(df).agg({'x': 'sum'}).collect()
+        self.assertEqual(sum_overwrite[0][0], 2 + 4 + 6 + 8)