Repository: spark
Updated Branches:
  refs/heads/master 214d1be4f -> 0513c3ac9


[SPARK-14637][SQL] object expressions cleanup

## What changes were proposed in this pull request?

Simplify and clean up some object expressions:

1. simplify the logic to handle `propagateNull`
2. add `propagateNull` parameter to `Invoke`
3. simplify the unbox logic in `Invoke`
4. other minor cleanup

TODO: simplify `MapObjects`

## How was this patch tested?

existing tests.

Author: Wenchen Fan <wenc...@databricks.com>

Closes #12399 from cloud-fan/object.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0513c3ac
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0513c3ac
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0513c3ac

Branch: refs/heads/master
Commit: 0513c3ac93e0a25d6eedbafe6c0561e71c92880a
Parents: 214d1be
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Mon May 2 10:21:14 2016 -0700
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Mon May 2 10:21:14 2016 -0700

----------------------------------------------------------------------
 .../sql/catalyst/expressions/objects.scala      | 218 +++++++++----------
 1 file changed, 100 insertions(+), 118 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0513c3ac/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
index 1e41854..523eed8 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala
@@ -64,33 +64,29 @@ case class StaticInvoke(
     val argGen = arguments.map(_.genCode(ctx))
     val argString = argGen.map(_.value).mkString(", ")
 
-    if (propagateNull) {
-      val objNullCheck = if (ctx.defaultValue(dataType) == "null") {
-        s"${ev.isNull} = ${ev.value} == null;"
-      } else {
-        ""
-      }
-
-      val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})"
-      ev.copy(code = s"""
-        ${argGen.map(_.code).mkString("\n")}
-
-        boolean ${ev.isNull} = !$argsNonNull;
-        $javaType ${ev.value} = ${ctx.defaultValue(dataType)};
+    val callFunc = s"$objectName.$functionName($argString)"
 
-        if ($argsNonNull) {
-          ${ev.value} = $objectName.$functionName($argString);
-          $objNullCheck
-        }
-       """)
+    val setIsNull = if (propagateNull && arguments.nonEmpty) {
+      s"boolean ${ev.isNull} = ${argGen.map(_.isNull).mkString(" || ")};"
     } else {
-      ev.copy(code = s"""
-        ${argGen.map(_.code).mkString("\n")}
+      s"boolean ${ev.isNull} = false;"
+    }
 
-        $javaType ${ev.value} = $objectName.$functionName($argString);
-        final boolean ${ev.isNull} = ${ev.value} == null;
-      """)
+    // If the function can return null, we do an extra check to make sure our 
null bit is still set
+    // correctly.
+    val postNullCheck = if (ctx.defaultValue(dataType) == "null") {
+      s"${ev.isNull} = ${ev.value} == null;"
+    } else {
+      ""
     }
+
+    val code = s"""
+      ${argGen.map(_.code).mkString("\n")}
+      $setIsNull
+      final $javaType ${ev.value} = ${ev.isNull} ? 
${ctx.defaultValue(dataType)} : $callFunc;
+      $postNullCheck
+     """
+    ev.copy(code = code)
   }
 }
 
@@ -111,7 +107,8 @@ case class Invoke(
     targetObject: Expression,
     functionName: String,
     dataType: DataType,
-    arguments: Seq[Expression] = Nil) extends Expression with NonSQLExpression 
{
+    arguments: Seq[Expression] = Nil,
+    propagateNull: Boolean = true) extends Expression with NonSQLExpression {
 
   override def nullable: Boolean = true
   override def children: Seq[Expression] = targetObject +: arguments
@@ -130,60 +127,53 @@ case class Invoke(
     case _ => None
   }
 
-  lazy val unboxer = (dataType, 
method.map(_.getReturnType.getName).getOrElse("")) match {
-    case (IntegerType, "java.lang.Object") => (s: String) =>
-      s"((java.lang.Integer)$s).intValue()"
-    case (LongType, "java.lang.Object") => (s: String) =>
-      s"((java.lang.Long)$s).longValue()"
-    case (FloatType, "java.lang.Object") => (s: String) =>
-      s"((java.lang.Float)$s).floatValue()"
-    case (ShortType, "java.lang.Object") => (s: String) =>
-      s"((java.lang.Short)$s).shortValue()"
-    case (ByteType, "java.lang.Object") => (s: String) =>
-      s"((java.lang.Byte)$s).byteValue()"
-    case (DoubleType, "java.lang.Object") => (s: String) =>
-      s"((java.lang.Double)$s).doubleValue()"
-    case (BooleanType, "java.lang.Object") => (s: String) =>
-      s"((java.lang.Boolean)$s).booleanValue()"
-    case _ => identity[String] _
-  }
-
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val javaType = ctx.javaType(dataType)
     val obj = targetObject.genCode(ctx)
     val argGen = arguments.map(_.genCode(ctx))
     val argString = argGen.map(_.value).mkString(", ")
 
-    // If the function can return null, we do an extra check to make sure our 
null bit is still set
-    // correctly.
-    val objNullCheck = if (ctx.defaultValue(dataType) == "null") {
-      s"boolean ${ev.isNull} = ${ev.value} == null;"
+    val callFunc = if (method.isDefined && 
method.get.getReturnType.isPrimitive) {
+      s"${obj.value}.$functionName($argString)"
     } else {
-      ev.isNull = obj.isNull
-      ""
+      s"(${ctx.boxedType(javaType)}) ${obj.value}.$functionName($argString)"
     }
 
-    val value = unboxer(s"${obj.value}.$functionName($argString)")
+    val setIsNull = if (propagateNull && arguments.nonEmpty) {
+      s"boolean ${ev.isNull} = ${obj.isNull} || 
${argGen.map(_.isNull).mkString(" || ")};"
+    } else {
+      s"boolean ${ev.isNull} = ${obj.isNull};"
+    }
 
     val evaluate = if (method.forall(_.getExceptionTypes.isEmpty)) {
-      s"$javaType ${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(dataType)} 
: ($javaType) $value;"
+      s"final $javaType ${ev.value} = ${ev.isNull} ? 
${ctx.defaultValue(dataType)} : $callFunc;"
     } else {
       s"""
         $javaType ${ev.value} = ${ctx.defaultValue(javaType)};
         try {
-          ${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(javaType)} : 
($javaType) $value;
+          ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(javaType)} : 
$callFunc;
         } catch (Exception e) {
           org.apache.spark.unsafe.Platform.throwException(e);
         }
       """
     }
 
-    ev.copy(code = s"""
+    // If the function can return null, we do an extra check to make sure our 
null bit is still set
+    // correctly.
+    val postNullCheck = if (ctx.defaultValue(dataType) == "null") {
+      s"${ev.isNull} = ${ev.value} == null;"
+    } else {
+      ""
+    }
+
+    val code = s"""
       ${obj.code}
       ${argGen.map(_.code).mkString("\n")}
+      $setIsNull
       $evaluate
-      $objNullCheck
-    """)
+      $postNullCheck
+     """
+    ev.copy(code = code)
   }
 
   override def toString: String = s"$targetObject.$functionName"
@@ -246,11 +236,13 @@ case class NewInstance(
 
     val outer = outerPointer.map(func => 
Literal.fromObject(func()).genCode(ctx))
 
-    val setup =
-      s"""
-         ${argGen.map(_.code).mkString("\n")}
-         ${outer.map(_.code).getOrElse("")}
-       """.stripMargin
+    var isNull = ev.isNull
+    val setIsNull = if (propagateNull && arguments.nonEmpty) {
+      s"final boolean $isNull = ${argGen.map(_.isNull).mkString(" || ")};"
+    } else {
+      isNull = "false"
+      ""
+    }
 
     val constructorCall = outer.map { gen =>
       s"""${gen.value}.new ${cls.getSimpleName}($argString)"""
@@ -258,27 +250,13 @@ case class NewInstance(
       s"new $className($argString)"
     }
 
-    if (propagateNull && argGen.nonEmpty) {
-      val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})"
-
-      ev.copy(code = s"""
-        $setup
-
-        boolean ${ev.isNull} = true;
-        $javaType ${ev.value} = ${ctx.defaultValue(dataType)};
-        if ($argsNonNull) {
-          ${ev.value} = $constructorCall;
-          ${ev.isNull} = false;
-        }
-       """)
-    } else {
-      ev.copy(code = s"""
-        $setup
-
-        final $javaType ${ev.value} = $constructorCall;
-        final boolean ${ev.isNull} = false;
-      """)
-    }
+    val code = s"""
+      ${argGen.map(_.code).mkString("\n")}
+      ${outer.map(_.code).getOrElse("")}
+      $setIsNull
+      final $javaType ${ev.value} = $isNull ? ${ctx.defaultValue(javaType)} : 
$constructorCall;
+     """
+    ev.copy(code = code, isNull = isNull)
   }
 
   override def toString: String = s"newInstance($cls)"
@@ -306,13 +284,14 @@ case class UnwrapOption(
     val javaType = ctx.javaType(dataType)
     val inputObject = child.genCode(ctx)
 
-    ev.copy(code = s"""
+    val code = s"""
       ${inputObject.code}
 
-      boolean ${ev.isNull} = ${inputObject.value} == null || 
${inputObject.value}.isEmpty();
+      final boolean ${ev.isNull} = ${inputObject.isNull} || 
${inputObject.value}.isEmpty();
       $javaType ${ev.value} =
-        ${ev.isNull} ? ${ctx.defaultValue(dataType)} : 
($javaType)${inputObject.value}.get();
-    """)
+        ${ev.isNull} ? ${ctx.defaultValue(javaType)} : ($javaType) 
${inputObject.value}.get();
+    """
+    ev.copy(code = code)
   }
 }
 
@@ -338,14 +317,14 @@ case class WrapOption(child: Expression, optType: 
DataType)
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val inputObject = child.genCode(ctx)
 
-    ev.copy(code = s"""
+    val code = s"""
       ${inputObject.code}
 
-      boolean ${ev.isNull} = false;
       scala.Option ${ev.value} =
         ${inputObject.isNull} ?
         scala.Option$$.MODULE$$.apply(null) : new 
scala.Some(${inputObject.value});
-    """)
+    """
+    ev.copy(code = code, isNull = "false")
   }
 }
 
@@ -474,7 +453,7 @@ case class MapObjects private(
       s"${loopVar.isNull} = ${genInputData.isNull} || ${loopVar.value} == 
null;"
     }
 
-    ev.copy(code = s"""
+    val code = s"""
       ${genInputData.code}
 
       boolean ${ev.isNull} = ${genInputData.value} == null;
@@ -504,7 +483,8 @@ case class MapObjects private(
         ${ev.isNull} = false;
         ${ev.value} = new 
${classOf[GenericArrayData].getName}($convertedArray);
       }
-    """)
+    """
+    ev.copy(code = code)
   }
 }
 
@@ -539,14 +519,16 @@ case class CreateExternalRow(children: Seq[Expression], 
schema: StructType)
           }
          """
     }
+
     val childrenCode = ctx.splitExpressions(ctx.INPUT_ROW, childrenCodes)
     val schemaField = ctx.addReferenceObj("schema", schema)
-    ev.copy(code = s"""
-      boolean ${ev.isNull} = false;
+
+    val code = s"""
       $values = new Object[${children.size}];
       $childrenCode
       final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, 
this.$schemaField);
-      """)
+      """
+    ev.copy(code = code, isNull = "false")
   }
 }
 
@@ -579,14 +561,14 @@ case class EncodeUsingSerializer(child: Expression, kryo: 
Boolean)
 
     // Code to serialize.
     val input = child.genCode(ctx)
-    ev.copy(code = s"""
+    val javaType = ctx.javaType(dataType)
+    val serialize = s"$serializer.serialize(${input.value}, null).array()"
+
+    val code = s"""
       ${input.code}
-      final boolean ${ev.isNull} = ${input.isNull};
-      ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
-      if (!${ev.isNull}) {
-        ${ev.value} = $serializer.serialize(${input.value}, null).array();
-      }
-     """)
+      final $javaType ${ev.value} = ${input.isNull} ? 
${ctx.defaultValue(javaType)} : $serialize;
+     """
+    ev.copy(code = code, isNull = input.isNull)
   }
 
   override def dataType: DataType = BinaryType
@@ -617,17 +599,17 @@ case class DecodeUsingSerializer[T](child: Expression, 
tag: ClassTag[T], kryo: B
       serializer,
       s"$serializer = ($serializerInstanceClass) new 
$serializerClass($sparkConf).newInstance();")
 
-    // Code to serialize.
+    // Code to deserialize.
     val input = child.genCode(ctx)
-    ev.copy(code = s"""
+    val javaType = ctx.javaType(dataType)
+    val deserialize =
+      s"($javaType) 
$serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)"
+
+    val code = s"""
       ${input.code}
-      final boolean ${ev.isNull} = ${input.isNull};
-      ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
-      if (!${ev.isNull}) {
-        ${ev.value} = (${ctx.javaType(dataType)})
-          $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), 
null);
-      }
-     """)
+      final $javaType ${ev.value} = ${input.isNull} ? 
${ctx.defaultValue(javaType)} : $deserialize;
+     """
+    ev.copy(code = code, isNull = input.isNull)
   }
 
   override def dataType: DataType = ObjectType(tag.runtimeClass)
@@ -658,15 +640,13 @@ case class InitializeJavaBean(beanInstance: Expression, 
setters: Map[String, Exp
          """
     }
 
-    ev.isNull = instanceGen.isNull
-    ev.value = instanceGen.value
-
-    ev.copy(code = s"""
+    val code = s"""
       ${instanceGen.code}
       if (!${instanceGen.isNull}) {
         ${initialize.mkString("\n")}
       }
-     """)
+     """
+    ev.copy(code = code, isNull = instanceGen.isNull, value = 
instanceGen.value)
   }
 }
 
@@ -696,13 +676,15 @@ case class AssertNotNull(child: Expression, 
walkedTypePath: Seq[String])
       "If the schema is inferred from a Scala tuple/case class, or a Java 
bean, " +
       "please try to use scala.Option[_] or other nullable types " +
       "(e.g. java.lang.Integer instead of int/scala.Int)."
-    val idx = ctx.references.length
-    ctx.references += errMsg
-    ExprCode(code = s"""
+    val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
+
+    val code = s"""
       ${childGen.code}
 
       if (${childGen.isNull}) {
-        throw new RuntimeException((String) references[$idx]);
-      }""", isNull = "false", value = childGen.value)
+        throw new RuntimeException(this.$errMsgField);
+      }
+     """
+    ev.copy(code = code, isNull = "false", value = childGen.value)
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to