Github user rxin commented on a diff in the pull request:

    https://github.com/apache/spark/pull/7365#discussion_r34967532
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
 ---
    @@ -418,51 +418,508 @@ case class Cast(child: Expression, dataType: 
DataType)
       protected override def nullSafeEval(input: Any): Any = cast(input)
     
       override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): 
String = {
    -    // TODO: Add support for more data types.
    -    (child.dataType, dataType) match {
    +    val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx)
    +    if (nullSafeCast != null) {
    +      val eval = child.gen(ctx)
    +      eval.code +
    +        castCode(ctx, eval.primitive, eval.isNull, ev.primitive, 
ev.isNull, dataType, nullSafeCast)
    +    } else {
    +      super.genCode(ctx, ev)
    +    }
    +  }
    +
    +  // three function arguments are: child.primitive, result.primitive and 
result.isNull
    +  // it returns the code snippets to be put in null safe evaluation region
    +  private[this] type CastFunction = (String, String, String) => String
    +
    +  private[this] def nullSafeCastFunction(
    +      from: DataType,
    +      to: DataType,
    +      ctx: CodeGenContext): CastFunction = to match {
    +
    +    case _ if from == NullType => (c, evPrim, evNull) => s"$evNull = true;"
    +    case _ if to == from => (c, evPrim, evNull) => s"$evPrim = $c;"
    +    case StringType => castToStringCode(from, ctx)
    +    case BinaryType => castToBinaryCode(from)
    +    case DateType => castToDateCode(from)
    +    case decimal: DecimalType => castToDecimalCode(from, decimal)
    +    case TimestampType => castToTimestampCode(from)
    +    case IntervalType => castToIntervalCode(from)
    +    case BooleanType => castToBooleanCode(from)
    +    case ByteType => castToByteCode(from)
    +    case ShortType => castToShortCode(from)
    +    case IntegerType => castToIntCode(from)
    +    case FloatType => castToFloatCode(from)
    +    case LongType => castToLongCode(from)
    +    case DoubleType => castToDoubleCode(from)
    +
    +    case array: ArrayType => castArrayCode(from.asInstanceOf[ArrayType], 
array, ctx)
    +    case map: MapType => castMapCode(from.asInstanceOf[MapType], map, ctx)
    +    case struct: StructType => 
castStructCode(from.asInstanceOf[StructType], struct, ctx)
    +    case other => null
    +  }
    +
    +  private[this] def castCode(ctx: CodeGenContext, childPrim: String, 
childNull: String,
    +    resultPrim: String, resultNull: String, resultType: DataType, cast: 
CastFunction): String = {
    +    s"""
    +      boolean $resultNull = $childNull;
    +      ${ctx.javaType(resultType)} $resultPrim = 
${ctx.defaultValue(resultType)};
    +      if (!${childNull}) {
    +        ${cast(childPrim, resultPrim, resultNull)}
    +      }
    +    """
    +  }
    +
    +  private[this] def castToStringCode(from: DataType, ctx: CodeGenContext): 
CastFunction = {
    +    from match {
    +      case BinaryType =>
    +        (c, evPrim, evNull) => s"$evPrim = UTF8String.fromBytes($c);"
    +      case DateType =>
    +        (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString(
    +          
org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c));"""
    +      case TimestampType =>
    +        (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString(
    +          
org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c));"""
    +      case _ =>
    +        (c, evPrim, evNull) => s"$evPrim = 
UTF8String.fromString(String.valueOf($c));"
    +    }
    +  }
    +
    +  private[this] def castToBinaryCode(from: DataType): CastFunction = from 
match {
    +    case StringType =>
    +      (c, evPrim, evNull) => s"$evPrim = $c.getBytes();"
    +  }
    +
    +  private[this] def castToDateCode(from: DataType): CastFunction = from 
match {
    +    case StringType =>
    +      (c, evPrim, evNull) => s"""
    +        try {
    +          $evPrim = 
org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToDateTE($c);
    +        } catch (java.lang.IllegalArgumentException e) {
    +          $evNull = true;
    +        }
    +       """
    +    case TimestampType =>
    +      (c, evPrim, evNull) =>
    +        s"$evPrim = 
org.apache.spark.sql.catalyst.util.DateTimeUtils.millisToDays($c / 1000L);";
    +    case _ =>
    +      (c, evPrim, evNull) => s"$evNull = true;"
    +  }
    +
    +  private[this] def changePrecision(d: String, decimalType: DecimalType,
    +      evPrim: String, evNull: String): String = {
    +    decimalType match {
    +      case DecimalType.Unlimited =>
    +        s"$evPrim = $d;"
    +      case DecimalType.Fixed(precision, scale) =>
    +        s"""
    +          if ($d.changePrecision($precision, $scale)) {
    +            $evPrim = $d;
    +          } else {
    +            $evNull = true;
    +          }
    +        """
    +    }
    +  }
    +
    +  private[this] def castToDecimalCode(from: DataType, target: 
DecimalType): CastFunction = {
    +    from match {
    +      case StringType =>
    +        (c, evPrim, evNull) =>
    +          s"""
    +            try {
    +              org.apache.spark.sql.types.Decimal tmpDecimal =
    +                new org.apache.spark.sql.types.Decimal().set(
    +                  new scala.math.BigDecimal(
    +                    new java.math.BigDecimal($c.toString())));
    +              ${changePrecision("tmpDecimal", target, evPrim, evNull)}
    +            } catch (java.lang.NumberFormatException e) {
    +              $evNull = true;
    +            }
    +          """
    +      case BooleanType =>
    +        (c, evPrim, evNull) =>
    +          s"""
    +            org.apache.spark.sql.types.Decimal tmpDecimal = null;
    +            if ($c) {
    +              tmpDecimal = new org.apache.spark.sql.types.Decimal().set(1);
    +            } else {
    +              tmpDecimal = new org.apache.spark.sql.types.Decimal().set(0);
    +            }
    +            ${changePrecision("tmpDecimal", target, evPrim, evNull)}
    +          """
    +      case DateType =>
    +        // date can't cast to decimal in Hive
    +        (c, evPrim, evNull) => s"$evNull = true;"
    +      case TimestampType =>
    +        // Note that we lose precision here.
    +        (c, evPrim, evNull) =>
    +          s"""
    +            org.apache.spark.sql.types.Decimal tmpDecimal =
    +              new org.apache.spark.sql.types.Decimal().set(
    +                
scala.math.BigDecimal.valueOf(${timestampToDoubleCode(c)}));
    +            ${changePrecision("tmpDecimal", target, evPrim, evNull)}
    +          """
    +      case DecimalType() =>
    +        (c, evPrim, evNull) =>
    +          s"""
    +            org.apache.spark.sql.types.Decimal tmpDecimal = $c.clone();
    +            ${changePrecision("tmpDecimal", target, evPrim, evNull)}
    +          """
    +      case LongType =>
    +        (c, evPrim, evNull) =>
    +          s"""
    +            org.apache.spark.sql.types.Decimal tmpDecimal =
    +              new org.apache.spark.sql.types.Decimal().set($c);
    +            ${changePrecision("tmpDecimal", target, evPrim, evNull)}
    +          """
    +      case x: NumericType =>
    +        // All other numeric types can be represented precisely as Doubles
    +        (c, evPrim, evNull) =>
    +          s"""
    +            try {
    +              org.apache.spark.sql.types.Decimal tmpDecimal =
    +                new org.apache.spark.sql.types.Decimal().set(
    +                  scala.math.BigDecimal.valueOf((double) $c));
    +              ${changePrecision("tmpDecimal", target, evPrim, evNull)}
    +            } catch (java.lang.NumberFormatException e) {
    +              $evNull = true;
    +            }
    +          """
    +    }
    +  }
    +
    +  private[this] def castToTimestampCode(from: DataType): CastFunction = 
from match {
    +    case StringType =>
    +      (c, evPrim, evNull) =>
    +        s"""
    +          try {
    +            $evPrim = 
org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToTimestampTE($c);
    +          } catch (java.lang.IllegalArgumentException e) {
    +            $evNull = true;
    +          }
    --- End diff --
    
    ah never mind - i see what's going on. if you are adding a wrapper in order 
to support try/catch, i think it's best to remove the wrapper and just use the 
option directly. 


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to