Repository: spark Updated Branches: refs/heads/master dc8a6befa -> 92c2f00bd
[SPARK-23934][SQL] Adding map_from_entries function ## What changes were proposed in this pull request? The PR adds the `map_from_entries` function that returns a map created from the given array of entries. ## How was this patch tested? New tests added into: - `CollectionExpressionSuite` - `DataFrameFunctionSuite` ## CodeGen Examples ### Primitive-type Keys and Values ``` val idf = Seq( Seq((1, 10), (2, 20), (3, 10)), Seq((1, 10), null, (2, 20)) ).toDF("a") idf.filter('a.isNotNull).select(map_from_entries('a)).debugCodegen ``` Result: ``` /* 042 */ boolean project_isNull_0 = false; /* 043 */ MapData project_value_0 = null; /* 044 */ /* 045 */ for (int project_idx_2 = 0; !project_isNull_0 && project_idx_2 < inputadapter_value_0.numElements(); project_idx_2++) { /* 046 */ project_isNull_0 |= inputadapter_value_0.isNullAt(project_idx_2); /* 047 */ } /* 048 */ if (!project_isNull_0) { /* 049 */ final int project_numEntries_0 = inputadapter_value_0.numElements(); /* 050 */ /* 051 */ final long project_keySectionSize_0 = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(project_numEntries_0, 4); /* 052 */ final long project_valueSectionSize_0 = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(project_numEntries_0, 4); /* 053 */ final long project_byteArraySize_0 = 8 + project_keySectionSize_0 + project_valueSectionSize_0; /* 054 */ if (project_byteArraySize_0 > 2147483632) { /* 055 */ final Object[] project_keys_0 = new Object[project_numEntries_0]; /* 056 */ final Object[] project_values_0 = new Object[project_numEntries_0]; /* 057 */ /* 058 */ for (int project_idx_1 = 0; project_idx_1 < project_numEntries_0; project_idx_1++) { /* 059 */ InternalRow project_entry_1 = inputadapter_value_0.getStruct(project_idx_1, 2); /* 060 */ /* 061 */ project_keys_0[project_idx_1] = project_entry_1.getInt(0); /* 062 */ project_values_0[project_idx_1] = project_entry_1.getInt(1); /* 063 */ } /* 064 */ /* 065 */ project_value_0 = org.apache.spark.sql.catalyst.util.ArrayBasedMapData.apply(project_keys_0, project_values_0); /* 066 */ /* 067 */ } else { /* 068 */ final byte[] project_byteArray_0 = new byte[(int)project_byteArraySize_0]; /* 069 */ UnsafeMapData project_unsafeMapData_0 = new UnsafeMapData(); /* 070 */ Platform.putLong(project_byteArray_0, 16, project_keySectionSize_0); /* 071 */ Platform.putLong(project_byteArray_0, 24, project_numEntries_0); /* 072 */ Platform.putLong(project_byteArray_0, 24 + project_keySectionSize_0, project_numEntries_0); /* 073 */ project_unsafeMapData_0.pointTo(project_byteArray_0, 16, (int)project_byteArraySize_0); /* 074 */ ArrayData project_keyArrayData_0 = project_unsafeMapData_0.keyArray(); /* 075 */ ArrayData project_valueArrayData_0 = project_unsafeMapData_0.valueArray(); /* 076 */ /* 077 */ for (int project_idx_0 = 0; project_idx_0 < project_numEntries_0; project_idx_0++) { /* 078 */ InternalRow project_entry_0 = inputadapter_value_0.getStruct(project_idx_0, 2); /* 079 */ /* 080 */ project_keyArrayData_0.setInt(project_idx_0, project_entry_0.getInt(0)); /* 081 */ project_valueArrayData_0.setInt(project_idx_0, project_entry_0.getInt(1)); /* 082 */ } /* 083 */ /* 084 */ project_value_0 = project_unsafeMapData_0; /* 085 */ } /* 086 */ /* 087 */ } ``` ### Non-primitive-type Keys and Values ``` val sdf = Seq( Seq(("a", null), ("b", "bb"), ("c", "aa")), Seq(("a", "aa"), null, (null, "bb")) ).toDF("a") sdf.filter('a.isNotNull).select(map_from_entries('a)).debugCodegen ``` Result: ``` /* 042 */ boolean project_isNull_0 = false; /* 043 */ MapData project_value_0 = null; /* 044 */ /* 045 */ for (int project_idx_1 = 0; !project_isNull_0 && project_idx_1 < inputadapter_value_0.numElements(); project_idx_1++) { /* 046 */ project_isNull_0 |= inputadapter_value_0.isNullAt(project_idx_1); /* 047 */ } /* 048 */ if (!project_isNull_0) { /* 049 */ final int project_numEntries_0 = inputadapter_value_0.numElements(); /* 050 */ /* 051 */ final Object[] project_keys_0 = new Object[project_numEntries_0]; /* 052 */ final Object[] project_values_0 = new Object[project_numEntries_0]; /* 053 */ /* 054 */ for (int project_idx_0 = 0; project_idx_0 < project_numEntries_0; project_idx_0++) { /* 055 */ InternalRow project_entry_0 = inputadapter_value_0.getStruct(project_idx_0, 2); /* 056 */ /* 057 */ if (project_entry_0.isNullAt(0)) { /* 058 */ throw new RuntimeException("The first field from a struct (key) can't be null."); /* 059 */ } /* 060 */ /* 061 */ project_keys_0[project_idx_0] = project_entry_0.getUTF8String(0); /* 062 */ project_values_0[project_idx_0] = project_entry_0.getUTF8String(1); /* 063 */ } /* 064 */ /* 065 */ project_value_0 = org.apache.spark.sql.catalyst.util.ArrayBasedMapData.apply(project_keys_0, project_values_0); /* 066 */ /* 067 */ } ``` Author: Marek Novotny <mn.mi...@gmail.com> Closes #21282 from mn-mikke/feature/array-api-map_from_entries-to-master. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/92c2f00b Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/92c2f00b Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/92c2f00b Branch: refs/heads/master Commit: 92c2f00bd275a90b6912fb8c8cf542002923629c Parents: dc8a6be Author: Marek Novotny <mn.mi...@gmail.com> Authored: Fri Jun 22 16:18:22 2018 +0900 Committer: Takuya UESHIN <ues...@databricks.com> Committed: Fri Jun 22 16:18:22 2018 +0900 ---------------------------------------------------------------------- python/pyspark/sql/functions.py | 20 ++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/codegen/CodeGenerator.scala | 30 +++ .../expressions/collectionOperations.scala | 235 +++++++++++++++++-- .../CollectionExpressionsSuite.scala | 51 ++++ .../scala/org/apache/spark/sql/functions.scala | 7 + .../spark/sql/DataFrameFunctionsSuite.scala | 50 ++++ 7 files changed, 378 insertions(+), 16 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/92c2f00b/python/pyspark/sql/functions.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 11b179f..5f5d733 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2412,6 +2412,26 @@ def map_entries(col): return Column(sc._jvm.functions.map_entries(_to_java_column(col))) +@since(2.4) +def map_from_entries(col): + """ + Collection function: Returns a map created from the given array of entries. + + :param col: name of column or expression + + >>> from pyspark.sql.functions import map_from_entries + >>> df = spark.sql("SELECT array(struct(1, 'a'), struct(2, 'b')) as data") + >>> df.select(map_from_entries("data").alias("map")).show() + +----------------+ + | map| + +----------------+ + |[1 -> a, 2 -> b]| + +----------------+ + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.map_from_entries(_to_java_column(col))) + + @ignore_unicode_prefix @since(2.4) def array_repeat(col, count): http://git-wip-us.apache.org/repos/asf/spark/blob/92c2f00b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 4b09b9a..8abc616 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -421,6 +421,7 @@ object FunctionRegistry { expression[MapKeys]("map_keys"), expression[MapValues]("map_values"), expression[MapEntries]("map_entries"), + expression[MapFromEntries]("map_from_entries"), expression[Size]("size"), expression[Slice]("slice"), expression[Size]("cardinality"), http://git-wip-us.apache.org/repos/asf/spark/blob/92c2f00b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 66315e5..4cc0968 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -820,6 +820,36 @@ class CodegenContext { } /** + * Generates code to do null safe execution when accessing properties of complex + * ArrayData elements. + * + * @param nullElements used to decide whether the ArrayData might contain null or not. + * @param isNull a variable indicating whether the result will be evaluated to null or not. + * @param arrayData a variable name representing the ArrayData. + * @param execute the code that should be executed only if the ArrayData doesn't contain + * any null. + */ + def nullArrayElementsSaveExec( + nullElements: Boolean, + isNull: String, + arrayData: String)( + execute: String): String = { + val i = freshName("idx") + if (nullElements) { + s""" + |for (int $i = 0; !$isNull && $i < $arrayData.numElements(); $i++) { + | $isNull |= $arrayData.isNullAt($i); + |} + |if (!$isNull) { + | $execute + |} + """.stripMargin + } else { + execute + } + } + + /** * Splits the generated code of expressions into multiple functions, because function has * 64kb code size limit in JVM. If the class to which the function would be inlined would grow * beyond 1000kb, we declare a private, inner sub-class, and the function is inlined to it http://git-wip-us.apache.org/repos/asf/spark/blob/92c2f00b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 7c064a1..3afabe1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform @@ -476,6 +476,223 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp } /** + * Returns a map created from the given array of entries. + */ +@ExpressionDescription( + usage = "_FUNC_(arrayOfEntries) - Returns a map created from the given array of entries.", + examples = """ + Examples: + > SELECT _FUNC_(array(struct(1, 'a'), struct(2, 'b'))); + {1:"a",2:"b"} + """, + since = "2.4.0") +case class MapFromEntries(child: Expression) extends UnaryExpression { + + @transient + private lazy val dataTypeDetails: Option[(MapType, Boolean, Boolean)] = child.dataType match { + case ArrayType( + StructType(Array( + StructField(_, keyType, keyNullable, _), + StructField(_, valueType, valueNullable, _))), + containsNull) => Some((MapType(keyType, valueType, valueNullable), keyNullable, containsNull)) + case _ => None + } + + private def nullEntries: Boolean = dataTypeDetails.get._3 + + override def nullable: Boolean = child.nullable || nullEntries + + override def dataType: MapType = dataTypeDetails.get._1 + + override def checkInputDataTypes(): TypeCheckResult = dataTypeDetails match { + case Some(_) => TypeCheckResult.TypeCheckSuccess + case None => TypeCheckResult.TypeCheckFailure(s"'${child.sql}' is of " + + s"${child.dataType.simpleString} type. $prettyName accepts only arrays of pair structs.") + } + + override protected def nullSafeEval(input: Any): Any = { + val arrayData = input.asInstanceOf[ArrayData] + val numEntries = arrayData.numElements() + var i = 0 + if(nullEntries) { + while (i < numEntries) { + if (arrayData.isNullAt(i)) return null + i += 1 + } + } + val keyArray = new Array[AnyRef](numEntries) + val valueArray = new Array[AnyRef](numEntries) + i = 0 + while (i < numEntries) { + val entry = arrayData.getStruct(i, 2) + val key = entry.get(0, dataType.keyType) + if (key == null) { + throw new RuntimeException("The first field from a struct (key) can't be null.") + } + keyArray.update(i, key) + val value = entry.get(1, dataType.valueType) + valueArray.update(i, value) + i += 1 + } + ArrayBasedMapData(keyArray, valueArray) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => { + val numEntries = ctx.freshName("numEntries") + val isKeyPrimitive = CodeGenerator.isPrimitiveType(dataType.keyType) + val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType) + val code = if (isKeyPrimitive && isValuePrimitive) { + genCodeForPrimitiveElements(ctx, c, ev.value, numEntries) + } else { + genCodeForAnyElements(ctx, c, ev.value, numEntries) + } + ctx.nullArrayElementsSaveExec(nullEntries, ev.isNull, c) { + s""" + |final int $numEntries = $c.numElements(); + |$code + """.stripMargin + } + }) + } + + private def genCodeForAssignmentLoop( + ctx: CodegenContext, + childVariable: String, + mapData: String, + numEntries: String, + keyAssignment: (String, String) => String, + valueAssignment: (String, String) => String): String = { + val entry = ctx.freshName("entry") + val i = ctx.freshName("idx") + + val nullKeyCheck = if (dataTypeDetails.get._2) { + s""" + |if ($entry.isNullAt(0)) { + | throw new RuntimeException("The first field from a struct (key) can't be null."); + |} + """.stripMargin + } else { + "" + } + + s""" + |for (int $i = 0; $i < $numEntries; $i++) { + | InternalRow $entry = $childVariable.getStruct($i, 2); + | $nullKeyCheck + | ${keyAssignment(CodeGenerator.getValue(entry, dataType.keyType, "0"), i)} + | ${valueAssignment(entry, i)} + |} + """.stripMargin + } + + private def genCodeForPrimitiveElements( + ctx: CodegenContext, + childVariable: String, + mapData: String, + numEntries: String): String = { + val byteArraySize = ctx.freshName("byteArraySize") + val keySectionSize = ctx.freshName("keySectionSize") + val valueSectionSize = ctx.freshName("valueSectionSize") + val data = ctx.freshName("byteArray") + val unsafeMapData = ctx.freshName("unsafeMapData") + val keyArrayData = ctx.freshName("keyArrayData") + val valueArrayData = ctx.freshName("valueArrayData") + + val baseOffset = Platform.BYTE_ARRAY_OFFSET + val keySize = dataType.keyType.defaultSize + val valueSize = dataType.valueType.defaultSize + val kByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $keySize)" + val vByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $valueSize)" + val keyTypeName = CodeGenerator.primitiveTypeName(dataType.keyType) + val valueTypeName = CodeGenerator.primitiveTypeName(dataType.valueType) + + val keyAssignment = (key: String, idx: String) => s"$keyArrayData.set$keyTypeName($idx, $key);" + val valueAssignment = (entry: String, idx: String) => { + val value = CodeGenerator.getValue(entry, dataType.valueType, "1") + val valueNullUnsafeAssignment = s"$valueArrayData.set$valueTypeName($idx, $value);" + if (dataType.valueContainsNull) { + s""" + |if ($entry.isNullAt(1)) { + | $valueArrayData.setNullAt($idx); + |} else { + | $valueNullUnsafeAssignment + |} + """.stripMargin + } else { + valueNullUnsafeAssignment + } + } + val assignmentLoop = genCodeForAssignmentLoop( + ctx, + childVariable, + mapData, + numEntries, + keyAssignment, + valueAssignment + ) + + s""" + |final long $keySectionSize = $kByteSize; + |final long $valueSectionSize = $vByteSize; + |final long $byteArraySize = 8 + $keySectionSize + $valueSectionSize; + |if ($byteArraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | ${genCodeForAnyElements(ctx, childVariable, mapData, numEntries)} + |} else { + | final byte[] $data = new byte[(int)$byteArraySize]; + | UnsafeMapData $unsafeMapData = new UnsafeMapData(); + | Platform.putLong($data, $baseOffset, $keySectionSize); + | Platform.putLong($data, ${baseOffset + 8}, $numEntries); + | Platform.putLong($data, ${baseOffset + 8} + $keySectionSize, $numEntries); + | $unsafeMapData.pointTo($data, $baseOffset, (int)$byteArraySize); + | ArrayData $keyArrayData = $unsafeMapData.keyArray(); + | ArrayData $valueArrayData = $unsafeMapData.valueArray(); + | $assignmentLoop + | $mapData = $unsafeMapData; + |} + """.stripMargin + } + + private def genCodeForAnyElements( + ctx: CodegenContext, + childVariable: String, + mapData: String, + numEntries: String): String = { + val keys = ctx.freshName("keys") + val values = ctx.freshName("values") + val mapDataClass = classOf[ArrayBasedMapData].getName() + + val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType) + val valueAssignment = (entry: String, idx: String) => { + val value = CodeGenerator.getValue(entry, dataType.valueType, "1") + if (dataType.valueContainsNull && isValuePrimitive) { + s"$values[$idx] = $entry.isNullAt(1) ? null : (Object)$value;" + } else { + s"$values[$idx] = $value;" + } + } + val keyAssignment = (key: String, idx: String) => s"$keys[$idx] = $key;" + val assignmentLoop = genCodeForAssignmentLoop( + ctx, + childVariable, + mapData, + numEntries, + keyAssignment, + valueAssignment) + + s""" + |final Object[] $keys = new Object[$numEntries]; + |final Object[] $values = new Object[$numEntries]; + |$assignmentLoop + |$mapData = $mapDataClass.apply($keys, $values); + """.stripMargin + } + + override def prettyName: String = "map_from_entries" +} + + +/** * Common base class for [[SortArray]] and [[ArraySort]]. */ trait ArraySortLike extends ExpectsInputTypes { @@ -1990,24 +2207,10 @@ case class Flatten(child: Expression) extends UnaryExpression { } else { genCodeForFlattenOfNonPrimitiveElements(ctx, c, ev.value) } - if (childDataType.containsNull) nullElementsProtection(ev, c, code) else code + ctx.nullArrayElementsSaveExec(childDataType.containsNull, ev.isNull, c)(code) }) } - private def nullElementsProtection( - ev: ExprCode, - childVariableName: String, - coreLogic: String): String = { - s""" - |for (int z = 0; !${ev.isNull} && z < $childVariableName.numElements(); z++) { - | ${ev.isNull} |= $childVariableName.isNullAt(z); - |} - |if (!${ev.isNull}) { - | $coreLogic - |} - """.stripMargin - } - private def genCodeForNumberOfElements( ctx: CodegenContext, childVariableName: String) : (String, String) = { http://git-wip-us.apache.org/repos/asf/spark/blob/92c2f00b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index f377f9c..5b8cf51 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -80,6 +80,57 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MapEntries(ms2), null) } + test("MapFromEntries") { + def arrayType(keyType: DataType, valueType: DataType) : DataType = { + ArrayType( + StructType(Seq( + StructField("a", keyType), + StructField("b", valueType))), + true) + } + def r(values: Any*): InternalRow = create_row(values: _*) + + // Primitive-type keys and values + val aiType = arrayType(IntegerType, IntegerType) + val ai0 = Literal.create(Seq(r(1, 10), r(2, 20), r(3, 20)), aiType) + val ai1 = Literal.create(Seq(r(1, null), r(2, 20), r(3, null)), aiType) + val ai2 = Literal.create(Seq.empty, aiType) + val ai3 = Literal.create(null, aiType) + val ai4 = Literal.create(Seq(r(1, 10), r(1, 20)), aiType) + val ai5 = Literal.create(Seq(r(1, 10), r(null, 20)), aiType) + val ai6 = Literal.create(Seq(null, r(2, 20), null), aiType) + + checkEvaluation(MapFromEntries(ai0), Map(1 -> 10, 2 -> 20, 3 -> 20)) + checkEvaluation(MapFromEntries(ai1), Map(1 -> null, 2 -> 20, 3 -> null)) + checkEvaluation(MapFromEntries(ai2), Map.empty) + checkEvaluation(MapFromEntries(ai3), null) + checkEvaluation(MapKeys(MapFromEntries(ai4)), Seq(1, 1)) + checkExceptionInExpression[RuntimeException]( + MapFromEntries(ai5), + "The first field from a struct (key) can't be null.") + checkEvaluation(MapFromEntries(ai6), null) + + // Non-primitive-type keys and values + val asType = arrayType(StringType, StringType) + val as0 = Literal.create(Seq(r("a", "aa"), r("b", "bb"), r("c", "bb")), asType) + val as1 = Literal.create(Seq(r("a", null), r("b", "bb"), r("c", null)), asType) + val as2 = Literal.create(Seq.empty, asType) + val as3 = Literal.create(null, asType) + val as4 = Literal.create(Seq(r("a", "aa"), r("a", "bb")), asType) + val as5 = Literal.create(Seq(r("a", "aa"), r(null, "bb")), asType) + val as6 = Literal.create(Seq(null, r("b", "bb"), null), asType) + + checkEvaluation(MapFromEntries(as0), Map("a" -> "aa", "b" -> "bb", "c" -> "bb")) + checkEvaluation(MapFromEntries(as1), Map("a" -> null, "b" -> "bb", "c" -> null)) + checkEvaluation(MapFromEntries(as2), Map.empty) + checkEvaluation(MapFromEntries(as3), null) + checkEvaluation(MapKeys(MapFromEntries(as4)), Seq("a", "a")) + checkExceptionInExpression[RuntimeException]( + MapFromEntries(as5), + "The first field from a struct (key) can't be null.") + checkEvaluation(MapFromEntries(as6), null) + } + test("Sort Array") { val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType)) val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) http://git-wip-us.apache.org/repos/asf/spark/blob/92c2f00b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index c296a1b..f792a6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3527,6 +3527,13 @@ object functions { def map_entries(e: Column): Column = withExpr { MapEntries(e.expr) } /** + * Returns a map created from the given array of entries. + * @group collection_funcs + * @since 2.4.0 + */ + def map_from_entries(e: Column): Column = withExpr { MapFromEntries(e.expr) } + + /** * Returns a merged array of structs in which the N-th struct contains all N-th values of input * arrays. * @group collection_funcs http://git-wip-us.apache.org/repos/asf/spark/blob/92c2f00b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index fcdd33f..25fdbab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -633,6 +633,56 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(sdf.filter(dummyFilter('m)).select(map_entries('m)), sExpected) } + test("map_from_entries function") { + def dummyFilter(c: Column): Column = c.isNull || c.isNotNull + val oneRowDF = Seq(3215).toDF("i") + + // Test cases with primitive-type keys and values + val idf = Seq( + Seq((1, 10), (2, 20), (3, 10)), + Seq((1, 10), null, (2, 20)), + Seq.empty, + null + ).toDF("a") + val iExpected = Seq( + Row(Map(1 -> 10, 2 -> 20, 3 -> 10)), + Row(null), + Row(Map.empty), + Row(null)) + + checkAnswer(idf.select(map_from_entries('a)), iExpected) + checkAnswer(idf.selectExpr("map_from_entries(a)"), iExpected) + checkAnswer(idf.filter(dummyFilter('a)).select(map_from_entries('a)), iExpected) + checkAnswer( + oneRowDF.selectExpr("map_from_entries(array(struct(1, null), struct(2, null)))"), + Seq(Row(Map(1 -> null, 2 -> null))) + ) + checkAnswer( + oneRowDF.filter(dummyFilter('i)) + .selectExpr("map_from_entries(array(struct(1, null), struct(2, null)))"), + Seq(Row(Map(1 -> null, 2 -> null))) + ) + + // Test cases with non-primitive-type keys and values + val sdf = Seq( + Seq(("a", "aa"), ("b", "bb"), ("c", "aa")), + Seq(("a", "aa"), null, ("b", "bb")), + Seq(("a", null), ("b", null)), + Seq.empty, + null + ).toDF("a") + val sExpected = Seq( + Row(Map("a" -> "aa", "b" -> "bb", "c" -> "aa")), + Row(null), + Row(Map("a" -> null, "b" -> null)), + Row(Map.empty), + Row(null)) + + checkAnswer(sdf.select(map_from_entries('a)), sExpected) + checkAnswer(sdf.selectExpr("map_from_entries(a)"), sExpected) + checkAnswer(sdf.filter(dummyFilter('a)).select(map_from_entries('a)), sExpected) + } + test("array contains function") { val df = Seq( (Seq[Int](1, 2), "x", 1), --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org