This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new f42c029fac5c [SPARK-41049][SQL][FOLLOW-UP] Mark map related expressions as stateful expressions f42c029fac5c is described below commit f42c029fac5c8015d80ad957fae325243a2ed30d Author: Rui Wang <rui.w...@databricks.com> AuthorDate: Mon May 27 22:40:13 2024 -0700 [SPARK-41049][SQL][FOLLOW-UP] Mark map related expressions as stateful expressions MapConcat contains a state so it is stateful: ``` private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType, dataType.valueType) ``` Similarly `MapFromEntries, CreateMap, MapFromArrays, StringToMap, and TransformKeys` need the same change. Stateful expression should be marked as stateful. No N/A No Closes #46721 from amaliujia/statefulexpr. Authored-by: Rui Wang <rui.w...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit af1ac1edc2a96c9aba949e3100ddae37b6f0e5b2) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../spark/sql/catalyst/expressions/collectionOperations.scala | 3 +++ .../spark/sql/catalyst/expressions/complexTypeCreator.scala | 6 ++++++ .../spark/sql/catalyst/expressions/higherOrderFunctions.scala | 2 ++ .../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 10 +++++++++- 4 files changed, 20 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 3ddbe38fdedf..45896382af67 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -712,6 +712,7 @@ case class MapConcat(children: Seq[Expression]) } } + override def stateful: Boolean = true override def nullable: Boolean = children.exists(_.nullable) private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType, dataType.valueType) @@ -827,6 +828,8 @@ case class MapFromEntries(child: Expression) override def nullable: Boolean = child.nullable || nullEntries + override def stateful: Boolean = true + @transient override lazy val dataType: MapType = dataTypeDetails.get._1 override def checkInputDataTypes(): TypeCheckResult = dataTypeDetails match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index c95a0987330d..1b6f86984be7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -242,6 +242,8 @@ case class CreateMap(children: Seq[Expression], useStringTypeWhenEmpty: Boolean) private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType, dataType.valueType) + override def stateful: Boolean = true + override def eval(input: InternalRow): Any = { var i = 0 while (i < keys.length) { @@ -317,6 +319,8 @@ case class MapFromArrays(left: Expression, right: Expression) valueContainsNull = right.dataType.asInstanceOf[ArrayType].containsNull) } + override def stateful: Boolean = true + private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType, dataType.valueType) override def nullSafeEval(keyArray: Any, valueArray: Any): Any = { @@ -563,6 +567,8 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E this(child, Literal(","), Literal(":")) } + override def stateful: Boolean = true + override def first: Expression = text override def second: Expression = pairDelim override def third: Expression = keyValueDelim diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index fec1df108bcc..5b10b401af98 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -918,6 +918,8 @@ case class TransformKeys( override def dataType: MapType = MapType(function.dataType, valueType, valueContainsNull) + override def stateful: Boolean = true + override def checkInputDataTypes(): TypeCheckResult = { TypeUtils.checkForMapKeyType(function.dataType) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index c586da6105fd..260ecaa5ece1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -35,7 +35,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobEnd} import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, EqualTo, ExpressionSet, GreaterThan, Literal, PythonUDF, ScalarSubquery, Uuid} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, CreateMap, EqualTo, ExpressionSet, GreaterThan, Literal, PythonUDF, ScalarSubquery, Uuid} import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, LocalRelation, LogicalPlan, OneRowRelation, Statistics} @@ -3636,6 +3636,14 @@ class DataFrameSuite extends QueryTest assert(row.getInt(0).toString == row.getString(2)) assert(row.getInt(0).toString == row.getString(3)) } + + val v3 = Column(CreateMap(Seq(Literal("key"), Literal("value")))) + val v4 = to_csv(struct(v3.as("a"))) // to_csv is CodegenFallback + df.select(v3, v3, v4, v4).collect().foreach { row => + assert(row.getMap(0).toString() == row.getMap(1).toString()) + assert(row.getString(2) == s"{key -> ${row.getMap(0).get("key").get}}") + assert(row.getString(3) == s"{key -> ${row.getMap(0).get("key").get}}") + } } test("SPARK-41219: IntegralDivide use decimal(1, 0) to represent 0") { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org