Repository: spark Updated Branches: refs/heads/master 03fdc92e4 -> ced6ccf0d
[SPARK-22701][SQL] add ctx.splitExpressionsWithCurrentInputs ## What changes were proposed in this pull request? This pattern appears many times in the codebase: ``` if (ctx.INPUT_ROW == null || ctx.currentVars != null) { exprs.mkString("\n") } else { ctx.splitExpressions(...) } ``` This PR adds a `ctx.splitExpressionsWithCurrentInputs` for this pattern ## How was this patch tested? existing tests Author: Wenchen Fan <wenc...@databricks.com> Closes #19895 from cloud-fan/splitExpression. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ced6ccf0 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ced6ccf0 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ced6ccf0 Branch: refs/heads/master Commit: ced6ccf0d6f362e299f270ed2a474f2e14f845da Parents: 03fdc92 Author: Wenchen Fan <wenc...@databricks.com> Authored: Tue Dec 5 10:15:15 2017 -0800 Committer: gatorsmile <gatorsm...@gmail.com> Committed: Tue Dec 5 10:15:15 2017 -0800 ---------------------------------------------------------------------- .../sql/catalyst/expressions/arithmetic.scala | 4 +- .../expressions/codegen/CodeGenerator.scala | 44 ++++----- .../codegen/GenerateMutableProjection.scala | 4 +- .../codegen/GenerateSafeProjection.scala | 2 +- .../expressions/complexTypeCreator.scala | 6 +- .../expressions/conditionalExpressions.scala | 84 ++++++++--------- .../sql/catalyst/expressions/generators.scala | 2 +- .../spark/sql/catalyst/expressions/hash.scala | 55 +++++------- .../catalyst/expressions/nullExpressions.scala | 94 +++++++++----------- .../catalyst/expressions/objects/objects.scala | 6 +- .../sql/catalyst/expressions/predicates.scala | 47 +++++----- .../expressions/stringExpressions.scala | 37 ++++---- 12 files changed, 179 insertions(+), 206 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/ced6ccf0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index d98f7b3..739bd13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -614,7 +614,7 @@ case class Least(children: Seq[Expression]) extends Expression { } """ } - val codes = ctx.splitExpressions(evalChildren.map(updateEval)) + val codes = ctx.splitExpressionsWithCurrentInputs(evalChildren.map(updateEval)) ev.copy(code = s""" ${ev.isNull} = true; ${ev.value} = ${ctx.defaultValue(dataType)}; @@ -680,7 +680,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { } """ } - val codes = ctx.splitExpressions(evalChildren.map(updateEval)) + val codes = ctx.splitExpressionsWithCurrentInputs(evalChildren.map(updateEval)) ev.copy(code = s""" ${ev.isNull} = true; ${ev.value} = ${ctx.defaultValue(dataType)}; http://git-wip-us.apache.org/repos/asf/spark/blob/ced6ccf0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 1645db1..670c82e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -781,29 +781,26 @@ class CodegenContext { * beyond 1000kb, we declare a private, inner sub-class, and the function is inlined to it * instead, because classes have a constant pool limit of 65,536 named values. * - * Note that we will extract the current inputs of this context and pass them to the generated - * functions. The input is `INPUT_ROW` for normal codegen path, and `currentVars` for whole - * stage codegen path. Whole stage codegen path is not supported yet. - * - * @param expressions the codes to evaluate expressions. - */ - def splitExpressions(expressions: Seq[String]): String = { - splitExpressions(expressions, funcName = "apply", extraArguments = Nil) - } - - /** - * Similar to [[splitExpressions(expressions: Seq[String])]], but has customized function name - * and extra arguments. + * Note that different from `splitExpressions`, we will extract the current inputs of this + * context and pass them to the generated functions. The input is `INPUT_ROW` for normal codegen + * path, and `currentVars` for whole stage codegen path. Whole stage codegen path is not + * supported yet. * * @param expressions the codes to evaluate expressions. * @param funcName the split function name base. - * @param extraArguments the list of (type, name) of the arguments of the split function - * except for ctx.INPUT_ROW - */ - def splitExpressions( + * @param extraArguments the list of (type, name) of the arguments of the split function, + * except for the current inputs like `ctx.INPUT_ROW`. + * @param returnType the return type of the split function. + * @param makeSplitFunction makes split function body, e.g. add preparation or cleanup. + * @param foldFunctions folds the split function calls. + */ + def splitExpressionsWithCurrentInputs( expressions: Seq[String], - funcName: String, - extraArguments: Seq[(String, String)]): String = { + funcName: String = "apply", + extraArguments: Seq[(String, String)] = Nil, + returnType: String = "void", + makeSplitFunction: String => String = identity, + foldFunctions: Seq[String] => String = _.mkString("", ";\n", ";")): String = { // TODO: support whole stage codegen if (INPUT_ROW == null || currentVars != null) { expressions.mkString("\n") @@ -811,13 +808,18 @@ class CodegenContext { splitExpressions( expressions, funcName, - arguments = ("InternalRow", INPUT_ROW) +: extraArguments) + ("InternalRow", INPUT_ROW) +: extraArguments, + returnType, + makeSplitFunction, + foldFunctions) } } /** * Splits the generated code of expressions into multiple functions, because function has - * 64kb code size limit in JVM + * 64kb code size limit in JVM. If the class to which the function would be inlined would grow + * beyond 1000kb, we declare a private, inner sub-class, and the function is inlined to it + * instead, because classes have a constant pool limit of 65,536 named values. * * @param expressions the codes to evaluate expressions. * @param funcName the split function name base. http://git-wip-us.apache.org/repos/asf/spark/blob/ced6ccf0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 5fdbda5..bd8312e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -91,8 +91,8 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) } - val allProjections = ctx.splitExpressions(projectionCodes) - val allUpdates = ctx.splitExpressions(updates) + val allProjections = ctx.splitExpressionsWithCurrentInputs(projectionCodes) + val allUpdates = ctx.splitExpressionsWithCurrentInputs(updates) val codeBody = s""" public java.lang.Object generate(Object[] references) { http://git-wip-us.apache.org/repos/asf/spark/blob/ced6ccf0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 5d35cce..44e7148 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -159,7 +159,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] } """ } - val allExpressions = ctx.splitExpressions(expressionCodes) + val allExpressions = ctx.splitExpressionsWithCurrentInputs(expressionCodes) val codeBody = s""" public java.lang.Object generate(Object[] references) { http://git-wip-us.apache.org/repos/asf/spark/blob/ced6ccf0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index fc68bf4..087b210 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -108,7 +108,7 @@ private [sql] object GenArrayData { } """ } - val assignmentString = ctx.splitExpressions( + val assignmentString = ctx.splitExpressionsWithCurrentInputs( expressions = assignments, funcName = "apply", extraArguments = ("Object[]", arrayDataName) :: Nil) @@ -139,7 +139,7 @@ private [sql] object GenArrayData { } """ } - val assignmentString = ctx.splitExpressions( + val assignmentString = ctx.splitExpressionsWithCurrentInputs( expressions = assignments, funcName = "apply", extraArguments = ("UnsafeArrayData", arrayDataName) :: Nil) @@ -357,7 +357,7 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc val rowClass = classOf[GenericInternalRow].getName val values = ctx.freshName("values") ctx.addMutableState("Object[]", values, s"$values = null;") - val valuesCode = ctx.splitExpressions( + val valuesCode = ctx.splitExpressionsWithCurrentInputs( valExprs.zipWithIndex.map { case (e, i) => val eval = e.genCode(ctx) s""" http://git-wip-us.apache.org/repos/asf/spark/blob/ced6ccf0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 43e6431..ae5f714 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -219,57 +219,51 @@ case class CaseWhen( val allConditions = cases ++ elseCode - val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) { - allConditions.mkString("\n") - } else { - // This generates code like: - // conditionMet = caseWhen_1(i); - // if(conditionMet) { - // continue; - // } - // conditionMet = caseWhen_2(i); - // if(conditionMet) { - // continue; - // } - // ... - // and the declared methods are: - // private boolean caseWhen_1234() { - // boolean conditionMet = false; - // do { - // // here the evaluation of the conditions - // } while (false); - // return conditionMet; - // } - ctx.splitExpressions(allConditions, "caseWhen", - ("InternalRow", ctx.INPUT_ROW) :: Nil, - returnType = ctx.JAVA_BOOLEAN, - makeSplitFunction = { - func => - s""" - ${ctx.JAVA_BOOLEAN} $conditionMet = false; - do { - $func - } while (false); - return $conditionMet; - """ - }, - foldFunctions = { funcCalls => - funcCalls.map { funcCall => - s""" - $conditionMet = $funcCall; - if ($conditionMet) { - continue; - }""" - }.mkString - }) - } + // This generates code like: + // conditionMet = caseWhen_1(i); + // if(conditionMet) { + // continue; + // } + // conditionMet = caseWhen_2(i); + // if(conditionMet) { + // continue; + // } + // ... + // and the declared methods are: + // private boolean caseWhen_1234() { + // boolean conditionMet = false; + // do { + // // here the evaluation of the conditions + // } while (false); + // return conditionMet; + // } + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = allConditions, + funcName = "caseWhen", + returnType = ctx.JAVA_BOOLEAN, + makeSplitFunction = func => + s""" + |${ctx.JAVA_BOOLEAN} $conditionMet = false; + |do { + | $func + |} while (false); + |return $conditionMet; + """.stripMargin, + foldFunctions = _.map { funcCall => + s""" + |$conditionMet = $funcCall; + |if ($conditionMet) { + | continue; + |} + """.stripMargin + }.mkString) ev.copy(code = s""" ${ev.isNull} = true; ${ev.value} = ${ctx.defaultValue(dataType)}; ${ctx.JAVA_BOOLEAN} $conditionMet = false; do { - $code + $codes } while (false);""") } } http://git-wip-us.apache.org/repos/asf/spark/blob/ced6ccf0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index f1aa130..cd38783 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -203,7 +203,7 @@ case class Stack(children: Seq[Expression]) extends Generator { ctx.addMutableState("InternalRow[]", rowData, s"$rowData = new InternalRow[$numRows];") val values = children.tail val dataTypes = values.take(numFields).map(_.dataType) - val code = ctx.splitExpressions(Seq.tabulate(numRows) { row => + val code = ctx.splitExpressionsWithCurrentInputs(Seq.tabulate(numRows) { row => val fields = Seq.tabulate(numFields) { col => val index = row * numFields + col if (index < values.length) values(index) else Literal(null, dataTypes(col)) http://git-wip-us.apache.org/repos/asf/spark/blob/ced6ccf0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index d0ed2ab..055ebf6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -279,21 +279,17 @@ abstract class HashExpression[E] extends Expression { } val hashResultType = ctx.javaType(dataType) - val codes = if (ctx.INPUT_ROW == null || ctx.currentVars != null) { - childrenHash.mkString("\n") - } else { - ctx.splitExpressions( - expressions = childrenHash, - funcName = "computeHash", - arguments = Seq("InternalRow" -> ctx.INPUT_ROW, hashResultType -> ev.value), - returnType = hashResultType, - makeSplitFunction = body => - s""" - |$body - |return ${ev.value}; - """.stripMargin, - foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) - } + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = childrenHash, + funcName = "computeHash", + extraArguments = Seq(hashResultType -> ev.value), + returnType = hashResultType, + makeSplitFunction = body => + s""" + |$body + |return ${ev.value}; + """.stripMargin, + foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) ev.copy(code = s""" @@ -652,22 +648,19 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { """.stripMargin } - val codes = if (ctx.INPUT_ROW == null || ctx.currentVars != null) { - childrenHash.mkString("\n") - } else { - ctx.splitExpressions( - expressions = childrenHash, - funcName = "computeHash", - arguments = Seq("InternalRow" -> ctx.INPUT_ROW, ctx.JAVA_INT -> ev.value), - returnType = ctx.JAVA_INT, - makeSplitFunction = body => - s""" - |${ctx.JAVA_INT} $childHash = 0; - |$body - |return ${ev.value}; - """.stripMargin, - foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) - } + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = childrenHash, + funcName = "computeHash", + extraArguments = Seq(ctx.JAVA_INT -> ev.value), + returnType = ctx.JAVA_INT, + makeSplitFunction = body => + s""" + |${ctx.JAVA_INT} $childHash = 0; + |$body + |return ${ev.value}; + """.stripMargin, + foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n")) + ev.copy(code = s""" http://git-wip-us.apache.org/repos/asf/spark/blob/ced6ccf0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 3b52a0e..26c9a41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -87,37 +87,32 @@ case class Coalesce(children: Seq[Expression]) extends Expression { |} """.stripMargin } - val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) { - evals.mkString("\n") - } else { - ctx.splitExpressions(evals, "coalesce", - ("InternalRow", ctx.INPUT_ROW) :: Nil, - makeSplitFunction = { - func => - s""" - |do { - | $func - |} while (false); - """.stripMargin - }, - foldFunctions = { funcCalls => - funcCalls.map { funcCall => - s""" - |$funcCall; - |if (!${ev.isNull}) { - | continue; - |} - """.stripMargin - }.mkString - }) - } + + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = evals, + funcName = "coalesce", + makeSplitFunction = func => + s""" + |do { + | $func + |} while (false); + """.stripMargin, + foldFunctions = _.map { funcCall => + s""" + |$funcCall; + |if (!${ev.isNull}) { + | continue; + |} + """.stripMargin + }.mkString) + ev.copy(code = s""" |${ev.isNull} = true; |${ev.value} = ${ctx.defaultValue(dataType)}; |do { - | $code + | $codes |} while (false); """.stripMargin) } @@ -415,39 +410,32 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate } } - val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) { - evals.mkString("\n") - } else { - ctx.splitExpressions( - expressions = evals, - funcName = "atLeastNNonNulls", - arguments = ("InternalRow", ctx.INPUT_ROW) :: (ctx.JAVA_INT, nonnull) :: Nil, - returnType = ctx.JAVA_INT, - makeSplitFunction = { body => - s""" - |do { - | $body - |} while (false); - |return $nonnull; - """.stripMargin - }, - foldFunctions = { funcCalls => - funcCalls.map(funcCall => - s""" - |$nonnull = $funcCall; - |if ($nonnull >= $n) { - | continue; - |} - """.stripMargin).mkString("\n") - } - ) - } + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = evals, + funcName = "atLeastNNonNulls", + extraArguments = (ctx.JAVA_INT, nonnull) :: Nil, + returnType = ctx.JAVA_INT, + makeSplitFunction = body => + s""" + |do { + | $body + |} while (false); + |return $nonnull; + """.stripMargin, + foldFunctions = _.map { funcCall => + s""" + |$nonnull = $funcCall; + |if ($nonnull >= $n) { + | continue; + |} + """.stripMargin + }.mkString) ev.copy(code = s""" |${ctx.JAVA_INT} $nonnull = 0; |do { - | $code + | $codes |} while (false); |${ctx.JAVA_BOOLEAN} ${ev.value} = $nonnull >= $n; """.stripMargin, isNull = "false") http://git-wip-us.apache.org/repos/asf/spark/blob/ced6ccf0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index e2bc79d..730b2ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -101,7 +101,7 @@ trait InvokeLike extends Expression with NonSQLExpression { """ } } - val argCode = ctx.splitExpressions(argCodes) + val argCode = ctx.splitExpressionsWithCurrentInputs(argCodes) (argCode, argValues.mkString(", "), resultIsNull) } @@ -1119,7 +1119,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) """ } - val childrenCode = ctx.splitExpressions(childrenCodes) + val childrenCode = ctx.splitExpressionsWithCurrentInputs(childrenCodes) val schemaField = ctx.addReferenceObj("schema", schema) val code = s""" @@ -1254,7 +1254,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp ${javaBeanInstance}.$setterMethod(${fieldGen.value}); """ } - val initializeCode = ctx.splitExpressions(initialize.toSeq) + val initializeCode = ctx.splitExpressionsWithCurrentInputs(initialize.toSeq) val code = s""" ${instanceGen.code} http://git-wip-us.apache.org/repos/asf/spark/blob/ced6ccf0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 75cc9b3..04e6694 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -253,31 +253,26 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { | continue; |} """.stripMargin) - val code = if (ctx.INPUT_ROW == null || ctx.currentVars != null) { - listCode.mkString("\n") - } else { - ctx.splitExpressions( - expressions = listCode, - funcName = "valueIn", - arguments = ("InternalRow", ctx.INPUT_ROW) :: (javaDataType, valueArg) :: Nil, - makeSplitFunction = { body => - s""" - |do { - | $body - |} while (false); - """.stripMargin - }, - foldFunctions = { funcCalls => - funcCalls.map(funcCall => - s""" - |$funcCall; - |if (${ev.value}) { - | continue; - |} - """.stripMargin).mkString("\n") - } - ) - } + + val codes = ctx.splitExpressionsWithCurrentInputs( + expressions = listCode, + funcName = "valueIn", + extraArguments = (javaDataType, valueArg) :: Nil, + makeSplitFunction = body => + s""" + |do { + | $body + |} while (false); + """.stripMargin, + foldFunctions = _.map { funcCall => + s""" + |$funcCall; + |if (${ev.value}) { + | continue; + |} + """.stripMargin + }.mkString("\n")) + ev.copy(code = s""" |${valueGen.code} @@ -286,7 +281,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { |if (!${ev.isNull}) { | $javaDataType $valueArg = ${valueGen.value}; | do { - | $code + | $codes | } while (false); |} """.stripMargin) http://git-wip-us.apache.org/repos/asf/spark/blob/ced6ccf0/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 34917ac..47f0b57 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -73,7 +73,7 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas } """ } - val codes = ctx.splitExpressions( + val codes = ctx.splitExpressionsWithCurrentInputs( expressions = inputs, funcName = "valueConcat", extraArguments = ("UTF8String[]", args) :: Nil) @@ -152,7 +152,7 @@ case class ConcatWs(children: Seq[Expression]) "" } } - val codes = ctx.splitExpressions( + val codes = ctx.splitExpressionsWithCurrentInputs( expressions = inputs, funcName = "valueConcatWs", extraArguments = ("UTF8String[]", args) :: Nil) @@ -200,31 +200,32 @@ case class ConcatWs(children: Seq[Expression]) } }.unzip - val codes = ctx.splitExpressions(evals.map(_.code)) - val varargCounts = ctx.splitExpressions( + val codes = ctx.splitExpressionsWithCurrentInputs(evals.map(_.code)) + + val varargCounts = ctx.splitExpressionsWithCurrentInputs( expressions = varargCount, funcName = "varargCountsConcatWs", - arguments = ("InternalRow", ctx.INPUT_ROW) :: Nil, returnType = "int", makeSplitFunction = body => s""" - int $varargNum = 0; - $body - return $varargNum; - """, - foldFunctions = _.mkString(s"$varargNum += ", s";\n$varargNum += ", ";")) - val varargBuilds = ctx.splitExpressions( + |int $varargNum = 0; + |$body + |return $varargNum; + """.stripMargin, + foldFunctions = _.map(funcCall => s"$varargNum += $funcCall;").mkString("\n")) + + val varargBuilds = ctx.splitExpressionsWithCurrentInputs( expressions = varargBuild, funcName = "varargBuildsConcatWs", - arguments = - ("InternalRow", ctx.INPUT_ROW) :: ("UTF8String []", array) :: ("int", idxInVararg) :: Nil, + extraArguments = ("UTF8String []", array) :: ("int", idxInVararg) :: Nil, returnType = "int", makeSplitFunction = body => s""" - $body - return $idxInVararg; - """, - foldFunctions = _.mkString(s"$idxInVararg = ", s";\n$idxInVararg = ", ";")) + |$body + |return $idxInVararg; + """.stripMargin, + foldFunctions = _.map(funcCall => s"$idxInVararg = $funcCall;").mkString("\n")) + ev.copy( s""" $codes @@ -1380,7 +1381,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC $argList[$index] = $value; """ } - val argListCodes = ctx.splitExpressions( + val argListCodes = ctx.splitExpressionsWithCurrentInputs( expressions = argListCode, funcName = "valueFormatString", extraArguments = ("Object[]", argList) :: Nil) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org