Repository: spark Updated Branches: refs/heads/master 044f7ecbf -> ffc57b011
[SPARK-20302][SQL] Short circuit cast when from and to types are structurally the same ## What changes were proposed in this pull request? When we perform a cast expression and the from and to types are structurally the same (having the same structure but different field names), we should be able to skip the actual cast. ## How was this patch tested? Added unit tests for the newly introduced functions. Author: Reynold Xin <r...@databricks.com> Closes #17614 from rxin/SPARK-20302. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ffc57b01 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ffc57b01 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ffc57b01 Branch: refs/heads/master Commit: ffc57b0118b58de57520967d8e8730b11baad507 Parents: 044f7ecb Author: Reynold Xin <r...@databricks.com> Authored: Wed Apr 12 01:30:00 2017 -0700 Committer: Reynold Xin <r...@databricks.com> Committed: Wed Apr 12 01:30:00 2017 -0700 ---------------------------------------------------------------------- .../spark/sql/catalyst/expressions/Cast.scala | 65 +++++++++++++------- .../org/apache/spark/sql/types/DataType.scala | 26 ++++++++ .../sql/catalyst/expressions/CastSuite.scala | 14 +++++ .../apache/spark/sql/types/DataTypeSuite.scala | 31 ++++++++++ 4 files changed, 113 insertions(+), 23 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/ffc57b01/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 1049915..bb1273f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -462,35 +462,54 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String }) } - private[this] def cast(from: DataType, to: DataType): Any => Any = to match { - case dt if dt == from => identity[Any] - case StringType => castToString(from) - case BinaryType => castToBinary(from) - case DateType => castToDate(from) - case decimal: DecimalType => castToDecimal(from, decimal) - case TimestampType => castToTimestamp(from) - case CalendarIntervalType => castToInterval(from) - case BooleanType => castToBoolean(from) - case ByteType => castToByte(from) - case ShortType => castToShort(from) - case IntegerType => castToInt(from) - case FloatType => castToFloat(from) - case LongType => castToLong(from) - case DoubleType => castToDouble(from) - case array: ArrayType => castArray(from.asInstanceOf[ArrayType].elementType, array.elementType) - case map: MapType => castMap(from.asInstanceOf[MapType], map) - case struct: StructType => castStruct(from.asInstanceOf[StructType], struct) - case udt: UserDefinedType[_] - if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass => - identity[Any] - case _: UserDefinedType[_] => - throw new SparkException(s"Cannot cast $from to $to.") + private[this] def cast(from: DataType, to: DataType): Any => Any = { + // If the cast does not change the structure, then we don't really need to cast anything. + // We can return what the children return. Same thing should happen in the codegen path. + if (DataType.equalsStructurally(from, to)) { + identity + } else { + to match { + case dt if dt == from => identity[Any] + case StringType => castToString(from) + case BinaryType => castToBinary(from) + case DateType => castToDate(from) + case decimal: DecimalType => castToDecimal(from, decimal) + case TimestampType => castToTimestamp(from) + case CalendarIntervalType => castToInterval(from) + case BooleanType => castToBoolean(from) + case ByteType => castToByte(from) + case ShortType => castToShort(from) + case IntegerType => castToInt(from) + case FloatType => castToFloat(from) + case LongType => castToLong(from) + case DoubleType => castToDouble(from) + case array: ArrayType => + castArray(from.asInstanceOf[ArrayType].elementType, array.elementType) + case map: MapType => castMap(from.asInstanceOf[MapType], map) + case struct: StructType => castStruct(from.asInstanceOf[StructType], struct) + case udt: UserDefinedType[_] + if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass => + identity[Any] + case _: UserDefinedType[_] => + throw new SparkException(s"Cannot cast $from to $to.") + } + } } private[this] lazy val cast: Any => Any = cast(child.dataType, dataType) protected override def nullSafeEval(input: Any): Any = cast(input) + override def genCode(ctx: CodegenContext): ExprCode = { + // If the cast does not change the structure, then we don't really need to cast anything. + // We can return what the children return. Same thing should happen in the interpreted path. + if (DataType.equalsStructurally(child.dataType, dataType)) { + child.genCode(ctx) + } else { + super.genCode(ctx) + } + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx) http://git-wip-us.apache.org/repos/asf/spark/blob/ffc57b01/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 520aff5..30745c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -288,4 +288,30 @@ object DataType { case (fromDataType, toDataType) => fromDataType == toDataType } } + + /** + * Returns true if the two data types share the same "shape", i.e. the types (including + * nullability) are the same, but the field names don't need to be the same. + */ + def equalsStructurally(from: DataType, to: DataType): Boolean = { + (from, to) match { + case (left: ArrayType, right: ArrayType) => + equalsStructurally(left.elementType, right.elementType) && + left.containsNull == right.containsNull + + case (left: MapType, right: MapType) => + equalsStructurally(left.keyType, right.keyType) && + equalsStructurally(left.valueType, right.valueType) && + left.valueContainsNull == right.valueContainsNull + + case (StructType(fromFields), StructType(toFields)) => + fromFields.length == toFields.length && + fromFields.zip(toFields) + .forall { case (l, r) => + equalsStructurally(l.dataType, r.dataType) && l.nullable == r.nullable + } + + case (fromDataType, toDataType) => fromDataType == toDataType + } + } } http://git-wip-us.apache.org/repos/asf/spark/blob/ffc57b01/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 8eccadb..a7ffa88 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -813,4 +813,18 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { assert(cast(1.0.toFloat, DateType).checkInputDataTypes().isFailure) assert(cast(1.0, DateType).checkInputDataTypes().isFailure) } + + test("SPARK-20302 cast with same structure") { + val from = new StructType() + .add("a", IntegerType) + .add("b", new StructType().add("b1", LongType)) + + val to = new StructType() + .add("a1", IntegerType) + .add("b1", new StructType().add("b11", LongType)) + + val input = Row(10, Row(12L)) + + checkEvaluation(cast(Literal.create(input, from), to), input) + } } http://git-wip-us.apache.org/repos/asf/spark/blob/ffc57b01/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index f078ef0..c4635c8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -411,4 +411,35 @@ class DataTypeSuite extends SparkFunSuite { checkCatalogString(ArrayType(createStruct(40))) checkCatalogString(MapType(IntegerType, StringType)) checkCatalogString(MapType(IntegerType, createStruct(40))) + + def checkEqualsStructurally(from: DataType, to: DataType, expected: Boolean): Unit = { + val testName = s"equalsStructurally: (from: $from, to: $to)" + test(testName) { + assert(DataType.equalsStructurally(from, to) === expected) + } + } + + checkEqualsStructurally(BooleanType, BooleanType, true) + checkEqualsStructurally(IntegerType, IntegerType, true) + checkEqualsStructurally(IntegerType, LongType, false) + checkEqualsStructurally(ArrayType(IntegerType, true), ArrayType(IntegerType, true), true) + checkEqualsStructurally(ArrayType(IntegerType, true), ArrayType(IntegerType, false), false) + + checkEqualsStructurally( + new StructType().add("f1", IntegerType), + new StructType().add("f2", IntegerType), + true) + checkEqualsStructurally( + new StructType().add("f1", IntegerType), + new StructType().add("f2", IntegerType, false), + false) + + checkEqualsStructurally( + new StructType().add("f1", IntegerType).add("f", new StructType().add("f2", StringType)), + new StructType().add("f2", IntegerType).add("g", new StructType().add("f1", StringType)), + true) + checkEqualsStructurally( + new StructType().add("f1", IntegerType).add("f", new StructType().add("f2", StringType, false)), + new StructType().add("f2", IntegerType).add("g", new StructType().add("f1", StringType)), + false) } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org