This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 8585fafab86 [SPARK-39877][PYTHON] Add unpivot to PySpark DataFrame API 8585fafab86 is described below commit 8585fafab8633c02e6f1b989acd2bbdb0eb1678e Author: Enrico Minack <git...@enrico.minack.dev> AuthorDate: Mon Aug 1 09:39:12 2022 +0800 [SPARK-39877][PYTHON] Add unpivot to PySpark DataFrame API ### What changes were proposed in this pull request? This adds `unpivot` and its alias `melt` to the PySpark API. It calls into Scala `Dataset.unpivot` (#36150). Small difference to Scala method signature is that PySpark method has default values. This is similar to `melt` in Spark Pandas API. ### Why are the changes needed? To support `unpivot` in Python. ### Does this PR introduce _any_ user-facing change? Yes, adds `DataFrame.unpivot` and `DataFrame.melt` to PySpark API. ### How was this patch tested? Added test to `test_dataframe.py`. Closes #37304 from EnricoMi/branch-pyspark-unpivot. Authored-by: Enrico Minack <git...@enrico.minack.dev> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/sql/dataframe.py | 134 +++++++++++++++++++ python/pyspark/sql/tests/test_dataframe.py | 144 +++++++++++++++++++++ .../main/scala/org/apache/spark/sql/Dataset.scala | 11 ++ 3 files changed, 289 insertions(+) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 481dafa310d..8c9632fe766 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -2238,6 +2238,140 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin): return GroupedData(jgd, self) + def unpivot( + self, + ids: Optional[Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]], + values: Optional[Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]], + variableColumnName: str, + valueColumnName: str, + ) -> "DataFrame": + """ + Unpivot a DataFrame from wide format to long format, optionally leaving + identifier columns set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, + except for the aggregation, which cannot be reversed. + + This function is useful to massage a DataFrame into a format where some + columns are identifier columns ("ids"), while all other columns ("values") + are "unpivoted" to the rows, leaving just two non-id columns, named as given + by `variableColumnName` and `valueColumnName`. + + When no "id" columns are given, the unpivoted DataFrame consists of only the + "variable" and "value" columns. + + All "value" columns must share a least common data type. Unless they are the same data type, + all "value" columns are cast to the nearest common data type. For instance, types + `IntegerType` and `LongType` are cast to `LongType`, while `IntegerType` and `StringType` + do not have a common data type and `unpivot` fails. + + :func:`groupby` is an alias for :func:`groupBy`. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + ids : str, Column, tuple, list, optional + Column(s) to use as identifiers. Can be a single column or column name, + or a list or tuple for multiple columns. + values : str, Column, tuple, list, optional + Column(s) to unpivot. Can be a single column or column name, or a list or tuple + for multiple columns. If not specified or empty, uses all columns that + are not set as `ids`. + variableColumnName : str + Name of the variable column. + valueColumnName : str + Name of the value column. + + Returns + ------- + DataFrame + Unpivoted DataFrame. + + Examples + -------- + >>> df = spark.createDataFrame( + ... [(1, 11, 1.1), (2, 12, 1.2)], + ... ["id", "int", "double"], + ... ) + >>> df.show() + +---+---+------+ + | id|int|double| + +---+---+------+ + | 1| 11| 1.1| + | 2| 12| 1.2| + +---+---+------+ + + >>> df.unpivot("id", ["int", "double"], "var", "val").show() + +---+------+----+ + | id| var| val| + +---+------+----+ + | 1| int|11.0| + | 1|double| 1.1| + | 2| int|12.0| + | 2|double| 1.2| + +---+------+----+ + """ + + def to_jcols( + cols: Optional[Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]] + ) -> JavaObject: + if cols is None: + lst = [] + elif isinstance(cols, tuple): + lst = list(cols) + elif isinstance(cols, list): + lst = cols + else: + lst = [cols] + return self._jcols(*lst) + + return DataFrame( + self._jdf.unpivotWithSeq( + to_jcols(ids), to_jcols(values), variableColumnName, valueColumnName + ), + self.sparkSession, + ) + + def melt( + self, + ids: Optional[Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]], + values: Optional[Union["ColumnOrName", List["ColumnOrName"], Tuple["ColumnOrName", ...]]], + variableColumnName: str, + valueColumnName: str, + ) -> "DataFrame": + """ + Unpivot a DataFrame from wide format to long format, optionally leaving + identifier columns set. This is the reverse to `groupBy(...).pivot(...).agg(...)`, + except for the aggregation, which cannot be reversed. + + :func:`melt` is an alias for :func:`unpivot`. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + ids : str, Column, tuple, list, optional + Column(s) to use as identifiers. Can be a single column or column name, + or a list or tuple for multiple columns. + values : str, Column, tuple, list, optional + Column(s) to unpivot. Can be a single column or column name, or a list or tuple + for multiple columns. If not specified or empty, uses all columns that + are not set as `ids`. + variableColumnName : str + Name of the variable column. + valueColumnName : str + Name of the value column. + + Returns + ------- + DataFrame + Unpivoted DataFrame. + + See Also + -------- + DataFrame.unpivot + """ + return self.unpivot(ids, values, variableColumnName, valueColumnName) + def agg(self, *exprs: Union[Column, Dict[str, str]]) -> "DataFrame": """Aggregate on the entire :class:`DataFrame` without groups (shorthand for ``df.groupBy().agg()``). diff --git a/python/pyspark/sql/tests/test_dataframe.py b/python/pyspark/sql/tests/test_dataframe.py index 7c7d3d1e51c..987ff91402d 100644 --- a/python/pyspark/sql/tests/test_dataframe.py +++ b/python/pyspark/sql/tests/test_dataframe.py @@ -534,6 +534,150 @@ class DataFrameTests(ReusedSQLTestCase): self.assertEqual(1, logical_plan.toString().count("what")) self.assertEqual(3, logical_plan.toString().count("itworks")) + def test_unpivot(self): + # SPARK-39877: test the DataFrame.unpivot method + df = self.spark.createDataFrame( + [ + (1, 10, 1.0, "one"), + (2, 20, 2.0, "two"), + (3, 30, 3.0, "three"), + ], + ["id", "int", "double", "str"], + ) + + with self.subTest(desc="with no identifier and no value columns"): + # select only columns that have common data type (double) + actual = df.select("id", "int", "double").unpivot( + ids=None, values=None, variableColumnName="var", valueColumnName="val" + ) + self.assertEqual(actual.schema.simpleString(), "struct<var:string,val:double>") + self.assertEqual( + actual.collect(), + [ + Row(variable="id", value=1.0), + Row(variable="int", value=10.0), + Row(variable="double", value=1.0), + Row(variable="id", value=2.0), + Row(variable="int", value=20.0), + Row(variable="double", value=2.0), + Row(variable="id", value=3.0), + Row(variable="int", value=30.0), + Row(variable="double", value=3.0), + ], + ) + + with self.subTest(desc="with no identifier column and multiple value columns"): + for id in [None, [], ()]: + for values in [["int", "double"], ("int", "double")]: + with self.subTest(ids=id, values=values): + actual = df.unpivot(id, values, "var", "val") + self.assertEqual( + actual.schema.simpleString(), "struct<var:string,val:double>" + ) + self.assertEqual( + actual.collect(), + [ + Row(variable="int", value=10.0), + Row(variable="double", value=1.0), + Row(variable="int", value=20.0), + Row(variable="double", value=2.0), + Row(variable="int", value=30.0), + Row(variable="double", value=3.0), + ], + ) + + with self.subTest(desc="with single identifier column and multiple value columns"): + for id in ["id", ["id"], ("id",)]: + for values in [["int", "double"], ("int", "double")]: + with self.subTest(ids=id, values=values): + actual = df.unpivot(id, values, "var", "val") + self.assertEqual( + actual.schema.simpleString(), + "struct<id:bigint,var:string,val:double>", + ) + self.assertEqual( + actual.collect(), + [ + Row(id=1, variable="int", value=10.0), + Row(id=1, variable="double", value=1.0), + Row(id=2, variable="int", value=20.0), + Row(id=2, variable="double", value=2.0), + Row(id=3, variable="int", value=30.0), + Row(id=3, variable="double", value=3.0), + ], + ) + + with self.subTest(desc="with multiple identifier columns and single given value columns"): + for ids in [["id", "double"], ("id", "double")]: + for values in ["str", ["str"], ("str",)]: + with self.subTest(ids=ids, values=values): + actual = df.unpivot(ids, values, "var", "val") + self.assertEqual( + actual.schema.simpleString(), + "struct<id:bigint,double:double,var:string,val:string>", + ) + self.assertEqual( + actual.collect(), + [ + Row(id=1, double=1.0, variable="str", value="one"), + Row(id=2, double=2.0, variable="str", value="two"), + Row(id=3, double=3.0, variable="str", value="three"), + ], + ) + + with self.subTest(desc="with multiple identifier columns but no given value columns"): + for ids in [["id", "str"], ("id", "str")]: + for values in [None, [], ()]: + with self.subTest(ids=ids, values=values): + actual = df.unpivot(ids, values, "var", "val") + self.assertEqual( + actual.schema.simpleString(), + "struct<id:bigint,str:string,var:string,val:double>", + ) + self.assertEqual( + actual.collect(), + [ + Row(id=1, str="one", variable="int", value=10.0), + Row(id=1, str="one", variable="double", value=1.0), + Row(id=2, str="two", variable="int", value=20.0), + Row(id=2, str="two", variable="double", value=2.0), + Row(id=3, str="three", variable="int", value=30.0), + Row(id=3, str="three", variable="double", value=3.0), + ], + ) + + with self.subTest(desc="with value columns without common data type"): + with self.assertRaisesRegex( + AnalysisException, + r"\[UNPIVOT_VALUE_DATA_TYPE_MISMATCH\] Unpivot value columns must share " + r"a least common type, some types do not: .*", + ): + df.unpivot("id", ["int", "str"], "var", "val") + + with self.subTest(desc="with columns"): + for id in [df.id, [df.id], (df.id,)]: + for values in [[df.int, df.double], (df.int, df.double)]: + with self.subTest(ids=id, values=values): + self.assertEqual( + df.unpivot(id, values, "var", "val").collect(), + df.unpivot("id", ["int", "double"], "var", "val").collect(), + ) + + with self.subTest(desc="with column names and columns"): + for ids in [[df.id, "str"], (df.id, "str")]: + for values in [[df.int, "double"], (df.int, "double")]: + with self.subTest(ids=ids, values=values): + self.assertEqual( + df.unpivot(ids, values, "var", "val").collect(), + df.unpivot(["id", "str"], ["int", "double"], "var", "val").collect(), + ) + + with self.subTest(desc="melt alias"): + self.assertEqual( + df.unpivot("id", ["int", "double"], "var", "val").collect(), + df.melt("id", ["int", "double"], "var", "val").collect(), + ) + def test_observe(self): # SPARK-36263: tests the DataFrame.observe(Observation, *Column) method from pyspark.sql import Observation diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 4bc337e5af3..3f0cef33b5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2107,6 +2107,17 @@ class Dataset[T] private[sql]( valueColumnName: String): DataFrame = unpivot(ids, Array.empty, variableColumnName, valueColumnName) + /** + * Called from Python as Seq[Column] are easier to create via py4j than Array[Column]. + * We use Array[Column] for unpivot rather than Seq[Column] as those are Java-friendly. + */ + private[sql] def unpivotWithSeq( + ids: Seq[Column], + values: Seq[Column], + variableColumnName: String, + valueColumnName: String): DataFrame = + unpivot(ids.toArray, values.toArray, variableColumnName, valueColumnName) + /** * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns set. * This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation, --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org