http://git-wip-us.apache.org/repos/asf/spark/blob/40ed2af5/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala new file mode 100644 index 0000000..8f06de7 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -0,0 +1,916 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag + +import org.apache.parquet.schema.MessageTypeParser + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.types._ + +abstract class ParquetSchemaTest extends SparkFunSuite with ParquetTest { + val sqlContext = TestSQLContext + + /** + * Checks whether the reflected Parquet message type for product type `T` conforms `messageType`. + */ + protected def testSchemaInference[T <: Product: ClassTag: TypeTag]( + testName: String, + messageType: String, + binaryAsString: Boolean = true, + int96AsTimestamp: Boolean = true, + followParquetFormatSpec: Boolean = false, + isThriftDerived: Boolean = false): Unit = { + testSchema( + testName, + StructType.fromAttributes(ScalaReflection.attributesFor[T]), + messageType, + binaryAsString, + int96AsTimestamp, + followParquetFormatSpec, + isThriftDerived) + } + + protected def testParquetToCatalyst( + testName: String, + sqlSchema: StructType, + parquetSchema: String, + binaryAsString: Boolean = true, + int96AsTimestamp: Boolean = true, + followParquetFormatSpec: Boolean = false, + isThriftDerived: Boolean = false): Unit = { + val converter = new CatalystSchemaConverter( + assumeBinaryIsString = binaryAsString, + assumeInt96IsTimestamp = int96AsTimestamp, + followParquetFormatSpec = followParquetFormatSpec) + + test(s"sql <= parquet: $testName") { + val actual = converter.convert(MessageTypeParser.parseMessageType(parquetSchema)) + val expected = sqlSchema + assert( + actual === expected, + s"""Schema mismatch. + |Expected schema: ${expected.json} + |Actual schema: ${actual.json} + """.stripMargin) + } + } + + protected def testCatalystToParquet( + testName: String, + sqlSchema: StructType, + parquetSchema: String, + binaryAsString: Boolean = true, + int96AsTimestamp: Boolean = true, + followParquetFormatSpec: Boolean = false, + isThriftDerived: Boolean = false): Unit = { + val converter = new CatalystSchemaConverter( + assumeBinaryIsString = binaryAsString, + assumeInt96IsTimestamp = int96AsTimestamp, + followParquetFormatSpec = followParquetFormatSpec) + + test(s"sql => parquet: $testName") { + val actual = converter.convert(sqlSchema) + val expected = MessageTypeParser.parseMessageType(parquetSchema) + actual.checkContains(expected) + expected.checkContains(actual) + } + } + + protected def testSchema( + testName: String, + sqlSchema: StructType, + parquetSchema: String, + binaryAsString: Boolean = true, + int96AsTimestamp: Boolean = true, + followParquetFormatSpec: Boolean = false, + isThriftDerived: Boolean = false): Unit = { + + testCatalystToParquet( + testName, + sqlSchema, + parquetSchema, + binaryAsString, + int96AsTimestamp, + followParquetFormatSpec, + isThriftDerived) + + testParquetToCatalyst( + testName, + sqlSchema, + parquetSchema, + binaryAsString, + int96AsTimestamp, + followParquetFormatSpec, + isThriftDerived) + } +} + +class ParquetSchemaInferenceSuite extends ParquetSchemaTest { + testSchemaInference[(Boolean, Int, Long, Float, Double, Array[Byte])]( + "basic types", + """ + |message root { + | required boolean _1; + | required int32 _2; + | required int64 _3; + | required float _4; + | required double _5; + | optional binary _6; + |} + """.stripMargin, + binaryAsString = false) + + testSchemaInference[(Byte, Short, Int, Long, java.sql.Date)]( + "logical integral types", + """ + |message root { + | required int32 _1 (INT_8); + | required int32 _2 (INT_16); + | required int32 _3 (INT_32); + | required int64 _4 (INT_64); + | optional int32 _5 (DATE); + |} + """.stripMargin) + + testSchemaInference[Tuple1[String]]( + "string", + """ + |message root { + | optional binary _1 (UTF8); + |} + """.stripMargin, + binaryAsString = true) + + testSchemaInference[Tuple1[String]]( + "binary enum as string", + """ + |message root { + | optional binary _1 (ENUM); + |} + """.stripMargin) + + testSchemaInference[Tuple1[Seq[Int]]]( + "non-nullable array - non-standard", + """ + |message root { + | optional group _1 (LIST) { + | repeated int32 array; + | } + |} + """.stripMargin) + + testSchemaInference[Tuple1[Seq[Int]]]( + "non-nullable array - standard", + """ + |message root { + | optional group _1 (LIST) { + | repeated group list { + | required int32 element; + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchemaInference[Tuple1[Seq[Integer]]]( + "nullable array - non-standard", + """ + |message root { + | optional group _1 (LIST) { + | repeated group bag { + | optional int32 array_element; + | } + | } + |} + """.stripMargin) + + testSchemaInference[Tuple1[Seq[Integer]]]( + "nullable array - standard", + """ + |message root { + | optional group _1 (LIST) { + | repeated group list { + | optional int32 element; + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchemaInference[Tuple1[Map[Int, String]]]( + "map - standard", + """ + |message root { + | optional group _1 (MAP) { + | repeated group key_value { + | required int32 key; + | optional binary value (UTF8); + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchemaInference[Tuple1[Map[Int, String]]]( + "map - non-standard", + """ + |message root { + | optional group _1 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | optional binary value (UTF8); + | } + | } + |} + """.stripMargin) + + testSchemaInference[Tuple1[Pair[Int, String]]]( + "struct", + """ + |message root { + | optional group _1 { + | required int32 _1; + | optional binary _2 (UTF8); + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchemaInference[Tuple1[Map[Int, (String, Seq[(Int, Double)])]]]( + "deeply nested type - non-standard", + """ + |message root { + | optional group _1 (MAP_KEY_VALUE) { + | repeated group map { + | required int32 key; + | optional group value { + | optional binary _1 (UTF8); + | optional group _2 (LIST) { + | repeated group bag { + | optional group array_element { + | required int32 _1; + | required double _2; + | } + | } + | } + | } + | } + | } + |} + """.stripMargin) + + testSchemaInference[Tuple1[Map[Int, (String, Seq[(Int, Double)])]]]( + "deeply nested type - standard", + """ + |message root { + | optional group _1 (MAP) { + | repeated group key_value { + | required int32 key; + | optional group value { + | optional binary _1 (UTF8); + | optional group _2 (LIST) { + | repeated group list { + | optional group element { + | required int32 _1; + | required double _2; + | } + | } + | } + | } + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchemaInference[(Option[Int], Map[Int, Option[Double]])]( + "optional types", + """ + |message root { + | optional int32 _1; + | optional group _2 (MAP) { + | repeated group key_value { + | required int32 key; + | optional double value; + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + // Parquet files generated by parquet-thrift are already handled by the schema converter, but + // let's leave this test here until both read path and write path are all updated. + ignore("thrift generated parquet schema") { + // Test for SPARK-4520 -- ensure that thrift generated parquet schema is generated + // as expected from attributes + testSchemaInference[( + Array[Byte], Array[Byte], Array[Byte], Seq[Int], Map[Array[Byte], Seq[Int]])]( + "thrift generated parquet schema", + """ + |message root { + | optional binary _1 (UTF8); + | optional binary _2 (UTF8); + | optional binary _3 (UTF8); + | optional group _4 (LIST) { + | repeated int32 _4_tuple; + | } + | optional group _5 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required binary key (UTF8); + | optional group value (LIST) { + | repeated int32 value_tuple; + | } + | } + | } + |} + """.stripMargin, + isThriftDerived = true) + } +} + +class ParquetSchemaSuite extends ParquetSchemaTest { + test("DataType string parser compatibility") { + // This is the generated string from previous versions of the Spark SQL, using the following: + // val schema = StructType(List( + // StructField("c1", IntegerType, false), + // StructField("c2", BinaryType, true))) + val caseClassString = + "StructType(List(StructField(c1,IntegerType,false), StructField(c2,BinaryType,true)))" + + // scalastyle:off + val jsonString = """{"type":"struct","fields":[{"name":"c1","type":"integer","nullable":false,"metadata":{}},{"name":"c2","type":"binary","nullable":true,"metadata":{}}]}""" + // scalastyle:on + + val fromCaseClassString = ParquetTypesConverter.convertFromString(caseClassString) + val fromJson = ParquetTypesConverter.convertFromString(jsonString) + + (fromCaseClassString, fromJson).zipped.foreach { (a, b) => + assert(a.name == b.name) + assert(a.dataType === b.dataType) + assert(a.nullable === b.nullable) + } + } + + test("merge with metastore schema") { + // Field type conflict resolution + assertResult( + StructType(Seq( + StructField("lowerCase", StringType), + StructField("UPPERCase", DoubleType, nullable = false)))) { + + ParquetRelation.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("lowercase", StringType), + StructField("uppercase", DoubleType, nullable = false))), + + StructType(Seq( + StructField("lowerCase", BinaryType), + StructField("UPPERCase", IntegerType, nullable = true)))) + } + + // MetaStore schema is subset of parquet schema + assertResult( + StructType(Seq( + StructField("UPPERCase", DoubleType, nullable = false)))) { + + ParquetRelation.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("uppercase", DoubleType, nullable = false))), + + StructType(Seq( + StructField("lowerCase", BinaryType), + StructField("UPPERCase", IntegerType, nullable = true)))) + } + + // Metastore schema contains additional non-nullable fields. + assert(intercept[Throwable] { + ParquetRelation.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("uppercase", DoubleType, nullable = false), + StructField("lowerCase", BinaryType, nullable = false))), + + StructType(Seq( + StructField("UPPERCase", IntegerType, nullable = true)))) + }.getMessage.contains("detected conflicting schemas")) + + // Conflicting non-nullable field names + intercept[Throwable] { + ParquetRelation.mergeMetastoreParquetSchema( + StructType(Seq(StructField("lower", StringType, nullable = false))), + StructType(Seq(StructField("lowerCase", BinaryType)))) + } + } + + test("merge missing nullable fields from Metastore schema") { + // Standard case: Metastore schema contains additional nullable fields not present + // in the Parquet file schema. + assertResult( + StructType(Seq( + StructField("firstField", StringType, nullable = true), + StructField("secondField", StringType, nullable = true), + StructField("thirdfield", StringType, nullable = true)))) { + ParquetRelation.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("firstfield", StringType, nullable = true), + StructField("secondfield", StringType, nullable = true), + StructField("thirdfield", StringType, nullable = true))), + StructType(Seq( + StructField("firstField", StringType, nullable = true), + StructField("secondField", StringType, nullable = true)))) + } + + // Merge should fail if the Metastore contains any additional fields that are not + // nullable. + assert(intercept[Throwable] { + ParquetRelation.mergeMetastoreParquetSchema( + StructType(Seq( + StructField("firstfield", StringType, nullable = true), + StructField("secondfield", StringType, nullable = true), + StructField("thirdfield", StringType, nullable = false))), + StructType(Seq( + StructField("firstField", StringType, nullable = true), + StructField("secondField", StringType, nullable = true)))) + }.getMessage.contains("detected conflicting schemas")) + } + + // ======================================================= + // Tests for converting Parquet LIST to Catalyst ArrayType + // ======================================================= + + testParquetToCatalyst( + "Backwards-compatibility: LIST with nullable element type - 1 - standard", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = true), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group list { + | optional int32 element; + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with nullable element type - 2", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = true), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group element { + | optional int32 num; + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 1 - standard", + StructType(Seq( + StructField("f1", ArrayType(IntegerType, containsNull = false), nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group list { + | required int32 element; + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 2", + StructType(Seq( + StructField("f1", ArrayType(IntegerType, containsNull = false), nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group element { + | required int32 num; + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 3", + StructType(Seq( + StructField("f1", ArrayType(IntegerType, containsNull = false), nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated int32 element; + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 4", + StructType(Seq( + StructField( + "f1", + ArrayType( + StructType(Seq( + StructField("str", StringType, nullable = false), + StructField("num", IntegerType, nullable = false))), + containsNull = false), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group element { + | required binary str (UTF8); + | required int32 num; + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 5 - parquet-avro style", + StructType(Seq( + StructField( + "f1", + ArrayType( + StructType(Seq( + StructField("str", StringType, nullable = false))), + containsNull = false), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group array { + | required binary str (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 6 - parquet-thrift style", + StructType(Seq( + StructField( + "f1", + ArrayType( + StructType(Seq( + StructField("str", StringType, nullable = false))), + containsNull = false), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group f1_tuple { + | required binary str (UTF8); + | } + | } + |} + """.stripMargin) + + // ======================================================= + // Tests for converting Catalyst ArrayType to Parquet LIST + // ======================================================= + + testCatalystToParquet( + "Backwards-compatibility: LIST with nullable element type - 1 - standard", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = true), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group list { + | optional int32 element; + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testCatalystToParquet( + "Backwards-compatibility: LIST with nullable element type - 2 - prior to 1.4.x", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = true), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group bag { + | optional int32 array_element; + | } + | } + |} + """.stripMargin) + + testCatalystToParquet( + "Backwards-compatibility: LIST with non-nullable element type - 1 - standard", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = false), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group list { + | required int32 element; + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testCatalystToParquet( + "Backwards-compatibility: LIST with non-nullable element type - 2 - prior to 1.4.x", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = false), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated int32 array; + | } + |} + """.stripMargin) + + // ==================================================== + // Tests for converting Parquet Map to Catalyst MapType + // ==================================================== + + testParquetToCatalyst( + "Backwards-compatibility: MAP with non-nullable value type - 1 - standard", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = false), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group key_value { + | required int32 key; + | required binary value (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: MAP with non-nullable value type - 2", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = false), + nullable = true))), + """message root { + | optional group f1 (MAP_KEY_VALUE) { + | repeated group map { + | required int32 num; + | required binary str (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: MAP with non-nullable value type - 3 - prior to 1.4.x", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = false), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | required binary value (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: MAP with nullable value type - 1 - standard", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = true), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group key_value { + | required int32 key; + | optional binary value (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: MAP with nullable value type - 2", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = true), + nullable = true))), + """message root { + | optional group f1 (MAP_KEY_VALUE) { + | repeated group map { + | required int32 num; + | optional binary str (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: MAP with nullable value type - 3 - parquet-avro style", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = true), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | optional binary value (UTF8); + | } + | } + |} + """.stripMargin) + + // ==================================================== + // Tests for converting Catalyst MapType to Parquet Map + // ==================================================== + + testCatalystToParquet( + "Backwards-compatibility: MAP with non-nullable value type - 1 - standard", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = false), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group key_value { + | required int32 key; + | required binary value (UTF8); + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testCatalystToParquet( + "Backwards-compatibility: MAP with non-nullable value type - 2 - prior to 1.4.x", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = false), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | required binary value (UTF8); + | } + | } + |} + """.stripMargin) + + testCatalystToParquet( + "Backwards-compatibility: MAP with nullable value type - 1 - standard", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = true), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group key_value { + | required int32 key; + | optional binary value (UTF8); + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testCatalystToParquet( + "Backwards-compatibility: MAP with nullable value type - 3 - prior to 1.4.x", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = true), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | optional binary value (UTF8); + | } + | } + |} + """.stripMargin) + + // ================================= + // Tests for conversion for decimals + // ================================= + + testSchema( + "DECIMAL(1, 0) - standard", + StructType(Seq(StructField("f1", DecimalType(1, 0)))), + """message root { + | optional int32 f1 (DECIMAL(1, 0)); + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchema( + "DECIMAL(8, 3) - standard", + StructType(Seq(StructField("f1", DecimalType(8, 3)))), + """message root { + | optional int32 f1 (DECIMAL(8, 3)); + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchema( + "DECIMAL(9, 3) - standard", + StructType(Seq(StructField("f1", DecimalType(9, 3)))), + """message root { + | optional int32 f1 (DECIMAL(9, 3)); + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchema( + "DECIMAL(18, 3) - standard", + StructType(Seq(StructField("f1", DecimalType(18, 3)))), + """message root { + | optional int64 f1 (DECIMAL(18, 3)); + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchema( + "DECIMAL(19, 3) - standard", + StructType(Seq(StructField("f1", DecimalType(19, 3)))), + """message root { + | optional fixed_len_byte_array(9) f1 (DECIMAL(19, 3)); + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchema( + "DECIMAL(1, 0) - prior to 1.4.x", + StructType(Seq(StructField("f1", DecimalType(1, 0)))), + """message root { + | optional fixed_len_byte_array(1) f1 (DECIMAL(1, 0)); + |} + """.stripMargin) + + testSchema( + "DECIMAL(8, 3) - prior to 1.4.x", + StructType(Seq(StructField("f1", DecimalType(8, 3)))), + """message root { + | optional fixed_len_byte_array(4) f1 (DECIMAL(8, 3)); + |} + """.stripMargin) + + testSchema( + "DECIMAL(9, 3) - prior to 1.4.x", + StructType(Seq(StructField("f1", DecimalType(9, 3)))), + """message root { + | optional fixed_len_byte_array(5) f1 (DECIMAL(9, 3)); + |} + """.stripMargin) + + testSchema( + "DECIMAL(18, 3) - prior to 1.4.x", + StructType(Seq(StructField("f1", DecimalType(18, 3)))), + """message root { + | optional fixed_len_byte_array(8) f1 (DECIMAL(18, 3)); + |} + """.stripMargin) +}
http://git-wip-us.apache.org/repos/asf/spark/blob/40ed2af5/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala new file mode 100644 index 0000000..3c6e54d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.io.File + +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.{DataFrame, SaveMode} + +/** + * A helper trait that provides convenient facilities for Parquet testing. + * + * NOTE: Considering classes `Tuple1` ... `Tuple22` all extend `Product`, it would be more + * convenient to use tuples rather than special case classes when writing test cases/suites. + * Especially, `Tuple1.apply` can be used to easily wrap a single type/value. + */ +private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite => + /** + * Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f` + * returns. + */ + protected def withParquetFile[T <: Product: ClassTag: TypeTag] + (data: Seq[T]) + (f: String => Unit): Unit = { + withTempPath { file => + sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath) + f(file.getCanonicalPath) + } + } + + /** + * Writes `data` to a Parquet file and reads it back as a [[DataFrame]], + * which is then passed to `f`. The Parquet file will be deleted after `f` returns. + */ + protected def withParquetDataFrame[T <: Product: ClassTag: TypeTag] + (data: Seq[T]) + (f: DataFrame => Unit): Unit = { + withParquetFile(data)(path => f(sqlContext.read.parquet(path))) + } + + /** + * Writes `data` to a Parquet file, reads it back as a [[DataFrame]] and registers it as a + * temporary table named `tableName`, then call `f`. The temporary table together with the + * Parquet file will be dropped/deleted after `f` returns. + */ + protected def withParquetTable[T <: Product: ClassTag: TypeTag] + (data: Seq[T], tableName: String) + (f: => Unit): Unit = { + withParquetDataFrame(data) { df => + sqlContext.registerDataFrameAsTable(df, tableName) + withTempTable(tableName)(f) + } + } + + protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( + data: Seq[T], path: File): Unit = { + sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) + } + + protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( + df: DataFrame, path: File): Unit = { + df.write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) + } + + protected def makePartitionDir( + basePath: File, + defaultPartitionName: String, + partitionCols: (String, Any)*): File = { + val partNames = partitionCols.map { case (k, v) => + val valueString = if (v == null || v == "") defaultPartitionName else v.toString + s"$k=$valueString" + } + + val partDir = partNames.foldLeft(basePath) { (parent, child) => + new File(parent, child) + } + + assert(partDir.mkdirs(), s"Couldn't create directory $partDir") + partDir + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/40ed2af5/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala new file mode 100644 index 0000000..92b1d82 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.{Row, SQLContext} + +class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest { + import ParquetCompatibilityTest._ + + override val sqlContext: SQLContext = TestSQLContext + + private val parquetFilePath = + Thread.currentThread().getContextClassLoader.getResource("parquet-thrift-compat.snappy.parquet") + + test("Read Parquet file generated by parquet-thrift") { + logInfo( + s"""Schema of the Parquet file written by parquet-thrift: + |${readParquetSchema(parquetFilePath.toString)} + """.stripMargin) + + checkAnswer(sqlContext.read.parquet(parquetFilePath.toString), (0 until 10).map { i => + def nullable[T <: AnyRef]: ( => T) => T = makeNullable[T](i) + + val suits = Array("SPADES", "HEARTS", "DIAMONDS", "CLUBS") + + Row( + i % 2 == 0, + i.toByte, + (i + 1).toShort, + i + 2, + i.toLong * 10, + i.toDouble + 0.2d, + // Thrift `BINARY` values are actually unencoded `STRING` values, and thus are always + // treated as `BINARY (UTF8)` in parquet-thrift, since parquet-thrift always assume + // Thrift `STRING`s are encoded using UTF-8. + s"val_$i", + s"val_$i", + // Thrift ENUM values are converted to Parquet binaries containing UTF-8 strings + suits(i % 4), + + nullable(i % 2 == 0: java.lang.Boolean), + nullable(i.toByte: java.lang.Byte), + nullable((i + 1).toShort: java.lang.Short), + nullable(i + 2: Integer), + nullable((i * 10).toLong: java.lang.Long), + nullable(i.toDouble + 0.2d: java.lang.Double), + nullable(s"val_$i"), + nullable(s"val_$i"), + nullable(suits(i % 4)), + + Seq.tabulate(3)(n => s"arr_${i + n}"), + // Thrift `SET`s are converted to Parquet `LIST`s + Seq(i), + Seq.tabulate(3)(n => (i + n: Integer) -> s"val_${i + n}").toMap, + Seq.tabulate(3) { n => + (i + n) -> Seq.tabulate(3) { m => + Row(Seq.tabulate(3)(j => i + j + m), s"val_${i + m}") + } + }.toMap) + }) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/40ed2af5/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala new file mode 100644 index 0000000..953284c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -0,0 +1,135 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.metric + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream} + +import scala.collection.mutable + +import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm._ +import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.util.Utils + + +class SQLMetricsSuite extends SparkFunSuite { + + test("LongSQLMetric should not box Long") { + val l = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "long") + val f = () => { l += 1L } + BoxingFinder.getClassReader(f.getClass).foreach { cl => + val boxingFinder = new BoxingFinder() + cl.accept(boxingFinder, 0) + assert(boxingFinder.boxingInvokes.isEmpty, s"Found boxing: ${boxingFinder.boxingInvokes}") + } + } + + test("Normal accumulator should do boxing") { + // We need this test to make sure BoxingFinder works. + val l = TestSQLContext.sparkContext.accumulator(0L) + val f = () => { l += 1L } + BoxingFinder.getClassReader(f.getClass).foreach { cl => + val boxingFinder = new BoxingFinder() + cl.accept(boxingFinder, 0) + assert(boxingFinder.boxingInvokes.nonEmpty, "Found find boxing in this test") + } + } +} + +private case class MethodIdentifier[T](cls: Class[T], name: String, desc: String) + +/** + * If `method` is null, search all methods of this class recursively to find if they do some boxing. + * If `method` is specified, only search this method of the class to speed up the searching. + * + * This method will skip the methods in `visitedMethods` to avoid potential infinite cycles. + */ +private class BoxingFinder( + method: MethodIdentifier[_] = null, + val boxingInvokes: mutable.Set[String] = mutable.Set.empty, + visitedMethods: mutable.Set[MethodIdentifier[_]] = mutable.Set.empty) + extends ClassVisitor(ASM4) { + + private val primitiveBoxingClassName = + Set("java/lang/Long", + "java/lang/Double", + "java/lang/Integer", + "java/lang/Float", + "java/lang/Short", + "java/lang/Character", + "java/lang/Byte", + "java/lang/Boolean") + + override def visitMethod( + access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): + MethodVisitor = { + if (method != null && (method.name != name || method.desc != desc)) { + // If method is specified, skip other methods. + return new MethodVisitor(ASM4) {} + } + + new MethodVisitor(ASM4) { + override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) { + if (op == INVOKESPECIAL && name == "<init>" || op == INVOKESTATIC && name == "valueOf") { + if (primitiveBoxingClassName.contains(owner)) { + // Find boxing methods, e.g, new java.lang.Long(l) or java.lang.Long.valueOf(l) + boxingInvokes.add(s"$owner.$name") + } + } else { + // scalastyle:off classforname + val classOfMethodOwner = Class.forName(owner.replace('/', '.'), false, + Thread.currentThread.getContextClassLoader) + // scalastyle:on classforname + val m = MethodIdentifier(classOfMethodOwner, name, desc) + if (!visitedMethods.contains(m)) { + // Keep track of visited methods to avoid potential infinite cycles + visitedMethods += m + BoxingFinder.getClassReader(classOfMethodOwner).foreach { cl => + visitedMethods += m + cl.accept(new BoxingFinder(m, boxingInvokes, visitedMethods), 0) + } + } + } + } + } + } +} + +private object BoxingFinder { + + def getClassReader(cls: Class[_]): Option[ClassReader] = { + val className = cls.getName.replaceFirst("^.*\\.", "") + ".class" + val resourceStream = cls.getResourceAsStream(className) + val baos = new ByteArrayOutputStream(128) + // Copy data over, before delegating to ClassReader - + // else we can run out of open file handles. + Utils.copyStream(resourceStream, baos, true) + // ASM4 doesn't support Java 8 classes, which requires ASM5. + // So if the class is ASM5 (E.g., java.lang.Long when using JDK8 runtime to run these codes), + // then ClassReader will throw IllegalArgumentException, + // However, since this is only for testing, it's safe to skip these classes. + try { + Some(new ClassReader(new ByteArrayInputStream(baos.toByteArray))) + } catch { + case _: IllegalArgumentException => None + } + } + +} http://git-wip-us.apache.org/repos/asf/spark/blob/40ed2af5/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala new file mode 100644 index 0000000..41dd189 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -0,0 +1,348 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.ui + +import java.util.Properties + +import org.apache.spark.{SparkException, SparkContext, SparkConf, SparkFunSuite} +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.sql.execution.metric.LongSQLMetricValue +import org.apache.spark.scheduler._ +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.test.TestSQLContext + +class SQLListenerSuite extends SparkFunSuite { + + private def createTestDataFrame: DataFrame = { + import TestSQLContext.implicits._ + Seq( + (1, 1), + (2, 2) + ).toDF().filter("_1 > 1") + } + + private def createProperties(executionId: Long): Properties = { + val properties = new Properties() + properties.setProperty(SQLExecution.EXECUTION_ID_KEY, executionId.toString) + properties + } + + private def createStageInfo(stageId: Int, attemptId: Int): StageInfo = new StageInfo( + stageId = stageId, + attemptId = attemptId, + // The following fields are not used in tests + name = "", + numTasks = 0, + rddInfos = Nil, + parentIds = Nil, + details = "" + ) + + private def createTaskInfo(taskId: Int, attempt: Int): TaskInfo = new TaskInfo( + taskId = taskId, + attempt = attempt, + // The following fields are not used in tests + index = 0, + launchTime = 0, + executorId = "", + host = "", + taskLocality = null, + speculative = false + ) + + private def createTaskMetrics(accumulatorUpdates: Map[Long, Long]): TaskMetrics = { + val metrics = new TaskMetrics + metrics.setAccumulatorsUpdater(() => accumulatorUpdates.mapValues(new LongSQLMetricValue(_))) + metrics.updateAccumulators() + metrics + } + + test("basic") { + val listener = new SQLListener(TestSQLContext) + val executionId = 0 + val df = createTestDataFrame + val accumulatorIds = + SparkPlanGraph(df.queryExecution.executedPlan).nodes.flatMap(_.metrics.map(_.accumulatorId)) + // Assume all accumulators are long + var accumulatorValue = 0L + val accumulatorUpdates = accumulatorIds.map { id => + accumulatorValue += 1L + (id, accumulatorValue) + }.toMap + + listener.onExecutionStart( + executionId, + "test", + "test", + df.queryExecution.toString, + SparkPlanGraph(df.queryExecution.executedPlan), + System.currentTimeMillis()) + + val executionUIData = listener.executionIdToData(0) + + listener.onJobStart(SparkListenerJobStart( + jobId = 0, + time = System.currentTimeMillis(), + stageInfos = Seq( + createStageInfo(0, 0), + createStageInfo(1, 0) + ), + createProperties(executionId))) + listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 0))) + + assert(listener.getExecutionMetrics(0).isEmpty) + + listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( + // (task id, stage id, stage attempt, metrics) + (0L, 0, 0, createTaskMetrics(accumulatorUpdates)), + (1L, 0, 0, createTaskMetrics(accumulatorUpdates)) + ))) + + assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 2)) + + listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( + // (task id, stage id, stage attempt, metrics) + (0L, 0, 0, createTaskMetrics(accumulatorUpdates)), + (1L, 0, 0, createTaskMetrics(accumulatorUpdates.mapValues(_ * 2))) + ))) + + assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 3)) + + // Retrying a stage should reset the metrics + listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(0, 1))) + + listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( + // (task id, stage id, stage attempt, metrics) + (0L, 0, 1, createTaskMetrics(accumulatorUpdates)), + (1L, 0, 1, createTaskMetrics(accumulatorUpdates)) + ))) + + assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 2)) + + // Ignore the task end for the first attempt + listener.onTaskEnd(SparkListenerTaskEnd( + stageId = 0, + stageAttemptId = 0, + taskType = "", + reason = null, + createTaskInfo(0, 0), + createTaskMetrics(accumulatorUpdates.mapValues(_ * 100)))) + + assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 2)) + + // Finish two tasks + listener.onTaskEnd(SparkListenerTaskEnd( + stageId = 0, + stageAttemptId = 1, + taskType = "", + reason = null, + createTaskInfo(0, 0), + createTaskMetrics(accumulatorUpdates.mapValues(_ * 2)))) + listener.onTaskEnd(SparkListenerTaskEnd( + stageId = 0, + stageAttemptId = 1, + taskType = "", + reason = null, + createTaskInfo(1, 0), + createTaskMetrics(accumulatorUpdates.mapValues(_ * 3)))) + + assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 5)) + + // Summit a new stage + listener.onStageSubmitted(SparkListenerStageSubmitted(createStageInfo(1, 0))) + + listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate("", Seq( + // (task id, stage id, stage attempt, metrics) + (0L, 1, 0, createTaskMetrics(accumulatorUpdates)), + (1L, 1, 0, createTaskMetrics(accumulatorUpdates)) + ))) + + assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 7)) + + // Finish two tasks + listener.onTaskEnd(SparkListenerTaskEnd( + stageId = 1, + stageAttemptId = 0, + taskType = "", + reason = null, + createTaskInfo(0, 0), + createTaskMetrics(accumulatorUpdates.mapValues(_ * 3)))) + listener.onTaskEnd(SparkListenerTaskEnd( + stageId = 1, + stageAttemptId = 0, + taskType = "", + reason = null, + createTaskInfo(1, 0), + createTaskMetrics(accumulatorUpdates.mapValues(_ * 3)))) + + assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 11)) + + assert(executionUIData.runningJobs === Seq(0)) + assert(executionUIData.succeededJobs.isEmpty) + assert(executionUIData.failedJobs.isEmpty) + + listener.onJobEnd(SparkListenerJobEnd( + jobId = 0, + time = System.currentTimeMillis(), + JobSucceeded + )) + listener.onExecutionEnd(executionId, System.currentTimeMillis()) + + assert(executionUIData.runningJobs.isEmpty) + assert(executionUIData.succeededJobs === Seq(0)) + assert(executionUIData.failedJobs.isEmpty) + + assert(listener.getExecutionMetrics(0) === accumulatorUpdates.mapValues(_ * 11)) + } + + test("onExecutionEnd happens before onJobEnd(JobSucceeded)") { + val listener = new SQLListener(TestSQLContext) + val executionId = 0 + val df = createTestDataFrame + listener.onExecutionStart( + executionId, + "test", + "test", + df.queryExecution.toString, + SparkPlanGraph(df.queryExecution.executedPlan), + System.currentTimeMillis()) + listener.onJobStart(SparkListenerJobStart( + jobId = 0, + time = System.currentTimeMillis(), + stageInfos = Nil, + createProperties(executionId))) + listener.onExecutionEnd(executionId, System.currentTimeMillis()) + listener.onJobEnd(SparkListenerJobEnd( + jobId = 0, + time = System.currentTimeMillis(), + JobSucceeded + )) + + val executionUIData = listener.executionIdToData(0) + assert(executionUIData.runningJobs.isEmpty) + assert(executionUIData.succeededJobs === Seq(0)) + assert(executionUIData.failedJobs.isEmpty) + } + + test("onExecutionEnd happens before multiple onJobEnd(JobSucceeded)s") { + val listener = new SQLListener(TestSQLContext) + val executionId = 0 + val df = createTestDataFrame + listener.onExecutionStart( + executionId, + "test", + "test", + df.queryExecution.toString, + SparkPlanGraph(df.queryExecution.executedPlan), + System.currentTimeMillis()) + listener.onJobStart(SparkListenerJobStart( + jobId = 0, + time = System.currentTimeMillis(), + stageInfos = Nil, + createProperties(executionId))) + listener.onJobEnd(SparkListenerJobEnd( + jobId = 0, + time = System.currentTimeMillis(), + JobSucceeded + )) + + listener.onJobStart(SparkListenerJobStart( + jobId = 1, + time = System.currentTimeMillis(), + stageInfos = Nil, + createProperties(executionId))) + listener.onExecutionEnd(executionId, System.currentTimeMillis()) + listener.onJobEnd(SparkListenerJobEnd( + jobId = 1, + time = System.currentTimeMillis(), + JobSucceeded + )) + + val executionUIData = listener.executionIdToData(0) + assert(executionUIData.runningJobs.isEmpty) + assert(executionUIData.succeededJobs.sorted === Seq(0, 1)) + assert(executionUIData.failedJobs.isEmpty) + } + + test("onExecutionEnd happens before onJobEnd(JobFailed)") { + val listener = new SQLListener(TestSQLContext) + val executionId = 0 + val df = createTestDataFrame + listener.onExecutionStart( + executionId, + "test", + "test", + df.queryExecution.toString, + SparkPlanGraph(df.queryExecution.executedPlan), + System.currentTimeMillis()) + listener.onJobStart(SparkListenerJobStart( + jobId = 0, + time = System.currentTimeMillis(), + stageInfos = Seq.empty, + createProperties(executionId))) + listener.onExecutionEnd(executionId, System.currentTimeMillis()) + listener.onJobEnd(SparkListenerJobEnd( + jobId = 0, + time = System.currentTimeMillis(), + JobFailed(new RuntimeException("Oops")) + )) + + val executionUIData = listener.executionIdToData(0) + assert(executionUIData.runningJobs.isEmpty) + assert(executionUIData.succeededJobs.isEmpty) + assert(executionUIData.failedJobs === Seq(0)) + } + + ignore("no memory leak") { + val conf = new SparkConf() + .setMaster("local") + .setAppName("test") + .set("spark.task.maxFailures", "1") // Don't retry the tasks to run this test quickly + .set("spark.sql.ui.retainedExecutions", "50") // Set it to 50 to run this test quickly + val sc = new SparkContext(conf) + try { + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + // Run 100 successful executions and 100 failed executions. + // Each execution only has one job and one stage. + for (i <- 0 until 100) { + val df = Seq( + (1, 1), + (2, 2) + ).toDF() + df.collect() + try { + df.foreach(_ => throw new RuntimeException("Oops")) + } catch { + case e: SparkException => // This is expected for a failed job + } + } + sc.listenerBus.waitUntilEmpty(10000) + assert(sqlContext.listener.getCompletedExecutions.size <= 50) + assert(sqlContext.listener.getFailedExecutions.size <= 50) + // 50 for successful executions and 50 for failed executions + assert(sqlContext.listener.executionIdToData.size <= 100) + assert(sqlContext.listener.jobIdToExecutionId.size <= 100) + assert(sqlContext.listener.stageIdToStageMetrics.size <= 100) + } finally { + sc.stop() + } + } + +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org