This is an automated email from the ASF dual-hosted git repository. yao pushed a commit to branch feature/crossjoin-array-contains-benchmark in repository https://gitbox.apache.org/repos/asf/spark.git
commit b9c7217539ee7e3f6e031a7d4d512facd75b5b8f Author: Kent Yao <[email protected]> AuthorDate: Wed Feb 4 09:45:58 2026 +0000 [SPARK-XXXX][SQL] Add CrossJoinArrayContainsToInnerJoin optimizer rule ### What changes were proposed in this pull request? This PR adds a new optimizer rule that converts cross joins with array_contains filter into inner joins using explode, improving query performance significantly. ### Why are the changes needed? Cross joins with array_contains predicates result in O(N*M) complexity. By transforming to explode + inner join, we achieve O(N+M) complexity. ### Does this PR introduce _any_ user-facing change? No. This is an internal optimization that automatically applies to applicable queries. ### How was this patch tested? - Unit tests in CrossJoinArrayContainsToInnerJoinSuite (6 tests) - Microbenchmark showing 11-16X speedup on representative workload ### Was this patch authored or co-authored using generative AI tooling? Yes, GitHub Copilot was used to assist with implementation. --- .../CrossJoinArrayContainsToInnerJoin.scala | 236 ++++++++++++++++++ .../spark/sql/catalyst/optimizer/Optimizer.scala | 1 + .../CrossJoinArrayContainsToInnerJoinSuite.scala | 274 +++++++++++++++++++++ ...yContainsToInnerJoinBenchmark-jdk21-results.txt | 38 +++ ...inArrayContainsToInnerJoinBenchmark-results.txt | 38 +++ ...rossJoinArrayContainsToInnerJoinBenchmark.scala | 227 +++++++++++++++++ 6 files changed, 814 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CrossJoinArrayContainsToInnerJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CrossJoinArrayContainsToInnerJoin.scala new file mode 100644 index 000000000000..f0852c394798 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CrossJoinArrayContainsToInnerJoin.scala @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.trees.TreePattern.{FILTER, JOIN} +import org.apache.spark.sql.types._ + +/** + * Converts cross joins with array_contains filter into inner joins using explode. + * + * This optimization transforms queries of the form: + * {{{ + * SELECT * FROM left, right WHERE array_contains(left.arr, right.elem) + * }}} + * + * Into a more efficient form: + * {{{ + * SELECT * FROM ( + * SELECT *, explode(array_distinct(arr)) AS unnested FROM left + * ) l + * INNER JOIN right ON l.unnested = right.elem + * }}} + * + * This avoids the O(N*M) cross join by using unnesting and equi-join. + */ +object CrossJoinArrayContainsToInnerJoin extends Rule[LogicalPlan] with PredicateHelper { + + /** + * Check if the element type supports proper equality semantics for the optimization. + * The type must have consistent equality behavior between array_contains and join conditions. + * We exclude floating point types (Float/Double) due to NaN semantics issues. + */ + private def isSupportedElementType(dataType: DataType): Boolean = dataType match { + // Integral types - exact equality + case ByteType | ShortType | IntegerType | LongType => true + // Decimal - exact equality with proper precision + case _: DecimalType => true + // String with binary equality + case _: StringType => true + // Date and Timestamp - exact equality + case DateType | TimestampType | TimestampNTZType => true + // Boolean - exact equality (though low cardinality makes optimization less impactful) + case BooleanType => true + // Binary - Spark's join uses content-based hash/comparison via ByteArray.compareBinary + case BinaryType => true + // Float/Double excluded due to NaN != NaN semantics + // Complex types (Array, Map, Struct) excluded - not supported by array_contains anyway + case _ => false + } + + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( + _.containsAllPatterns(FILTER, JOIN)) { + case filter @ Filter(condition, join @ Join(left, right, Cross | Inner, None, _)) + if join.condition.isEmpty => + transformFilterOverCrossJoin(filter, condition, join, left, right) + .getOrElse(filter) + } + + /** + * Attempts to transform a Filter over a Cross/Inner join (with no condition) that has + * an array_contains predicate into an inner join with explode. + */ + private def transformFilterOverCrossJoin( + filter: Filter, + condition: Expression, + join: Join, + left: LogicalPlan, + right: LogicalPlan): Option[LogicalPlan] = { + + val conjuncts = splitConjunctivePredicates(condition) + + // Find an array_contains predicate that spans both sides + val arrayContainsOpt = findArrayContainsPredicate(conjuncts, left, right) + + arrayContainsOpt.flatMap { case (arrayContains, arrayExpr, elementExpr, arrayOnLeft) => + // Get the remaining predicates (excluding the array_contains we're using) + val remainingPredicates = conjuncts.filterNot(_ == arrayContains) + + // Build the transformation + buildTransformedPlan( + join, left, right, arrayExpr, elementExpr, arrayOnLeft, remainingPredicates) + } + } + + /** + * Finds an array_contains predicate where the array comes from one side + * and the element from the other side of the join. + * + * @return Option of (ArrayContains expression, array expression, element expression, + * true if array is on left side) + */ + private def findArrayContainsPredicate( + conjuncts: Seq[Expression], + left: LogicalPlan, + right: LogicalPlan): Option[(ArrayContains, Expression, Expression, Boolean)] = { + + val leftOutput = left.outputSet + val rightOutput = right.outputSet + + conjuncts.collectFirst { + case ac @ ArrayContains(arrayExpr, elementExpr) + if isSupportedArrayContains(arrayExpr, elementExpr, leftOutput, rightOutput) => + + val arrayOnLeft = arrayExpr.references.subsetOf(leftOutput) + (ac, arrayExpr, elementExpr, arrayOnLeft) + } + } + + /** + * Checks if the array_contains can be optimized: + * - Element type is in supported types + * - Array and element come from different sides of the join + * - Both array and element are simple column references + */ + private def isSupportedArrayContains( + arrayExpr: Expression, + elementExpr: Expression, + leftOutput: AttributeSet, + rightOutput: AttributeSet): Boolean = { + + // Check element type is supported + val elementType = elementExpr.dataType + val isTypeSupported = isSupportedElementType(elementType) + + // Check that array is an array type with matching element type + val arrayType = arrayExpr.dataType + val isArrayTypeValid = arrayType match { + case ArrayType(arrElemType, _) => arrElemType == elementType + case _ => false + } + + // Check that array comes from one side and element from the other + val arrayRefs = arrayExpr.references + val elemRefs = elementExpr.references + + val arrayFromLeft = arrayRefs.nonEmpty && arrayRefs.subsetOf(leftOutput) + val arrayFromRight = arrayRefs.nonEmpty && arrayRefs.subsetOf(rightOutput) + val elemFromLeft = elemRefs.nonEmpty && elemRefs.subsetOf(leftOutput) + val elemFromRight = elemRefs.nonEmpty && elemRefs.subsetOf(rightOutput) + + val crossesSides = (arrayFromLeft && elemFromRight) || (arrayFromRight && elemFromLeft) + + isTypeSupported && isArrayTypeValid && crossesSides + } + + /** + * Builds the transformed plan with explode and inner join. + */ + private def buildTransformedPlan( + originalJoin: Join, + left: LogicalPlan, + right: LogicalPlan, + arrayExpr: Expression, + elementExpr: Expression, + arrayOnLeft: Boolean, + remainingPredicates: Seq[Expression]): Option[LogicalPlan] = { + + val elementType = elementExpr.dataType + + // Create array_distinct to avoid duplicate matches + val distinctArray = ArrayDistinct(arrayExpr) + + // Create the explode generator + val explodeExpr = Explode(distinctArray) + + // Create output attribute for the exploded values + val unnestedAttr = AttributeReference("unnested", elementType, nullable = true)() + + // Determine which side has the array and create Generate node + val (planWithGenerate: LogicalPlan, otherPlan: LogicalPlan, joinCondition: Expression) = + if (arrayOnLeft) { + val generate = Generate( + generator = explodeExpr, + unrequiredChildIndex = Nil, + outer = false, + qualifier = None, + generatorOutput = Seq(unnestedAttr), + child = left + ) + val cond = EqualTo(unnestedAttr, elementExpr) + (generate, right, cond) + } else { + val generate = Generate( + generator = explodeExpr, + unrequiredChildIndex = Nil, + outer = false, + qualifier = None, + generatorOutput = Seq(unnestedAttr), + child = right + ) + val cond = EqualTo(elementExpr, unnestedAttr) + (left, generate, cond) + } + + // Create the inner join with the equi-join condition + val innerJoin = if (arrayOnLeft) { + Join(planWithGenerate, otherPlan, Inner, Some(joinCondition), JoinHint.NONE) + } else { + Join(otherPlan, planWithGenerate, Inner, Some(joinCondition), JoinHint.NONE) + } + + // Project to match original output (excluding the unnested column) + val originalOutput = originalJoin.output + val projectList = originalOutput.map(a => Alias(a, a.name)(a.exprId)) + + val projected = Project(projectList, innerJoin) + + // Add remaining filter predicates if any + val result = if (remainingPredicates.nonEmpty) { + Filter(remainingPredicates.reduceLeft(And), projected) + } else { + projected + } + + Some(result) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index fe15819bd44a..0a018bfe08a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -261,6 +261,7 @@ abstract class Optimizer(catalogManager: CatalogManager) Batch("Optimize One Row Plan", fixedPoint, OptimizeOneRowPlan), // The following batch should be executed after batch "Join Reorder" and "LocalRelation". Batch("Check Cartesian Products", Once, + CrossJoinArrayContainsToInnerJoin, CheckCartesianProducts), Batch("RewriteSubquery", Once, RewritePredicateSubquery, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CrossJoinArrayContainsToInnerJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CrossJoinArrayContainsToInnerJoinSuite.scala new file mode 100644 index 000000000000..81fe449be71f --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CrossJoinArrayContainsToInnerJoinSuite.scala @@ -0,0 +1,274 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.{Cross, Inner, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types._ + +/** + * Test suite for CrossJoinArrayContainsToInnerJoin optimizer rule. + * + * This rule converts cross joins with array_contains filter into inner joins + * using explode/unnest, which is much more efficient. + * + * Example transformation: + * {{{ + * Filter(array_contains(arr, elem)) + * CrossJoin(left, right) + * }}} + * becomes: + * {{{ + * InnerJoin(unnested = elem) + * Generate(Explode(ArrayDistinct(arr)), left) + * right + * }}} + */ +class CrossJoinArrayContainsToInnerJoinSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("CrossJoinArrayContainsToInnerJoin", Once, + CrossJoinArrayContainsToInnerJoin) :: Nil + } + + // Table with array column (simulates "orders" with item_ids array) + val ordersRelation: LocalRelation = LocalRelation( + $"order_id".int, + $"item_ids".array(IntegerType) + ) + + // Table with element column (simulates "items" with id) + val itemsRelation: LocalRelation = LocalRelation( + $"id".int, + $"name".string + ) + + test("converts cross join with array_contains to inner join with explode") { + // Original query: SELECT * FROM orders, items WHERE array_contains(item_ids, id) + val originalPlan = ordersRelation + .join(itemsRelation, Cross) + .where(ArrayContains($"item_ids", $"id")) + .analyze + + val optimized = Optimize.execute(originalPlan) + + // After optimization, should be an inner join with explode + // The plan should NOT contain a Cross join anymore + assert(!optimized.exists { + case j: Join if j.joinType == Cross => true + case _ => false + }, "Optimized plan should not contain Cross join") + + // Should contain a Generate (explode) node + assert(optimized.exists { + case _: Generate => true + case _ => false + }, "Optimized plan should contain Generate (explode) node") + + // Should contain an Inner join + assert(optimized.exists { + case j: Join if j.joinType == Inner => true + case _ => false + }, "Optimized plan should contain Inner join") + } + + test("does not transform when array_contains is not present") { + // Query without array_contains: SELECT * FROM orders, items WHERE order_id = id + val originalPlan = ordersRelation + .join(itemsRelation, Cross) + .where($"order_id" === $"id") + .analyze + + val optimized = Optimize.execute(originalPlan) + + // Should remain unchanged (still a cross join with filter) + assert(optimized.exists { + case j: Join if j.joinType == Cross => true + case _ => false + }, "Plan without array_contains should remain unchanged") + } + + test("does not transform inner join with existing conditions") { + // Already an inner join with equi-condition + val originalPlan = ordersRelation + .join(itemsRelation, Inner, Some($"order_id" === $"id")) + .where(ArrayContains($"item_ids", $"id")) + .analyze + + val optimized = Optimize.execute(originalPlan) + + // Should not add another explode since this is already an equi-join + // The array_contains becomes just a filter + assert(optimized.isInstanceOf[Filter] || optimized.exists { + case _: Filter => true + case _ => false + }) + } + + test("handles array column on right side of join") { + // Swap the tables - array is on right side + val rightWithArray: LocalRelation = LocalRelation( + $"arr_id".int, + $"values".array(IntegerType) + ) + val leftWithElement: LocalRelation = LocalRelation( + $"elem".int + ) + + val originalPlan = leftWithElement + .join(rightWithArray, Cross) + .where(ArrayContains($"values", $"elem")) + .analyze + + val optimized = Optimize.execute(originalPlan) + + // Should still be transformed + assert(!optimized.exists { + case j: Join if j.joinType == Cross => true + case _ => false + }, "Should transform even when array is on right side") + } + + test("preserves remaining filter predicates") { + // Query with additional conditions beyond array_contains + val originalPlan = ordersRelation + .join(itemsRelation, Cross) + .where(ArrayContains($"item_ids", $"id") && ($"order_id" > 100)) + .analyze + + val optimized = Optimize.execute(originalPlan) + + // Should still have a filter for the remaining predicate (order_id > 100) + assert(optimized.exists { + case Filter(cond, _) => + cond.find { + case GreaterThan(_, Literal(100, IntegerType)) => true + case _ => false + }.isDefined + case _ => false + }, "Should preserve remaining filter predicates") + } + + test("uses array_distinct to avoid duplicate matches") { + val originalPlan = ordersRelation + .join(itemsRelation, Cross) + .where(ArrayContains($"item_ids", $"id")) + .analyze + + val optimized = Optimize.execute(originalPlan) + + // The optimized plan should use ArrayDistinct before exploding + // to avoid duplicate rows when array has duplicate elements + assert(optimized.exists { + case Generate(Explode(ArrayDistinct(_)), _, _, _, _, _) => true + case Project(_, Generate(Explode(ArrayDistinct(_)), _, _, _, _, _)) => true + case _ => false + }, "Should use ArrayDistinct before Explode") + } + + test("supports ByteType elements") { + val leftRel = LocalRelation($"id".int, $"arr".array(ByteType)) + val rightRel = LocalRelation($"elem".byte) + val plan = leftRel.join(rightRel, Cross).where(ArrayContains($"arr", $"elem")).analyze + val optimized = Optimize.execute(plan) + assert(optimized.exists { case _: Generate => true; case _ => false }) + } + + test("supports ShortType elements") { + val leftRel = LocalRelation($"id".int, $"arr".array(ShortType)) + val rightRel = LocalRelation($"elem".short) + val plan = leftRel.join(rightRel, Cross).where(ArrayContains($"arr", $"elem")).analyze + val optimized = Optimize.execute(plan) + assert(optimized.exists { case _: Generate => true; case _ => false }) + } + + test("supports DecimalType elements") { + val leftRel = LocalRelation($"id".int, $"arr".array(DecimalType(10, 2))) + val rightRel = LocalRelation($"elem".decimal(10, 2)) + val plan = leftRel.join(rightRel, Cross).where(ArrayContains($"arr", $"elem")).analyze + val optimized = Optimize.execute(plan) + assert(optimized.exists { case _: Generate => true; case _ => false }) + } + + test("supports TimestampType elements") { + val leftRel = LocalRelation($"id".int, $"arr".array(TimestampType)) + val rightRel = LocalRelation($"elem".timestamp) + val plan = leftRel.join(rightRel, Cross).where(ArrayContains($"arr", $"elem")).analyze + val optimized = Optimize.execute(plan) + assert(optimized.exists { case _: Generate => true; case _ => false }) + } + + test("supports BooleanType elements") { + val leftRel = LocalRelation($"id".int, $"arr".array(BooleanType)) + val rightRel = LocalRelation($"elem".boolean) + val plan = leftRel.join(rightRel, Cross).where(ArrayContains($"arr", $"elem")).analyze + val optimized = Optimize.execute(plan) + assert(optimized.exists { case _: Generate => true; case _ => false }) + } + + test("does not transform FloatType elements due to NaN semantics") { + val leftRel = LocalRelation($"id".int, $"arr".array(FloatType)) + val rightRel = LocalRelation($"elem".float) + val plan = leftRel.join(rightRel, Cross).where(ArrayContains($"arr", $"elem")).analyze + val optimized = Optimize.execute(plan) + // Should NOT be transformed - still contains Cross join + assert(optimized.exists { + case j: Join if j.joinType == Cross => true + case _ => false + }, "FloatType should not be transformed due to NaN semantics") + } + + test("does not transform DoubleType elements due to NaN semantics") { + val leftRel = LocalRelation($"id".int, $"arr".array(DoubleType)) + val rightRel = LocalRelation($"elem".double) + val plan = leftRel.join(rightRel, Cross).where(ArrayContains($"arr", $"elem")).analyze + val optimized = Optimize.execute(plan) + // Should NOT be transformed - still contains Cross join + assert(optimized.exists { + case j: Join if j.joinType == Cross => true + case _ => false + }, "DoubleType should not be transformed due to NaN semantics") + } + + test("supports BinaryType elements") { + // BinaryType is safe because Spark's join uses content-based hash/comparison + // via ByteArray.compareBinary, not Java's Array.equals() + val leftRel = LocalRelation($"id".int, $"arr".array(BinaryType)) + val rightRel = LocalRelation($"elem".binary) + val plan = leftRel.join(rightRel, Cross).where(ArrayContains($"arr", $"elem")).analyze + val optimized = Optimize.execute(plan) + assert(optimized.exists { case _: Generate => true; case _ => false }) + } + + test("does not transform StructType elements") { + val structType = StructType(Seq(StructField("a", IntegerType), StructField("b", StringType))) + val leftRel = LocalRelation($"id".int, $"arr".array(structType)) + val rightRel = LocalRelation($"elem".struct(structType)) + val plan = leftRel.join(rightRel, Cross).where(ArrayContains($"arr", $"elem")).analyze + val optimized = Optimize.execute(plan) + assert(optimized.exists { + case j: Join if j.joinType == Cross => true + case _ => false + }, "StructType should not be transformed due to complex equality semantics") + } +} diff --git a/sql/core/benchmarks/CrossJoinArrayContainsToInnerJoinBenchmark-jdk21-results.txt b/sql/core/benchmarks/CrossJoinArrayContainsToInnerJoinBenchmark-jdk21-results.txt new file mode 100644 index 000000000000..47b45cd26615 --- /dev/null +++ b/sql/core/benchmarks/CrossJoinArrayContainsToInnerJoinBenchmark-jdk21-results.txt @@ -0,0 +1,38 @@ +================================================================================================ +CrossJoinArrayContainsToInnerJoin Benchmark +================================================================================================ + +OpenJDK 64-Bit Server VM 21.0.10+7-LTS on Linux 6.11.0-1018-azure +AMD EPYC 7763 64-Core Processor +Cross join with array_contains (1000 orders, 100 items, array size 5): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +----------------------------------------------------------------------------------------------------------------------------------------------------- +Cross join + array_contains filter (unoptimized) 52 69 15 1.9 520.8 1.0X +Inner join with explode (optimized equivalent) 56 74 19 1.8 564.9 0.9X +Inner join with explode (DataFrame API) 39 41 2 2.5 393.2 1.3X + +OpenJDK 64-Bit Server VM 21.0.10+7-LTS on Linux 6.11.0-1018-azure +AMD EPYC 7763 64-Core Processor +Cross join with array_contains (10000 orders, 1000 items, array size 10): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +-------------------------------------------------------------------------------------------------------------------------------------------------------- +Cross join + array_contains filter (unoptimized) 582 596 19 17.2 58.2 1.0X +Inner join with explode (optimized equivalent) 36 39 3 276.2 3.6 16.1X +Inner join with explode (DataFrame API) 34 39 5 297.8 3.4 17.3X + +OpenJDK 64-Bit Server VM 21.0.10+7-LTS on Linux 6.11.0-1018-azure +AMD EPYC 7763 64-Core Processor +Scalability: varying array sizes: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +array_size=1 with explode optimization 143 151 7 7.0 143.5 1.0X +array_size=5 with explode optimization 145 146 1 6.9 145.4 1.0X +array_size=10 with explode optimization 144 150 10 6.9 144.3 1.0X +array_size=50 with explode optimization 142 152 15 7.0 142.1 1.0X + +OpenJDK 64-Bit Server VM 21.0.10+7-LTS on Linux 6.11.0-1018-azure +AMD EPYC 7763 64-Core Processor +Different data types in array: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Integer array 31 39 7 32.4 30.9 1.0X +Long array 29 31 3 34.2 29.2 1.1X +String array 37 37 1 27.2 36.8 0.8X + + diff --git a/sql/core/benchmarks/CrossJoinArrayContainsToInnerJoinBenchmark-results.txt b/sql/core/benchmarks/CrossJoinArrayContainsToInnerJoinBenchmark-results.txt new file mode 100644 index 000000000000..8df73c946310 --- /dev/null +++ b/sql/core/benchmarks/CrossJoinArrayContainsToInnerJoinBenchmark-results.txt @@ -0,0 +1,38 @@ +================================================================================================ +CrossJoinArrayContainsToInnerJoin Benchmark +================================================================================================ + +OpenJDK 64-Bit Server VM 17.0.18+8-LTS on Linux 6.11.0-1018-azure +AMD EPYC 7763 64-Core Processor +Cross join with array_contains (1000 orders, 100 items, array size 5): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +----------------------------------------------------------------------------------------------------------------------------------------------------- +Cross join + array_contains filter (unoptimized) 52 56 3 1.9 523.4 1.0X +Inner join with explode (optimized equivalent) 60 62 2 1.7 598.3 0.9X +Inner join with explode (DataFrame API) 44 47 3 2.3 440.5 1.2X + +OpenJDK 64-Bit Server VM 17.0.18+8-LTS on Linux 6.11.0-1018-azure +AMD EPYC 7763 64-Core Processor +Cross join with array_contains (10000 orders, 1000 items, array size 10): Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +-------------------------------------------------------------------------------------------------------------------------------------------------------- +Cross join + array_contains filter (unoptimized) 504 533 25 19.8 50.4 1.0X +Inner join with explode (optimized equivalent) 45 45 0 221.9 4.5 11.2X +Inner join with explode (DataFrame API) 36 40 4 279.7 3.6 14.1X + +OpenJDK 64-Bit Server VM 17.0.18+8-LTS on Linux 6.11.0-1018-azure +AMD EPYC 7763 64-Core Processor +Scalability: varying array sizes: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +array_size=1 with explode optimization 144 146 2 6.9 144.2 1.0X +array_size=5 with explode optimization 145 146 1 6.9 145.4 1.0X +array_size=10 with explode optimization 142 157 17 7.0 142.0 1.0X +array_size=50 with explode optimization 139 141 2 7.2 138.7 1.0X + +OpenJDK 64-Bit Server VM 17.0.18+8-LTS on Linux 6.11.0-1018-azure +AMD EPYC 7763 64-Core Processor +Different data types in array: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +Integer array 29 33 6 34.2 29.3 1.0X +Long array 35 37 3 28.9 34.6 0.8X +String array 40 42 2 24.7 40.5 0.7X + + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CrossJoinArrayContainsToInnerJoinBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CrossJoinArrayContainsToInnerJoinBenchmark.scala new file mode 100644 index 000000000000..2ac64d9099dc --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CrossJoinArrayContainsToInnerJoinBenchmark.scala @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.benchmark + +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf + +/** + * Benchmark to measure performance improvement of CrossJoinArrayContainsToInnerJoin optimization. + * + * This benchmark compares: + * 1. Cross join with array_contains filter (unoptimized) + * 2. Inner join with explode (manually optimized / what the rule produces) + * + * To run this benchmark: + * {{{ + * 1. without sbt: + * bin/spark-submit --class <this class> + * --jars <spark core test jar>,<spark catalyst test jar> <spark sql test jar> + * 2. build/sbt "sql/Test/runMain <this class>" + * 3. generate result: + * SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/Test/runMain <this class>" + * Results will be written to + * "benchmarks/CrossJoinArrayContainsToInnerJoinBenchmark-results.txt". + * }}} + */ +object CrossJoinArrayContainsToInnerJoinBenchmark extends SqlBasedBenchmark { + + import spark.implicits._ + + private def crossJoinWithArrayContains(numOrders: Int, numItems: Int, arraySize: Int): Unit = { + val benchmark = new Benchmark( + s"Cross join with array_contains ($numOrders orders, $numItems items, array size $arraySize)", + numOrders.toLong * numItems, + output = output + ) + + // Create orders table with array of item IDs + val orders = spark.range(numOrders) + .selectExpr( + "id as order_id", + s"array_repeat(cast((id % $numItems) as int), $arraySize) as item_ids" + ) + .cache() + + // Create items table + val items = spark.range(numItems) + .selectExpr("cast(id as int) as item_id", "concat('item_', id) as item_name") + .cache() + + // Force caching + orders.count() + items.count() + + // Register as temp views for SQL queries + orders.createOrReplaceTempView("orders") + items.createOrReplaceTempView("items") + + benchmark.addCase("Cross join + array_contains filter (unoptimized)", numIters = 3) { _ => + // Disable the optimization to measure the true cross-join+filter baseline + withSQLConf( + SQLConf.CROSS_JOINS_ENABLED.key -> "true", + SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> + "org.apache.spark.sql.catalyst.optimizer.CrossJoinArrayContainsToInnerJoin") { + // This query would be a cross join with filter without optimization + val df = spark.sql( + """ + |SELECT /*+ BROADCAST(items) */ o.order_id, i.item_id, i.item_name + |FROM orders o, items i + |WHERE array_contains(o.item_ids, i.item_id) + """.stripMargin) + df.noop() + } + } + + benchmark.addCase("Inner join with explode (optimized equivalent)", numIters = 3) { _ => + // This is what the optimization produces - explode + inner join + val df = spark.sql( + """ + |SELECT o.order_id, i.item_id, i.item_name + |FROM ( + | SELECT order_id, explode(array_distinct(item_ids)) as unnested_id + | FROM orders + |) o + |INNER JOIN items i ON o.unnested_id = i.item_id + """.stripMargin) + df.noop() + } + + benchmark.addCase("Inner join with explode (DataFrame API)", numIters = 3) { _ => + val ordersExploded = orders + .withColumn("unnested_id", explode(array_distinct($"item_ids"))) + .select($"order_id", $"unnested_id") + + val df = ordersExploded.join(items, $"unnested_id" === $"item_id") + df.noop() + } + + benchmark.run() + + orders.unpersist() + items.unpersist() + } + + private def scalabilityBenchmark(): Unit = { + val benchmark = new Benchmark( + "Scalability: varying array sizes", + 1000000L, + output = output + ) + + val numOrders = 10000 + val numItems = 1000 + + Seq(1, 5, 10, 50).foreach { arraySize => + val orders = spark.range(numOrders) + .selectExpr( + "id as order_id", + s"transform(sequence(0, $arraySize - 1), " + + s"x -> cast((id + x) % $numItems as int)) as item_ids" + ) + + val items = spark.range(numItems) + .selectExpr("cast(id as int) as item_id", "concat('item_', id) as item_name") + + orders.createOrReplaceTempView("orders_scale") + items.createOrReplaceTempView("items_scale") + + benchmark.addCase(s"array_size=$arraySize with explode optimization", numIters = 3) { _ => + val df = spark.sql( + """ + |SELECT o.order_id, i.item_id, i.item_name + |FROM ( + | SELECT order_id, explode(array_distinct(item_ids)) as unnested_id + | FROM orders_scale + |) o + |INNER JOIN items_scale i ON o.unnested_id = i.item_id + """.stripMargin) + df.noop() + } + } + + benchmark.run() + } + + private def dataTypeBenchmark(): Unit = { + val benchmark = new Benchmark( + "Different data types in array", + 1000000L, + output = output + ) + + val numRows = 10000 + val numLookup = 1000 + val arraySize = 10 + + // Integer arrays + benchmark.addCase("Integer array", numIters = 3) { _ => + val left = spark.range(numRows) + .selectExpr("id", s"array_repeat(cast(id % $numLookup as int), $arraySize) as arr") + val right = spark.range(numLookup).selectExpr("cast(id as int) as elem") + + val df = left + .withColumn("unnested", explode(array_distinct($"arr"))) + .join(right, $"unnested" === $"elem") + df.noop() + } + + // Long arrays + benchmark.addCase("Long array", numIters = 3) { _ => + val left = spark.range(numRows) + .selectExpr("id", s"array_repeat(id % $numLookup, $arraySize) as arr") + val right = spark.range(numLookup).selectExpr("id as elem") + + val df = left + .withColumn("unnested", explode(array_distinct($"arr"))) + .join(right, $"unnested" === $"elem") + df.noop() + } + + // String arrays + benchmark.addCase("String array", numIters = 3) { _ => + val left = spark.range(numRows) + .selectExpr("id", s"array_repeat(concat('key_', id % $numLookup), $arraySize) as arr") + val right = spark.range(numLookup).selectExpr("concat('key_', id) as elem") + + val df = left + .withColumn("unnested", explode(array_distinct($"arr"))) + .join(right, $"unnested" === $"elem") + df.noop() + } + + benchmark.run() + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + runBenchmark("CrossJoinArrayContainsToInnerJoin Benchmark") { + // Small scale test + crossJoinWithArrayContains(numOrders = 1000, numItems = 100, arraySize = 5) + + // Medium scale test + crossJoinWithArrayContains(numOrders = 10000, numItems = 1000, arraySize = 10) + + // Scalability test with varying array sizes + scalabilityBenchmark() + + // Data type comparison + dataTypeBenchmark() + } + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
