Repository: spark
Updated Branches:
  refs/heads/master 180f969c9 -> 7143e9d72


[SPARK-25829][SQL][FOLLOWUP] Refactor MapConcat in order to check properly the 
limit size

## What changes were proposed in this pull request?

The PR starts from the 
[comment](https://github.com/apache/spark/pull/23124#discussion_r236112390) in 
the main one and it aims at:
 - simplifying the code for `MapConcat`;
 - be more precise in checking the limit size.

## How was this patch tested?

existing tests

Closes #23217 from mgaido91/SPARK-25829_followup.

Authored-by: Marco Gaido <marcogaid...@gmail.com>
Signed-off-by: Wenchen Fan <wenc...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7143e9d7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7143e9d7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7143e9d7

Branch: refs/heads/master
Commit: 7143e9d7220bd98ceb82c5c5f045108a8a664ec1
Parents: 180f969
Author: Marco Gaido <marcogaid...@gmail.com>
Authored: Wed Dec 5 09:12:24 2018 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Wed Dec 5 09:12:24 2018 +0800

----------------------------------------------------------------------
 .../expressions/collectionOperations.scala      | 77 +-------------------
 .../catalyst/util/ArrayBasedMapBuilder.scala    | 10 +++
 2 files changed, 12 insertions(+), 75 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7143e9d7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index fa8e38a..67f6739 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -554,13 +554,6 @@ case class MapConcat(children: Seq[Expression]) extends 
ComplexTypeMergingExpres
       return null
     }
 
-    val numElements = maps.foldLeft(0L)((sum, ad) => sum + ad.numElements())
-    if (numElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
-      throw new RuntimeException(s"Unsuccessful attempt to concat maps with 
$numElements " +
-        s"elements due to exceeding the map size limit " +
-        s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.")
-    }
-
     for (map <- maps) {
       mapBuilder.putAll(map.keyArray(), map.valueArray())
     }
@@ -569,8 +562,6 @@ case class MapConcat(children: Seq[Expression]) extends 
ComplexTypeMergingExpres
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val mapCodes = children.map(_.genCode(ctx))
-    val keyType = dataType.keyType
-    val valueType = dataType.valueType
     val argsName = ctx.freshName("args")
     val hasNullName = ctx.freshName("hasNull")
     val builderTerm = ctx.addReferenceObj("mapBuilder", mapBuilder)
@@ -610,41 +601,12 @@ case class MapConcat(children: Seq[Expression]) extends 
ComplexTypeMergingExpres
     )
 
     val idxName = ctx.freshName("idx")
-    val numElementsName = ctx.freshName("numElems")
-    val finKeysName = ctx.freshName("finalKeys")
-    val finValsName = ctx.freshName("finalValues")
-
-    val keyConcat = genCodeForArrays(ctx, keyType, false)
-
-    val valueConcat =
-      if (valueType.sameType(keyType) &&
-          !(CodeGenerator.isPrimitiveType(valueType) && 
dataType.valueContainsNull)) {
-        keyConcat
-      } else {
-        genCodeForArrays(ctx, valueType, dataType.valueContainsNull)
-      }
-
-    val keyArgsName = ctx.freshName("keyArgs")
-    val valArgsName = ctx.freshName("valArgs")
-
     val mapMerge =
       s"""
-        |ArrayData[] $keyArgsName = new ArrayData[${mapCodes.size}];
-        |ArrayData[] $valArgsName = new ArrayData[${mapCodes.size}];
-        |long $numElementsName = 0;
         |for (int $idxName = 0; $idxName < $argsName.length; $idxName++) {
-        |  $keyArgsName[$idxName] = $argsName[$idxName].keyArray();
-        |  $valArgsName[$idxName] = $argsName[$idxName].valueArray();
-        |  $numElementsName += $argsName[$idxName].numElements();
+        |  $builderTerm.putAll($argsName[$idxName].keyArray(), 
$argsName[$idxName].valueArray());
         |}
-        |if ($numElementsName > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
-        |  throw new RuntimeException("Unsuccessful attempt to concat maps 
with " +
-        |     $numElementsName + " elements due to exceeding the map size 
limit " +
-        |     "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.");
-        |}
-        |ArrayData $finKeysName = $keyConcat($keyArgsName, (int) 
$numElementsName);
-        |ArrayData $finValsName = $valueConcat($valArgsName, (int) 
$numElementsName);
-        |${ev.value} = $builderTerm.from($finKeysName, $finValsName);
+        |${ev.value} = $builderTerm.build();
       """.stripMargin
 
     ev.copy(
@@ -660,41 +622,6 @@ case class MapConcat(children: Seq[Expression]) extends 
ComplexTypeMergingExpres
       """.stripMargin)
   }
 
-  private def genCodeForArrays(
-      ctx: CodegenContext,
-      elementType: DataType,
-      checkForNull: Boolean): String = {
-    val counter = ctx.freshName("counter")
-    val arrayData = ctx.freshName("arrayData")
-    val argsName = ctx.freshName("args")
-    val numElemName = ctx.freshName("numElements")
-    val y = ctx.freshName("y")
-    val z = ctx.freshName("z")
-
-    val allocation = CodeGenerator.createArrayData(
-      arrayData, elementType, numElemName, s" $prettyName failed.")
-    val assignment = CodeGenerator.createArrayAssignment(
-      arrayData, elementType, s"$argsName[$y]", counter, z, checkForNull)
-
-    val concat = ctx.freshName("concat")
-    val concatDef =
-      s"""
-         |private ArrayData $concat(ArrayData[] $argsName, int $numElemName) {
-         |  $allocation
-         |  int $counter = 0;
-         |  for (int $y = 0; $y < ${children.length}; $y++) {
-         |    for (int $z = 0; $z < $argsName[$y].numElements(); $z++) {
-         |      $assignment
-         |      $counter++;
-         |    }
-         |  }
-         |  return $arrayData;
-         |}
-       """.stripMargin
-
-    ctx.addNewFunction(concat, concatDef)
-  }
-
   override def prettyName: String = "map_concat"
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/7143e9d7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala
index e7cd616..9893436 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala
@@ -21,6 +21,7 @@ import scala.collection.mutable
 
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.array.ByteArrayMethods
 
 /**
  * A builder of [[ArrayBasedMapData]], which fails if a null map key is 
detected, and removes
@@ -54,6 +55,10 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: 
DataType) extends Seria
 
     val index = keyToIndex.getOrDefault(key, -1)
     if (index == -1) {
+      if (size >= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+        throw new RuntimeException(s"Unsuccessful attempt to build maps with 
$size elements " +
+          s"due to exceeding the map size limit 
${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.")
+      }
       keyToIndex.put(key, values.length)
       keys.append(key)
       values.append(value)
@@ -117,4 +122,9 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: 
DataType) extends Seria
       build()
     }
   }
+
+  /**
+   * Returns the current size of the map which is going to be produced by the 
current builder.
+   */
+  def size: Int = keys.size
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to