Derek Murray created SPARK-43781:
------------------------------------

             Summary: IllegalStateException when cogrouping two datasets 
derived from the same source
                 Key: SPARK-43781
                 URL: https://issues.apache.org/jira/browse/SPARK-43781
             Project: Spark
          Issue Type: Bug
          Components: SQL
    Affects Versions: 3.3.1
         Environment: Reproduces in a unit test, using Spark 3.3.1, the Java 
API, and a {{local[2]}} SparkSession.
            Reporter: Derek Murray


Attempting to {{cogroup}} two datasets derived from the same source dataset 
yields an {{IllegalStateException}} when the query is executed.

Minimal reproducer:
{code:java}
StructType inputType = DataTypes.createStructType(
    new StructField[]{
        DataTypes.createStructField("id", DataTypes.LongType, false),
        DataTypes.createStructField("type", DataTypes.StringType, false)
    }
);

StructType keyType = DataTypes.createStructType(
    new StructField[]{
        DataTypes.createStructField("id", DataTypes.LongType, false)
    }
);

List<Row> inputRows = new ArrayList<>();
inputRows.add(RowFactory.create(1L, "foo"));
inputRows.add(RowFactory.create(1L, "bar"));
inputRows.add(RowFactory.create(2L, "foo"));
Dataset<Row> input = sparkSession.createDataFrame(inputRows, inputType);

KeyValueGroupedDataset<Row, Row> fooGroups = input
    .filter("type = 'foo'")
    .groupBy("id")
    .as(RowEncoder.apply(keyType), RowEncoder.apply(inputType));

KeyValueGroupedDataset<Row, Row> barGroups = input
    .filter("type = 'bar'")
    .groupBy("id")
    .as(RowEncoder.apply(keyType), RowEncoder.apply(inputType));

Dataset<Row> result = fooGroups.cogroup(
    barGroups,
    (CoGroupFunction<Row, Row, Row, Row>) (row, iterator, iterator1) -> new 
ArrayList<Row>().iterator(),
    RowEncoder.apply(inputType));

result.explain();
result.show();{code}
Explain output (note mismatch in column IDs between Sort/Exchagne and 
LocalTableScan on the first input to the CoGroup):
{code:java}
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- SerializeFromObject 
[validateexternaltype(getexternalrowfield(assertnotnull(input[0, 
org.apache.spark.sql.Row, true]), 0, id), LongType, false) AS id#37L, 
staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, 
fromString, validateexternaltype(getexternalrowfield(assertnotnull(input[0, 
org.apache.spark.sql.Row, true]), 1, type), StringType, false), true, false, 
true) AS type#38]
   +- CoGroup 
org.apache.spark.sql.KeyValueGroupedDataset$$Lambda$1478/1869116781@77856cc5, 
createexternalrow(id#16L, StructField(id,LongType,false)), 
createexternalrow(id#16L, type#17.toString, StructField(id,LongType,false), 
StructField(type,StringType,false)), createexternalrow(id#16L, 
type#17.toString, StructField(id,LongType,false), 
StructField(type,StringType,false)), [id#39L], [id#39L], [id#39L, type#40], 
[id#39L, type#40], obj#36: org.apache.spark.sql.Row
      :- !Sort [id#39L ASC NULLS FIRST], false, 0
      :  +- !Exchange hashpartitioning(id#39L, 2), ENSURE_REQUIREMENTS, 
[plan_id=19]
      :     +- LocalTableScan [id#16L, type#17]
      +- Sort [id#39L ASC NULLS FIRST], false, 0
         +- Exchange hashpartitioning(id#39L, 2), ENSURE_REQUIREMENTS, 
[plan_id=20]
            +- LocalTableScan [id#39L, type#40]{code}
Exception:
{code:java}
java.lang.IllegalStateException: Couldn't find id#39L in [id#16L,type#17]
        at 
org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:80)
        at 
org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:73)
        at 
org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$1(TreeNode.scala:584)
        at 
org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:176)
        at 
org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:584)
        at 
org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$3(TreeNode.scala:589)
        at scala.collection.immutable.ArraySeq.map(ArraySeq.scala:75)
        at scala.collection.immutable.ArraySeq.map(ArraySeq.scala:35)
        at 
org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:698)
        at 
org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:589)
        at 
org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$3(TreeNode.scala:589)
        at 
org.apache.spark.sql.catalyst.trees.BinaryLike.mapChildren(TreeNode.scala:1254)
        at 
org.apache.spark.sql.catalyst.trees.BinaryLike.mapChildren$(TreeNode.scala:1253)
        at 
org.apache.spark.sql.catalyst.expressions.BinaryExpression.mapChildren(Expression.scala:608)
        at 
org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:589)
        at 
org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:560)
        at 
org.apache.spark.sql.catalyst.trees.TreeNode.transform(TreeNode.scala:528)
        at 
org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReference(BoundAttribute.scala:73)
        at 
org.apache.spark.sql.catalyst.expressions.BindReferences$.$anonfun$bindReferences$1(BoundAttribute.scala:94)
        at scala.collection.immutable.List.map(List.scala:246)
        at scala.collection.immutable.List.map(List.scala:79)
        at 
org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReferences(BoundAttribute.scala:94)
        at 
org.apache.spark.sql.catalyst.expressions.UnsafeProjection$.create(Projection.scala:160)
        at 
org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$.getPartitionKeyExtractor$1(ShuffleExchangeExec.scala:323)
        at 
org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$.$anonfun$prepareShuffleDependency$13(ShuffleExchangeExec.scala:391)
        at 
org.apache.spark.sql.execution.exchange.ShuffleExchangeExec$.$anonfun$prepareShuffleDependency$13$adapted(ShuffleExchangeExec.scala:390)
        at 
org.apache.spark.rdd.RDD.$anonfun$mapPartitionsWithIndexInternal$2(RDD.scala:877)
        at 
org.apache.spark.rdd.RDD.$anonfun$mapPartitionsWithIndexInternal$2$adapted(RDD.scala:877)
        at 
org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
        at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:365)
        at org.apache.spark.rdd.RDD.iterator(RDD.scala:329)
        at 
org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
        at 
org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:99)
        at 
org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:52)
        at org.apache.spark.scheduler.Task.run(Task.scala:136)
        at 
org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:548)
        at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1504)
        at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:551)
        at 
java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
        at 
java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
        at java.lang.Thread.run(Thread.java:748) {code}
Other observations:
 * The same code works if I call {{createDataFrame()}} twice and use two 
separate datasets as input to the cogroup.
 * The real code uses two different filters on the same cached dataset as the 
two inputs to the cogroup. However, this results in the same exception, and the 
same apparent error in the physical plan, which looks as follows:
{code:java}
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- SerializeFromObject 
[validateexternaltype(getexternalrowfield(assertnotnull(input[0, 
org.apache.spark.sql.Row, true]), 0, id), LongType, false) AS id#47L, 
staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, 
fromString, validateexternaltype(getexternalrowfield(assertnotnull(input[0, 
org.apache.spark.sql.Row, true]), 1, type), StringType, false), true, false, 
true) AS type#48]
   +- CoGroup 
org.apache.spark.sql.KeyValueGroupedDataset$$Lambda$1526/693211959@7b2e931, 
createexternalrow(id#16L, StructField(id,LongType,false)), 
createexternalrow(id#16L, type#17.toString, StructField(id,LongType,false), 
StructField(type,StringType,false)), createexternalrow(id#16L, 
type#17.toString, StructField(id,LongType,false), 
StructField(type,StringType,false)), [id#49L], [id#49L], [id#49L, type#50], 
[id#49L, type#50], obj#46: org.apache.spark.sql.Row
      :- !Sort [id#49L ASC NULLS FIRST], false, 0
      :  +- !Exchange hashpartitioning(id#49L, 2), ENSURE_REQUIREMENTS, 
[plan_id=26]
      :     +- Filter (type#17 = foo)
      :        +- InMemoryTableScan [id#16L, type#17], [(type#17 = foo)]
      :              +- InMemoryRelation [id#16L, type#17], StorageLevel(disk, 
memory, deserialized, 1 replicas)
      :                    +- LocalTableScan [id#16L, type#17]
      +- Sort [id#49L ASC NULLS FIRST], false, 0
         +- Exchange hashpartitioning(id#49L, 2), ENSURE_REQUIREMENTS, 
[plan_id=27]
            +- Filter (type#50 = bar)
               +- InMemoryTableScan [id#49L, type#50], [(type#50 = bar)]
                     +- InMemoryRelation [id#49L, type#50], StorageLevel(disk, 
memory, deserialized, 1 replicas)
                           +- LocalTableScan [id#16L, type#17] {code}

 * The issue doesn't arise if I write the same code in PySpark, using 
{{{}FlatMapCoGroupsInPandas{}}}.



--
This message was sent by Atlassian Jira
(v8.20.10#820010)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org
For additional commands, e-mail: issues-h...@spark.apache.org

Reply via email to