Github user mn-mikke commented on a diff in the pull request: https://github.com/apache/spark/pull/22013#discussion_r208167785 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala --- @@ -365,3 +365,69 @@ case class ArrayAggregate( override def prettyName: String = "aggregate" } + +/** + * Transform Keys in a map using the transform_keys function. + */ +@ExpressionDescription( + usage = "_FUNC_(expr, func) - Transforms elements in a map using the function.", + examples = """ + Examples: + > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3), (k,v) -> k + 1); + map(array(2, 3, 4), array(1, 2, 3)) + > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3), (k, v) -> k + v); + map(array(2, 4, 6), array(1, 2, 3)) + """, + since = "2.4.0") +case class TransformKeys( + input: Expression, + function: Expression) + extends ArrayBasedHigherOrderFunction with CodegenFallback { + + override def nullable: Boolean = input.nullable + + override def dataType: DataType = { + val valueType = input.dataType.asInstanceOf[MapType].valueType + MapType(function.dataType, valueType, input.nullable) + } + + override def inputTypes: Seq[AbstractDataType] = Seq(MapType, expectingFunctionType) + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): + TransformKeys = { + val (keyElementType, valueElementType, containsNull) = input.dataType match { + case MapType(keyType, valueType, containsNullValue) => + (keyType, valueType, containsNullValue) + case _ => + val MapType(keyType, valueType, containsNullValue) = MapType.defaultConcreteType + (keyType, valueType, containsNullValue) + } + copy(function = f(function, (keyElementType, false) :: (valueElementType, containsNull) :: Nil)) + } + + @transient lazy val (keyVar, valueVar) = { + val LambdaFunction( + _, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function + (keyVar, valueVar) + } + + override def eval(input: InternalRow): Any = { + val arr = this.input.eval(input).asInstanceOf[MapData] + if (arr == null) { + null + } else { + val f = functionForEval + val resultKeys = new GenericArrayData(new Array[Any](arr.numElements)) + var i = 0 + while (i < arr.numElements) { + keyVar.value.set(arr.keyArray().get(i, keyVar.dataType)) + valueVar.value.set(arr.valueArray().get(i, valueVar.dataType)) + resultKeys.update(i, f.eval(input)) --- End diff -- Maybe I'm missing something, but couldn't ```f.eval(input)``` be evaluated to ```null```? Keys are not allowed to be```null```. Other functions have usually a ```null``` check and throw ```RuntimeException``` for such cases.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org