Repository: spark
Updated Branches:
  refs/heads/master ced6ccf0d -> 132a3f470


[SPARK-22500][SQL][FOLLOWUP] cast for struct can split code even with whole 
stage codegen

## What changes were proposed in this pull request?

A followup of https://github.com/apache/spark/pull/19730, we can split the code 
for casting struct even with whole stage codegen.

This PR also has some renaming to make the code easier to read.

## How was this patch tested?

existing test

Author: Wenchen Fan <wenc...@databricks.com>

Closes #19891 from cloud-fan/cast.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/132a3f47
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/132a3f47
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/132a3f47

Branch: refs/heads/master
Commit: 132a3f470811bb98f265d0c9ad2c161698e0237b
Parents: ced6ccf
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Tue Dec 5 11:40:13 2017 -0800
Committer: gatorsmile <gatorsm...@gmail.com>
Committed: Tue Dec 5 11:40:13 2017 -0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/expressions/Cast.scala   | 52 +++++++++-----------
 1 file changed, 24 insertions(+), 28 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/132a3f47/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index f4ecbdb..b8d3661 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -548,8 +548,8 @@ case class Cast(child: Expression, dataType: DataType, 
timeZoneId: Option[String
       castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, 
nullSafeCast))
   }
 
-  // three function arguments are: child.primitive, result.primitive and 
result.isNull
-  // it returns the code snippets to be put in null safe evaluation region
+  // The function arguments are: `input`, `result` and `resultIsNull`. We 
don't need `inputIsNull`
+  // in parameter list, because the returned code will be put in null safe 
evaluation region.
   private[this] type CastFunction = (String, String, String) => String
 
   private[this] def nullSafeCastFunction(
@@ -584,15 +584,15 @@ case class Cast(child: Expression, dataType: DataType, 
timeZoneId: Option[String
       throw new SparkException(s"Cannot cast $from to $to.")
   }
 
-  // Since we need to cast child expressions recursively inside ComplexTypes, 
such as Map's
+  // Since we need to cast input expressions recursively inside ComplexTypes, 
such as Map's
   // Key and Value, Struct's field, we need to name out all the variable names 
involved in a cast.
-  private[this] def castCode(ctx: CodegenContext, childPrim: String, 
childNull: String,
-    resultPrim: String, resultNull: String, resultType: DataType, cast: 
CastFunction): String = {
+  private[this] def castCode(ctx: CodegenContext, input: String, inputIsNull: 
String,
+    result: String, resultIsNull: String, resultType: DataType, cast: 
CastFunction): String = {
     s"""
-      boolean $resultNull = $childNull;
-      ${ctx.javaType(resultType)} $resultPrim = 
${ctx.defaultValue(resultType)};
-      if (!$childNull) {
-        ${cast(childPrim, resultPrim, resultNull)}
+      boolean $resultIsNull = $inputIsNull;
+      ${ctx.javaType(resultType)} $result = ${ctx.defaultValue(resultType)};
+      if (!$inputIsNull) {
+        ${cast(input, result, resultIsNull)}
       }
     """
   }
@@ -1014,8 +1014,8 @@ case class Cast(child: Expression, dataType: DataType, 
timeZoneId: Option[String
       case (fromField, toField) => nullSafeCastFunction(fromField.dataType, 
toField.dataType, ctx)
     }
     val rowClass = classOf[GenericInternalRow].getName
-    val result = ctx.freshName("result")
-    val tmpRow = ctx.freshName("tmpRow")
+    val tmpResult = ctx.freshName("tmpResult")
+    val tmpInput = ctx.freshName("tmpInput")
 
     val fieldsEvalCode = fieldsCasts.zipWithIndex.map { case (cast, i) =>
       val fromFieldPrim = ctx.freshName("ffp")
@@ -1024,37 +1024,33 @@ case class Cast(child: Expression, dataType: DataType, 
timeZoneId: Option[String
       val toFieldNull = ctx.freshName("tfn")
       val fromType = ctx.javaType(from.fields(i).dataType)
       s"""
-        boolean $fromFieldNull = $tmpRow.isNullAt($i);
+        boolean $fromFieldNull = $tmpInput.isNullAt($i);
         if ($fromFieldNull) {
-          $result.setNullAt($i);
+          $tmpResult.setNullAt($i);
         } else {
           $fromType $fromFieldPrim =
-            ${ctx.getValue(tmpRow, from.fields(i).dataType, i.toString)};
+            ${ctx.getValue(tmpInput, from.fields(i).dataType, i.toString)};
           ${castCode(ctx, fromFieldPrim,
             fromFieldNull, toFieldPrim, toFieldNull, to.fields(i).dataType, 
cast)}
           if ($toFieldNull) {
-            $result.setNullAt($i);
+            $tmpResult.setNullAt($i);
           } else {
-            ${ctx.setColumn(result, to.fields(i).dataType, i, toFieldPrim)};
+            ${ctx.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim)};
           }
         }
        """
     }
-    val fieldsEvalCodes = if (ctx.currentVars == null) {
-      ctx.splitExpressions(
-        expressions = fieldsEvalCode,
-        funcName = "castStruct",
-        arguments = ("InternalRow", tmpRow) :: (rowClass, result) :: Nil)
-    } else {
-      fieldsEvalCode.mkString("\n")
-    }
+    val fieldsEvalCodes = ctx.splitExpressions(
+      expressions = fieldsEvalCode,
+      funcName = "castStruct",
+      arguments = ("InternalRow", tmpInput) :: (rowClass, tmpResult) :: Nil)
 
-    (c, evPrim, evNull) =>
+    (input, result, resultIsNull) =>
       s"""
-        final $rowClass $result = new $rowClass(${fieldsCasts.length});
-        final InternalRow $tmpRow = $c;
+        final $rowClass $tmpResult = new $rowClass(${fieldsCasts.length});
+        final InternalRow $tmpInput = $input;
         $fieldsEvalCodes
-        $evPrim = $result;
+        $result = $tmpResult;
       """
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to