jinchengchenghh commented on code in PR #12000:
URL: https://github.com/apache/gluten/pull/12000#discussion_r3172863648
##########
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/AddFallbackTags.scala:
##########
@@ -40,4 +46,131 @@ case class AddFallbackTags(validator: Validator) extends
Rule[SparkPlan] {
case Validator.Passed =>
}
}
+
+ /**
+ * Traverses the plan tree looking for join nodes (SortMergeJoin,
ShuffledHashJoin,
+ * BroadcastHashJoin) whose join keys include at least one decimal column.
+ *
+ * For each such join, delegates to [[setFallbackTagForOtherSide]] to ensure
that if one side's
+ * scan ([[FileSourceScanExec]] or `HiveTableScanExec`) cannot be offloaded
to the native engine,
+ * the other side is also forced to fall back. This prevents a decimal-value
mismatch that would
+ * produce incorrect (typically empty) join results when one side applies
Spark's precision
+ * coercion and the other side reads raw native values.
+ *
+ * AdaptiveSparkPlanExec is handled by descending into its `initialPlan`;
all other non-join nodes
+ * are handled recursively through their children.
+ */
+ private def validateJoin(plan: SparkPlan): Unit =
+ plan match {
+ case smj: SortMergeJoinExec
+ if (smj.leftKeys ++
smj.rightKeys).exists(_.dataType.isInstanceOf[DecimalType]) =>
+ setFallbackTagForOtherSide(smj.left, smj.right)
+ case shj: ShuffledHashJoinExec
+ if (shj.leftKeys ++
shj.rightKeys).exists(_.dataType.isInstanceOf[DecimalType]) =>
+ setFallbackTagForOtherSide(shj.left, shj.right)
+ case bhj: BroadcastHashJoinExec
+ if (bhj.leftKeys ++
bhj.rightKeys).exists(_.dataType.isInstanceOf[DecimalType]) =>
+ setFallbackTagForOtherSide(bhj.left, bhj.right)
+ case a: AdaptiveSparkPlanExec =>
+ validateJoin(a.initialPlan)
+ case _ => plan.children.foreach(validateJoin(_))
+ }
+
+ /**
+ * Enforces symmetric scan fallback for the two sides of a decimal-key join.
+ *
+ * When the join key is a decimal type, a native (Velox) scan and a vanilla
Spark scan
+ * ([[FileSourceScanExec]] or `HiveTableScanExec`) may produce different
representations of the
+ * same decimal value: the native reader may surface raw uncoerced int128_t
values while the
Review Comment:
Can we update the native side to support this case?
##########
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/AddFallbackTags.scala:
##########
@@ -40,4 +46,131 @@ case class AddFallbackTags(validator: Validator) extends
Rule[SparkPlan] {
case Validator.Passed =>
}
}
+
+ /**
+ * Traverses the plan tree looking for join nodes (SortMergeJoin,
ShuffledHashJoin,
+ * BroadcastHashJoin) whose join keys include at least one decimal column.
+ *
+ * For each such join, delegates to [[setFallbackTagForOtherSide]] to ensure
that if one side's
+ * scan ([[FileSourceScanExec]] or `HiveTableScanExec`) cannot be offloaded
to the native engine,
+ * the other side is also forced to fall back. This prevents a decimal-value
mismatch that would
+ * produce incorrect (typically empty) join results when one side applies
Spark's precision
+ * coercion and the other side reads raw native values.
+ *
+ * AdaptiveSparkPlanExec is handled by descending into its `initialPlan`;
all other non-join nodes
+ * are handled recursively through their children.
+ */
+ private def validateJoin(plan: SparkPlan): Unit =
+ plan match {
+ case smj: SortMergeJoinExec
+ if (smj.leftKeys ++
smj.rightKeys).exists(_.dataType.isInstanceOf[DecimalType]) =>
+ setFallbackTagForOtherSide(smj.left, smj.right)
+ case shj: ShuffledHashJoinExec
+ if (shj.leftKeys ++
shj.rightKeys).exists(_.dataType.isInstanceOf[DecimalType]) =>
+ setFallbackTagForOtherSide(shj.left, shj.right)
+ case bhj: BroadcastHashJoinExec
+ if (bhj.leftKeys ++
bhj.rightKeys).exists(_.dataType.isInstanceOf[DecimalType]) =>
+ setFallbackTagForOtherSide(bhj.left, bhj.right)
+ case a: AdaptiveSparkPlanExec =>
+ validateJoin(a.initialPlan)
+ case _ => plan.children.foreach(validateJoin(_))
+ }
+
+ /**
+ * Enforces symmetric scan fallback for the two sides of a decimal-key join.
+ *
+ * When the join key is a decimal type, a native (Velox) scan and a vanilla
Spark scan
+ * ([[FileSourceScanExec]] or `HiveTableScanExec`) may produce different
representations of the
+ * same decimal value: the native reader may surface raw uncoerced int128_t
values while the
+ * vanilla reader applies Spark's precision coercion (returning NULL for
out-of-range values). If
+ * only one side falls back, the join key values diverge and the join
returns 0 rows.
+ *
+ * This method detects the asymmetric case (exactly one side contains a
fallback scan) and adds a
+ * fallback tag to the native scan on the other side, so that both sides end
up using the same
+ * read path.
+ *
+ * @param leftChild
+ * the left subtree of the join
+ * @param rightChild
+ * the right subtree of the join
+ */
+ private def setFallbackTagForOtherSide(leftChild: SparkPlan, rightChild:
SparkPlan): Unit = {
+ val leftHasFallbackScan = hasFallbackScan(leftChild)
Review Comment:
Not only the scan fallback cause this issue, after filter, it may also occur?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]