This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 91b95056806 [SPARK-39823][SQL][PYTHON] Rename Dataset.as as Dataset.to and add DataFrame.to in PySpark 91b95056806 is described below commit 91b950568066830ecd7a4581ab5bf4dbdbbeb474 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Wed Jul 27 08:11:18 2022 +0800 [SPARK-39823][SQL][PYTHON] Rename Dataset.as as Dataset.to and add DataFrame.to in PySpark ### What changes were proposed in this pull request? 1, rename `Dataset.as(StructType)` to `Dataset.to(StructType)`, since `as` is a keyword in python, we dont want to use a different name; 2, Add `DataFrame.to(StructType)` in Python ### Why are the changes needed? for function parity ### Does this PR introduce _any_ user-facing change? yes, new api ### How was this patch tested? added UT Closes #37233 from zhengruifeng/py_ds_as. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../source/reference/pyspark.sql/dataframe.rst | 1 + python/pyspark/sql/dataframe.py | 50 +++++++++++++++++++++ python/pyspark/sql/tests/test_dataframe.py | 36 ++++++++++++++- .../main/scala/org/apache/spark/sql/Dataset.scala | 4 +- ...emaSuite.scala => DataFrameToSchemaSuite.scala} | 52 +++++++++++----------- 5 files changed, 114 insertions(+), 29 deletions(-) diff --git a/python/docs/source/reference/pyspark.sql/dataframe.rst b/python/docs/source/reference/pyspark.sql/dataframe.rst index 5b6e704ba48..8cf083e5dd4 100644 --- a/python/docs/source/reference/pyspark.sql/dataframe.rst +++ b/python/docs/source/reference/pyspark.sql/dataframe.rst @@ -102,6 +102,7 @@ DataFrame DataFrame.summary DataFrame.tail DataFrame.take + DataFrame.to DataFrame.toDF DataFrame.toJSON DataFrame.toLocalIterator diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index efebd05c08d..481dafa310d 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1422,6 +1422,56 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): jc = self._jdf.colRegex(colName) return Column(jc) + def to(self, schema: StructType) -> "DataFrame": + """ + Returns a new :class:`DataFrame` where each row is reconciled to match the specified + schema. + + Notes + ----- + 1, Reorder columns and/or inner fields by name to match the specified schema. + + 2, Project away columns and/or inner fields that are not needed by the specified schema. + Missing columns and/or inner fields (present in the specified schema but not input + DataFrame) lead to failures. + + 3, Cast the columns and/or inner fields to match the data types in the specified schema, + if the types are compatible, e.g., numeric to numeric (error if overflows), but not string + to int. + + 4, Carry over the metadata from the specified schema, while the columns and/or inner fields + still keep their own metadata if not overwritten by the specified schema. + + 5, Fail if the nullability is not compatible. For example, the column and/or inner field + is nullable but the specified schema requires them to be not nullable. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + schema : :class:`StructType` + Specified schema. + + Examples + -------- + >>> df = spark.createDataFrame([("a", 1)], ["i", "j"]) + >>> df.schema + StructType([StructField('i', StringType(), True), StructField('j', LongType(), True)]) + >>> schema = StructType([StructField("j", StringType()), StructField("i", StringType())]) + >>> df2 = df.to(schema) + >>> df2.schema + StructType([StructField('j', StringType(), True), StructField('i', StringType(), True)]) + >>> df2.show() + +---+---+ + | j| i| + +---+---+ + | 1| a| + +---+---+ + """ + assert schema is not None + jschema = self._jdf.sparkSession().parseDataType(schema.json()) + return DataFrame(self._jdf.to(jschema), self.sparkSession) + def alias(self, alias: str) -> "DataFrame": """Returns a new :class:`DataFrame` with an alias set. diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index ac6b6f68aed..7c7d3d1e51c 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -25,11 +25,12 @@ import unittest from typing import cast from pyspark.sql import SparkSession, Row -from pyspark.sql.functions import col, lit, count, sum, mean +from pyspark.sql.functions import col, lit, count, sum, mean, struct from pyspark.sql.types import ( StringType, IntegerType, DoubleType, + LongType, StructType, StructField, BooleanType, @@ -1200,6 +1201,39 @@ class DataFrameTests(ReusedSQLTestCase): [Row(value=None)], ) + def test_to(self): + schema = StructType( + [StructField("i", StringType(), True), StructField("j", IntegerType(), True)] + ) + df = self.spark.createDataFrame([("a", 1)], schema) + + schema1 = StructType([StructField("j", StringType()), StructField("i", StringType())]) + df1 = df.to(schema1) + self.assertEqual(schema1, df1.schema) + self.assertEqual(df.count(), df1.count()) + + schema2 = StructType([StructField("j", LongType())]) + df2 = df.to(schema2) + self.assertEqual(schema2, df2.schema) + self.assertEqual(df.count(), df2.count()) + + schema3 = StructType([StructField("struct", schema1, False)]) + df3 = df.select(struct("i", "j").alias("struct")).to(schema3) + self.assertEqual(schema3, df3.schema) + self.assertEqual(df.count(), df3.count()) + + # incompatible field nullability + schema4 = StructType([StructField("j", LongType(), False)]) + self.assertRaisesRegex( + AnalysisException, "NULLABLE_COLUMN_OR_FIELD", lambda: df.to(schema4) + ) + + # field cannot upcast + schema5 = StructType([StructField("i", LongType())]) + self.assertRaisesRegex( + AnalysisException, "INVALID_COLUMN_OR_FIELD_DATA_TYPE", lambda: df.to(schema5) + ) + class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils): # These tests are separate because it uses 'spark.sql.queryExecutionListeners' which is diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 49b4a8389f9..2e1dc7d83d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -476,14 +476,14 @@ class Dataset[T] private[sql]( * int.</li> * <li>Carry over the metadata from the specified schema, while the columns and/or inner fields * still keep their own metadata if not overwritten by the specified schema.</li> - * <li>Fail if the nullability are not compatible. For example, the column and/or inner field is + * <li>Fail if the nullability is not compatible. For example, the column and/or inner field is * nullable but the specified schema requires them to be not nullable.</li> * </ul> * * @group basic * @since 3.4.0 */ - def as(schema: StructType): DataFrame = withPlan { + def to(schema: StructType): DataFrame = withPlan { Project.matchSchema(logicalPlan, schema, sparkSession.sessionState.conf) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameToSchemaSuite.scala similarity index 93% rename from sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsSchemaSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/DataFrameToSchemaSuite.scala index eccbfc339f0..26ddbc4569e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameToSchemaSuite.scala @@ -22,33 +22,33 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ -class DataFrameAsSchemaSuite extends QueryTest with SharedSparkSession { +class DataFrameToSchemaSuite extends QueryTest with SharedSparkSession { import testImplicits._ test("reorder columns by name") { val schema = new StructType().add("j", StringType).add("i", StringType) - val df = Seq("a" -> "b").toDF("i", "j").as(schema) + val df = Seq("a" -> "b").toDF("i", "j").to(schema) assert(df.schema == schema) checkAnswer(df, Row("b", "a")) } test("case insensitive: reorder columns by name") { val schema = new StructType().add("J", StringType).add("I", StringType) - val df = Seq("a" -> "b").toDF("i", "j").as(schema) + val df = Seq("a" -> "b").toDF("i", "j").to(schema) assert(df.schema == schema) checkAnswer(df, Row("b", "a")) } test("select part of the columns") { val schema = new StructType().add("j", StringType) - val df = Seq("a" -> "b").toDF("i", "j").as(schema) + val df = Seq("a" -> "b").toDF("i", "j").to(schema) assert(df.schema == schema) checkAnswer(df, Row("b")) } test("negative: column not found") { val schema = new StructType().add("non_exist", StringType) - val e = intercept[SparkThrowable](Seq("a" -> "b").toDF("i", "j").as(schema)) + val e = intercept[SparkThrowable](Seq("a" -> "b").toDF("i", "j").to(schema)) checkError( exception = e, errorClass = "UNRESOLVED_COLUMN", @@ -59,7 +59,7 @@ class DataFrameAsSchemaSuite extends QueryTest with SharedSparkSession { test("negative: ambiguous column") { val schema = new StructType().add("i", StringType) - val e = intercept[SparkThrowable](Seq("a" -> "b").toDF("i", "I").as(schema)) + val e = intercept[SparkThrowable](Seq("a" -> "b").toDF("i", "I").to(schema)) checkError( exception = e, errorClass = "AMBIGUOUS_COLUMN_OR_FIELD", @@ -72,7 +72,7 @@ class DataFrameAsSchemaSuite extends QueryTest with SharedSparkSession { val schema = new StructType().add("j", IntegerType) val data = Seq("a" -> 1).toDF("i", "j") assert(!data.schema.fields(1).nullable) - val df = data.as(schema) + val df = data.to(schema) val finalSchema = new StructType().add("j", IntegerType, nullable = false) assert(df.schema == finalSchema) checkAnswer(df, Row(1)) @@ -82,7 +82,7 @@ class DataFrameAsSchemaSuite extends QueryTest with SharedSparkSession { val schema = new StructType().add("i", IntegerType, nullable = false) val data = sql("SELECT i FROM VALUES 1, NULL as t(i)") assert(data.schema.fields(0).nullable) - val e = intercept[SparkThrowable](data.as(schema)) + val e = intercept[SparkThrowable](data.to(schema)) checkError( exception = e, errorClass = "NULLABLE_COLUMN_OR_FIELD", @@ -91,14 +91,14 @@ class DataFrameAsSchemaSuite extends QueryTest with SharedSparkSession { test("upcast the original column") { val schema = new StructType().add("j", LongType, nullable = false) - val df = Seq("a" -> 1).toDF("i", "j").as(schema) + val df = Seq("a" -> 1).toDF("i", "j").to(schema) assert(df.schema == schema) checkAnswer(df, Row(1L)) } test("negative: column cannot upcast") { val schema = new StructType().add("i", IntegerType) - val e = intercept[SparkThrowable](Seq("a" -> 1).toDF("i", "j").as(schema)) + val e = intercept[SparkThrowable](Seq("a" -> 1).toDF("i", "j").to(schema)) checkError( exception = e, errorClass = "INVALID_COLUMN_OR_FIELD_DATA_TYPE", @@ -113,7 +113,7 @@ class DataFrameAsSchemaSuite extends QueryTest with SharedSparkSession { val metadata1 = new MetadataBuilder().putString("a", "1").putString("b", "2").build() val metadata2 = new MetadataBuilder().putString("b", "3").putString("c", "4").build() val schema = new StructType().add("i", IntegerType, nullable = true, metadata = metadata2) - val df = Seq((1)).toDF("i").select($"i".as("i", metadata1)).as(schema) + val df = Seq((1)).toDF("i").select($"i".as("i", metadata1)).to(schema) // Metadata "a" remains, "b" gets overwritten by the specified schema, "c" is newly added. val resultMetadata = new MetadataBuilder() .putString("a", "1").putString("b", "3").putString("c", "4").build() @@ -124,7 +124,7 @@ class DataFrameAsSchemaSuite extends QueryTest with SharedSparkSession { test("reorder inner fields by name") { val innerFields = new StructType().add("j", StringType).add("i", StringType) val schema = new StructType().add("struct", innerFields, nullable = false) - val df = Seq("a" -> "b").toDF("i", "j").select(struct($"i", $"j").as("struct")).as(schema) + val df = Seq("a" -> "b").toDF("i", "j").select(struct($"i", $"j").as("struct")).to(schema) assert(df.schema == schema) checkAnswer(df, Row(Row("b", "a"))) } @@ -132,7 +132,7 @@ class DataFrameAsSchemaSuite extends QueryTest with SharedSparkSession { test("case insensitive: reorder inner fields by name") { val innerFields = new StructType().add("J", StringType).add("I", StringType) val schema = new StructType().add("struct", innerFields, nullable = false) - val df = Seq("a" -> "b").toDF("i", "j").select(struct($"i", $"j").as("struct")).as(schema) + val df = Seq("a" -> "b").toDF("i", "j").select(struct($"i", $"j").as("struct")).to(schema) assert(df.schema == schema) checkAnswer(df, Row(Row("b", "a"))) } @@ -141,7 +141,7 @@ class DataFrameAsSchemaSuite extends QueryTest with SharedSparkSession { val innerFields = new StructType().add("non_exist", StringType) val schema = new StructType().add("struct", innerFields, nullable = false) val e = intercept[SparkThrowable] { - Seq("a" -> "b").toDF("i", "j").select(struct($"i", $"j").as("struct")).as(schema) + Seq("a" -> "b").toDF("i", "j").select(struct($"i", $"j").as("struct")).to(schema) } checkError( exception = e, @@ -158,7 +158,7 @@ class DataFrameAsSchemaSuite extends QueryTest with SharedSparkSession { val data = Seq("a" -> 1).toDF("i", "j").select(struct($"i", $"j").as("struct")) assert(!data.schema.fields(0).nullable) assert(!data.schema.fields(0).dataType.asInstanceOf[StructType].fields(1).nullable) - val df = data.as(schema) + val df = data.to(schema) val finalFields = new StructType().add("j", IntegerType, nullable = false) val finalSchema = new StructType().add("struct", finalFields, nullable = false) assert(df.schema == finalSchema) @@ -171,7 +171,7 @@ class DataFrameAsSchemaSuite extends QueryTest with SharedSparkSession { val data = sql("SELECT i FROM VALUES 1, NULL as t(i)").select(struct($"i").as("struct")) assert(!data.schema.fields(0).nullable) assert(data.schema.fields(0).dataType.asInstanceOf[StructType].fields(0).nullable) - val e = intercept[SparkThrowable](data.as(schema)) + val e = intercept[SparkThrowable](data.to(schema)) checkError( exception = e, errorClass = "NULLABLE_COLUMN_OR_FIELD", @@ -181,7 +181,7 @@ class DataFrameAsSchemaSuite extends QueryTest with SharedSparkSession { test("upcast the original field") { val innerFields = new StructType().add("j", LongType, nullable = false) val schema = new StructType().add("struct", innerFields, nullable = false) - val df = Seq("a" -> 1).toDF("i", "j").select(struct($"i", $"j").as("struct")).as(schema) + val df = Seq("a" -> 1).toDF("i", "j").select(struct($"i", $"j").as("struct")).to(schema) assert(df.schema == schema) checkAnswer(df, Row(Row(1L))) } @@ -190,7 +190,7 @@ class DataFrameAsSchemaSuite extends QueryTest with SharedSparkSession { val innerFields = new StructType().add("i", IntegerType) val schema = new StructType().add("struct", innerFields, nullable = false) val e = intercept[SparkThrowable] { - Seq("a" -> 1).toDF("i", "j").select(struct($"i", $"j").as("struct")).as(schema) + Seq("a" -> 1).toDF("i", "j").select(struct($"i", $"j").as("struct")).to(schema) } checkError( exception = e, @@ -210,7 +210,7 @@ class DataFrameAsSchemaSuite extends QueryTest with SharedSparkSession { val df = Seq((1)).toDF("i") .select($"i".as("i", metadata1)) .select(struct($"i").as("struct")) - .as(schema) + .to(schema) // Metadata "a" remains, "b" gets overwritten by the specified schema, "c" is newly added. val resultMetadata = new MetadataBuilder() .putString("a", "1").putString("b", "3").putString("c", "4").build() @@ -223,7 +223,7 @@ class DataFrameAsSchemaSuite extends QueryTest with SharedSparkSession { val schema = new StructType().add("arr", arr, nullable = false) val df = Seq("a" -> "b").toDF("i", "j") .select(array(struct($"i", $"j")).as("arr")) - .as(schema) + .to(schema) assert(df.schema == schema) checkAnswer(df, Row(Seq(Row("b", "a")))) } @@ -234,7 +234,7 @@ class DataFrameAsSchemaSuite extends QueryTest with SharedSparkSession { val schema = new StructType().add("arr", arr, nullable = false) val df = Seq("a" -> 1).toDF("i", "j") .select(array(struct($"i", $"j")).as("arr")) - .as(schema) + .to(schema) assert(df.schema == schema) checkAnswer(df, Row(Seq(Row(1L)))) } @@ -244,7 +244,7 @@ class DataFrameAsSchemaSuite extends QueryTest with SharedSparkSession { val schema = new StructType().add("arr", arr) val data = sql("SELECT i FROM VALUES 1, NULL as t(i)").select(array($"i").as("arr")) assert(data.schema.fields(0).dataType.asInstanceOf[ArrayType].containsNull) - val e = intercept[SparkThrowable](data.as(schema)) + val e = intercept[SparkThrowable](data.to(schema)) checkError( exception = e, errorClass = "NULLABLE_ARRAY_OR_MAP_ELEMENT", @@ -260,7 +260,7 @@ class DataFrameAsSchemaSuite extends QueryTest with SharedSparkSession { val df = Seq((1)).toDF("i") .select($"i") .select(array(struct($"i")).as("arr", metadata1)) - .as(schema) + .to(schema) // Metadata "a" remains, "b" gets overwritten by the specified schema, "c" is newly added. val resultMetadata = new MetadataBuilder() .putString("a", "1").putString("b", "3").putString("c", "4").build() @@ -276,7 +276,7 @@ class DataFrameAsSchemaSuite extends QueryTest with SharedSparkSession { val df = Seq((1)).toDF("i") .select($"i".as("i", metadata1)) .select(array(struct($"i")).as("arr")) - .as(schema) + .to(schema) // Metadata "a" remains, "b" gets overwritten by the specified schema, "c" is newly added. val resultMetadata = new MetadataBuilder() .putString("a", "1").putString("b", "3").putString("c", "4").build() @@ -290,7 +290,7 @@ class DataFrameAsSchemaSuite extends QueryTest with SharedSparkSession { val schema = new StructType().add("map", m, nullable = false) val df = Seq("a" -> "b").toDF("i", "j") .select(map(struct($"i", $"j"), $"i").as("map")) - .as(schema) + .to(schema) assert(df.schema == schema) checkAnswer(df, Row(Map(Row("b", "a") -> "a"))) } @@ -301,7 +301,7 @@ class DataFrameAsSchemaSuite extends QueryTest with SharedSparkSession { val schema = new StructType().add("map", m, nullable = false) val df = Seq("a" -> "b").toDF("i", "j") .select(map($"i", struct($"i", $"j")).as("map")) - .as(schema) + .to(schema) assert(df.schema == schema) checkAnswer(df, Row(Map("a" -> Row("b", "a")))) } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org