Repository: spark
Updated Branches:
  refs/heads/master 8086acc2f -> f9f055afa


http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala
new file mode 100644
index 0000000..d2c6420
--- /dev/null
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeBlockSuite.scala
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.types.{BooleanType, IntegerType}
+
+class CodeBlockSuite extends SparkFunSuite {
+
+  test("Block interpolates string and ExprValue inputs") {
+    val isNull = JavaCode.isNullVariable("expr1_isNull")
+    val stringLiteral = "false"
+    val code = code"boolean $isNull = $stringLiteral;"
+    assert(code.toString == "boolean expr1_isNull = false;")
+  }
+
+  test("Literals are folded into string code parts instead of block inputs") {
+    val value = JavaCode.variable("expr1", IntegerType)
+    val intLiteral = 1
+    val code = code"int $value = $intLiteral;"
+    assert(code.asInstanceOf[CodeBlock].blockInputs === Seq(value))
+  }
+
+  test("Block.stripMargin") {
+    val isNull = JavaCode.isNullVariable("expr1_isNull")
+    val value = JavaCode.variable("expr1", IntegerType)
+    val code1 =
+      code"""
+           |boolean $isNull = false;
+           |int $value = 
${JavaCode.defaultLiteral(IntegerType)};""".stripMargin
+    val expected =
+      s"""
+        |boolean expr1_isNull = false;
+        |int expr1 = 
${JavaCode.defaultLiteral(IntegerType)};""".stripMargin.trim
+    assert(code1.toString == expected)
+
+    val code2 =
+      code"""
+           >boolean $isNull = false;
+           >int $value = 
${JavaCode.defaultLiteral(IntegerType)};""".stripMargin('>')
+    assert(code2.toString == expected)
+  }
+
+  test("Block can capture input expr values") {
+    val isNull = JavaCode.isNullVariable("expr1_isNull")
+    val value = JavaCode.variable("expr1", IntegerType)
+    val code =
+      code"""
+           |boolean $isNull = false;
+           |int $value = -1;
+          """.stripMargin
+    val exprValues = code.exprValues
+    assert(exprValues.size == 2)
+    assert(exprValues === Set(value, isNull))
+  }
+
+  test("concatenate blocks") {
+    val isNull1 = JavaCode.isNullVariable("expr1_isNull")
+    val value1 = JavaCode.variable("expr1", IntegerType)
+    val isNull2 = JavaCode.isNullVariable("expr2_isNull")
+    val value2 = JavaCode.variable("expr2", IntegerType)
+    val literal = JavaCode.literal("100", IntegerType)
+
+    val code =
+      code"""
+           |boolean $isNull1 = false;
+           |int $value1 = -1;""".stripMargin +
+      code"""
+           |boolean $isNull2 = true;
+           |int $value2 = $literal;""".stripMargin
+
+    val expected =
+      """
+       |boolean expr1_isNull = false;
+       |int expr1 = -1;
+       |boolean expr2_isNull = true;
+       |int expr2 = 100;""".stripMargin.trim
+
+    assert(code.toString == expected)
+
+    val exprValues = code.exprValues
+    assert(exprValues.size == 5)
+    assert(exprValues === Set(isNull1, value1, isNull2, value2, literal))
+  }
+
+  test("Throws exception when interpolating unexcepted object in code block") {
+    val obj = Tuple2(1, 1)
+    val e = intercept[IllegalArgumentException] {
+      code"$obj"
+    }
+    assert(e.getMessage().contains(s"Can not interpolate 
${obj.getClass.getName}"))
+  }
+
+  test("replace expr values in code block") {
+    val expr = JavaCode.expression("1 + 1", IntegerType)
+    val isNull = JavaCode.isNullVariable("expr1_isNull")
+    val exprInFunc = JavaCode.variable("expr1", IntegerType)
+
+    val code =
+      code"""
+           |callFunc(int $expr) {
+           |  boolean $isNull = false;
+           |  int $exprInFunc = $expr + 1;
+           |}""".stripMargin
+
+    val aliasedParam = JavaCode.variable("aliased", expr.javaType)
+    val aliasedInputs = code.asInstanceOf[CodeBlock].blockInputs.map {
+      case _: SimpleExprValue => aliasedParam
+      case other => other
+    }
+    val aliasedCode = CodeBlock(code.asInstanceOf[CodeBlock].codeParts, 
aliasedInputs).stripMargin
+    val expected =
+      code"""
+           |callFunc(int $aliasedParam) {
+           |  boolean $isNull = false;
+           |  int $exprInFunc = $aliasedParam + 1;
+           |}""".stripMargin
+    assert(aliasedCode.toString == expected.toString)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
index fc3dbc1..48abad9 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution
 
 import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeRow}
 import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.sql.types.DataType
 import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
@@ -58,14 +59,14 @@ private[sql] trait ColumnarBatchScan extends CodegenSupport 
{
     }
     val valueVar = ctx.freshName("value")
     val str = s"columnVector[$columnVar, $ordinal, ${dataType.simpleString}]"
-    val code = s"${ctx.registerComment(str)}\n" + (if (nullable) {
-      s"""
+    val code = code"${ctx.registerComment(str)}" + (if (nullable) {
+      code"""
         boolean $isNullVar = $columnVar.isNullAt($ordinal);
         $javaType $valueVar = $isNullVar ? 
${CodeGenerator.defaultValue(dataType)} : ($value);
       """
     } else {
-      s"$javaType $valueVar = $value;"
-    }).trim
+      code"$javaType $valueVar = $value;"
+    })
     ExprCode(code, isNullVar, JavaCode.variable(valueVar, dataType))
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
index e4812f3..5b4edf5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala
@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.errors._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, 
UnknownPartitioning}
 import org.apache.spark.sql.execution.metric.SQLMetrics
 
@@ -152,7 +153,7 @@ case class ExpandExec(
       } else {
         val isNull = ctx.freshName("isNull")
         val value = ctx.freshName("value")
-        val code = s"""
+        val code = code"""
           |boolean $isNull = true;
           |${CodeGenerator.javaType(firstExpr.dataType)} $value =
           |  ${CodeGenerator.defaultValue(firstExpr.dataType)};

http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
index f40c50d..2549b9e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala
@@ -21,6 +21,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.execution.metric.SQLMetrics
 import org.apache.spark.sql.types._
@@ -313,13 +314,13 @@ case class GenerateExec(
     if (checks.nonEmpty) {
       val isNull = ctx.freshName("isNull")
       val code =
-        s"""
+        code"""
            |boolean $isNull = ${checks.mkString(" || ")};
            |$javaType $value = $isNull ? ${CodeGenerator.defaultValue(dt)} : 
$getter;
          """.stripMargin
       ExprCode(code, JavaCode.isNullVariable(isNull), JavaCode.variable(value, 
dt))
     } else {
-      ExprCode(s"$javaType $value = $getter;", FalseLiteral, 
JavaCode.variable(value, dt))
+      ExprCode(code"$javaType $value = $getter;", FalseLiteral, 
JavaCode.variable(value, dt))
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
index 828b51f..372dc3d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala
@@ -27,6 +27,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.plans.physical.Partitioning
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution.aggregate.HashAggregateExec
@@ -122,10 +123,10 @@ trait CodegenSupport extends SparkPlan {
         ctx.INPUT_ROW = row
         ctx.currentVars = colVars
         val ev = GenerateUnsafeProjection.createCode(ctx, colExprs, false)
-        val code = s"""
+        val code = code"""
           |$evaluateInputs
-          |${ev.code.trim}
-         """.stripMargin.trim
+          |${ev.code}
+         """.stripMargin
         ExprCode(code, FalseLiteral, ev.value)
       } else {
         // There are no columns
@@ -259,8 +260,8 @@ trait CodegenSupport extends SparkPlan {
    * them to be evaluated twice.
    */
   protected def evaluateVariables(variables: Seq[ExprCode]): String = {
-    val evaluate = variables.filter(_.code != 
"").map(_.code.trim).mkString("\n")
-    variables.foreach(_.code = "")
+    val evaluate = 
variables.filter(_.code.nonEmpty).map(_.code.toString).mkString("\n")
+    variables.foreach(_.code = EmptyBlock)
     evaluate
   }
 
@@ -275,8 +276,8 @@ trait CodegenSupport extends SparkPlan {
     val evaluateVars = new StringBuilder
     variables.zipWithIndex.foreach { case (ev, i) =>
       if (ev.code != "" && required.contains(attributes(i))) {
-        evaluateVars.append(ev.code.trim + "\n")
-        ev.code = ""
+        evaluateVars.append(ev.code.toString + "\n")
+        ev.code = EmptyBlock
       }
     }
     evaluateVars.toString()

http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index 6a8ec4f..8c7b2c1 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.errors._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
@@ -190,7 +191,7 @@ case class HashAggregateExec(
       val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), 
"bufValue")
       // The initial expression should not access any column
       val ev = e.genCode(ctx)
-      val initVars = s"""
+      val initVars = code"""
          | $isNull = ${ev.isNull};
          | $value = ${ev.value};
        """.stripMargin
@@ -773,8 +774,8 @@ case class HashAggregateExec(
     val findOrInsertRegularHashMap: String =
       s"""
          |// generate grouping key
-         |${unsafeRowKeyCode.code.trim}
-         |${hashEval.code.trim}
+         |${unsafeRowKeyCode.code}
+         |${hashEval.code}
          |if ($checkFallbackForBytesToBytesMap) {
          |  // try to get the buffer from hash map
          |  $unsafeRowBuffer =

http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala
index de2d630..e1c8582 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.aggregate
 
 import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
DeclarativeAggregate}
 import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.types._
 
 /**
@@ -50,7 +51,7 @@ abstract class HashMapGenerator(
       val value = ctx.addMutableState(CodeGenerator.javaType(e.dataType), 
"bufValue")
       val ev = e.genCode(ctx)
       val initVars =
-        s"""
+        code"""
            | $isNull = ${ev.isNull};
            | $value = ${ev.value};
        """.stripMargin

http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
index 6fa716d..0da0e86 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
@@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, 
Distribution, UnspecifiedDistribution}
 import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, 
SparkPlan}
@@ -183,7 +184,7 @@ case class BroadcastHashJoinExec(
         val isNull = ctx.freshName("isNull")
         val value = ctx.freshName("value")
         val javaType = CodeGenerator.javaType(a.dataType)
-        val code = s"""
+        val code = code"""
           |boolean $isNull = true;
           |$javaType $value = ${CodeGenerator.defaultValue(a.dataType)};
           |if ($matched != null) {

http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
index d8261f0..f4b9d13 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala
@@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution._
@@ -521,7 +522,7 @@ case class SortMergeJoinExec(
       if (a.nullable) {
         val isNull = ctx.freshName("isNull")
         val code =
-          s"""
+          code"""
              |$isNull = $leftRow.isNullAt($i);
              |$value = $isNull ? $defaultValue : ($valueCode);
            """.stripMargin
@@ -533,7 +534,7 @@ case class SortMergeJoinExec(
         (ExprCode(code, JavaCode.isNullVariable(isNull), 
JavaCode.variable(value, a.dataType)),
           leftVarsDecl)
       } else {
-        val code = s"$value = $valueCode;"
+        val code = code"$value = $valueCode;"
         val leftVarsDecl = s"""$javaType $value = $defaultValue;"""
         (ExprCode(code, FalseLiteral, JavaCode.variable(value, a.dataType)), 
leftVarsDecl)
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
index 109fcf9..8280a3c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Expression, Generator}
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, 
ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types.{IntegerType, StructType}
@@ -315,6 +316,7 @@ case class EmptyGenerator() extends Generator {
   override def eval(input: InternalRow): TraversableOnce[InternalRow] = 
Seq.empty
   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
     val iteratorClass = classOf[Iterator[_]].getName
-    ev.copy(code = s"$iteratorClass<InternalRow> ${ev.value} = 
$iteratorClass$$.MODULE$$.empty();")
+    ev.copy(code =
+      code"$iteratorClass<InternalRow> ${ev.value} = 
$iteratorClass$$.MODULE$$.empty();")
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to