This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new e9d31a0a1dd [SPARK-41875][CONNECT][PYTHON] Add test cases for `Dataset.to()` e9d31a0a1dd is described below commit e9d31a0a1dd54900207f92760b63bdf53f6688b4 Author: Jiaan Geng <belie...@163.com> AuthorDate: Sat Jan 7 09:43:49 2023 +0800 [SPARK-41875][CONNECT][PYTHON] Add test cases for `Dataset.to()` ### What changes were proposed in this pull request? 1. This PR let the parameter of `Dataset.to()` the same as pyspark. 2. The connect's `Dataset.to()` lost some test cases. This PR adds these test cases that refer https://github.com/apache/spark/blob/89666d44a39c48df841a0102ff6f54eaeb4c6140/python/pyspark/sql/tests/test_dataframe.py#L1464 ### Why are the changes needed? This PR adds these test cases for connect's `Dataset.to()`. ### Does this PR introduce _any_ user-facing change? 'No'. New feature. ### How was this patch tested? New test cases. Closes #39422 from beliefer/SPARK-41875. Authored-by: Jiaan Geng <belie...@163.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/sql/connect/dataframe.py | 4 +- .../sql/tests/connect/test_connect_basic.py | 116 ++++++++++++--------- 2 files changed, 69 insertions(+), 51 deletions(-) diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 8aca9fbb968..17b88461a43 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -40,7 +40,7 @@ from collections.abc import Iterable from pyspark import _NoValue, SparkContext, SparkConf from pyspark._globals import _NoValueType -from pyspark.sql.types import DataType, StructType, Row +from pyspark.sql.types import StructType, Row import pyspark.sql.connect.plan as plan from pyspark.sql.connect.group import GroupedData @@ -1210,7 +1210,7 @@ class DataFrame: inputFiles.__doc__ = PySparkDataFrame.inputFiles.__doc__ - def to(self, schema: DataType) -> "DataFrame": + def to(self, schema: StructType) -> "DataFrame": assert schema is not None return DataFrame.withPlan( plan.ToSchema(child=self._plan, schema=schema), diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 72e60712b98..31a7e6fdbad 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -30,7 +30,6 @@ from pyspark.sql.types import ( ArrayType, Row, ) -import pyspark.sql.functions from pyspark.testing.utils import ReusedPySparkTestCase from pyspark.testing.connectutils import should_test_connect, connect_requirement_message from pyspark.testing.pandasutils import PandasOnSparkTestCase @@ -43,9 +42,11 @@ if should_test_connect: from pyspark.sql.connect.session import SparkSession as RemoteSparkSession from pyspark.sql.connect.client import ChannelBuilder from pyspark.sql.connect.column import Column + from pyspark.sql.dataframe import DataFrame from pyspark.sql.connect.dataframe import DataFrame as CDataFrame from pyspark.sql.connect.function_builder import udf - from pyspark.sql.connect.functions import lit, col + from pyspark.sql import functions as SF + from pyspark.sql.connect import functions as CF @unittest.skipIf(not should_test_connect, connect_requirement_message) @@ -333,7 +334,7 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): """SPARK-41114: Test creating a dataframe using local data""" pdf = pd.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]}) df = self.connect.createDataFrame(pdf) - rows = df.filter(df.a == lit(3)).collect() + rows = df.filter(df.a == CF.lit(3)).collect() self.assertTrue(len(rows) == 1) self.assertEqual(rows[0][0], 3) self.assertEqual(rows[0][1], "c") @@ -679,6 +680,15 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): def test_to(self): # SPARK-41464: test DataFrame.to() + cdf = self.connect.read.table(self.tbl_name) + df = self.spark.read.table(self.tbl_name) + + def assert_eq_schema(cdf: CDataFrame, df: DataFrame, schema: StructType): + cdf_to = cdf.to(schema) + df_to = df.to(schema) + self.assertEqual(cdf_to.schema, df_to.schema) + self.assert_eq(cdf_to.toPandas(), df_to.toPandas()) + # The schema has not changed schema = StructType( [ @@ -687,11 +697,15 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): ] ) - cdf = self.connect.read.table(self.tbl_name).to(schema) - df = self.spark.read.table(self.tbl_name).to(schema) + assert_eq_schema(cdf, df, schema) + + # Change schema with struct + schema2 = StructType([StructField("struct", schema, False)]) + + cdf_to = cdf.select(CF.struct("id", "name").alias("struct")).to(schema2) + df_to = df.select(SF.struct("id", "name").alias("struct")).to(schema2) - self.assertEqual(cdf.schema, df.schema) - self.assert_eq(cdf.toPandas(), df.toPandas()) + self.assertEqual(cdf_to.schema, df_to.schema) # Change the column name schema = StructType( @@ -701,11 +715,7 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): ] ) - cdf = self.connect.read.table(self.tbl_name).to(schema) - df = self.spark.read.table(self.tbl_name).to(schema) - - self.assertEqual(cdf.schema, df.schema) - self.assert_eq(cdf.toPandas(), df.toPandas()) + assert_eq_schema(cdf, df, schema) # Change the column data type schema = StructType( @@ -715,26 +725,44 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): ] ) - cdf = self.connect.read.table(self.tbl_name).to(schema) - df = self.spark.read.table(self.tbl_name).to(schema) + assert_eq_schema(cdf, df, schema) + + # Reduce the column quantity and change data type + schema = StructType( + [ + StructField("id", LongType(), True), + ] + ) + + assert_eq_schema(cdf, df, schema) + + # incompatible field nullability + schema = StructType([StructField("id", LongType(), False)]) + self.assertRaisesRegex( + SparkConnectAnalysisException, + "NULLABLE_COLUMN_OR_FIELD", + lambda: cdf.to(schema).toPandas(), + ) - self.assertEqual(cdf.schema, df.schema) - self.assert_eq(cdf.toPandas(), df.toPandas()) + # field cannot upcast + schema = StructType([StructField("name", LongType())]) + self.assertRaisesRegex( + SparkConnectAnalysisException, + "INVALID_COLUMN_OR_FIELD_DATA_TYPE", + lambda: cdf.to(schema).toPandas(), + ) - # Change the column data type failed schema = StructType( [ StructField("id", IntegerType(), True), StructField("name", IntegerType(), True), ] ) - - with self.assertRaises(SparkConnectException) as context: - self.connect.read.table(self.tbl_name).to(schema).toPandas() - self.assertIn( - """Column or field `name` is of type "STRING" while it's required to be "INT".""", - str(context.exception), - ) + self.assertRaisesRegex( + SparkConnectAnalysisException, + "INVALID_COLUMN_OR_FIELD_DATA_TYPE", + lambda: cdf.to(schema).toPandas(), + ) # Test map type and array type schema = StructType( @@ -744,11 +772,10 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): StructField("my_array", ArrayType(IntegerType(), False), True), ] ) - cdf = self.connect.read.table(self.tbl_name4).to(schema) - df = self.spark.read.table(self.tbl_name4).to(schema) + cdf = self.connect.read.table(self.tbl_name4) + df = self.spark.read.table(self.tbl_name4) - self.assertEqual(cdf.schema, df.schema) - self.assert_eq(cdf.toPandas(), df.toPandas()) + assert_eq_schema(cdf, df, schema) def test_toDF(self): # SPARK-41310: test DataFrame.toDF() @@ -1195,21 +1222,19 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): def test_with_columns(self): # SPARK-41256: test withColumn(s). self.assert_eq( - self.connect.read.table(self.tbl_name).withColumn("id", lit(False)).toPandas(), - self.spark.read.table(self.tbl_name) - .withColumn("id", pyspark.sql.functions.lit(False)) - .toPandas(), + self.connect.read.table(self.tbl_name).withColumn("id", CF.lit(False)).toPandas(), + self.spark.read.table(self.tbl_name).withColumn("id", SF.lit(False)).toPandas(), ) self.assert_eq( self.connect.read.table(self.tbl_name) - .withColumns({"id": lit(False), "col_not_exist": lit(False)}) + .withColumns({"id": CF.lit(False), "col_not_exist": CF.lit(False)}) .toPandas(), self.spark.read.table(self.tbl_name) .withColumns( { - "id": pyspark.sql.functions.lit(False), - "col_not_exist": pyspark.sql.functions.lit(False), + "id": SF.lit(False), + "col_not_exist": SF.lit(False), } ) .toPandas(), @@ -1392,9 +1417,6 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): def test_stat_sample_by(self): # SPARK-41069: Test stat.sample_by - from pyspark.sql import functions as SF - from pyspark.sql.connect import functions as CF - cdf = self.connect.range(0, 100).select((CF.col("id") % 3).alias("key")) sdf = self.spark.range(0, 100).select((SF.col("id") % 3).alias("key")) @@ -1475,7 +1497,7 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): """SPARK-41203: Support DF.transform""" def transform_df(input_df: CDataFrame) -> CDataFrame: - return input_df.select((col("id") + lit(10)).alias("id")) + return input_df.select((CF.col("id") + CF.lit(10)).alias("id")) df = self.connect.range(1, 100) result_left = df.transform(transform_df).collect() @@ -1490,13 +1512,13 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): """Testing supported and unsupported alias""" col0 = ( self.connect.range(1, 10) - .select(col("id").alias("name", metadata={"max": 99})) + .select(CF.col("id").alias("name", metadata={"max": 99})) .schema.names[0] ) self.assertEqual("name", col0) with self.assertRaises(SparkConnectException) as exc: - self.connect.range(1, 10).select(col("id").alias("this", "is", "not")).collect() + self.connect.range(1, 10).select(CF.col("id").alias("this", "is", "not")).collect() self.assertIn("(this, is, not)", str(exc.exception)) def test_column_regexp(self) -> None: @@ -1555,7 +1577,7 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): # SPARK-41325: groupby.avg() df = ( self.connect.range(10) - .groupBy((col("id") % lit(2)).alias("moded")) + .groupBy((CF.col("id") % CF.lit(2)).alias("moded")) .avg("id") .sort("moded") ) @@ -1565,13 +1587,11 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): self.assertEqual(5.0, res[1][1]) # Additional GroupBy tests with 3 rows - import pyspark.sql.connect.functions as CF - import pyspark.sql.functions as PF - df_a = self.connect.range(10).groupBy((col("id") % lit(3)).alias("moded")) - df_b = self.spark.range(10).groupBy((PF.col("id") % PF.lit(3)).alias("moded")) + df_a = self.connect.range(10).groupBy((CF.col("id") % CF.lit(3)).alias("moded")) + df_b = self.spark.range(10).groupBy((SF.col("id") % SF.lit(3)).alias("moded")) self.assertEqual( - set(df_b.agg(PF.sum("id")).collect()), set(df_a.agg(CF.sum("id")).collect()) + set(df_b.agg(SF.sum("id")).collect()), set(df_a.agg(CF.sum("id")).collect()) ) # Dict agg @@ -1603,8 +1623,6 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase): ) def test_grouped_data(self): - from pyspark.sql import functions as SF - from pyspark.sql.connect import functions as CF query = """ SELECT * FROM VALUES --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org