Repository: spark
Updated Branches:
  refs/heads/master b6935ffb4 -> e99825058


[SPARK-23828][ML][PYTHON] PySpark StringIndexerModel should have constructor 
from labels

## What changes were proposed in this pull request?

The Scala StringIndexerModel has an alternate constructor that will create the 
model from an array of label strings.  Add the corresponding Python API:

model = StringIndexerModel.from_labels(["a", "b", "c"])

## How was this patch tested?

Add doctest and unit test.

Author: Huaxin Gao <huax...@us.ibm.com>

Closes #20968 from huaxingao/spark-23828.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e9982505
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e9982505
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e9982505

Branch: refs/heads/master
Commit: e998250588de0df250e2800278da4d3e3705c259
Parents: b6935ff
Author: Huaxin Gao <huax...@us.ibm.com>
Authored: Fri Apr 6 11:51:36 2018 -0700
Committer: Bryan Cutler <cutl...@gmail.com>
Committed: Fri Apr 6 11:51:36 2018 -0700

----------------------------------------------------------------------
 python/pyspark/ml/feature.py | 88 ++++++++++++++++++++++++++++-----------
 python/pyspark/ml/tests.py   | 41 +++++++++++++++++-
 2 files changed, 104 insertions(+), 25 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e9982505/python/pyspark/ml/feature.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index fcb0dfc..5a3e0dd 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -2342,9 +2342,38 @@ class StandardScalerModel(JavaModel, JavaMLReadable, 
JavaMLWritable):
         return self._call_java("mean")
 
 
+class _StringIndexerParams(JavaParams, HasHandleInvalid, HasInputCol, 
HasOutputCol):
+    """
+    Params for :py:attr:`StringIndexer` and :py:attr:`StringIndexerModel`.
+    """
+
+    stringOrderType = Param(Params._dummy(), "stringOrderType",
+                            "How to order labels of string column. The first 
label after " +
+                            "ordering is assigned an index of 0. Supported 
options: " +
+                            "frequencyDesc, frequencyAsc, alphabetDesc, 
alphabetAsc.",
+                            typeConverter=TypeConverters.toString)
+
+    handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle 
invalid data (unseen " +
+                          "or NULL values) in features and label column of 
string type. " +
+                          "Options are 'skip' (filter out rows with invalid 
data), " +
+                          "error (throw an error), or 'keep' (put invalid data 
" +
+                          "in a special additional bucket, at index 
numLabels).",
+                          typeConverter=TypeConverters.toString)
+
+    def __init__(self, *args):
+        super(_StringIndexerParams, self).__init__(*args)
+        self._setDefault(handleInvalid="error", 
stringOrderType="frequencyDesc")
+
+    @since("2.3.0")
+    def getStringOrderType(self):
+        """
+        Gets the value of :py:attr:`stringOrderType` or its default value 
'frequencyDesc'.
+        """
+        return self.getOrDefault(self.stringOrderType)
+
+
 @inherit_doc
-class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, 
HasHandleInvalid, JavaMLReadable,
-                    JavaMLWritable):
+class StringIndexer(JavaEstimator, _StringIndexerParams, JavaMLReadable, 
JavaMLWritable):
     """
     A label indexer that maps a string column of labels to an ML column of 
label indices.
     If the input column is numeric, we cast it to string and index the string 
values.
@@ -2388,23 +2417,16 @@ class StringIndexer(JavaEstimator, HasInputCol, 
HasOutputCol, HasHandleInvalid,
     >>> sorted(set([(i[0], i[1]) for i in td.select(td.id, 
td.indexed).collect()]),
     ...     key=lambda x: x[0])
     [(0, 2.0), (1, 1.0), (2, 0.0), (3, 2.0), (4, 2.0), (5, 0.0)]
+    >>> fromlabelsModel = StringIndexerModel.from_labels(["a", "b", "c"],
+    ...     inputCol="label", outputCol="indexed", handleInvalid="error")
+    >>> result = fromlabelsModel.transform(stringIndDf)
+    >>> sorted(set([(i[0], i[1]) for i in result.select(result.id, 
result.indexed).collect()]),
+    ...     key=lambda x: x[0])
+    [(0, 0.0), (1, 1.0), (2, 2.0), (3, 0.0), (4, 0.0), (5, 2.0)]
 
     .. versionadded:: 1.4.0
     """
 
-    stringOrderType = Param(Params._dummy(), "stringOrderType",
-                            "How to order labels of string column. The first 
label after " +
-                            "ordering is assigned an index of 0. Supported 
options: " +
-                            "frequencyDesc, frequencyAsc, alphabetDesc, 
alphabetAsc.",
-                            typeConverter=TypeConverters.toString)
-
-    handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle 
invalid data (unseen " +
-                          "or NULL values) in features and label column of 
string type. " +
-                          "Options are 'skip' (filter out rows with invalid 
data), " +
-                          "error (throw an error), or 'keep' (put invalid data 
" +
-                          "in a special additional bucket, at index 
numLabels).",
-                          typeConverter=TypeConverters.toString)
-
     @keyword_only
     def __init__(self, inputCol=None, outputCol=None, handleInvalid="error",
                  stringOrderType="frequencyDesc"):
@@ -2414,7 +2436,6 @@ class StringIndexer(JavaEstimator, HasInputCol, 
HasOutputCol, HasHandleInvalid,
         """
         super(StringIndexer, self).__init__()
         self._java_obj = 
self._new_java_obj("org.apache.spark.ml.feature.StringIndexer", self.uid)
-        self._setDefault(handleInvalid="error", 
stringOrderType="frequencyDesc")
         kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
@@ -2440,21 +2461,33 @@ class StringIndexer(JavaEstimator, HasInputCol, 
HasOutputCol, HasHandleInvalid,
         """
         return self._set(stringOrderType=value)
 
-    @since("2.3.0")
-    def getStringOrderType(self):
-        """
-        Gets the value of :py:attr:`stringOrderType` or its default value 
'frequencyDesc'.
-        """
-        return self.getOrDefault(self.stringOrderType)
-
 
-class StringIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable):
+class StringIndexerModel(JavaModel, _StringIndexerParams, JavaMLReadable, 
JavaMLWritable):
     """
     Model fitted by :py:class:`StringIndexer`.
 
     .. versionadded:: 1.4.0
     """
 
+    @classmethod
+    @since("2.4.0")
+    def from_labels(cls, labels, inputCol, outputCol=None, handleInvalid=None):
+        """
+        Construct the model directly from an array of label strings,
+        requires an active SparkContext.
+        """
+        sc = SparkContext._active_spark_context
+        java_class = sc._gateway.jvm.java.lang.String
+        jlabels = StringIndexerModel._new_java_array(labels, java_class)
+        model = StringIndexerModel._create_from_java_class(
+            "org.apache.spark.ml.feature.StringIndexerModel", jlabels)
+        model.setInputCol(inputCol)
+        if outputCol is not None:
+            model.setOutputCol(outputCol)
+        if handleInvalid is not None:
+            model.setHandleInvalid(handleInvalid)
+        return model
+
     @property
     @since("1.5.0")
     def labels(self):
@@ -2463,6 +2496,13 @@ class StringIndexerModel(JavaModel, JavaMLReadable, 
JavaMLWritable):
         """
         return self._call_java("labels")
 
+    @since("2.4.0")
+    def setHandleInvalid(self, value):
+        """
+        Sets the value of :py:attr:`handleInvalid`.
+        """
+        return self._set(handleInvalid=value)
+
 
 @inherit_doc
 class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, 
JavaMLReadable, JavaMLWritable):

http://git-wip-us.apache.org/repos/asf/spark/blob/e9982505/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index c2c4861..4ce5454 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -800,6 +800,43 @@ class FeatureTests(SparkSessionTestCase):
         expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0)]
         self.assertEqual(actual2, expected2)
 
+    def test_string_indexer_from_labels(self):
+        model = StringIndexerModel.from_labels(["a", "b", "c"], 
inputCol="label",
+                                               outputCol="indexed", 
handleInvalid="keep")
+        self.assertEqual(model.labels, ["a", "b", "c"])
+
+        df1 = self.spark.createDataFrame([
+            (0, "a"),
+            (1, "c"),
+            (2, None),
+            (3, "b"),
+            (4, "b")], ["id", "label"])
+
+        result1 = model.transform(df1)
+        actual1 = result1.select("id", "indexed").collect()
+        expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=2.0), Row(id=2, 
indexed=3.0),
+                     Row(id=3, indexed=1.0), Row(id=4, indexed=1.0)]
+        self.assertEqual(actual1, expected1)
+
+        model_empty_labels = StringIndexerModel.from_labels(
+            [], inputCol="label", outputCol="indexed", handleInvalid="keep")
+        actual2 = model_empty_labels.transform(df1).select("id", 
"indexed").collect()
+        expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=0.0), Row(id=2, 
indexed=0.0),
+                     Row(id=3, indexed=0.0), Row(id=4, indexed=0.0)]
+        self.assertEqual(actual2, expected2)
+
+        # Test model with default settings can transform
+        model_default = StringIndexerModel.from_labels(["a", "b", "c"], 
inputCol="label")
+        df2 = self.spark.createDataFrame([
+            (0, "a"),
+            (1, "c"),
+            (2, "b"),
+            (3, "b"),
+            (4, "b")], ["id", "label"])
+        transformed_list = model_default.transform(df2)\
+            
.select(model_default.getOrDefault(model_default.outputCol)).collect()
+        self.assertEqual(len(transformed_list), 5)
+
 
 class HasInducedError(Params):
 
@@ -2097,9 +2134,11 @@ class DefaultValuesTests(PySparkTestCase):
                     ParamTests.check_params(self, cls(), 
check_params_exist=False)
 
         # Additional classes that need explicit construction
-        from pyspark.ml.feature import CountVectorizerModel
+        from pyspark.ml.feature import CountVectorizerModel, StringIndexerModel
         ParamTests.check_params(self, 
CountVectorizerModel.from_vocabulary(['a'], 'input'),
                                 check_params_exist=False)
+        ParamTests.check_params(self, StringIndexerModel.from_labels(['a', 
'b'], 'input'),
+                                check_params_exist=False)
 
 
 def _squared_distance(a, b):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to