Github user mn-mikke commented on a diff in the pull request:

    https://github.com/apache/spark/pull/22017#discussion_r208527423
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
 ---
    @@ -442,3 +442,191 @@ case class ArrayAggregate(
     
       override def prettyName: String = "aggregate"
     }
    +
    +/**
    + * Merges two given maps into a single map by applying function to the 
pair of values with
    + * the same key.
    + */
    +@ExpressionDescription(
    +  usage =
    +    """
    +      _FUNC_(map1, map2, function) - Merges two given maps into a single 
map by applying
    +      function to the pair of values with the same key. For keys only 
presented in one map,
    +      NULL will be passed as the value for the missing key. If an input 
map contains duplicated
    +      keys, only the first entry of the duplicated key is passed into the 
lambda function.
    +    """,
    +  examples = """
    +    Examples:
    +      > SELECT _FUNC_(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, 
v2) -> concat(v1, v2));
    +       {1:"ax",2:"by"}
    +  """,
    +  since = "2.4.0")
    +case class MapZipWith(left: Expression, right: Expression, function: 
Expression)
    +  extends HigherOrderFunction with CodegenFallback {
    +
    +  @transient lazy val functionForEval: Expression = functionsForEval.head
    +
    +  @transient lazy val (keyType, leftValueType, _) =
    +    HigherOrderFunction.mapKeyValueArgumentType(left.dataType)
    +
    +  @transient lazy val (_, rightValueType, _) =
    +    HigherOrderFunction.mapKeyValueArgumentType(right.dataType)
    +
    +  @transient lazy val ordering = TypeUtils.getInterpretedOrdering(keyType)
    +
    +  override def inputs: Seq[Expression] = left :: right :: Nil
    +
    +  override def functions: Seq[Expression] = function :: Nil
    +
    +  override def nullable: Boolean = left.nullable || right.nullable
    +
    +  override def dataType: DataType = MapType(keyType, function.dataType, 
function.nullable)
    +
    +  override def checkInputDataTypes(): TypeCheckResult = {
    +    (left.dataType, right.dataType) match {
    +      case (MapType(k1, _, _), MapType(k2, _, _)) if k1.sameType(k2) =>
    +        TypeUtils.checkForOrderingExpr(k1, s"function $prettyName")
    +      case _ => TypeCheckResult.TypeCheckFailure(s"The input to function 
$prettyName should have " +
    +        s"been two ${MapType.simpleString}s with the same key type, but 
it's " +
    +        s"[${left.dataType.catalogString}, 
${right.dataType.catalogString}].")
    +    }
    +  }
    +
    +  override def bind(f: (Expression, Seq[(DataType, Boolean)]) => 
LambdaFunction): MapZipWith = {
    +    val arguments = Seq((keyType, false), (leftValueType, true), 
(rightValueType, true))
    +    copy(function = f(function, arguments))
    +  }
    +
    +  override def eval(input: InternalRow): Any = {
    +    val value1 = left.eval(input)
    +    if (value1 == null) {
    +      null
    +    } else {
    +      val value2 = right.eval(input)
    +      if (value2 == null) {
    +        null
    +      } else {
    +        nullSafeEval(input, value1, value2)
    +      }
    +    }
    +  }
    +
    +  @transient lazy val LambdaFunction(_, Seq(
    +    keyVar: NamedLambdaVariable,
    +    value1Var: NamedLambdaVariable,
    +    value2Var: NamedLambdaVariable),
    +    _) = function
    +
    +  private def keyTypeSupportsEquals = keyType match {
    +    case BinaryType => false
    +    case _: AtomicType => true
    +    case _ => false
    +  }
    +
    +  @transient private lazy val getKeysWithValueIndexes:
    +      (ArrayData, ArrayData) => Seq[(Any, Array[Option[Int]])] = {
    +    if (keyTypeSupportsEquals) {
    +      getKeysWithIndexesFast
    +    } else {
    +      getKeysWithIndexesBruteForce
    +    }
    +  }
    +
    +  private def assertSizeOfArrayBuffer(size: Int): Unit = {
    +    if (size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
    +      throw new RuntimeException(s"Unsuccessful try to zip maps with $size 
" +
    +        s"unique keys due to exceeding the array size limit " +
    +        s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.")
    +    }
    +  }
    +
    +  private def getKeysWithIndexesFast(keys1: ArrayData, keys2: ArrayData) = 
{
    +    val arrayBuffer = new mutable.ArrayBuffer[(Any, Array[Option[Int]])]
    +    val hashMap = new mutable.OpenHashMap[Any, Array[Option[Int]]]
    +    val keys = Array(keys1, keys2)
    +    var z = 0
    +    while(z < 2) {
    +      var i = 0
    +      val array = keys(z)
    +      while (i < array.numElements()) {
    +        val key = array.get(i, keyType)
    +        hashMap.get(key) match {
    +          case Some(indexes) =>
    +            if (indexes(z).isEmpty) indexes(z) = Some(i)
    +          case None =>
    +            assertSizeOfArrayBuffer(arrayBuffer.size)
    +            val indexes = Array[Option[Int]](None, None)
    +            indexes(z) = Some(i)
    +            hashMap.put(key, indexes)
    +            arrayBuffer += Tuple2(key, indexes)
    +        }
    +        i += 1
    +      }
    +      z += 1
    +    }
    +    arrayBuffer
    +  }
    +
    +  private def getKeysWithIndexesBruteForce(keys1: ArrayData, keys2: 
ArrayData) = {
    +    val arrayBuffer = new mutable.ArrayBuffer[(Any, Array[Option[Int]])]
    +    val keys = Array(keys1, keys2)
    +    var z = 0
    +    while(z < 2) {
    +      var i = 0
    +      val array = keys(z)
    +      while (i < array.numElements()) {
    +        val key = array.get(i, keyType)
    +        var found = false
    +        var j = 0
    +        while (!found && j < arrayBuffer.size) {
    +          val (bufferKey, indexes) = arrayBuffer(j)
    +          if (ordering.equiv(bufferKey, key)) {
    +            found = true
    +            if(indexes(z).isEmpty) indexes(z) = Some(i)
    +          }
    +          j += 1
    +        }
    +        if (!found) {
    +          assertSizeOfArrayBuffer(arrayBuffer.size)
    --- End diff --
    
    The purpose of this line is to avoid ```OutOfMemoryError``` exception when 
max array size is exceeded and throw something more accurate. Maybe I'm missing 
something, but wouldn't we break it we checked this only once at the end? The 
max size could be exceeded in any iteration.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to