peter-toth commented on code in PR #55885:
URL: https://github.com/apache/spark/pull/55885#discussion_r3297542122
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TransformExpression.scala:
##########
@@ -17,46 +17,103 @@
package org.apache.spark.sql.catalyst.expressions
+import scala.util.Try
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.LogKeys.FUNCTION_NAME
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext,
ExprCode}
-import org.apache.spark.sql.connector.catalog.functions.{BoundFunction,
Reducer, ReducibleFunction, ScalarFunction}
+import org.apache.spark.sql.connector.catalog.functions.{BoundFunction,
Reducer, ReducibleFunction, ReducibleParameters, ScalarFunction}
import org.apache.spark.sql.errors.QueryExecutionErrors
-import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.types.{DataType, Decimal, DecimalType, StringType}
+import org.apache.spark.unsafe.types.UTF8String
/**
* Represents a partition transform expression, for instance, `bucket`,
`days`, `years`, etc.
*
* @param function the transform function itself. Spark will use it to decide
whether two
* partition transform expressions are compatible.
- * @param numBucketsOpt the number of buckets if the transform is `bucket`.
Unset otherwise.
*/
-case class TransformExpression(
- function: BoundFunction,
- children: Seq[Expression],
- numBucketsOpt: Option[Int] = None) extends Expression {
+case class TransformExpression(function: BoundFunction, children:
Seq[Expression])
+ extends Expression with Logging {
override def nullable: Boolean = true
/**
- * Whether this [[TransformExpression]] has the same semantics as `other`.
- * For instance, `bucket(32, c)` is equal to `bucket(32, d)`, but not to
`bucket(16, d)` or
- * `year(c)`.
+ * Extract literal children (constant parameters) from this transform. These
are constant
+ * arguments like width in truncate(col, width). Literals are compared when
checking if two
+ * transforms are the same.
+ */
+ private lazy val literalChildren: Seq[Literal] =
+ children.collect { case l: Literal => l }
+
+ /**
+ * Whether this [[TransformExpression]] has the same semantics as `other`.
For instance,
+ * `bucket(32, c)` is equal to `bucket(32, d)`, but not to `bucket(16, d)`
or `year(c)`.
+ * Similarly, `truncate(c, 2)` is equal to `truncate(d, 2)`, but may not to
`truncate(c, 4)`.
*
* This will be used, for instance, by Spark to determine whether
storage-partitioned join can
* be triggered, by comparing partition transforms from both sides of the
join and checking
* whether they are compatible.
*
- * @param other the transform expression to compare to
- * @return true if this and `other` has the same semantics w.r.t to
transform, false otherwise.
+ * Two transforms are considered the same if:
+ * 1. They have the same function name
+ * 2. They have the same literal arguments (e.g., numBuckets for bucket,
width for truncate)
+ *
+ * @param other
+ * the transform expression to compare to
+ * @return
+ * true if this and `other` has the same semantics w.r.t to transform,
false otherwise.
*/
def isSameFunction(other: TransformExpression): Boolean = other match {
- case TransformExpression(otherFunction, _, otherNumBucketsOpt) =>
- function.canonicalName() == otherFunction.canonicalName() &&
- numBucketsOpt == otherNumBucketsOpt
+ case TransformExpression(otherFunction, _) =>
+ val sameFunctionName = function.canonicalName() ==
otherFunction.canonicalName()
+
+ // Compare literal arguments to ensure transforms with different
parameters
+ // (e.g., bucket(32, col) vs bucket(16, col), truncate(col, 2) vs
truncate(col, 4))
+ // are not considered the same
+ val otherLiterals = other.literalChildren
+ val sameLiterals = literalChildren.length == otherLiterals.length &&
+ literalChildren.zip(otherLiterals).forall { case (l1, l2) =>
+ l1.equals(l2)
+ }
+
+ sameFunctionName && sameLiterals
case _ =>
false
}
+ /**
+ * Override canonicalized to ensure transforms with the same function and
literals are
+ * considered semantically equal, regardless of which specific column
references they use.
+ *
+ * This is crucial for Storage Partitioned Joins - we need bucket(4,
tableA.id) and bucket(4,
+ * tableB.id) to be semantically equal so SPJ can be triggered.
+ */
+ override lazy val canonicalized: Expression = {
+ // Canonicalize only the non-literal children (i.e., column references)
+ val canonicalizedReferenceChildren = children.map {
+ case l: Literal => l
+ case other => other.canonicalized
+ }
+ TransformExpression(function, canonicalizedReferenceChildren)
+ }
+
+ /**
+ * Override collectLeaves to only return reference children (columns), not
literal parameters.
+ *
+ * For TransformExpression, literal children are metadata about the
transform function (e.g.,
+ * numBuckets=4 in bucket(4, col), width=2 in truncate(col, 2)). All
consumers of
+ * collectLeaves() expect only column references, not these metadata
literals.
+ *
+ */
+ override def collectLeaves(): Seq[Expression] = {
+ children.flatMap {
Review Comment:
Actually, now I think we don't even need a new
`Expression.collectAttributes()`, but we can use `Expression.references()` in
SPJ planning. Adjusted the PR.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]