Repository: spark Updated Branches: refs/heads/master 932bd09c8 -> 999ec137a
[SPARK-22570][SQL] Avoid to create a lot of global variables by using a local variable with allocation of an object in generated code ## What changes were proposed in this pull request? This PR reduces # of global variables in generated code by replacing a global variable with a local variable with an allocation of an object every time. When a lot of global variables were generated, the generated code may meet 64K constant pool limit. This PR reduces # of generated global variables in the following three operations: * `Cast` with String to primitive byte/short/int/long * `RegExpReplace` * `CreateArray` I intentionally leave [this part](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala#L595-L603). This is because this variable keeps a class that is dynamically generated. In other word, it is not possible to reuse one class. ## How was this patch tested? Added test cases Author: Kazuaki Ishizaki <ishiz...@jp.ibm.com> Closes #19797 from kiszk/SPARK-22570. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/999ec137 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/999ec137 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/999ec137 Branch: refs/heads/master Commit: 999ec137a97844abbbd483dd98c7ded2f8ff356c Parents: 932bd09 Author: Kazuaki Ishizaki <ishiz...@jp.ibm.com> Authored: Fri Dec 1 02:28:24 2017 +0800 Committer: Wenchen Fan <wenc...@databricks.com> Committed: Fri Dec 1 02:28:24 2017 +0800 ---------------------------------------------------------------------- .../spark/sql/catalyst/expressions/Cast.scala | 24 ++++++------- .../expressions/complexTypeCreator.scala | 36 ++++++++++++-------- .../expressions/regexpExpressions.scala | 8 ++--- .../sql/catalyst/expressions/CastSuite.scala | 8 +++++ .../expressions/RegexpExpressionsSuite.scala | 11 +++++- .../catalyst/optimizer/complexTypesSuite.scala | 7 ++++ 6 files changed, 61 insertions(+), 33 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/999ec137/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 8cafaef..f4ecbdb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -799,16 +799,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val wrapper = ctx.freshName("wrapper") - ctx.addMutableState("UTF8String.IntWrapper", wrapper, - s"$wrapper = new UTF8String.IntWrapper();") + val wrapper = ctx.freshName("intWrapper") (c, evPrim, evNull) => s""" + UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); if ($c.toByte($wrapper)) { $evPrim = (byte) $wrapper.value; } else { $evNull = true; } + $wrapper = null; """ case BooleanType => (c, evPrim, evNull) => s"$evPrim = $c ? (byte) 1 : (byte) 0;" @@ -826,16 +826,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val wrapper = ctx.freshName("wrapper") - ctx.addMutableState("UTF8String.IntWrapper", wrapper, - s"$wrapper = new UTF8String.IntWrapper();") + val wrapper = ctx.freshName("intWrapper") (c, evPrim, evNull) => s""" + UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); if ($c.toShort($wrapper)) { $evPrim = (short) $wrapper.value; } else { $evNull = true; } + $wrapper = null; """ case BooleanType => (c, evPrim, evNull) => s"$evPrim = $c ? (short) 1 : (short) 0;" @@ -851,16 +851,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val wrapper = ctx.freshName("wrapper") - ctx.addMutableState("UTF8String.IntWrapper", wrapper, - s"$wrapper = new UTF8String.IntWrapper();") + val wrapper = ctx.freshName("intWrapper") (c, evPrim, evNull) => s""" + UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper(); if ($c.toInt($wrapper)) { $evPrim = $wrapper.value; } else { $evNull = true; } + $wrapper = null; """ case BooleanType => (c, evPrim, evNull) => s"$evPrim = $c ? 1 : 0;" @@ -876,17 +876,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => - val wrapper = ctx.freshName("wrapper") - ctx.addMutableState("UTF8String.LongWrapper", wrapper, - s"$wrapper = new UTF8String.LongWrapper();") + val wrapper = ctx.freshName("longWrapper") (c, evPrim, evNull) => s""" + UTF8String.LongWrapper $wrapper = new UTF8String.LongWrapper(); if ($c.toLong($wrapper)) { $evPrim = $wrapper.value; } else { $evNull = true; } + $wrapper = null; """ case BooleanType => (c, evPrim, evNull) => s"$evPrim = $c ? 1L : 0L;" http://git-wip-us.apache.org/repos/asf/spark/blob/999ec137/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 57a7f2e..fc68bf4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -63,7 +63,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { val (preprocess, assigns, postprocess, arrayData) = GenArrayData.genCodeToCreateArrayData(ctx, et, evals, false) ev.copy( - code = preprocess + ctx.splitExpressions(assigns) + postprocess, + code = preprocess + assigns + postprocess, value = arrayData, isNull = "false") } @@ -77,24 +77,22 @@ private [sql] object GenArrayData { * * @param ctx a [[CodegenContext]] * @param elementType data type of underlying array elements - * @param elementsCode a set of [[ExprCode]] for each element of an underlying array + * @param elementsCode concatenated set of [[ExprCode]] for each element of an underlying array * @param isMapKey if true, throw an exception when the element is null - * @return (code pre-assignments, assignments to each array elements, code post-assignments, - * arrayData name) + * @return (code pre-assignments, concatenated assignments to each array elements, + * code post-assignments, arrayData name) */ def genCodeToCreateArrayData( ctx: CodegenContext, elementType: DataType, elementsCode: Seq[ExprCode], - isMapKey: Boolean): (String, Seq[String], String, String) = { - val arrayName = ctx.freshName("array") + isMapKey: Boolean): (String, String, String, String) = { val arrayDataName = ctx.freshName("arrayData") val numElements = elementsCode.length if (!ctx.isPrimitiveType(elementType)) { + val arrayName = ctx.freshName("arrayObject") val genericArrayClass = classOf[GenericArrayData].getName - ctx.addMutableState("Object[]", arrayName, - s"$arrayName = new Object[$numElements];") val assignments = elementsCode.zipWithIndex.map { case (eval, i) => val isNullAssignment = if (!isMapKey) { @@ -110,17 +108,21 @@ private [sql] object GenArrayData { } """ } + val assignmentString = ctx.splitExpressions( + expressions = assignments, + funcName = "apply", + extraArguments = ("Object[]", arrayDataName) :: Nil) - ("", - assignments, + (s"Object[] $arrayName = new Object[$numElements];", + assignmentString, s"final ArrayData $arrayDataName = new $genericArrayClass($arrayName);", arrayDataName) } else { + val arrayName = ctx.freshName("array") val unsafeArraySizeInBytes = UnsafeArrayData.calculateHeaderPortionInBytes(numElements) + ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements) val baseOffset = Platform.BYTE_ARRAY_OFFSET - ctx.addMutableState("UnsafeArrayData", arrayDataName) val primitiveValueTypeName = ctx.primitiveTypeName(elementType) val assignments = elementsCode.zipWithIndex.map { case (eval, i) => @@ -137,14 +139,18 @@ private [sql] object GenArrayData { } """ } + val assignmentString = ctx.splitExpressions( + expressions = assignments, + funcName = "apply", + extraArguments = ("UnsafeArrayData", arrayDataName) :: Nil) (s""" byte[] $arrayName = new byte[$unsafeArraySizeInBytes]; - $arrayDataName = new UnsafeArrayData(); + UnsafeArrayData $arrayDataName = new UnsafeArrayData(); Platform.putLong($arrayName, $baseOffset, $numElements); $arrayDataName.pointTo($arrayName, $baseOffset, $unsafeArraySizeInBytes); """, - assignments, + assignmentString, "", arrayDataName) } @@ -216,10 +222,10 @@ case class CreateMap(children: Seq[Expression]) extends Expression { s""" final boolean ${ev.isNull} = false; $preprocessKeyData - ${ctx.splitExpressions(assignKeys)} + $assignKeys $postprocessKeyData $preprocessValueData - ${ctx.splitExpressions(assignValues)} + $assignValues $postprocessValueData final MapData ${ev.value} = new $mapClass($keyArrayData, $valueArrayData); """ http://git-wip-us.apache.org/repos/asf/spark/blob/999ec137/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index d0d663f..53d7096 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -321,8 +321,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio val termLastReplacement = ctx.freshName("lastReplacement") val termLastReplacementInUTF8 = ctx.freshName("lastReplacementInUTF8") - - val termResult = ctx.freshName("result") + val termResult = ctx.freshName("termResult") val classNamePattern = classOf[Pattern].getCanonicalName val classNameStringBuffer = classOf[java.lang.StringBuffer].getCanonicalName @@ -334,8 +333,6 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio ctx.addMutableState("String", termLastReplacement, s"${termLastReplacement} = null;") ctx.addMutableState("UTF8String", termLastReplacementInUTF8, s"${termLastReplacementInUTF8} = null;") - ctx.addMutableState(classNameStringBuffer, - termResult, s"${termResult} = new $classNameStringBuffer();") val setEvNotNull = if (nullable) { s"${ev.isNull} = false;" @@ -355,7 +352,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio ${termLastReplacementInUTF8} = $rep.clone(); ${termLastReplacement} = ${termLastReplacementInUTF8}.toString(); } - ${termResult}.delete(0, ${termResult}.length()); + $classNameStringBuffer ${termResult} = new $classNameStringBuffer(); java.util.regex.Matcher ${matcher} = ${termPattern}.matcher($subject.toString()); while (${matcher}.find()) { @@ -363,6 +360,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio } ${matcher}.appendTail(${termResult}); ${ev.value} = UTF8String.fromString(${termResult}.toString()); + ${termResult} = null; $setEvNotNull """ }) http://git-wip-us.apache.org/repos/asf/spark/blob/999ec137/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 7837d65..65617be 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -23,6 +23,7 @@ import java.util.{Calendar, Locale, TimeZone} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.util.DateTimeTestUtils._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils.TimeZoneGMT @@ -845,4 +846,11 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { val outputOuter = Row.fromSeq((1 to N).map(_ => outputInner)) checkEvaluation(cast(Literal.create(inputOuter, fromOuter), toOuter), outputOuter) } + + test("SPARK-22570: Cast should not create a lot of global variables") { + val ctx = new CodegenContext + cast("1", IntegerType).genCode(ctx) + cast("2", LongType).genCode(ctx) + assert(ctx.mutableStates.length == 0) + } } http://git-wip-us.apache.org/repos/asf/spark/blob/999ec137/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala index 1ce150e..4fa61fb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{IntegerType, StringType} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.types.StringType /** * Unit tests for regular expression (regexp) related SQL expressions. @@ -178,6 +179,14 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(nonNullExpr, "num-num", row1) } + test("SPARK-22570: RegExpReplace should not create a lot of global variables") { + val ctx = new CodegenContext + RegExpReplace(Literal("100"), Literal("(\\d+)"), Literal("num")).genCode(ctx) + // four global variables (lastRegex, pattern, lastReplacement, and lastReplacementInUTF8) + // are always required + assert(ctx.mutableStates.length == 4) + } + test("RegexExtract") { val row1 = create_row("100-200", "(\\d+)-(\\d+)", 1) val row2 = create_row("100-200", "(\\d+)-(\\d+)", 2) http://git-wip-us.apache.org/repos/asf/spark/blob/999ec137/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index 3634acc..e367536 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range} import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -164,6 +165,12 @@ class ComplexTypesSuite extends PlanTest{ comparePlans(Optimizer execute query, expected) } + test("SPARK-22570: CreateArray should not create a lot of global variables") { + val ctx = new CodegenContext + CreateArray(Seq(Literal(1))).genCode(ctx) + assert(ctx.mutableStates.length == 0) + } + test("simplify map ops") { val rel = relation .select( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org