Repository: spark
Updated Branches:
  refs/heads/master c0abb1d99 -> c6f01cade


[SPARK-22750][SQL] Reuse mutable states when possible

## What changes were proposed in this pull request?

The PR introduces a new method `addImmutableStateIfNotExists ` to 
`CodeGenerator` to allow reusing and sharing the same global variable between 
different Expressions. This helps reducing the number of global variables 
needed, which is important to limit the impact on the constant pool.

## How was this patch tested?

added UTs

Author: Marco Gaido <marcogaid...@gmail.com>
Author: Marco Gaido <mga...@hortonworks.com>

Closes #19940 from mgaido91/SPARK-22750.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c6f01cad
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c6f01cad
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c6f01cad

Branch: refs/heads/master
Commit: c6f01cadede490bde987d067becef14442f1e4a1
Parents: c0abb1d
Author: Marco Gaido <marcogaid...@gmail.com>
Authored: Fri Dec 22 10:13:26 2017 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Fri Dec 22 10:13:26 2017 +0800

----------------------------------------------------------------------
 .../expressions/MonotonicallyIncreasingID.scala |  3 +-
 .../catalyst/expressions/SparkPartitionID.scala |  3 +-
 .../expressions/codegen/CodeGenerator.scala     | 40 ++++++++++++++++++++
 .../expressions/datetimeExpressions.scala       | 12 ++++--
 .../catalyst/expressions/objects/objects.scala  | 24 ++++++++----
 .../expressions/CodeGenerationSuite.scala       | 12 ++++++
 6 files changed, 80 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c6f01cad/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
index 784eaf8..11fb579 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala
@@ -66,7 +66,8 @@ case class MonotonicallyIncreasingID() extends LeafExpression 
with Nondeterminis
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val countTerm = ctx.addMutableState(ctx.JAVA_LONG, "count")
-    val partitionMaskTerm = ctx.addMutableState(ctx.JAVA_LONG, "partitionMask")
+    val partitionMaskTerm = "partitionMask"
+    ctx.addImmutableStateIfNotExists(ctx.JAVA_LONG, partitionMaskTerm)
     ctx.addPartitionInitializationStatement(s"$countTerm = 0L;")
     ctx.addPartitionInitializationStatement(s"$partitionMaskTerm = ((long) 
partitionIndex) << 33;")
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c6f01cad/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
index 736ca37..a160b9b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala
@@ -43,7 +43,8 @@ case class SparkPartitionID() extends LeafExpression with 
Nondeterministic {
   override protected def evalInternal(input: InternalRow): Int = partitionId
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    val idTerm = ctx.addMutableState(ctx.JAVA_INT, "partitionId")
+    val idTerm = "partitionId"
+    ctx.addImmutableStateIfNotExists(ctx.JAVA_INT, idTerm)
     ctx.addPartitionInitializationStatement(s"$idTerm = partitionIndex;")
     ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", 
isNull = "false")
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/c6f01cad/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 9adf632..d6eccad 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -208,6 +208,14 @@ class CodegenContext {
   }
 
   /**
+   * A map containing the mutable states which have been defined so far using
+   * `addImmutableStateIfNotExists`. Each entry contains the name of the 
mutable state as key and
+   * its Java type and init code as value.
+   */
+  private val immutableStates: mutable.Map[String, (String, String)] =
+    mutable.Map.empty[String, (String, String)]
+
+  /**
    * Add a mutable state as a field to the generated class. c.f. the comments 
above.
    *
    * @param javaType Java type of the field. Note that short names can be used 
for some types,
@@ -266,6 +274,38 @@ class CodegenContext {
   }
 
   /**
+   * Add an immutable state as a field to the generated class only if it does 
not exist yet a field
+   * with that name. This helps reducing the number of the generated class' 
fields, since the same
+   * variable can be reused by many functions.
+   *
+   * Even though the added variables are not declared as final, they should 
never be reassigned in
+   * the generated code to prevent errors and unexpected behaviors.
+   *
+   * Internally, this method calls `addMutableState`.
+   *
+   * @param javaType Java type of the field.
+   * @param variableName Name of the field.
+   * @param initFunc Function includes statement(s) to put into the init() 
method to initialize
+   *                 this field. The argument is the name of the mutable state 
variable.
+   */
+  def addImmutableStateIfNotExists(
+      javaType: String,
+      variableName: String,
+      initFunc: String => String = _ => ""): Unit = {
+    val existingImmutableState = immutableStates.get(variableName)
+    if (existingImmutableState.isEmpty) {
+      addMutableState(javaType, variableName, initFunc, useFreshName = false, 
forceInline = true)
+      immutableStates(variableName) = (javaType, initFunc(variableName))
+    } else {
+      val (prevJavaType, prevInitCode) = existingImmutableState.get
+      assert(prevJavaType == javaType, s"$variableName has already been 
defined with type " +
+        s"$prevJavaType and now it is tried to define again with type 
$javaType.")
+      assert(prevInitCode == initFunc(variableName), s"$variableName has 
already been defined " +
+        s"with different initialization statements.")
+    }
+  }
+
+  /**
    * Add buffer variable which stores data coming from an [[InternalRow]]. 
This methods guarantees
    * that the variable is safely stored, which is important for (potentially) 
byte array backed
    * data types like: UTF8String, ArrayData, MapData & InternalRow.

http://git-wip-us.apache.org/repos/asf/spark/blob/c6f01cad/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
index 59c3e3d..7a674ea 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala
@@ -443,7 +443,8 @@ case class DayOfWeek(child: Expression) extends 
UnaryExpression with ImplicitCas
     nullSafeCodeGen(ctx, ev, time => {
       val cal = classOf[Calendar].getName
       val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
-      val c = ctx.addMutableState(cal, "cal",
+      val c = "calDayOfWeek"
+      ctx.addImmutableStateIfNotExists(cal, c,
         v => s"""$v = $cal.getInstance($dtu.getTimeZone("UTC"));""")
       s"""
         $c.setTimeInMillis($time * 1000L * 3600L * 24L);
@@ -484,8 +485,9 @@ case class WeekOfYear(child: Expression) extends 
UnaryExpression with ImplicitCa
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     nullSafeCodeGen(ctx, ev, time => {
       val cal = classOf[Calendar].getName
+      val c = "calWeekOfYear"
       val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
-      val c = ctx.addMutableState(cal, "cal", v =>
+      ctx.addImmutableStateIfNotExists(cal, c, v =>
         s"""
            |$v = $cal.getInstance($dtu.getTimeZone("UTC"));
            |$v.setFirstDayOfWeek($cal.MONDAY);
@@ -1017,7 +1019,8 @@ case class FromUTCTimestamp(left: Expression, right: 
Expression)
         val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
         val tzTerm = ctx.addMutableState(tzClass, "tz",
           v => s"""$v = $dtu.getTimeZone("$tz");""")
-        val utcTerm = ctx.addMutableState(tzClass, "utc",
+        val utcTerm = "tzUTC"
+        ctx.addImmutableStateIfNotExists(tzClass, utcTerm,
           v => s"""$v = $dtu.getTimeZone("UTC");""")
         val eval = left.genCode(ctx)
         ev.copy(code = s"""
@@ -1193,7 +1196,8 @@ case class ToUTCTimestamp(left: Expression, right: 
Expression)
         val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
         val tzTerm = ctx.addMutableState(tzClass, "tz",
           v => s"""$v = $dtu.getTimeZone("$tz");""")
-        val utcTerm = ctx.addMutableState(tzClass, "utc",
+        val utcTerm = "tzUTC"
+        ctx.addImmutableStateIfNotExists(tzClass, utcTerm,
           v => s"""$v = $dtu.getTimeZone("UTC");""")
         val eval = left.genCode(ctx)
         ev.copy(code = s"""

http://git-wip-us.apache.org/repos/asf/spark/blob/c6f01cad/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index a59aad5..4af8134 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -1148,17 +1148,21 @@ case class EncodeUsingSerializer(child: Expression, 
kryo: Boolean)
 
   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
     // Code to initialize the serializer.
-    val (serializerClass, serializerInstanceClass) = {
+    val (serializer, serializerClass, serializerInstanceClass) = {
       if (kryo) {
-        (classOf[KryoSerializer].getName, 
classOf[KryoSerializerInstance].getName)
+        ("kryoSerializer",
+          classOf[KryoSerializer].getName,
+          classOf[KryoSerializerInstance].getName)
       } else {
-        (classOf[JavaSerializer].getName, 
classOf[JavaSerializerInstance].getName)
+        ("javaSerializer",
+          classOf[JavaSerializer].getName,
+          classOf[JavaSerializerInstance].getName)
       }
     }
     // try conf from env, otherwise create a new one
     val env = s"${classOf[SparkEnv].getName}.get()"
     val sparkConf = s"new ${classOf[SparkConf].getName}()"
-    val serializer = ctx.addMutableState(serializerInstanceClass, 
"serializerForEncode", v =>
+    ctx.addImmutableStateIfNotExists(serializerInstanceClass, serializer, v =>
       s"""
          |if ($env == null) {
          |  $v = ($serializerInstanceClass) new 
$serializerClass($sparkConf).newInstance();
@@ -1193,17 +1197,21 @@ case class DecodeUsingSerializer[T](child: Expression, 
tag: ClassTag[T], kryo: B
 
   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
     // Code to initialize the serializer.
-    val (serializerClass, serializerInstanceClass) = {
+    val (serializer, serializerClass, serializerInstanceClass) = {
       if (kryo) {
-        (classOf[KryoSerializer].getName, 
classOf[KryoSerializerInstance].getName)
+        ("kryoSerializer",
+          classOf[KryoSerializer].getName,
+          classOf[KryoSerializerInstance].getName)
       } else {
-        (classOf[JavaSerializer].getName, 
classOf[JavaSerializerInstance].getName)
+        ("javaSerializer",
+          classOf[JavaSerializer].getName,
+          classOf[JavaSerializerInstance].getName)
       }
     }
     // try conf from env, otherwise create a new one
     val env = s"${classOf[SparkEnv].getName}.get()"
     val sparkConf = s"new ${classOf[SparkConf].getName}()"
-    val serializer = ctx.addMutableState(serializerInstanceClass, 
"serializerForDecode", v =>
+    ctx.addImmutableStateIfNotExists(serializerInstanceClass, serializer, v =>
       s"""
          |if ($env == null) {
          |  $v = ($serializerInstanceClass) new 
$serializerClass($sparkConf).newInstance();

http://git-wip-us.apache.org/repos/asf/spark/blob/c6f01cad/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index b1a4452..676ba39 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -424,4 +424,16 @@ class CodeGenerationSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     assert(ctx2.arrayCompactedMutableStates("InternalRow[]").getCurrentIndex 
== 10)
     assert(ctx2.mutableStateInitCode.size == 
CodeGenerator.MUTABLESTATEARRAY_SIZE_LIMIT + 10)
   }
+
+  test("SPARK-22750: addImmutableStateIfNotExists") {
+    val ctx = new CodegenContext
+    val mutableState1 = "field1"
+    val mutableState2 = "field2"
+    ctx.addImmutableStateIfNotExists("int", mutableState1)
+    ctx.addImmutableStateIfNotExists("int", mutableState1)
+    ctx.addImmutableStateIfNotExists("String", mutableState2)
+    ctx.addImmutableStateIfNotExists("int", mutableState1)
+    ctx.addImmutableStateIfNotExists("String", mutableState2)
+    assert(ctx.inlinedMutableStates.length == 2)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to