Repository: spark Updated Branches: refs/heads/branch-1.6 91a5ca5e8 -> 9d45ec466
[SPARK-13047][PYSPARK][ML] Pyspark Params.hasParam should not throw an error Pyspark Params class has a method `hasParam(paramName)` which returns `True` if the class has a parameter by that name, but throws an `AttributeError` otherwise. There is not currently a way of getting a Boolean to indicate if a class has a parameter. With Spark 2.0 we could modify the existing behavior of `hasParam` or add an additional method with this functionality. In Python: ```python from pyspark.ml.classification import NaiveBayes nb = NaiveBayes() print nb.hasParam("smoothing") print nb.hasParam("notAParam") ``` produces: > True > AttributeError: 'NaiveBayes' object has no attribute 'notAParam' However, in Scala: ```scala import org.apache.spark.ml.classification.NaiveBayes val nb = new NaiveBayes() nb.hasParam("smoothing") nb.hasParam("notAParam") ``` produces: > true > false cc holdenk Author: sethah <seth.hendrickso...@gmail.com> Closes #10962 from sethah/SPARK-13047. (cherry picked from commit b35467388612167f0bc3d17142c21a406f6c620d) Signed-off-by: Xiangrui Meng <m...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9d45ec46 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9d45ec46 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9d45ec46 Branch: refs/heads/branch-1.6 Commit: 9d45ec466a4067bb2d0b59ff1174bec630daa7b1 Parents: 91a5ca5 Author: sethah <seth.hendrickso...@gmail.com> Authored: Thu Feb 11 16:42:44 2016 -0800 Committer: Xiangrui Meng <m...@databricks.com> Committed: Thu Feb 11 16:42:52 2016 -0800 ---------------------------------------------------------------------- python/pyspark/ml/param/__init__.py | 7 +++++-- python/pyspark/ml/tests.py | 9 +++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/9d45ec46/python/pyspark/ml/param/__init__.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 35c9b77..666689c 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -156,8 +156,11 @@ class Params(Identifiable): Tests whether this instance contains a param with a given (string) name. """ - param = self._resolveParam(paramName) - return param in self.params + if isinstance(paramName, str): + p = getattr(self, paramName, None) + return isinstance(p, Param) + else: + raise TypeError("hasParam(): paramName must be a string") @since("1.4.0") def getOrDefault(self, param): http://git-wip-us.apache.org/repos/asf/spark/blob/9d45ec46/python/pyspark/ml/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 7a16cf5..674dbe9 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -170,6 +170,11 @@ class ParamTests(PySparkTestCase): self.assertEqual(maxIter.doc, "max number of iterations (>= 0).") self.assertTrue(maxIter.parent == testParams.uid) + def test_hasparam(self): + testParams = TestParams() + self.assertTrue(all([testParams.hasParam(p.name) for p in testParams.params])) + self.assertFalse(testParams.hasParam("notAParameter")) + def test_params(self): testParams = TestParams() maxIter = testParams.maxIter @@ -179,7 +184,7 @@ class ParamTests(PySparkTestCase): params = testParams.params self.assertEqual(params, [inputCol, maxIter, seed]) - self.assertTrue(testParams.hasParam(maxIter)) + self.assertTrue(testParams.hasParam(maxIter.name)) self.assertTrue(testParams.hasDefault(maxIter)) self.assertFalse(testParams.isSet(maxIter)) self.assertTrue(testParams.isDefined(maxIter)) @@ -188,7 +193,7 @@ class ParamTests(PySparkTestCase): self.assertTrue(testParams.isSet(maxIter)) self.assertEqual(testParams.getMaxIter(), 100) - self.assertTrue(testParams.hasParam(inputCol)) + self.assertTrue(testParams.hasParam(inputCol.name)) self.assertFalse(testParams.hasDefault(inputCol)) self.assertFalse(testParams.isSet(inputCol)) self.assertFalse(testParams.isDefined(inputCol)) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org