This is an automated email from the ASF dual-hosted git repository.

yao pushed a commit to branch feature/crossjoin-array-contains-benchmark
in repository https://gitbox.apache.org/repos/asf/spark.git

commit 7da480e5ca16a7dca3cd5d48aa7e30fda07a3b5c
Author: Kent Yao <[email protected]>
AuthorDate: Wed Feb 4 09:45:58 2026 +0000

    [SPARK-XXXX][SQL] Add CrossJoinArrayContainsToInnerJoin optimizer rule
    
    ### What changes were proposed in this pull request?
    
    This PR adds a new optimizer rule that converts cross joins with 
array_contains filter into inner joins using explode, improving query 
performance significantly.
    
    ### Why are the changes needed?
    
    Cross joins with array_contains predicates result in O(N*M) complexity. By 
transforming to explode + inner join, we achieve O(N+M) complexity.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No. This is an internal optimization that automatically applies to 
applicable queries.
    
    ### How was this patch tested?
    
    - Unit tests in CrossJoinArrayContainsToInnerJoinSuite (6 tests)
    - Microbenchmark showing 11-16X speedup on representative workload
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Yes, GitHub Copilot was used to assist with implementation.
---
 .../CrossJoinArrayContainsToInnerJoin.scala        | 219 ++++++++++++++++++++
 .../spark/sql/catalyst/optimizer/Optimizer.scala   |   1 +
 .../CrossJoinArrayContainsToInnerJoinSuite.scala   | 188 +++++++++++++++++
 ...yContainsToInnerJoinBenchmark-jdk21-results.txt |  38 ++++
 ...inArrayContainsToInnerJoinBenchmark-results.txt |  38 ++++
 ...rossJoinArrayContainsToInnerJoinBenchmark.scala | 227 +++++++++++++++++++++
 6 files changed, 711 insertions(+)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CrossJoinArrayContainsToInnerJoin.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CrossJoinArrayContainsToInnerJoin.scala
new file mode 100644
index 000000000000..10ed24468202
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CrossJoinArrayContainsToInnerJoin.scala
@@ -0,0 +1,219 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.catalyst.trees.TreePattern.{FILTER, JOIN}
+import org.apache.spark.sql.types._
+
+/**
+ * Converts cross joins with array_contains filter into inner joins using 
explode.
+ *
+ * This optimization transforms queries of the form:
+ * {{{
+ * SELECT * FROM left, right WHERE array_contains(left.arr, right.elem)
+ * }}}
+ *
+ * Into a more efficient form:
+ * {{{
+ * SELECT * FROM (
+ *   SELECT *, explode(array_distinct(arr)) AS unnested FROM left
+ * ) l
+ * INNER JOIN right ON l.unnested = right.elem
+ * }}}
+ *
+ * This avoids the O(N*M) cross join by using unnesting and equi-join.
+ */
+object CrossJoinArrayContainsToInnerJoin extends Rule[LogicalPlan] with 
PredicateHelper {
+
+  // Supported element types for the optimization
+  private val supportedTypes: Set[DataType] = Set(
+    IntegerType, LongType, StringType, DateType
+  )
+
+  override def apply(plan: LogicalPlan): LogicalPlan = 
plan.transformUpWithPruning(
+    _.containsAllPatterns(FILTER, JOIN)) {
+    case filter @ Filter(condition, join @ Join(left, right, Cross | Inner, 
None, _))
+        if join.condition.isEmpty =>
+      transformFilterOverCrossJoin(filter, condition, join, left, right)
+        .getOrElse(filter)
+  }
+
+  /**
+   * Attempts to transform a Filter over a Cross/Inner join (with no 
condition) that has
+   * an array_contains predicate into an inner join with explode.
+   */
+  private def transformFilterOverCrossJoin(
+      filter: Filter,
+      condition: Expression,
+      join: Join,
+      left: LogicalPlan,
+      right: LogicalPlan): Option[LogicalPlan] = {
+
+    val conjuncts = splitConjunctivePredicates(condition)
+
+    // Find an array_contains predicate that spans both sides
+    val arrayContainsOpt = findArrayContainsPredicate(conjuncts, left, right)
+
+    arrayContainsOpt.flatMap { case (arrayContains, arrayExpr, elementExpr, 
arrayOnLeft) =>
+      // Get the remaining predicates (excluding the array_contains we're 
using)
+      val remainingPredicates = conjuncts.filterNot(_ == arrayContains)
+
+      // Build the transformation
+      buildTransformedPlan(
+        join, left, right, arrayExpr, elementExpr, arrayOnLeft, 
remainingPredicates)
+    }
+  }
+
+  /**
+   * Finds an array_contains predicate where the array comes from one side
+   * and the element from the other side of the join.
+   *
+   * @return Option of (ArrayContains expression, array expression, element 
expression,
+   *         true if array is on left side)
+   */
+  private def findArrayContainsPredicate(
+      conjuncts: Seq[Expression],
+      left: LogicalPlan,
+      right: LogicalPlan): Option[(ArrayContains, Expression, Expression, 
Boolean)] = {
+
+    val leftOutput = left.outputSet
+    val rightOutput = right.outputSet
+
+    conjuncts.collectFirst {
+      case ac @ ArrayContains(arrayExpr, elementExpr)
+          if isSupportedArrayContains(arrayExpr, elementExpr, leftOutput, 
rightOutput) =>
+
+        val arrayOnLeft = arrayExpr.references.subsetOf(leftOutput)
+        (ac, arrayExpr, elementExpr, arrayOnLeft)
+    }
+  }
+
+  /**
+   * Checks if the array_contains can be optimized:
+   * - Element type is in supported types
+   * - Array and element come from different sides of the join
+   * - Both array and element are simple column references
+   */
+  private def isSupportedArrayContains(
+      arrayExpr: Expression,
+      elementExpr: Expression,
+      leftOutput: AttributeSet,
+      rightOutput: AttributeSet): Boolean = {
+
+    // Check element type is supported
+    val elementType = elementExpr.dataType
+    val isTypeSupported = supportedTypes.contains(elementType) ||
+      elementType.isInstanceOf[StringType]
+
+    // Check that array is an array type with matching element type
+    val arrayType = arrayExpr.dataType
+    val isArrayTypeValid = arrayType match {
+      case ArrayType(arrElemType, _) => arrElemType == elementType
+      case _ => false
+    }
+
+    // Check that array comes from one side and element from the other
+    val arrayRefs = arrayExpr.references
+    val elemRefs = elementExpr.references
+
+    val arrayFromLeft = arrayRefs.nonEmpty && arrayRefs.subsetOf(leftOutput)
+    val arrayFromRight = arrayRefs.nonEmpty && arrayRefs.subsetOf(rightOutput)
+    val elemFromLeft = elemRefs.nonEmpty && elemRefs.subsetOf(leftOutput)
+    val elemFromRight = elemRefs.nonEmpty && elemRefs.subsetOf(rightOutput)
+
+    val crossesSides = (arrayFromLeft && elemFromRight) || (arrayFromRight && 
elemFromLeft)
+
+    isTypeSupported && isArrayTypeValid && crossesSides
+  }
+
+  /**
+   * Builds the transformed plan with explode and inner join.
+   */
+  private def buildTransformedPlan(
+      originalJoin: Join,
+      left: LogicalPlan,
+      right: LogicalPlan,
+      arrayExpr: Expression,
+      elementExpr: Expression,
+      arrayOnLeft: Boolean,
+      remainingPredicates: Seq[Expression]): Option[LogicalPlan] = {
+
+    val elementType = elementExpr.dataType
+
+    // Create array_distinct to avoid duplicate matches
+    val distinctArray = ArrayDistinct(arrayExpr)
+
+    // Create the explode generator
+    val explodeExpr = Explode(distinctArray)
+
+    // Create output attribute for the exploded values
+    val unnestedAttr = AttributeReference("unnested", elementType, nullable = 
true)()
+
+    // Determine which side has the array and create Generate node
+    val (planWithGenerate: LogicalPlan, otherPlan: LogicalPlan, joinCondition: 
Expression) =
+      if (arrayOnLeft) {
+        val generate = Generate(
+          generator = explodeExpr,
+          unrequiredChildIndex = Nil,
+          outer = false,
+          qualifier = None,
+          generatorOutput = Seq(unnestedAttr),
+          child = left
+        )
+        val cond = EqualTo(unnestedAttr, elementExpr)
+        (generate, right, cond)
+      } else {
+        val generate = Generate(
+          generator = explodeExpr,
+          unrequiredChildIndex = Nil,
+          outer = false,
+          qualifier = None,
+          generatorOutput = Seq(unnestedAttr),
+          child = right
+        )
+        val cond = EqualTo(elementExpr, unnestedAttr)
+        (left, generate, cond)
+      }
+
+    // Create the inner join with the equi-join condition
+    val innerJoin = if (arrayOnLeft) {
+      Join(planWithGenerate, otherPlan, Inner, Some(joinCondition), 
JoinHint.NONE)
+    } else {
+      Join(otherPlan, planWithGenerate, Inner, Some(joinCondition), 
JoinHint.NONE)
+    }
+
+    // Project to match original output (excluding the unnested column)
+    val originalOutput = originalJoin.output
+    val projectList = originalOutput.map(a => Alias(a, a.name)(a.exprId))
+
+    val projected = Project(projectList, innerJoin)
+
+    // Add remaining filter predicates if any
+    val result = if (remainingPredicates.nonEmpty) {
+      Filter(remainingPredicates.reduceLeft(And), projected)
+    } else {
+      projected
+    }
+
+    Some(result)
+  }
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index fe15819bd44a..0a018bfe08a5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -261,6 +261,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
     Batch("Optimize One Row Plan", fixedPoint, OptimizeOneRowPlan),
     // The following batch should be executed after batch "Join Reorder" and 
"LocalRelation".
     Batch("Check Cartesian Products", Once,
+      CrossJoinArrayContainsToInnerJoin,
       CheckCartesianProducts),
     Batch("RewriteSubquery", Once,
       RewritePredicateSubquery,
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CrossJoinArrayContainsToInnerJoinSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CrossJoinArrayContainsToInnerJoinSuite.scala
new file mode 100644
index 000000000000..6eebc23b46ef
--- /dev/null
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CrossJoinArrayContainsToInnerJoinSuite.scala
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.{Cross, Inner, PlanTest}
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.types._
+
+/**
+ * Test suite for CrossJoinArrayContainsToInnerJoin optimizer rule.
+ *
+ * This rule converts cross joins with array_contains filter into inner joins
+ * using explode/unnest, which is much more efficient.
+ *
+ * Example transformation:
+ * {{{
+ * Filter(array_contains(arr, elem))
+ *   CrossJoin(left, right)
+ * }}}
+ * becomes:
+ * {{{
+ * InnerJoin(unnested = elem)
+ *   Generate(Explode(ArrayDistinct(arr)), left)
+ *   right
+ * }}}
+ */
+class CrossJoinArrayContainsToInnerJoinSuite extends PlanTest {
+
+  object Optimize extends RuleExecutor[LogicalPlan] {
+    val batches =
+      Batch("CrossJoinArrayContainsToInnerJoin", Once,
+        CrossJoinArrayContainsToInnerJoin) :: Nil
+  }
+
+  // Table with array column (simulates "orders" with item_ids array)
+  val ordersRelation: LocalRelation = LocalRelation(
+    $"order_id".int,
+    $"item_ids".array(IntegerType)
+  )
+
+  // Table with element column (simulates "items" with id)
+  val itemsRelation: LocalRelation = LocalRelation(
+    $"id".int,
+    $"name".string
+  )
+
+  test("converts cross join with array_contains to inner join with explode") {
+    // Original query: SELECT * FROM orders, items WHERE 
array_contains(item_ids, id)
+    val originalPlan = ordersRelation
+      .join(itemsRelation, Cross)
+      .where(ArrayContains($"item_ids", $"id"))
+      .analyze
+
+    val optimized = Optimize.execute(originalPlan)
+
+    // After optimization, should be an inner join with explode
+    // The plan should NOT contain a Cross join anymore
+    assert(!optimized.exists {
+      case j: Join if j.joinType == Cross => true
+      case _ => false
+    }, "Optimized plan should not contain Cross join")
+
+    // Should contain a Generate (explode) node
+    assert(optimized.exists {
+      case _: Generate => true
+      case _ => false
+    }, "Optimized plan should contain Generate (explode) node")
+
+    // Should contain an Inner join
+    assert(optimized.exists {
+      case j: Join if j.joinType == Inner => true
+      case _ => false
+    }, "Optimized plan should contain Inner join")
+  }
+
+  test("does not transform when array_contains is not present") {
+    // Query without array_contains: SELECT * FROM orders, items WHERE 
order_id = id
+    val originalPlan = ordersRelation
+      .join(itemsRelation, Cross)
+      .where($"order_id" === $"id")
+      .analyze
+
+    val optimized = Optimize.execute(originalPlan)
+
+    // Should remain unchanged (still a cross join with filter)
+    assert(optimized.exists {
+      case j: Join if j.joinType == Cross => true
+      case _ => false
+    }, "Plan without array_contains should remain unchanged")
+  }
+
+  test("does not transform inner join with existing conditions") {
+    // Already an inner join with equi-condition
+    val originalPlan = ordersRelation
+      .join(itemsRelation, Inner, Some($"order_id" === $"id"))
+      .where(ArrayContains($"item_ids", $"id"))
+      .analyze
+
+    val optimized = Optimize.execute(originalPlan)
+
+    // Should not add another explode since this is already an equi-join
+    // The array_contains becomes just a filter
+    assert(optimized.isInstanceOf[Filter] || optimized.exists {
+      case _: Filter => true
+      case _ => false
+    })
+  }
+
+  test("handles array column on right side of join") {
+    // Swap the tables - array is on right side
+    val rightWithArray: LocalRelation = LocalRelation(
+      $"arr_id".int,
+      $"values".array(IntegerType)
+    )
+    val leftWithElement: LocalRelation = LocalRelation(
+      $"elem".int
+    )
+
+    val originalPlan = leftWithElement
+      .join(rightWithArray, Cross)
+      .where(ArrayContains($"values", $"elem"))
+      .analyze
+
+    val optimized = Optimize.execute(originalPlan)
+
+    // Should still be transformed
+    assert(!optimized.exists {
+      case j: Join if j.joinType == Cross => true
+      case _ => false
+    }, "Should transform even when array is on right side")
+  }
+
+  test("preserves remaining filter predicates") {
+    // Query with additional conditions beyond array_contains
+    val originalPlan = ordersRelation
+      .join(itemsRelation, Cross)
+      .where(ArrayContains($"item_ids", $"id") && ($"order_id" > 100))
+      .analyze
+
+    val optimized = Optimize.execute(originalPlan)
+
+    // Should still have a filter for the remaining predicate (order_id > 100)
+    assert(optimized.exists {
+      case Filter(cond, _) =>
+        cond.find {
+          case GreaterThan(_, Literal(100, IntegerType)) => true
+          case _ => false
+        }.isDefined
+      case _ => false
+    }, "Should preserve remaining filter predicates")
+  }
+
+  test("uses array_distinct to avoid duplicate matches") {
+    val originalPlan = ordersRelation
+      .join(itemsRelation, Cross)
+      .where(ArrayContains($"item_ids", $"id"))
+      .analyze
+
+    val optimized = Optimize.execute(originalPlan)
+
+    // The optimized plan should use ArrayDistinct before exploding
+    // to avoid duplicate rows when array has duplicate elements
+    assert(optimized.exists {
+      case Generate(Explode(ArrayDistinct(_)), _, _, _, _, _) => true
+      case Project(_, Generate(Explode(ArrayDistinct(_)), _, _, _, _, _)) => 
true
+      case _ => false
+    }, "Should use ArrayDistinct before Explode")
+  }
+}
diff --git 
a/sql/core/benchmarks/CrossJoinArrayContainsToInnerJoinBenchmark-jdk21-results.txt
 
b/sql/core/benchmarks/CrossJoinArrayContainsToInnerJoinBenchmark-jdk21-results.txt
new file mode 100644
index 000000000000..47b45cd26615
--- /dev/null
+++ 
b/sql/core/benchmarks/CrossJoinArrayContainsToInnerJoinBenchmark-jdk21-results.txt
@@ -0,0 +1,38 @@
+================================================================================================
+CrossJoinArrayContainsToInnerJoin Benchmark
+================================================================================================
+
+OpenJDK 64-Bit Server VM 21.0.10+7-LTS on Linux 6.11.0-1018-azure
+AMD EPYC 7763 64-Core Processor
+Cross join with array_contains (1000 orders, 100 items, array size 5):  Best 
Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
+-----------------------------------------------------------------------------------------------------------------------------------------------------
+Cross join + array_contains filter (unoptimized)                               
   52             69          15          1.9         520.8       1.0X
+Inner join with explode (optimized equivalent)                                 
   56             74          19          1.8         564.9       0.9X
+Inner join with explode (DataFrame API)                                        
   39             41           2          2.5         393.2       1.3X
+
+OpenJDK 64-Bit Server VM 21.0.10+7-LTS on Linux 6.11.0-1018-azure
+AMD EPYC 7763 64-Core Processor
+Cross join with array_contains (10000 orders, 1000 items, array size 10):  
Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
+--------------------------------------------------------------------------------------------------------------------------------------------------------
+Cross join + array_contains filter (unoptimized)                               
     582            596          19         17.2          58.2       1.0X
+Inner join with explode (optimized equivalent)                                 
      36             39           3        276.2           3.6      16.1X
+Inner join with explode (DataFrame API)                                        
      34             39           5        297.8           3.4      17.3X
+
+OpenJDK 64-Bit Server VM 21.0.10+7-LTS on Linux 6.11.0-1018-azure
+AMD EPYC 7763 64-Core Processor
+Scalability: varying array sizes:         Best Time(ms)   Avg Time(ms)   
Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
+------------------------------------------------------------------------------------------------------------------------
+array_size=1 with explode optimization              143            151         
  7          7.0         143.5       1.0X
+array_size=5 with explode optimization              145            146         
  1          6.9         145.4       1.0X
+array_size=10 with explode optimization             144            150         
 10          6.9         144.3       1.0X
+array_size=50 with explode optimization             142            152         
 15          7.0         142.1       1.0X
+
+OpenJDK 64-Bit Server VM 21.0.10+7-LTS on Linux 6.11.0-1018-azure
+AMD EPYC 7763 64-Core Processor
+Different data types in array:            Best Time(ms)   Avg Time(ms)   
Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
+------------------------------------------------------------------------------------------------------------------------
+Integer array                                        31             39         
  7         32.4          30.9       1.0X
+Long array                                           29             31         
  3         34.2          29.2       1.1X
+String array                                         37             37         
  1         27.2          36.8       0.8X
+
+
diff --git 
a/sql/core/benchmarks/CrossJoinArrayContainsToInnerJoinBenchmark-results.txt 
b/sql/core/benchmarks/CrossJoinArrayContainsToInnerJoinBenchmark-results.txt
new file mode 100644
index 000000000000..8df73c946310
--- /dev/null
+++ b/sql/core/benchmarks/CrossJoinArrayContainsToInnerJoinBenchmark-results.txt
@@ -0,0 +1,38 @@
+================================================================================================
+CrossJoinArrayContainsToInnerJoin Benchmark
+================================================================================================
+
+OpenJDK 64-Bit Server VM 17.0.18+8-LTS on Linux 6.11.0-1018-azure
+AMD EPYC 7763 64-Core Processor
+Cross join with array_contains (1000 orders, 100 items, array size 5):  Best 
Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
+-----------------------------------------------------------------------------------------------------------------------------------------------------
+Cross join + array_contains filter (unoptimized)                               
   52             56           3          1.9         523.4       1.0X
+Inner join with explode (optimized equivalent)                                 
   60             62           2          1.7         598.3       0.9X
+Inner join with explode (DataFrame API)                                        
   44             47           3          2.3         440.5       1.2X
+
+OpenJDK 64-Bit Server VM 17.0.18+8-LTS on Linux 6.11.0-1018-azure
+AMD EPYC 7763 64-Core Processor
+Cross join with array_contains (10000 orders, 1000 items, array size 10):  
Best Time(ms)   Avg Time(ms)   Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
+--------------------------------------------------------------------------------------------------------------------------------------------------------
+Cross join + array_contains filter (unoptimized)                               
     504            533          25         19.8          50.4       1.0X
+Inner join with explode (optimized equivalent)                                 
      45             45           0        221.9           4.5      11.2X
+Inner join with explode (DataFrame API)                                        
      36             40           4        279.7           3.6      14.1X
+
+OpenJDK 64-Bit Server VM 17.0.18+8-LTS on Linux 6.11.0-1018-azure
+AMD EPYC 7763 64-Core Processor
+Scalability: varying array sizes:         Best Time(ms)   Avg Time(ms)   
Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
+------------------------------------------------------------------------------------------------------------------------
+array_size=1 with explode optimization              144            146         
  2          6.9         144.2       1.0X
+array_size=5 with explode optimization              145            146         
  1          6.9         145.4       1.0X
+array_size=10 with explode optimization             142            157         
 17          7.0         142.0       1.0X
+array_size=50 with explode optimization             139            141         
  2          7.2         138.7       1.0X
+
+OpenJDK 64-Bit Server VM 17.0.18+8-LTS on Linux 6.11.0-1018-azure
+AMD EPYC 7763 64-Core Processor
+Different data types in array:            Best Time(ms)   Avg Time(ms)   
Stdev(ms)    Rate(M/s)   Per Row(ns)   Relative
+------------------------------------------------------------------------------------------------------------------------
+Integer array                                        29             33         
  6         34.2          29.3       1.0X
+Long array                                           35             37         
  3         28.9          34.6       0.8X
+String array                                         40             42         
  2         24.7          40.5       0.7X
+
+
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CrossJoinArrayContainsToInnerJoinBenchmark.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CrossJoinArrayContainsToInnerJoinBenchmark.scala
new file mode 100644
index 000000000000..2ac64d9099dc
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CrossJoinArrayContainsToInnerJoinBenchmark.scala
@@ -0,0 +1,227 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.benchmark
+
+import org.apache.spark.benchmark.Benchmark
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
+
+/**
+ * Benchmark to measure performance improvement of 
CrossJoinArrayContainsToInnerJoin optimization.
+ *
+ * This benchmark compares:
+ * 1. Cross join with array_contains filter (unoptimized)
+ * 2. Inner join with explode (manually optimized / what the rule produces)
+ *
+ * To run this benchmark:
+ * {{{
+ *   1. without sbt:
+ *      bin/spark-submit --class <this class>
+ *        --jars <spark core test jar>,<spark catalyst test jar> <spark sql 
test jar>
+ *   2. build/sbt "sql/Test/runMain <this class>"
+ *   3. generate result:
+ *      SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/Test/runMain <this 
class>"
+ *      Results will be written to
+ *        "benchmarks/CrossJoinArrayContainsToInnerJoinBenchmark-results.txt".
+ * }}}
+ */
+object CrossJoinArrayContainsToInnerJoinBenchmark extends SqlBasedBenchmark {
+
+  import spark.implicits._
+
+  private def crossJoinWithArrayContains(numOrders: Int, numItems: Int, 
arraySize: Int): Unit = {
+    val benchmark = new Benchmark(
+      s"Cross join with array_contains ($numOrders orders, $numItems items, 
array size $arraySize)",
+      numOrders.toLong * numItems,
+      output = output
+    )
+
+    // Create orders table with array of item IDs
+    val orders = spark.range(numOrders)
+      .selectExpr(
+        "id as order_id",
+        s"array_repeat(cast((id % $numItems) as int), $arraySize) as item_ids"
+      )
+      .cache()
+
+    // Create items table
+    val items = spark.range(numItems)
+      .selectExpr("cast(id as int) as item_id", "concat('item_', id) as 
item_name")
+      .cache()
+
+    // Force caching
+    orders.count()
+    items.count()
+
+    // Register as temp views for SQL queries
+    orders.createOrReplaceTempView("orders")
+    items.createOrReplaceTempView("items")
+
+    benchmark.addCase("Cross join + array_contains filter (unoptimized)", 
numIters = 3) { _ =>
+      // Disable the optimization to measure the true cross-join+filter 
baseline
+      withSQLConf(
+        SQLConf.CROSS_JOINS_ENABLED.key -> "true",
+        SQLConf.OPTIMIZER_EXCLUDED_RULES.key ->
+          
"org.apache.spark.sql.catalyst.optimizer.CrossJoinArrayContainsToInnerJoin") {
+        // This query would be a cross join with filter without optimization
+        val df = spark.sql(
+          """
+            |SELECT /*+ BROADCAST(items) */ o.order_id, i.item_id, i.item_name
+            |FROM orders o, items i
+            |WHERE array_contains(o.item_ids, i.item_id)
+          """.stripMargin)
+        df.noop()
+      }
+    }
+
+    benchmark.addCase("Inner join with explode (optimized equivalent)", 
numIters = 3) { _ =>
+      // This is what the optimization produces - explode + inner join
+      val df = spark.sql(
+        """
+          |SELECT o.order_id, i.item_id, i.item_name
+          |FROM (
+          |  SELECT order_id, explode(array_distinct(item_ids)) as unnested_id
+          |  FROM orders
+          |) o
+          |INNER JOIN items i ON o.unnested_id = i.item_id
+        """.stripMargin)
+      df.noop()
+    }
+
+    benchmark.addCase("Inner join with explode (DataFrame API)", numIters = 3) 
{ _ =>
+      val ordersExploded = orders
+        .withColumn("unnested_id", explode(array_distinct($"item_ids")))
+        .select($"order_id", $"unnested_id")
+
+      val df = ordersExploded.join(items, $"unnested_id" === $"item_id")
+      df.noop()
+    }
+
+    benchmark.run()
+
+    orders.unpersist()
+    items.unpersist()
+  }
+
+  private def scalabilityBenchmark(): Unit = {
+    val benchmark = new Benchmark(
+      "Scalability: varying array sizes",
+      1000000L,
+      output = output
+    )
+
+    val numOrders = 10000
+    val numItems = 1000
+
+    Seq(1, 5, 10, 50).foreach { arraySize =>
+      val orders = spark.range(numOrders)
+        .selectExpr(
+          "id as order_id",
+          s"transform(sequence(0, $arraySize - 1), " +
+            s"x -> cast((id + x) % $numItems as int)) as item_ids"
+        )
+
+      val items = spark.range(numItems)
+        .selectExpr("cast(id as int) as item_id", "concat('item_', id) as 
item_name")
+
+      orders.createOrReplaceTempView("orders_scale")
+      items.createOrReplaceTempView("items_scale")
+
+      benchmark.addCase(s"array_size=$arraySize with explode optimization", 
numIters = 3) { _ =>
+        val df = spark.sql(
+          """
+            |SELECT o.order_id, i.item_id, i.item_name
+            |FROM (
+            |  SELECT order_id, explode(array_distinct(item_ids)) as 
unnested_id
+            |  FROM orders_scale
+            |) o
+            |INNER JOIN items_scale i ON o.unnested_id = i.item_id
+          """.stripMargin)
+        df.noop()
+      }
+    }
+
+    benchmark.run()
+  }
+
+  private def dataTypeBenchmark(): Unit = {
+    val benchmark = new Benchmark(
+      "Different data types in array",
+      1000000L,
+      output = output
+    )
+
+    val numRows = 10000
+    val numLookup = 1000
+    val arraySize = 10
+
+    // Integer arrays
+    benchmark.addCase("Integer array", numIters = 3) { _ =>
+      val left = spark.range(numRows)
+        .selectExpr("id", s"array_repeat(cast(id % $numLookup as int), 
$arraySize) as arr")
+      val right = spark.range(numLookup).selectExpr("cast(id as int) as elem")
+
+      val df = left
+        .withColumn("unnested", explode(array_distinct($"arr")))
+        .join(right, $"unnested" === $"elem")
+      df.noop()
+    }
+
+    // Long arrays
+    benchmark.addCase("Long array", numIters = 3) { _ =>
+      val left = spark.range(numRows)
+        .selectExpr("id", s"array_repeat(id % $numLookup, $arraySize) as arr")
+      val right = spark.range(numLookup).selectExpr("id as elem")
+
+      val df = left
+        .withColumn("unnested", explode(array_distinct($"arr")))
+        .join(right, $"unnested" === $"elem")
+      df.noop()
+    }
+
+    // String arrays
+    benchmark.addCase("String array", numIters = 3) { _ =>
+      val left = spark.range(numRows)
+        .selectExpr("id", s"array_repeat(concat('key_', id % $numLookup), 
$arraySize) as arr")
+      val right = spark.range(numLookup).selectExpr("concat('key_', id) as 
elem")
+
+      val df = left
+        .withColumn("unnested", explode(array_distinct($"arr")))
+        .join(right, $"unnested" === $"elem")
+      df.noop()
+    }
+
+    benchmark.run()
+  }
+
+  override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
+    runBenchmark("CrossJoinArrayContainsToInnerJoin Benchmark") {
+      // Small scale test
+      crossJoinWithArrayContains(numOrders = 1000, numItems = 100, arraySize = 
5)
+
+      // Medium scale test
+      crossJoinWithArrayContains(numOrders = 10000, numItems = 1000, arraySize 
= 10)
+
+      // Scalability test with varying array sizes
+      scalabilityBenchmark()
+
+      // Data type comparison
+      dataTypeBenchmark()
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to