Repository: spark Updated Branches: refs/heads/master 0e2f21633 -> 2f6fd5256
[SPARK-9654] [ML] [PYSPARK] Add IndexToString to PySpark Adds IndexToString to PySpark. Author: Holden Karau <hol...@pigscanfly.ca> Closes #7976 from holdenk/SPARK-9654-add-string-indexer-inverse-in-pyspark. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2f6fd525 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2f6fd525 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2f6fd525 Branch: refs/heads/master Commit: 2f6fd5256c6650868916a3eefaa0beb091187cbb Parents: 0e2f216 Author: Holden Karau <hol...@pigscanfly.ca> Authored: Tue Sep 8 22:13:05 2015 -0700 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Tue Sep 8 22:13:05 2015 -0700 ---------------------------------------------------------------------- .../apache/spark/ml/feature/StringIndexer.scala | 2 +- python/pyspark/ml/feature.py | 74 ++++++++++++++++++-- python/pyspark/ml/wrapper.py | 3 +- 3 files changed, 73 insertions(+), 6 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/2f6fd525/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 77aeed0..b6482ff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -102,7 +102,7 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod * [[StringIndexerModel.transform]] would return the input dataset unmodified. * This is a temporary fix for the case when target labels do not exist during prediction. * - * @param labels Ordered list of labels, corresponding to indices to be assigned + * @param labels Ordered list of labels, corresponding to indices to be assigned. */ @Experimental class StringIndexerModel ( http://git-wip-us.apache.org/repos/asf/spark/blob/2f6fd525/python/pyspark/ml/feature.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index a7c5b2b..8c26cfb 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -27,10 +27,11 @@ from pyspark.mllib.common import inherit_doc from pyspark.mllib.linalg import _convert_to_vector __all__ = ['Binarizer', 'Bucketizer', 'DCT', 'ElementwiseProduct', 'HashingTF', 'IDF', 'IDFModel', - 'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer', - 'SQLTransformer', 'StandardScaler', 'StandardScalerModel', 'StringIndexer', - 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'Word2Vec', - 'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel', 'StopWordsRemover'] + 'IndexToString', 'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', + 'RegexTokenizer', 'SQLTransformer', 'StandardScaler', 'StandardScalerModel', + 'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', + 'Word2Vec', 'Word2VecModel', 'PCA', 'PCAModel', 'RFormula', 'RFormulaModel', + 'StopWordsRemover'] @inherit_doc @@ -934,6 +935,11 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol): >>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]), ... key=lambda x: x[0]) [(0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)] + >>> inverter = IndexToString(inputCol="indexed", outputCol="label2", labels=model.labels()) + >>> itd = inverter.transform(td) + >>> sorted(set([(i[0], str(i[1])) for i in itd.select(itd.id, itd.label2).collect()]), + ... key=lambda x: x[0]) + [(0, 'a'), (1, 'b'), (2, 'c'), (3, 'a'), (4, 'a'), (5, 'c')] """ @keyword_only @@ -965,6 +971,66 @@ class StringIndexerModel(JavaModel): Model fitted by StringIndexer. """ + @property + def labels(self): + """ + Ordered list of labels, corresponding to indices to be assigned. + """ + return self._java_obj.labels + + +@inherit_doc +class IndexToString(JavaTransformer, HasInputCol, HasOutputCol): + """ + .. note:: Experimental + + A :py:class:`Transformer` that maps a column of string indices back to a new column of + corresponding string values using either the ML attributes of the input column, or if + provided using the labels supplied by the user. + All original columns are kept during transformation. + See L{StringIndexer} for converting strings into indices. + """ + + # a placeholder to make the labels show up in generated doc + labels = Param(Params._dummy(), "labels", + "Optional array of labels to be provided by the user, if not supplied or " + + "empty, column metadata is read for labels") + + @keyword_only + def __init__(self, inputCol=None, outputCol=None, labels=None): + """ + __init__(self, inputCol=None, outputCol=None, labels=None) + """ + super(IndexToString, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IndexToString", + self.uid) + self.labels = Param(self, "labels", + "Optional array of labels to be provided by the user, if not " + + "supplied or empty, column metadata is read for labels") + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, inputCol=None, outputCol=None, labels=None): + """ + setParams(self, inputCol=None, outputCol=None, labels=None) + Sets params for this IndexToString. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setLabels(self, value): + """ + Sets the value of :py:attr:`labels`. + """ + self._paramMap[self.labels] = value + return self + + def getLabels(self): + """ + Gets the value of :py:attr:`labels` or its default value. + """ + return self.getOrDefault(self.labels) class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol): http://git-wip-us.apache.org/repos/asf/spark/blob/2f6fd525/python/pyspark/ml/wrapper.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 253705b..8218c7c 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -136,7 +136,8 @@ class JavaEstimator(Estimator, JavaWrapper): class JavaTransformer(Transformer, JavaWrapper): """ Base class for :py:class:`Transformer`s that wrap Java/Scala - implementations. + implementations. Subclasses should ensure they have the transformer Java object + available as _java_obj. """ __metaclass__ = ABCMeta --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org