This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new 2699898fc [SEDONA-532] Correctly handle complex join conditions (#1325)
2699898fc is described below

commit 2699898fc4ec96a1a523b1fed83a1d23040488de
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Wed Apr 10 00:50:40 2024 +0800

    [SEDONA-532] Correctly handle complex join conditions (#1325)
    
    * Add a violating test case to fix in later commit.
    
    * Refactored join condition matcher to make it work with complex join 
conditions
---
 .../sedona_sql/optimization/ExpressionUtils.scala  |  33 +++-
 .../strategy/join/JoinQueryDetector.scala          | 218 ++++++---------------
 .../strategy/join/OptimizableJoinCondition.scala   | 106 ++++++++++
 .../org/apache/sedona/sql/SpatialJoinSuite.scala   |  27 ++-
 .../org/apache/sedona/sql/functionTestScala.scala  |   2 -
 5 files changed, 223 insertions(+), 163 deletions(-)

diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/ExpressionUtils.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/ExpressionUtils.scala
index e67a22ff9..377f97a2e 100644
--- 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/ExpressionUtils.scala
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/ExpressionUtils.scala
@@ -20,6 +20,8 @@
 package org.apache.spark.sql.sedona_sql.optimization
 
 import org.apache.spark.sql.catalyst.expressions.{And, Expression}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.sedona_sql.strategy.join.{JoinSide, LeftSide, 
RightSide}
 
 /**
  * This class contains helper methods for transforming catalyst expressions.
@@ -38,7 +40,36 @@ object ExpressionUtils {
     condition match {
       case And(cond1, cond2) =>
         splitConjunctivePredicates(cond1) ++ splitConjunctivePredicates(cond2)
-      case other => other :: Nil
+      case other: Expression => other :: Nil
+    }
+  }
+
+  /**
+   * Returns true if specified expression has at least one reference and all 
its references
+   * map to the output of the specified plan.
+   */
+  def matches(expr: Expression, plan: LogicalPlan): Boolean =
+    expr.references.nonEmpty && expr.references.subsetOf(plan.outputSet)
+
+  def matchExpressionsToPlans(exprA: Expression,
+    exprB: Expression,
+    planA: LogicalPlan,
+    planB: LogicalPlan): Option[(LogicalPlan, LogicalPlan, Boolean)] =
+    if (matches(exprA, planA) && matches(exprB, planB)) {
+      Some((planA, planB, false))
+    } else if (matches(exprA, planB) && matches(exprB, planA)) {
+      Some((planB, planA, true))
+    } else {
+      None
+    }
+
+  def matchDistanceExpressionToJoinSide(distance: Expression, left: 
LogicalPlan, right: LogicalPlan): Option[JoinSide] = {
+    if (distance.references.isEmpty || matches(distance, left)) {
+      Some(LeftSide)
+    } else if (matches(distance, right)) {
+      Some(RightSide)
+    } else {
+      None
     }
   }
 }
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
index f35bab885..6e71eea9e 100644
--- 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.sedona_sql.expressions._
 import org.apache.spark.sql.sedona_sql.expressions.raster._
 import 
org.apache.spark.sql.sedona_sql.optimization.ExpressionUtils.splitConjunctivePredicates
 import org.apache.spark.sql.{SparkSession, Strategy}
+import org.apache.spark.sql.sedona_sql.optimization.ExpressionUtils.{matches, 
matchExpressionsToPlans, matchDistanceExpressionToJoinSide}
 
 
 case class JoinQueryDetection(
@@ -95,15 +96,15 @@ class JoinQueryDetector(sparkSession: SparkSession) extends 
Strategy {
   }
 
   def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
-    case Join(left, right, joinType, condition, JoinHint(leftHint, rightHint)) 
if optimizationEnabled(left, right, condition) => {
+    case Join(left, right, joinType, condition, JoinHint(leftHint, rightHint)) 
if optimizationEnabled(left, right, condition) =>
       var broadcastLeft = leftHint.exists(_.strategy.contains(BROADCAST))
       var broadcastRight = rightHint.exists(_.strategy.contains(BROADCAST))
 
       /*
-      If either side is small we can automatically broadcast just like Spark 
does.
-      This only applies to inner joins as there are no optimized fallback plan 
for other join types.
-      It's better that users are explicit about broadcasting for other join 
types than seeing wildly different behavior
-      depending on data size.
+       * If either side is small we can automatically broadcast just like 
Spark does.
+       * This only applies to inner joins as there are no optimized fallback 
plan for other join types.
+       * It's better that users are explicit about broadcasting for other join 
types than seeing wildly different behavior
+       * depending on data size.
        */
       if (!broadcastLeft && !broadcastRight && joinType == Inner) {
         val canAutoBroadCastLeft = canAutoBroadcastBySize(left)
@@ -118,131 +119,60 @@ class JoinQueryDetector(sparkSession: SparkSession) 
extends Strategy {
         }
       }
 
-      val queryDetection: Option[JoinQueryDetection] = condition match {
-        //For vector only joins
-        case Some(predicate: ST_Predicate) =>
-          getJoinDetection(left, right, predicate)
-        case Some(And(predicate: ST_Predicate, extraCondition)) =>
-          getJoinDetection(left, right, predicate, Some(extraCondition))
-        case Some(And(extraCondition, predicate: ST_Predicate)) =>
-          getJoinDetection(left, right, predicate, Some(extraCondition))
-          //For raster-vector joins
-        case Some(predicate: RS_Predicate) =>
-          getRasterJoinDetection(left, right, predicate, None)
-        case Some(And(predicate: RS_Predicate, extraCondition)) =>
-          getRasterJoinDetection(left, right, predicate, Some(extraCondition))
-        case Some(And(extraCondition, predicate: RS_Predicate)) =>
-          getRasterJoinDetection(left, right, predicate, Some(extraCondition))
-        // For distance joins we execute the actual predicate (condition) and 
not only extraConditions.
-        case Some(ST_DWithin(Seq(leftShape, rightShape, distance))) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, isGeography = false, condition, Some(distance)))
-        case Some(And(ST_DWithin(Seq(leftShape, rightShape, distance)), _)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, isGeography = false, condition, Some(distance)))
-        case Some(And(_, ST_DWithin(Seq(leftShape, rightShape, distance)))) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, isGeography = false, condition, Some(distance)))
-        case Some(ST_DWithin(Seq(leftShape, rightShape, distance, 
useSpheroid))) =>
-          try {
-            val useSpheroidUnwrapped = useSpheroid.eval().asInstanceOf[Boolean]
-            Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, isGeography = useSpheroidUnwrapped, condition, 
Some(distance)))
-          }catch {
-            case _: Throwable =>
-              Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, isGeography = false, condition, Some(distance)))
-          }
-        case Some(And(ST_DWithin(Seq(leftShape, rightShape, distance, 
useSpheroid)), _)) =>
-          try {
-            val useSpheroidUnwrapped = useSpheroid.eval().asInstanceOf[Boolean]
-            Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, isGeography = useSpheroidUnwrapped, condition, 
Some(distance)))
-          }catch {
-            case _: Throwable =>
-              Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, isGeography = false, condition, Some(distance)))
-          }
-        case Some(And(_, ST_DWithin(Seq(leftShape, rightShape, distance, 
useSpheroid)))) =>
-          try {
-            val useSpheroidUnwrapped = useSpheroid.eval().asInstanceOf[Boolean]
-            Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, isGeography = useSpheroidUnwrapped, condition, 
Some(distance)))
-          }catch {
-            case _: Throwable =>
-              Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, isGeography = false, condition, Some(distance)))
+      val joinConditionMatcher = OptimizableJoinCondition(left, right)
+      val queryDetection: Option[JoinQueryDetection] = condition.flatMap {
+        case joinConditionMatcher(predicate, extraCondition) =>
+          predicate match {
+            case pred: ST_Predicate =>
+              getJoinDetection(left, right, pred, extraCondition)
+            case pred: RS_Predicate =>
+              getRasterJoinDetection(left, right, pred, extraCondition)
+            case ST_DWithin(Seq(leftShape, rightShape, distance)) =>
+              Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS,
+                isGeography = false, condition, Some(distance)))
+            case ST_DWithin(Seq(leftShape, rightShape, distance, useSpheroid)) 
=>
+              val useSpheroidUnwrapped = 
useSpheroid.eval().asInstanceOf[Boolean]
+              Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS,
+                isGeography = useSpheroidUnwrapped, condition, Some(distance)))
+
+            // For distance joins we execute the actual predicate (condition) 
and not only extraConditions.
+            // ST_Distance
+            case LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)), 
distance) =>
+              Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
+            case LessThan(ST_Distance(Seq(leftShape, rightShape)), distance) =>
+              Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some (distance)))
+
+            // ST_DistanceSphere
+            case LessThanOrEqual(ST_DistanceSphere(Seq(leftShape, 
rightShape)), distance) =>
+              Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
+            case LessThan(ST_DistanceSphere(Seq(leftShape, rightShape)), 
distance) =>
+              Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
+
+            // ST_DistanceSpheroid
+            case LessThanOrEqual(ST_DistanceSpheroid(Seq(leftShape, 
rightShape)), distance) =>
+              Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
+            case LessThan(ST_DistanceSpheroid(Seq(leftShape, rightShape)), 
distance) =>
+              Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
+
+            // ST_HausdorffDistance
+            case LessThanOrEqual(ST_HausdorffDistance(Seq(leftShape, 
rightShape)), distance) =>
+              Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some (distance)))
+            case LessThan(ST_HausdorffDistance(Seq(leftShape, rightShape)), 
distance) =>
+              Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some (distance)))
+            case LessThanOrEqual(ST_HausdorffDistance(Seq(leftShape, 
rightShape, densityFrac)), distance) =>
+              Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
+            case LessThan(ST_HausdorffDistance(Seq(leftShape, rightShape, 
densityFrac)), distance) =>
+              Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
+
+            // ST_FrechetDistance
+            case LessThanOrEqual(ST_FrechetDistance(Seq(leftShape, 
rightShape)), distance) =>
+              Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
+            case LessThan(ST_FrechetDistance(Seq(leftShape, rightShape)), 
distance) =>
+              Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
+
+            case _ => None
           }
-        case Some(LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)), 
distance)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
-        case Some(And(LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)), 
distance), _)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
-        case Some(And(_, LessThanOrEqual(ST_Distance(Seq(leftShape, 
rightShape)), distance))) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
-        case Some(LessThan(ST_Distance(Seq(leftShape, rightShape)), distance)) 
=>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
-        case Some(And(LessThan(ST_Distance(Seq(leftShape, rightShape)), 
distance), _)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
-        case Some(And(_, LessThan(ST_Distance(Seq(leftShape, rightShape)), 
distance))) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
-        // ST_DistanceSphere
-        case Some(LessThanOrEqual(ST_DistanceSphere(Seq(leftShape, 
rightShape)), distance)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
-        case Some(And(LessThanOrEqual(ST_DistanceSphere(Seq(leftShape, 
rightShape)), distance), _)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
-        case Some(And(_, LessThanOrEqual(ST_DistanceSphere(Seq(leftShape, 
rightShape)), distance))) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
-        case Some(LessThan(ST_DistanceSphere(Seq(leftShape, rightShape)), 
distance)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
-        case Some(And(LessThan(ST_DistanceSphere(Seq(leftShape, rightShape)), 
distance), _)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
-        case Some(And(_, LessThan(ST_DistanceSphere(Seq(leftShape, 
rightShape)), distance))) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
-        // ST_DistanceSpheroid
-        case Some(LessThanOrEqual(ST_DistanceSpheroid(Seq(leftShape, 
rightShape)), distance)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
-        case Some(And(LessThanOrEqual(ST_DistanceSpheroid(Seq(leftShape, 
rightShape)), distance), _)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
-        case Some(And(_, LessThanOrEqual(ST_DistanceSpheroid(Seq(leftShape, 
rightShape)), distance))) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
-        case Some(LessThan(ST_DistanceSpheroid(Seq(leftShape, rightShape)), 
distance)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
-        case Some(And(LessThan(ST_DistanceSpheroid(Seq(leftShape, 
rightShape)), distance), _)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
-        case Some(And(_, LessThan(ST_DistanceSpheroid(Seq(leftShape, 
rightShape)), distance))) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
-        //ST_HausdorffDistanceDefault
-        case Some(LessThanOrEqual(ST_HausdorffDistance(Seq(leftShape, 
rightShape)), distance)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
-        case Some(And(LessThanOrEqual(ST_HausdorffDistance(Seq(leftShape, 
rightShape)), distance), _)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
-        case Some(And(_, LessThanOrEqual(ST_HausdorffDistance(Seq(leftShape, 
rightShape)), distance))) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
-        case Some(LessThan(ST_HausdorffDistance(Seq(leftShape, rightShape)), 
distance)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
-        case Some(And(LessThan(ST_HausdorffDistance(Seq(leftShape, 
rightShape)), distance), _)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
-        case Some(And(_, LessThan(ST_HausdorffDistance(Seq(leftShape, 
rightShape)), distance))) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
-        //ST_HausdorffDistanceDensityFrac
-        case Some(LessThanOrEqual(ST_HausdorffDistance(Seq(leftShape, 
rightShape, densityFrac)), distance)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
-        case Some(And(LessThanOrEqual(ST_HausdorffDistance(Seq(leftShape, 
rightShape, densityFrac)), distance), _)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
-        case Some(And(_, LessThanOrEqual(ST_HausdorffDistance(Seq(leftShape, 
rightShape, densityFrac)), distance))) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
-        case Some(LessThan(ST_HausdorffDistance(Seq(leftShape, rightShape, 
densityFrac)), distance)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
-        case Some(And(LessThan(ST_HausdorffDistance(Seq(leftShape, rightShape, 
densityFrac)), distance), _)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
-        case Some(And(_, LessThan(ST_HausdorffDistance(Seq(leftShape, 
rightShape, densityFrac)), distance))) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
-        //ST_FrechetDistance
-        case Some(LessThanOrEqual(ST_FrechetDistance(Seq(leftShape, 
rightShape)), distance)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
-        case Some(And(LessThanOrEqual(ST_FrechetDistance(Seq(leftShape, 
rightShape)), distance), _)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
-        case Some(And(_, LessThanOrEqual(ST_FrechetDistance(Seq(leftShape, 
rightShape)), distance))) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
-        case Some(LessThan(ST_FrechetDistance(Seq(leftShape, rightShape)), 
distance)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
-        case Some(And(LessThan(ST_FrechetDistance(Seq(leftShape, rightShape)), 
distance), _)) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
-        case Some(And(_, LessThan(ST_FrechetDistance(Seq(leftShape, 
rightShape)), distance))) =>
-          Some(JoinQueryDetection(left, right, leftShape, rightShape, 
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
-        case _ =>
-          None
+        case _ => None
       }
 
       val sedonaConf = new SedonaConf(sparkSession.conf)
@@ -267,7 +197,6 @@ class JoinQueryDetector(sparkSession: SparkSession) extends 
Strategy {
             Nil
         }
       }
-    }
     case _ =>
       Nil
   }
@@ -285,35 +214,6 @@ class JoinQueryDetector(sparkSession: SparkSession) 
extends Strategy {
   private def canAutoBroadcastBySize(plan: LogicalPlan) =
     plan.stats.sizeInBytes != 0 && plan.stats.sizeInBytes <= 
SedonaConf.fromActiveSession.getAutoBroadcastJoinThreshold
 
-  /**
-    * Returns true if specified expression has at least one reference and all 
its references
-    * map to the output of the specified plan.
-    */
-  private def matches(expr: Expression, plan: LogicalPlan): Boolean =
-    expr.references.nonEmpty && expr.references.subsetOf(plan.outputSet)
-
-  private def matchExpressionsToPlans(exprA: Expression,
-                                      exprB: Expression,
-                                      planA: LogicalPlan,
-                                      planB: LogicalPlan): 
Option[(LogicalPlan, LogicalPlan, Boolean)] =
-    if (matches(exprA, planA) && matches(exprB, planB)) {
-      Some((planA, planB, false))
-    } else if (matches(exprA, planB) && matches(exprB, planA)) {
-      Some((planB, planA, true))
-    } else {
-      None
-    }
-
-  private def matchDistanceExpressionToJoinSide(distance: Expression, left: 
LogicalPlan, right: LogicalPlan): Option[JoinSide] = {
-    if (distance.references.isEmpty || matches(distance, left)) {
-      Some(LeftSide)
-    } else if (matches(distance, right)) {
-      Some(RightSide)
-    } else {
-      None
-    }
-  }
-
   private def planSpatialJoin(
     left: LogicalPlan,
     right: LogicalPlan,
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/OptimizableJoinCondition.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/OptimizableJoinCondition.scala
new file mode 100644
index 000000000..2211df003
--- /dev/null
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/OptimizableJoinCondition.scala
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.strategy.join
+
+import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan, 
LessThanOrEqual, Literal}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.sedona_sql.expressions._
+import org.apache.spark.sql.sedona_sql.expressions.raster.RS_Predicate
+import org.apache.spark.sql.sedona_sql.optimization.ExpressionUtils
+
+case class OptimizableJoinCondition(left: LogicalPlan, right: LogicalPlan) {
+  /**
+   * An extractor that matches expressions that are optimizable join 
conditions. Join queries with optimizable join
+   * conditions will be executed as a spatial join (RangeJoin or DistanceJoin).
+   * @param expression the join condition
+   * @return an optional tuple containing the spatial predicate and the other 
predicates
+   */
+  def unapply(expression: Expression): Option[(Expression, 
Option[Expression])] = {
+    val predicates = ExpressionUtils.splitConjunctivePredicates(expression)
+    val (maybeSpatialPredicate, otherPredicates) = 
extractFirstOptimizablePredicate(predicates)
+    maybeSpatialPredicate match {
+      case Some(spatialPredicate) =>
+        val other = otherPredicates.reduceOption((l, r) => And(l, r))
+        Some(spatialPredicate, other)
+      case None => None
+    }
+  }
+
+  private def extractFirstOptimizablePredicate(expressions: Seq[Expression]): 
(Option[Expression], Seq[Expression]) = {
+    expressions match {
+      case Nil => (None, Nil)
+      case head :: tail =>
+        if (isOptimizablePredicate(head)) {
+          (Some(head), tail)
+        } else {
+          val (spatialPredicate, otherPredicates) = 
extractFirstOptimizablePredicate(tail)
+          (spatialPredicate, head +: otherPredicates)
+        }
+    }
+  }
+
+  private def isOptimizablePredicate(expression: Expression): Boolean = {
+    expression match {
+      case _: ST_Intersects |
+           _: ST_Contains |
+           _: ST_Covers |
+           _: ST_Within |
+           _: ST_CoveredBy |
+           _: ST_Overlaps |
+           _: ST_Touches |
+           _: ST_Equals |
+           _: ST_Crosses |
+           _: RS_Predicate =>
+        val leftShape = expression.children.head
+        val rightShape = expression.children(1)
+        ExpressionUtils.matchExpressionsToPlans(leftShape, rightShape, left, 
right).isDefined
+
+      case ST_DWithin(Seq(leftShape, rightShape, distance)) =>
+        isDistanceJoinOptimizable(leftShape, rightShape, distance)
+      case ST_DWithin(Seq(leftShape, rightShape, distance, useSpheroid)) =>
+        useSpheroid.isInstanceOf[Literal] && 
isDistanceJoinOptimizable(leftShape, rightShape, distance)
+
+      case _: LessThan | _: LessThanOrEqual =>
+        val (smaller, larger) = (expression.children.head, 
expression.children(1))
+        smaller match {
+          case _: ST_Distance |
+               _: ST_DistanceSphere |
+               _: ST_DistanceSpheroid |
+               _: ST_FrechetDistance =>
+            val leftShape = smaller.children.head
+            val rightShape = smaller.children(1)
+            isDistanceJoinOptimizable(leftShape, rightShape, larger)
+
+          case ST_HausdorffDistance(Seq(leftShape, rightShape)) =>
+            isDistanceJoinOptimizable(leftShape, rightShape, larger)
+          case ST_HausdorffDistance(Seq(leftShape, rightShape, densityFrac)) =>
+            isDistanceJoinOptimizable(leftShape, rightShape, larger)
+
+          case _ => false
+        }
+
+      case _ => false
+    }
+  }
+
+  private def isDistanceJoinOptimizable(leftShape: Expression, rightShape: 
Expression, distance: Expression): Boolean = {
+    ExpressionUtils.matchExpressionsToPlans(leftShape, rightShape, left, 
right).isDefined &&
+      ExpressionUtils.matchDistanceExpressionToJoinSide(distance, left, 
right).isDefined
+  }
+}
diff --git 
a/spark/common/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala 
b/spark/common/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
index 0fa61bfe7..952878ce9 100644
--- a/spark/common/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
+++ b/spark/common/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
@@ -64,7 +64,10 @@ class SpatialJoinSuite extends TestBaseScala with 
TableDrivenPropertyChecks {
       "ST_Distance(df1.geom, df2.geom) < df1.dist",
       "ST_Distance(df1.geom, df2.geom) < df2.dist",
       "ST_Distance(df2.geom, df1.geom) < df1.dist",
-      "ST_Distance(df2.geom, df1.geom) < df2.dist"
+      "ST_Distance(df2.geom, df1.geom) < df2.dist",
+
+      "1.0 > ST_Distance(df1.geom, df2.geom)",
+      "1.0 >= ST_Distance(df1.geom, df2.geom)"
     )
 
     var spatialJoinPartitionSide = "left"
@@ -172,6 +175,22 @@ class SpatialJoinSuite extends TestBaseScala with 
TableDrivenPropertyChecks {
     }
   }
 
+  describe("Spatial join optimizer should work with complex join conditions") {
+    it("Optimize spatial join with complex join conditions") {
+      withOptimizationMode("all") {
+        prepareTempViewsForTestData()
+        val df = sparkSession.sql(
+          """
+            |SELECT df1.id, df2.id FROM df1 JOIN df2 ON
+            |ST_Intersects(df1.geom, df2.geom) AND df1.id > df2.id AND df1.id 
< df2.id + 100""".stripMargin)
+        assert(isUsingOptimizedSpatialJoin(df))
+        val expectedResult = buildExpectedResult("ST_Intersects(df1.geom, 
df2.geom)")
+          .filter { case (id1, id2) => id1 > id2 && id1 < id2 + 100 }
+        verifyResult(expectedResult, df)
+      }
+    }
+  }
+
   private def withOptimizationMode(mode: String)(body: => Unit) : Unit = {
     val oldOptimizationMode = 
sparkSession.conf.get("sedona.join.optimizationmode", "nonequi")
     try {
@@ -227,6 +246,12 @@ class SpatialJoinSuite extends TestBaseScala with 
TableDrivenPropertyChecks {
             (l: Geometry, r: Geometry) => l.distance(r) < 1.0
           }
         }
+      case _ =>
+        if (udf.contains(">=")) {
+          (l: Geometry, r: Geometry) => l.distance(r) <= 1.0
+        } else {
+          (l: Geometry, r: Geometry) => l.distance(r) < 1.0
+        }
     }
     left.flatMap { case (id, geom) =>
       right.filter { case (_, geom2) =>
diff --git 
a/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala 
b/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
index f7836e18d..c1cd9cfcf 100644
--- a/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
+++ b/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
@@ -685,7 +685,6 @@ class functionTestScala extends TestBaseScala with Matchers 
with GeometrySample
       val testtable = sparkSession.sql("select ST_GeomFromWKT('POLYGON ((-3 
-3, 3 -3, 3 3, -3 3, -3 -3))') as a,ST_GeomFromWKT('POLYGON ((5 -3, 7 -3, 7 -1, 
5 -1, 5 -3))') as b")
       testtable.createOrReplaceTempView("union_table")
       val union = sparkSession.sql("select ST_Union(a,b) from union_table")
-      println(union.take(1)(0).get(0).asInstanceOf[Geometry].toText)
       
assert(union.take(1)(0).get(0).asInstanceOf[Geometry].toText.equals("MULTIPOLYGON
 (((-3 -3, -3 3, 3 3, 3 -3, -3 -3)), ((5 -3, 5 -1, 7 -1, 7 -3, 5 -3)))"))
     }
 
@@ -939,7 +938,6 @@ class functionTestScala extends TestBaseScala with Matchers 
with GeometrySample
 
   it("Should pass ST_IsPolygonCW") {
     var actual = sparkSession.sql("SELECT 
ST_IsPolygonCW(ST_GeomFromWKT('POLYGON ((20 35, 10 30, 10 10, 30 5, 45 20, 20 
35),(30 20, 20 15, 20 25, 30 20))'))").first().getBoolean(0)
-    print(actual)
     assert(actual == false)
 
     actual = sparkSession.sql("SELECT ST_IsPolygonCW(ST_GeomFromWKT('POLYGON 
((20 35, 45 20, 30 5, 10 10, 10 30, 20 35), (30 20, 20 25, 20 15, 30 
20))'))").first().getBoolean(0)

Reply via email to