[ https://issues.apache.org/jira/browse/SPARK-44768?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel ]
Franck Tago updated SPARK-44768: -------------------------------- Description: The code sample below is to showcase the wholestagecodegen generated when exploding nested arrays. The data sample in the dataframe is quite small so it will not trigger the Out of Memory error . However if the array is larger and the row size is large , you will definitely end up with an OOM error . consider a scenario where you flatten a nested array // e.g you can use the following steps to create the dataframe //create a partClass case class case class partClass (PARTNAME: String , PartNumber: String , PartPrice : Double ) //create a nested array array class case class array_array_class ( col_int: Int, arr_arr_string : Seq[Seq[String]], arr_arr_bigint : Seq[Seq[Long]], col_string : String, parts_s : Seq[Seq[partClass]] ) //create a dataframe var df_array_array = sc.parallelize( Seq( (1,Seq(Seq("a","b" ,"c" ,"d") ,Seq("aa","bb" ,"cc","dd")) , Seq(Seq(1000,20000), Seq(30000,-10000)),"ItemPart1", Seq(Seq(partClass("PNAME1","P1",20.75),partClass("PNAME1_1","P1_1",30.75))) ) , (2,Seq(Seq("ab","bc" ,"cd" ,"de") ,Seq("aab","bbc" ,"ccd","dde"),Seq("aaaaaabbbbb")) , Seq(Seq(-1000,-20000,-1,-2), Seq(0,30000,-10000)),"ItemPart2", Seq(Seq(partClass("PNAME2","P2",170.75),partClass("PNAME2_1","P2_1",33.75),partClass("PNAME2_2","P2_2",73.75))) ) ) ).toDF("c1" ,"c2" ,"c3" ,"c4" ,"c5") //explode a nested array var result = df_array_array.select( col("c1"), explode(col("c2"))).select('c1 , explode('col)) result.explain The physical for this operator is seen below. ------------------------------------- Physical plan == Physical Plan == *(1) Generate explode(col#27), [c1#17|#17], false, [col#30|#30] +- *(1) Filter ((size(col#27, true) > 0) AND isnotnull(col#27)) +- *(1) Generate explode(c2#18), [c1#17|#17], false, [col#27|#27] +- *(1) Project [_1#6 AS c1#17, _2#7 AS c2#18|#6 AS c1#17, _2#7 AS c2#18] +- *(1) Filter ((size(_2#7, true) > 0) AND isnotnull(_2#7)) +- *(1) SerializeFromObject [knownnotnull(assertnotnull(input[0, scala.Tuple5, true]))._1 AS _1#6, mapobjects(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -1), mapobjects(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -2), staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, validateexternaltype(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -2), StringType, ObjectType(class java.lang.String)), true, false, true), validateexternaltype(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -1), ArrayType(StringType,true), ObjectType(interface scala.collection.Seq)), None), knownnotnull(assertnotnull(input[0, scala.Tuple5, true]))._2, None) AS _2#7, mapobjects(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -3), mapobjects(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -4), assertnotnull(validateexternaltype(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -4), IntegerType, IntegerType)), validateexternaltype(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -3), ArrayType(IntegerType,false), ObjectType(interface scala.collection.Seq)), None), knownnotnull(assertnotnull(input[0, scala.Tuple5, true]))._3, None) AS _3#8, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(assertnotnull(input[0, scala.Tuple5, true]))._4, true, false, true) AS _4#9, mapobjects(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -5), mapobjects(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -6), if (isnull(validateexternaltype(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -6), StructField(PARTNAME,StringType,true), StructField(PartNumber,StringType,true), StructField(PartPrice,DoubleType,false), ObjectType(class $line14.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$partClass)))) null else named_struct(PARTNAME, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(validateexternaltype(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -6), StructField(PARTNAME,StringType,true), StructField(PartNumber,StringType,true), StructField(PartPrice,DoubleType,false), ObjectType(class $line14.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$partClass))).PARTNAME, true, false, true), PartNumber, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(validateexternaltype(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -6), StructField(PARTNAME,StringType,true), StructField(PartNumber,StringType,true), StructField(PartPrice,DoubleType,false), ObjectType(class $line14.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$partClass))).PartNumber, true, false, true), PartPrice, knownnotnull(validateexternaltype(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -6), StructField(PARTNAME,StringType,true), StructField(PartNumber,StringType,true), StructField(PartPrice,DoubleType,false), ObjectType(class $line14.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$partClass))).PartPrice), validateexternaltype(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -5), ArrayType(StructType(StructField(PARTNAME,StringType,true),StructField(PartNumber,StringType,true),StructField(PartPrice,DoubleType,false)),true), ObjectType(interface scala.collection.Seq)), None), knownnotnull(assertnotnull(input[0, scala.Tuple5, true]))._5, None) AS _5#10] +- Scan[obj#5|#5] Because the explode function can create multiple rows from a single row , we should account for the memory available when adding rows to the buffer . This is even more important when we are exploding nested arrays . was: consider a scenario where you flatten a nested array // e.g you can use the following steps to create the dataframe //create a partClass case class case class partClass (PARTNAME: String , PartNumber: String , PartPrice : Double ) //create a nested array array class case class array_array_class ( col_int: Int, arr_arr_string : Seq[Seq[String]], arr_arr_bigint : Seq[Seq[Long]], col_string : String, parts_s : Seq[Seq[partClass]] ) //create a dataframe var df_array_array = sc.parallelize( Seq( (1,Seq(Seq("a","b" ,"c" ,"d") ,Seq("aa","bb" ,"cc","dd")) , Seq(Seq(1000,20000), Seq(30000,-10000)),"ItemPart1", Seq(Seq(partClass("PNAME1","P1",20.75),partClass("PNAME1_1","P1_1",30.75))) ) , (2,Seq(Seq("ab","bc" ,"cd" ,"de") ,Seq("aab","bbc" ,"ccd","dde"),Seq("aaaaaabbbbb")) , Seq(Seq(-1000,-20000,-1,-2), Seq(0,30000,-10000)),"ItemPart2", Seq(Seq(partClass("PNAME2","P2",170.75),partClass("PNAME2_1","P2_1",33.75),partClass("PNAME2_2","P2_2",73.75))) ) ) ).toDF("c1" ,"c2" ,"c3" ,"c4" ,"c5") //explode a nested array var result = df_array_array.select( col("c1"), explode(col("c2"))).select('c1 , explode('col)) result.explain The physical for this operator is seen below. ------------------------------------- Physical plan == Physical Plan == *(1) Generate explode(col#27), [c1#17|#17], false, [col#30|#30] +- *(1) Filter ((size(col#27, true) > 0) AND isnotnull(col#27)) +- *(1) Generate explode(c2#18), [c1#17|#17], false, [col#27|#27] +- *(1) Project [_1#6 AS c1#17, _2#7 AS c2#18|#6 AS c1#17, _2#7 AS c2#18] +- *(1) Filter ((size(_2#7, true) > 0) AND isnotnull(_2#7)) +- *(1) SerializeFromObject [knownnotnull(assertnotnull(input[0, scala.Tuple5, true]))._1 AS _1#6, mapobjects(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -1), mapobjects(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -2), staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, validateexternaltype(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -2), StringType, ObjectType(class java.lang.String)), true, false, true), validateexternaltype(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -1), ArrayType(StringType,true), ObjectType(interface scala.collection.Seq)), None), knownnotnull(assertnotnull(input[0, scala.Tuple5, true]))._2, None) AS _2#7, mapobjects(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -3), mapobjects(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -4), assertnotnull(validateexternaltype(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -4), IntegerType, IntegerType)), validateexternaltype(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -3), ArrayType(IntegerType,false), ObjectType(interface scala.collection.Seq)), None), knownnotnull(assertnotnull(input[0, scala.Tuple5, true]))._3, None) AS _3#8, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(assertnotnull(input[0, scala.Tuple5, true]))._4, true, false, true) AS _4#9, mapobjects(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -5), mapobjects(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -6), if (isnull(validateexternaltype(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -6), StructField(PARTNAME,StringType,true), StructField(PartNumber,StringType,true), StructField(PartPrice,DoubleType,false), ObjectType(class $line14.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$partClass)))) null else named_struct(PARTNAME, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(validateexternaltype(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -6), StructField(PARTNAME,StringType,true), StructField(PartNumber,StringType,true), StructField(PartPrice,DoubleType,false), ObjectType(class $line14.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$partClass))).PARTNAME, true, false, true), PartNumber, staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, knownnotnull(validateexternaltype(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -6), StructField(PARTNAME,StringType,true), StructField(PartNumber,StringType,true), StructField(PartPrice,DoubleType,false), ObjectType(class $line14.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$partClass))).PartNumber, true, false, true), PartPrice, knownnotnull(validateexternaltype(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -6), StructField(PARTNAME,StringType,true), StructField(PartNumber,StringType,true), StructField(PartPrice,DoubleType,false), ObjectType(class $line14.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$partClass))).PartPrice), validateexternaltype(lambdavariable(MapObject, ObjectType(class java.lang.Object), true, -5), ArrayType(StructType(StructField(PARTNAME,StringType,true),StructField(PartNumber,StringType,true),StructField(PartPrice,DoubleType,false)),true), ObjectType(interface scala.collection.Seq)), None), knownnotnull(assertnotnull(input[0, scala.Tuple5, true]))._5, None) AS _5#10] +- Scan[obj#5|#5] Because the explode function can create multiple rows from a single row , we should account for the memory available when adding rows to the buffer . This is even more important when we are exploding nested arrays . > Improve WSCG handling of row buffer by accounting for executor memory . > Exploding nested arrays can easily lead to out of memory errors. > ------------------------------------------------------------------------------------------------------------------------------------------- > > Key: SPARK-44768 > URL: https://issues.apache.org/jira/browse/SPARK-44768 > Project: Spark > Issue Type: Bug > Components: Optimizer > Affects Versions: 3.3.2, 3.4.0, 3.4.1 > Reporter: Franck Tago > Priority: Major > Attachments: image-2023-08-10-20-32-55-684.png, > spark-jira_wscg_code.txt > > > The code sample below is to showcase the wholestagecodegen generated when > exploding nested arrays. The data sample in the dataframe is quite small so > it will not trigger the Out of Memory error . > However if the array is larger and the row size is large , you will > definitely end up with an OOM error . > > consider a scenario where you flatten a nested array > // e.g you can use the following steps to create the dataframe > //create a partClass case class > case class partClass (PARTNAME: String , PartNumber: String , PartPrice : > Double ) > //create a nested array array class > case class array_array_class ( > col_int: Int, > arr_arr_string : Seq[Seq[String]], > arr_arr_bigint : Seq[Seq[Long]], > col_string : String, > parts_s : Seq[Seq[partClass]] > > ) > //create a dataframe > var df_array_array = sc.parallelize( > Seq( > (1,Seq(Seq("a","b" ,"c" ,"d") ,Seq("aa","bb" ,"cc","dd")) , > Seq(Seq(1000,20000), Seq(30000,-10000)),"ItemPart1", > Seq(Seq(partClass("PNAME1","P1",20.75),partClass("PNAME1_1","P1_1",30.75))) > ) , > > (2,Seq(Seq("ab","bc" ,"cd" ,"de") ,Seq("aab","bbc" > ,"ccd","dde"),Seq("aaaaaabbbbb")) , Seq(Seq(-1000,-20000,-1,-2), > Seq(0,30000,-10000)),"ItemPart2", > > Seq(Seq(partClass("PNAME2","P2",170.75),partClass("PNAME2_1","P2_1",33.75),partClass("PNAME2_2","P2_2",73.75))) > ) > > ) > ).toDF("c1" ,"c2" ,"c3" ,"c4" ,"c5") > //explode a nested array > var result = df_array_array.select( col("c1"), > explode(col("c2"))).select('c1 , explode('col)) > result.explain > > The physical for this operator is seen below. > ------------------------------------- > Physical plan > == Physical Plan == > *(1) Generate explode(col#27), [c1#17|#17], false, [col#30|#30] > +- *(1) Filter ((size(col#27, true) > 0) AND isnotnull(col#27)) > +- *(1) Generate explode(c2#18), [c1#17|#17], false, [col#27|#27] > +- *(1) Project [_1#6 AS c1#17, _2#7 AS c2#18|#6 AS c1#17, _2#7 AS > c2#18] > +- *(1) Filter ((size(_2#7, true) > 0) AND isnotnull(_2#7)) > +- *(1) SerializeFromObject [knownnotnull(assertnotnull(input[0, > scala.Tuple5, true]))._1 AS _1#6, mapobjects(lambdavariable(MapObject, > ObjectType(class java.lang.Object), true, -1), > mapobjects(lambdavariable(MapObject, ObjectType(class java.lang.Object), > true, -2), staticinvoke(class org.apache.spark.unsafe.types.UTF8String, > StringType, fromString, validateexternaltype(lambdavariable(MapObject, > ObjectType(class java.lang.Object), true, -2), StringType, ObjectType(class > java.lang.String)), true, false, true), > validateexternaltype(lambdavariable(MapObject, ObjectType(class > java.lang.Object), true, -1), ArrayType(StringType,true), > ObjectType(interface scala.collection.Seq)), None), > knownnotnull(assertnotnull(input[0, scala.Tuple5, true]))._2, None) AS _2#7, > mapobjects(lambdavariable(MapObject, ObjectType(class java.lang.Object), > true, -3), mapobjects(lambdavariable(MapObject, ObjectType(class > java.lang.Object), true, -4), > assertnotnull(validateexternaltype(lambdavariable(MapObject, ObjectType(class > java.lang.Object), true, -4), IntegerType, IntegerType)), > validateexternaltype(lambdavariable(MapObject, ObjectType(class > java.lang.Object), true, -3), ArrayType(IntegerType,false), > ObjectType(interface scala.collection.Seq)), None), > knownnotnull(assertnotnull(input[0, scala.Tuple5, true]))._3, None) AS _3#8, > staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, > fromString, knownnotnull(assertnotnull(input[0, scala.Tuple5, true]))._4, > true, false, true) AS _4#9, mapobjects(lambdavariable(MapObject, > ObjectType(class java.lang.Object), true, -5), > mapobjects(lambdavariable(MapObject, ObjectType(class java.lang.Object), > true, -6), if (isnull(validateexternaltype(lambdavariable(MapObject, > ObjectType(class java.lang.Object), true, -6), > StructField(PARTNAME,StringType,true), > StructField(PartNumber,StringType,true), > StructField(PartPrice,DoubleType,false), ObjectType(class > $line14.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$partClass)))) null else > named_struct(PARTNAME, staticinvoke(class > org.apache.spark.unsafe.types.UTF8String, StringType, fromString, > knownnotnull(validateexternaltype(lambdavariable(MapObject, ObjectType(class > java.lang.Object), true, -6), StructField(PARTNAME,StringType,true), > StructField(PartNumber,StringType,true), > StructField(PartPrice,DoubleType,false), ObjectType(class > $line14.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$partClass))).PARTNAME, true, > false, true), PartNumber, staticinvoke(class > org.apache.spark.unsafe.types.UTF8String, StringType, fromString, > knownnotnull(validateexternaltype(lambdavariable(MapObject, ObjectType(class > java.lang.Object), true, -6), StructField(PARTNAME,StringType,true), > StructField(PartNumber,StringType,true), > StructField(PartPrice,DoubleType,false), ObjectType(class > $line14.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$partClass))).PartNumber, true, > false, true), PartPrice, > knownnotnull(validateexternaltype(lambdavariable(MapObject, ObjectType(class > java.lang.Object), true, -6), StructField(PARTNAME,StringType,true), > StructField(PartNumber,StringType,true), > StructField(PartPrice,DoubleType,false), ObjectType(class > $line14.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$partClass))).PartPrice), > validateexternaltype(lambdavariable(MapObject, ObjectType(class > java.lang.Object), true, -5), > ArrayType(StructType(StructField(PARTNAME,StringType,true),StructField(PartNumber,StringType,true),StructField(PartPrice,DoubleType,false)),true), > ObjectType(interface scala.collection.Seq)), None), > knownnotnull(assertnotnull(input[0, scala.Tuple5, true]))._5, None) AS _5#10] > +- Scan[obj#5|#5] > > > Because the explode function can create multiple rows from a single row , we > should account for the memory available when adding rows to the buffer . > > This is even more important when we are exploding nested arrays . -- This message was sent by Atlassian Jira (v8.20.10#820010) --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org For additional commands, e-mail: issues-h...@spark.apache.org