Repository: spark Updated Branches: refs/heads/master bd201bf61 -> 8a72734f3
[SPARK-15009][PYTHON][ML] Construct a CountVectorizerModel from a vocabulary list ## What changes were proposed in this pull request? Added a class method to construct CountVectorizerModel from a list of vocabulary strings, equivalent to the Scala version. Introduced a common param base class `_CountVectorizerParams` to allow the Python model to also own the parameters. This now matches the Scala class hierarchy. ## How was this patch tested? Added to CountVectorizer doctests to do a transform on a model constructed from vocab, and unit test to verify params and vocab are constructed correctly. Author: Bryan Cutler <cutl...@gmail.com> Closes #16770 from BryanCutler/pyspark-CountVectorizerModel-vocab_ctor-SPARK-15009. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8a72734f Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8a72734f Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8a72734f Branch: refs/heads/master Commit: 8a72734f33f6a0abbd3207b0d661633c8b25d9ad Parents: bd201bf Author: Bryan Cutler <cutl...@gmail.com> Authored: Fri Mar 16 11:42:57 2018 -0700 Committer: Holden Karau <hol...@pigscanfly.ca> Committed: Fri Mar 16 11:42:57 2018 -0700 ---------------------------------------------------------------------- python/pyspark/ml/feature.py | 168 +++++++++++++++++++++++++------------- python/pyspark/ml/tests.py | 32 +++++++- 2 files changed, 142 insertions(+), 58 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/8a72734f/python/pyspark/ml/feature.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index f2e357f..a1ceb7f 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -19,12 +19,12 @@ import sys if sys.version > '3': basestring = str -from pyspark import since, keyword_only +from pyspark import since, keyword_only, SparkContext from pyspark.rdd import ignore_unicode_prefix from pyspark.ml.linalg import _convert_to_vector from pyspark.ml.param.shared import * from pyspark.ml.util import JavaMLReadable, JavaMLWritable -from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm +from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, JavaTransformer, _jvm from pyspark.ml.common import inherit_doc __all__ = ['Binarizer', @@ -403,8 +403,69 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, HasHandleInvalid, return self.getOrDefault(self.splits) +class _CountVectorizerParams(JavaParams, HasInputCol, HasOutputCol): + """ + Params for :py:attr:`CountVectorizer` and :py:attr:`CountVectorizerModel`. + """ + + minTF = Param( + Params._dummy(), "minTF", "Filter to ignore rare words in" + + " a document. For each document, terms with frequency/count less than the given" + + " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" + + " times the term must appear in the document); if this is a double in [0,1), then this " + + "specifies a fraction (out of the document's token count). Note that the parameter is " + + "only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0", + typeConverter=TypeConverters.toFloat) + minDF = Param( + Params._dummy(), "minDF", "Specifies the minimum number of" + + " different documents a term must appear in to be included in the vocabulary." + + " If this is an integer >= 1, this specifies the number of documents the term must" + + " appear in; if this is a double in [0,1), then this specifies the fraction of documents." + + " Default 1.0", typeConverter=TypeConverters.toFloat) + vocabSize = Param( + Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.", + typeConverter=TypeConverters.toInt) + binary = Param( + Params._dummy(), "binary", "Binary toggle to control the output vector values." + + " If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful" + + " for discrete probabilistic models that model binary events rather than integer counts." + + " Default False", typeConverter=TypeConverters.toBoolean) + + def __init__(self, *args): + super(_CountVectorizerParams, self).__init__(*args) + self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False) + + @since("1.6.0") + def getMinTF(self): + """ + Gets the value of minTF or its default value. + """ + return self.getOrDefault(self.minTF) + + @since("1.6.0") + def getMinDF(self): + """ + Gets the value of minDF or its default value. + """ + return self.getOrDefault(self.minDF) + + @since("1.6.0") + def getVocabSize(self): + """ + Gets the value of vocabSize or its default value. + """ + return self.getOrDefault(self.vocabSize) + + @since("2.0.0") + def getBinary(self): + """ + Gets the value of binary or its default value. + """ + return self.getOrDefault(self.binary) + + @inherit_doc -class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): +class CountVectorizer(JavaEstimator, _CountVectorizerParams, JavaMLReadable, JavaMLWritable): """ Extracts a vocabulary from document collections and generates a :py:attr:`CountVectorizerModel`. @@ -437,33 +498,20 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, >>> loadedModel = CountVectorizerModel.load(modelPath) >>> loadedModel.vocabulary == model.vocabulary True + >>> fromVocabModel = CountVectorizerModel.from_vocabulary(["a", "b", "c"], + ... inputCol="raw", outputCol="vectors") + >>> fromVocabModel.transform(df).show(truncate=False) + +-----+---------------+-------------------------+ + |label|raw |vectors | + +-----+---------------+-------------------------+ + |0 |[a, b, c] |(3,[0,1,2],[1.0,1.0,1.0])| + |1 |[a, b, b, c, a]|(3,[0,1,2],[2.0,2.0,1.0])| + +-----+---------------+-------------------------+ + ... .. versionadded:: 1.6.0 """ - minTF = Param( - Params._dummy(), "minTF", "Filter to ignore rare words in" + - " a document. For each document, terms with frequency/count less than the given" + - " threshold are ignored. If this is an integer >= 1, then this specifies a count (of" + - " times the term must appear in the document); if this is a double in [0,1), then this " + - "specifies a fraction (out of the document's token count). Note that the parameter is " + - "only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0", - typeConverter=TypeConverters.toFloat) - minDF = Param( - Params._dummy(), "minDF", "Specifies the minimum number of" + - " different documents a term must appear in to be included in the vocabulary." + - " If this is an integer >= 1, this specifies the number of documents the term must" + - " appear in; if this is a double in [0,1), then this specifies the fraction of documents." + - " Default 1.0", typeConverter=TypeConverters.toFloat) - vocabSize = Param( - Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.", - typeConverter=TypeConverters.toInt) - binary = Param( - Params._dummy(), "binary", "Binary toggle to control the output vector values." + - " If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful" + - " for discrete probabilistic models that model binary events rather than integer counts." + - " Default False", typeConverter=TypeConverters.toBoolean) - @keyword_only def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None, outputCol=None): @@ -474,7 +522,6 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, super(CountVectorizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer", self.uid) - self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False) kwargs = self._input_kwargs self.setParams(**kwargs) @@ -498,13 +545,6 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, return self._set(minTF=value) @since("1.6.0") - def getMinTF(self): - """ - Gets the value of minTF or its default value. - """ - return self.getOrDefault(self.minTF) - - @since("1.6.0") def setMinDF(self, value): """ Sets the value of :py:attr:`minDF`. @@ -512,26 +552,12 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, return self._set(minDF=value) @since("1.6.0") - def getMinDF(self): - """ - Gets the value of minDF or its default value. - """ - return self.getOrDefault(self.minDF) - - @since("1.6.0") def setVocabSize(self, value): """ Sets the value of :py:attr:`vocabSize`. """ return self._set(vocabSize=value) - @since("1.6.0") - def getVocabSize(self): - """ - Gets the value of vocabSize or its default value. - """ - return self.getOrDefault(self.vocabSize) - @since("2.0.0") def setBinary(self, value): """ @@ -539,24 +565,40 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, """ return self._set(binary=value) - @since("2.0.0") - def getBinary(self): - """ - Gets the value of binary or its default value. - """ - return self.getOrDefault(self.binary) - def _create_model(self, java_model): return CountVectorizerModel(java_model) -class CountVectorizerModel(JavaModel, JavaMLReadable, JavaMLWritable): +@inherit_doc +class CountVectorizerModel(JavaModel, _CountVectorizerParams, JavaMLReadable, JavaMLWritable): """ Model fitted by :py:class:`CountVectorizer`. .. versionadded:: 1.6.0 """ + @classmethod + @since("2.4.0") + def from_vocabulary(cls, vocabulary, inputCol, outputCol=None, minTF=None, binary=None): + """ + Construct the model directly from a vocabulary list of strings, + requires an active SparkContext. + """ + sc = SparkContext._active_spark_context + java_class = sc._gateway.jvm.java.lang.String + jvocab = CountVectorizerModel._new_java_array(vocabulary, java_class) + model = CountVectorizerModel._create_from_java_class( + "org.apache.spark.ml.feature.CountVectorizerModel", jvocab) + model.setInputCol(inputCol) + if outputCol is not None: + model.setOutputCol(outputCol) + if minTF is not None: + model.setMinTF(minTF) + if binary is not None: + model.setBinary(binary) + model._set(vocabSize=len(vocabulary)) + return model + @property @since("1.6.0") def vocabulary(self): @@ -565,6 +607,20 @@ class CountVectorizerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ return self._call_java("vocabulary") + @since("2.4.0") + def setMinTF(self, value): + """ + Sets the value of :py:attr:`minTF`. + """ + return self._set(minTF=value) + + @since("2.4.0") + def setBinary(self, value): + """ + Sets the value of :py:attr:`binary`. + """ + return self._set(binary=value) + @inherit_doc class DCT(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): http://git-wip-us.apache.org/repos/asf/spark/blob/8a72734f/python/pyspark/ml/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 6dee693..fd45fd0 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -679,6 +679,34 @@ class FeatureTests(SparkSessionTestCase): feature, expected = r self.assertEqual(feature, expected) + def test_count_vectorizer_from_vocab(self): + model = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words", + outputCol="features", minTF=2) + self.assertEqual(model.vocabulary, ["a", "b", "c"]) + self.assertEqual(model.getMinTF(), 2) + + dataset = self.spark.createDataFrame([ + (0, "a a a b b c".split(' '), SparseVector(3, {0: 3.0, 1: 2.0}),), + (1, "a a".split(' '), SparseVector(3, {0: 2.0}),), + (2, "a b".split(' '), SparseVector(3, {}),)], ["id", "words", "expected"]) + + transformed_list = model.transform(dataset).select("features", "expected").collect() + + for r in transformed_list: + feature, expected = r + self.assertEqual(feature, expected) + + # Test an empty vocabulary + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, "vocabSize.*invalid.*0"): + CountVectorizerModel.from_vocabulary([], inputCol="words") + + # Test model with default settings can transform + model_default = CountVectorizerModel.from_vocabulary(["a", "b", "c"], inputCol="words") + transformed_list = model_default.transform(dataset)\ + .select(model_default.getOrDefault(model_default.outputCol)).collect() + self.assertEqual(len(transformed_list), 3) + def test_rformula_force_index_label(self): df = self.spark.createDataFrame([ (1.0, 1.0, "a"), @@ -2019,8 +2047,8 @@ class DefaultValuesTests(PySparkTestCase): pyspark.ml.regression] for module in modules: for name, cls in inspect.getmembers(module, inspect.isclass): - if not name.endswith('Model') and issubclass(cls, JavaParams)\ - and not inspect.isabstract(cls): + if not name.endswith('Model') and not name.endswith('Params')\ + and issubclass(cls, JavaParams) and not inspect.isabstract(cls): # NOTE: disable check_params_exist until there is parity with Scala API ParamTests.check_params(self, cls(), check_params_exist=False) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org