Repository: spark
Updated Branches:
  refs/heads/master 2ea17afb6 -> f2b3525c1


[SPARK-22771][SQL] Concatenate binary inputs into a binary output

## What changes were proposed in this pull request?
This pr modified `concat` to concat binary inputs into a single binary output.
`concat` in the current master always output data as a string. But, in some 
databases (e.g., PostgreSQL), if all inputs are binary, `concat` also outputs 
binary.

## How was this patch tested?
Added tests in `SQLQueryTestSuite` and `TypeCoercionSuite`.

Author: Takeshi Yamamuro <yamam...@apache.org>

Closes #19977 from maropu/SPARK-22771.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f2b3525c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f2b3525c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f2b3525c

Branch: refs/heads/master
Commit: f2b3525c17d660cf6f082bbafea8632615b4f58e
Parents: 2ea17af
Author: Takeshi Yamamuro <yamam...@apache.org>
Authored: Sat Dec 30 14:09:56 2017 +0800
Committer: gatorsmile <gatorsm...@gmail.com>
Committed: Sat Dec 30 14:09:56 2017 +0800

----------------------------------------------------------------------
 R/pkg/R/functions.R                             |   3 +-
 .../apache/spark/unsafe/types/ByteArray.java    |  25 ++
 docs/sql-programming-guide.md                   |   2 +
 python/pyspark/sql/functions.py                 |   3 +-
 .../spark/sql/catalyst/analysis/Analyzer.scala  |   2 +-
 .../sql/catalyst/analysis/TypeCoercion.scala    |  26 +-
 .../expressions/stringExpressions.scala         |  52 +++-
 .../sql/catalyst/optimizer/expressions.scala    |  15 +-
 .../org/apache/spark/sql/internal/SQLConf.scala |   8 +
 .../catalyst/analysis/TypeCoercionSuite.scala   |  54 +++++
 .../optimizer/CombineConcatsSuite.scala         |  14 +-
 .../scala/org/apache/spark/sql/functions.scala  |   3 +-
 .../sql-tests/inputs/string-functions.sql       |  23 ++
 .../inputs/typeCoercion/native/concat.sql       |  93 ++++++++
 .../sql-tests/results/string-functions.sql.out  |  45 +++-
 .../results/typeCoercion/native/concat.sql.out  | 239 +++++++++++++++++++
 16 files changed, 587 insertions(+), 20 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f2b3525c/R/pkg/R/functions.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
index fff230d..55365a4 100644
--- a/R/pkg/R/functions.R
+++ b/R/pkg/R/functions.R
@@ -2133,7 +2133,8 @@ setMethod("countDistinct",
           })
 
 #' @details
-#' \code{concat}: Concatenates multiple input string columns together into a 
single string column.
+#' \code{concat}: Concatenates multiple input columns together into a single 
column.
+#' If all inputs are binary, concat returns an output as binary. Otherwise, it 
returns as string.
 #'
 #' @rdname column_string_functions
 #' @aliases concat concat,Column-method

http://git-wip-us.apache.org/repos/asf/spark/blob/f2b3525c/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java
----------------------------------------------------------------------
diff --git 
a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java 
b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java
index 7ced13d..c03caf0 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/ByteArray.java
@@ -74,4 +74,29 @@ public final class ByteArray {
     }
     return Arrays.copyOfRange(bytes, start, end);
   }
+
+  public static byte[] concat(byte[]... inputs) {
+    // Compute the total length of the result
+    int totalLength = 0;
+    for (int i = 0; i < inputs.length; i++) {
+      if (inputs[i] != null) {
+        totalLength += inputs[i].length;
+      } else {
+        return null;
+      }
+    }
+
+    // Allocate a new byte array, and copy the inputs one by one into it
+    final byte[] result = new byte[totalLength];
+    int offset = 0;
+    for (int i = 0; i < inputs.length; i++) {
+      int len = inputs[i].length;
+      Platform.copyMemory(
+        inputs[i], Platform.BYTE_ARRAY_OFFSET,
+        result, Platform.BYTE_ARRAY_OFFSET + offset,
+        len);
+      offset += len;
+    }
+    return result;
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f2b3525c/docs/sql-programming-guide.md
----------------------------------------------------------------------
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index f02f462..4b5f56c 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -1780,6 +1780,8 @@ options.
  
  - Since Spark 2.3, when either broadcast hash join or broadcast nested loop 
join is applicable, we prefer to broadcasting the table that is explicitly 
specified in a broadcast hint. For details, see the section [Broadcast 
Hint](#broadcast-hint-for-sql-queries) and 
[SPARK-22489](https://issues.apache.org/jira/browse/SPARK-22489).
 
+ - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns 
an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it 
always returns as a string despite of input types. To keep the old behavior, 
set `spark.sql.function.concatBinaryAsString` to `true`.
+
 ## Upgrading From Spark SQL 2.1 to 2.2
 
   - Spark 2.1.1 introduced a new configuration key: 
`spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of 
`NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 
changes this setting's default value to `INFER_AND_SAVE` to restore 
compatibility with reading Hive metastore tables whose underlying file schema 
have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on 
first access Spark will perform schema inference on any Hive metastore table 
for which it has not already saved an inferred schema. Note that schema 
inference can be a very time consuming operation for tables with thousands of 
partitions. If compatibility with mixed-case column names is not a concern, you 
can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to 
avoid the initial overhead of schema inference. Note that with the new default 
`INFER_AND_SAVE` setting, the results of the schema inference are saved as a 
metastore key for future use
 . Therefore, the initial schema inference occurs only at a table's first 
access.

http://git-wip-us.apache.org/repos/asf/spark/blob/f2b3525c/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 66ee033..a4ed562 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1374,7 +1374,8 @@ del _name, _doc
 @ignore_unicode_prefix
 def concat(*cols):
     """
-    Concatenates multiple input string columns together into a single string 
column.
+    Concatenates multiple input columns together into a single column.
+    If all inputs are binary, concat returns an output as binary. Otherwise, 
it returns as string.
 
     >>> df = spark.createDataFrame([('abcd','123')], ['s', 'd'])
     >>> df.select(concat(df.s, df.d).alias('s')).collect()

http://git-wip-us.apache.org/repos/asf/spark/blob/f2b3525c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 1f7191c..6d294d4 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -150,7 +150,7 @@ class Analyzer(
       TimeWindowing ::
       ResolveInlineTables(conf) ::
       ResolveTimeZone(conf) ::
-      TypeCoercion.typeCoercionRules ++
+      TypeCoercion.typeCoercionRules(conf) ++
       extendedResolutionRules : _*),
     Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*),
     Batch("View", Once,

http://git-wip-us.apache.org/repos/asf/spark/blob/f2b3525c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 1c4be54..e943636 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 
 
@@ -45,13 +46,14 @@ import org.apache.spark.sql.types._
  */
 object TypeCoercion {
 
-  val typeCoercionRules =
+  def typeCoercionRules(conf: SQLConf): List[Rule[LogicalPlan]] =
     InConversion ::
       WidenSetOperationTypes ::
       PromoteStrings ::
       DecimalPrecision ::
       BooleanEquality ::
       FunctionArgumentConversion ::
+      ConcatCoercion(conf) ::
       CaseWhenCoercion ::
       IfCoercion ::
       StackCoercion ::
@@ -661,6 +663,28 @@ object TypeCoercion {
   }
 
   /**
+   * Coerces the types of [[Concat]] children to expected ones.
+   *
+   * If `spark.sql.function.concatBinaryAsString` is false and all children 
types are binary,
+   * the expected types are binary. Otherwise, the expected ones are strings.
+   */
+  case class ConcatCoercion(conf: SQLConf) extends TypeCoercionRule {
+
+    override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan 
transform { case p =>
+      p transformExpressionsUp {
+        // Skip nodes if unresolved or empty children
+        case c @ Concat(children) if !c.childrenResolved || children.isEmpty 
=> c
+        case c @ Concat(children) if conf.concatBinaryAsString ||
+            !children.map(_.dataType).forall(_ == BinaryType) =>
+          val newChildren = c.children.map { e =>
+            ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e)
+          }
+          c.copy(children = newChildren)
+      }
+    }
+  }
+
+  /**
    * Turns Add/Subtract of DateType/TimestampType/StringType and 
CalendarIntervalType
    * to TimeAdd/TimeSub
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/f2b3525c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index c02c41d..b0da55a 100755
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -24,11 +24,10 @@ import java.util.regex.Pattern
 
 import scala.collection.mutable.ArrayBuffer
 
-import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
+import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, 
TypeUtils}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
 
@@ -38,7 +37,8 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
 
 
 /**
- * An expression that concatenates multiple input strings into a single string.
+ * An expression that concatenates multiple inputs into a single output.
+ * If all inputs are binary, concat returns an output as binary. Otherwise, it 
returns as string.
  * If any input is null, concat returns null.
  */
 @ExpressionDescription(
@@ -48,17 +48,37 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
       > SELECT _FUNC_('Spark', 'SQL');
        SparkSQL
   """)
-case class Concat(children: Seq[Expression]) extends Expression with 
ImplicitCastInputTypes {
+case class Concat(children: Seq[Expression]) extends Expression {
 
-  override def inputTypes: Seq[AbstractDataType] = 
Seq.fill(children.size)(StringType)
-  override def dataType: DataType = StringType
+  private lazy val isBinaryMode: Boolean = dataType == BinaryType
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    if (children.isEmpty) {
+      TypeCheckResult.TypeCheckSuccess
+    } else {
+      val childTypes = children.map(_.dataType)
+      if (childTypes.exists(tpe => !Seq(StringType, 
BinaryType).contains(tpe))) {
+        TypeCheckResult.TypeCheckFailure(
+          s"input to function $prettyName should have StringType or 
BinaryType, but it's " +
+            childTypes.map(_.simpleString).mkString("[", ", ", "]"))
+      }
+      TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName")
+    }
+  }
+
+  override def dataType: DataType = 
children.map(_.dataType).headOption.getOrElse(StringType)
 
   override def nullable: Boolean = children.exists(_.nullable)
   override def foldable: Boolean = children.forall(_.foldable)
 
   override def eval(input: InternalRow): Any = {
-    val inputs = children.map(_.eval(input).asInstanceOf[UTF8String])
-    UTF8String.concat(inputs : _*)
+    if (isBinaryMode) {
+      val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]])
+      ByteArray.concat(inputs: _*)
+    } else {
+      val inputs = children.map(_.eval(input).asInstanceOf[UTF8String])
+      UTF8String.concat(inputs : _*)
+    }
   }
 
   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
@@ -73,17 +93,27 @@ case class Concat(children: Seq[Expression]) extends 
Expression with ImplicitCas
         }
       """
     }
+
+    val (concatenator, initCode) = if (isBinaryMode) {
+      (classOf[ByteArray].getName, s"byte[][] $args = new 
byte[${evals.length}][];")
+    } else {
+      ("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];")
+    }
     val codes = ctx.splitExpressionsWithCurrentInputs(
       expressions = inputs,
       funcName = "valueConcat",
-      extraArguments = ("UTF8String[]", args) :: Nil)
+      extraArguments = (s"${ctx.javaType(dataType)}[]", args) :: Nil)
     ev.copy(s"""
-      UTF8String[] $args = new UTF8String[${evals.length}];
+      $initCode
       $codes
-      UTF8String ${ev.value} = UTF8String.concat($args);
+      ${ctx.javaType(dataType)} ${ev.value} = $concatenator.concat($args);
       boolean ${ev.isNull} = ${ev.value} == null;
     """)
   }
+
+  override def toString: String = s"concat(${children.mkString(", ")})"
+
+  override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})"
 }
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f2b3525c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index 85295af..7d830bb 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -21,6 +21,7 @@ import scala.collection.immutable.HashSet
 import scala.collection.mutable.{ArrayBuffer, Stack}
 
 import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.analysis.TypeCoercion.ImplicitTypeCasts
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, 
TrueLiteral}
 import org.apache.spark.sql.catalyst.expressions.aggregate._
@@ -645,6 +646,12 @@ object CombineConcats extends Rule[LogicalPlan] {
       stack.pop() match {
         case Concat(children) =>
           stack.pushAll(children.reverse)
+        // If `spark.sql.function.concatBinaryAsString` is false, nested 
`Concat` exprs possibly
+        // have `Concat`s with binary output. Since `TypeCoercion` casts them 
into strings,
+        // we need to handle the case to combine all nested `Concat`s.
+        case c @ Cast(Concat(children), StringType, _) =>
+          val newChildren = children.map { e => c.copy(child = e) }
+          stack.pushAll(newChildren.reverse)
         case child =>
           flattened += child
       }
@@ -652,8 +659,14 @@ object CombineConcats extends Rule[LogicalPlan] {
     Concat(flattened)
   }
 
+  private def hasNestedConcats(concat: Concat): Boolean = 
concat.children.exists {
+    case c: Concat => true
+    case c @ Cast(Concat(children), StringType, _) => true
+    case _ => false
+  }
+
   def apply(plan: LogicalPlan): LogicalPlan = plan.transformExpressionsDown {
-    case concat: Concat if concat.children.exists(_.isInstanceOf[Concat]) =>
+    case concat: Concat if hasNestedConcats(concat) =>
       flattenConcats(concat)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f2b3525c/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index f16972e..4f77c54 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -1044,6 +1044,12 @@ object SQLConf {
         "When this conf is not set, the value from 
`spark.redaction.string.regex` is used.")
       .fallbackConf(org.apache.spark.internal.config.STRING_REDACTION_PATTERN)
 
+  val CONCAT_BINARY_AS_STRING = 
buildConf("spark.sql.function.concatBinaryAsString")
+    .doc("When this option is set to false and all inputs are binary, 
`functions.concat` returns " +
+      "an output as binary. Otherwise, it returns as a string. ")
+    .booleanConf
+    .createWithDefault(false)
+
   val CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE =
     buildConf("spark.sql.streaming.continuous.executorQueueSize")
     .internal()
@@ -1378,6 +1384,8 @@ class SQLConf extends Serializable with Logging {
   def continuousStreamingExecutorPollIntervalMs: Long =
     getConf(CONTINUOUS_STREAMING_EXECUTOR_POLL_INTERVAL_MS)
 
+  def concatBinaryAsString: Boolean = getConf(CONCAT_BINARY_AS_STRING)
+
   /** ********************** SQLConf functionality methods ************ */
 
   /** Set Spark SQL configuration properties. */

http://git-wip-us.apache.org/repos/asf/spark/blob/f2b3525c/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
index 5dcd653..3661530 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
@@ -869,6 +869,60 @@ class TypeCoercionSuite extends AnalysisTest {
         Literal.create(null, IntegerType), Literal.create(null, StringType))))
   }
 
+  test("type coercion for Concat") {
+    val rule = TypeCoercion.ConcatCoercion(conf)
+
+    ruleTest(rule,
+      Concat(Seq(Literal("ab"), Literal("cde"))),
+      Concat(Seq(Literal("ab"), Literal("cde"))))
+    ruleTest(rule,
+      Concat(Seq(Literal(null), Literal("abc"))),
+      Concat(Seq(Cast(Literal(null), StringType), Literal("abc"))))
+    ruleTest(rule,
+      Concat(Seq(Literal(1), Literal("234"))),
+      Concat(Seq(Cast(Literal(1), StringType), Literal("234"))))
+    ruleTest(rule,
+      Concat(Seq(Literal("1"), Literal("234".getBytes()))),
+      Concat(Seq(Literal("1"), Cast(Literal("234".getBytes()), StringType))))
+    ruleTest(rule,
+      Concat(Seq(Literal(1L), Literal(2.toByte), Literal(0.1))),
+      Concat(Seq(Cast(Literal(1L), StringType), Cast(Literal(2.toByte), 
StringType),
+        Cast(Literal(0.1), StringType))))
+    ruleTest(rule,
+      Concat(Seq(Literal(true), Literal(0.1f), Literal(3.toShort))),
+      Concat(Seq(Cast(Literal(true), StringType), Cast(Literal(0.1f), 
StringType),
+        Cast(Literal(3.toShort), StringType))))
+    ruleTest(rule,
+      Concat(Seq(Literal(1L), Literal(0.1))),
+      Concat(Seq(Cast(Literal(1L), StringType), Cast(Literal(0.1), 
StringType))))
+    ruleTest(rule,
+      Concat(Seq(Literal(Decimal(10)))),
+      Concat(Seq(Cast(Literal(Decimal(10)), StringType))))
+    ruleTest(rule,
+      Concat(Seq(Literal(BigDecimal.valueOf(10)))),
+      Concat(Seq(Cast(Literal(BigDecimal.valueOf(10)), StringType))))
+    ruleTest(rule,
+      Concat(Seq(Literal(java.math.BigDecimal.valueOf(10)))),
+      Concat(Seq(Cast(Literal(java.math.BigDecimal.valueOf(10)), StringType))))
+    ruleTest(rule,
+      Concat(Seq(Literal(new java.sql.Date(0)), Literal(new Timestamp(0)))),
+      Concat(Seq(Cast(Literal(new java.sql.Date(0)), StringType),
+        Cast(Literal(new Timestamp(0)), StringType))))
+
+    withSQLConf("spark.sql.function.concatBinaryAsString" -> "true") {
+      ruleTest(rule,
+        Concat(Seq(Literal("123".getBytes), Literal("456".getBytes))),
+        Concat(Seq(Cast(Literal("123".getBytes), StringType),
+          Cast(Literal("456".getBytes), StringType))))
+    }
+
+    withSQLConf("spark.sql.function.concatBinaryAsString" -> "false") {
+      ruleTest(rule,
+        Concat(Seq(Literal("123".getBytes), Literal("456".getBytes))),
+        Concat(Seq(Literal("123".getBytes), Literal("456".getBytes))))
+    }
+  }
+
   test("BooleanEquality type cast") {
     val be = TypeCoercion.BooleanEquality
     // Use something more than a literal to avoid triggering the 
simplification rules.

http://git-wip-us.apache.org/repos/asf/spark/blob/f2b3525c/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineConcatsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineConcatsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineConcatsSuite.scala
index 412e199..441c153 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineConcatsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineConcatsSuite.scala
@@ -22,7 +22,6 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.PlanTest
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules._
-import org.apache.spark.sql.types.StringType
 
 
 class CombineConcatsSuite extends PlanTest {
@@ -37,8 +36,10 @@ class CombineConcatsSuite extends PlanTest {
     comparePlans(actual, correctAnswer)
   }
 
+  def str(s: String): Literal = Literal(s)
+  def binary(s: String): Literal = Literal(s.getBytes)
+
   test("combine nested Concat exprs") {
-    def str(s: String): Literal = Literal(s, StringType)
     assertEquivalent(
       Concat(
         Concat(str("a") :: str("b") :: Nil) ::
@@ -72,4 +73,13 @@ class CombineConcatsSuite extends PlanTest {
         Nil),
       Concat(str("a") :: str("b") :: str("c") :: str("d") :: Nil))
   }
+
+  test("combine string and binary exprs") {
+    assertEquivalent(
+      Concat(
+        Concat(str("a") :: str("b") :: Nil) ::
+        Concat(binary("c") :: binary("d") :: Nil) ::
+        Nil),
+      Concat(str("a") :: str("b") :: binary("c") :: binary("d") :: Nil))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f2b3525c/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 052a3f5..530a525 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -2171,7 +2171,8 @@ object functions {
   def base64(e: Column): Column = withExpr { Base64(e.expr) }
 
   /**
-   * Concatenates multiple input string columns together into a single string 
column.
+   * Concatenates multiple input columns together into a single column.
+   * If all inputs are binary, concat returns an output as binary. Otherwise, 
it returns as string.
    *
    * @group string_funcs
    * @since 1.5.0

http://git-wip-us.apache.org/repos/asf/spark/blob/f2b3525c/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql 
b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql
index 40d0c06..4113734 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql
@@ -24,3 +24,26 @@ select left("abcd", 2), left("abcd", 5), left("abcd", '2'), 
left("abcd", null);
 select left(null, -2), left("abcd", -2), left("abcd", 0), left("abcd", 'a');
 select right("abcd", 2), right("abcd", 5), right("abcd", '2'), right("abcd", 
null);
 select right(null, -2), right("abcd", -2), right("abcd", 0), right("abcd", 
'a');
+
+-- turn off concatBinaryAsString
+set spark.sql.function.concatBinaryAsString=false;
+
+-- Check if catalyst combine nested `Concat`s if concatBinaryAsString=false
+EXPLAIN SELECT ((col1 || col2) || (col3 || col4)) col
+FROM (
+  SELECT
+    string(id) col1,
+    string(id + 1) col2,
+    encode(string(id + 2), 'utf-8') col3,
+    encode(string(id + 3), 'utf-8') col4
+  FROM range(10)
+);
+
+EXPLAIN SELECT (col1 || (col3 || col4)) col
+FROM (
+  SELECT
+    string(id) col1,
+    encode(string(id + 2), 'utf-8') col3,
+    encode(string(id + 3), 'utf-8') col4
+  FROM range(10)
+)

http://git-wip-us.apache.org/repos/asf/spark/blob/f2b3525c/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql 
b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql
new file mode 100644
index 0000000..0beebec
--- /dev/null
+++ 
b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql
@@ -0,0 +1,93 @@
+-- Concatenate mixed inputs (output type is string)
+SELECT (col1 || col2 || col3) col
+FROM (
+  SELECT
+    id col1,
+    string(id + 1) col2,
+    encode(string(id + 2), 'utf-8') col3
+  FROM range(10)
+);
+
+SELECT ((col1 || col2) || (col3 || col4) || col5) col
+FROM (
+  SELECT
+    'prefix_' col1,
+    id col2,
+    string(id + 1) col3,
+    encode(string(id + 2), 'utf-8') col4,
+    CAST(id AS DOUBLE) col5
+  FROM range(10)
+);
+
+SELECT ((col1 || col2) || (col3 || col4)) col
+FROM (
+  SELECT
+    string(id) col1,
+    string(id + 1) col2,
+    encode(string(id + 2), 'utf-8') col3,
+    encode(string(id + 3), 'utf-8') col4
+  FROM range(10)
+);
+
+-- turn on concatBinaryAsString
+set spark.sql.function.concatBinaryAsString=true;
+
+SELECT (col1 || col2) col
+FROM (
+  SELECT
+    encode(string(id), 'utf-8') col1,
+    encode(string(id + 1), 'utf-8') col2
+  FROM range(10)
+);
+
+SELECT (col1 || col2 || col3 || col4) col
+FROM (
+  SELECT
+    encode(string(id), 'utf-8') col1,
+    encode(string(id + 1), 'utf-8') col2,
+    encode(string(id + 2), 'utf-8') col3,
+    encode(string(id + 3), 'utf-8') col4
+  FROM range(10)
+);
+
+SELECT ((col1 || col2) || (col3 || col4)) col
+FROM (
+  SELECT
+    encode(string(id), 'utf-8') col1,
+    encode(string(id + 1), 'utf-8') col2,
+    encode(string(id + 2), 'utf-8') col3,
+    encode(string(id + 3), 'utf-8') col4
+  FROM range(10)
+);
+
+-- turn off concatBinaryAsString
+set spark.sql.function.concatBinaryAsString=false;
+
+-- Concatenate binary inputs (output type is binary)
+SELECT (col1 || col2) col
+FROM (
+  SELECT
+    encode(string(id), 'utf-8') col1,
+    encode(string(id + 1), 'utf-8') col2
+  FROM range(10)
+);
+
+SELECT (col1 || col2 || col3 || col4) col
+FROM (
+  SELECT
+    encode(string(id), 'utf-8') col1,
+    encode(string(id + 1), 'utf-8') col2,
+    encode(string(id + 2), 'utf-8') col3,
+    encode(string(id + 3), 'utf-8') col4
+  FROM range(10)
+);
+
+SELECT ((col1 || col2) || (col3 || col4)) col
+FROM (
+  SELECT
+    encode(string(id), 'utf-8') col1,
+    encode(string(id + 1), 'utf-8') col2,
+    encode(string(id + 2), 'utf-8') col3,
+    encode(string(id + 3), 'utf-8') col4
+  FROM range(10)
+);

http://git-wip-us.apache.org/repos/asf/spark/blob/f2b3525c/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out 
b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out
index 2d9b3d7..d5f8705 100644
--- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out
@@ -1,5 +1,5 @@
 -- Automatically generated by SQLQueryTestSuite
--- Number of queries: 12
+-- Number of queries: 15
 
 
 -- !query 0
@@ -118,3 +118,46 @@ select right(null, -2), right("abcd", -2), right("abcd", 
0), right("abcd", 'a')
 struct<right(NULL, -2):string,right('abcd', -2):string,right('abcd', 
0):string,right('abcd', 'a'):string>
 -- !query 11 output
 NULL                   NULL
+
+
+-- !query 12
+set spark.sql.function.concatBinaryAsString=false
+-- !query 12 schema
+struct<key:string,value:string>
+-- !query 12 output
+spark.sql.function.concatBinaryAsString        false
+
+
+-- !query 13
+EXPLAIN SELECT ((col1 || col2) || (col3 || col4)) col
+FROM (
+  SELECT
+    string(id) col1,
+    string(id + 1) col2,
+    encode(string(id + 2), 'utf-8') col3,
+    encode(string(id + 3), 'utf-8') col4
+  FROM range(10)
+)
+-- !query 13 schema
+struct<plan:string>
+-- !query 13 output
+== Physical Plan ==
+*Project [concat(cast(id#xL as string), cast((id#xL + 1) as string), 
cast(encode(cast((id#xL + 2) as string), utf-8) as string), 
cast(encode(cast((id#xL + 3) as string), utf-8) as string)) AS col#x]
++- *Range (0, 10, step=1, splits=2)
+
+
+-- !query 14
+EXPLAIN SELECT (col1 || (col3 || col4)) col
+FROM (
+  SELECT
+    string(id) col1,
+    encode(string(id + 2), 'utf-8') col3,
+    encode(string(id + 3), 'utf-8') col4
+  FROM range(10)
+)
+-- !query 14 schema
+struct<plan:string>
+-- !query 14 output
+== Physical Plan ==
+*Project [concat(cast(id#xL as string), cast(encode(cast((id#xL + 2) as 
string), utf-8) as string), cast(encode(cast((id#xL + 3) as string), utf-8) as 
string)) AS col#x]
++- *Range (0, 10, step=1, splits=2)

http://git-wip-us.apache.org/repos/asf/spark/blob/f2b3525c/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out
 
b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out
new file mode 100644
index 0000000..09729fd
--- /dev/null
+++ 
b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out
@@ -0,0 +1,239 @@
+-- Automatically generated by SQLQueryTestSuite
+-- Number of queries: 11
+
+
+-- !query 0
+SELECT (col1 || col2 || col3) col
+FROM (
+  SELECT
+    id col1,
+    string(id + 1) col2,
+    encode(string(id + 2), 'utf-8') col3
+  FROM range(10)
+)
+-- !query 0 schema
+struct<col:string>
+-- !query 0 output
+012
+123
+234
+345
+456
+567
+678
+789
+8910
+91011
+
+
+-- !query 1
+SELECT ((col1 || col2) || (col3 || col4) || col5) col
+FROM (
+  SELECT
+    'prefix_' col1,
+    id col2,
+    string(id + 1) col3,
+    encode(string(id + 2), 'utf-8') col4,
+    CAST(id AS DOUBLE) col5
+  FROM range(10)
+)
+-- !query 1 schema
+struct<col:string>
+-- !query 1 output
+prefix_0120.0
+prefix_1231.0
+prefix_2342.0
+prefix_3453.0
+prefix_4564.0
+prefix_5675.0
+prefix_6786.0
+prefix_7897.0
+prefix_89108.0
+prefix_910119.0
+
+
+-- !query 2
+SELECT ((col1 || col2) || (col3 || col4)) col
+FROM (
+  SELECT
+    string(id) col1,
+    string(id + 1) col2,
+    encode(string(id + 2), 'utf-8') col3,
+    encode(string(id + 3), 'utf-8') col4
+  FROM range(10)
+)
+-- !query 2 schema
+struct<col:string>
+-- !query 2 output
+0123
+1234
+2345
+3456
+4567
+5678
+6789
+78910
+891011
+9101112
+
+
+-- !query 3
+set spark.sql.function.concatBinaryAsString=true
+-- !query 3 schema
+struct<key:string,value:string>
+-- !query 3 output
+spark.sql.function.concatBinaryAsString        true
+
+
+-- !query 4
+SELECT (col1 || col2) col
+FROM (
+  SELECT
+    encode(string(id), 'utf-8') col1,
+    encode(string(id + 1), 'utf-8') col2
+  FROM range(10)
+)
+-- !query 4 schema
+struct<col:string>
+-- !query 4 output
+01
+12
+23
+34
+45
+56
+67
+78
+89
+910
+
+
+-- !query 5
+SELECT (col1 || col2 || col3 || col4) col
+FROM (
+  SELECT
+    encode(string(id), 'utf-8') col1,
+    encode(string(id + 1), 'utf-8') col2,
+    encode(string(id + 2), 'utf-8') col3,
+    encode(string(id + 3), 'utf-8') col4
+  FROM range(10)
+)
+-- !query 5 schema
+struct<col:string>
+-- !query 5 output
+0123
+1234
+2345
+3456
+4567
+5678
+6789
+78910
+891011
+9101112
+
+
+-- !query 6
+SELECT ((col1 || col2) || (col3 || col4)) col
+FROM (
+  SELECT
+    encode(string(id), 'utf-8') col1,
+    encode(string(id + 1), 'utf-8') col2,
+    encode(string(id + 2), 'utf-8') col3,
+    encode(string(id + 3), 'utf-8') col4
+  FROM range(10)
+)
+-- !query 6 schema
+struct<col:string>
+-- !query 6 output
+0123
+1234
+2345
+3456
+4567
+5678
+6789
+78910
+891011
+9101112
+
+
+-- !query 7
+set spark.sql.function.concatBinaryAsString=false
+-- !query 7 schema
+struct<key:string,value:string>
+-- !query 7 output
+spark.sql.function.concatBinaryAsString        false
+
+
+-- !query 8
+SELECT (col1 || col2) col
+FROM (
+  SELECT
+    encode(string(id), 'utf-8') col1,
+    encode(string(id + 1), 'utf-8') col2
+  FROM range(10)
+)
+-- !query 8 schema
+struct<col:binary>
+-- !query 8 output
+01
+12
+23
+34
+45
+56
+67
+78
+89
+910
+
+
+-- !query 9
+SELECT (col1 || col2 || col3 || col4) col
+FROM (
+  SELECT
+    encode(string(id), 'utf-8') col1,
+    encode(string(id + 1), 'utf-8') col2,
+    encode(string(id + 2), 'utf-8') col3,
+    encode(string(id + 3), 'utf-8') col4
+  FROM range(10)
+)
+-- !query 9 schema
+struct<col:binary>
+-- !query 9 output
+0123
+1234
+2345
+3456
+4567
+5678
+6789
+78910
+891011
+9101112
+
+
+-- !query 10
+SELECT ((col1 || col2) || (col3 || col4)) col
+FROM (
+  SELECT
+    encode(string(id), 'utf-8') col1,
+    encode(string(id + 1), 'utf-8') col2,
+    encode(string(id + 2), 'utf-8') col3,
+    encode(string(id + 3), 'utf-8') col4
+  FROM range(10)
+)
+-- !query 10 schema
+struct<col:binary>
+-- !query 10 output
+0123
+1234
+2345
+3456
+4567
+5678
+6789
+78910
+891011
+9101112


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to