[ https://issues.apache.org/jira/browse/SPARK-43781?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel ]
Jia Fan updated SPARK-43781: ---------------------------- Affects Version/s: 3.4.0 > IllegalStateException when cogrouping two datasets derived from the same > source > ------------------------------------------------------------------------------- > > Key: SPARK-43781 > URL: https://issues.apache.org/jira/browse/SPARK-43781 > Project: Spark > Issue Type: Bug > Components: SQL > Affects Versions: 3.3.1, 3.4.0 > Environment: Reproduces in a unit test, using Spark 3.3.1, the Java > API, and a {{local[2]}} SparkSession. > Reporter: Derek Murray > Priority: Major > > Attempting to {{cogroup}} two datasets derived from the same source dataset > yields an {{IllegalStateException}} when the query is executed. > Minimal reproducer: > {code:java} > StructType inputType = DataTypes.createStructType( > new StructField[]{ > DataTypes.createStructField("id", DataTypes.LongType, false), > DataTypes.createStructField("type", DataTypes.StringType, false) > } > ); > StructType keyType = DataTypes.createStructType( > new StructField[]{ > DataTypes.createStructField("id", DataTypes.LongType, false) > } > ); > List<Row> inputRows = new ArrayList<>(); > inputRows.add(RowFactory.create(1L, "foo")); > inputRows.add(RowFactory.create(1L, "bar")); > inputRows.add(RowFactory.create(2L, "foo")); > Dataset<Row> input = sparkSession.createDataFrame(inputRows, inputType); > KeyValueGroupedDataset<Row, Row> fooGroups = input > .filter("type = 'foo'") > .groupBy("id") > .as(RowEncoder.apply(keyType), RowEncoder.apply(inputType)); > KeyValueGroupedDataset<Row, Row> barGroups = input > .filter("type = 'bar'") > .groupBy("id") > .as(RowEncoder.apply(keyType), RowEncoder.apply(inputType)); > Dataset<Row> result = fooGroups.cogroup( > barGroups, > (CoGroupFunction<Row, Row, Row, Row>) (row, iterator, iterator1) -> new > ArrayList<Row>().iterator(), > RowEncoder.apply(inputType)); > result.explain(); > result.show();{code} > Explain output (note mismatch in column IDs between Sort/Exchagne and > LocalTableScan on the first input to the CoGroup): > {code:java} > == Physical Plan == > AdaptiveSparkPlan isFinalPlan=false > +- SerializeFromObject > [validateexternaltype(getexternalrowfield(assertnotnull(input[0, > org.apache.spark.sql.Row, true]), 0, id), LongType, false) AS id#37L, > staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, > fromString, validateexternaltype(getexternalrowfield(assertnotnull(input[0, > org.apache.spark.sql.Row, true]), 1, type), StringType, false), true, false, > true) AS type#38] > +- CoGroup > org.apache.spark.sql.KeyValueGroupedDataset$$Lambda$1478/1869116781@77856cc5, > createexternalrow(id#16L, StructField(id,LongType,false)), > createexternalrow(id#16L, type#17.toString, StructField(id,LongType,false), > StructField(type,StringType,false)), createexternalrow(id#16L, > type#17.toString, StructField(id,LongType,false), > StructField(type,StringType,false)), [id#39L], [id#39L], [id#39L, type#40], > [id#39L, type#40], obj#36: org.apache.spark.sql.Row > :- !Sort [id#39L ASC NULLS FIRST], false, 0 > : +- !Exchange hashpartitioning(id#39L, 2), ENSURE_REQUIREMENTS, > [plan_id=19] > : +- LocalTableScan [id#16L, type#17] > +- Sort [id#39L ASC NULLS FIRST], false, 0 > +- Exchange hashpartitioning(id#39L, 2), ENSURE_REQUIREMENTS, > [plan_id=20] > +- LocalTableScan [id#39L, type#40]{code} > Exception: > {code:java} > java.lang.IllegalStateException: Couldn't find id#39L in [id#16L,type#17] > at > org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:80) > at > org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:73) > at > org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$1(TreeNode.scala:584) > at > org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:176) > at > org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:584) > at > org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$3(TreeNode.scala:589) > at scala.collection.immutable.ArraySeq.map(ArraySeq.scala:75) > at scala.collection.immutable.ArraySeq.map(ArraySeq.scala:35) > at > org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:698) > at > org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:589) > at > org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$3(TreeNode.scala:589) > at > org.apache.spark.sql.catalyst.trees.BinaryLike.mapChildren(TreeNode.scala:1254) > at > org.apache.spark.sql.catalyst.trees.BinaryLike.mapChildren$(TreeNode.scala:1253) > at > org.apache.spark.sql.catalyst.expressions.BinaryExpression.mapChildren(Expression.scala:608) > at > org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:589) > at > org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:560) > at > org.apache.spark.sql.catalyst.trees.TreeNode.transform(TreeNode.scala:528) > at > org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReference(BoundAttribute.scala:73) > at > org.apache.spark.sql.catalyst.expressions.BindReferences$.$anonfun$bindReferences$1(BoundAttribute.scala:94) > at scala.collection.immutable.List.map(List.scala:246) > at scala.collection.immutable.List.map(List.scala:79) > at > org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReferences(BoundAttribute.scala:94) > at > org.apache.spark.sql.catalyst.expressions.UnsafeProjection$.create(Projection.scala:160) > at > org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$.getPartitionKeyExtractor$1(ShuffleExchangeExec.scala:323) > at > org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$.$anonfun$prepareShuffleDependency$13(ShuffleExchangeExec.scala:391) > at > org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$.$anonfun$prepareShuffleDependency$13$adapted(ShuffleExchangeExec.scala:390) > at > org.apache.spark.rdd.RDD.$anonfun$mapPartitionsWithIndexInternal$2(RDD.scala:877) > at > org.apache.spark.rdd.RDD.$anonfun$mapPartitionsWithIndexInternal$2$adapted(RDD.scala:877) > at > org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52) > at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365) > at org.apache.spark.rdd.RDD.iterator(RDD.scala:329) > at > org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59) > at > org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:99) > at > org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:52) > at org.apache.spark.scheduler.Task.run(Task.scala:136) > at > org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:548) > at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1504) > at > org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:551) > at > java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149) > at > java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624) > at java.lang.Thread.run(Thread.java:748) {code} > Other observations: > * The same code works if I call {{createDataFrame()}} twice and use two > separate datasets as input to the cogroup. > * The real code uses two different filters on the same cached dataset as the > two inputs to the cogroup. However, this results in the same exception, and > the same apparent error in the physical plan, which looks as follows: > {code:java} > == Physical Plan == > AdaptiveSparkPlan isFinalPlan=false > +- SerializeFromObject > [validateexternaltype(getexternalrowfield(assertnotnull(input[0, > org.apache.spark.sql.Row, true]), 0, id), LongType, false) AS id#47L, > staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, > fromString, validateexternaltype(getexternalrowfield(assertnotnull(input[0, > org.apache.spark.sql.Row, true]), 1, type), StringType, false), true, false, > true) AS type#48] > +- CoGroup > org.apache.spark.sql.KeyValueGroupedDataset$$Lambda$1526/693211959@7b2e931, > createexternalrow(id#16L, StructField(id,LongType,false)), > createexternalrow(id#16L, type#17.toString, StructField(id,LongType,false), > StructField(type,StringType,false)), createexternalrow(id#16L, > type#17.toString, StructField(id,LongType,false), > StructField(type,StringType,false)), [id#49L], [id#49L], [id#49L, type#50], > [id#49L, type#50], obj#46: org.apache.spark.sql.Row > :- !Sort [id#49L ASC NULLS FIRST], false, 0 > : +- !Exchange hashpartitioning(id#49L, 2), ENSURE_REQUIREMENTS, > [plan_id=26] > : +- Filter (type#17 = foo) > : +- InMemoryTableScan [id#16L, type#17], [(type#17 = foo)] > : +- InMemoryRelation [id#16L, type#17], > StorageLevel(disk, memory, deserialized, 1 replicas) > : +- LocalTableScan [id#16L, type#17] > +- Sort [id#49L ASC NULLS FIRST], false, 0 > +- Exchange hashpartitioning(id#49L, 2), ENSURE_REQUIREMENTS, > [plan_id=27] > +- Filter (type#50 = bar) > +- InMemoryTableScan [id#49L, type#50], [(type#50 = bar)] > +- InMemoryRelation [id#49L, type#50], > StorageLevel(disk, memory, deserialized, 1 replicas) > +- LocalTableScan [id#16L, type#17] {code} > * The issue doesn't arise if I write the same code in PySpark, using > {{{}FlatMapCoGroupsInPandas{}}}. -- This message was sent by Atlassian Jira (v8.20.10#820010) --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org For additional commands, e-mail: issues-h...@spark.apache.org