Repository: spark
Updated Branches:
  refs/heads/branch-1.5 2cd96329f -> 3ed219f62


[SPARK-9738] [SQL] remove FromUnsafe and add its codegen version to GenerateSafe

In https://github.com/apache/spark/pull/7752 we added `FromUnsafe` to convert 
nexted unsafe data like array/map/struct to safe versions. It's a quick 
solution and we already have `GenerateSafe` to do the conversion which is 
codegened. So we should remove `FromUnsafe` and implement its codegen version 
in `GenerateSafe`.

Author: Wenchen Fan <cloud0...@outlook.com>

Closes #8029 from cloud-fan/from-unsafe and squashes the following commits:

ed40d8f [Wenchen Fan] add the copy back
a93fd4b [Wenchen Fan] cogengen FromUnsafe

(cherry picked from commit 106c0789d8c83c7081bc9a335df78ba728e95872)
Signed-off-by: Davies Liu <davies....@gmail.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3ed219f6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3ed219f6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3ed219f6

Branch: refs/heads/branch-1.5
Commit: 3ed219f62b956b8f49c175a5e51e86873e67fdc5
Parents: 2cd9632
Author: Wenchen Fan <cloud0...@outlook.com>
Authored: Sat Aug 8 08:33:14 2015 -0700
Committer: Davies Liu <davies....@gmail.com>
Committed: Sat Aug 8 08:33:28 2015 -0700

----------------------------------------------------------------------
 .../sql/catalyst/expressions/FromUnsafe.scala   |  70 -----------
 .../sql/catalyst/expressions/Projection.scala   |   8 +-
 .../codegen/GenerateSafeProjection.scala        | 120 ++++++++++++++-----
 .../execution/RowFormatConvertersSuite.scala    |   4 +-
 4 files changed, 95 insertions(+), 107 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3ed219f6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala
deleted file mode 100644
index 9b960b1..0000000
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FromUnsafe.scala
+++ /dev/null
@@ -1,70 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.expressions
-
-import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
-import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
-
-case class FromUnsafe(child: Expression) extends UnaryExpression
-  with ExpectsInputTypes with CodegenFallback {
-
-  override def inputTypes: Seq[AbstractDataType] =
-    Seq(TypeCollection(ArrayType, StructType, MapType))
-
-  override def dataType: DataType = child.dataType
-
-  private def convert(value: Any, dt: DataType): Any = dt match {
-    case StructType(fields) =>
-      val row = value.asInstanceOf[UnsafeRow]
-      val result = new Array[Any](fields.length)
-      fields.map(_.dataType).zipWithIndex.foreach { case (dt, i) =>
-        if (!row.isNullAt(i)) {
-          result(i) = convert(row.get(i, dt), dt)
-        }
-      }
-      new GenericInternalRow(result)
-
-    case ArrayType(elementType, _) =>
-      val array = value.asInstanceOf[UnsafeArrayData]
-      val length = array.numElements()
-      val result = new Array[Any](length)
-      var i = 0
-      while (i < length) {
-        if (!array.isNullAt(i)) {
-          result(i) = convert(array.get(i, elementType), elementType)
-        }
-        i += 1
-      }
-      new GenericArrayData(result)
-
-    case StringType => value.asInstanceOf[UTF8String].clone()
-
-    case MapType(kt, vt, _) =>
-      val map = value.asInstanceOf[UnsafeMapData]
-      val safeKeyArray = convert(map.keys, 
ArrayType(kt)).asInstanceOf[GenericArrayData]
-      val safeValueArray = convert(map.values, 
ArrayType(vt)).asInstanceOf[GenericArrayData]
-      new ArrayBasedMapData(safeKeyArray, safeValueArray)
-
-    case _ => value
-  }
-
-  override def nullSafeEval(input: Any): Any = {
-    convert(input, dataType)
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/3ed219f6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index 796bc32..afe52e6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -152,13 +152,7 @@ object FromUnsafeProjection {
    */
   def apply(fields: Seq[DataType]): Projection = {
     create(fields.zipWithIndex.map(x => {
-      val b = new BoundReference(x._2, x._1, true)
-      // todo: this is quite slow, maybe remove this whole projection after 
remove generic getter of
-      // InternalRow?
-      b.dataType match {
-        case _: StructType | _: ArrayType | _: MapType => FromUnsafe(b)
-        case _ => b
-      }
+      new BoundReference(x._2, x._1, true)
     }))
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/3ed219f6/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
index f06ffc5..ef08ddf 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
-import org.apache.spark.sql.types.{StringType, StructType, DataType}
+import org.apache.spark.sql.types._
 
 
 /**
@@ -36,34 +36,94 @@ object GenerateSafeProjection extends 
CodeGenerator[Seq[Expression], Projection]
   protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): 
Seq[Expression] =
     in.map(BindReferences.bindReference(_, inputSchema))
 
-  private def genUpdater(
+  private def createCodeForStruct(
       ctx: CodeGenContext,
-      setter: String,
-      dataType: DataType,
-      ordinal: Int,
-      value: String): String = {
-    dataType match {
-      case struct: StructType =>
-        val rowTerm = ctx.freshName("row")
-        val updates = struct.map(_.dataType).zipWithIndex.map { case (dt, i) =>
-          val colTerm = ctx.freshName("col")
-          s"""
-            if ($value.isNullAt($i)) {
-              $rowTerm.setNullAt($i);
-            } else {
-              ${ctx.javaType(dt)} $colTerm = ${ctx.getValue(value, dt, s"$i")};
-              ${genUpdater(ctx, rowTerm, dt, i, colTerm)};
-            }
-           """
-        }.mkString("\n")
-        s"""
-          $genericMutableRowType $rowTerm = new 
$genericMutableRowType(${struct.fields.length});
-          $updates
-          $setter.update($ordinal, $rowTerm.copy());
-        """
-      case _ =>
-        ctx.setColumn(setter, dataType, ordinal, value)
-    }
+      input: String,
+      schema: StructType): GeneratedExpressionCode = {
+    val tmp = ctx.freshName("tmp")
+    val output = ctx.freshName("safeRow")
+    val values = ctx.freshName("values")
+    val rowClass = classOf[GenericInternalRow].getName
+
+    val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) 
=>
+      val converter = convertToSafe(ctx, ctx.getValue(tmp, dt, i.toString), dt)
+      s"""
+        if (!$tmp.isNullAt($i)) {
+          ${converter.code}
+          $values[$i] = ${converter.primitive};
+        }
+      """
+    }.mkString("\n")
+
+    val code = s"""
+      final InternalRow $tmp = $input;
+      final Object[] $values = new Object[${schema.length}];
+      $fieldWriters
+      final InternalRow $output = new $rowClass($values);
+    """
+
+    GeneratedExpressionCode(code, "false", output)
+  }
+
+  private def createCodeForArray(
+      ctx: CodeGenContext,
+      input: String,
+      elementType: DataType): GeneratedExpressionCode = {
+    val tmp = ctx.freshName("tmp")
+    val output = ctx.freshName("safeArray")
+    val values = ctx.freshName("values")
+    val numElements = ctx.freshName("numElements")
+    val index = ctx.freshName("index")
+    val arrayClass = classOf[GenericArrayData].getName
+
+    val elementConverter = convertToSafe(ctx, ctx.getValue(tmp, elementType, 
index), elementType)
+    val code = s"""
+      final ArrayData $tmp = $input;
+      final int $numElements = $tmp.numElements();
+      final Object[] $values = new Object[$numElements];
+      for (int $index = 0; $index < $numElements; $index++) {
+        if (!$tmp.isNullAt($index)) {
+          ${elementConverter.code}
+          $values[$index] = ${elementConverter.primitive};
+        }
+      }
+      final ArrayData $output = new $arrayClass($values);
+    """
+
+    GeneratedExpressionCode(code, "false", output)
+  }
+
+  private def createCodeForMap(
+      ctx: CodeGenContext,
+      input: String,
+      keyType: DataType,
+      valueType: DataType): GeneratedExpressionCode = {
+    val tmp = ctx.freshName("tmp")
+    val output = ctx.freshName("safeMap")
+    val mapClass = classOf[ArrayBasedMapData].getName
+
+    val keyConverter = createCodeForArray(ctx, s"$tmp.keyArray()", keyType)
+    val valueConverter = createCodeForArray(ctx, s"$tmp.valueArray()", 
valueType)
+    val code = s"""
+      final MapData $tmp = $input;
+      ${keyConverter.code}
+      ${valueConverter.code}
+      final MapData $output = new $mapClass(${keyConverter.primitive}, 
${valueConverter.primitive});
+    """
+
+    GeneratedExpressionCode(code, "false", output)
+  }
+
+  private def convertToSafe(
+      ctx: CodeGenContext,
+      input: String,
+      dataType: DataType): GeneratedExpressionCode = dataType match {
+    case s: StructType => createCodeForStruct(ctx, input, s)
+    case ArrayType(elementType, _) => createCodeForArray(ctx, input, 
elementType)
+    case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, 
keyType, valueType)
+    // UTF8String act as a pointer if it's inside UnsafeRow, so copy it to 
make it safe.
+    case StringType => GeneratedExpressionCode("", "false", s"$input.clone()")
+    case _ => GeneratedExpressionCode("", "false", input)
   }
 
   protected def create(expressions: Seq[Expression]): Projection = {
@@ -72,12 +132,14 @@ object GenerateSafeProjection extends 
CodeGenerator[Seq[Expression], Projection]
       case (NoOp, _) => ""
       case (e, i) =>
         val evaluationCode = e.gen(ctx)
+        val converter = convertToSafe(ctx, evaluationCode.primitive, 
e.dataType)
         evaluationCode.code +
           s"""
             if (${evaluationCode.isNull}) {
               mutableRow.setNullAt($i);
             } else {
-              ${genUpdater(ctx, "mutableRow", e.dataType, i, 
evaluationCode.primitive)};
+              ${converter.code}
+              ${ctx.setColumn("mutableRow", e.dataType, i, 
converter.primitive)};
             }
           """
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/3ed219f6/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
index 322966f..dd08e90 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
@@ -112,7 +112,9 @@ case class DummyPlan(child: SparkPlan) extends UnaryNode {
 
   override protected def doExecute(): RDD[InternalRow] = {
     child.execute().mapPartitions { iter =>
-      // cache all strings to make sure we have deep copied UTF8String inside 
incoming
+      // This `DummyPlan` is in safe mode, so we don't need to do copy even we 
hold some
+      // values gotten from the incoming rows.
+      // we cache all strings here to make sure we have deep copied UTF8String 
inside incoming
       // safe InternalRow.
       val strings = new scala.collection.mutable.ArrayBuffer[UTF8String]
       iter.foreach { row =>


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to