Repository: spark
Updated Branches:
  refs/heads/master d6c26d1c9 -> 5fea17b3b


[SPARK-23821][SQL] Collection function: flatten

## What changes were proposed in this pull request?

This PR adds a new collection function that transforms an array of arrays into 
a single array. The PR comprises:
- An expression for flattening array structure
- Flatten function
- A wrapper for PySpark

## How was this patch tested?

New tests added into:
- CollectionExpressionsSuite
- DataFrameFunctionsSuite

## Codegen examples
### Primitive type
```
val df = Seq(
  Seq(Seq(1, 2), Seq(4, 5)),
  Seq(null, Seq(1))
).toDF("i")
df.filter($"i".isNotNull || $"i".isNull).select(flatten($"i")).debugCodegen
```
Result:
```
/* 033 */         boolean inputadapter_isNull = inputadapter_row.isNullAt(0);
/* 034 */         ArrayData inputadapter_value = inputadapter_isNull ?
/* 035 */         null : (inputadapter_row.getArray(0));
/* 036 */
/* 037 */         boolean filter_value = true;
/* 038 */
/* 039 */         if (!(!inputadapter_isNull)) {
/* 040 */           filter_value = inputadapter_isNull;
/* 041 */         }
/* 042 */         if (!filter_value) continue;
/* 043 */
/* 044 */         ((org.apache.spark.sql.execution.metric.SQLMetric) 
references[0] /* numOutputRows */).add(1);
/* 045 */
/* 046 */         boolean project_isNull = inputadapter_isNull;
/* 047 */         ArrayData project_value = null;
/* 048 */
/* 049 */         if (!inputadapter_isNull) {
/* 050 */           for (int z = 0; !project_isNull && z < 
inputadapter_value.numElements(); z++) {
/* 051 */             project_isNull |= inputadapter_value.isNullAt(z);
/* 052 */           }
/* 053 */           if (!project_isNull) {
/* 054 */             long project_numElements = 0;
/* 055 */             for (int z = 0; z < inputadapter_value.numElements(); 
z++) {
/* 056 */               project_numElements += 
inputadapter_value.getArray(z).numElements();
/* 057 */             }
/* 058 */             if (project_numElements > 2147483632) {
/* 059 */               throw new RuntimeException("Unsuccessful try to flatten 
an array of arrays with " +
/* 060 */                 project_numElements + " elements due to exceeding the 
array size limit 2147483632.");
/* 061 */             }
/* 062 */
/* 063 */             long project_size = 
UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
/* 064 */               project_numElements,
/* 065 */               4);
/* 066 */             if (project_size > 2147483632) {
/* 067 */               throw new RuntimeException("Unsuccessful try to flatten 
an array of arrays with " +
/* 068 */                 project_size + " bytes of data due to exceeding the 
limit 2147483632" +
/* 069 */                 " bytes for UnsafeArrayData.");
/* 070 */             }
/* 071 */
/* 072 */             byte[] project_array = new byte[(int)project_size];
/* 073 */             UnsafeArrayData project_tempArrayData = new 
UnsafeArrayData();
/* 074 */             Platform.putLong(project_array, 16, project_numElements);
/* 075 */             project_tempArrayData.pointTo(project_array, 16, 
(int)project_size);
/* 076 */             int project_counter = 0;
/* 077 */             for (int k = 0; k < inputadapter_value.numElements(); 
k++) {
/* 078 */               ArrayData arr = inputadapter_value.getArray(k);
/* 079 */               for (int l = 0; l < arr.numElements(); l++) {
/* 080 */                 if (arr.isNullAt(l)) {
/* 081 */                   project_tempArrayData.setNullAt(project_counter);
/* 082 */                 } else {
/* 083 */                   project_tempArrayData.setInt(
/* 084 */                     project_counter,
/* 085 */                     arr.getInt(l)
/* 086 */                   );
/* 087 */                 }
/* 088 */                 project_counter++;
/* 089 */               }
/* 090 */             }
/* 091 */             project_value = project_tempArrayData;
/* 092 */
/* 093 */           }
/* 094 */
/* 095 */         }
```
### Non-primitive type
```
val df = Seq(
  Seq(Seq("a", "b"), Seq(null, "d")),
  Seq(null, Seq("a"))
).toDF("s")
df.filter($"s".isNotNull || $"s".isNull).select(flatten($"s")).debugCodegen
```
Result:
```
/* 033 */         boolean inputadapter_isNull = inputadapter_row.isNullAt(0);
/* 034 */         ArrayData inputadapter_value = inputadapter_isNull ?
/* 035 */         null : (inputadapter_row.getArray(0));
/* 036 */
/* 037 */         boolean filter_value = true;
/* 038 */
/* 039 */         if (!(!inputadapter_isNull)) {
/* 040 */           filter_value = inputadapter_isNull;
/* 041 */         }
/* 042 */         if (!filter_value) continue;
/* 043 */
/* 044 */         ((org.apache.spark.sql.execution.metric.SQLMetric) 
references[0] /* numOutputRows */).add(1);
/* 045 */
/* 046 */         boolean project_isNull = inputadapter_isNull;
/* 047 */         ArrayData project_value = null;
/* 048 */
/* 049 */         if (!inputadapter_isNull) {
/* 050 */           for (int z = 0; !project_isNull && z < 
inputadapter_value.numElements(); z++) {
/* 051 */             project_isNull |= inputadapter_value.isNullAt(z);
/* 052 */           }
/* 053 */           if (!project_isNull) {
/* 054 */             long project_numElements = 0;
/* 055 */             for (int z = 0; z < inputadapter_value.numElements(); 
z++) {
/* 056 */               project_numElements += 
inputadapter_value.getArray(z).numElements();
/* 057 */             }
/* 058 */             if (project_numElements > 2147483632) {
/* 059 */               throw new RuntimeException("Unsuccessful try to flatten 
an array of arrays with " +
/* 060 */                 project_numElements + " elements due to exceeding the 
array size limit 2147483632.");
/* 061 */             }
/* 062 */
/* 063 */             Object[] project_arrayObject = new 
Object[(int)project_numElements];
/* 064 */             int project_counter = 0;
/* 065 */             for (int k = 0; k < inputadapter_value.numElements(); 
k++) {
/* 066 */               ArrayData arr = inputadapter_value.getArray(k);
/* 067 */               for (int l = 0; l < arr.numElements(); l++) {
/* 068 */                 project_arrayObject[project_counter] = 
arr.getUTF8String(l);
/* 069 */                 project_counter++;
/* 070 */               }
/* 071 */             }
/* 072 */             project_value = new 
org.apache.spark.sql.catalyst.util.GenericArrayData(project_arrayObject);
/* 073 */
/* 074 */           }
/* 075 */
/* 076 */         }
```

Author: mn-mikke <mrkAha12346github>

Closes #20938 from mn-mikke/feature/array-api-flatten-to-master.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5fea17b3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5fea17b3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5fea17b3

Branch: refs/heads/master
Commit: 5fea17b3befc50aef59b799711d03b9552f21b19
Parents: d6c26d1
Author: mn-mikke <mrkAha12346github>
Authored: Wed Apr 25 11:19:08 2018 +0900
Committer: Takuya UESHIN <ues...@databricks.com>
Committed: Wed Apr 25 11:19:08 2018 +0900

----------------------------------------------------------------------
 python/pyspark/sql/functions.py                 |  17 ++
 .../catalyst/analysis/FunctionRegistry.scala    |   1 +
 .../expressions/collectionOperations.scala      | 176 +++++++++++++++++++
 .../CollectionExpressionsSuite.scala            |  95 ++++++++++
 .../scala/org/apache/spark/sql/functions.scala  |   8 +
 .../spark/sql/DataFrameFunctionsSuite.scala     |  79 +++++++++
 6 files changed, 376 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5fea17b3/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index da32ab2..de53b48 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2191,6 +2191,23 @@ def reverse(col):
     return Column(sc._jvm.functions.reverse(_to_java_column(col)))
 
 
+@since(2.4)
+def flatten(col):
+    """
+    Collection function: creates a single array from an array of arrays.
+    If a structure of nested arrays is deeper than two levels,
+    only one level of nesting is removed.
+
+    :param col: name of column or expression
+
+    >>> df = spark.createDataFrame([([[1, 2, 3], [4, 5], [6]],), ([None, [4, 
5]],)], ['data'])
+    >>> df.select(flatten(df.data).alias('r')).collect()
+    [Row(r=[1, 2, 3, 4, 5, 6]), Row(r=None)]
+    """
+    sc = SparkContext._active_spark_context
+    return Column(sc._jvm.functions.flatten(_to_java_column(col)))
+
+
 @since(2.3)
 def map_keys(col):
     """

http://git-wip-us.apache.org/repos/asf/spark/blob/5fea17b3/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index c41f16c..6afcf30 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -413,6 +413,7 @@ object FunctionRegistry {
     expression[ArrayMax]("array_max"),
     expression[Reverse]("reverse"),
     expression[Concat]("concat"),
+    expression[Flatten]("flatten"),
     CreateStruct.registryEntry,
 
     // misc functions

http://git-wip-us.apache.org/repos/asf/spark/blob/5fea17b3/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index c16793b..bc71b5f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -883,3 +883,179 @@ case class Concat(children: Seq[Expression]) extends 
Expression {
 
   override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})"
 }
+
+/**
+ * Transforms an array of arrays into a single array.
+ */
+@ExpressionDescription(
+  usage = "_FUNC_(arrayOfArrays) - Transforms an array of arrays into a single 
array.",
+  examples = """
+    Examples:
+      > SELECT _FUNC_(array(array(1, 2), array(3, 4));
+       [1,2,3,4]
+  """,
+  since = "2.4.0")
+case class Flatten(child: Expression) extends UnaryExpression {
+
+  private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
+
+  private lazy val childDataType: ArrayType = 
child.dataType.asInstanceOf[ArrayType]
+
+  override def nullable: Boolean = child.nullable || childDataType.containsNull
+
+  override def dataType: DataType = childDataType.elementType
+
+  lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType
+
+  override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
+    case ArrayType(_: ArrayType, _) =>
+      TypeCheckResult.TypeCheckSuccess
+    case _ =>
+      TypeCheckResult.TypeCheckFailure(
+        s"The argument should be an array of arrays, " +
+        s"but '${child.sql}' is of ${child.dataType.simpleString} type."
+      )
+  }
+
+  override def nullSafeEval(child: Any): Any = {
+    val elements = child.asInstanceOf[ArrayData].toObjectArray(dataType)
+
+    if (elements.contains(null)) {
+      null
+    } else {
+      val arrayData = elements.map(_.asInstanceOf[ArrayData])
+      val numberOfElements = arrayData.foldLeft(0L)((sum, e) => sum + 
e.numElements())
+      if (numberOfElements > MAX_ARRAY_LENGTH) {
+        throw new RuntimeException("Unsuccessful try to flatten an array of 
arrays with " +
+          s"$numberOfElements elements due to exceeding the array size limit 
$MAX_ARRAY_LENGTH.")
+      }
+      val flattenedData = new Array(numberOfElements.toInt)
+      var position = 0
+      for (ad <- arrayData) {
+        val arr = ad.toObjectArray(elementType)
+        Array.copy(arr, 0, flattenedData, position, arr.length)
+        position += arr.length
+      }
+      new GenericArrayData(flattenedData)
+    }
+  }
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    nullSafeCodeGen(ctx, ev, c => {
+      val code = if (CodeGenerator.isPrimitiveType(elementType)) {
+        genCodeForFlattenOfPrimitiveElements(ctx, c, ev.value)
+      } else {
+        genCodeForFlattenOfNonPrimitiveElements(ctx, c, ev.value)
+      }
+      if (childDataType.containsNull) nullElementsProtection(ev, c, code) else 
code
+    })
+  }
+
+  private def nullElementsProtection(
+      ev: ExprCode,
+      childVariableName: String,
+      coreLogic: String): String = {
+    s"""
+    |for (int z = 0; !${ev.isNull} && z < $childVariableName.numElements(); 
z++) {
+    |  ${ev.isNull} |= $childVariableName.isNullAt(z);
+    |}
+    |if (!${ev.isNull}) {
+    |  $coreLogic
+    |}
+    """.stripMargin
+  }
+
+  private def genCodeForNumberOfElements(
+      ctx: CodegenContext,
+      childVariableName: String) : (String, String) = {
+    val variableName = ctx.freshName("numElements")
+    val code = s"""
+      |long $variableName = 0;
+      |for (int z = 0; z < $childVariableName.numElements(); z++) {
+      |  $variableName += $childVariableName.getArray(z).numElements();
+      |}
+      |if ($variableName > $MAX_ARRAY_LENGTH) {
+      |  throw new RuntimeException("Unsuccessful try to flatten an array of 
arrays with " +
+      |    $variableName + " elements due to exceeding the array size limit 
$MAX_ARRAY_LENGTH.");
+      |}
+      """.stripMargin
+    (code, variableName)
+  }
+
+  private def genCodeForFlattenOfPrimitiveElements(
+      ctx: CodegenContext,
+      childVariableName: String,
+      arrayDataName: String): String = {
+    val arrayName = ctx.freshName("array")
+    val arraySizeName = ctx.freshName("size")
+    val counter = ctx.freshName("counter")
+    val tempArrayDataName = ctx.freshName("tempArrayData")
+
+    val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, 
childVariableName)
+
+    val unsafeArraySizeInBytes = s"""
+      |long $arraySizeName = 
UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
+      |  $numElemName,
+      |  ${elementType.defaultSize});
+      |if ($arraySizeName > $MAX_ARRAY_LENGTH) {
+      |  throw new RuntimeException("Unsuccessful try to flatten an array of 
arrays with " +
+      |    $arraySizeName + " bytes of data due to exceeding the limit 
$MAX_ARRAY_LENGTH" +
+      |    " bytes for UnsafeArrayData.");
+      |}
+      """.stripMargin
+    val baseOffset = Platform.BYTE_ARRAY_OFFSET
+
+    val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
+
+    s"""
+    |$numElemCode
+    |$unsafeArraySizeInBytes
+    |byte[] $arrayName = new byte[(int)$arraySizeName];
+    |UnsafeArrayData $tempArrayDataName = new UnsafeArrayData();
+    |Platform.putLong($arrayName, $baseOffset, $numElemName);
+    |$tempArrayDataName.pointTo($arrayName, $baseOffset, (int)$arraySizeName);
+    |int $counter = 0;
+    |for (int k = 0; k < $childVariableName.numElements(); k++) {
+    |  ArrayData arr = $childVariableName.getArray(k);
+    |  for (int l = 0; l < arr.numElements(); l++) {
+    |   if (arr.isNullAt(l)) {
+    |     $tempArrayDataName.setNullAt($counter);
+    |   } else {
+    |     $tempArrayDataName.set$primitiveValueTypeName(
+    |       $counter,
+    |       ${CodeGenerator.getValue("arr", elementType, "l")}
+    |     );
+    |   }
+    |   $counter++;
+    | }
+    |}
+    |$arrayDataName = $tempArrayDataName;
+    """.stripMargin
+  }
+
+  private def genCodeForFlattenOfNonPrimitiveElements(
+      ctx: CodegenContext,
+      childVariableName: String,
+      arrayDataName: String): String = {
+    val genericArrayClass = classOf[GenericArrayData].getName
+    val arrayName = ctx.freshName("arrayObject")
+    val counter = ctx.freshName("counter")
+    val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, 
childVariableName)
+
+    s"""
+    |$numElemCode
+    |Object[] $arrayName = new Object[(int)$numElemName];
+    |int $counter = 0;
+    |for (int k = 0; k < $childVariableName.numElements(); k++) {
+    |  ArrayData arr = $childVariableName.getArray(k);
+    |  for (int l = 0; l < arr.numElements(); l++) {
+    |    $arrayName[$counter] = ${CodeGenerator.getValue("arr", elementType, 
"l")};
+    |    $counter++;
+    |  }
+    |}
+    |$arrayDataName = new $genericArrayClass($arrayName);
+    """.stripMargin
+  }
+
+  override def prettyName: String = "flatten"
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/5fea17b3/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index 43c5dda..b49fa76 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -280,4 +280,99 @@ class CollectionExpressionsSuite extends SparkFunSuite 
with ExpressionEvalHelper
 
     checkEvaluation(Concat(Seq(aa0, aa1)), Seq(Seq("a", "b"), Seq("c"), 
Seq("d"), Seq("e", "f")))
   }
+
+  test("Flatten") {
+    // Primitive-type test cases
+    val intArrayType = ArrayType(ArrayType(IntegerType))
+
+    // Main test cases (primitive type)
+    val aim1 = Literal.create(Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6)), 
intArrayType)
+    val aim2 = Literal.create(Seq(Seq(1, 2, 3)), intArrayType)
+
+    checkEvaluation(Flatten(aim1), Seq(1, 2, 3, 4, 5, 6))
+    checkEvaluation(Flatten(aim2), Seq(1, 2, 3))
+
+    // Test cases with an empty array (primitive type)
+    val aie1 = Literal.create(Seq(Seq.empty, Seq(1, 2), Seq(3, 4)), 
intArrayType)
+    val aie2 = Literal.create(Seq(Seq(1, 2), Seq.empty, Seq(3, 4)), 
intArrayType)
+    val aie3 = Literal.create(Seq(Seq(1, 2), Seq(3, 4), Seq.empty), 
intArrayType)
+    val aie4 = Literal.create(Seq(Seq.empty, Seq.empty, Seq.empty), 
intArrayType)
+    val aie5 = Literal.create(Seq(Seq.empty), intArrayType)
+    val aie6 = Literal.create(Seq.empty, intArrayType)
+
+    checkEvaluation(Flatten(aie1), Seq(1, 2, 3, 4))
+    checkEvaluation(Flatten(aie2), Seq(1, 2, 3, 4))
+    checkEvaluation(Flatten(aie3), Seq(1, 2, 3, 4))
+    checkEvaluation(Flatten(aie4), Seq.empty)
+    checkEvaluation(Flatten(aie5), Seq.empty)
+    checkEvaluation(Flatten(aie6), Seq.empty)
+
+    // Test cases with null elements (primitive type)
+    val ain1 = Literal.create(Seq(Seq(null, null, null), Seq(4, null)), 
intArrayType)
+    val ain2 = Literal.create(Seq(Seq(null, 2, null), Seq(null, null)), 
intArrayType)
+    val ain3 = Literal.create(Seq(Seq(null, null), Seq(null, null)), 
intArrayType)
+
+    checkEvaluation(Flatten(ain1), Seq(null, null, null, 4, null))
+    checkEvaluation(Flatten(ain2), Seq(null, 2, null, null, null))
+    checkEvaluation(Flatten(ain3), Seq(null, null, null, null))
+
+    // Test cases with a null array (primitive type)
+    val aia1 = Literal.create(Seq(null, Seq(1, 2)), intArrayType)
+    val aia2 = Literal.create(Seq(Seq(1, 2), null), intArrayType)
+    val aia3 = Literal.create(Seq(null), intArrayType)
+    val aia4 = Literal.create(null, intArrayType)
+
+    checkEvaluation(Flatten(aia1), null)
+    checkEvaluation(Flatten(aia2), null)
+    checkEvaluation(Flatten(aia3), null)
+    checkEvaluation(Flatten(aia4), null)
+
+    // Non-primitive-type test cases
+    val strArrayType = ArrayType(ArrayType(StringType))
+    val arrArrayType = ArrayType(ArrayType(ArrayType(StringType)))
+
+    // Main test cases (non-primitive type)
+    val asm1 = Literal.create(Seq(Seq("a"), Seq("b", "c"), Seq("d", "e", 
"f")), strArrayType)
+    val asm2 = Literal.create(Seq(Seq("a", "b")), strArrayType)
+    val asm3 = Literal.create(Seq(Seq(Seq("a", "b"), Seq("c")), Seq(Seq("d", 
"e"))), arrArrayType)
+
+    checkEvaluation(Flatten(asm1), Seq("a", "b", "c", "d", "e", "f"))
+    checkEvaluation(Flatten(asm2), Seq("a", "b"))
+    checkEvaluation(Flatten(asm3), Seq(Seq("a", "b"), Seq("c"), Seq("d", "e")))
+
+    // Test cases with an empty array (non-primitive type)
+    val ase1 = Literal.create(Seq(Seq.empty, Seq("a", "b"), Seq("c", "d")), 
strArrayType)
+    val ase2 = Literal.create(Seq(Seq("a", "b"), Seq.empty, Seq("c", "d")), 
strArrayType)
+    val ase3 = Literal.create(Seq(Seq("a", "b"), Seq("c", "d"), Seq.empty), 
strArrayType)
+    val ase4 = Literal.create(Seq(Seq.empty, Seq.empty, Seq.empty), 
strArrayType)
+    val ase5 = Literal.create(Seq(Seq.empty), strArrayType)
+    val ase6 = Literal.create(Seq.empty, strArrayType)
+
+    checkEvaluation(Flatten(ase1), Seq("a", "b", "c", "d"))
+    checkEvaluation(Flatten(ase2), Seq("a", "b", "c", "d"))
+    checkEvaluation(Flatten(ase3), Seq("a", "b", "c", "d"))
+    checkEvaluation(Flatten(ase4), Seq.empty)
+    checkEvaluation(Flatten(ase5), Seq.empty)
+    checkEvaluation(Flatten(ase6), Seq.empty)
+
+    // Test cases with null elements (non-primitive type)
+    val asn1 = Literal.create(Seq(Seq(null, null, "c"), Seq(null, null)), 
strArrayType)
+    val asn2 = Literal.create(Seq(Seq(null, null, null), Seq("d", null)), 
strArrayType)
+    val asn3 = Literal.create(Seq(Seq(null, null), Seq(null, null)), 
strArrayType)
+
+    checkEvaluation(Flatten(asn1), Seq(null, null, "c", null, null))
+    checkEvaluation(Flatten(asn2), Seq(null, null, null, "d", null))
+    checkEvaluation(Flatten(asn3), Seq(null, null, null, null))
+
+    // Test cases with a null array (non-primitive type)
+    val asa1 = Literal.create(Seq(null, Seq("a", "b")), strArrayType)
+    val asa2 = Literal.create(Seq(Seq("a", "b"), null), strArrayType)
+    val asa3 = Literal.create(Seq(null), strArrayType)
+    val asa4 = Literal.create(null, strArrayType)
+
+    checkEvaluation(Flatten(asa1), null)
+    checkEvaluation(Flatten(asa2), null)
+    checkEvaluation(Flatten(asa3), null)
+    checkEvaluation(Flatten(asa4), null)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/5fea17b3/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index bea8c0e..d2f0573 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -3341,6 +3341,14 @@ object functions {
   def reverse(e: Column): Column = withExpr { Reverse(e.expr) }
 
   /**
+   * Creates a single array from an array of arrays. If a structure of nested 
arrays is deeper than
+   * two levels, only one level of nesting is removed.
+   * @group collection_funcs
+   * @since 2.4.0
+   */
+  def flatten(e: Column): Column = withExpr { Flatten(e.expr) }
+
+  /**
    * Returns an unordered array containing the keys of the map.
    * @group collection_funcs
    * @since 2.3.0

http://git-wip-us.apache.org/repos/asf/spark/blob/5fea17b3/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 25e5cd6..03605c3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -691,6 +691,85 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSQLContext {
     }
   }
 
+  test("flatten function") {
+    val dummyFilter = (c: Column) => c.isNull || c.isNotNull // to switch 
codeGen on
+    val oneRowDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr")
+
+    // Test cases with a primitive type
+    val intDF = Seq(
+      (Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6))),
+      (Seq(Seq(1, 2))),
+      (Seq(Seq(1), Seq.empty)),
+      (Seq(Seq.empty, Seq(1))),
+      (Seq(Seq.empty, Seq.empty)),
+      (Seq(Seq(1), null)),
+      (Seq(null, Seq(1))),
+      (Seq(null, null))
+    ).toDF("i")
+
+    val intDFResult = Seq(
+      Row(Seq(1, 2, 3, 4, 5, 6)),
+      Row(Seq(1, 2)),
+      Row(Seq(1)),
+      Row(Seq(1)),
+      Row(Seq.empty),
+      Row(null),
+      Row(null),
+      Row(null))
+
+    checkAnswer(intDF.select(flatten($"i")), intDFResult)
+    checkAnswer(intDF.filter(dummyFilter($"i"))select(flatten($"i")), 
intDFResult)
+    checkAnswer(intDF.selectExpr("flatten(i)"), intDFResult)
+    checkAnswer(
+      oneRowDF.selectExpr("flatten(array(arr, array(null, 5), array(6, 
null)))"),
+      Seq(Row(Seq(1, 2, 3, null, 5, 6, null))))
+
+    // Test cases with non-primitive types
+    val strDF = Seq(
+      (Seq(Seq("a", "b"), Seq("c"), Seq("d", "e", "f"))),
+      (Seq(Seq("a", "b"))),
+      (Seq(Seq("a", null), Seq(null, "b"), Seq(null, null))),
+      (Seq(Seq("a"), Seq.empty)),
+      (Seq(Seq.empty, Seq("a"))),
+      (Seq(Seq.empty, Seq.empty)),
+      (Seq(Seq("a"), null)),
+      (Seq(null, Seq("a"))),
+      (Seq(null, null))
+    ).toDF("s")
+
+    val strDFResult = Seq(
+      Row(Seq("a", "b", "c", "d", "e", "f")),
+      Row(Seq("a", "b")),
+      Row(Seq("a", null, null, "b", null, null)),
+      Row(Seq("a")),
+      Row(Seq("a")),
+      Row(Seq.empty),
+      Row(null),
+      Row(null),
+      Row(null))
+
+    checkAnswer(strDF.select(flatten($"s")), strDFResult)
+    checkAnswer(strDF.filter(dummyFilter($"s")).select(flatten($"s")), 
strDFResult)
+    checkAnswer(strDF.selectExpr("flatten(s)"), strDFResult)
+    checkAnswer(
+      oneRowDF.selectExpr("flatten(array(array(arr, arr), array(arr)))"),
+      Seq(Row(Seq(Seq(1, 2, 3), Seq(1, 2, 3), Seq(1, 2, 3)))))
+
+    // Error test cases
+    intercept[AnalysisException] {
+      oneRowDF.select(flatten($"arr"))
+    }
+    intercept[AnalysisException] {
+      oneRowDF.select(flatten($"i"))
+    }
+    intercept[AnalysisException] {
+      oneRowDF.select(flatten($"s"))
+    }
+    intercept[AnalysisException] {
+      oneRowDF.selectExpr("flatten(null)")
+    }
+  }
+
   private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
     import DataFrameFunctionsSuite.CodegenFallbackExpr
     for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), 
(false, true))) {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to