This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 55c3347c48f [SPARK-38864][SQL] Add unpivot / melt to Dataset 55c3347c48f is described below commit 55c3347c48f93a9c5c5c2fb00b30f838eb081b7f Author: Enrico Minack <git...@enrico.minack.dev> AuthorDate: Tue Jul 26 15:50:03 2022 +0800 [SPARK-38864][SQL] Add unpivot / melt to Dataset ### What changes were proposed in this pull request? This proposes a Scala implementation of the `melt` (aka. `unpivot`) operation. ### Why are the changes needed? The Scala Dataset API provides the `pivot` operation, but not its reverse operation `unpivot` or `melt`. The `melt` operation is part of the [Pandas API](https://pandas.pydata.org/docs/reference/api/pandas.melt.html), which is why this method is provided by PySpark Pandas API, implemented purely in Python. [It should be implemented in Scala](https://github.com/apache/spark/pull/26912#pullrequestreview-332975715) to make this operation available to Scala / Java, SQL, PySpark, and to reuse the implementation in PySpark Pandas APIs. The `melt` / `unpivot` operation exists in other systems like [BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax#unpivot_operator), [T-SQL](https://docs.microsoft.com/en-us/sql/t-sql/queries/from-using-pivot-and-unpivot?view=sql-server-ver15#unpivot-example), [Oracle](https://www.oracletutorial.com/oracle-basics/oracle-unpivot/). It supports expressions for ids and value columns including `*` expansion and structs. So this also fixes / includes SPARK-39292. ### Does this PR introduce _any_ user-facing change? It adds `melt` to the `Dataset` API (Scala and Java). ### How was this patch tested? It is tested in the `DatasetMeltSuite` and `JavaDatasetSuite`. Closes #36150 from EnricoMi/branch-melt. Authored-by: Enrico Minack <git...@enrico.minack.dev> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- core/src/main/resources/error/error-classes.json | 12 + .../spark/sql/catalyst/analysis/Analyzer.scala | 41 ++ .../sql/catalyst/analysis/AnsiTypeCoercion.scala | 1 + .../sql/catalyst/analysis/CheckAnalysis.scala | 8 + .../spark/sql/catalyst/analysis/TypeCoercion.scala | 16 + .../plans/logical/basicLogicalOperators.scala | 39 ++ .../sql/catalyst/rules/RuleIdCollection.scala | 1 + .../spark/sql/catalyst/trees/TreePatterns.scala | 1 + .../spark/sql/errors/QueryCompilationErrors.scala | 18 + .../main/scala/org/apache/spark/sql/Dataset.scala | 138 +++++- .../spark/sql/RelationalGroupedDataset.scala | 18 + .../org/apache/spark/sql/DatasetUnpivotSuite.scala | 543 +++++++++++++++++++++ .../spark/sql/errors/QueryErrorsSuiteBase.scala | 3 +- 13 files changed, 837 insertions(+), 2 deletions(-) diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index e2a99c1a62e..29ca280719e 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -375,6 +375,18 @@ "Unable to acquire <requestedBytes> bytes of memory, got <receivedBytes>" ] }, + "UNPIVOT_REQUIRES_VALUE_COLUMNS" : { + "message" : [ + "At least one value column needs to be specified for UNPIVOT, all columns specified as ids" + ], + "sqlState" : "42000" + }, + "UNPIVOT_VALUE_DATA_TYPE_MISMATCH" : { + "message" : [ + "Unpivot value columns must share a least common type, some types do not: [<types>]" + ], + "sqlState" : "42000" + }, "UNRECOGNIZED_SQL_TYPE" : { "message" : [ "Unrecognized SQL type <typeName>" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f40c260eb6f..a6108c2a3d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -293,6 +293,7 @@ class Analyzer(override val catalogManager: CatalogManager) ResolveUpCast :: ResolveGroupingAnalytics :: ResolvePivot :: + ResolveUnpivot :: ResolveOrdinalInOrderByAndGroupBy :: ResolveAggAliasInGroupBy :: ResolveMissingReferences :: @@ -514,6 +515,10 @@ class Analyzer(override val catalogManager: CatalogManager) if child.resolved && groupByOpt.isDefined && hasUnresolvedAlias(groupByOpt.get) => Pivot(Some(assignAliases(groupByOpt.get)), pivotColumn, pivotValues, aggregates, child) + case up: Unpivot if up.child.resolved && + (hasUnresolvedAlias(up.ids) || hasUnresolvedAlias(up.values)) => + up.copy(ids = assignAliases(up.ids), values = assignAliases(up.values)) + case Project(projectList, child) if child.resolved && hasUnresolvedAlias(projectList) => Project(assignAliases(projectList), child) @@ -859,6 +864,36 @@ class Analyzer(override val catalogManager: CatalogManager) } } + object ResolveUnpivot extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( + _.containsPattern(UNPIVOT), ruleId) { + + // once children and ids are resolved, we can determine values, if non were given + case up: Unpivot if up.childrenResolved && up.ids.forall(_.resolved) && up.values.isEmpty => + up.copy(values = up.child.output.diff(up.ids)) + + case up: Unpivot if !up.childrenResolved || !up.ids.forall(_.resolved) || + up.values.isEmpty || !up.values.forall(_.resolved) || !up.valuesTypeCoercioned => up + + // TypeCoercionBase.UnpivotCoercion determines valueType + // and casts values once values are set and resolved + case Unpivot(ids, values, variableColumnName, valueColumnName, child) => + // construct unpivot expressions for Expand + val exprs: Seq[Seq[Expression]] = values.map { + value => ids ++ Seq(Literal(value.name), value) + } + + // construct output attributes + val output = ids.map(_.toAttribute) ++ Seq( + AttributeReference(variableColumnName, StringType, nullable = false)(), + AttributeReference(valueColumnName, values.head.dataType, values.exists(_.nullable))() + ) + + // expand the unpivot expressions + Expand(exprs, output, child) + } + } + private def isResolvingView: Boolean = AnalysisContext.get.catalogAndNamespace.nonEmpty private def isReferredTempViewName(nameParts: Seq[String]): Boolean = { AnalysisContext.get.referredTempViewNames.exists { n => @@ -1349,6 +1384,12 @@ class Analyzer(override val catalogManager: CatalogManager) case g: Generate if containsStar(g.generator.children) => throw QueryCompilationErrors.invalidStarUsageError("explode/json_tuple/UDTF", extractStar(g.generator.children)) + // If the Unpivot ids or values contain Stars, expand them. + case up: Unpivot if containsStar(up.ids) || containsStar(up.values) => + up.copy( + ids = buildExpandedProjectList(up.ids, up.child), + values = buildExpandedProjectList(up.values, up.child) + ) case u @ Union(children, _, _) // if there are duplicate output columns, give them unique expr ids diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala index fd3885fe834..56dbb2a8590 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala @@ -74,6 +74,7 @@ import org.apache.spark.sql.types._ */ object AnsiTypeCoercion extends TypeCoercionBase { override def typeCoercionRules: List[Rule[LogicalPlan]] = + UnpivotCoercion :: WidenSetOperationTypes :: new AnsiCombinedTypeCoercionRule( InConversion :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index cf734b7aa26..3f5b535b947 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -422,6 +422,14 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { } metrics.foreach(m => checkMetric(m, m)) + // see Analyzer.ResolveUnpivot + case up: Unpivot + if up.childrenResolved && up.ids.forall(_.resolved) && up.values.isEmpty => + throw QueryCompilationErrors.unpivotRequiresValueColumns() + // see TypeCoercionBase.UnpivotCoercion + case up: Unpivot if !up.valuesTypeCoercioned => + throw QueryCompilationErrors.unpivotValDataTypeMismatchError(up.values) + case Sort(orders, _, _) => orders.foreach { order => if (!RowOrdering.isOrderable(order.dataType)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index c3db4787eca..4e66c87f361 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -198,6 +198,21 @@ abstract class TypeCoercionBase { } } + /** + * Widens the data types of the [[Unpivot]] values. + */ + object UnpivotCoercion extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case up: Unpivot + if up.values.nonEmpty && up.values.forall(_.resolved) && !up.valuesTypeCoercioned => + val valueDataType = findWiderTypeWithoutStringPromotion(up.values.map(_.dataType)) + val values = valueDataType.map(valueType => + up.values.map(value => Alias(Cast(value, valueType), value.name)()) + ).getOrElse(up.values) + up.copy(values = values) + } + } + /** * Widens the data types of the children of Union/Except/Intersect. * 1. When ANSI mode is off: @@ -806,6 +821,7 @@ abstract class TypeCoercionBase { object TypeCoercion extends TypeCoercionBase { override def typeCoercionRules: List[Rule[LogicalPlan]] = + UnpivotCoercion :: WidenSetOperationTypes :: new CombinedTypeCoercionRule( InConversion :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index bdc7bf9bd7d..22134a06288 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -1354,6 +1354,45 @@ case class Pivot( override protected def withNewChildInternal(newChild: LogicalPlan): Pivot = copy(child = newChild) } +/** + * A constructor for creating an Unpivot, which will later be converted to an [[Expand]] + * during the query analysis. + * + * An empty values array will be replaced during analysis with all resolved outputs of child except + * the ids. This expansion allows to easily unpivot all non-id columns. + * + * @see `org.apache.spark.sql.catalyst.analysis.Analyzer.ResolveUnpivot` + * + * The type of the value column is derived from all value columns during analysis once all values + * are resolved. All values' types have to be compatible, otherwise the result value column cannot + * be assigned the individual values and an AnalysisException is thrown. + * + * @see `org.apache.spark.sql.catalyst.analysis.TypeCoercionBase.UnpivotCoercion` + * + * @param ids Id columns + * @param values Value columns to unpivot + * @param variableColumnName Name of the variable column + * @param valueColumnName Name of the value column + * @param child Child operator + */ +case class Unpivot( + ids: Seq[NamedExpression], + values: Seq[NamedExpression], + variableColumnName: String, + valueColumnName: String, + child: LogicalPlan) extends UnaryNode { + override lazy val resolved = false // Unpivot will be replaced after being resolved. + override def output: Seq[Attribute] = Nil + override def metadataOutput: Seq[Attribute] = Nil + final override val nodePatterns: Seq[TreePattern] = Seq(UNPIVOT) + + override protected def withNewChildInternal(newChild: LogicalPlan): Unpivot = + copy(child = newChild) + + def valuesTypeCoercioned: Boolean = values.nonEmpty && values.forall(_.resolved) && + values.tail.forall(v => v.dataType.sameType(values.head.dataType)) +} + /** * A constructor for creating a logical limit, which is split into two separate logical nodes: * a [[LocalLimit]], which is a partition local limit, followed by a [[GlobalLimit]]. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala index 2f118db8248..eda6ff60e61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala @@ -71,6 +71,7 @@ object RuleIdCollection { "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveSubqueryColumnAliases" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveTables" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveTempViews" :: + "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveUnpivot" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveUpCast" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveUserSpecifiedColumns" :: "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveWindowFrame" :: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index 93273b5a2c7..3342f11a0fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -87,6 +87,7 @@ object TreePattern extends Enumeration { val TRUE_OR_FALSE_LITERAL: Value = Value val WINDOW_EXPRESSION: Value = Value val UNARY_POSITIVE: Value = Value + val UNPIVOT: Value = Value val UPDATE_FIELDS: Value = Value val UPPER_OR_LOWER: Value = Value val UP_CAST: Value = Value diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index c828318f2cd..c344c64997f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -92,6 +92,24 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { pivotVal.toString, pivotVal.dataType.simpleString, pivotCol.dataType.catalogString)) } + def unpivotRequiresValueColumns(): Throwable = { + new AnalysisException( + errorClass = "UNPIVOT_REQUIRES_VALUE_COLUMNS", + messageParameters = Array.empty) + } + + def unpivotValDataTypeMismatchError(values: Seq[NamedExpression]): Throwable = { + val dataTypes = values + .groupBy(_.dataType) + .mapValues(values => values.map(value => toSQLId(value.toString))) + .mapValues(values => if (values.length > 3) values.take(3) :+ "..." else values) + .map { case (dataType, values) => s"${toSQLType(dataType)} (${values.mkString(", ")})" } + + new AnalysisException( + errorClass = "UNPIVOT_VALUE_DATA_TYPE_MISMATCH", + messageParameters = Array(dataTypes.mkString(", "))) + } + def unsupportedIfNotExistsError(tableName: String): Throwable = { new AnalysisException( errorClass = "UNSUPPORTED_FEATURE", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index bc0b37e5923..49b4a8389f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1065,7 +1065,7 @@ class Dataset[T] private[sql]( * @param joinType Type of join to perform. Default `inner`. Must be one of: * `inner`, `cross`, `outer`, `full`, `fullouter`, `full_outer`, `left`, * `leftouter`, `left_outer`, `right`, `rightouter`, `right_outer`, - * `semi`, `leftsemi`, `left_semi`, `anti`, `leftanti`, left_anti`. + * `semi`, `leftsemi`, `left_semi`, `anti`, `leftanti`, `left_anti`. * * @note If you perform a self-join using this function without aliasing the input * `DataFrame`s, you will NOT be able to reference any columns after the join, since @@ -2036,6 +2036,142 @@ class Dataset[T] private[sql]( @scala.annotation.varargs def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs : _*) + /** + * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns set. + * This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation, + * which cannot be reversed. + * + * This function is useful to massage a DataFrame into a format where some + * columns are identifier columns ("ids"), while all other columns ("values") + * are "unpivoted" to the rows, leaving just two non-id columns, named as given + * by `variableColumnName` and `valueColumnName`. + * + * {{{ + * val df = Seq((1, 11, 12L), (2, 21, 22L)).toDF("id", "int", "long") + * df.show() + * // output: + * // +---+---+----+ + * // | id|int|long| + * // +---+---+----+ + * // | 1| 11| 12| + * // | 2| 21| 22| + * // +---+---+----+ + * + * df.unpivot(Array($"id"), Array($"int", $"long"), "variable", "value").show() + * // output: + * // +---+--------+-----+ + * // | id|variable|value| + * // +---+--------+-----+ + * // | 1| int| 11| + * // | 1| long| 12| + * // | 2| int| 21| + * // | 2| long| 22| + * // +---+--------+-----+ + * // schema: + * //root + * // |-- id: integer (nullable = false) + * // |-- variable: string (nullable = false) + * // |-- value: long (nullable = true) + * }}} + * + * When no "id" columns are given, the unpivoted DataFrame consists of only the + * "variable" and "value" columns. + * + * All "value" columns must share a least common data type. Unless they are the same data type, + * all "value" columns are cast to the nearest common data type. For instance, + * types `IntegerType` and `LongType` are cast to `LongType`, while `IntegerType` and `StringType` + * do not have a common data type and `unpivot` fails. + * + * @param ids Id columns + * @param values Value columns to unpivot + * @param variableColumnName Name of the variable column + * @param valueColumnName Name of the value column + * + * @group untypedrel + * @since 3.4.0 + */ + def unpivot( + ids: Array[Column], + values: Array[Column], + variableColumnName: String, + valueColumnName: String): DataFrame = withPlan { + Unpivot( + ids.map(_.named), + values.map(_.named), + variableColumnName, + valueColumnName, + logicalPlan + ) + } + + /** + * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns set. + * This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation, + * which cannot be reversed. + * + * @see `org.apache.spark.sql.Dataset.unpivot(Array, Array, String, String)` + * + * This is equivalent to calling `Dataset#unpivot(Array, Array, String, String)` + * where `values` is set to all non-id columns that exist in the DataFrame. + * + * @param ids Id columns + * @param variableColumnName Name of the variable column + * @param valueColumnName Name of the value column + * + * @group untypedrel + * @since 3.4.0 + */ + def unpivot( + ids: Array[Column], + variableColumnName: String, + valueColumnName: String): DataFrame = + unpivot(ids, Array.empty, variableColumnName, valueColumnName) + + /** + * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns set. + * This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation, + * which cannot be reversed. This is an alias for `unpivot`. + * + * @see `org.apache.spark.sql.Dataset.unpivot(Array, Array, String, String)` + * + * @param ids Id columns + * @param values Value columns to unpivot + * @param variableColumnName Name of the variable column + * @param valueColumnName Name of the value column + * + * @group untypedrel + * @since 3.4.0 + */ + def melt( + ids: Array[Column], + values: Array[Column], + variableColumnName: String, + valueColumnName: String): DataFrame = + unpivot(ids, values, variableColumnName, valueColumnName) + + /** + * Unpivot a DataFrame from wide format to long format, optionally leaving identifier columns set. + * This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the aggregation, + * which cannot be reversed. This is an alias for `unpivot`. + * + * @see `org.apache.spark.sql.Dataset.unpivot(Array, Array, String, String)` + * + * This is equivalent to calling `Dataset#unpivot(Array, Array, String, String)` + * where `values` is set to all non-id columns that exist in the DataFrame. + * + * @param ids Id columns + * @param variableColumnName Name of the variable column + * @param valueColumnName Name of the value column + * + * @group untypedrel + * @since 3.4.0 + */ + def melt( + ids: Array[Column], + variableColumnName: String, + valueColumnName: String): DataFrame = + unpivot(ids, variableColumnName, valueColumnName) + /** * Define (named) metrics to observe on the Dataset. This method returns an 'observed' Dataset * that returns the same result as the input, with the following guarantees: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 7e3c6221961..989ee325218 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -343,6 +343,9 @@ class RelationalGroupedDataset protected[sql]( * df.groupBy("year").pivot("course").sum("earnings") * }}} * + * @see `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, + * except for the aggregation. + * * @param pivotColumn Name of the column to pivot. * @since 1.6.0 */ @@ -371,6 +374,9 @@ class RelationalGroupedDataset protected[sql]( * .agg(sum($"earnings")) * }}} * + * @see `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, + * except for the aggregation. + * * @param pivotColumn Name of the column to pivot. * @param values List of values that will be translated to columns in the output DataFrame. * @since 1.6.0 @@ -395,6 +401,9 @@ class RelationalGroupedDataset protected[sql]( * df.groupBy("year").pivot("course").sum("earnings"); * }}} * + * @see `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, + * except for the aggregation. + * * @param pivotColumn Name of the column to pivot. * @param values List of values that will be translated to columns in the output DataFrame. * @since 1.6.0 @@ -412,6 +421,9 @@ class RelationalGroupedDataset protected[sql]( * df.groupBy($"year").pivot($"course").sum($"earnings"); * }}} * + * @see `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, + * except for the aggregation. + * * @param pivotColumn he column to pivot. * @since 2.4.0 */ @@ -444,6 +456,9 @@ class RelationalGroupedDataset protected[sql]( * df.groupBy($"year").pivot($"course", Seq("dotNET", "Java")).sum($"earnings") * }}} * + * @see `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, + * except for the aggregation. + * * @param pivotColumn the column to pivot. * @param values List of values that will be translated to columns in the output DataFrame. * @since 2.4.0 @@ -477,6 +492,9 @@ class RelationalGroupedDataset protected[sql]( * aggregation. This is an overloaded version of the `pivot` method with `pivotColumn` of * the `String` type. * + * @see `org.apache.spark.sql.Dataset.unpivot` for the reverse operation, + * except for the aggregation. + * * @param pivotColumn the column to pivot. * @param values List of values that will be translated to columns in the output DataFrame. * @since 2.4.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetUnpivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetUnpivotSuite.scala new file mode 100644 index 00000000000..8ccad457e8d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetUnpivotSuite.scala @@ -0,0 +1,543 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.errors.QueryErrorsSuiteBase +import org.apache.spark.sql.functions.{length, struct, sum} +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types._ + +/** + * Comprehensive tests for Dataset.unpivot. + */ +class DatasetUnpivotSuite extends QueryTest + with QueryErrorsSuiteBase + with SharedSparkSession { + import testImplicits._ + + lazy val wideDataDs: Dataset[WideData] = Seq( + WideData(1, "one", "One", Some(1), Some(1L)), + WideData(2, "two", null, None, Some(2L)), + WideData(3, null, "three", Some(3), None), + WideData(4, null, null, None, None) + ).toDS() + + val longDataRows = Seq( + Row(1, "str1", "one"), + Row(1, "str2", "One"), + Row(2, "str1", "two"), + Row(2, "str2", null), + Row(3, "str1", null), + Row(3, "str2", "three"), + Row(4, "str1", null), + Row(4, "str2", null) + ) + + val longDataWithoutIdRows: Seq[Row] = + longDataRows.map(row => Row(row.getString(1), row.getString(2))) + + val longSchema: StructType = StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("var", StringType, nullable = false), + StructField("val", StringType, nullable = true) + )) + + lazy val wideStructDataDs: DataFrame = wideDataDs.select( + struct($"id").as("an"), + struct( + $"str1".as("one"), + $"str2".as("two") + ).as("str") + ) + val longStructDataRows: Seq[Row] = longDataRows.map(row => + Row( + row.getInt(0), + row.getString(1) match { + case "str1" => "one" + case "str2" => "two" + }, + row.getString(2)) + ) + + test("overloaded unpivot without values") { + val ds = wideDataDs.select($"id", $"str1", $"str2") + checkAnswer( + ds.unpivot(Array($"id"), "var", "val"), + ds.unpivot(Array($"id"), Array.empty, "var", "val")) + } + + test("unpivot with single id") { + val unpivoted = wideDataDs + .unpivot( + Array($"id"), + Array($"str1", $"str2"), + variableColumnName = "var", + valueColumnName = "val") + assert(unpivoted.schema === longSchema) + checkAnswer(unpivoted, longDataRows) + } + + test("unpivot with two ids") { + val unpivotedRows = Seq( + Row(1, 1, "str1", "one"), + Row(1, 1, "str2", "One"), + Row(2, null, "str1", "two"), + Row(2, null, "str2", null), + Row(3, 3, "str1", null), + Row(3, 3, "str2", "three"), + Row(4, null, "str1", null), + Row(4, null, "str2", null)) + + val unpivoted = wideDataDs + .unpivot( + Array($"id", $"int1"), + Array($"str1", $"str2"), + variableColumnName = "var", + valueColumnName = "val") + assert(unpivoted.schema === StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("int1", IntegerType, nullable = true), + StructField("var", StringType, nullable = false), + StructField("val", StringType, nullable = true)))) + checkAnswer(unpivoted, unpivotedRows) + } + + test("unpivot without ids") { + val unpivoted = wideDataDs + .unpivot( + Array.empty, + Array($"str1", $"str2"), + variableColumnName = "var", + valueColumnName = "val") + assert(unpivoted.schema === StructType(Seq( + StructField("var", StringType, nullable = false), + StructField("val", StringType, nullable = true)))) + checkAnswer(unpivoted, longDataWithoutIdRows) + } + + test("unpivot without values") { + val unpivoted = wideDataDs.select($"id", $"str1", $"str2") + .unpivot( + Array($"id"), + variableColumnName = "var", + valueColumnName = "val") + assert(unpivoted.schema === longSchema) + checkAnswer(unpivoted, longDataRows) + + val unpivoted2 = wideDataDs.select($"id", $"str1", $"str2") + .unpivot( + Array($"id"), + Array.empty, + variableColumnName = "var", + valueColumnName = "val") + assert(unpivoted2.schema === longSchema) + checkAnswer(unpivoted2, longDataRows) + + val unpivotedRows = Seq( + Row(1, "id", 1L), + Row(1, "int1", 1L), + Row(1, "long1", 1L), + Row(2, "id", 2L), + Row(2, "int1", null), + Row(2, "long1", 2L), + Row(3, "id", 3L), + Row(3, "int1", 3L), + Row(3, "long1", null), + Row(4, "id", 4L), + Row(4, "int1", null), + Row(4, "long1", null) + ) + + val unpivoted3 = wideDataDs.select($"id", $"int1", $"long1") + .unpivot( + Array($"id" * 2), + Array.empty, + variableColumnName = "var", + valueColumnName = "val") + assert(unpivoted3.schema === StructType(Seq( + StructField("(id * 2)", IntegerType, nullable = false), + StructField("var", StringType, nullable = false), + StructField("val", LongType, nullable = true) + ))) + checkAnswer(unpivoted3, unpivotedRows.map(row => + Row(row.getInt(0) * 2, row.get(1), row.get(2)))) + + val unpivoted4 = wideDataDs.select($"id", $"int1", $"long1") + .unpivot( + Array($"id".as("uid")), + Array.empty, + variableColumnName = "var", + valueColumnName = "val") + assert(unpivoted4.schema === StructType(Seq( + StructField("uid", IntegerType, nullable = false), + StructField("var", StringType, nullable = false), + StructField("val", LongType, nullable = true) + ))) + checkAnswer(unpivoted4, unpivotedRows) + } + + test("unpivot without ids or values") { + val unpivoted = wideDataDs.select($"str1", $"str2") + .unpivot( + Array.empty, + Array.empty, + variableColumnName = "var", + valueColumnName = "val") + assert(unpivoted.schema === StructType(Seq( + StructField("var", StringType, nullable = false), + StructField("val", StringType, nullable = true)))) + checkAnswer(unpivoted, longDataWithoutIdRows) + } + + test("unpivot with star values") { + val unpivoted = wideDataDs.select($"str1", $"str2") + .unpivot( + Array.empty, + Array($"*"), + variableColumnName = "var", + valueColumnName = "val") + assert(unpivoted.schema === StructType(Seq( + StructField("var", StringType, nullable = false), + StructField("val", StringType, nullable = true)))) + checkAnswer(unpivoted, longDataWithoutIdRows) + } + + test("unpivot with id and star values") { + val unpivoted = wideDataDs.select($"id", $"int1", $"long1") + .unpivot( + Array($"id"), + Array($"*"), + variableColumnName = "var", + valueColumnName = "val") + + assert(unpivoted.schema === StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("var", StringType, nullable = false), + StructField("val", LongType, nullable = true)))) + + checkAnswer(unpivoted, wideDataDs.collect().flatMap { row => Seq( + Row(row.id, "id", row.id), + Row(row.id, "int1", row.int1.orNull), + Row(row.id, "long1", row.long1.orNull) + )}) + } + + test("unpivot with expressions") { + // ids and values are all expressions (computed) + val unpivoted = wideDataDs + .unpivot( + Array(($"id" * 10).as("primary"), $"str1".as("secondary")), + Array(($"int1" + $"long1").as("sum"), length($"str2").as("len")), + variableColumnName = "var", + valueColumnName = "val") + + assert(unpivoted.schema === StructType(Seq( + StructField("primary", IntegerType, nullable = false), + StructField("secondary", StringType, nullable = true), + StructField("var", StringType, nullable = false), + StructField("val", LongType, nullable = true)))) + + checkAnswer(unpivoted, wideDataDs.collect().flatMap { row => + Seq( + Row( + row.id * 10, + row.str1, + "sum", + // sum of int1 and long1 when both are set, or null otherwise + row.int1.flatMap(i => row.long1.map(l => i + l)).orNull), + Row( + row.id * 10, + row.str1, + "len", + // length of str2 if set, or null otherwise + Option(row.str2).map(_.length).orNull) + ) + }) + } + + test("unpivot with variable / value columns") { + // with value column `variable` and `value` + val unpivoted = wideDataDs + .withColumnRenamed("str1", "var") + .withColumnRenamed("str2", "val") + .unpivot( + Array($"id"), + Array($"var", $"val"), + variableColumnName = "var", + valueColumnName = "val") + checkAnswer(unpivoted, longDataRows.map(row => Row( + row.getInt(0), + row.getString(1) match { + case "str1" => "var" + case "str2" => "val" + }, + row.getString(2)))) + } + + test("unpivot with incompatible value types") { + val e = intercept[AnalysisException] { + wideDataDs + .select( + $"id", + $"str1", + $"int1", $"int1".as("int2"), $"int1".as("int3"), $"int1".as("int4"), + $"long1", $"long1".as("long2") + ) + .unpivot( + Array($"id"), + Array(), + variableColumnName = "var", + valueColumnName = "val" + ) + } + checkErrorClass( + exception = e, + errorClass = "UNPIVOT_VALUE_DATA_TYPE_MISMATCH", + msg = "Unpivot value columns must share a least common type, some types do not: \\[" + + "\"STRING\" \\(`str1#\\d+`\\), " + + "\"INT\" \\(`int1#\\d+`, `int2#\\d+`, `int3#\\d+`, ...\\), " + + "\"BIGINT\" \\(`long1#\\d+L`, `long2#\\d+L`\\)\\];(\n.*)*", + matchMsg = true) + } + + test("unpivot with compatible value types") { + val unpivoted = wideDataDs.unpivot( + Array($"id"), + Array($"int1", $"long1"), + variableColumnName = "var", + valueColumnName = "val") + assert(unpivoted.schema === StructType(Seq( + StructField("id", IntegerType, nullable = false), + StructField("var", StringType, nullable = false), + StructField("val", LongType, nullable = true) + ))) + + val unpivotedRows = Seq( + Row(1, "int1", 1L), + Row(1, "long1", 1L), + Row(2, "int1", null), + Row(2, "long1", 2L), + Row(3, "int1", 3L), + Row(3, "long1", null), + Row(4, "int1", null), + Row(4, "long1", null) + ) + checkAnswer(unpivoted, unpivotedRows) + } + + test("unpivot and drop nulls") { + checkAnswer( + wideDataDs + .unpivot(Array($"id"), Array($"str1", $"str2"), "var", "val") + .where($"val".isNotNull), + longDataRows.filter(_.getString(2) != null)) + } + + test("unpivot with invalid arguments") { + // unpivoting where id column does not exist + val e1 = intercept[AnalysisException] { + wideDataDs.unpivot( + Array($"1", $"2"), + Array($"str1", $"str2"), + variableColumnName = "var", + valueColumnName = "val" + ) + } + checkErrorClass( + exception = e1, + errorClass = "UNRESOLVED_COLUMN", + msg = "A column or function parameter with name `1` cannot be resolved\\. " + + "Did you mean one of the following\\? \\[`id`, `int1`, `str1`, `str2`, `long1`\\];(\n.*)*", + matchMsg = true) + + // unpivoting where value column does not exist + val e2 = intercept[AnalysisException] { + wideDataDs.unpivot( + Array($"id"), + Array($"does", $"not", $"exist"), + variableColumnName = "var", + valueColumnName = "val" + ) + } + checkErrorClass( + exception = e2, + errorClass = "UNRESOLVED_COLUMN", + msg = "A column or function parameter with name `does` cannot be resolved\\. " + + "Did you mean one of the following\\? \\[`id`, `int1`, `long1`, `str1`, `str2`\\];(\n.*)*", + matchMsg = true) + + // unpivoting with empty list of value columns + // where potential value columns are of incompatible types + val e3 = intercept[AnalysisException] { + wideDataDs.unpivot( + Array.empty, + Array.empty, + variableColumnName = "var", + valueColumnName = "val" + ) + } + checkErrorClass( + exception = e3, + errorClass = "UNPIVOT_VALUE_DATA_TYPE_MISMATCH", + msg = "Unpivot value columns must share a least common type, some types do not: \\[" + + "\"INT\" \\(`id#\\d+`, `int1#\\d+`\\), " + + "\"STRING\" \\(`str1#\\d+`, `str2#\\d+`\\), " + + "\"BIGINT\" \\(`long1#\\d+L`\\)\\];(\n.*)*", + matchMsg = true) + + // unpivoting with star id columns so that no value columns are left + val e4 = intercept[AnalysisException] { + wideDataDs.unpivot( + Array($"*"), + Array.empty, + variableColumnName = "var", + valueColumnName = "val" + ) + } + checkErrorClass( + exception = e4, + errorClass = "UNPIVOT_REQUIRES_VALUE_COLUMNS", + msg = "At least one value column needs to be specified for UNPIVOT, " + + "all columns specified as ids;(\\n.*)*", + matchMsg = true) + + // unpivoting with star value columns + // where potential value columns are of incompatible types + val e5 = intercept[AnalysisException] { + wideDataDs.unpivot( + Array.empty, + Array($"*"), + variableColumnName = "var", + valueColumnName = "val" + ) + } + checkErrorClass( + exception = e5, + errorClass = "UNPIVOT_VALUE_DATA_TYPE_MISMATCH", + msg = "Unpivot value columns must share a least common type, some types do not: \\[" + + "\"INT\" \\(`id#\\d+`, `int1#\\d+`\\), " + + "\"STRING\" \\(`str1#\\d+`, `str2#\\d+`\\), " + + "\"BIGINT\" \\(`long1#\\d+L`\\)\\];(\n.*)*", + matchMsg = true) + + // unpivoting without giving values and no non-id columns + val e6 = intercept[AnalysisException] { + wideDataDs.select($"id", $"str1", $"str2").unpivot( + Array($"id", $"str1", $"str2"), + Array.empty, + variableColumnName = "var", + valueColumnName = "val" + ) + } + checkErrorClass( + exception = e6, + errorClass = "UNPIVOT_REQUIRES_VALUE_COLUMNS", + msg = "At least one value column needs to be specified for UNPIVOT, " + + "all columns specified as ids;(\\n.*)*", + matchMsg = true) + } + + test("unpivot after pivot") { + // see test "pivot courses" in DataFramePivotSuite + val pivoted = courseSales.groupBy("year").pivot("course", Array("dotNET", "Java")) + .agg(sum($"earnings")) + val unpivoted = pivoted.unpivot(Array($"year"), "course", "earnings") + val expected = courseSales.groupBy("year", "course").sum("earnings") + checkAnswer(unpivoted, expected) + } + + test("unpivot of unpivot") { + checkAnswer( + wideDataDs + .unpivot(Array($"id"), Array($"str1", $"str2"), "var", "val") + .unpivot(Array($"id"), Array($"var", $"val"), "col", "value"), + longDataRows.flatMap(row => Seq( + Row(row.getInt(0), "var", row.getString(1)), + Row(row.getInt(0), "val", row.getString(2))))) + } + + test("unpivot with dot and backtick") { + val ds = wideDataDs + .withColumnRenamed("id", "an.id") + .withColumnRenamed("str1", "str.one") + .withColumnRenamed("str2", "str.two") + + val unpivoted = ds.unpivot( + Array($"`an.id`"), + Array($"`str.one`", $"`str.two`"), + variableColumnName = "var", + valueColumnName = "val") + checkAnswer(unpivoted, longDataRows.map(row => Row( + row.getInt(0), + row.getString(1) match { + case "str1" => "str.one" + case "str2" => "str.two" + }, + row.getString(2)))) + + // without backticks, this references struct fields, which do not exist + val e = intercept[AnalysisException] { + ds.unpivot( + Array($"an.id"), + Array($"str.one", $"str.two"), + variableColumnName = "var", + valueColumnName = "val" + ) + } + checkErrorClass( + exception = e, + errorClass = "UNRESOLVED_COLUMN", + // expected message is wrong: https://issues.apache.org/jira/browse/SPARK-39783 + msg = "A column or function parameter with name `an`\\.`id` cannot be resolved\\. " + + "Did you mean one of the following\\? " + + "\\[`an`.`id`, `int1`, `long1`, `str`.`one`, `str`.`two`\\];(\n.*)*", + matchMsg = true) + } + + test("unpivot with struct fields") { + checkAnswer( + wideStructDataDs.unpivot( + Array($"an.id"), + Array($"str.one", $"str.two"), + "var", + "val"), + longStructDataRows) + } + + test("unpivot with struct ids star") { + checkAnswer( + wideStructDataDs.unpivot( + Array($"an.*"), + Array($"str.one", $"str.two"), + "var", + "val"), + longStructDataRows) + } + + test("unpivot with struct values star") { + checkAnswer( + wideStructDataDs.unpivot( + Array($"an.id"), + Array($"str.*"), + "var", + "val"), + longStructDataRows) + } +} + +case class WideData(id: Int, str1: String, str2: String, int1: Option[Int], long1: Option[Long]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryErrorsSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryErrorsSuiteBase.scala index 895a72efeec..d78a6a91959 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryErrorsSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryErrorsSuiteBase.scala @@ -37,7 +37,8 @@ trait QueryErrorsSuiteBase extends SharedSparkSession { errorClass } if (matchMsg) { - assert(exception.getMessage.matches(s"""\\[$fullErrorClass\\] """ + msg)) + assert(exception.getMessage.matches(s"""\\[$fullErrorClass\\] """ + msg), + "exception is: " + exception.getMessage) } else { assert(exception.getMessage === s"""[$fullErrorClass] """ + msg) } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org