This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new 2699898fc [SEDONA-532] Correctly handle complex join conditions (#1325)
2699898fc is described below
commit 2699898fc4ec96a1a523b1fed83a1d23040488de
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Wed Apr 10 00:50:40 2024 +0800
[SEDONA-532] Correctly handle complex join conditions (#1325)
* Add a violating test case to fix in later commit.
* Refactored join condition matcher to make it work with complex join
conditions
---
.../sedona_sql/optimization/ExpressionUtils.scala | 33 +++-
.../strategy/join/JoinQueryDetector.scala | 218 ++++++---------------
.../strategy/join/OptimizableJoinCondition.scala | 106 ++++++++++
.../org/apache/sedona/sql/SpatialJoinSuite.scala | 27 ++-
.../org/apache/sedona/sql/functionTestScala.scala | 2 -
5 files changed, 223 insertions(+), 163 deletions(-)
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/ExpressionUtils.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/ExpressionUtils.scala
index e67a22ff9..377f97a2e 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/ExpressionUtils.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/ExpressionUtils.scala
@@ -20,6 +20,8 @@
package org.apache.spark.sql.sedona_sql.optimization
import org.apache.spark.sql.catalyst.expressions.{And, Expression}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.sedona_sql.strategy.join.{JoinSide, LeftSide,
RightSide}
/**
* This class contains helper methods for transforming catalyst expressions.
@@ -38,7 +40,36 @@ object ExpressionUtils {
condition match {
case And(cond1, cond2) =>
splitConjunctivePredicates(cond1) ++ splitConjunctivePredicates(cond2)
- case other => other :: Nil
+ case other: Expression => other :: Nil
+ }
+ }
+
+ /**
+ * Returns true if specified expression has at least one reference and all
its references
+ * map to the output of the specified plan.
+ */
+ def matches(expr: Expression, plan: LogicalPlan): Boolean =
+ expr.references.nonEmpty && expr.references.subsetOf(plan.outputSet)
+
+ def matchExpressionsToPlans(exprA: Expression,
+ exprB: Expression,
+ planA: LogicalPlan,
+ planB: LogicalPlan): Option[(LogicalPlan, LogicalPlan, Boolean)] =
+ if (matches(exprA, planA) && matches(exprB, planB)) {
+ Some((planA, planB, false))
+ } else if (matches(exprA, planB) && matches(exprB, planA)) {
+ Some((planB, planA, true))
+ } else {
+ None
+ }
+
+ def matchDistanceExpressionToJoinSide(distance: Expression, left:
LogicalPlan, right: LogicalPlan): Option[JoinSide] = {
+ if (distance.references.isEmpty || matches(distance, left)) {
+ Some(LeftSide)
+ } else if (matches(distance, right)) {
+ Some(RightSide)
+ } else {
+ None
}
}
}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
index f35bab885..6e71eea9e 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/JoinQueryDetector.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.sedona_sql.expressions._
import org.apache.spark.sql.sedona_sql.expressions.raster._
import
org.apache.spark.sql.sedona_sql.optimization.ExpressionUtils.splitConjunctivePredicates
import org.apache.spark.sql.{SparkSession, Strategy}
+import org.apache.spark.sql.sedona_sql.optimization.ExpressionUtils.{matches,
matchExpressionsToPlans, matchDistanceExpressionToJoinSide}
case class JoinQueryDetection(
@@ -95,15 +96,15 @@ class JoinQueryDetector(sparkSession: SparkSession) extends
Strategy {
}
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case Join(left, right, joinType, condition, JoinHint(leftHint, rightHint))
if optimizationEnabled(left, right, condition) => {
+ case Join(left, right, joinType, condition, JoinHint(leftHint, rightHint))
if optimizationEnabled(left, right, condition) =>
var broadcastLeft = leftHint.exists(_.strategy.contains(BROADCAST))
var broadcastRight = rightHint.exists(_.strategy.contains(BROADCAST))
/*
- If either side is small we can automatically broadcast just like Spark
does.
- This only applies to inner joins as there are no optimized fallback plan
for other join types.
- It's better that users are explicit about broadcasting for other join
types than seeing wildly different behavior
- depending on data size.
+ * If either side is small we can automatically broadcast just like
Spark does.
+ * This only applies to inner joins as there are no optimized fallback
plan for other join types.
+ * It's better that users are explicit about broadcasting for other join
types than seeing wildly different behavior
+ * depending on data size.
*/
if (!broadcastLeft && !broadcastRight && joinType == Inner) {
val canAutoBroadCastLeft = canAutoBroadcastBySize(left)
@@ -118,131 +119,60 @@ class JoinQueryDetector(sparkSession: SparkSession)
extends Strategy {
}
}
- val queryDetection: Option[JoinQueryDetection] = condition match {
- //For vector only joins
- case Some(predicate: ST_Predicate) =>
- getJoinDetection(left, right, predicate)
- case Some(And(predicate: ST_Predicate, extraCondition)) =>
- getJoinDetection(left, right, predicate, Some(extraCondition))
- case Some(And(extraCondition, predicate: ST_Predicate)) =>
- getJoinDetection(left, right, predicate, Some(extraCondition))
- //For raster-vector joins
- case Some(predicate: RS_Predicate) =>
- getRasterJoinDetection(left, right, predicate, None)
- case Some(And(predicate: RS_Predicate, extraCondition)) =>
- getRasterJoinDetection(left, right, predicate, Some(extraCondition))
- case Some(And(extraCondition, predicate: RS_Predicate)) =>
- getRasterJoinDetection(left, right, predicate, Some(extraCondition))
- // For distance joins we execute the actual predicate (condition) and
not only extraConditions.
- case Some(ST_DWithin(Seq(leftShape, rightShape, distance))) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, isGeography = false, condition, Some(distance)))
- case Some(And(ST_DWithin(Seq(leftShape, rightShape, distance)), _)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, isGeography = false, condition, Some(distance)))
- case Some(And(_, ST_DWithin(Seq(leftShape, rightShape, distance)))) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, isGeography = false, condition, Some(distance)))
- case Some(ST_DWithin(Seq(leftShape, rightShape, distance,
useSpheroid))) =>
- try {
- val useSpheroidUnwrapped = useSpheroid.eval().asInstanceOf[Boolean]
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, isGeography = useSpheroidUnwrapped, condition,
Some(distance)))
- }catch {
- case _: Throwable =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, isGeography = false, condition, Some(distance)))
- }
- case Some(And(ST_DWithin(Seq(leftShape, rightShape, distance,
useSpheroid)), _)) =>
- try {
- val useSpheroidUnwrapped = useSpheroid.eval().asInstanceOf[Boolean]
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, isGeography = useSpheroidUnwrapped, condition,
Some(distance)))
- }catch {
- case _: Throwable =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, isGeography = false, condition, Some(distance)))
- }
- case Some(And(_, ST_DWithin(Seq(leftShape, rightShape, distance,
useSpheroid)))) =>
- try {
- val useSpheroidUnwrapped = useSpheroid.eval().asInstanceOf[Boolean]
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, isGeography = useSpheroidUnwrapped, condition,
Some(distance)))
- }catch {
- case _: Throwable =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, isGeography = false, condition, Some(distance)))
+ val joinConditionMatcher = OptimizableJoinCondition(left, right)
+ val queryDetection: Option[JoinQueryDetection] = condition.flatMap {
+ case joinConditionMatcher(predicate, extraCondition) =>
+ predicate match {
+ case pred: ST_Predicate =>
+ getJoinDetection(left, right, pred, extraCondition)
+ case pred: RS_Predicate =>
+ getRasterJoinDetection(left, right, pred, extraCondition)
+ case ST_DWithin(Seq(leftShape, rightShape, distance)) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS,
+ isGeography = false, condition, Some(distance)))
+ case ST_DWithin(Seq(leftShape, rightShape, distance, useSpheroid))
=>
+ val useSpheroidUnwrapped =
useSpheroid.eval().asInstanceOf[Boolean]
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS,
+ isGeography = useSpheroidUnwrapped, condition, Some(distance)))
+
+ // For distance joins we execute the actual predicate (condition)
and not only extraConditions.
+ // ST_Distance
+ case LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)),
distance) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
+ case LessThan(ST_Distance(Seq(leftShape, rightShape)), distance) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some (distance)))
+
+ // ST_DistanceSphere
+ case LessThanOrEqual(ST_DistanceSphere(Seq(leftShape,
rightShape)), distance) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
+ case LessThan(ST_DistanceSphere(Seq(leftShape, rightShape)),
distance) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
+
+ // ST_DistanceSpheroid
+ case LessThanOrEqual(ST_DistanceSpheroid(Seq(leftShape,
rightShape)), distance) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
+ case LessThan(ST_DistanceSpheroid(Seq(leftShape, rightShape)),
distance) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
+
+ // ST_HausdorffDistance
+ case LessThanOrEqual(ST_HausdorffDistance(Seq(leftShape,
rightShape)), distance) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some (distance)))
+ case LessThan(ST_HausdorffDistance(Seq(leftShape, rightShape)),
distance) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some (distance)))
+ case LessThanOrEqual(ST_HausdorffDistance(Seq(leftShape,
rightShape, densityFrac)), distance) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
+ case LessThan(ST_HausdorffDistance(Seq(leftShape, rightShape,
densityFrac)), distance) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
+
+ // ST_FrechetDistance
+ case LessThanOrEqual(ST_FrechetDistance(Seq(leftShape,
rightShape)), distance) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
+ case LessThan(ST_FrechetDistance(Seq(leftShape, rightShape)),
distance) =>
+ Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
+
+ case _ => None
}
- case Some(LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)),
distance)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
- case Some(And(LessThanOrEqual(ST_Distance(Seq(leftShape, rightShape)),
distance), _)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
- case Some(And(_, LessThanOrEqual(ST_Distance(Seq(leftShape,
rightShape)), distance))) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
- case Some(LessThan(ST_Distance(Seq(leftShape, rightShape)), distance))
=>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
- case Some(And(LessThan(ST_Distance(Seq(leftShape, rightShape)),
distance), _)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
- case Some(And(_, LessThan(ST_Distance(Seq(leftShape, rightShape)),
distance))) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
- // ST_DistanceSphere
- case Some(LessThanOrEqual(ST_DistanceSphere(Seq(leftShape,
rightShape)), distance)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
- case Some(And(LessThanOrEqual(ST_DistanceSphere(Seq(leftShape,
rightShape)), distance), _)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
- case Some(And(_, LessThanOrEqual(ST_DistanceSphere(Seq(leftShape,
rightShape)), distance))) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
- case Some(LessThan(ST_DistanceSphere(Seq(leftShape, rightShape)),
distance)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
- case Some(And(LessThan(ST_DistanceSphere(Seq(leftShape, rightShape)),
distance), _)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
- case Some(And(_, LessThan(ST_DistanceSphere(Seq(leftShape,
rightShape)), distance))) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
- // ST_DistanceSpheroid
- case Some(LessThanOrEqual(ST_DistanceSpheroid(Seq(leftShape,
rightShape)), distance)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
- case Some(And(LessThanOrEqual(ST_DistanceSpheroid(Seq(leftShape,
rightShape)), distance), _)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
- case Some(And(_, LessThanOrEqual(ST_DistanceSpheroid(Seq(leftShape,
rightShape)), distance))) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
- case Some(LessThan(ST_DistanceSpheroid(Seq(leftShape, rightShape)),
distance)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
- case Some(And(LessThan(ST_DistanceSpheroid(Seq(leftShape,
rightShape)), distance), _)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
- case Some(And(_, LessThan(ST_DistanceSpheroid(Seq(leftShape,
rightShape)), distance))) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, true, condition, Some(distance)))
- //ST_HausdorffDistanceDefault
- case Some(LessThanOrEqual(ST_HausdorffDistance(Seq(leftShape,
rightShape)), distance)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
- case Some(And(LessThanOrEqual(ST_HausdorffDistance(Seq(leftShape,
rightShape)), distance), _)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
- case Some(And(_, LessThanOrEqual(ST_HausdorffDistance(Seq(leftShape,
rightShape)), distance))) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
- case Some(LessThan(ST_HausdorffDistance(Seq(leftShape, rightShape)),
distance)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
- case Some(And(LessThan(ST_HausdorffDistance(Seq(leftShape,
rightShape)), distance), _)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
- case Some(And(_, LessThan(ST_HausdorffDistance(Seq(leftShape,
rightShape)), distance))) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
- //ST_HausdorffDistanceDensityFrac
- case Some(LessThanOrEqual(ST_HausdorffDistance(Seq(leftShape,
rightShape, densityFrac)), distance)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
- case Some(And(LessThanOrEqual(ST_HausdorffDistance(Seq(leftShape,
rightShape, densityFrac)), distance), _)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
- case Some(And(_, LessThanOrEqual(ST_HausdorffDistance(Seq(leftShape,
rightShape, densityFrac)), distance))) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
- case Some(LessThan(ST_HausdorffDistance(Seq(leftShape, rightShape,
densityFrac)), distance)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
- case Some(And(LessThan(ST_HausdorffDistance(Seq(leftShape, rightShape,
densityFrac)), distance), _)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
- case Some(And(_, LessThan(ST_HausdorffDistance(Seq(leftShape,
rightShape, densityFrac)), distance))) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
- //ST_FrechetDistance
- case Some(LessThanOrEqual(ST_FrechetDistance(Seq(leftShape,
rightShape)), distance)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
- case Some(And(LessThanOrEqual(ST_FrechetDistance(Seq(leftShape,
rightShape)), distance), _)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
- case Some(And(_, LessThanOrEqual(ST_FrechetDistance(Seq(leftShape,
rightShape)), distance))) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
- case Some(LessThan(ST_FrechetDistance(Seq(leftShape, rightShape)),
distance)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
- case Some(And(LessThan(ST_FrechetDistance(Seq(leftShape, rightShape)),
distance), _)) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
- case Some(And(_, LessThan(ST_FrechetDistance(Seq(leftShape,
rightShape)), distance))) =>
- Some(JoinQueryDetection(left, right, leftShape, rightShape,
SpatialPredicate.INTERSECTS, false, condition, Some(distance)))
- case _ =>
- None
+ case _ => None
}
val sedonaConf = new SedonaConf(sparkSession.conf)
@@ -267,7 +197,6 @@ class JoinQueryDetector(sparkSession: SparkSession) extends
Strategy {
Nil
}
}
- }
case _ =>
Nil
}
@@ -285,35 +214,6 @@ class JoinQueryDetector(sparkSession: SparkSession)
extends Strategy {
private def canAutoBroadcastBySize(plan: LogicalPlan) =
plan.stats.sizeInBytes != 0 && plan.stats.sizeInBytes <=
SedonaConf.fromActiveSession.getAutoBroadcastJoinThreshold
- /**
- * Returns true if specified expression has at least one reference and all
its references
- * map to the output of the specified plan.
- */
- private def matches(expr: Expression, plan: LogicalPlan): Boolean =
- expr.references.nonEmpty && expr.references.subsetOf(plan.outputSet)
-
- private def matchExpressionsToPlans(exprA: Expression,
- exprB: Expression,
- planA: LogicalPlan,
- planB: LogicalPlan):
Option[(LogicalPlan, LogicalPlan, Boolean)] =
- if (matches(exprA, planA) && matches(exprB, planB)) {
- Some((planA, planB, false))
- } else if (matches(exprA, planB) && matches(exprB, planA)) {
- Some((planB, planA, true))
- } else {
- None
- }
-
- private def matchDistanceExpressionToJoinSide(distance: Expression, left:
LogicalPlan, right: LogicalPlan): Option[JoinSide] = {
- if (distance.references.isEmpty || matches(distance, left)) {
- Some(LeftSide)
- } else if (matches(distance, right)) {
- Some(RightSide)
- } else {
- None
- }
- }
-
private def planSpatialJoin(
left: LogicalPlan,
right: LogicalPlan,
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/OptimizableJoinCondition.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/OptimizableJoinCondition.scala
new file mode 100644
index 000000000..2211df003
--- /dev/null
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategy/join/OptimizableJoinCondition.scala
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.strategy.join
+
+import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan,
LessThanOrEqual, Literal}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.sedona_sql.expressions._
+import org.apache.spark.sql.sedona_sql.expressions.raster.RS_Predicate
+import org.apache.spark.sql.sedona_sql.optimization.ExpressionUtils
+
+case class OptimizableJoinCondition(left: LogicalPlan, right: LogicalPlan) {
+ /**
+ * An extractor that matches expressions that are optimizable join
conditions. Join queries with optimizable join
+ * conditions will be executed as a spatial join (RangeJoin or DistanceJoin).
+ * @param expression the join condition
+ * @return an optional tuple containing the spatial predicate and the other
predicates
+ */
+ def unapply(expression: Expression): Option[(Expression,
Option[Expression])] = {
+ val predicates = ExpressionUtils.splitConjunctivePredicates(expression)
+ val (maybeSpatialPredicate, otherPredicates) =
extractFirstOptimizablePredicate(predicates)
+ maybeSpatialPredicate match {
+ case Some(spatialPredicate) =>
+ val other = otherPredicates.reduceOption((l, r) => And(l, r))
+ Some(spatialPredicate, other)
+ case None => None
+ }
+ }
+
+ private def extractFirstOptimizablePredicate(expressions: Seq[Expression]):
(Option[Expression], Seq[Expression]) = {
+ expressions match {
+ case Nil => (None, Nil)
+ case head :: tail =>
+ if (isOptimizablePredicate(head)) {
+ (Some(head), tail)
+ } else {
+ val (spatialPredicate, otherPredicates) =
extractFirstOptimizablePredicate(tail)
+ (spatialPredicate, head +: otherPredicates)
+ }
+ }
+ }
+
+ private def isOptimizablePredicate(expression: Expression): Boolean = {
+ expression match {
+ case _: ST_Intersects |
+ _: ST_Contains |
+ _: ST_Covers |
+ _: ST_Within |
+ _: ST_CoveredBy |
+ _: ST_Overlaps |
+ _: ST_Touches |
+ _: ST_Equals |
+ _: ST_Crosses |
+ _: RS_Predicate =>
+ val leftShape = expression.children.head
+ val rightShape = expression.children(1)
+ ExpressionUtils.matchExpressionsToPlans(leftShape, rightShape, left,
right).isDefined
+
+ case ST_DWithin(Seq(leftShape, rightShape, distance)) =>
+ isDistanceJoinOptimizable(leftShape, rightShape, distance)
+ case ST_DWithin(Seq(leftShape, rightShape, distance, useSpheroid)) =>
+ useSpheroid.isInstanceOf[Literal] &&
isDistanceJoinOptimizable(leftShape, rightShape, distance)
+
+ case _: LessThan | _: LessThanOrEqual =>
+ val (smaller, larger) = (expression.children.head,
expression.children(1))
+ smaller match {
+ case _: ST_Distance |
+ _: ST_DistanceSphere |
+ _: ST_DistanceSpheroid |
+ _: ST_FrechetDistance =>
+ val leftShape = smaller.children.head
+ val rightShape = smaller.children(1)
+ isDistanceJoinOptimizable(leftShape, rightShape, larger)
+
+ case ST_HausdorffDistance(Seq(leftShape, rightShape)) =>
+ isDistanceJoinOptimizable(leftShape, rightShape, larger)
+ case ST_HausdorffDistance(Seq(leftShape, rightShape, densityFrac)) =>
+ isDistanceJoinOptimizable(leftShape, rightShape, larger)
+
+ case _ => false
+ }
+
+ case _ => false
+ }
+ }
+
+ private def isDistanceJoinOptimizable(leftShape: Expression, rightShape:
Expression, distance: Expression): Boolean = {
+ ExpressionUtils.matchExpressionsToPlans(leftShape, rightShape, left,
right).isDefined &&
+ ExpressionUtils.matchDistanceExpressionToJoinSide(distance, left,
right).isDefined
+ }
+}
diff --git
a/spark/common/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
b/spark/common/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
index 0fa61bfe7..952878ce9 100644
--- a/spark/common/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
+++ b/spark/common/src/test/scala/org/apache/sedona/sql/SpatialJoinSuite.scala
@@ -64,7 +64,10 @@ class SpatialJoinSuite extends TestBaseScala with
TableDrivenPropertyChecks {
"ST_Distance(df1.geom, df2.geom) < df1.dist",
"ST_Distance(df1.geom, df2.geom) < df2.dist",
"ST_Distance(df2.geom, df1.geom) < df1.dist",
- "ST_Distance(df2.geom, df1.geom) < df2.dist"
+ "ST_Distance(df2.geom, df1.geom) < df2.dist",
+
+ "1.0 > ST_Distance(df1.geom, df2.geom)",
+ "1.0 >= ST_Distance(df1.geom, df2.geom)"
)
var spatialJoinPartitionSide = "left"
@@ -172,6 +175,22 @@ class SpatialJoinSuite extends TestBaseScala with
TableDrivenPropertyChecks {
}
}
+ describe("Spatial join optimizer should work with complex join conditions") {
+ it("Optimize spatial join with complex join conditions") {
+ withOptimizationMode("all") {
+ prepareTempViewsForTestData()
+ val df = sparkSession.sql(
+ """
+ |SELECT df1.id, df2.id FROM df1 JOIN df2 ON
+ |ST_Intersects(df1.geom, df2.geom) AND df1.id > df2.id AND df1.id
< df2.id + 100""".stripMargin)
+ assert(isUsingOptimizedSpatialJoin(df))
+ val expectedResult = buildExpectedResult("ST_Intersects(df1.geom,
df2.geom)")
+ .filter { case (id1, id2) => id1 > id2 && id1 < id2 + 100 }
+ verifyResult(expectedResult, df)
+ }
+ }
+ }
+
private def withOptimizationMode(mode: String)(body: => Unit) : Unit = {
val oldOptimizationMode =
sparkSession.conf.get("sedona.join.optimizationmode", "nonequi")
try {
@@ -227,6 +246,12 @@ class SpatialJoinSuite extends TestBaseScala with
TableDrivenPropertyChecks {
(l: Geometry, r: Geometry) => l.distance(r) < 1.0
}
}
+ case _ =>
+ if (udf.contains(">=")) {
+ (l: Geometry, r: Geometry) => l.distance(r) <= 1.0
+ } else {
+ (l: Geometry, r: Geometry) => l.distance(r) < 1.0
+ }
}
left.flatMap { case (id, geom) =>
right.filter { case (_, geom2) =>
diff --git
a/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
b/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
index f7836e18d..c1cd9cfcf 100644
--- a/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
+++ b/spark/common/src/test/scala/org/apache/sedona/sql/functionTestScala.scala
@@ -685,7 +685,6 @@ class functionTestScala extends TestBaseScala with Matchers
with GeometrySample
val testtable = sparkSession.sql("select ST_GeomFromWKT('POLYGON ((-3
-3, 3 -3, 3 3, -3 3, -3 -3))') as a,ST_GeomFromWKT('POLYGON ((5 -3, 7 -3, 7 -1,
5 -1, 5 -3))') as b")
testtable.createOrReplaceTempView("union_table")
val union = sparkSession.sql("select ST_Union(a,b) from union_table")
- println(union.take(1)(0).get(0).asInstanceOf[Geometry].toText)
assert(union.take(1)(0).get(0).asInstanceOf[Geometry].toText.equals("MULTIPOLYGON
(((-3 -3, -3 3, 3 3, 3 -3, -3 -3)), ((5 -3, 5 -1, 7 -1, 7 -3, 5 -3)))"))
}
@@ -939,7 +938,6 @@ class functionTestScala extends TestBaseScala with Matchers
with GeometrySample
it("Should pass ST_IsPolygonCW") {
var actual = sparkSession.sql("SELECT
ST_IsPolygonCW(ST_GeomFromWKT('POLYGON ((20 35, 10 30, 10 10, 30 5, 45 20, 20
35),(30 20, 20 15, 20 25, 30 20))'))").first().getBoolean(0)
- print(actual)
assert(actual == false)
actual = sparkSession.sql("SELECT ST_IsPolygonCW(ST_GeomFromWKT('POLYGON
((20 35, 45 20, 30 5, 10 10, 10 30, 20 35), (30 20, 20 25, 20 15, 30
20))'))").first().getBoolean(0)