Repository: spark
Updated Branches:
  refs/heads/master 25826c77d -> 1347b2a69


[SPARK-21633][ML][PYTHON] UnaryTransformer in Python

## What changes were proposed in this pull request?

Implemented UnaryTransformer in Python.

## How was this patch tested?

This patch was tested by creating a MockUnaryTransformer class in the unit 
tests that extends UnaryTransformer and testing that the transform function 
produced correct output.

Author: Ajay Saini <ajays...@gmail.com>

Closes #18746 from ajaysaini725/AddPythonUnaryTransformer.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1347b2a6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1347b2a6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1347b2a6

Branch: refs/heads/master
Commit: 1347b2a697aa798c04b39fbb352efc735aa42ea3
Parents: 25826c7
Author: Ajay Saini <ajays...@gmail.com>
Authored: Fri Aug 4 01:01:32 2017 -0700
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Fri Aug 4 01:01:32 2017 -0700

----------------------------------------------------------------------
 python/pyspark/ml/__init__.py |  4 +--
 python/pyspark/ml/base.py     | 56 ++++++++++++++++++++++++++++++++++
 python/pyspark/ml/tests.py    | 62 +++++++++++++++++++++++++++++++++++++-
 3 files changed, 119 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1347b2a6/python/pyspark/ml/__init__.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/__init__.py b/python/pyspark/ml/__init__.py
index 1d42d49..129d7d6 100644
--- a/python/pyspark/ml/__init__.py
+++ b/python/pyspark/ml/__init__.py
@@ -19,7 +19,7 @@
 DataFrame-based machine learning APIs to let users quickly assemble and 
configure practical
 machine learning pipelines.
 """
-from pyspark.ml.base import Estimator, Model, Transformer
+from pyspark.ml.base import Estimator, Model, Transformer, UnaryTransformer
 from pyspark.ml.pipeline import Pipeline, PipelineModel
 
-__all__ = ["Transformer", "Estimator", "Model", "Pipeline", "PipelineModel"]
+__all__ = ["Transformer", "UnaryTransformer", "Estimator", "Model", 
"Pipeline", "PipelineModel"]

http://git-wip-us.apache.org/repos/asf/spark/blob/1347b2a6/python/pyspark/ml/base.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/base.py b/python/pyspark/ml/base.py
index 339e5d6..a6767ce 100644
--- a/python/pyspark/ml/base.py
+++ b/python/pyspark/ml/base.py
@@ -17,9 +17,14 @@
 
 from abc import ABCMeta, abstractmethod
 
+import copy
+
 from pyspark import since
 from pyspark.ml.param import Params
+from pyspark.ml.param.shared import *
 from pyspark.ml.common import inherit_doc
+from pyspark.sql.functions import udf
+from pyspark.sql.types import StructField, StructType, DoubleType
 
 
 @inherit_doc
@@ -116,3 +121,54 @@ class Model(Transformer):
     """
 
     __metaclass__ = ABCMeta
+
+
+@inherit_doc
+class UnaryTransformer(HasInputCol, HasOutputCol, Transformer):
+    """
+    Abstract class for transformers that take one input column, apply 
transformation,
+    and output the result as a new column.
+
+    .. versionadded:: 2.3.0
+    """
+
+    @abstractmethod
+    def createTransformFunc(self):
+        """
+        Creates the transform function using the given param map. The input 
param map already takes
+        account of the embedded param map. So the param values should be 
determined
+        solely by the input param map.
+        """
+        raise NotImplementedError()
+
+    @abstractmethod
+    def outputDataType(self):
+        """
+        Returns the data type of the output column.
+        """
+        raise NotImplementedError()
+
+    @abstractmethod
+    def validateInputType(self, inputType):
+        """
+        Validates the input type. Throw an exception if it is invalid.
+        """
+        raise NotImplementedError()
+
+    def transformSchema(self, schema):
+        inputType = schema[self.getInputCol()].dataType
+        self.validateInputType(inputType)
+        if self.getOutputCol() in schema.names:
+            raise ValueError("Output column %s already exists." % 
self.getOutputCol())
+        outputFields = copy.copy(schema.fields)
+        outputFields.append(StructField(self.getOutputCol(),
+                                        self.outputDataType(),
+                                        nullable=False))
+        return StructType(outputFields)
+
+    def _transform(self, dataset):
+        self.transformSchema(dataset.schema)
+        transformUDF = udf(self.createTransformFunc(), self.outputDataType())
+        transformedDataset = dataset.withColumn(self.getOutputCol(),
+                                                
transformUDF(dataset[self.getInputCol()]))
+        return transformedDataset

http://git-wip-us.apache.org/repos/asf/spark/blob/1347b2a6/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 7ee2c2f..3bd4d37 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -45,7 +45,7 @@ from numpy import abs, all, arange, array, array_equal, inf, 
ones, tile, zeros
 import inspect
 
 from pyspark import keyword_only, SparkContext
-from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer
+from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer, 
UnaryTransformer
 from pyspark.ml.classification import *
 from pyspark.ml.clustering import *
 from pyspark.ml.common import _java2py, _py2java
@@ -66,6 +66,7 @@ from pyspark.ml.wrapper import JavaParams, JavaWrapper
 from pyspark.serializers import PickleSerializer
 from pyspark.sql import DataFrame, Row, SparkSession
 from pyspark.sql.functions import rand
+from pyspark.sql.types import DoubleType, IntegerType
 from pyspark.storagelevel import *
 from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
 
@@ -121,6 +122,36 @@ class MockTransformer(Transformer, HasFake):
         return dataset
 
 
+class MockUnaryTransformer(UnaryTransformer):
+
+    shift = Param(Params._dummy(), "shift", "The amount by which to shift " +
+                  "data in a DataFrame",
+                  typeConverter=TypeConverters.toFloat)
+
+    def __init__(self, shiftVal=1):
+        super(MockUnaryTransformer, self).__init__()
+        self._setDefault(shift=1)
+        self._set(shift=shiftVal)
+
+    def getShift(self):
+        return self.getOrDefault(self.shift)
+
+    def setShift(self, shift):
+        self._set(shift=shift)
+
+    def createTransformFunc(self):
+        shiftVal = self.getShift()
+        return lambda x: x + shiftVal
+
+    def outputDataType(self):
+        return DoubleType()
+
+    def validateInputType(self, inputType):
+        if inputType != DoubleType():
+            raise TypeError("Bad input type: {}. ".format(inputType) +
+                            "Requires Integer.")
+
+
 class MockEstimator(Estimator, HasFake):
 
     def __init__(self):
@@ -2008,6 +2039,35 @@ class ChiSquareTestTests(SparkSessionTestCase):
         self.assertTrue(all(field in fieldNames for field in expectedFields))
 
 
+class UnaryTransformerTests(SparkSessionTestCase):
+
+    def test_unary_transformer_validate_input_type(self):
+        shiftVal = 3
+        transformer = MockUnaryTransformer(shiftVal=shiftVal)\
+            .setInputCol("input").setOutputCol("output")
+
+        # should not raise any errors
+        transformer.validateInputType(DoubleType())
+
+        with self.assertRaises(TypeError):
+            # passing the wrong input type should raise an error
+            transformer.validateInputType(IntegerType())
+
+    def test_unary_transformer_transform(self):
+        shiftVal = 3
+        transformer = MockUnaryTransformer(shiftVal=shiftVal)\
+            .setInputCol("input").setOutputCol("output")
+
+        df = self.spark.range(0, 10).toDF('input')
+        df = df.withColumn("input", df.input.cast(dataType="double"))
+
+        transformed_df = transformer.transform(df)
+        results = transformed_df.select("input", "output").collect()
+
+        for res in results:
+            self.assertEqual(res.input + shiftVal, res.output)
+
+
 if __name__ == "__main__":
     from pyspark.ml.tests import *
     if xmlrunner:


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to