Derek Murray created SPARK-43781: ------------------------------------ Summary: IllegalStateException when cogrouping two datasets derived from the same source Key: SPARK-43781 URL: https://issues.apache.org/jira/browse/SPARK-43781 Project: Spark Issue Type: Bug Components: SQL Affects Versions: 3.3.1 Environment: Reproduces in a unit test, using Spark 3.3.1, the Java API, and a {{local[2]}} SparkSession. Reporter: Derek Murray
Attempting to {{cogroup}} two datasets derived from the same source dataset yields an {{IllegalStateException}} when the query is executed. Minimal reproducer: {code:java} StructType inputType = DataTypes.createStructType( new StructField[]{ DataTypes.createStructField("id", DataTypes.LongType, false), DataTypes.createStructField("type", DataTypes.StringType, false) } ); StructType keyType = DataTypes.createStructType( new StructField[]{ DataTypes.createStructField("id", DataTypes.LongType, false) } ); List<Row> inputRows = new ArrayList<>(); inputRows.add(RowFactory.create(1L, "foo")); inputRows.add(RowFactory.create(1L, "bar")); inputRows.add(RowFactory.create(2L, "foo")); Dataset<Row> input = sparkSession.createDataFrame(inputRows, inputType); KeyValueGroupedDataset<Row, Row> fooGroups = input .filter("type = 'foo'") .groupBy("id") .as(RowEncoder.apply(keyType), RowEncoder.apply(inputType)); KeyValueGroupedDataset<Row, Row> barGroups = input .filter("type = 'bar'") .groupBy("id") .as(RowEncoder.apply(keyType), RowEncoder.apply(inputType)); Dataset<Row> result = fooGroups.cogroup( barGroups, (CoGroupFunction<Row, Row, Row, Row>) (row, iterator, iterator1) -> new ArrayList<Row>().iterator(), RowEncoder.apply(inputType)); result.explain(); result.show();{code} Explain output (note mismatch in column IDs between Sort/Exchagne and LocalTableScan on the first input to the CoGroup): {code:java} == Physical Plan == AdaptiveSparkPlan isFinalPlan=false +- SerializeFromObject [validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 0, id), LongType, false) AS id#37L, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 1, type), StringType, false), true, false, true) AS type#38] +- CoGroup org.apache.spark.sql.KeyValueGroupedDataset$$Lambda$1478/1869116781@77856cc5, createexternalrow(id#16L, StructField(id,LongType,false)), createexternalrow(id#16L, type#17.toString, StructField(id,LongType,false), StructField(type,StringType,false)), createexternalrow(id#16L, type#17.toString, StructField(id,LongType,false), StructField(type,StringType,false)), [id#39L], [id#39L], [id#39L, type#40], [id#39L, type#40], obj#36: org.apache.spark.sql.Row :- !Sort [id#39L ASC NULLS FIRST], false, 0 : +- !Exchange hashpartitioning(id#39L, 2), ENSURE_REQUIREMENTS, [plan_id=19] : +- LocalTableScan [id#16L, type#17] +- Sort [id#39L ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(id#39L, 2), ENSURE_REQUIREMENTS, [plan_id=20] +- LocalTableScan [id#39L, type#40]{code} Exception: {code:java} java.lang.IllegalStateException: Couldn't find id#39L in [id#16L,type#17] at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:80) at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:73) at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$1(TreeNode.scala:584) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:176) at org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:584) at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$3(TreeNode.scala:589) at scala.collection.immutable.ArraySeq.map(ArraySeq.scala:75) at scala.collection.immutable.ArraySeq.map(ArraySeq.scala:35) at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:698) at org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:589) at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$3(TreeNode.scala:589) at org.apache.spark.sql.catalyst.trees.BinaryLike.mapChildren(TreeNode.scala:1254) at org.apache.spark.sql.catalyst.trees.BinaryLike.mapChildren$(TreeNode.scala:1253) at org.apache.spark.sql.catalyst.expressions.BinaryExpression.mapChildren(Expression.scala:608) at org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:589) at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:560) at org.apache.spark.sql.catalyst.trees.TreeNode.transform(TreeNode.scala:528) at org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReference(BoundAttribute.scala:73) at org.apache.spark.sql.catalyst.expressions.BindReferences$.$anonfun$bindReferences$1(BoundAttribute.scala:94) at scala.collection.immutable.List.map(List.scala:246) at scala.collection.immutable.List.map(List.scala:79) at org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReferences(BoundAttribute.scala:94) at org.apache.spark.sql.catalyst.expressions.UnsafeProjection$.create(Projection.scala:160) at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$.getPartitionKeyExtractor$1(ShuffleExchangeExec.scala:323) at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$.$anonfun$prepareShuffleDependency$13(ShuffleExchangeExec.scala:391) at org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$.$anonfun$prepareShuffleDependency$13$adapted(ShuffleExchangeExec.scala:390) at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsWithIndexInternal$2(RDD.scala:877) at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsWithIndexInternal$2$adapted(RDD.scala:877) at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52) at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365) at org.apache.spark.rdd.RDD.iterator(RDD.scala:329) at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59) at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:99) at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:52) at org.apache.spark.scheduler.Task.run(Task.scala:136) at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:548) at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1504) at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:551) at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) at java.lang.Thread.run(Thread.java:748) {code} Other observations: * The same code works if I call {{createDataFrame()}} twice and use two separate datasets as input to the cogroup. * The real code uses two different filters on the same cached dataset as the two inputs to the cogroup. However, this results in the same exception, and the same apparent error in the physical plan, which looks as follows: {code:java} == Physical Plan == AdaptiveSparkPlan isFinalPlan=false +- SerializeFromObject [validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 0, id), LongType, false) AS id#47L, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, validateexternaltype(getexternalrowfield(assertnotnull(input[0, org.apache.spark.sql.Row, true]), 1, type), StringType, false), true, false, true) AS type#48] +- CoGroup org.apache.spark.sql.KeyValueGroupedDataset$$Lambda$1526/693211959@7b2e931, createexternalrow(id#16L, StructField(id,LongType,false)), createexternalrow(id#16L, type#17.toString, StructField(id,LongType,false), StructField(type,StringType,false)), createexternalrow(id#16L, type#17.toString, StructField(id,LongType,false), StructField(type,StringType,false)), [id#49L], [id#49L], [id#49L, type#50], [id#49L, type#50], obj#46: org.apache.spark.sql.Row :- !Sort [id#49L ASC NULLS FIRST], false, 0 : +- !Exchange hashpartitioning(id#49L, 2), ENSURE_REQUIREMENTS, [plan_id=26] : +- Filter (type#17 = foo) : +- InMemoryTableScan [id#16L, type#17], [(type#17 = foo)] : +- InMemoryRelation [id#16L, type#17], StorageLevel(disk, memory, deserialized, 1 replicas) : +- LocalTableScan [id#16L, type#17] +- Sort [id#49L ASC NULLS FIRST], false, 0 +- Exchange hashpartitioning(id#49L, 2), ENSURE_REQUIREMENTS, [plan_id=27] +- Filter (type#50 = bar) +- InMemoryTableScan [id#49L, type#50], [(type#50 = bar)] +- InMemoryRelation [id#49L, type#50], StorageLevel(disk, memory, deserialized, 1 replicas) +- LocalTableScan [id#16L, type#17] {code} * The issue doesn't arise if I write the same code in PySpark, using {{{}FlatMapCoGroupsInPandas{}}}. -- This message was sent by Atlassian Jira (v8.20.10#820010) --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org For additional commands, e-mail: issues-h...@spark.apache.org