Repository: spark
Updated Branches:
  refs/heads/master 5855b5c03 -> 2d868d939


[SPARK-22521][ML] VectorIndexerModel support handle unseen categories via 
handleInvalid: Python API

## What changes were proposed in this pull request?

Add python api for VectorIndexerModel support handle unseen categories via 
handleInvalid.

## How was this patch tested?

doctest added.

Author: WeichenXu <weichen...@databricks.com>

Closes #19753 from WeichenXu123/vector_indexer_invalid_py.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2d868d93
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2d868d93
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2d868d93

Branch: refs/heads/master
Commit: 2d868d93987ea1757cc66cdfb534bc49794eb0d0
Parents: 5855b5c
Author: WeichenXu <weichen...@databricks.com>
Authored: Tue Nov 21 10:53:53 2017 -0800
Committer: Holden Karau <holdenka...@google.com>
Committed: Tue Nov 21 10:53:53 2017 -0800

----------------------------------------------------------------------
 .../apache/spark/ml/feature/VectorIndexer.scala |  7 +++--
 python/pyspark/ml/feature.py                    | 30 +++++++++++++++-----
 2 files changed, 27 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2d868d93/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index 3403ec4..e6ec4e2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -47,7 +47,8 @@ private[ml] trait VectorIndexerParams extends Params with 
HasInputCol with HasOu
    * Options are:
    * 'skip': filter out rows with invalid data.
    * 'error': throw an error.
-   * 'keep': put invalid data in a special additional bucket, at index 
numCategories.
+   * 'keep': put invalid data in a special additional bucket, at index of the 
number of
+   * categories of the feature.
    * Default value: "error"
    * @group param
    */
@@ -55,7 +56,8 @@ private[ml] trait VectorIndexerParams extends Params with 
HasInputCol with HasOu
   override val handleInvalid: Param[String] = new Param[String](this, 
"handleInvalid",
     "How to handle invalid data (unseen labels or NULL values). " +
     "Options are 'skip' (filter out rows with invalid data), 'error' (throw an 
error), " +
-    "or 'keep' (put invalid data in a special additional bucket, at index 
numLabels).",
+    "or 'keep' (put invalid data in a special additional bucket, at index of 
the " +
+    "number of categories of the feature).",
     ParamValidators.inArray(VectorIndexer.supportedHandleInvalids))
 
   setDefault(handleInvalid, VectorIndexer.ERROR_INVALID)
@@ -112,7 +114,6 @@ private[ml] trait VectorIndexerParams extends Params with 
HasInputCol with HasOu
  *  - Preserve metadata in transform; if a feature's metadata is already 
present, do not recompute.
  *  - Specify certain features to not index, either via a parameter or via 
existing metadata.
  *  - Add warning if a categorical feature has only 1 category.
- *  - Add option for allowing unknown categories.
  */
 @Since("1.4.0")
 class VectorIndexer @Since("1.4.0") (

http://git-wip-us.apache.org/repos/asf/spark/blob/2d868d93/python/pyspark/ml/feature.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 232ae3e..608f2a5 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -2490,7 +2490,8 @@ class VectorAssembler(JavaTransformer, HasInputCols, 
HasOutputCol, JavaMLReadabl
 
 
 @inherit_doc
-class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, 
JavaMLWritable):
+class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, 
HasHandleInvalid, JavaMLReadable,
+                    JavaMLWritable):
     """
     Class for indexing categorical feature columns in a dataset of `Vector`.
 
@@ -2525,7 +2526,6 @@ class VectorIndexer(JavaEstimator, HasInputCol, 
HasOutputCol, JavaMLReadable, Ja
         do not recompute.
       - Specify certain features to not index, either via a parameter or via 
existing metadata.
       - Add warning if a categorical feature has only 1 category.
-      - Add option for allowing unknown categories.
 
     >>> from pyspark.ml.linalg import Vectors
     >>> df = spark.createDataFrame([(Vectors.dense([-1.0, 0.0]),),
@@ -2556,6 +2556,15 @@ class VectorIndexer(JavaEstimator, HasInputCol, 
HasOutputCol, JavaMLReadable, Ja
     True
     >>> loadedModel.categoryMaps == model.categoryMaps
     True
+    >>> dfWithInvalid = spark.createDataFrame([(Vectors.dense([3.0, 1.0]),)], 
["a"])
+    >>> indexer.getHandleInvalid()
+    'error'
+    >>> model3 = indexer.setHandleInvalid("skip").fit(df)
+    >>> model3.transform(dfWithInvalid).count()
+    0
+    >>> model4 = indexer.setParams(handleInvalid="keep", 
outputCol="indexed").fit(df)
+    >>> model4.transform(dfWithInvalid).head().indexed
+    DenseVector([2.0, 1.0])
 
     .. versionadded:: 1.4.0
     """
@@ -2565,22 +2574,29 @@ class VectorIndexer(JavaEstimator, HasInputCol, 
HasOutputCol, JavaMLReadable, Ja
                           "(>= 2). If a feature is found to have > 
maxCategories values, then " +
                           "it is declared continuous.", 
typeConverter=TypeConverters.toInt)
 
+    handleInvalid = Param(Params._dummy(), "handleInvalid", "How to handle 
invalid data " +
+                          "(unseen labels or NULL values). Options are 'skip' 
(filter out " +
+                          "rows with invalid data), 'error' (throw an error), 
or 'keep' (put " +
+                          "invalid data in a special additional bucket, at 
index of the number " +
+                          "of categories of the feature).",
+                          typeConverter=TypeConverters.toString)
+
     @keyword_only
-    def __init__(self, maxCategories=20, inputCol=None, outputCol=None):
+    def __init__(self, maxCategories=20, inputCol=None, outputCol=None, 
handleInvalid="error"):
         """
-        __init__(self, maxCategories=20, inputCol=None, outputCol=None)
+        __init__(self, maxCategories=20, inputCol=None, outputCol=None, 
handleInvalid="error")
         """
         super(VectorIndexer, self).__init__()
         self._java_obj = 
self._new_java_obj("org.apache.spark.ml.feature.VectorIndexer", self.uid)
-        self._setDefault(maxCategories=20)
+        self._setDefault(maxCategories=20, handleInvalid="error")
         kwargs = self._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
     @since("1.4.0")
-    def setParams(self, maxCategories=20, inputCol=None, outputCol=None):
+    def setParams(self, maxCategories=20, inputCol=None, outputCol=None, 
handleInvalid="error"):
         """
-        setParams(self, maxCategories=20, inputCol=None, outputCol=None)
+        setParams(self, maxCategories=20, inputCol=None, outputCol=None, 
handleInvalid="error")
         Sets params for this VectorIndexer.
         """
         kwargs = self._input_kwargs


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to