Repository: spark Updated Branches: refs/heads/branch-2.3 9fa7b0e10 -> 8875e47ce
[SPARK-23387][SQL][PYTHON][TEST][BRANCH-2.3] Backport assertPandasEqual to branch-2.3. ## What changes were proposed in this pull request? When backporting a pr with tests using `assertPandasEqual` from master to branch-2.3, the tests fail because `assertPandasEqual` doesn't exist in branch-2.3. We should backport `assertPandasEqual` to branch-2.3 to avoid the failures. ## How was this patch tested? Modified tests. Author: Takuya UESHIN <ues...@databricks.com> Closes #20577 from ueshin/issues/SPARK-23387/branch-2.3. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8875e47c Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8875e47c Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8875e47c Branch: refs/heads/branch-2.3 Commit: 8875e47cec01ae8da4ffb855409b54089e1016fb Parents: 9fa7b0e Author: Takuya UESHIN <ues...@databricks.com> Authored: Sun Feb 11 22:16:47 2018 +0900 Committer: hyukjinkwon <gurwls...@gmail.com> Committed: Sun Feb 11 22:16:47 2018 +0900 ---------------------------------------------------------------------- python/pyspark/sql/tests.py | 44 +++++++++++++++++----------------------- 1 file changed, 19 insertions(+), 25 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/8875e47c/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 0f76c96..5480144 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -195,6 +195,12 @@ class ReusedSQLTestCase(ReusedPySparkTestCase): ReusedPySparkTestCase.tearDownClass() cls.spark.stop() + def assertPandasEqual(self, expected, result): + msg = ("DataFrames are not equal: " + + "\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) + + "\n\nResult:\n%s\n%s" % (result, result.dtypes)) + self.assertTrue(expected.equals(result), msg=msg) + class DataTypeTests(unittest.TestCase): # regression test for SPARK-6055 @@ -3422,12 +3428,6 @@ class ArrowTests(ReusedSQLTestCase): time.tzset() ReusedSQLTestCase.tearDownClass() - def assertFramesEqual(self, df_with_arrow, df_without): - msg = ("DataFrame from Arrow is not equal" + - ("\n\nWith Arrow:\n%s\n%s" % (df_with_arrow, df_with_arrow.dtypes)) + - ("\n\nWithout:\n%s\n%s" % (df_without, df_without.dtypes))) - self.assertTrue(df_without.equals(df_with_arrow), msg=msg) - def create_pandas_data_frame(self): import pandas as pd import numpy as np @@ -3466,8 +3466,8 @@ class ArrowTests(ReusedSQLTestCase): df = self.spark.createDataFrame(self.data, schema=self.schema) pdf, pdf_arrow = self._toPandas_arrow_toggle(df) expected = self.create_pandas_data_frame() - self.assertFramesEqual(expected, pdf) - self.assertFramesEqual(expected, pdf_arrow) + self.assertPandasEqual(expected, pdf) + self.assertPandasEqual(expected, pdf_arrow) def test_toPandas_respect_session_timezone(self): df = self.spark.createDataFrame(self.data, schema=self.schema) @@ -3478,11 +3478,11 @@ class ArrowTests(ReusedSQLTestCase): self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false") try: pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df) - self.assertFramesEqual(pdf_arrow_la, pdf_la) + self.assertPandasEqual(pdf_arrow_la, pdf_la) finally: self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "true") pdf_ny, pdf_arrow_ny = self._toPandas_arrow_toggle(df) - self.assertFramesEqual(pdf_arrow_ny, pdf_ny) + self.assertPandasEqual(pdf_arrow_ny, pdf_ny) self.assertFalse(pdf_ny.equals(pdf_la)) @@ -3492,7 +3492,7 @@ class ArrowTests(ReusedSQLTestCase): if isinstance(field.dataType, TimestampType): pdf_la_corrected[field.name] = _check_series_convert_timestamps_local_tz( pdf_la_corrected[field.name], timezone) - self.assertFramesEqual(pdf_ny, pdf_la_corrected) + self.assertPandasEqual(pdf_ny, pdf_la_corrected) finally: self.spark.conf.set("spark.sql.session.timeZone", orig_tz) @@ -3500,7 +3500,7 @@ class ArrowTests(ReusedSQLTestCase): pdf = self.create_pandas_data_frame() df = self.spark.createDataFrame(self.data, schema=self.schema) pdf_arrow = df.toPandas() - self.assertFramesEqual(pdf_arrow, pdf) + self.assertPandasEqual(pdf_arrow, pdf) def test_filtered_frame(self): df = self.spark.range(3).toDF("i") @@ -3558,7 +3558,7 @@ class ArrowTests(ReusedSQLTestCase): df = self.spark.createDataFrame(pdf, schema=self.schema) self.assertEquals(self.schema, df.schema) pdf_arrow = df.toPandas() - self.assertFramesEqual(pdf_arrow, pdf) + self.assertPandasEqual(pdf_arrow, pdf) def test_createDataFrame_with_incorrect_schema(self): pdf = self.create_pandas_data_frame() @@ -4318,12 +4318,6 @@ class ScalarPandasUDFTests(ReusedSQLTestCase): _pandas_requirement_message or _pyarrow_requirement_message) class GroupedMapPandasUDFTests(ReusedSQLTestCase): - def assertFramesEqual(self, expected, result): - msg = ("DataFrames are not equal: " + - ("\n\nExpected:\n%s\n%s" % (expected, expected.dtypes)) + - ("\n\nResult:\n%s\n%s" % (result, result.dtypes))) - self.assertTrue(expected.equals(result), msg=msg) - @property def data(self): from pyspark.sql.functions import array, explode, col, lit @@ -4347,7 +4341,7 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase): result = df.groupby('id').apply(foo_udf).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_register_grouped_map_udf(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -4371,7 +4365,7 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase): result = df.groupby('id').apply(foo).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_coerce(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -4386,7 +4380,7 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase): result = df.groupby('id').apply(foo).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) expected = expected.assign(v=expected.v.astype('float64')) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_complex_groupby(self): from pyspark.sql.functions import pandas_udf, col, PandasUDFType @@ -4405,7 +4399,7 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase): expected = pdf.groupby(pdf['id'] % 2 == 0).apply(normalize.func) expected = expected.sort_values(['id', 'v']).reset_index(drop=True) expected = expected.assign(norm=expected.norm.astype('float64')) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_empty_groupby(self): from pyspark.sql.functions import pandas_udf, col, PandasUDFType @@ -4424,7 +4418,7 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase): expected = normalize.func(pdf) expected = expected.sort_values(['id', 'v']).reset_index(drop=True) expected = expected.assign(norm=expected.norm.astype('float64')) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_datatype_string(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -4438,7 +4432,7 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase): result = df.groupby('id').apply(foo_udf).sort('id').toPandas() expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) - self.assertFramesEqual(expected, result) + self.assertPandasEqual(expected, result) def test_wrong_return_type(self): from pyspark.sql.functions import pandas_udf, PandasUDFType --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org