jinchengchenghh commented on code in PR #12000:
URL: https://github.com/apache/gluten/pull/12000#discussion_r3172863648


##########
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/AddFallbackTags.scala:
##########
@@ -40,4 +46,131 @@ case class AddFallbackTags(validator: Validator) extends 
Rule[SparkPlan] {
       case Validator.Passed =>
     }
   }
+
+  /**
+   * Traverses the plan tree looking for join nodes (SortMergeJoin, 
ShuffledHashJoin,
+   * BroadcastHashJoin) whose join keys include at least one decimal column.
+   *
+   * For each such join, delegates to [[setFallbackTagForOtherSide]] to ensure 
that if one side's
+   * scan ([[FileSourceScanExec]] or `HiveTableScanExec`) cannot be offloaded 
to the native engine,
+   * the other side is also forced to fall back. This prevents a decimal-value 
mismatch that would
+   * produce incorrect (typically empty) join results when one side applies 
Spark's precision
+   * coercion and the other side reads raw native values.
+   *
+   * AdaptiveSparkPlanExec is handled by descending into its `initialPlan`; 
all other non-join nodes
+   * are handled recursively through their children.
+   */
+  private def validateJoin(plan: SparkPlan): Unit =
+    plan match {
+      case smj: SortMergeJoinExec
+          if (smj.leftKeys ++ 
smj.rightKeys).exists(_.dataType.isInstanceOf[DecimalType]) =>
+        setFallbackTagForOtherSide(smj.left, smj.right)
+      case shj: ShuffledHashJoinExec
+          if (shj.leftKeys ++ 
shj.rightKeys).exists(_.dataType.isInstanceOf[DecimalType]) =>
+        setFallbackTagForOtherSide(shj.left, shj.right)
+      case bhj: BroadcastHashJoinExec
+          if (bhj.leftKeys ++ 
bhj.rightKeys).exists(_.dataType.isInstanceOf[DecimalType]) =>
+        setFallbackTagForOtherSide(bhj.left, bhj.right)
+      case a: AdaptiveSparkPlanExec =>
+        validateJoin(a.initialPlan)
+      case _ => plan.children.foreach(validateJoin(_))
+    }
+
+  /**
+   * Enforces symmetric scan fallback for the two sides of a decimal-key join.
+   *
+   * When the join key is a decimal type, a native (Velox) scan and a vanilla 
Spark scan
+   * ([[FileSourceScanExec]] or `HiveTableScanExec`) may produce different 
representations of the
+   * same decimal value: the native reader may surface raw uncoerced int128_t 
values while the

Review Comment:
   Can we update the native side to support this case?



##########
gluten-core/src/main/scala/org/apache/gluten/extension/columnar/heuristic/AddFallbackTags.scala:
##########
@@ -40,4 +46,131 @@ case class AddFallbackTags(validator: Validator) extends 
Rule[SparkPlan] {
       case Validator.Passed =>
     }
   }
+
+  /**
+   * Traverses the plan tree looking for join nodes (SortMergeJoin, 
ShuffledHashJoin,
+   * BroadcastHashJoin) whose join keys include at least one decimal column.
+   *
+   * For each such join, delegates to [[setFallbackTagForOtherSide]] to ensure 
that if one side's
+   * scan ([[FileSourceScanExec]] or `HiveTableScanExec`) cannot be offloaded 
to the native engine,
+   * the other side is also forced to fall back. This prevents a decimal-value 
mismatch that would
+   * produce incorrect (typically empty) join results when one side applies 
Spark's precision
+   * coercion and the other side reads raw native values.
+   *
+   * AdaptiveSparkPlanExec is handled by descending into its `initialPlan`; 
all other non-join nodes
+   * are handled recursively through their children.
+   */
+  private def validateJoin(plan: SparkPlan): Unit =
+    plan match {
+      case smj: SortMergeJoinExec
+          if (smj.leftKeys ++ 
smj.rightKeys).exists(_.dataType.isInstanceOf[DecimalType]) =>
+        setFallbackTagForOtherSide(smj.left, smj.right)
+      case shj: ShuffledHashJoinExec
+          if (shj.leftKeys ++ 
shj.rightKeys).exists(_.dataType.isInstanceOf[DecimalType]) =>
+        setFallbackTagForOtherSide(shj.left, shj.right)
+      case bhj: BroadcastHashJoinExec
+          if (bhj.leftKeys ++ 
bhj.rightKeys).exists(_.dataType.isInstanceOf[DecimalType]) =>
+        setFallbackTagForOtherSide(bhj.left, bhj.right)
+      case a: AdaptiveSparkPlanExec =>
+        validateJoin(a.initialPlan)
+      case _ => plan.children.foreach(validateJoin(_))
+    }
+
+  /**
+   * Enforces symmetric scan fallback for the two sides of a decimal-key join.
+   *
+   * When the join key is a decimal type, a native (Velox) scan and a vanilla 
Spark scan
+   * ([[FileSourceScanExec]] or `HiveTableScanExec`) may produce different 
representations of the
+   * same decimal value: the native reader may surface raw uncoerced int128_t 
values while the
+   * vanilla reader applies Spark's precision coercion (returning NULL for 
out-of-range values). If
+   * only one side falls back, the join key values diverge and the join 
returns 0 rows.
+   *
+   * This method detects the asymmetric case (exactly one side contains a 
fallback scan) and adds a
+   * fallback tag to the native scan on the other side, so that both sides end 
up using the same
+   * read path.
+   *
+   * @param leftChild
+   *   the left subtree of the join
+   * @param rightChild
+   *   the right subtree of the join
+   */
+  private def setFallbackTagForOtherSide(leftChild: SparkPlan, rightChild: 
SparkPlan): Unit = {
+    val leftHasFallbackScan = hasFallbackScan(leftChild)

Review Comment:
   Not only the scan fallback cause this issue, after filter, it may also occur?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to