Github user maropu commented on a diff in the pull request: https://github.com/apache/spark/pull/22468#discussion_r238894837 --- Diff: sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala --- @@ -535,4 +535,98 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers with PlanTestB assert(unsafeRow.getSizeInBytes == 8 + 8 * 2 + roundedSize(field1.getSizeInBytes) + roundedSize(field2.getSizeInBytes)) } + + testBothCodegenAndInterpreted("SPARK-25374 converts back into safe representation") { + def convertBackToInternalRow(inputRow: InternalRow, fields: Array[DataType]): InternalRow = { + val unsafeProj = UnsafeProjection.create(fields) + val unsafeRow = unsafeProj(inputRow) + val safeProj = SafeProjection.create(fields) + safeProj(unsafeRow) + } + + // Simple tests + val inputRow = InternalRow.fromSeq(Seq( + false, 3.toByte, 15.toShort, -83, 129L, 1.0f, 8.0, UTF8String.fromString("test"), + Decimal(255), CalendarInterval.fromString("interval 1 day"), Array[Byte](1, 2) + )) + val fields1 = Array( + BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, + DoubleType, StringType, DecimalType.defaultConcreteType, CalendarIntervalType, + BinaryType) + + assert(convertBackToInternalRow(inputRow, fields1) === inputRow) + + // Array tests + val arrayRow = InternalRow.fromSeq(Seq( + createArray(1, 2, 3), + createArray( + createArray(Seq("a", "b", "c").map(UTF8String.fromString): _*), + createArray(Seq("d").map(UTF8String.fromString): _*)) + )) + val fields2 = Array[DataType]( + ArrayType(IntegerType), + ArrayType(ArrayType(StringType))) + + assert(convertBackToInternalRow(arrayRow, fields2) === arrayRow) + + // Struct tests + val structRow = InternalRow.fromSeq(Seq( + InternalRow.fromSeq(Seq[Any](1, 4.0)), + InternalRow.fromSeq(Seq( + UTF8String.fromString("test"), + InternalRow.fromSeq(Seq( + 1, + createArray(Seq("2", "3").map(UTF8String.fromString): _*) + )) + )) + )) + val fields3 = Array[DataType]( + StructType( + StructField("c0", IntegerType) :: + StructField("c1", DoubleType) :: + Nil), + StructType( + StructField("c2", StringType) :: + StructField("c3", StructType( + StructField("c4", IntegerType) :: + StructField("c5", ArrayType(StringType)) :: + Nil)) :: + Nil)) + + assert(convertBackToInternalRow(structRow, fields3) === structRow) + + // Map tests + val mapRow = InternalRow.fromSeq(Seq( + createMap(Seq("k1", "k2").map(UTF8String.fromString): _*)(1, 2), + createMap( + createMap(3, 5)(Seq("v1", "v2").map(UTF8String.fromString): _*), + createMap(7, 9)(Seq("v3", "v4").map(UTF8String.fromString): _*) + )( + createMap(Seq("k3", "k4").map(UTF8String.fromString): _*)(3.toShort, 4.toShort), + createMap(Seq("k5", "k6").map(UTF8String.fromString): _*)(5.toShort, 6.toShort) + ))) + val fields4 = Array[DataType]( + MapType(StringType, IntegerType), + MapType(MapType(IntegerType, StringType), MapType(StringType, ShortType))) + + val mapResultRow = convertBackToInternalRow(mapRow, fields4).toSeq(fields4) + val mapExpectedRow = mapRow.toSeq(fields4) + // Since `ArrayBasedMapData` does not override `equals` and `hashCode`, --- End diff -- Aha, thanks. I remember that its related to SPARK-18134.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org