Repository: spark Updated Branches: refs/heads/master 25826c77d -> 1347b2a69
[SPARK-21633][ML][PYTHON] UnaryTransformer in Python ## What changes were proposed in this pull request? Implemented UnaryTransformer in Python. ## How was this patch tested? This patch was tested by creating a MockUnaryTransformer class in the unit tests that extends UnaryTransformer and testing that the transform function produced correct output. Author: Ajay Saini <ajays...@gmail.com> Closes #18746 from ajaysaini725/AddPythonUnaryTransformer. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1347b2a6 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1347b2a6 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1347b2a6 Branch: refs/heads/master Commit: 1347b2a697aa798c04b39fbb352efc735aa42ea3 Parents: 25826c7 Author: Ajay Saini <ajays...@gmail.com> Authored: Fri Aug 4 01:01:32 2017 -0700 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Fri Aug 4 01:01:32 2017 -0700 ---------------------------------------------------------------------- python/pyspark/ml/__init__.py | 4 +-- python/pyspark/ml/base.py | 56 ++++++++++++++++++++++++++++++++++ python/pyspark/ml/tests.py | 62 +++++++++++++++++++++++++++++++++++++- 3 files changed, 119 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/1347b2a6/python/pyspark/ml/__init__.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/__init__.py b/python/pyspark/ml/__init__.py index 1d42d49..129d7d6 100644 --- a/python/pyspark/ml/__init__.py +++ b/python/pyspark/ml/__init__.py @@ -19,7 +19,7 @@ DataFrame-based machine learning APIs to let users quickly assemble and configure practical machine learning pipelines. """ -from pyspark.ml.base import Estimator, Model, Transformer +from pyspark.ml.base import Estimator, Model, Transformer, UnaryTransformer from pyspark.ml.pipeline import Pipeline, PipelineModel -__all__ = ["Transformer", "Estimator", "Model", "Pipeline", "PipelineModel"] +__all__ = ["Transformer", "UnaryTransformer", "Estimator", "Model", "Pipeline", "PipelineModel"] http://git-wip-us.apache.org/repos/asf/spark/blob/1347b2a6/python/pyspark/ml/base.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/base.py b/python/pyspark/ml/base.py index 339e5d6..a6767ce 100644 --- a/python/pyspark/ml/base.py +++ b/python/pyspark/ml/base.py @@ -17,9 +17,14 @@ from abc import ABCMeta, abstractmethod +import copy + from pyspark import since from pyspark.ml.param import Params +from pyspark.ml.param.shared import * from pyspark.ml.common import inherit_doc +from pyspark.sql.functions import udf +from pyspark.sql.types import StructField, StructType, DoubleType @inherit_doc @@ -116,3 +121,54 @@ class Model(Transformer): """ __metaclass__ = ABCMeta + + +@inherit_doc +class UnaryTransformer(HasInputCol, HasOutputCol, Transformer): + """ + Abstract class for transformers that take one input column, apply transformation, + and output the result as a new column. + + .. versionadded:: 2.3.0 + """ + + @abstractmethod + def createTransformFunc(self): + """ + Creates the transform function using the given param map. The input param map already takes + account of the embedded param map. So the param values should be determined + solely by the input param map. + """ + raise NotImplementedError() + + @abstractmethod + def outputDataType(self): + """ + Returns the data type of the output column. + """ + raise NotImplementedError() + + @abstractmethod + def validateInputType(self, inputType): + """ + Validates the input type. Throw an exception if it is invalid. + """ + raise NotImplementedError() + + def transformSchema(self, schema): + inputType = schema[self.getInputCol()].dataType + self.validateInputType(inputType) + if self.getOutputCol() in schema.names: + raise ValueError("Output column %s already exists." % self.getOutputCol()) + outputFields = copy.copy(schema.fields) + outputFields.append(StructField(self.getOutputCol(), + self.outputDataType(), + nullable=False)) + return StructType(outputFields) + + def _transform(self, dataset): + self.transformSchema(dataset.schema) + transformUDF = udf(self.createTransformFunc(), self.outputDataType()) + transformedDataset = dataset.withColumn(self.getOutputCol(), + transformUDF(dataset[self.getInputCol()])) + return transformedDataset http://git-wip-us.apache.org/repos/asf/spark/blob/1347b2a6/python/pyspark/ml/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 7ee2c2f..3bd4d37 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -45,7 +45,7 @@ from numpy import abs, all, arange, array, array_equal, inf, ones, tile, zeros import inspect from pyspark import keyword_only, SparkContext -from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer +from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer, UnaryTransformer from pyspark.ml.classification import * from pyspark.ml.clustering import * from pyspark.ml.common import _java2py, _py2java @@ -66,6 +66,7 @@ from pyspark.ml.wrapper import JavaParams, JavaWrapper from pyspark.serializers import PickleSerializer from pyspark.sql import DataFrame, Row, SparkSession from pyspark.sql.functions import rand +from pyspark.sql.types import DoubleType, IntegerType from pyspark.storagelevel import * from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase @@ -121,6 +122,36 @@ class MockTransformer(Transformer, HasFake): return dataset +class MockUnaryTransformer(UnaryTransformer): + + shift = Param(Params._dummy(), "shift", "The amount by which to shift " + + "data in a DataFrame", + typeConverter=TypeConverters.toFloat) + + def __init__(self, shiftVal=1): + super(MockUnaryTransformer, self).__init__() + self._setDefault(shift=1) + self._set(shift=shiftVal) + + def getShift(self): + return self.getOrDefault(self.shift) + + def setShift(self, shift): + self._set(shift=shift) + + def createTransformFunc(self): + shiftVal = self.getShift() + return lambda x: x + shiftVal + + def outputDataType(self): + return DoubleType() + + def validateInputType(self, inputType): + if inputType != DoubleType(): + raise TypeError("Bad input type: {}. ".format(inputType) + + "Requires Integer.") + + class MockEstimator(Estimator, HasFake): def __init__(self): @@ -2008,6 +2039,35 @@ class ChiSquareTestTests(SparkSessionTestCase): self.assertTrue(all(field in fieldNames for field in expectedFields)) +class UnaryTransformerTests(SparkSessionTestCase): + + def test_unary_transformer_validate_input_type(self): + shiftVal = 3 + transformer = MockUnaryTransformer(shiftVal=shiftVal)\ + .setInputCol("input").setOutputCol("output") + + # should not raise any errors + transformer.validateInputType(DoubleType()) + + with self.assertRaises(TypeError): + # passing the wrong input type should raise an error + transformer.validateInputType(IntegerType()) + + def test_unary_transformer_transform(self): + shiftVal = 3 + transformer = MockUnaryTransformer(shiftVal=shiftVal)\ + .setInputCol("input").setOutputCol("output") + + df = self.spark.range(0, 10).toDF('input') + df = df.withColumn("input", df.input.cast(dataType="double")) + + transformed_df = transformer.transform(df) + results = transformed_df.select("input", "output").collect() + + for res in results: + self.assertEqual(res.input + shiftVal, res.output) + + if __name__ == "__main__": from pyspark.ml.tests import * if xmlrunner: --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org