Repository: spark
Updated Branches:
  refs/heads/master 517bdf36a -> ce7ddabbc


[SPARK-6368][SQL] Build a specialized serializer for Exchange operator.

JIRA: https://issues.apache.org/jira/browse/SPARK-6368

Author: Yin Huai <yh...@databricks.com>

Closes #5497 from yhuai/serializer2 and squashes the following commits:

da562c5 [Yin Huai] Merge remote-tracking branch 'upstream/master' into 
serializer2
50e0c3d [Yin Huai] When no filed is emitted to shuffle, use SparkSqlSerializer 
for now.
9f1ed92 [Yin Huai] Merge remote-tracking branch 'upstream/master' into 
serializer2
6d07678 [Yin Huai] Address comments.
4273b8c [Yin Huai] Enabled SparkSqlSerializer2.
09e587a [Yin Huai] Remove TODO.
791b96a [Yin Huai] Use UTF8String.
60a1487 [Yin Huai] Merge remote-tracking branch 'upstream/master' into 
serializer2
3e09655 [Yin Huai] Use getAs for Date column.
43b9fb4 [Yin Huai] Test.
8297732 [Yin Huai] Fix test.
c9373c8 [Yin Huai] Support DecimalType.
2379eeb [Yin Huai] ASF header.
39704ab [Yin Huai] Specialized serializer for Exchange.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ce7ddabb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ce7ddabb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ce7ddabb

Branch: refs/heads/master
Commit: ce7ddabbcd330b19f6d0c17082304dfa6e1621b2
Parents: 517bdf3
Author: Yin Huai <yh...@databricks.com>
Authored: Mon Apr 20 18:42:50 2015 -0700
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Mon Apr 20 18:42:50 2015 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/sql/SQLConf.scala    |   4 +
 .../apache/spark/sql/execution/Exchange.scala   |  59 ++-
 .../sql/execution/SparkSqlSerializer2.scala     | 421 +++++++++++++++++++
 .../execution/SparkSqlSerializer2Suite.scala    | 195 +++++++++
 4 files changed, 673 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ce7ddabb/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 5c65f04..4fc5de7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -64,6 +64,8 @@ private[spark] object SQLConf {
   // Set to false when debugging requires the ability to look at invalid query 
plans.
   val DATAFRAME_EAGER_ANALYSIS = "spark.sql.eagerAnalysis"
 
+  val USE_SQL_SERIALIZER2 = "spark.sql.useSerializer2"
+
   object Deprecated {
     val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
   }
@@ -147,6 +149,8 @@ private[sql] class SQLConf extends Serializable {
    */
   private[spark] def codegenEnabled: Boolean = getConf(CODEGEN_ENABLED, 
"false").toBoolean
 
+  private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2, 
"true").toBoolean
+
   /**
    * Upper bound on the sizes (in bytes) of the tables qualified for the auto 
conversion to
    * a broadcast value during the physical executions of join operations.  
Setting this to -1

http://git-wip-us.apache.org/repos/asf/spark/blob/ce7ddabb/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 69a620e..5b2e469 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -19,13 +19,15 @@ package org.apache.spark.sql.execution
 
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.shuffle.sort.SortShuffleManager
-import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner, 
SparkConf}
+import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner}
 import org.apache.spark.rdd.{RDD, ShuffledRDD}
+import org.apache.spark.serializer.Serializer
 import org.apache.spark.sql.{SQLContext, Row}
 import org.apache.spark.sql.catalyst.errors.attachTree
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.types.DataType
 import org.apache.spark.util.MutablePair
 
 object Exchange {
@@ -77,9 +79,48 @@ case class Exchange(
     }
   }
 
-  override def execute(): RDD[Row] = attachTree(this , "execute") {
-    lazy val sparkConf = child.sqlContext.sparkContext.getConf
+  @transient private lazy val sparkConf = child.sqlContext.sparkContext.getConf
+
+  def serializer(
+      keySchema: Array[DataType],
+      valueSchema: Array[DataType],
+      numPartitions: Int): Serializer = {
+    // In ExternalSorter's spillToMergeableFile function, key-value pairs are 
written out
+    // through write(key) and then write(value) instead of write((key, 
value)). Because
+    // SparkSqlSerializer2 assumes that objects passed in are Product2, we 
cannot safely use
+    // it when spillToMergeableFile in ExternalSorter will be used.
+    // So, we will not use SparkSqlSerializer2 when
+    //  - Sort-based shuffle is enabled and the number of reducers 
(numPartitions) is greater
+    //     then the bypassMergeThreshold; or
+    //  - newOrdering is defined.
+    val cannotUseSqlSerializer2 =
+      (sortBasedShuffleOn && numPartitions > bypassMergeThreshold) || 
newOrdering.nonEmpty
+
+    // It is true when there is no field that needs to be write out.
+    // For now, we will not use SparkSqlSerializer2 when noField is true.
+    val noField =
+      (keySchema == null || keySchema.length == 0) &&
+      (valueSchema == null || valueSchema.length == 0)
+
+    val useSqlSerializer2 =
+        child.sqlContext.conf.useSqlSerializer2 &&   // SparkSqlSerializer2 is 
enabled.
+        !cannotUseSqlSerializer2 &&                  // Safe to use 
Serializer2.
+        SparkSqlSerializer2.support(keySchema) &&    // The schema of key is 
supported.
+        SparkSqlSerializer2.support(valueSchema) &&  // The schema of value is 
supported.
+        !noField
+
+    val serializer = if (useSqlSerializer2) {
+      logInfo("Using SparkSqlSerializer2.")
+      new SparkSqlSerializer2(keySchema, valueSchema)
+    } else {
+      logInfo("Using SparkSqlSerializer.")
+      new SparkSqlSerializer(sparkConf)
+    }
+
+    serializer
+  }
 
+  override def execute(): RDD[Row] = attachTree(this , "execute") {
     newPartitioning match {
       case HashPartitioning(expressions, numPartitions) =>
         // TODO: Eliminate redundant expressions in grouping key and value.
@@ -111,7 +152,10 @@ case class Exchange(
           } else {
             new ShuffledRDD[Row, Row, Row](rdd, part)
           }
-        shuffled.setSerializer(new SparkSqlSerializer(sparkConf))
+        val keySchema = expressions.map(_.dataType).toArray
+        val valueSchema = child.output.map(_.dataType).toArray
+        shuffled.setSerializer(serializer(keySchema, valueSchema, 
numPartitions))
+
         shuffled.map(_._2)
 
       case RangePartitioning(sortingExpressions, numPartitions) =>
@@ -134,7 +178,9 @@ case class Exchange(
           } else {
             new ShuffledRDD[Row, Null, Null](rdd, part)
           }
-        shuffled.setSerializer(new SparkSqlSerializer(sparkConf))
+        val keySchema = child.output.map(_.dataType).toArray
+        shuffled.setSerializer(serializer(keySchema, null, numPartitions))
+
         shuffled.map(_._1)
 
       case SinglePartition =>
@@ -152,7 +198,8 @@ case class Exchange(
         }
         val partitioner = new HashPartitioner(1)
         val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner)
-        shuffled.setSerializer(new SparkSqlSerializer(sparkConf))
+        val valueSchema = child.output.map(_.dataType).toArray
+        shuffled.setSerializer(serializer(null, valueSchema, 1))
         shuffled.map(_._2)
 
       case _ => sys.error(s"Exchange not implemented for $newPartitioning")

http://git-wip-us.apache.org/repos/asf/spark/blob/ce7ddabb/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
new file mode 100644
index 0000000..cec97de
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
@@ -0,0 +1,421 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import java.io._
+import java.math.{BigDecimal, BigInteger}
+import java.nio.ByteBuffer
+import java.sql.Timestamp
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.serializer._
+import org.apache.spark.Logging
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
+import org.apache.spark.sql.types._
+
+/**
+ * The serialization stream for [[SparkSqlSerializer2]]. It assumes that the 
object passed in
+ * its `writeObject` are [[Product2]]. The serialization functions for the key 
and value of the
+ * [[Product2]] are constructed based on their schemata.
+ * The benefit of this serialization stream is that compared with 
general-purpose serializers like
+ * Kryo and Java serializer, it can significantly reduce the size of 
serialized and has a lower
+ * allocation cost, which can benefit the shuffle operation. Right now, its 
main limitations are:
+ *  1. It does not support complex types, i.e. Map, Array, and Struct.
+ *  2. It assumes that the objects passed in are [[Product2]]. So, it cannot 
be used when
+ *     [[org.apache.spark.util.collection.ExternalSorter]]'s merge sort 
operation is used because
+ *     the objects passed in the serializer are not in the type of 
[[Product2]]. Also also see
+ *     the comment of the `serializer` method in [[Exchange]] for more 
information on it.
+ */
+private[sql] class Serializer2SerializationStream(
+    keySchema: Array[DataType],
+    valueSchema: Array[DataType],
+    out: OutputStream)
+  extends SerializationStream with Logging {
+
+  val rowOut = new DataOutputStream(out)
+  val writeKey = SparkSqlSerializer2.createSerializationFunction(keySchema, 
rowOut)
+  val writeValue = 
SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut)
+
+  def writeObject[T: ClassTag](t: T): SerializationStream = {
+    val kv = t.asInstanceOf[Product2[Row, Row]]
+    writeKey(kv._1)
+    writeValue(kv._2)
+
+    this
+  }
+
+  def flush(): Unit = {
+    rowOut.flush()
+  }
+
+  def close(): Unit = {
+    rowOut.close()
+  }
+}
+
+/**
+ * The corresponding deserialization stream for 
[[Serializer2SerializationStream]].
+ */
+private[sql] class Serializer2DeserializationStream(
+    keySchema: Array[DataType],
+    valueSchema: Array[DataType],
+    in: InputStream)
+  extends DeserializationStream with Logging  {
+
+  val rowIn = new DataInputStream(new BufferedInputStream(in))
+
+  val key = if (keySchema != null) new SpecificMutableRow(keySchema) else null
+  val value = if (valueSchema != null) new SpecificMutableRow(valueSchema) 
else null
+  val readKey = SparkSqlSerializer2.createDeserializationFunction(keySchema, 
rowIn, key)
+  val readValue = 
SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn, value)
+
+  def readObject[T: ClassTag](): T = {
+    readKey()
+    readValue()
+
+    (key, value).asInstanceOf[T]
+  }
+
+  def close(): Unit = {
+    rowIn.close()
+  }
+}
+
+private[sql] class ShuffleSerializerInstance(
+    keySchema: Array[DataType],
+    valueSchema: Array[DataType])
+  extends SerializerInstance {
+
+  def serialize[T: ClassTag](t: T): ByteBuffer =
+    throw new UnsupportedOperationException("Not supported.")
+
+  def deserialize[T: ClassTag](bytes: ByteBuffer): T =
+    throw new UnsupportedOperationException("Not supported.")
+
+  def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T =
+    throw new UnsupportedOperationException("Not supported.")
+
+  def serializeStream(s: OutputStream): SerializationStream = {
+    new Serializer2SerializationStream(keySchema, valueSchema, s)
+  }
+
+  def deserializeStream(s: InputStream): DeserializationStream = {
+    new Serializer2DeserializationStream(keySchema, valueSchema, s)
+  }
+}
+
+/**
+ * SparkSqlSerializer2 is a special serializer that creates serialization 
function and
+ * deserialization function based on the schema of data. It assumes that 
values passed in
+ * are key/value pairs and values returned from it are also key/value pairs.
+ * The schema of keys is represented by `keySchema` and that of values is 
represented by
+ * `valueSchema`.
+ */
+private[sql] class SparkSqlSerializer2(keySchema: Array[DataType], 
valueSchema: Array[DataType])
+  extends Serializer
+  with Logging
+  with Serializable{
+
+  def newInstance(): SerializerInstance = new 
ShuffleSerializerInstance(keySchema, valueSchema)
+}
+
+private[sql] object SparkSqlSerializer2 {
+
+  final val NULL = 0
+  final val NOT_NULL = 1
+
+  /**
+   * Check if rows with the given schema can be serialized with 
ShuffleSerializer.
+   */
+  def support(schema: Array[DataType]): Boolean = {
+    if (schema == null) return true
+
+    var i = 0
+    while (i < schema.length) {
+      schema(i) match {
+        case udt: UserDefinedType[_] => return false
+        case array: ArrayType => return false
+        case map: MapType => return false
+        case struct: StructType => return false
+        case _ =>
+      }
+      i += 1
+    }
+
+    return true
+  }
+
+  /**
+   * The util function to create the serialization function based on the given 
schema.
+   */
+  def createSerializationFunction(schema: Array[DataType], out: 
DataOutputStream): Row => Unit = {
+    (row: Row) =>
+      // If the schema is null, the returned function does nothing when it get 
called.
+      if (schema != null) {
+        var i = 0
+        while (i < schema.length) {
+          schema(i) match {
+            // When we write values to the underlying stream, we also first 
write the null byte
+            // first. Then, if the value is not null, we write the contents 
out.
+
+            case NullType => // Write nothing.
+
+            case BooleanType =>
+              if (row.isNullAt(i)) {
+                out.writeByte(NULL)
+              } else {
+                out.writeByte(NOT_NULL)
+                out.writeBoolean(row.getBoolean(i))
+              }
+
+            case ByteType =>
+              if (row.isNullAt(i)) {
+                out.writeByte(NULL)
+              } else {
+                out.writeByte(NOT_NULL)
+                out.writeByte(row.getByte(i))
+              }
+
+            case ShortType =>
+              if (row.isNullAt(i)) {
+                out.writeByte(NULL)
+              } else {
+                out.writeByte(NOT_NULL)
+                out.writeShort(row.getShort(i))
+              }
+
+            case IntegerType =>
+              if (row.isNullAt(i)) {
+                out.writeByte(NULL)
+              } else {
+                out.writeByte(NOT_NULL)
+                out.writeInt(row.getInt(i))
+              }
+
+            case LongType =>
+              if (row.isNullAt(i)) {
+                out.writeByte(NULL)
+              } else {
+                out.writeByte(NOT_NULL)
+                out.writeLong(row.getLong(i))
+              }
+
+            case FloatType =>
+              if (row.isNullAt(i)) {
+                out.writeByte(NULL)
+              } else {
+                out.writeByte(NOT_NULL)
+                out.writeFloat(row.getFloat(i))
+              }
+
+            case DoubleType =>
+              if (row.isNullAt(i)) {
+                out.writeByte(NULL)
+              } else {
+                out.writeByte(NOT_NULL)
+                out.writeDouble(row.getDouble(i))
+              }
+
+            case decimal: DecimalType =>
+              if (row.isNullAt(i)) {
+                out.writeByte(NULL)
+              } else {
+                out.writeByte(NOT_NULL)
+                val value = row.apply(i).asInstanceOf[Decimal]
+                val javaBigDecimal = value.toJavaBigDecimal
+                // First, write out the unscaled value.
+                val bytes: Array[Byte] = 
javaBigDecimal.unscaledValue().toByteArray
+                out.writeInt(bytes.length)
+                out.write(bytes)
+                // Then, write out the scale.
+                out.writeInt(javaBigDecimal.scale())
+              }
+
+            case DateType =>
+              if (row.isNullAt(i)) {
+                out.writeByte(NULL)
+              } else {
+                out.writeByte(NOT_NULL)
+                out.writeInt(row.getAs[Int](i))
+              }
+
+            case TimestampType =>
+              if (row.isNullAt(i)) {
+                out.writeByte(NULL)
+              } else {
+                out.writeByte(NOT_NULL)
+                val timestamp = row.getAs[java.sql.Timestamp](i)
+                val time = timestamp.getTime
+                val nanos = timestamp.getNanos
+                out.writeLong(time - (nanos / 1000000)) // Write the 
milliseconds value.
+                out.writeInt(nanos)                     // Write the 
nanoseconds part.
+              }
+
+            case StringType =>
+              if (row.isNullAt(i)) {
+                out.writeByte(NULL)
+              } else {
+                out.writeByte(NOT_NULL)
+                val bytes = row.getAs[UTF8String](i).getBytes
+                out.writeInt(bytes.length)
+                out.write(bytes)
+              }
+
+            case BinaryType =>
+              if (row.isNullAt(i)) {
+                out.writeByte(NULL)
+              } else {
+                out.writeByte(NOT_NULL)
+                val bytes = row.getAs[Array[Byte]](i)
+                out.writeInt(bytes.length)
+                out.write(bytes)
+              }
+          }
+          i += 1
+        }
+      }
+  }
+
+  /**
+   * The util function to create the deserialization function based on the 
given schema.
+   */
+  def createDeserializationFunction(
+      schema: Array[DataType],
+      in: DataInputStream,
+      mutableRow: SpecificMutableRow): () => Unit = {
+    () => {
+      // If the schema is null, the returned function does nothing when it get 
called.
+      if (schema != null) {
+        var i = 0
+        while (i < schema.length) {
+          schema(i) match {
+            // When we read values from the underlying stream, we also first 
read the null byte
+            // first. Then, if the value is not null, we update the field of 
the mutable row.
+
+            case NullType => mutableRow.setNullAt(i) // Read nothing.
+
+            case BooleanType =>
+              if (in.readByte() == NULL) {
+                mutableRow.setNullAt(i)
+              } else {
+                mutableRow.setBoolean(i, in.readBoolean())
+              }
+
+            case ByteType =>
+              if (in.readByte() == NULL) {
+                mutableRow.setNullAt(i)
+              } else {
+                mutableRow.setByte(i, in.readByte())
+              }
+
+            case ShortType =>
+              if (in.readByte() == NULL) {
+                mutableRow.setNullAt(i)
+              } else {
+                mutableRow.setShort(i, in.readShort())
+              }
+
+            case IntegerType =>
+              if (in.readByte() == NULL) {
+                mutableRow.setNullAt(i)
+              } else {
+                mutableRow.setInt(i, in.readInt())
+              }
+
+            case LongType =>
+              if (in.readByte() == NULL) {
+                mutableRow.setNullAt(i)
+              } else {
+                mutableRow.setLong(i, in.readLong())
+              }
+
+            case FloatType =>
+              if (in.readByte() == NULL) {
+                mutableRow.setNullAt(i)
+              } else {
+                mutableRow.setFloat(i, in.readFloat())
+              }
+
+            case DoubleType =>
+              if (in.readByte() == NULL) {
+                mutableRow.setNullAt(i)
+              } else {
+                mutableRow.setDouble(i, in.readDouble())
+              }
+
+            case decimal: DecimalType =>
+              if (in.readByte() == NULL) {
+                mutableRow.setNullAt(i)
+              } else {
+                // First, read in the unscaled value.
+                val length = in.readInt()
+                val bytes = new Array[Byte](length)
+                in.readFully(bytes)
+                val unscaledVal = new BigInteger(bytes)
+                // Then, read the scale.
+                val scale = in.readInt()
+                // Finally, create the Decimal object and set it in the row.
+                mutableRow.update(i, Decimal(new BigDecimal(unscaledVal, 
scale)))
+              }
+
+            case DateType =>
+              if (in.readByte() == NULL) {
+                mutableRow.setNullAt(i)
+              } else {
+                mutableRow.update(i, in.readInt())
+              }
+
+            case TimestampType =>
+              if (in.readByte() == NULL) {
+                mutableRow.setNullAt(i)
+              } else {
+                val time = in.readLong() // Read the milliseconds value.
+                val nanos = in.readInt() // Read the nanoseconds part.
+                val timestamp = new Timestamp(time)
+                timestamp.setNanos(nanos)
+                mutableRow.update(i, timestamp)
+              }
+
+            case StringType =>
+              if (in.readByte() == NULL) {
+                mutableRow.setNullAt(i)
+              } else {
+                val length = in.readInt()
+                val bytes = new Array[Byte](length)
+                in.readFully(bytes)
+                mutableRow.update(i, UTF8String(bytes))
+              }
+
+            case BinaryType =>
+              if (in.readByte() == NULL) {
+                mutableRow.setNullAt(i)
+              } else {
+                val length = in.readInt()
+                val bytes = new Array[Byte](length)
+                in.readFully(bytes)
+                mutableRow.update(i, bytes)
+              }
+          }
+          i += 1
+        }
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/ce7ddabb/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala
new file mode 100644
index 0000000..27f063d
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlSerializer2Suite.scala
@@ -0,0 +1,195 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import java.sql.{Timestamp, Date}
+
+import org.scalatest.{FunSuite, BeforeAndAfterAll}
+
+import org.apache.spark.rdd.ShuffledRDD
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.ShuffleDependency
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.test.TestSQLContext._
+import org.apache.spark.sql.{MyDenseVectorUDT, QueryTest}
+
+class SparkSqlSerializer2DataTypeSuite extends FunSuite {
+  // Make sure that we will not use serializer2 for unsupported data types.
+  def checkSupported(dataType: DataType, isSupported: Boolean): Unit = {
+    val testName =
+      s"${if (dataType == null) null else dataType.toString} is " +
+        s"${if (isSupported) "supported" else "unsupported"}"
+
+    test(testName) {
+      assert(SparkSqlSerializer2.support(Array(dataType)) === isSupported)
+    }
+  }
+
+  checkSupported(null, isSupported = true)
+  checkSupported(NullType, isSupported = true)
+  checkSupported(BooleanType, isSupported = true)
+  checkSupported(ByteType, isSupported = true)
+  checkSupported(ShortType, isSupported = true)
+  checkSupported(IntegerType, isSupported = true)
+  checkSupported(LongType, isSupported = true)
+  checkSupported(FloatType, isSupported = true)
+  checkSupported(DoubleType, isSupported = true)
+  checkSupported(DateType, isSupported = true)
+  checkSupported(TimestampType, isSupported = true)
+  checkSupported(StringType, isSupported = true)
+  checkSupported(BinaryType, isSupported = true)
+  checkSupported(DecimalType(10, 5), isSupported = true)
+  checkSupported(DecimalType.Unlimited, isSupported = true)
+
+  // For now, ArrayType, MapType, and StructType are not supported.
+  checkSupported(ArrayType(DoubleType, true), isSupported = false)
+  checkSupported(ArrayType(StringType, false), isSupported = false)
+  checkSupported(MapType(IntegerType, StringType, true), isSupported = false)
+  checkSupported(MapType(IntegerType, ArrayType(DoubleType), false), 
isSupported = false)
+  checkSupported(StructType(StructField("a", IntegerType, true) :: Nil), 
isSupported = false)
+  // UDTs are not supported right now.
+  checkSupported(new MyDenseVectorUDT, isSupported = false)
+}
+
+abstract class SparkSqlSerializer2Suite extends QueryTest with 
BeforeAndAfterAll {
+  var allColumns: String = _
+  val serializerClass: Class[Serializer] =
+    classOf[SparkSqlSerializer2].asInstanceOf[Class[Serializer]]
+  var numShufflePartitions: Int = _
+  var useSerializer2: Boolean = _
+
+  override def beforeAll(): Unit = {
+    numShufflePartitions = conf.numShufflePartitions
+    useSerializer2 = conf.useSqlSerializer2
+
+    sql("set spark.sql.useSerializer2=true")
+
+    val supportedTypes =
+      Seq(StringType, BinaryType, NullType, BooleanType,
+        ByteType, ShortType, IntegerType, LongType,
+        FloatType, DoubleType, DecimalType.Unlimited, DecimalType(6, 5),
+        DateType, TimestampType)
+
+    val fields = supportedTypes.zipWithIndex.map { case (dataType, index) =>
+      StructField(s"col$index", dataType, true)
+    }
+    allColumns = fields.map(_.name).mkString(",")
+    val schema = StructType(fields)
+
+    // Create a RDD with all data types supported by SparkSqlSerializer2.
+    val rdd =
+      sparkContext.parallelize((1 to 1000), 10).map { i =>
+        Row(
+          s"str${i}: test serializer2.",
+          s"binary${i}: test serializer2.".getBytes("UTF-8"),
+          null,
+          i % 2 == 0,
+          i.toByte,
+          i.toShort,
+          i,
+          Long.MaxValue - i.toLong,
+          (i + 0.25).toFloat,
+          (i + 0.75),
+          BigDecimal(Long.MaxValue.toString + ".12345"),
+          new java.math.BigDecimal(s"${i % 9 + 1}" + ".23456"),
+          new Date(i),
+          new Timestamp(i))
+      }
+
+    createDataFrame(rdd, schema).registerTempTable("shuffle")
+
+    super.beforeAll()
+  }
+
+  override def afterAll(): Unit = {
+    dropTempTable("shuffle")
+    sql(s"set spark.sql.shuffle.partitions=$numShufflePartitions")
+    sql(s"set spark.sql.useSerializer2=$useSerializer2")
+    super.afterAll()
+  }
+
+  def checkSerializer[T <: Serializer](
+      executedPlan: SparkPlan,
+      expectedSerializerClass: Class[T]): Unit = {
+    executedPlan.foreach {
+      case exchange: Exchange =>
+        val shuffledRDD = 
exchange.execute().firstParent.asInstanceOf[ShuffledRDD[_, _, _]]
+        val dependency = 
shuffledRDD.getDependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
+        val serializerNotSetMessage =
+          s"Expected $expectedSerializerClass as the serializer of Exchange. " 
+
+          s"However, the serializer was not set."
+        val serializer = 
dependency.serializer.getOrElse(fail(serializerNotSetMessage))
+        assert(serializer.getClass === expectedSerializerClass)
+      case _ => // Ignore other nodes.
+    }
+  }
+
+  test("key schema and value schema are not nulls") {
+    val df = sql(s"SELECT DISTINCT ${allColumns} FROM shuffle")
+    checkSerializer(df.queryExecution.executedPlan, serializerClass)
+    checkAnswer(
+      df,
+      table("shuffle").collect())
+  }
+
+  test("value schema is null") {
+    val df = sql(s"SELECT col0 FROM shuffle ORDER BY col0")
+    checkSerializer(df.queryExecution.executedPlan, serializerClass)
+    assert(
+      df.map(r => r.getString(0)).collect().toSeq ===
+      table("shuffle").select("col0").map(r => 
r.getString(0)).collect().sorted.toSeq)
+  }
+
+  test("no map output field") {
+    val df = sql(s"SELECT 1 + 1 FROM shuffle")
+    checkSerializer(df.queryExecution.executedPlan, 
classOf[SparkSqlSerializer])
+  }
+}
+
+/** Tests SparkSqlSerializer2 with sort based shuffle without sort merge. */
+class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite {
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    // Sort merge will not be triggered.
+    sql("set spark.sql.shuffle.partitions = 200")
+  }
+
+  test("key schema is null") {
+    val aggregations = allColumns.split(",").map(c => 
s"COUNT($c)").mkString(",")
+    val df = sql(s"SELECT $aggregations FROM shuffle")
+    checkSerializer(df.queryExecution.executedPlan, serializerClass)
+    checkAnswer(
+      df,
+      Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 
1000, 1000))
+  }
+}
+
+/** For now, we will use SparkSqlSerializer for sort based shuffle with sort 
merge. */
+class SparkSqlSerializer2SortMergeShuffleSuite extends 
SparkSqlSerializer2Suite {
+
+  // We are expecting SparkSqlSerializer.
+  override val serializerClass: Class[Serializer] =
+    classOf[SparkSqlSerializer].asInstanceOf[Class[Serializer]]
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    // To trigger the sort merge.
+    sql("set spark.sql.shuffle.partitions = 201")
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to