Repository: spark Updated Branches: refs/heads/master 5855b5c03 -> 2d868d939
[SPARK-22521][ML] VectorIndexerModel support handle unseen categories via handleInvalid: Python API ## What changes were proposed in this pull request? Add python api for VectorIndexerModel support handle unseen categories via handleInvalid. ## How was this patch tested? doctest added. Author: WeichenXu <weichen...@databricks.com> Closes #19753 from WeichenXu123/vector_indexer_invalid_py. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2d868d93 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2d868d93 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2d868d93 Branch: refs/heads/master Commit: 2d868d93987ea1757cc66cdfb534bc49794eb0d0 Parents: 5855b5c Author: WeichenXu <weichen...@databricks.com> Authored: Tue Nov 21 10:53:53 2017 -0800 Committer: Holden Karau <holdenka...@google.com> Committed: Tue Nov 21 10:53:53 2017 -0800 ---------------------------------------------------------------------- .../apache/spark/ml/feature/VectorIndexer.scala | 7 +++-- python/pyspark/ml/feature.py | 30 +++++++++++++++----- 2 files changed, 27 insertions(+), 10 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/2d868d93/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 3403ec4..e6ec4e2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -47,7 +47,8 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu * Options are: * 'skip': filter out rows with invalid data. * 'error': throw an error. - * 'keep': put invalid data in a special additional bucket, at index numCategories. + * 'keep': put invalid data in a special additional bucket, at index of the number of + * categories of the feature. * Default value: "error" * @group param */ @@ -55,7 +56,8 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "How to handle invalid data (unseen labels or NULL values). " + "Options are 'skip' (filter out rows with invalid data), 'error' (throw an error), " + - "or 'keep' (put invalid data in a special additional bucket, at index numLabels).", + "or 'keep' (put invalid data in a special additional bucket, at index of the " + + "number of categories of the feature).", ParamValidators.inArray(VectorIndexer.supportedHandleInvalids)) setDefault(handleInvalid, VectorIndexer.ERROR_INVALID) @@ -112,7 +114,6 @@ private[ml] trait VectorIndexerParams extends Params with HasInputCol with HasOu * - Preserve metadata in transform; if a feature's metadata is already present, do not recompute. * - Specify certain features to not index, either via a parameter or via existing metadata. * - Add warning if a categorical feature has only 1 category. - * - Add option for allowing unknown categories. */ @Since("1.4.0") class VectorIndexer @Since("1.4.0") ( http://git-wip-us.apache.org/repos/asf/spark/blob/2d868d93/python/pyspark/ml/feature.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 232ae3e..608f2a5 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2490,7 +2490,8 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, JavaMLReadabl @inherit_doc -class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): +class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, JavaMLReadable, + JavaMLWritable): """ Class for indexing categorical feature columns in a dataset of `Vector`. @@ -2525,7 +2526,6 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja do not recompute. - Specify certain features to not index, either via a parameter or via existing metadata. - Add warning if a categorical feature has only 1 category. - - Add option for allowing unknown categories. >>> from pyspark.ml.linalg import Vectors >>> df = spark.createDataFrame([(Vectors.dense([-1.0, 0.0]),), @@ -2556,6 +2556,15 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja True >>> loadedModel.categoryMaps == model.categoryMaps True + >>> dfWithInvalid = spark.createDataFrame([(Vectors.dense([3.0, 1.0]),)], ["a"]) + >>> indexer.getHandleInvalid() + 'error' + >>> model3 = indexer.setHandleInvalid("skip").fit(df) + >>> model3.transform(dfWithInvalid).count() + 0 + >>> model4 = indexer.setParams(handleInvalid="keep", outputCol="indexed").fit(df) + >>> model4.transform(dfWithInvalid).head().indexed + DenseVector([2.0, 1.0]) .. versionadded:: 1.4.0 """ @@ -2565,22 +2574,29 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja "(>= 2). If a feature is found to have > maxCategories values, then " + "it is declared continuous.", typeConverter=TypeConverters.toInt) + handleInvalid = Param(Params._dummy(), "handleInvalid", "How to handle invalid data " + + "(unseen labels or NULL values). Options are 'skip' (filter out " + + "rows with invalid data), 'error' (throw an error), or 'keep' (put " + + "invalid data in a special additional bucket, at index of the number " + + "of categories of the feature).", + typeConverter=TypeConverters.toString) + @keyword_only - def __init__(self, maxCategories=20, inputCol=None, outputCol=None): + def __init__(self, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error"): """ - __init__(self, maxCategories=20, inputCol=None, outputCol=None) + __init__(self, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error") """ super(VectorIndexer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorIndexer", self.uid) - self._setDefault(maxCategories=20) + self._setDefault(maxCategories=20, handleInvalid="error") kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @since("1.4.0") - def setParams(self, maxCategories=20, inputCol=None, outputCol=None): + def setParams(self, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error"): """ - setParams(self, maxCategories=20, inputCol=None, outputCol=None) + setParams(self, maxCategories=20, inputCol=None, outputCol=None, handleInvalid="error") Sets params for this VectorIndexer. """ kwargs = self._input_kwargs --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org