Repository: spark Updated Branches: refs/heads/master 02d0a1ffd -> 20fa45693
[SPARK-25090][ML] Enforce implicit type coercion in ParamGridBuilder ## What changes were proposed in this pull request? When the grid of the parameters is created in `ParamGridBuilder`, the implicit type coercion is not enforced. So using an integer in the list of parameters to set for a parameter accepting a double can cause a class cast exception. The PR proposes to enforce the type coercion when building the parameters. ## How was this patch tested? added UT Closes #22076 from mgaido91/SPARK-25090. Authored-by: Marco Gaido <marcogaid...@gmail.com> Signed-off-by: hyukjinkwon <gurwls...@apache.org> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/20fa4569 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/20fa4569 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/20fa4569 Branch: refs/heads/master Commit: 20fa45693238cd39e162b129214f5d6a93e5552e Parents: 02d0a1f Author: Marco Gaido <marcogaid...@gmail.com> Authored: Mon Aug 13 09:11:37 2018 +0800 Committer: hyukjinkwon <gurwls...@apache.org> Committed: Mon Aug 13 09:11:37 2018 +0800 ---------------------------------------------------------------------- python/pyspark/ml/tests.py | 7 +++++++ python/pyspark/ml/tuning.py | 6 +++++- 2 files changed, 12 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/20fa4569/python/pyspark/ml/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 3d8883b..a770bad 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -950,6 +950,13 @@ class CrossValidatorTests(SparkSessionTestCase): "Best model should have zero induced error") self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1") + def test_param_grid_type_coercion(self): + lr = LogisticRegression(maxIter=10) + paramGrid = ParamGridBuilder().addGrid(lr.regParam, [0.5, 1]).build() + for param in paramGrid: + for v in param.values(): + assert(type(v) == float) + def test_save_load_trained_model(self): # This tests saving and loading the trained model only. # Save/load for CrossValidator will be added later: SPARK-13786 http://git-wip-us.apache.org/repos/asf/spark/blob/20fa4569/python/pyspark/ml/tuning.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 0c8029f..1f4abf5 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -115,7 +115,11 @@ class ParamGridBuilder(object): """ keys = self._param_grid.keys() grid_values = self._param_grid.values() - return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)] + + def to_key_value_pairs(keys, values): + return [(key, key.typeConverter(value)) for key, value in zip(keys, values)] + + return [dict(to_key_value_pairs(keys, prod)) for prod in itertools.product(*grid_values)] class ValidatorParams(HasSeed): --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org