Repository: spark Updated Branches: refs/heads/master 1e07fff24 -> b70e483cb
[SPARK-22617][SQL] make splitExpressions extract current input of the context ## What changes were proposed in this pull request? Mostly when we call `CodegenContext.splitExpressions`, we want to split the code into methods and pass the current inputs of the codegen context to these methods so that the code in these methods can still be evaluated. This PR makes the expectation clear, while still keep the advanced version of `splitExpressions` to customize the inputs to pass to generated methods. ## How was this patch tested? existing test Author: Wenchen Fan <wenc...@databricks.com> Closes #19827 from cloud-fan/codegen. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b70e483c Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b70e483c Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b70e483c Branch: refs/heads/master Commit: b70e483cb32d07eaab80739cd0cfcd8fe922547c Parents: 1e07fff Author: Wenchen Fan <wenc...@databricks.com> Authored: Tue Nov 28 22:57:30 2017 +0800 Committer: Wenchen Fan <wenc...@databricks.com> Committed: Tue Nov 28 22:57:30 2017 +0800 ---------------------------------------------------------------------- .../sql/catalyst/expressions/arithmetic.scala | 4 +- .../expressions/codegen/CodeGenerator.scala | 13 ++-- .../codegen/GenerateMutableProjection.scala | 4 +- .../codegen/GenerateSafeProjection.scala | 37 ++++++----- .../codegen/GenerateUnsafeProjection.scala | 67 ++++++++++---------- .../expressions/complexTypeCreator.scala | 31 ++++----- .../sql/catalyst/expressions/generators.scala | 2 +- .../spark/sql/catalyst/expressions/hash.scala | 26 +++++--- .../catalyst/expressions/nullExpressions.scala | 2 +- .../catalyst/expressions/objects/objects.scala | 6 +- .../expressions/stringExpressions.scala | 2 +- 11 files changed, 108 insertions(+), 86 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/b70e483c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index e5a1096..d98f7b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -614,7 +614,7 @@ case class Least(children: Seq[Expression]) extends Expression { } """ } - val codes = ctx.splitExpressions(ctx.INPUT_ROW, evalChildren.map(updateEval)) + val codes = ctx.splitExpressions(evalChildren.map(updateEval)) ev.copy(code = s""" ${ev.isNull} = true; ${ev.value} = ${ctx.defaultValue(dataType)}; @@ -680,7 +680,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { } """ } - val codes = ctx.splitExpressions(ctx.INPUT_ROW, evalChildren.map(updateEval)) + val codes = ctx.splitExpressions(evalChildren.map(updateEval)) ev.copy(code = s""" ${ev.isNull} = true; ${ev.value} = ${ctx.defaultValue(dataType)}; http://git-wip-us.apache.org/repos/asf/spark/blob/b70e483c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 0498e61..668c816 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -781,15 +781,18 @@ class CodegenContext { * beyond 1000kb, we declare a private, inner sub-class, and the function is inlined to it * instead, because classes have a constant pool limit of 65,536 named values. * - * @param row the variable name of row that is used by expressions + * Note that we will extract the current inputs of this context and pass them to the generated + * functions. The input is `INPUT_ROW` for normal codegen path, and `currentVars` for whole + * stage codegen path. Whole stage codegen path is not supported yet. + * * @param expressions the codes to evaluate expressions. */ - def splitExpressions(row: String, expressions: Seq[String]): String = { - if (row == null || currentVars != null) { - // Cannot split these expressions because they are not created from a row object. + def splitExpressions(expressions: Seq[String]): String = { + // TODO: support whole stage codegen + if (INPUT_ROW == null || currentVars != null) { return expressions.mkString("\n") } - splitExpressions(expressions, funcName = "apply", arguments = ("InternalRow", row) :: Nil) + splitExpressions(expressions, funcName = "apply", arguments = ("InternalRow", INPUT_ROW) :: Nil) } /** http://git-wip-us.apache.org/repos/asf/spark/blob/b70e483c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 802e8bd..5fdbda5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -91,8 +91,8 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) } - val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes) - val allUpdates = ctx.splitExpressions(ctx.INPUT_ROW, updates) + val allProjections = ctx.splitExpressions(projectionCodes) + val allUpdates = ctx.splitExpressions(updates) val codeBody = s""" public java.lang.Object generate(Object[] references) { http://git-wip-us.apache.org/repos/asf/spark/blob/b70e483c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 1e4ac3f..5d35cce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -45,7 +45,8 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] ctx: CodegenContext, input: String, schema: StructType): ExprCode = { - val tmp = ctx.freshName("tmp") + // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. + val tmpInput = ctx.freshName("tmpInput") val output = ctx.freshName("safeRow") val values = ctx.freshName("values") // These expressions could be split into multiple functions @@ -54,17 +55,21 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val rowClass = classOf[GenericInternalRow].getName val fieldWriters = schema.map(_.dataType).zipWithIndex.map { case (dt, i) => - val converter = convertToSafe(ctx, ctx.getValue(tmp, dt, i.toString), dt) + val converter = convertToSafe(ctx, ctx.getValue(tmpInput, dt, i.toString), dt) s""" - if (!$tmp.isNullAt($i)) { + if (!$tmpInput.isNullAt($i)) { ${converter.code} $values[$i] = ${converter.value}; } """ } - val allFields = ctx.splitExpressions(tmp, fieldWriters) + val allFields = ctx.splitExpressions( + expressions = fieldWriters, + funcName = "writeFields", + arguments = Seq("InternalRow" -> tmpInput) + ) val code = s""" - final InternalRow $tmp = $input; + final InternalRow $tmpInput = $input; $values = new Object[${schema.length}]; $allFields final InternalRow $output = new $rowClass($values); @@ -78,20 +83,22 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] ctx: CodegenContext, input: String, elementType: DataType): ExprCode = { - val tmp = ctx.freshName("tmp") + // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. + val tmpInput = ctx.freshName("tmpInput") val output = ctx.freshName("safeArray") val values = ctx.freshName("values") val numElements = ctx.freshName("numElements") val index = ctx.freshName("index") val arrayClass = classOf[GenericArrayData].getName - val elementConverter = convertToSafe(ctx, ctx.getValue(tmp, elementType, index), elementType) + val elementConverter = convertToSafe( + ctx, ctx.getValue(tmpInput, elementType, index), elementType) val code = s""" - final ArrayData $tmp = $input; - final int $numElements = $tmp.numElements(); + final ArrayData $tmpInput = $input; + final int $numElements = $tmpInput.numElements(); final Object[] $values = new Object[$numElements]; for (int $index = 0; $index < $numElements; $index++) { - if (!$tmp.isNullAt($index)) { + if (!$tmpInput.isNullAt($index)) { ${elementConverter.code} $values[$index] = ${elementConverter.value}; } @@ -107,14 +114,14 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] input: String, keyType: DataType, valueType: DataType): ExprCode = { - val tmp = ctx.freshName("tmp") + val tmpInput = ctx.freshName("tmpInput") val output = ctx.freshName("safeMap") val mapClass = classOf[ArrayBasedMapData].getName - val keyConverter = createCodeForArray(ctx, s"$tmp.keyArray()", keyType) - val valueConverter = createCodeForArray(ctx, s"$tmp.valueArray()", valueType) + val keyConverter = createCodeForArray(ctx, s"$tmpInput.keyArray()", keyType) + val valueConverter = createCodeForArray(ctx, s"$tmpInput.valueArray()", valueType) val code = s""" - final MapData $tmp = $input; + final MapData $tmpInput = $input; ${keyConverter.code} ${valueConverter.code} final MapData $output = new $mapClass(${keyConverter.value}, ${valueConverter.value}); @@ -152,7 +159,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] } """ } - val allExpressions = ctx.splitExpressions(ctx.INPUT_ROW, expressionCodes) + val allExpressions = ctx.splitExpressions(expressionCodes) val codeBody = s""" public java.lang.Object generate(Object[] references) { http://git-wip-us.apache.org/repos/asf/spark/blob/b70e483c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 4bd50ae..b022457 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -36,7 +36,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case NullType => true case t: AtomicType => true case _: CalendarIntervalType => true - case t: StructType => t.toSeq.forall(field => canSupport(field.dataType)) + case t: StructType => t.forall(field => canSupport(field.dataType)) case t: ArrayType if canSupport(t.elementType) => true case MapType(kt, vt, _) if canSupport(kt) && canSupport(vt) => true case udt: UserDefinedType[_] => canSupport(udt.sqlType) @@ -49,25 +49,18 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro input: String, fieldTypes: Seq[DataType], bufferHolder: String): String = { + // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. + val tmpInput = ctx.freshName("tmpInput") val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) => - val javaType = ctx.javaType(dt) - val isNullVar = ctx.freshName("isNull") - val valueVar = ctx.freshName("value") - val defaultValue = ctx.defaultValue(dt) - val readValue = ctx.getValue(input, dt, i.toString) - val code = - s""" - boolean $isNullVar = $input.isNullAt($i); - $javaType $valueVar = $isNullVar ? $defaultValue : $readValue; - """ - ExprCode(code, isNullVar, valueVar) + ExprCode("", s"$tmpInput.isNullAt($i)", ctx.getValue(tmpInput, dt, i.toString)) } s""" - if ($input instanceof UnsafeRow) { - ${writeUnsafeData(ctx, s"((UnsafeRow) $input)", bufferHolder)} + final InternalRow $tmpInput = $input; + if ($tmpInput instanceof UnsafeRow) { + ${writeUnsafeData(ctx, s"((UnsafeRow) $tmpInput)", bufferHolder)} } else { - ${writeExpressionsToBuffer(ctx, input, fieldEvals, fieldTypes, bufferHolder)} + ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes, bufferHolder)} } """ } @@ -167,9 +160,20 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } } + val writeFieldsCode = if (isTopLevel && (row == null || ctx.currentVars != null)) { + // TODO: support whole stage codegen + writeFields.mkString("\n") + } else { + assert(row != null, "the input row name cannot be null when generating code to write it.") + ctx.splitExpressions( + expressions = writeFields, + funcName = "writeFields", + arguments = Seq("InternalRow" -> row)) + } + s""" $resetWriter - ${ctx.splitExpressions(row, writeFields)} + $writeFieldsCode """.trim } @@ -179,13 +183,14 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro input: String, elementType: DataType, bufferHolder: String): String = { + // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. + val tmpInput = ctx.freshName("tmpInput") val arrayWriterClass = classOf[UnsafeArrayWriter].getName val arrayWriter = ctx.freshName("arrayWriter") ctx.addMutableState(arrayWriterClass, arrayWriter, s"$arrayWriter = new $arrayWriterClass();") val numElements = ctx.freshName("numElements") val index = ctx.freshName("index") - val element = ctx.freshName("element") val et = elementType match { case udt: UserDefinedType[_] => udt.sqlType @@ -201,6 +206,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } val tmpCursor = ctx.freshName("tmpCursor") + val element = ctx.getValue(tmpInput, et, index) val writeElement = et match { case t: StructType => s""" @@ -233,17 +239,17 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val primitiveTypeName = if (ctx.isPrimitiveType(jt)) ctx.primitiveTypeName(et) else "" s""" - if ($input instanceof UnsafeArrayData) { - ${writeUnsafeData(ctx, s"((UnsafeArrayData) $input)", bufferHolder)} + final ArrayData $tmpInput = $input; + if ($tmpInput instanceof UnsafeArrayData) { + ${writeUnsafeData(ctx, s"((UnsafeArrayData) $tmpInput)", bufferHolder)} } else { - final int $numElements = $input.numElements(); + final int $numElements = $tmpInput.numElements(); $arrayWriter.initialize($bufferHolder, $numElements, $elementOrOffsetSize); for (int $index = 0; $index < $numElements; $index++) { - if ($input.isNullAt($index)) { + if ($tmpInput.isNullAt($index)) { $arrayWriter.setNull$primitiveTypeName($index); } else { - final $jt $element = ${ctx.getValue(input, et, index)}; $writeElement } } @@ -258,19 +264,16 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro keyType: DataType, valueType: DataType, bufferHolder: String): String = { - val keys = ctx.freshName("keys") - val values = ctx.freshName("values") + // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement. + val tmpInput = ctx.freshName("tmpInput") val tmpCursor = ctx.freshName("tmpCursor") - // Writes out unsafe map according to the format described in `UnsafeMapData`. s""" - if ($input instanceof UnsafeMapData) { - ${writeUnsafeData(ctx, s"((UnsafeMapData) $input)", bufferHolder)} + final MapData $tmpInput = $input; + if ($tmpInput instanceof UnsafeMapData) { + ${writeUnsafeData(ctx, s"((UnsafeMapData) $tmpInput)", bufferHolder)} } else { - final ArrayData $keys = $input.keyArray(); - final ArrayData $values = $input.valueArray(); - // preserve 8 bytes to write the key array numBytes later. $bufferHolder.grow(8); $bufferHolder.cursor += 8; @@ -278,11 +281,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // Remember the current cursor so that we can write numBytes of key array later. final int $tmpCursor = $bufferHolder.cursor; - ${writeArrayToBuffer(ctx, keys, keyType, bufferHolder)} + ${writeArrayToBuffer(ctx, s"$tmpInput.keyArray()", keyType, bufferHolder)} // Write the numBytes of key array into the first 8 bytes. Platform.putLong($bufferHolder.buffer, $tmpCursor - 8, $bufferHolder.cursor - $tmpCursor); - ${writeArrayToBuffer(ctx, values, valueType, bufferHolder)} + ${writeArrayToBuffer(ctx, s"$tmpInput.valueArray()", valueType, bufferHolder)} } """ } http://git-wip-us.apache.org/repos/asf/spark/blob/b70e483c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 2a00d57..57a7f2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -63,7 +63,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { val (preprocess, assigns, postprocess, arrayData) = GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false) ev.copy( - code = preprocess + ctx.splitExpressions(ctx.INPUT_ROW, assigns) + postprocess, + code = preprocess + ctx.splitExpressions(assigns) + postprocess, value = arrayData, isNull = "false") } @@ -216,10 +216,10 @@ case class CreateMap(children: Seq[Expression]) extends Expression { s""" final boolean ${ev.isNull} = false; $preprocessKeyData - ${ctx.splitExpressions(ctx.INPUT_ROW, assignKeys)} + ${ctx.splitExpressions(assignKeys)} $postprocessKeyData $preprocessValueData - ${ctx.splitExpressions(ctx.INPUT_ROW, assignValues)} + ${ctx.splitExpressions(assignValues)} $postprocessValueData final MapData ${ev.value} = new $mapClass($keyArrayData, $valueArrayData); """ @@ -351,24 +351,25 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc val rowClass = classOf[GenericInternalRow].getName val values = ctx.freshName("values") ctx.addMutableState("Object[]", values, s"$values = null;") - - ev.copy(code = s""" - $values = new Object[${valExprs.size}];""" + - ctx.splitExpressions( - ctx.INPUT_ROW, - valExprs.zipWithIndex.map { case (e, i) => - val eval = e.genCode(ctx) - eval.code + s""" + val valuesCode = ctx.splitExpressions( + valExprs.zipWithIndex.map { case (e, i) => + val eval = e.genCode(ctx) + s""" + ${eval.code} if (${eval.isNull}) { $values[$i] = null; } else { $values[$i] = ${eval.value}; }""" - }) + + }) + + ev.copy(code = s""" - final InternalRow ${ev.value} = new $rowClass($values); - $values = null; - """, isNull = "false") + |$values = new Object[${valExprs.size}]; + |$valuesCode + |final InternalRow ${ev.value} = new $rowClass($values); + |$values = null; + """.stripMargin, isNull = "false") } override def prettyName: String = "named_struct" http://git-wip-us.apache.org/repos/asf/spark/blob/b70e483c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 8618f49..f1aa130 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -203,7 +203,7 @@ case class Stack(children: Seq[Expression]) extends Generator { ctx.addMutableState("InternalRow[]", rowData, s"$rowData = new InternalRow[$numRows];") val values = children.tail val dataTypes = values.take(numFields).map(_.dataType) - val code = ctx.splitExpressions(ctx.INPUT_ROW, Seq.tabulate(numRows) { row => + val code = ctx.splitExpressions(Seq.tabulate(numRows) { row => val fields = Seq.tabulate(numFields) { col => val index = row * numFields + col if (index < values.length) values(index) else Literal(null, dataTypes(col)) http://git-wip-us.apache.org/repos/asf/spark/blob/b70e483c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 9e0786e..c3289b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -270,7 +270,7 @@ abstract class HashExpression[E] extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { ev.isNull = "false" - val childrenHash = ctx.splitExpressions(ctx.INPUT_ROW, children.map { child => + val childrenHash = ctx.splitExpressions(children.map { child => val childGen = child.genCode(ctx) childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) { computeHash(childGen.value, child.dataType, ev.value, ctx) @@ -330,9 +330,9 @@ abstract class HashExpression[E] extends Expression { } else { val bytes = ctx.freshName("bytes") s""" - final byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray(); - ${genHashBytes(bytes, result)} - """ + |final byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray(); + |${genHashBytes(bytes, result)} + """.stripMargin } } @@ -392,7 +392,10 @@ abstract class HashExpression[E] extends Expression { val hashes = fields.zipWithIndex.map { case (field, index) => nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx) } - ctx.splitExpressions(input, hashes) + ctx.splitExpressions( + expressions = hashes, + funcName = "getHash", + arguments = Seq("InternalRow" -> input)) } @tailrec @@ -608,12 +611,17 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { ev.isNull = "false" val childHash = ctx.freshName("childHash") - val childrenHash = ctx.splitExpressions(ctx.INPUT_ROW, children.map { child => + val childrenHash = ctx.splitExpressions(children.map { child => val childGen = child.genCode(ctx) - childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) { + val codeToComputeHash = ctx.nullSafeExec(child.nullable, childGen.isNull) { computeHash(childGen.value, child.dataType, childHash, ctx) - } + s"${ev.value} = (31 * ${ev.value}) + $childHash;" + - s"\n$childHash = 0;" + } + s""" + |${childGen.code} + |$codeToComputeHash + |${ev.value} = (31 * ${ev.value}) + $childHash; + |$childHash = 0; + """.stripMargin }) ctx.addMutableState(ctx.javaType(dataType), ev.value) http://git-wip-us.apache.org/repos/asf/spark/blob/b70e483c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 5eaf3f2..173e171 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -91,7 +91,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { ev.copy(code = s""" ${ev.isNull} = true; ${ev.value} = ${ctx.defaultValue(dataType)}; - ${ctx.splitExpressions(ctx.INPUT_ROW, evals)}""") + ${ctx.splitExpressions(evals)}""") } } http://git-wip-us.apache.org/repos/asf/spark/blob/b70e483c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 006d37f..e2bc79d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -101,7 +101,7 @@ trait InvokeLike extends Expression with NonSQLExpression { """ } } - val argCode = ctx.splitExpressions(ctx.INPUT_ROW, argCodes) + val argCode = ctx.splitExpressions(argCodes) (argCode, argValues.mkString(", "), resultIsNull) } @@ -1119,7 +1119,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) """ } - val childrenCode = ctx.splitExpressions(ctx.INPUT_ROW, childrenCodes) + val childrenCode = ctx.splitExpressions(childrenCodes) val schemaField = ctx.addReferenceObj("schema", schema) val code = s""" @@ -1254,7 +1254,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp ${javaBeanInstance}.$setterMethod(${fieldGen.value}); """ } - val initializeCode = ctx.splitExpressions(ctx.INPUT_ROW, initialize.toSeq) + val initializeCode = ctx.splitExpressions(initialize.toSeq) val code = s""" ${instanceGen.code} http://git-wip-us.apache.org/repos/asf/spark/blob/b70e483c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index d629eb7..ee5cf92 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -208,7 +208,7 @@ case class ConcatWs(children: Seq[Expression]) } }.unzip - val codes = ctx.splitExpressions(ctx.INPUT_ROW, evals.map(_.code)) + val codes = ctx.splitExpressions(evals.map(_.code)) val varargCounts = ctx.splitExpressions( expressions = varargCount, funcName = "varargCountsConcatWs", --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org