Repository: spark Updated Branches: refs/heads/master 8a4f228dc -> 8598d03a0
[SPARK-15243][ML][SQL][PYTHON] Add missing support for unicode in Param methods & functions in dataframe ## What changes were proposed in this pull request? This PR proposes to support unicodes in Param methods in ML, other missed functions in DataFrame. For example, this causes a `ValueError` in Python 2.x when param is a unicode string: ```python >>> from pyspark.ml.classification import LogisticRegression >>> lr = LogisticRegression() >>> lr.hasParam("threshold") True >>> lr.hasParam(u"threshold") Traceback (most recent call last): ... raise TypeError("hasParam(): paramName must be a string") TypeError: hasParam(): paramName must be a string ``` This PR is based on https://github.com/apache/spark/pull/13036 ## How was this patch tested? Unit tests in `python/pyspark/ml/tests.py` and `python/pyspark/sql/tests.py`. Author: hyukjinkwon <gurwls...@gmail.com> Author: sethah <seth.hendrickso...@gmail.com> Closes #17096 from HyukjinKwon/SPARK-15243. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8598d03a Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8598d03a Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8598d03a Branch: refs/heads/master Commit: 8598d03a00a39dd23646bf752f9fed5d28e271c6 Parents: 8a4f228 Author: hyukjinkwon <gurwls...@gmail.com> Authored: Fri Sep 8 11:57:33 2017 -0700 Committer: Holden Karau <hol...@us.ibm.com> Committed: Fri Sep 8 11:57:33 2017 -0700 ---------------------------------------------------------------------- python/pyspark/ml/param/__init__.py | 4 ++-- python/pyspark/ml/tests.py | 15 +++++++++++++++ python/pyspark/sql/dataframe.py | 22 +++++++++++----------- python/pyspark/sql/tests.py | 22 ++++++++++++++-------- 4 files changed, 42 insertions(+), 21 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/8598d03a/python/pyspark/ml/param/__init__.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 1334207..043c25c 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -330,7 +330,7 @@ class Params(Identifiable): Tests whether this instance contains a param with a given (string) name. """ - if isinstance(paramName, str): + if isinstance(paramName, basestring): p = getattr(self, paramName, None) return isinstance(p, Param) else: @@ -413,7 +413,7 @@ class Params(Identifiable): if isinstance(param, Param): self._shouldOwn(param) return param - elif isinstance(param, str): + elif isinstance(param, basestring): return self.getParam(param) else: raise ValueError("Cannot resolve %r as a param." % param) http://git-wip-us.apache.org/repos/asf/spark/blob/8598d03a/python/pyspark/ml/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 6076b3c..509698f 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with @@ -352,6 +353,20 @@ class ParamTests(PySparkTestCase): testParams = TestParams() self.assertTrue(all([testParams.hasParam(p.name) for p in testParams.params])) self.assertFalse(testParams.hasParam("notAParameter")) + self.assertTrue(testParams.hasParam(u"maxIter")) + + def test_resolveparam(self): + testParams = TestParams() + self.assertEqual(testParams._resolveParam(testParams.maxIter), testParams.maxIter) + self.assertEqual(testParams._resolveParam("maxIter"), testParams.maxIter) + + self.assertEqual(testParams._resolveParam(u"maxIter"), testParams.maxIter) + if sys.version_info[0] >= 3: + # In Python 3, it is allowed to get/set attributes with non-ascii characters. + e_cls = AttributeError + else: + e_cls = UnicodeEncodeError + self.assertRaises(e_cls, lambda: testParams._resolveParam(u"ì")) def test_params(self): testParams = TestParams() http://git-wip-us.apache.org/repos/asf/spark/blob/8598d03a/python/pyspark/sql/dataframe.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 1cea130..8f88545 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -748,7 +748,7 @@ class DataFrame(object): +---+-----+ """ - if not isinstance(col, str): + if not isinstance(col, basestring): raise ValueError("col must be a string, but got %r" % type(col)) if not isinstance(fractions, dict): raise ValueError("fractions must be a dict but got %r" % type(fractions)) @@ -1664,18 +1664,18 @@ class DataFrame(object): Added support for multiple columns. """ - if not isinstance(col, (str, list, tuple)): + if not isinstance(col, (basestring, list, tuple)): raise ValueError("col should be a string, list or tuple, but got %r" % type(col)) - isStr = isinstance(col, str) + isStr = isinstance(col, basestring) if isinstance(col, tuple): col = list(col) - elif isinstance(col, str): + elif isStr: col = [col] for c in col: - if not isinstance(c, str): + if not isinstance(c, basestring): raise ValueError("columns should be strings, but got %r" % type(c)) col = _to_list(self._sc, col) @@ -1707,9 +1707,9 @@ class DataFrame(object): :param col2: The name of the second column :param method: The correlation method. Currently only supports "pearson" """ - if not isinstance(col1, str): + if not isinstance(col1, basestring): raise ValueError("col1 should be a string.") - if not isinstance(col2, str): + if not isinstance(col2, basestring): raise ValueError("col2 should be a string.") if not method: method = "pearson" @@ -1727,9 +1727,9 @@ class DataFrame(object): :param col1: The name of the first column :param col2: The name of the second column """ - if not isinstance(col1, str): + if not isinstance(col1, basestring): raise ValueError("col1 should be a string.") - if not isinstance(col2, str): + if not isinstance(col2, basestring): raise ValueError("col2 should be a string.") return self._jdf.stat().cov(col1, col2) @@ -1749,9 +1749,9 @@ class DataFrame(object): :param col2: The name of the second column. Distinct items will make the column names of the DataFrame. """ - if not isinstance(col1, str): + if not isinstance(col1, basestring): raise ValueError("col1 should be a string.") - if not isinstance(col2, str): + if not isinstance(col2, basestring): raise ValueError("col2 should be a string.") return DataFrame(self._jdf.stat().crosstab(col1, col2), self.sql_ctx) http://git-wip-us.apache.org/repos/asf/spark/blob/8598d03a/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1bc889c..4d65abc 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1140,11 +1140,12 @@ class SQLTests(ReusedPySparkTestCase): def test_approxQuantile(self): df = self.sc.parallelize([Row(a=i, b=i+10) for i in range(10)]).toDF() - aq = df.stat.approxQuantile("a", [0.1, 0.5, 0.9], 0.1) - self.assertTrue(isinstance(aq, list)) - self.assertEqual(len(aq), 3) + for f in ["a", u"a"]: + aq = df.stat.approxQuantile(f, [0.1, 0.5, 0.9], 0.1) + self.assertTrue(isinstance(aq, list)) + self.assertEqual(len(aq), 3) self.assertTrue(all(isinstance(q, float) for q in aq)) - aqs = df.stat.approxQuantile(["a", "b"], [0.1, 0.5, 0.9], 0.1) + aqs = df.stat.approxQuantile(["a", u"b"], [0.1, 0.5, 0.9], 0.1) self.assertTrue(isinstance(aqs, list)) self.assertEqual(len(aqs), 2) self.assertTrue(isinstance(aqs[0], list)) @@ -1153,7 +1154,7 @@ class SQLTests(ReusedPySparkTestCase): self.assertTrue(isinstance(aqs[1], list)) self.assertEqual(len(aqs[1]), 3) self.assertTrue(all(isinstance(q, float) for q in aqs[1])) - aqt = df.stat.approxQuantile(("a", "b"), [0.1, 0.5, 0.9], 0.1) + aqt = df.stat.approxQuantile((u"a", "b"), [0.1, 0.5, 0.9], 0.1) self.assertTrue(isinstance(aqt, list)) self.assertEqual(len(aqt), 2) self.assertTrue(isinstance(aqt[0], list)) @@ -1169,17 +1170,22 @@ class SQLTests(ReusedPySparkTestCase): def test_corr(self): import math df = self.sc.parallelize([Row(a=i, b=math.sqrt(i)) for i in range(10)]).toDF() - corr = df.stat.corr("a", "b") + corr = df.stat.corr(u"a", "b") self.assertTrue(abs(corr - 0.95734012) < 1e-6) + def test_sampleby(self): + df = self.sc.parallelize([Row(a=i, b=(i % 3)) for i in range(10)]).toDF() + sampled = df.stat.sampleBy(u"b", fractions={0: 0.5, 1: 0.5}, seed=0) + self.assertTrue(sampled.count() == 3) + def test_cov(self): df = self.sc.parallelize([Row(a=i, b=2 * i) for i in range(10)]).toDF() - cov = df.stat.cov("a", "b") + cov = df.stat.cov(u"a", "b") self.assertTrue(abs(cov - 55.0 / 3) < 1e-6) def test_crosstab(self): df = self.sc.parallelize([Row(a=i % 3, b=i % 2) for i in range(1, 7)]).toDF() - ct = df.stat.crosstab("a", "b").collect() + ct = df.stat.crosstab(u"a", "b").collect() ct = sorted(ct, key=lambda x: x[0]) for i, row in enumerate(ct): self.assertEqual(row[0], str(i)) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org