Repository: spark
Updated Branches:
  refs/heads/master 02d0a1ffd -> 20fa45693


[SPARK-25090][ML] Enforce implicit type coercion in ParamGridBuilder

## What changes were proposed in this pull request?

When the grid of the parameters is created in `ParamGridBuilder`, the implicit 
type coercion is not enforced. So using an integer in the list of parameters to 
set for a parameter accepting a double can cause a class cast exception.

The PR proposes to enforce the type coercion when building the parameters.

## How was this patch tested?

added UT

Closes #22076 from mgaido91/SPARK-25090.

Authored-by: Marco Gaido <marcogaid...@gmail.com>
Signed-off-by: hyukjinkwon <gurwls...@apache.org>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/20fa4569
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/20fa4569
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/20fa4569

Branch: refs/heads/master
Commit: 20fa45693238cd39e162b129214f5d6a93e5552e
Parents: 02d0a1f
Author: Marco Gaido <marcogaid...@gmail.com>
Authored: Mon Aug 13 09:11:37 2018 +0800
Committer: hyukjinkwon <gurwls...@apache.org>
Committed: Mon Aug 13 09:11:37 2018 +0800

----------------------------------------------------------------------
 python/pyspark/ml/tests.py  | 7 +++++++
 python/pyspark/ml/tuning.py | 6 +++++-
 2 files changed, 12 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/20fa4569/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 3d8883b..a770bad 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -950,6 +950,13 @@ class CrossValidatorTests(SparkSessionTestCase):
                          "Best model should have zero induced error")
         self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1")
 
+    def test_param_grid_type_coercion(self):
+        lr = LogisticRegression(maxIter=10)
+        paramGrid = ParamGridBuilder().addGrid(lr.regParam, [0.5, 1]).build()
+        for param in paramGrid:
+            for v in param.values():
+                assert(type(v) == float)
+
     def test_save_load_trained_model(self):
         # This tests saving and loading the trained model only.
         # Save/load for CrossValidator will be added later: SPARK-13786

http://git-wip-us.apache.org/repos/asf/spark/blob/20fa4569/python/pyspark/ml/tuning.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index 0c8029f..1f4abf5 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -115,7 +115,11 @@ class ParamGridBuilder(object):
         """
         keys = self._param_grid.keys()
         grid_values = self._param_grid.values()
-        return [dict(zip(keys, prod)) for prod in 
itertools.product(*grid_values)]
+
+        def to_key_value_pairs(keys, values):
+            return [(key, key.typeConverter(value)) for key, value in 
zip(keys, values)]
+
+        return [dict(to_key_value_pairs(keys, prod)) for prod in 
itertools.product(*grid_values)]
 
 
 class ValidatorParams(HasSeed):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to