[SPARK-24121][SQL] Add API for handling expression code generation ## What changes were proposed in this pull request?
This patch tries to implement this [proposal](https://github.com/apache/spark/pull/19813#issuecomment-354045400) to add an API for handling expression code generation. It should allow us to manipulate how to generate codes for expressions. In details, this adds an new abstraction `CodeBlock` to `JavaCode`. `CodeBlock` holds the code snippet and inputs for generating actual java code. For example, in following java code: ```java int ${variable} = 1; boolean ${isNull} = ${CodeGenerator.defaultValue(BooleanType)}; ``` `variable`, `isNull` are two `VariableValue` and `CodeGenerator.defaultValue(BooleanType)` is a string. They are all inputs to this code block and held by `CodeBlock` representing this code. For codegen, we provide a specified string interpolator `code`, so you can define a code like this: ```scala val codeBlock = code""" |int ${variable} = 1; |boolean ${isNull} = ${CodeGenerator.defaultValue(BooleanType)}; """.stripMargin // Generates actual java code. codeBlock.toString ``` Because those inputs are held separately in `CodeBlock` before generating code, we can safely manipulate them, e.g., replacing statements to aliased variables, etc.. ## How was this patch tested? Added tests. Author: Liang-Chi Hsieh <vii...@gmail.com> Closes #21193 from viirya/SPARK-24121. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f9f055af Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f9f055af Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f9f055af Branch: refs/heads/master Commit: f9f055afa47412eec8228c843b34a90decb9be43 Parents: 8086acc Author: Liang-Chi Hsieh <vii...@gmail.com> Authored: Wed May 23 01:50:22 2018 +0800 Committer: Wenchen Fan <wenc...@databricks.com> Committed: Wed May 23 01:50:22 2018 +0800 ---------------------------------------------------------------------- .../catalyst/expressions/BoundAttribute.scala | 5 +- .../spark/sql/catalyst/expressions/Cast.scala | 10 +- .../sql/catalyst/expressions/Expression.scala | 26 ++-- .../expressions/MonotonicallyIncreasingID.scala | 3 +- .../sql/catalyst/expressions/ScalaUDF.scala | 3 +- .../sql/catalyst/expressions/SortOrder.scala | 3 +- .../catalyst/expressions/SparkPartitionID.scala | 3 +- .../sql/catalyst/expressions/TimeWindow.scala | 3 +- .../sql/catalyst/expressions/arithmetic.scala | 13 +- .../expressions/codegen/CodeGenerator.scala | 25 ++-- .../expressions/codegen/CodegenFallback.scala | 5 +- .../codegen/GenerateSafeProjection.scala | 7 +- .../codegen/GenerateUnsafeProjection.scala | 5 +- .../catalyst/expressions/codegen/javaCode.scala | 145 ++++++++++++++++++- .../expressions/collectionOperations.scala | 19 +-- .../expressions/complexTypeCreator.scala | 7 +- .../expressions/conditionalExpressions.scala | 5 +- .../expressions/datetimeExpressions.scala | 23 +-- .../expressions/decimalExpressions.scala | 5 +- .../sql/catalyst/expressions/generators.scala | 3 +- .../spark/sql/catalyst/expressions/hash.scala | 5 +- .../catalyst/expressions/inputFileBlock.scala | 14 +- .../catalyst/expressions/mathExpressions.scala | 5 +- .../spark/sql/catalyst/expressions/misc.scala | 5 +- .../catalyst/expressions/nullExpressions.scala | 9 +- .../catalyst/expressions/objects/objects.scala | 48 +++--- .../sql/catalyst/expressions/predicates.scala | 15 +- .../expressions/randomExpressions.scala | 5 +- .../expressions/regexpExpressions.scala | 9 +- .../expressions/stringExpressions.scala | 25 ++-- .../expressions/ExpressionEvalHelperSuite.scala | 3 +- .../expressions/codegen/CodeBlockSuite.scala | 136 +++++++++++++++++ .../spark/sql/execution/ColumnarBatchScan.scala | 9 +- .../apache/spark/sql/execution/ExpandExec.scala | 3 +- .../spark/sql/execution/GenerateExec.scala | 5 +- .../sql/execution/WholeStageCodegenExec.scala | 15 +- .../execution/aggregate/HashAggregateExec.scala | 7 +- .../execution/aggregate/HashMapGenerator.scala | 3 +- .../execution/joins/BroadcastHashJoinExec.scala | 3 +- .../sql/execution/joins/SortMergeJoinExec.scala | 5 +- .../spark/sql/GeneratorFunctionSuite.scala | 4 +- 41 files changed, 479 insertions(+), 172 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 4cc84b2..df3ab05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -21,6 +21,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ /** @@ -56,13 +57,13 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) val value = CodeGenerator.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) if (nullable) { ev.copy(code = - s""" + code""" |boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); |$javaType ${ev.value} = ${ev.isNull} ? | ${CodeGenerator.defaultValue(dataType)} : ($value); """.stripMargin) } else { - ev.copy(code = s"$javaType ${ev.value} = $value;", isNull = FalseLiteral) + ev.copy(code = code"$javaType ${ev.value} = $value;", isNull = FalseLiteral) } } } http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 12330bf..699ea53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -23,6 +23,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -623,8 +624,13 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx) - ev.copy(code = eval.code + - castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)) + + ev.copy(code = + code""" + ${eval.code} + // This comment is added for manually tracking reference of ${eval.value}, ${eval.isNull} + ${castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)} + """) } // The function arguments are: `input`, `result` and `resultIsNull`. We don't need `inputIsNull` http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 97dff6a..9b9fa41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -22,6 +22,7 @@ import java.util.Locale import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -108,9 +109,9 @@ abstract class Expression extends TreeNode[Expression] { JavaCode.isNullVariable(isNull), JavaCode.variable(value, dataType))) reduceCodeSize(ctx, eval) - if (eval.code.nonEmpty) { + if (eval.code.toString.nonEmpty) { // Add `this` in the comment. - eval.copy(code = s"${ctx.registerComment(this.toString)}\n" + eval.code.trim) + eval.copy(code = ctx.registerComment(this.toString) + eval.code) } else { eval } @@ -119,7 +120,7 @@ abstract class Expression extends TreeNode[Expression] { private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = { // TODO: support whole stage codegen too - if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { + if (eval.code.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { val setIsNull = if (!eval.isNull.isInstanceOf[LiteralValue]) { val globalIsNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, "globalIsNull") val localIsNull = eval.isNull @@ -136,14 +137,14 @@ abstract class Expression extends TreeNode[Expression] { val funcFullName = ctx.addNewFunction(funcName, s""" |private $javaType $funcName(InternalRow ${ctx.INPUT_ROW}) { - | ${eval.code.trim} + | ${eval.code} | $setIsNull | return ${eval.value}; |} """.stripMargin) eval.value = JavaCode.variable(newValue, dataType) - eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" + eval.code = code"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" } } @@ -437,15 +438,14 @@ abstract class UnaryExpression extends Expression { if (nullable) { val nullSafeEval = ctx.nullSafeExec(child.nullable, childGen.isNull)(resultCode) - ev.copy(code = s""" + ev.copy(code = code""" ${childGen.code} boolean ${ev.isNull} = ${childGen.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval """) } else { - ev.copy(code = s""" - boolean ${ev.isNull} = false; + ev.copy(code = code""" ${childGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $resultCode""", isNull = FalseLiteral) @@ -537,14 +537,13 @@ abstract class BinaryExpression extends Expression { } } - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval """) } else { - ev.copy(code = s""" - boolean ${ev.isNull} = false; + ev.copy(code = code""" ${leftGen.code} ${rightGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -681,13 +680,12 @@ abstract class TernaryExpression extends Expression { } } - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; $nullSafeEval""") } else { - ev.copy(code = s""" - boolean ${ev.isNull} = false; + ev.copy(code = code""" ${leftGen.code} ${midGen.code} ${rightGen.code} http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 9f07796..f1da592 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, LongType} /** @@ -72,7 +73,7 @@ case class MonotonicallyIncreasingID() extends LeafExpression with Stateful { ctx.addPartitionInitializationStatement(s"$countTerm = 0L;") ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) partitionIndex) << 33;") - ev.copy(code = s""" + ev.copy(code = code""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; $countTerm++;""", isNull = FalseLiteral) } http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index e869258..3e7ca88 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.DataType /** @@ -1030,7 +1031,7 @@ case class ScalaUDF( """.stripMargin ev.copy(code = - s""" + code""" |$evalCode |${initArgs.mkString("\n")} |$callFunc http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index ff7c98f..2ce9d07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.unsafe.sort.PrefixComparators._ @@ -181,7 +182,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { } ev.copy(code = childCode.code + - s""" + code""" |long ${ev.value} = 0L; |boolean ${ev.isNull} = ${childCode.isNull}; |if (!${childCode.isNull}) { http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 787bcaf..9856b37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -46,7 +47,7 @@ case class SparkPartitionID() extends LeafExpression with Nondeterministic { val idTerm = "partitionId" ctx.addImmutableStateIfNotExists(CodeGenerator.JAVA_INT, idTerm) ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;") - ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;", + ev.copy(code = code"final ${CodeGenerator.javaType(dataType)} ${ev.value} = $idTerm;", isNull = FalseLiteral) } } http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 6c4a360..84e38a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -164,7 +165,7 @@ case class PreciseTimestampConversion( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) ev.copy(code = eval.code + - s"""boolean ${ev.isNull} = ${eval.isNull}; + code"""boolean ${ev.isNull} = ${eval.isNull}; |${CodeGenerator.javaType(dataType)} ${ev.value} = ${eval.value}; """.stripMargin) } http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index efd4e99..fe91e52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -259,7 +260,7 @@ trait DivModLike extends BinaryArithmetic { s"($javaType)(${eval1.value} $symbol ${eval2.value})" } if (!left.nullable && !right.nullable) { - ev.copy(code = s""" + ev.copy(code = code""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -270,7 +271,7 @@ trait DivModLike extends BinaryArithmetic { ${ev.value} = $operation; }""") } else { - ev.copy(code = s""" + ev.copy(code = code""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -436,7 +437,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { } if (!left.nullable && !right.nullable) { - ev.copy(code = s""" + ev.copy(code = code""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -447,7 +448,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { $result }""") } else { - ev.copy(code = s""" + ev.copy(code = code""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -569,7 +570,7 @@ case class Least(children: Seq[Expression]) extends Expression { """.stripMargin, foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) ev.copy(code = - s""" + code""" |${ev.isNull} = true; |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |$codes @@ -644,7 +645,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { """.stripMargin, foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) ev.copy(code = - s""" + code""" |${ev.isNull} = true; |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |$codes http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index d382d9a..66315e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -38,6 +38,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.metrics.source.CodegenMetrics import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -57,19 +58,19 @@ import org.apache.spark.util.{ParentClassLoader, Utils} * @param value A term for a (possibly primitive) value of the result of the evaluation. Not * valid if `isNull` is set to `true`. */ -case class ExprCode(var code: String, var isNull: ExprValue, var value: ExprValue) +case class ExprCode(var code: Block, var isNull: ExprValue, var value: ExprValue) object ExprCode { def apply(isNull: ExprValue, value: ExprValue): ExprCode = { - ExprCode(code = "", isNull, value) + ExprCode(code = EmptyBlock, isNull, value) } def forNullValue(dataType: DataType): ExprCode = { - ExprCode(code = "", isNull = TrueLiteral, JavaCode.defaultLiteral(dataType)) + ExprCode(code = EmptyBlock, isNull = TrueLiteral, JavaCode.defaultLiteral(dataType)) } def forNonNullValue(value: ExprValue): ExprCode = { - ExprCode(code = "", isNull = FalseLiteral, value = value) + ExprCode(code = EmptyBlock, isNull = FalseLiteral, value = value) } } @@ -330,9 +331,9 @@ class CodegenContext { def addBufferedState(dataType: DataType, variableName: String, initCode: String): ExprCode = { val value = addMutableState(javaType(dataType), variableName) val code = dataType match { - case StringType => s"$value = $initCode.clone();" - case _: StructType | _: ArrayType | _: MapType => s"$value = $initCode.copy();" - case _ => s"$value = $initCode;" + case StringType => code"$value = $initCode.clone();" + case _: StructType | _: ArrayType | _: MapType => code"$value = $initCode.copy();" + case _ => code"$value = $initCode;" } ExprCode(code, FalseLiteral, JavaCode.global(value, dataType)) } @@ -1056,7 +1057,7 @@ class CodegenContext { val eval = expr.genCode(this) val state = SubExprEliminationState(eval.isNull, eval.value) e.foreach(localSubExprEliminationExprs.put(_, state)) - eval.code.trim + eval.code.toString } SubExprCodes(codes, localSubExprEliminationExprs.toMap) } @@ -1084,7 +1085,7 @@ class CodegenContext { val fn = s""" |private void $fnName(InternalRow $INPUT_ROW) { - | ${eval.code.trim} + | ${eval.code} | $isNull = ${eval.isNull}; | $value = ${eval.value}; |} @@ -1141,7 +1142,7 @@ class CodegenContext { def registerComment( text: => String, placeholderId: String = "", - force: Boolean = false): String = { + force: Boolean = false): Block = { // By default, disable comments in generated code because computing the comments themselves can // be extremely expensive in certain cases, such as deeply-nested expressions which operate over // inputs with wide schemas. For more details on the performance issues that motivated this @@ -1160,9 +1161,9 @@ class CodegenContext { s"// $text" } placeHolderToComments += (name -> comment) - s"/*$name*/" + code"/*$name*/" } else { - "" + EmptyBlock } } } http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index a91989e..3f4704d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression, Nondeterministic} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ /** * A trait that can be used to provide a fallback mode for expression code generation. @@ -46,7 +47,7 @@ trait CodegenFallback extends Expression { val placeHolder = ctx.registerComment(this.toString) val javaType = CodeGenerator.javaType(this.dataType) if (nullable) { - ev.copy(code = s""" + ev.copy(code = code""" $placeHolder Object $objectTerm = ((Expression) references[$idx]).eval($input); boolean ${ev.isNull} = $objectTerm == null; @@ -55,7 +56,7 @@ trait CodegenFallback extends Expression { ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm; }""") } else { - ev.copy(code = s""" + ev.copy(code = code""" $placeHolder Object $objectTerm = ((Expression) references[$idx]).eval($input); $javaType ${ev.value} = (${CodeGenerator.boxedType(this.dataType)}) $objectTerm; http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 01c350e..3977866 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -22,6 +22,7 @@ import scala.annotation.tailrec import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ @@ -71,7 +72,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] arguments = Seq("InternalRow" -> tmpInput, "Object[]" -> values) ) val code = - s""" + code""" |final InternalRow $tmpInput = $input; |final Object[] $values = new Object[${schema.length}]; |$allFields @@ -97,7 +98,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] ctx, JavaCode.expression(CodeGenerator.getValue(tmpInput, elementType, index), elementType), elementType) - val code = s""" + val code = code""" final ArrayData $tmpInput = $input; final int $numElements = $tmpInput.numElements(); final Object[] $values = new Object[$numElements]; @@ -124,7 +125,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val keyConverter = createCodeForArray(ctx, s"$tmpInput.keyArray()", keyType) val valueConverter = createCodeForArray(ctx, s"$tmpInput.valueArray()", valueType) - val code = s""" + val code = code""" final MapData $tmpInput = $input; ${keyConverter.code} ${valueConverter.code} http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 01b4d6c..8f2a5a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ /** @@ -286,7 +287,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro ctx, ctx.INPUT_ROW, exprEvals, exprTypes, rowWriter, isTopLevel = true) val code = - s""" + code""" |$rowWriter.reset(); |$evalSubexpr |$writeExpressions @@ -343,7 +344,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro | } | | public UnsafeRow apply(InternalRow ${ctx.INPUT_ROW}) { - | ${eval.code.trim} + | ${eval.code} | return ${eval.value}; | } | http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala index 74ff018..250ce48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/javaCode.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import java.lang.{Boolean => JBool} +import scala.collection.mutable.ArrayBuffer import scala.language.{existentials, implicitConversions} import org.apache.spark.sql.types.{BooleanType, DataType} @@ -115,6 +116,147 @@ object JavaCode { } /** + * A trait representing a block of java code. + */ +trait Block extends JavaCode { + + // The expressions to be evaluated inside this block. + def exprValues: Set[ExprValue] + + // Returns java code string for this code block. + override def toString: String = _marginChar match { + case Some(c) => code.stripMargin(c).trim + case _ => code.trim + } + + def length: Int = toString.length + + def nonEmpty: Boolean = toString.nonEmpty + + // The leading prefix that should be stripped from each line. + // By default we strip blanks or control characters followed by '|' from the line. + var _marginChar: Option[Char] = Some('|') + + def stripMargin(c: Char): this.type = { + _marginChar = Some(c) + this + } + + def stripMargin: this.type = { + _marginChar = Some('|') + this + } + + // Concatenates this block with other block. + def + (other: Block): Block +} + +object Block { + + val CODE_BLOCK_BUFFER_LENGTH: Int = 512 + + implicit def blocksToBlock(blocks: Seq[Block]): Block = Blocks(blocks) + + implicit class BlockHelper(val sc: StringContext) extends AnyVal { + def code(args: Any*): Block = { + sc.checkLengths(args) + if (sc.parts.length == 0) { + EmptyBlock + } else { + args.foreach { + case _: ExprValue => + case _: Int | _: Long | _: Float | _: Double | _: String => + case _: Block => + case other => throw new IllegalArgumentException( + s"Can not interpolate ${other.getClass.getName} into code block.") + } + + val (codeParts, blockInputs) = foldLiteralArgs(sc.parts, args) + CodeBlock(codeParts, blockInputs) + } + } + } + + // Folds eagerly the literal args into the code parts. + private def foldLiteralArgs(parts: Seq[String], args: Seq[Any]): (Seq[String], Seq[JavaCode]) = { + val codeParts = ArrayBuffer.empty[String] + val blockInputs = ArrayBuffer.empty[JavaCode] + + val strings = parts.iterator + val inputs = args.iterator + val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH) + + buf.append(strings.next) + while (strings.hasNext) { + val input = inputs.next + input match { + case _: ExprValue | _: Block => + codeParts += buf.toString + buf.clear + blockInputs += input.asInstanceOf[JavaCode] + case _ => + buf.append(input) + } + buf.append(strings.next) + } + if (buf.nonEmpty) { + codeParts += buf.toString + } + + (codeParts.toSeq, blockInputs.toSeq) + } +} + +/** + * A block of java code. Including a sequence of code parts and some inputs to this block. + * The actual java code is generated by embedding the inputs into the code parts. + */ +case class CodeBlock(codeParts: Seq[String], blockInputs: Seq[JavaCode]) extends Block { + override lazy val exprValues: Set[ExprValue] = { + blockInputs.flatMap { + case b: Block => b.exprValues + case e: ExprValue => Set(e) + }.toSet + } + + override lazy val code: String = { + val strings = codeParts.iterator + val inputs = blockInputs.iterator + val buf = new StringBuilder(Block.CODE_BLOCK_BUFFER_LENGTH) + buf.append(StringContext.treatEscapes(strings.next)) + while (strings.hasNext) { + buf.append(inputs.next) + buf.append(StringContext.treatEscapes(strings.next)) + } + buf.toString + } + + override def + (other: Block): Block = other match { + case c: CodeBlock => Blocks(Seq(this, c)) + case b: Blocks => Blocks(Seq(this) ++ b.blocks) + case EmptyBlock => this + } +} + +case class Blocks(blocks: Seq[Block]) extends Block { + override lazy val exprValues: Set[ExprValue] = blocks.flatMap(_.exprValues).toSet + override lazy val code: String = blocks.map(_.toString).mkString("\n") + + override def + (other: Block): Block = other match { + case c: CodeBlock => Blocks(blocks :+ c) + case b: Blocks => Blocks(blocks ++ b.blocks) + case EmptyBlock => this + } +} + +object EmptyBlock extends Block with Serializable { + override val code: String = "" + override val exprValues: Set[ExprValue] = Set.empty + + override def + (other: Block): Block = other +} + +/** * A typed java fragment that must be a valid java expression. */ trait ExprValue extends JavaCode { @@ -123,10 +265,9 @@ trait ExprValue extends JavaCode { } object ExprValue { - implicit def exprValueToString(exprValue: ExprValue): String = exprValue.toString + implicit def exprValueToString(exprValue: ExprValue): String = exprValue.code } - /** * A java expression fragment. */ http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 7da4c3c..c28eab7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -91,7 +92,7 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val childGen = child.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = false; ${childGen.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 : @@ -1177,14 +1178,14 @@ case class ArrayJoin( } if (nullable) { ev.copy( - s""" + code""" |boolean ${ev.isNull} = true; |UTF8String ${ev.value} = null; |$code """.stripMargin) } else { ev.copy( - s""" + code""" |UTF8String ${ev.value} = null; |$code """.stripMargin, FalseLiteral) @@ -1269,11 +1270,11 @@ case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCast val childGen = child.genCode(ctx) val javaType = CodeGenerator.javaType(dataType) val i = ctx.freshName("i") - val item = ExprCode("", + val item = ExprCode(EmptyBlock, isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"), value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType)) ev.copy(code = - s""" + code""" |${childGen.code} |boolean ${ev.isNull} = true; |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -1334,11 +1335,11 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast val childGen = child.genCode(ctx) val javaType = CodeGenerator.javaType(dataType) val i = ctx.freshName("i") - val item = ExprCode("", + val item = ExprCode(EmptyBlock, isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"), value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType)) ev.copy(code = - s""" + code""" |${childGen.code} |boolean ${ev.isNull} = true; |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -1653,7 +1654,7 @@ case class Concat(children: Seq[Expression]) extends Expression { expressions = inputs, funcName = "valueConcat", extraArguments = (s"$javaType[]", args) :: Nil) - ev.copy(s""" + ev.copy(code""" $initCode $codes $javaType ${ev.value} = $concatenator.concat($args); @@ -1963,7 +1964,7 @@ case class ArrayRepeat(left: Expression, right: Expression) val resultCode = nullElementsProtection(ev, rightGen.isNull, coreLogic) ev.copy(code = - s""" + code""" |boolean ${ev.isNull} = false; |${leftGen.code} |${rightGen.code} http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 67876a8..a9867aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform @@ -63,7 +64,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { val (preprocess, assigns, postprocess, arrayData) = GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false) ev.copy( - code = preprocess + assigns + postprocess, + code = code"${preprocess}${assigns}${postprocess}", value = JavaCode.variable(arrayData, dataType), isNull = FalseLiteral) } @@ -219,7 +220,7 @@ case class CreateMap(children: Seq[Expression]) extends Expression { val (preprocessValueData, assignValues, postprocessValueData, valueArrayData) = GenArrayData.genCodeToCreateArrayData(ctx, valueDt, evalValues, false) val code = - s""" + code""" final boolean ${ev.isNull} = false; $preprocessKeyData $assignKeys @@ -373,7 +374,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc extraArguments = "Object[]" -> values :: Nil) ev.copy(code = - s""" + code""" |Object[] $values = new Object[${valExprs.size}]; |$valuesCode |final InternalRow ${ev.value} = new $rowClass($values); http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 205d77f..77ac6c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ // scalastyle:off line.size.limit @@ -66,7 +67,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi val falseEval = falseValue.genCode(ctx) val code = - s""" + code""" |${condEval.code} |boolean ${ev.isNull} = false; |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -265,7 +266,7 @@ case class CaseWhen( }.mkString) ev.copy(code = - s""" + code""" |${CodeGenerator.JAVA_BYTE} $resultState = $NOT_MATCHED; |do { | $codes http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 03422fe..e8d85f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -27,6 +27,7 @@ import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -717,7 +718,7 @@ abstract class UnixTime } else { val formatterName = ctx.addReferenceObj("formatter", formatter, df) val eval1 = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -746,7 +747,7 @@ abstract class UnixTime }) case TimestampType => val eval1 = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -757,7 +758,7 @@ abstract class UnixTime val tz = ctx.addReferenceObj("timeZone", timeZone) val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") val eval1 = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -852,7 +853,7 @@ case class FromUnixTime(sec: Expression, format: Expression, timeZoneId: Option[ } else { val formatterName = ctx.addReferenceObj("formatter", formatter, df) val t = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${t.code} boolean ${ev.isNull} = ${t.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -1042,7 +1043,7 @@ case class StringToTimestampWithoutTimezone(child: Expression, timeZoneId: Optio val tz = ctx.addReferenceObj("timeZone", timeZone) val longOpt = ctx.freshName("longOpt") val eval = child.genCode(ctx) - val code = s""" + val code = code""" |${eval.code} |${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = true; |${CodeGenerator.JAVA_LONG} ${ev.value} = ${CodeGenerator.defaultValue(TimestampType)}; @@ -1090,7 +1091,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression) if (right.foldable) { val tz = right.eval().asInstanceOf[UTF8String] if (tz == null) { - ev.copy(code = s""" + ev.copy(code = code""" |boolean ${ev.isNull} = true; |long ${ev.value} = 0; """.stripMargin) @@ -1104,7 +1105,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression) ctx.addImmutableStateIfNotExists(tzClass, utcTerm, v => s"""$v = $dtu.getTimeZone("UTC");""") val eval = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" |${eval.code} |boolean ${ev.isNull} = ${eval.isNull}; |long ${ev.value} = 0; @@ -1287,7 +1288,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression) if (right.foldable) { val tz = right.eval().asInstanceOf[UTF8String] if (tz == null) { - ev.copy(code = s""" + ev.copy(code = code""" |boolean ${ev.isNull} = true; |long ${ev.value} = 0; """.stripMargin) @@ -1301,7 +1302,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression) ctx.addImmutableStateIfNotExists(tzClass, utcTerm, v => s"""$v = $dtu.getTimeZone("UTC");""") val eval = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" |${eval.code} |boolean ${ev.isNull} = ${eval.isNull}; |long ${ev.value} = 0; @@ -1444,13 +1445,13 @@ trait TruncInstant extends BinaryExpression with ImplicitCastInputTypes { val javaType = CodeGenerator.javaType(dataType) if (format.foldable) { if (truncLevel == DateTimeUtils.TRUNC_INVALID || truncLevel > maxLevel) { - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""") } else { val t = instant.genCode(ctx) val truncFuncStr = truncFunc(t.value, truncLevel.toString) - ev.copy(code = s""" + ev.copy(code = code""" ${t.code} boolean ${ev.isNull} = ${t.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index db1579b..04de833 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, EmptyBlock, ExprCode} import org.apache.spark.sql.types._ /** @@ -72,7 +72,8 @@ case class PromotePrecision(child: Expression) extends UnaryExpression { override def eval(input: InternalRow): Any = child.eval(input) /** Just a simple pass-through for code generation. */ override def genCode(ctx: CodegenContext): ExprCode = child.genCode(ctx) - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev.copy("") + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + ev.copy(EmptyBlock) override def prettyName: String = "promote_precision" override def sql: String = child.sql override lazy val canonicalized: Expression = child.canonicalized http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 3af4bfe..b7c52f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ @@ -215,7 +216,7 @@ case class Stack(children: Seq[Expression]) extends Generator { // Create the collection. val wrapperClass = classOf[mutable.WrappedArray[_]].getName ev.copy(code = - s""" + code""" |$code |$wrapperClass<InternalRow> ${ev.value} = $wrapperClass$$.MODULE$$.make($rowData); """.stripMargin, isNull = FalseLiteral) http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index ef79033..cec00b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -28,6 +28,7 @@ import org.apache.commons.codec.digest.DigestUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform @@ -293,7 +294,7 @@ abstract class HashExpression[E] extends Expression { foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) ev.copy(code = - s""" + code""" |$hashResultType ${ev.value} = $seed; |$codes """.stripMargin) @@ -674,7 +675,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { ev.copy(code = - s""" + code""" |${CodeGenerator.JAVA_INT} ${ev.value} = $seed; |${CodeGenerator.JAVA_INT} $childHash = 0; |$codes http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala index 2a3cc58..3b0141a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/inputFileBlock.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.rdd.InputFileBlockHolder import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -42,8 +43,9 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getInputFilePath();", isNull = FalseLiteral) + val typeDef = s"final ${CodeGenerator.javaType(dataType)}" + ev.copy(code = code"$typeDef ${ev.value} = $className.getInputFilePath();", + isNull = FalseLiteral) } } @@ -65,8 +67,8 @@ case class InputFileBlockStart() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getStartOffset();", isNull = FalseLiteral) + val typeDef = s"final ${CodeGenerator.javaType(dataType)}" + ev.copy(code = code"$typeDef ${ev.value} = $className.getStartOffset();", isNull = FalseLiteral) } } @@ -88,7 +90,7 @@ case class InputFileBlockLength() extends LeafExpression with Nondeterministic { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val className = InputFileBlockHolder.getClass.getName.stripSuffix("$") - ev.copy(code = s"final ${CodeGenerator.javaType(dataType)} ${ev.value} = " + - s"$className.getLength();", isNull = FalseLiteral) + val typeDef = s"final ${CodeGenerator.javaType(dataType)}" + ev.copy(code = code"$typeDef ${ev.value} = $className.getLength();", isNull = FalseLiteral) } } http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index bc4cfce..c2e1720 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.NumberConverter import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -1191,11 +1192,11 @@ abstract class RoundBase(child: Expression, scale: Expression, val javaType = CodeGenerator.javaType(dataType) if (scaleV == null) { // if scale is null, no need to eval its child at all - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""") } else { - ev.copy(code = s""" + ev.copy(code = code""" ${ce.code} boolean ${ev.isNull} = ${ce.isNull}; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index b783469..5d98dac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -21,6 +21,7 @@ import java.util.UUID import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.RandomUUIDGenerator import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -88,7 +89,7 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa // Use unnamed reference that doesn't create a local field here to reduce the number of fields // because errMsgField is used only when the value is null or false. val errMsgField = ctx.addReferenceObj("errMsg", errMsg) - ExprCode(code = s"""${eval.code} + ExprCode(code = code"""${eval.code} |if (${eval.isNull} || !${eval.value}) { | throw new RuntimeException($errMsgField); |}""".stripMargin, isNull = TrueLiteral, @@ -151,7 +152,7 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Sta ctx.addPartitionInitializationStatement(s"$randomGen = " + "new org.apache.spark.sql.catalyst.util.RandomUUIDGenerator(" + s"${randomSeed.get}L + partitionIndex);") - ev.copy(code = s"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();", + ev.copy(code = code"final UTF8String ${ev.value} = $randomGen.getNextUUIDUTF8String();", isNull = FalseLiteral) } http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 0787342..2eeed3b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -111,7 +112,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { ev.copy(code = - s""" + code""" |${ev.isNull} = true; |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; |do { @@ -232,7 +233,7 @@ case class IsNaN(child: Expression) extends UnaryExpression val eval = child.genCode(ctx) child.dataType match { case DoubleType | FloatType => - ev.copy(code = s""" + ev.copy(code = code""" ${eval.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""", isNull = FalseLiteral) @@ -278,7 +279,7 @@ case class NaNvl(left: Expression, right: Expression) val rightGen = right.genCode(ctx) left.dataType match { case DoubleType | FloatType => - ev.copy(code = s""" + ev.copy(code = code""" ${leftGen.code} boolean ${ev.isNull} = false; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -440,7 +441,7 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate }.mkString) ev.copy(code = - s""" + code""" |${CodeGenerator.JAVA_INT} $nonnull = 0; |do { | $codes http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index f974fd8..2bf4203 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -269,7 +270,7 @@ case class StaticInvoke( s"${ev.value} = $callFunc;" } - val code = s""" + val code = code""" $argCode $prepareIsNull $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -385,8 +386,7 @@ case class Invoke( """ } - val code = s""" - ${obj.code} + val code = obj.code + code""" boolean ${ev.isNull} = true; $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${obj.isNull}) { @@ -492,7 +492,7 @@ case class NewInstance( s"new $className($argString)" } - val code = s""" + val code = code""" $argCode ${outer.map(_.code).getOrElse("")} final $javaType ${ev.value} = ${ev.isNull} ? @@ -532,9 +532,7 @@ case class UnwrapOption( val javaType = CodeGenerator.javaType(dataType) val inputObject = child.genCode(ctx) - val code = s""" - ${inputObject.code} - + val code = inputObject.code + code""" final boolean ${ev.isNull} = ${inputObject.isNull} || ${inputObject.value}.isEmpty(); $javaType ${ev.value} = ${ev.isNull} ? ${CodeGenerator.defaultValue(dataType)} : (${CodeGenerator.boxedType(javaType)}) ${inputObject.value}.get(); @@ -564,9 +562,7 @@ case class WrapOption(child: Expression, optType: DataType) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val inputObject = child.genCode(ctx) - val code = s""" - ${inputObject.code} - + val code = inputObject.code + code""" scala.Option ${ev.value} = ${inputObject.isNull} ? scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value}); @@ -935,8 +931,7 @@ case class MapObjects private( ) } - val code = s""" - ${genInputData.code} + val code = genInputData.code + code""" ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${genInputData.isNull}) { @@ -1147,8 +1142,7 @@ case class CatalystToExternalMap private( """ val getBuilderResult = s"${ev.value} = (${collClass.getName}) $builderValue.result();" - val code = s""" - ${genInputData.code} + val code = genInputData.code + code""" ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${genInputData.isNull}) { @@ -1391,9 +1385,8 @@ case class ExternalMapToCatalyst private( val mapCls = classOf[ArrayBasedMapData].getName val convertedKeyType = CodeGenerator.boxedType(keyConverter.dataType) val convertedValueType = CodeGenerator.boxedType(valueConverter.dataType) - val code = - s""" - ${inputMap.code} + val code = inputMap.code + + code""" ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${inputMap.isNull}) { final int $length = ${inputMap.value}.size(); @@ -1471,7 +1464,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) val schemaField = ctx.addReferenceObj("schema", schema) val code = - s""" + code""" |Object[] $values = new Object[${children.size}]; |$childrenCode |final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField); @@ -1499,8 +1492,7 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) val javaType = CodeGenerator.javaType(dataType) val serialize = s"$serializer.serialize(${input.value}, null).array()" - val code = s""" - ${input.code} + val code = input.code + code""" final $javaType ${ev.value} = ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $serialize; """ @@ -1532,8 +1524,7 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B val deserialize = s"($javaType) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)" - val code = s""" - ${input.code} + val code = input.code + code""" final $javaType ${ev.value} = ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $deserialize; """ @@ -1614,9 +1605,8 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp funcName = "initializeJavaBean", extraArguments = beanInstanceJavaType -> javaBeanInstance :: Nil) - val code = - s""" - |${instanceGen.code} + val code = instanceGen.code + + code""" |$beanInstanceJavaType $javaBeanInstance = ${instanceGen.value}; |if (!${instanceGen.isNull}) { | $initializeCode @@ -1664,9 +1654,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String] = Nil) // because errMsgField is used only when the value is null. val errMsgField = ctx.addReferenceObj("errMsg", errMsg) - val code = s""" - ${childGen.code} - + val code = childGen.code + code""" if (${childGen.isNull}) { throw new NullPointerException($errMsgField); } @@ -1709,7 +1697,7 @@ case class GetExternalRowField( // because errMsgField is used only when the field is null. val errMsgField = ctx.addReferenceObj("errMsg", errMsg) val row = child.genCode(ctx) - val code = s""" + val code = code""" ${row.code} if (${row.isNull}) { @@ -1784,7 +1772,7 @@ case class ValidateExternalType(child: Expression, expected: DataType) s"$obj instanceof ${CodeGenerator.boxedType(dataType)}" } - val code = s""" + val code = code""" ${input.code} ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${input.isNull}) { http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index f8c6dc4..f54103c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -22,6 +22,7 @@ import scala.collection.immutable.TreeSet import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ @@ -290,7 +291,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { }.mkString("\n")) ev.copy(code = - s""" + code""" |${valueGen.code} |byte $tmpResult = $HAS_NULL; |if (!${valueGen.isNull}) { @@ -354,7 +355,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with "" } ev.copy(code = - s""" + code""" |${childGen.code} |${CodeGenerator.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull}; |${CodeGenerator.JAVA_BOOLEAN} ${ev.value} = false; @@ -406,7 +407,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with // The result should be `false`, if any of them is `false` whenever the other is null or not. if (!left.nullable && !right.nullable) { - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.value} = false; @@ -415,7 +416,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with ${ev.value} = ${eval2.value}; }""", isNull = FalseLiteral) } else { - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = false; boolean ${ev.value} = false; @@ -470,7 +471,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P // The result should be `true`, if any of them is `true` whenever the other is null or not. if (!left.nullable && !right.nullable) { ev.isNull = FalseLiteral - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.value} = true; @@ -479,7 +480,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P ${ev.value} = ${eval2.value}; }""", isNull = FalseLiteral) } else { - ev.copy(code = s""" + ev.copy(code = code""" ${eval1.code} boolean ${ev.isNull} = false; boolean ${ev.value} = true; @@ -621,7 +622,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp val eval1 = left.genCode(ctx) val eval2 = right.genCode(ctx) val equalCode = ctx.genEqual(left.dataType, eval1.value, eval2.value) - ev.copy(code = eval1.code + eval2.code + s""" + ev.copy(code = eval1.code + eval2.code + code""" boolean ${ev.value} = (${eval1.isNull} && ${eval2.isNull}) || (!${eval1.isNull} && !${eval2.isNull} && $equalCode);""", isNull = FalseLiteral) } http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 2653b28..926c2f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom @@ -82,7 +83,7 @@ case class Rand(child: Expression) extends RDG { val rngTerm = ctx.addMutableState(className, "rng") ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") - ev.copy(code = s""" + ev.copy(code = code""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", isNull = FalseLiteral) } @@ -120,7 +121,7 @@ case class Randn(child: Expression) extends RDG { val rngTerm = ctx.addMutableState(className, "rng") ctx.addPartitionInitializationStatement( s"$rngTerm = new $className(${seed}L + partitionIndex);") - ev.copy(code = s""" + ev.copy(code = code""" final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = FalseLiteral) } http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index ad0c079..7b68bb7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -23,6 +23,7 @@ import java.util.regex.{MatchResult, Pattern} import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -123,7 +124,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -132,7 +133,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi } """) } else { - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; """) @@ -198,7 +199,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; @@ -207,7 +208,7 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress } """) } else { - ev.copy(code = s""" + ev.copy(code = code""" boolean ${ev.isNull} = true; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; """) http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index ea005a2..9823b2f 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -27,6 +27,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} @@ -105,7 +106,7 @@ case class ConcatWs(children: Seq[Expression]) expressions = inputs, funcName = "valueConcatWs", extraArguments = ("UTF8String[]", args) :: Nil) - ev.copy(s""" + ev.copy(code""" UTF8String[] $args = new UTF8String[$numArgs]; ${separator.code} $codes @@ -149,7 +150,7 @@ case class ConcatWs(children: Seq[Expression]) } }.unzip - val codes = ctx.splitExpressionsWithCurrentInputs(evals.map(_.code)) + val codes = ctx.splitExpressionsWithCurrentInputs(evals.map(_.code.toString)) val varargCounts = ctx.splitExpressionsWithCurrentInputs( expressions = varargCount, @@ -176,7 +177,7 @@ case class ConcatWs(children: Seq[Expression]) foldFunctions = _.map(funcCall => s"$idxVararg = $funcCall;").mkString("\n")) ev.copy( - s""" + code""" $codes int $varargNum = ${children.count(_.dataType == StringType) - 1}; int $idxVararg = 0; @@ -288,7 +289,7 @@ case class Elt(children: Seq[Expression]) extends Expression { }.mkString) ev.copy( - s""" + code""" |${index.code} |final int $indexVal = ${index.value}; |${CodeGenerator.JAVA_BOOLEAN} $indexMatched = false; @@ -654,7 +655,7 @@ case class StringTrim( val srcString = evals(0) if (evals.length == 1) { - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -671,7 +672,7 @@ case class StringTrim( } else { ${ev.value} = ${srcString.value}.trim(${trimString.value}); }""" - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -754,7 +755,7 @@ case class StringTrimLeft( val srcString = evals(0) if (evals.length == 1) { - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -771,7 +772,7 @@ case class StringTrimLeft( } else { ${ev.value} = ${srcString.value}.trimLeft(${trimString.value}); }""" - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -856,7 +857,7 @@ case class StringTrimRight( val srcString = evals(0) if (evals.length == 1) { - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -873,7 +874,7 @@ case class StringTrimRight( } else { ${ev.value} = ${srcString.value}.trimRight(${trimString.value}); }""" - ev.copy(evals.map(_.code).mkString + s""" + ev.copy(evals.map(_.code) :+ code""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = null; if (${srcString.isNull}) { @@ -1024,7 +1025,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) val substrGen = substr.genCode(ctx) val strGen = str.genCode(ctx) val startGen = start.genCode(ctx) - ev.copy(code = s""" + ev.copy(code = code""" int ${ev.value} = 0; boolean ${ev.isNull} = false; ${startGen.code} @@ -1350,7 +1351,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC val formatter = classOf[java.util.Formatter].getName val sb = ctx.freshName("sb") val stringBuffer = classOf[StringBuffer].getName - ev.copy(code = s""" + ev.copy(code = code""" ${pattern.code} boolean ${ev.isNull} = ${pattern.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; http://git-wip-us.apache.org/repos/asf/spark/blob/f9f055af/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala index 64b65e2..7c7c4cc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelperSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.types.{DataType, IntegerType} /** @@ -45,7 +46,7 @@ case class BadCodegenExpression() extends LeafExpression { override def eval(input: InternalRow): Any = 10 override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { ev.copy(code = - s""" + code""" |int some_variable = 11; |int ${ev.value} = 10; """.stripMargin) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org