Repository: spark Updated Branches: refs/heads/master 8ed044928 -> b5e118808
[SPARK-25124][ML] VectorSizeHint setSize and getSize don't return values ## What changes were proposed in this pull request? In feature.py, VectorSizeHint setSize and getSize don't return value. Add return. ## How was this patch tested? I tested the changes on my local. Closes #22136 from huaxingao/spark-25124. Authored-by: Huaxin Gao <huax...@us.ibm.com> Signed-off-by: Joseph K. Bradley <jos...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b5e11880 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b5e11880 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b5e11880 Branch: refs/heads/master Commit: b5e11880871d6ef31efe3ec42b3caa0fc403e71b Parents: 8ed0449 Author: Huaxin Gao <huax...@us.ibm.com> Authored: Thu Aug 23 16:17:27 2018 -0700 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Thu Aug 23 16:17:27 2018 -0700 ---------------------------------------------------------------------- python/pyspark/ml/feature.py | 4 ++-- python/pyspark/ml/tests.py | 17 +++++++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/b5e11880/python/pyspark/ml/feature.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index ddba738..760aa82 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -3843,12 +3843,12 @@ class VectorSizeHint(JavaTransformer, HasInputCol, HasHandleInvalid, JavaMLReada @since("2.3.0") def getSize(self): """ Gets size param, the size of vectors in `inputCol`.""" - self.getOrDefault(self.size) + return self.getOrDefault(self.size) @since("2.3.0") def setSize(self, value): """ Sets size param, the size of vectors in `inputCol`.""" - self._set(size=value) + return self._set(size=value) if __name__ == "__main__": http://git-wip-us.apache.org/repos/asf/spark/blob/b5e11880/python/pyspark/ml/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index a770bad..5c87d1d 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -844,6 +844,23 @@ class FeatureTests(SparkSessionTestCase): .select(model_default.getOrDefault(model_default.outputCol)).collect() self.assertEqual(len(transformed_list), 5) + def test_vector_size_hint(self): + df = self.spark.createDataFrame( + [(0, Vectors.dense([0.0, 10.0, 0.5])), + (1, Vectors.dense([1.0, 11.0, 0.5, 0.6])), + (2, Vectors.dense([2.0, 12.0]))], + ["id", "vector"]) + + sizeHint = VectorSizeHint( + inputCol="vector", + handleInvalid="skip") + sizeHint.setSize(3) + self.assertEqual(sizeHint.getSize(), 3) + + output = sizeHint.transform(df).head().vector + expected = DenseVector([0.0, 10.0, 0.5]) + self.assertEqual(output, expected) + class HasInducedError(Params): --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org