Repository: spark
Updated Branches:
  refs/heads/master 64c14618d -> 365c14055


[SPARK-8748][SQL] Move castability test out from Cast case class into Cast 
object.

This patch moved resolve function in Cast case class into the companion object, 
and renamed it canCast. We can then use this in the analyzer without a Cast 
expr.

Author: Reynold Xin <r...@databricks.com>

Closes #7145 from rxin/cast and squashes the following commits:

cd086a9 [Reynold Xin] Whitespace changes.
4d2d989 [Reynold Xin] [SPARK-8748][SQL] Move castability test out from Cast 
case class into Cast object.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/365c1405
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/365c1405
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/365c1405

Branch: refs/heads/master
Commit: 365c14055e90db5ea4b25afec03022be81c8a704
Parents: 64c1461
Author: Reynold Xin <r...@databricks.com>
Authored: Tue Jun 30 23:04:54 2015 -0700
Committer: Reynold Xin <r...@databricks.com>
Committed: Tue Jun 30 23:04:54 2015 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/expressions/Cast.scala   | 144 ++++++++++---------
 1 file changed, 78 insertions(+), 66 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/365c1405/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index d69d490..2d99d1a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -27,23 +27,65 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 
-/** Cast the child expression to the target data type. */
-case class Cast(child: Expression, dataType: DataType) extends UnaryExpression 
with Logging {
 
-  override def checkInputDataTypes(): TypeCheckResult = {
-    if (resolve(child.dataType, dataType)) {
-      TypeCheckResult.TypeCheckSuccess
-    } else {
-      TypeCheckResult.TypeCheckFailure(
-        s"cannot cast ${child.dataType} to $dataType")
-    }
-  }
+object Cast {
 
-  override def foldable: Boolean = child.foldable
+  /**
+   * Returns true iff we can cast `from` type to `to` type.
+   */
+  def canCast(from: DataType, to: DataType): Boolean = (from, to) match {
+    case (fromType, toType) if fromType == toType => true
+
+    case (NullType, _) => true
+
+    case (_, StringType) => true
 
-  override def nullable: Boolean = forceNullable(child.dataType, dataType) || 
child.nullable
+    case (StringType, BinaryType) => true
 
-  private[this] def forceNullable(from: DataType, to: DataType) = (from, to) 
match {
+    case (StringType, BooleanType) => true
+    case (DateType, BooleanType) => true
+    case (TimestampType, BooleanType) => true
+    case (_: NumericType, BooleanType) => true
+
+    case (StringType, TimestampType) => true
+    case (BooleanType, TimestampType) => true
+    case (DateType, TimestampType) => true
+    case (_: NumericType, TimestampType) => true
+
+    case (_, DateType) => true
+
+    case (StringType, _: NumericType) => true
+    case (BooleanType, _: NumericType) => true
+    case (DateType, _: NumericType) => true
+    case (TimestampType, _: NumericType) => true
+    case (_: NumericType, _: NumericType) => true
+
+    case (ArrayType(fromType, fn), ArrayType(toType, tn)) =>
+      canCast(fromType, toType) &&
+        resolvableNullability(fn || forceNullable(fromType, toType), tn)
+
+    case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
+      canCast(fromKey, toKey) &&
+        (!forceNullable(fromKey, toKey)) &&
+        canCast(fromValue, toValue) &&
+        resolvableNullability(fn || forceNullable(fromValue, toValue), tn)
+
+    case (StructType(fromFields), StructType(toFields)) =>
+      fromFields.length == toFields.length &&
+        fromFields.zip(toFields).forall {
+          case (fromField, toField) =>
+            canCast(fromField.dataType, toField.dataType) &&
+              resolvableNullability(
+                fromField.nullable || forceNullable(fromField.dataType, 
toField.dataType),
+                toField.nullable)
+        }
+
+    case _ => false
+  }
+
+  private def resolvableNullability(from: Boolean, to: Boolean) = !from || to
+
+  private def forceNullable(from: DataType, to: DataType) = (from, to) match {
     case (StringType, _: NumericType) => true
     case (StringType, TimestampType) => true
     case (DoubleType, TimestampType) => true
@@ -58,61 +100,24 @@ case class Cast(child: Expression, dataType: DataType) 
extends UnaryExpression w
     case (_, DecimalType.Fixed(_, _)) => true // TODO: not all upcasts here 
can really give null
     case _ => false
   }
+}
 
-  private[this] def resolvableNullability(from: Boolean, to: Boolean) = !from 
|| to
-
-  private[this] def resolve(from: DataType, to: DataType): Boolean = {
-    (from, to) match {
-      case (from, to) if from == to => true
-
-      case (NullType, _) => true
-
-      case (_, StringType) => true
-
-      case (StringType, BinaryType) => true
-
-      case (StringType, BooleanType) => true
-      case (DateType, BooleanType) => true
-      case (TimestampType, BooleanType) => true
-      case (_: NumericType, BooleanType) => true
-
-      case (StringType, TimestampType) => true
-      case (BooleanType, TimestampType) => true
-      case (DateType, TimestampType) => true
-      case (_: NumericType, TimestampType) => true
-
-      case (_, DateType) => true
-
-      case (StringType, _: NumericType) => true
-      case (BooleanType, _: NumericType) => true
-      case (DateType, _: NumericType) => true
-      case (TimestampType, _: NumericType) => true
-      case (_: NumericType, _: NumericType) => true
-
-      case (ArrayType(from, fn), ArrayType(to, tn)) =>
-        resolve(from, to) &&
-          resolvableNullability(fn || forceNullable(from, to), tn)
-
-      case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
-        resolve(fromKey, toKey) &&
-          (!forceNullable(fromKey, toKey)) &&
-          resolve(fromValue, toValue) &&
-          resolvableNullability(fn || forceNullable(fromValue, toValue), tn)
-
-      case (StructType(fromFields), StructType(toFields)) =>
-        fromFields.size == toFields.size &&
-          fromFields.zip(toFields).forall {
-            case (fromField, toField) =>
-              resolve(fromField.dataType, toField.dataType) &&
-                resolvableNullability(
-                  fromField.nullable || forceNullable(fromField.dataType, 
toField.dataType),
-                  toField.nullable)
-          }
+/** Cast the child expression to the target data type. */
+case class Cast(child: Expression, dataType: DataType) extends UnaryExpression 
with Logging {
 
-      case _ => false
+  override def checkInputDataTypes(): TypeCheckResult = {
+    if (Cast.canCast(child.dataType, dataType)) {
+      TypeCheckResult.TypeCheckSuccess
+    } else {
+      TypeCheckResult.TypeCheckFailure(
+        s"cannot cast ${child.dataType} to $dataType")
     }
   }
 
+  override def foldable: Boolean = child.foldable
+
+  override def nullable: Boolean = Cast.forceNullable(child.dataType, 
dataType) || child.nullable
+
   override def toString: String = s"CAST($child, $dataType)"
 
   // [[func]] assumes the input is no longer null because eval already does 
the null check.
@@ -172,7 +177,7 @@ case class Cast(child: Expression, dataType: DataType) 
extends UnaryExpression w
         catch { case _: java.lang.IllegalArgumentException => null }
       })
     case BooleanType =>
-      buildCast[Boolean](_, b => (if (b) 1L else 0))
+      buildCast[Boolean](_, b => if (b) 1L else 0)
     case LongType =>
       buildCast[Long](_, l => longToTimestamp(l))
     case IntegerType =>
@@ -388,7 +393,7 @@ case class Cast(child: Expression, dataType: DataType) 
extends UnaryExpression w
       case (fromField, toField) => cast(fromField.dataType, toField.dataType)
     }
     // TODO: Could be faster?
-    val newRow = new GenericMutableRow(from.fields.size)
+    val newRow = new GenericMutableRow(from.fields.length)
     buildCast[InternalRow](_, row => {
       var i = 0
       while (i < row.length) {
@@ -427,20 +432,23 @@ case class Cast(child: Expression, dataType: DataType) 
extends UnaryExpression w
   }
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
-    // TODO(cg): Add support for more data types.
+    // TODO: Add support for more data types.
     (child.dataType, dataType) match {
 
       case (BinaryType, StringType) =>
         defineCodeGen (ctx, ev, c =>
           s"${ctx.stringType}.fromBytes($c)")
+
       case (DateType, StringType) =>
         defineCodeGen(ctx, ev, c =>
           s"""${ctx.stringType}.fromString(
                 
org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c))""")
+
       case (TimestampType, StringType) =>
         defineCodeGen(ctx, ev, c =>
           s"""${ctx.stringType}.fromString(
                 
org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c))""")
+
       case (_, StringType) =>
         defineCodeGen(ctx, ev, c => 
s"${ctx.stringType}.fromString(String.valueOf($c))")
 
@@ -450,12 +458,16 @@ case class Cast(child: Expression, dataType: DataType) 
extends UnaryExpression w
 
       case (BooleanType, dt: NumericType) =>
         defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c ? 1 : 0)")
+
       case (dt: DecimalType, BooleanType) =>
         defineCodeGen(ctx, ev, c => s"!$c.isZero()")
+
       case (dt: NumericType, BooleanType) =>
         defineCodeGen(ctx, ev, c => s"$c != 0")
+
       case (_: DecimalType, dt: NumericType) =>
         defineCodeGen(ctx, ev, c => s"($c).to${ctx.primitiveTypeName(dt)}()")
+
       case (_: NumericType, dt: NumericType) =>
         defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c)")
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to