http://git-wip-us.apache.org/repos/asf/spark/blob/c1838e43/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala new file mode 100644 index 0000000..4086a13 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -0,0 +1,796 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.net.URI +import java.util.logging.{Level, Logger => JLogger} +import java.util.{List => JList} + +import scala.collection.JavaConversions._ +import scala.collection.mutable +import scala.util.{Failure, Try} + +import com.google.common.base.Objects +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat +import org.apache.parquet.filter2.predicate.FilterApi +import org.apache.parquet.hadoop.metadata.CompressionCodecName +import org.apache.parquet.hadoop.util.ContextUtil +import org.apache.parquet.hadoop.{ParquetOutputCommitter, ParquetRecordReader, _} +import org.apache.parquet.schema.MessageType +import org.apache.parquet.{Log => ParquetLog} + +import org.apache.spark.{Logging, Partition => SparkPartition, SparkException} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.{SqlNewHadoopPartition, SqlNewHadoopRDD, RDD} +import org.apache.spark.rdd.RDD._ +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.PartitionSpec +import org.apache.spark.sql.sources._ +import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.util.{SerializableConfiguration, Utils} + + +private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { + + override def shortName(): String = "parquet" + + override def createRelation( + sqlContext: SQLContext, + paths: Array[String], + schema: Option[StructType], + partitionColumns: Option[StructType], + parameters: Map[String, String]): HadoopFsRelation = { + new ParquetRelation(paths, schema, None, partitionColumns, parameters)(sqlContext) + } +} + +// NOTE: This class is instantiated and used on executor side only, no need to be serializable. +private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext) + extends OutputWriter { + + private val recordWriter: RecordWriter[Void, InternalRow] = { + val outputFormat = { + new ParquetOutputFormat[InternalRow]() { + // Here we override `getDefaultWorkFile` for two reasons: + // + // 1. To allow appending. We need to generate unique output file names to avoid + // overwriting existing files (either exist before the write job, or are just written + // by other tasks within the same write job). + // + // 2. To allow dynamic partitioning. Default `getDefaultWorkFile` uses + // `FileOutputCommitter.getWorkPath()`, which points to the base directory of all + // partitions in the case of dynamic partitioning. + override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { + val uniqueWriteJobId = context.getConfiguration.get("spark.sql.sources.writeJobUUID") + val split = context.getTaskAttemptID.getTaskID.getId + new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") + } + } + } + + outputFormat.getRecordWriter(context) + } + + override def write(row: Row): Unit = throw new UnsupportedOperationException("call writeInternal") + + override protected[sql] def writeInternal(row: InternalRow): Unit = recordWriter.write(null, row) + + override def close(): Unit = recordWriter.close(context) +} + +private[sql] class ParquetRelation( + override val paths: Array[String], + private val maybeDataSchema: Option[StructType], + // This is for metastore conversion. + private val maybePartitionSpec: Option[PartitionSpec], + override val userDefinedPartitionColumns: Option[StructType], + parameters: Map[String, String])( + val sqlContext: SQLContext) + extends HadoopFsRelation(maybePartitionSpec) + with Logging { + + private[sql] def this( + paths: Array[String], + maybeDataSchema: Option[StructType], + maybePartitionSpec: Option[PartitionSpec], + parameters: Map[String, String])( + sqlContext: SQLContext) = { + this( + paths, + maybeDataSchema, + maybePartitionSpec, + maybePartitionSpec.map(_.partitionColumns), + parameters)(sqlContext) + } + + // Should we merge schemas from all Parquet part-files? + private val shouldMergeSchemas = + parameters + .get(ParquetRelation.MERGE_SCHEMA) + .map(_.toBoolean) + .getOrElse(sqlContext.conf.getConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED)) + + private val mergeRespectSummaries = + sqlContext.conf.getConf(SQLConf.PARQUET_SCHEMA_RESPECT_SUMMARIES) + + private val maybeMetastoreSchema = parameters + .get(ParquetRelation.METASTORE_SCHEMA) + .map(DataType.fromJson(_).asInstanceOf[StructType]) + + private lazy val metadataCache: MetadataCache = { + val meta = new MetadataCache + meta.refresh() + meta + } + + override def equals(other: Any): Boolean = other match { + case that: ParquetRelation => + val schemaEquality = if (shouldMergeSchemas) { + this.shouldMergeSchemas == that.shouldMergeSchemas + } else { + this.dataSchema == that.dataSchema && + this.schema == that.schema + } + + this.paths.toSet == that.paths.toSet && + schemaEquality && + this.maybeDataSchema == that.maybeDataSchema && + this.partitionColumns == that.partitionColumns + + case _ => false + } + + override def hashCode(): Int = { + if (shouldMergeSchemas) { + Objects.hashCode( + Boolean.box(shouldMergeSchemas), + paths.toSet, + maybeDataSchema, + partitionColumns) + } else { + Objects.hashCode( + Boolean.box(shouldMergeSchemas), + paths.toSet, + dataSchema, + schema, + maybeDataSchema, + partitionColumns) + } + } + + /** Constraints on schema of dataframe to be stored. */ + private def checkConstraints(schema: StructType): Unit = { + if (schema.fieldNames.length != schema.fieldNames.distinct.length) { + val duplicateColumns = schema.fieldNames.groupBy(identity).collect { + case (x, ys) if ys.length > 1 => "\"" + x + "\"" + }.mkString(", ") + throw new AnalysisException(s"Duplicate column(s) : $duplicateColumns found, " + + s"cannot save to parquet format") + } + } + + override def dataSchema: StructType = { + val schema = maybeDataSchema.getOrElse(metadataCache.dataSchema) + // check if schema satisfies the constraints + // before moving forward + checkConstraints(schema) + schema + } + + override private[sql] def refresh(): Unit = { + super.refresh() + metadataCache.refresh() + } + + // Parquet data source always uses Catalyst internal representations. + override val needConversion: Boolean = false + + override def sizeInBytes: Long = metadataCache.dataStatuses.map(_.getLen).sum + + override def prepareJobForWrite(job: Job): OutputWriterFactory = { + val conf = ContextUtil.getConfiguration(job) + + val committerClass = + conf.getClass( + SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key, + classOf[ParquetOutputCommitter], + classOf[ParquetOutputCommitter]) + + if (conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key) == null) { + logInfo("Using default output committer for Parquet: " + + classOf[ParquetOutputCommitter].getCanonicalName) + } else { + logInfo("Using user defined output committer for Parquet: " + committerClass.getCanonicalName) + } + + conf.setClass( + SQLConf.OUTPUT_COMMITTER_CLASS.key, + committerClass, + classOf[ParquetOutputCommitter]) + + // We're not really using `ParquetOutputFormat[Row]` for writing data here, because we override + // it in `ParquetOutputWriter` to support appending and dynamic partitioning. The reason why + // we set it here is to setup the output committer class to `ParquetOutputCommitter`, which is + // bundled with `ParquetOutputFormat[Row]`. + job.setOutputFormatClass(classOf[ParquetOutputFormat[Row]]) + + // TODO There's no need to use two kinds of WriteSupport + // We should unify them. `SpecificMutableRow` can process both atomic (primitive) types and + // complex types. + val writeSupportClass = + if (dataSchema.map(_.dataType).forall(ParquetTypesConverter.isPrimitiveType)) { + classOf[MutableRowWriteSupport] + } else { + classOf[RowWriteSupport] + } + + ParquetOutputFormat.setWriteSupportClass(job, writeSupportClass) + RowWriteSupport.setSchema(dataSchema.toAttributes, conf) + + // Sets compression scheme + conf.set( + ParquetOutputFormat.COMPRESSION, + ParquetRelation + .shortParquetCompressionCodecNames + .getOrElse( + sqlContext.conf.parquetCompressionCodec.toUpperCase, + CompressionCodecName.UNCOMPRESSED).name()) + + new OutputWriterFactory { + override def newInstance( + path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { + new ParquetOutputWriter(path, context) + } + } + } + + override def buildScan( + requiredColumns: Array[String], + filters: Array[Filter], + inputFiles: Array[FileStatus], + broadcastedConf: Broadcast[SerializableConfiguration]): RDD[Row] = { + val useMetadataCache = sqlContext.getConf(SQLConf.PARQUET_CACHE_METADATA) + val parquetFilterPushDown = sqlContext.conf.parquetFilterPushDown + val assumeBinaryIsString = sqlContext.conf.isParquetBinaryAsString + val assumeInt96IsTimestamp = sqlContext.conf.isParquetINT96AsTimestamp + val followParquetFormatSpec = sqlContext.conf.followParquetFormatSpec + + // Create the function to set variable Parquet confs at both driver and executor side. + val initLocalJobFuncOpt = + ParquetRelation.initializeLocalJobFunc( + requiredColumns, + filters, + dataSchema, + useMetadataCache, + parquetFilterPushDown, + assumeBinaryIsString, + assumeInt96IsTimestamp, + followParquetFormatSpec) _ + + // Create the function to set input paths at the driver side. + val setInputPaths = ParquetRelation.initializeDriverSideJobFunc(inputFiles) _ + + Utils.withDummyCallSite(sqlContext.sparkContext) { + new SqlNewHadoopRDD( + sc = sqlContext.sparkContext, + broadcastedConf = broadcastedConf, + initDriverSideJobFuncOpt = Some(setInputPaths), + initLocalJobFuncOpt = Some(initLocalJobFuncOpt), + inputFormatClass = classOf[ParquetInputFormat[InternalRow]], + valueClass = classOf[InternalRow]) { + + val cacheMetadata = useMetadataCache + + @transient val cachedStatuses = inputFiles.map { f => + // In order to encode the authority of a Path containing special characters such as '/' + // (which does happen in some S3N credentials), we need to use the string returned by the + // URI of the path to create a new Path. + val pathWithEscapedAuthority = escapePathUserInfo(f.getPath) + new FileStatus( + f.getLen, f.isDir, f.getReplication, f.getBlockSize, f.getModificationTime, + f.getAccessTime, f.getPermission, f.getOwner, f.getGroup, pathWithEscapedAuthority) + }.toSeq + + private def escapePathUserInfo(path: Path): Path = { + val uri = path.toUri + new Path(new URI( + uri.getScheme, uri.getRawUserInfo, uri.getHost, uri.getPort, uri.getPath, + uri.getQuery, uri.getFragment)) + } + + // Overridden so we can inject our own cached files statuses. + override def getPartitions: Array[SparkPartition] = { + val inputFormat = new ParquetInputFormat[InternalRow] { + override def listStatus(jobContext: JobContext): JList[FileStatus] = { + if (cacheMetadata) cachedStatuses else super.listStatus(jobContext) + } + } + + val jobContext = newJobContext(getConf(isDriverSide = true), jobId) + val rawSplits = inputFormat.getSplits(jobContext) + + Array.tabulate[SparkPartition](rawSplits.size) { i => + new SqlNewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + } + } + }.asInstanceOf[RDD[Row]] // type erasure hack to pass RDD[InternalRow] as RDD[Row] + } + } + + private class MetadataCache { + // `FileStatus` objects of all "_metadata" files. + private var metadataStatuses: Array[FileStatus] = _ + + // `FileStatus` objects of all "_common_metadata" files. + private var commonMetadataStatuses: Array[FileStatus] = _ + + // `FileStatus` objects of all data files (Parquet part-files). + var dataStatuses: Array[FileStatus] = _ + + // Schema of the actual Parquet files, without partition columns discovered from partition + // directory paths. + var dataSchema: StructType = null + + // Schema of the whole table, including partition columns. + var schema: StructType = _ + + // Cached leaves + var cachedLeaves: Set[FileStatus] = null + + /** + * Refreshes `FileStatus`es, footers, partition spec, and table schema. + */ + def refresh(): Unit = { + val currentLeafStatuses = cachedLeafStatuses() + + // Check if cachedLeafStatuses is changed or not + val leafStatusesChanged = (cachedLeaves == null) || + !cachedLeaves.equals(currentLeafStatuses) + + if (leafStatusesChanged) { + cachedLeaves = currentLeafStatuses.toIterator.toSet + + // Lists `FileStatus`es of all leaf nodes (files) under all base directories. + val leaves = currentLeafStatuses.filter { f => + isSummaryFile(f.getPath) || + !(f.getPath.getName.startsWith("_") || f.getPath.getName.startsWith(".")) + }.toArray + + dataStatuses = leaves.filterNot(f => isSummaryFile(f.getPath)) + metadataStatuses = + leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_METADATA_FILE) + commonMetadataStatuses = + leaves.filter(_.getPath.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE) + + dataSchema = { + val dataSchema0 = maybeDataSchema + .orElse(readSchema()) + .orElse(maybeMetastoreSchema) + .getOrElse(throw new AnalysisException( + s"Failed to discover schema of Parquet file(s) in the following location(s):\n" + + paths.mkString("\n\t"))) + + // If this Parquet relation is converted from a Hive Metastore table, must reconcile case + // case insensitivity issue and possible schema mismatch (probably caused by schema + // evolution). + maybeMetastoreSchema + .map(ParquetRelation.mergeMetastoreParquetSchema(_, dataSchema0)) + .getOrElse(dataSchema0) + } + } + } + + private def isSummaryFile(file: Path): Boolean = { + file.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE || + file.getName == ParquetFileWriter.PARQUET_METADATA_FILE + } + + private def readSchema(): Option[StructType] = { + // Sees which file(s) we need to touch in order to figure out the schema. + // + // Always tries the summary files first if users don't require a merged schema. In this case, + // "_common_metadata" is more preferable than "_metadata" because it doesn't contain row + // groups information, and could be much smaller for large Parquet files with lots of row + // groups. If no summary file is available, falls back to some random part-file. + // + // NOTE: Metadata stored in the summary files are merged from all part-files. However, for + // user defined key-value metadata (in which we store Spark SQL schema), Parquet doesn't know + // how to merge them correctly if some key is associated with different values in different + // part-files. When this happens, Parquet simply gives up generating the summary file. This + // implies that if a summary file presents, then: + // + // 1. Either all part-files have exactly the same Spark SQL schema, or + // 2. Some part-files don't contain Spark SQL schema in the key-value metadata at all (thus + // their schemas may differ from each other). + // + // Here we tend to be pessimistic and take the second case into account. Basically this means + // we can't trust the summary files if users require a merged schema, and must touch all part- + // files to do the merge. + val filesToTouch = + if (shouldMergeSchemas) { + // Also includes summary files, 'cause there might be empty partition directories. + + // If mergeRespectSummaries config is true, we assume that all part-files are the same for + // their schema with summary files, so we ignore them when merging schema. + // If the config is disabled, which is the default setting, we merge all part-files. + // In this mode, we only need to merge schemas contained in all those summary files. + // You should enable this configuration only if you are very sure that for the parquet + // part-files to read there are corresponding summary files containing correct schema. + + val needMerged: Seq[FileStatus] = + if (mergeRespectSummaries) { + Seq() + } else { + dataStatuses + } + (metadataStatuses ++ commonMetadataStatuses ++ needMerged).toSeq + } else { + // Tries any "_common_metadata" first. Parquet files written by old versions or Parquet + // don't have this. + commonMetadataStatuses.headOption + // Falls back to "_metadata" + .orElse(metadataStatuses.headOption) + // Summary file(s) not found, the Parquet file is either corrupted, or different part- + // files contain conflicting user defined metadata (two or more values are associated + // with a same key in different files). In either case, we fall back to any of the + // first part-file, and just assume all schemas are consistent. + .orElse(dataStatuses.headOption) + .toSeq + } + + assert( + filesToTouch.nonEmpty || maybeDataSchema.isDefined || maybeMetastoreSchema.isDefined, + "No predefined schema found, " + + s"and no Parquet data files or summary files found under ${paths.mkString(", ")}.") + + ParquetRelation.mergeSchemasInParallel(filesToTouch, sqlContext) + } + } +} + +private[sql] object ParquetRelation extends Logging { + // Whether we should merge schemas collected from all Parquet part-files. + private[sql] val MERGE_SCHEMA = "mergeSchema" + + // Hive Metastore schema, used when converting Metastore Parquet tables. This option is only used + // internally. + private[sql] val METASTORE_SCHEMA = "metastoreSchema" + + /** This closure sets various Parquet configurations at both driver side and executor side. */ + private[parquet] def initializeLocalJobFunc( + requiredColumns: Array[String], + filters: Array[Filter], + dataSchema: StructType, + useMetadataCache: Boolean, + parquetFilterPushDown: Boolean, + assumeBinaryIsString: Boolean, + assumeInt96IsTimestamp: Boolean, + followParquetFormatSpec: Boolean)(job: Job): Unit = { + val conf = job.getConfiguration + conf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[CatalystReadSupport].getName) + + // Try to push down filters when filter push-down is enabled. + if (parquetFilterPushDown) { + filters + // Collects all converted Parquet filter predicates. Notice that not all predicates can be + // converted (`ParquetFilters.createFilter` returns an `Option`). That's why a `flatMap` + // is used here. + .flatMap(ParquetFilters.createFilter(dataSchema, _)) + .reduceOption(FilterApi.and) + .foreach(ParquetInputFormat.setFilterPredicate(conf, _)) + } + + conf.set(CatalystReadSupport.SPARK_ROW_REQUESTED_SCHEMA, { + val requestedSchema = StructType(requiredColumns.map(dataSchema(_))) + CatalystSchemaConverter.checkFieldNames(requestedSchema).json + }) + + conf.set( + RowWriteSupport.SPARK_ROW_SCHEMA, + CatalystSchemaConverter.checkFieldNames(dataSchema).json) + + // Tell FilteringParquetRowInputFormat whether it's okay to cache Parquet and FS metadata + conf.setBoolean(SQLConf.PARQUET_CACHE_METADATA.key, useMetadataCache) + + // Sets flags for Parquet schema conversion + conf.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING.key, assumeBinaryIsString) + conf.setBoolean(SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, assumeInt96IsTimestamp) + conf.setBoolean(SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key, followParquetFormatSpec) + } + + /** This closure sets input paths at the driver side. */ + private[parquet] def initializeDriverSideJobFunc( + inputFiles: Array[FileStatus])(job: Job): Unit = { + // We side the input paths at the driver side. + logInfo(s"Reading Parquet file(s) from ${inputFiles.map(_.getPath).mkString(", ")}") + if (inputFiles.nonEmpty) { + FileInputFormat.setInputPaths(job, inputFiles.map(_.getPath): _*) + } + } + + private[parquet] def readSchema( + footers: Seq[Footer], sqlContext: SQLContext): Option[StructType] = { + + def parseParquetSchema(schema: MessageType): StructType = { + val converter = new CatalystSchemaConverter( + sqlContext.conf.isParquetBinaryAsString, + sqlContext.conf.isParquetBinaryAsString, + sqlContext.conf.followParquetFormatSpec) + + converter.convert(schema) + } + + val seen = mutable.HashSet[String]() + val finalSchemas: Seq[StructType] = footers.flatMap { footer => + val metadata = footer.getParquetMetadata.getFileMetaData + val serializedSchema = metadata + .getKeyValueMetaData + .toMap + .get(CatalystReadSupport.SPARK_METADATA_KEY) + if (serializedSchema.isEmpty) { + // Falls back to Parquet schema if no Spark SQL schema found. + Some(parseParquetSchema(metadata.getSchema)) + } else if (!seen.contains(serializedSchema.get)) { + seen += serializedSchema.get + + // Don't throw even if we failed to parse the serialized Spark schema. Just fallback to + // whatever is available. + Some(Try(DataType.fromJson(serializedSchema.get)) + .recover { case _: Throwable => + logInfo( + s"Serialized Spark schema in Parquet key-value metadata is not in JSON format, " + + "falling back to the deprecated DataType.fromCaseClassString parser.") + DataType.fromCaseClassString(serializedSchema.get) + } + .recover { case cause: Throwable => + logWarning( + s"""Failed to parse serialized Spark schema in Parquet key-value metadata: + |\t$serializedSchema + """.stripMargin, + cause) + } + .map(_.asInstanceOf[StructType]) + .getOrElse { + // Falls back to Parquet schema if Spark SQL schema can't be parsed. + parseParquetSchema(metadata.getSchema) + }) + } else { + None + } + } + + finalSchemas.reduceOption { (left, right) => + try left.merge(right) catch { case e: Throwable => + throw new SparkException(s"Failed to merge incompatible schemas $left and $right", e) + } + } + } + + /** + * Reconciles Hive Metastore case insensitivity issue and data type conflicts between Metastore + * schema and Parquet schema. + * + * Hive doesn't retain case information, while Parquet is case sensitive. On the other hand, the + * schema read from Parquet files may be incomplete (e.g. older versions of Parquet doesn't + * distinguish binary and string). This method generates a correct schema by merging Metastore + * schema data types and Parquet schema field names. + */ + private[parquet] def mergeMetastoreParquetSchema( + metastoreSchema: StructType, + parquetSchema: StructType): StructType = { + def schemaConflictMessage: String = + s"""Converting Hive Metastore Parquet, but detected conflicting schemas. Metastore schema: + |${metastoreSchema.prettyJson} + | + |Parquet schema: + |${parquetSchema.prettyJson} + """.stripMargin + + val mergedParquetSchema = mergeMissingNullableFields(metastoreSchema, parquetSchema) + + assert(metastoreSchema.size <= mergedParquetSchema.size, schemaConflictMessage) + + val ordinalMap = metastoreSchema.zipWithIndex.map { + case (field, index) => field.name.toLowerCase -> index + }.toMap + + val reorderedParquetSchema = mergedParquetSchema.sortBy(f => + ordinalMap.getOrElse(f.name.toLowerCase, metastoreSchema.size + 1)) + + StructType(metastoreSchema.zip(reorderedParquetSchema).map { + // Uses Parquet field names but retains Metastore data types. + case (mSchema, pSchema) if mSchema.name.toLowerCase == pSchema.name.toLowerCase => + mSchema.copy(name = pSchema.name) + case _ => + throw new SparkException(schemaConflictMessage) + }) + } + + /** + * Returns the original schema from the Parquet file with any missing nullable fields from the + * Hive Metastore schema merged in. + * + * When constructing a DataFrame from a collection of structured data, the resulting object has + * a schema corresponding to the union of the fields present in each element of the collection. + * Spark SQL simply assigns a null value to any field that isn't present for a particular row. + * In some cases, it is possible that a given table partition stored as a Parquet file doesn't + * contain a particular nullable field in its schema despite that field being present in the + * table schema obtained from the Hive Metastore. This method returns a schema representing the + * Parquet file schema along with any additional nullable fields from the Metastore schema + * merged in. + */ + private[parquet] def mergeMissingNullableFields( + metastoreSchema: StructType, + parquetSchema: StructType): StructType = { + val fieldMap = metastoreSchema.map(f => f.name.toLowerCase -> f).toMap + val missingFields = metastoreSchema + .map(_.name.toLowerCase) + .diff(parquetSchema.map(_.name.toLowerCase)) + .map(fieldMap(_)) + .filter(_.nullable) + StructType(parquetSchema ++ missingFields) + } + + /** + * Figures out a merged Parquet schema with a distributed Spark job. + * + * Note that locality is not taken into consideration here because: + * + * 1. For a single Parquet part-file, in most cases the footer only resides in the last block of + * that file. Thus we only need to retrieve the location of the last block. However, Hadoop + * `FileSystem` only provides API to retrieve locations of all blocks, which can be + * potentially expensive. + * + * 2. This optimization is mainly useful for S3, where file metadata operations can be pretty + * slow. And basically locality is not available when using S3 (you can't run computation on + * S3 nodes). + */ + def mergeSchemasInParallel( + filesToTouch: Seq[FileStatus], sqlContext: SQLContext): Option[StructType] = { + val assumeBinaryIsString = sqlContext.conf.isParquetBinaryAsString + val assumeInt96IsTimestamp = sqlContext.conf.isParquetINT96AsTimestamp + val followParquetFormatSpec = sqlContext.conf.followParquetFormatSpec + val serializedConf = new SerializableConfiguration(sqlContext.sparkContext.hadoopConfiguration) + + // HACK ALERT: + // + // Parquet requires `FileStatus`es to read footers. Here we try to send cached `FileStatus`es + // to executor side to avoid fetching them again. However, `FileStatus` is not `Serializable` + // but only `Writable`. What makes it worth, for some reason, `FileStatus` doesn't play well + // with `SerializableWritable[T]` and always causes a weird `IllegalStateException`. These + // facts virtually prevents us to serialize `FileStatus`es. + // + // Since Parquet only relies on path and length information of those `FileStatus`es to read + // footers, here we just extract them (which can be easily serialized), send them to executor + // side, and resemble fake `FileStatus`es there. + val partialFileStatusInfo = filesToTouch.map(f => (f.getPath.toString, f.getLen)) + + // Issues a Spark job to read Parquet schema in parallel. + val partiallyMergedSchemas = + sqlContext + .sparkContext + .parallelize(partialFileStatusInfo) + .mapPartitions { iterator => + // Resembles fake `FileStatus`es with serialized path and length information. + val fakeFileStatuses = iterator.map { case (path, length) => + new FileStatus(length, false, 0, 0, 0, 0, null, null, null, new Path(path)) + }.toSeq + + // Skips row group information since we only need the schema + val skipRowGroups = true + + // Reads footers in multi-threaded manner within each task + val footers = + ParquetFileReader.readAllFootersInParallel( + serializedConf.value, fakeFileStatuses, skipRowGroups) + + // Converter used to convert Parquet `MessageType` to Spark SQL `StructType` + val converter = + new CatalystSchemaConverter( + assumeBinaryIsString = assumeBinaryIsString, + assumeInt96IsTimestamp = assumeInt96IsTimestamp, + followParquetFormatSpec = followParquetFormatSpec) + + footers.map { footer => + ParquetRelation.readSchemaFromFooter(footer, converter) + }.reduceOption(_ merge _).iterator + }.collect() + + partiallyMergedSchemas.reduceOption(_ merge _) + } + + /** + * Reads Spark SQL schema from a Parquet footer. If a valid serialized Spark SQL schema string + * can be found in the file metadata, returns the deserialized [[StructType]], otherwise, returns + * a [[StructType]] converted from the [[MessageType]] stored in this footer. + */ + def readSchemaFromFooter( + footer: Footer, converter: CatalystSchemaConverter): StructType = { + val fileMetaData = footer.getParquetMetadata.getFileMetaData + fileMetaData + .getKeyValueMetaData + .toMap + .get(CatalystReadSupport.SPARK_METADATA_KEY) + .flatMap(deserializeSchemaString) + .getOrElse(converter.convert(fileMetaData.getSchema)) + } + + private def deserializeSchemaString(schemaString: String): Option[StructType] = { + // Tries to deserialize the schema string as JSON first, then falls back to the case class + // string parser (data generated by older versions of Spark SQL uses this format). + Try(DataType.fromJson(schemaString).asInstanceOf[StructType]).recover { + case _: Throwable => + logInfo( + s"Serialized Spark schema in Parquet key-value metadata is not in JSON format, " + + "falling back to the deprecated DataType.fromCaseClassString parser.") + DataType.fromCaseClassString(schemaString).asInstanceOf[StructType] + }.recoverWith { + case cause: Throwable => + logWarning( + "Failed to parse and ignored serialized Spark schema in " + + s"Parquet key-value metadata:\n\t$schemaString", cause) + Failure(cause) + }.toOption + } + + def enableLogForwarding() { + // Note: the org.apache.parquet.Log class has a static initializer that + // sets the java.util.logging Logger for "org.apache.parquet". This + // checks first to see if there's any handlers already set + // and if not it creates them. If this method executes prior + // to that class being loaded then: + // 1) there's no handlers installed so there's none to + // remove. But when it IS finally loaded the desired affect + // of removing them is circumvented. + // 2) The parquet.Log static initializer calls setUseParentHandlers(false) + // undoing the attempt to override the logging here. + // + // Therefore we need to force the class to be loaded. + // This should really be resolved by Parquet. + Utils.classForName(classOf[ParquetLog].getName) + + // Note: Logger.getLogger("parquet") has a default logger + // that appends to Console which needs to be cleared. + val parquetLogger = JLogger.getLogger(classOf[ParquetLog].getPackage.getName) + parquetLogger.getHandlers.foreach(parquetLogger.removeHandler) + parquetLogger.setUseParentHandlers(true) + + // Disables a WARN log message in ParquetOutputCommitter. We first ensure that + // ParquetOutputCommitter is loaded and the static LOG field gets initialized. + // See https://issues.apache.org/jira/browse/SPARK-5968 for details + Utils.classForName(classOf[ParquetOutputCommitter].getName) + JLogger.getLogger(classOf[ParquetOutputCommitter].getName).setLevel(Level.OFF) + + // Similar as above, disables a unnecessary WARN log message in ParquetRecordReader. + // See https://issues.apache.org/jira/browse/PARQUET-220 for details + Utils.classForName(classOf[ParquetRecordReader[_]].getName) + JLogger.getLogger(classOf[ParquetRecordReader[_]].getName).setLevel(Level.OFF) + } + + // The parquet compression short names + val shortParquetCompressionCodecNames = Map( + "NONE" -> CompressionCodecName.UNCOMPRESSED, + "UNCOMPRESSED" -> CompressionCodecName.UNCOMPRESSED, + "SNAPPY" -> CompressionCodecName.SNAPPY, + "GZIP" -> CompressionCodecName.GZIP, + "LZO" -> CompressionCodecName.LZO) +}
http://git-wip-us.apache.org/repos/asf/spark/blob/c1838e43/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTableSupport.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTableSupport.scala new file mode 100644 index 0000000..3191cf3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTableSupport.scala @@ -0,0 +1,322 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.math.BigInteger +import java.nio.{ByteBuffer, ByteOrder} +import java.util.{HashMap => JHashMap} + +import org.apache.hadoop.conf.Configuration +import org.apache.parquet.column.ParquetProperties +import org.apache.parquet.hadoop.ParquetOutputFormat +import org.apache.parquet.hadoop.api.WriteSupport +import org.apache.parquet.io.api._ + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * A `parquet.hadoop.api.WriteSupport` for Row objects. + */ +private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Logging { + + private[parquet] var writer: RecordConsumer = null + private[parquet] var attributes: Array[Attribute] = null + + override def init(configuration: Configuration): WriteSupport.WriteContext = { + val origAttributesStr: String = configuration.get(RowWriteSupport.SPARK_ROW_SCHEMA) + val metadata = new JHashMap[String, String]() + metadata.put(CatalystReadSupport.SPARK_METADATA_KEY, origAttributesStr) + + if (attributes == null) { + attributes = ParquetTypesConverter.convertFromString(origAttributesStr).toArray + } + + log.debug(s"write support initialized for requested schema $attributes") + ParquetRelation.enableLogForwarding() + new WriteSupport.WriteContext(ParquetTypesConverter.convertFromAttributes(attributes), metadata) + } + + override def prepareForWrite(recordConsumer: RecordConsumer): Unit = { + writer = recordConsumer + log.debug(s"preparing for write with schema $attributes") + } + + override def write(record: InternalRow): Unit = { + val attributesSize = attributes.size + if (attributesSize > record.numFields) { + throw new IndexOutOfBoundsException("Trying to write more fields than contained in row " + + s"($attributesSize > ${record.numFields})") + } + + var index = 0 + writer.startMessage() + while(index < attributesSize) { + // null values indicate optional fields but we do not check currently + if (!record.isNullAt(index)) { + writer.startField(attributes(index).name, index) + writeValue(attributes(index).dataType, record.get(index, attributes(index).dataType)) + writer.endField(attributes(index).name, index) + } + index = index + 1 + } + writer.endMessage() + } + + private[parquet] def writeValue(schema: DataType, value: Any): Unit = { + if (value != null) { + schema match { + case t: UserDefinedType[_] => writeValue(t.sqlType, value) + case t @ ArrayType(_, _) => writeArray( + t, + value.asInstanceOf[CatalystConverter.ArrayScalaType]) + case t @ MapType(_, _, _) => writeMap( + t, + value.asInstanceOf[CatalystConverter.MapScalaType]) + case t @ StructType(_) => writeStruct( + t, + value.asInstanceOf[CatalystConverter.StructScalaType]) + case _ => writePrimitive(schema.asInstanceOf[AtomicType], value) + } + } + } + + private[parquet] def writePrimitive(schema: DataType, value: Any): Unit = { + if (value != null) { + schema match { + case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean]) + case ByteType => writer.addInteger(value.asInstanceOf[Byte]) + case ShortType => writer.addInteger(value.asInstanceOf[Short]) + case IntegerType | DateType => writer.addInteger(value.asInstanceOf[Int]) + case LongType => writer.addLong(value.asInstanceOf[Long]) + case TimestampType => writeTimestamp(value.asInstanceOf[Long]) + case FloatType => writer.addFloat(value.asInstanceOf[Float]) + case DoubleType => writer.addDouble(value.asInstanceOf[Double]) + case StringType => writer.addBinary( + Binary.fromByteArray(value.asInstanceOf[UTF8String].getBytes)) + case BinaryType => writer.addBinary( + Binary.fromByteArray(value.asInstanceOf[Array[Byte]])) + case DecimalType.Fixed(precision, _) => + writeDecimal(value.asInstanceOf[Decimal], precision) + case _ => sys.error(s"Do not know how to writer $schema to consumer") + } + } + } + + private[parquet] def writeStruct( + schema: StructType, + struct: CatalystConverter.StructScalaType): Unit = { + if (struct != null) { + val fields = schema.fields.toArray + writer.startGroup() + var i = 0 + while(i < fields.length) { + if (!struct.isNullAt(i)) { + writer.startField(fields(i).name, i) + writeValue(fields(i).dataType, struct.get(i, fields(i).dataType)) + writer.endField(fields(i).name, i) + } + i = i + 1 + } + writer.endGroup() + } + } + + private[parquet] def writeArray( + schema: ArrayType, + array: CatalystConverter.ArrayScalaType): Unit = { + val elementType = schema.elementType + writer.startGroup() + if (array.numElements() > 0) { + if (schema.containsNull) { + writer.startField(CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME, 0) + var i = 0 + while (i < array.numElements()) { + writer.startGroup() + if (!array.isNullAt(i)) { + writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) + writeValue(elementType, array.get(i, elementType)) + writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) + } + writer.endGroup() + i = i + 1 + } + writer.endField(CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME, 0) + } else { + writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) + var i = 0 + while (i < array.numElements()) { + writeValue(elementType, array.get(i, elementType)) + i = i + 1 + } + writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) + } + } + writer.endGroup() + } + + private[parquet] def writeMap( + schema: MapType, + map: CatalystConverter.MapScalaType): Unit = { + writer.startGroup() + val length = map.numElements() + if (length > 0) { + writer.startField(CatalystConverter.MAP_SCHEMA_NAME, 0) + map.foreach(schema.keyType, schema.valueType, (key, value) => { + writer.startGroup() + writer.startField(CatalystConverter.MAP_KEY_SCHEMA_NAME, 0) + writeValue(schema.keyType, key) + writer.endField(CatalystConverter.MAP_KEY_SCHEMA_NAME, 0) + if (value != null) { + writer.startField(CatalystConverter.MAP_VALUE_SCHEMA_NAME, 1) + writeValue(schema.valueType, value) + writer.endField(CatalystConverter.MAP_VALUE_SCHEMA_NAME, 1) + } + writer.endGroup() + }) + writer.endField(CatalystConverter.MAP_SCHEMA_NAME, 0) + } + writer.endGroup() + } + + // Scratch array used to write decimals as fixed-length byte array + private[this] var reusableDecimalBytes = new Array[Byte](16) + + private[parquet] def writeDecimal(decimal: Decimal, precision: Int): Unit = { + val numBytes = CatalystSchemaConverter.minBytesForPrecision(precision) + + def longToBinary(unscaled: Long): Binary = { + var i = 0 + var shift = 8 * (numBytes - 1) + while (i < numBytes) { + reusableDecimalBytes(i) = (unscaled >> shift).toByte + i += 1 + shift -= 8 + } + Binary.fromByteArray(reusableDecimalBytes, 0, numBytes) + } + + def bigIntegerToBinary(unscaled: BigInteger): Binary = { + unscaled.toByteArray match { + case bytes if bytes.length == numBytes => + Binary.fromByteArray(bytes) + + case bytes if bytes.length <= reusableDecimalBytes.length => + val signedByte = (if (bytes.head < 0) -1 else 0).toByte + java.util.Arrays.fill(reusableDecimalBytes, 0, numBytes - bytes.length, signedByte) + System.arraycopy(bytes, 0, reusableDecimalBytes, numBytes - bytes.length, bytes.length) + Binary.fromByteArray(reusableDecimalBytes, 0, numBytes) + + case bytes => + reusableDecimalBytes = new Array[Byte](bytes.length) + bigIntegerToBinary(unscaled) + } + } + + val binary = if (numBytes <= 8) { + longToBinary(decimal.toUnscaledLong) + } else { + bigIntegerToBinary(decimal.toJavaBigDecimal.unscaledValue()) + } + + writer.addBinary(binary) + } + + // array used to write Timestamp as Int96 (fixed-length binary) + private[this] val int96buf = new Array[Byte](12) + + private[parquet] def writeTimestamp(ts: Long): Unit = { + val (julianDay, timeOfDayNanos) = DateTimeUtils.toJulianDay(ts) + val buf = ByteBuffer.wrap(int96buf) + buf.order(ByteOrder.LITTLE_ENDIAN) + buf.putLong(timeOfDayNanos) + buf.putInt(julianDay) + writer.addBinary(Binary.fromByteArray(int96buf)) + } +} + +// Optimized for non-nested rows +private[parquet] class MutableRowWriteSupport extends RowWriteSupport { + override def write(record: InternalRow): Unit = { + val attributesSize = attributes.size + if (attributesSize > record.numFields) { + throw new IndexOutOfBoundsException("Trying to write more fields than contained in row " + + s"($attributesSize > ${record.numFields})") + } + + var index = 0 + writer.startMessage() + while(index < attributesSize) { + // null values indicate optional fields but we do not check currently + if (!record.isNullAt(index) && !record.isNullAt(index)) { + writer.startField(attributes(index).name, index) + consumeType(attributes(index).dataType, record, index) + writer.endField(attributes(index).name, index) + } + index = index + 1 + } + writer.endMessage() + } + + private def consumeType( + ctype: DataType, + record: InternalRow, + index: Int): Unit = { + ctype match { + case BooleanType => writer.addBoolean(record.getBoolean(index)) + case ByteType => writer.addInteger(record.getByte(index)) + case ShortType => writer.addInteger(record.getShort(index)) + case IntegerType | DateType => writer.addInteger(record.getInt(index)) + case LongType => writer.addLong(record.getLong(index)) + case TimestampType => writeTimestamp(record.getLong(index)) + case FloatType => writer.addFloat(record.getFloat(index)) + case DoubleType => writer.addDouble(record.getDouble(index)) + case StringType => + writer.addBinary(Binary.fromByteArray(record.getUTF8String(index).getBytes)) + case BinaryType => + writer.addBinary(Binary.fromByteArray(record.getBinary(index))) + case DecimalType.Fixed(precision, scale) => + writeDecimal(record.getDecimal(index, precision, scale), precision) + case _ => sys.error(s"Unsupported datatype $ctype, cannot write to consumer") + } + } +} + +private[parquet] object RowWriteSupport { + val SPARK_ROW_SCHEMA: String = "org.apache.spark.sql.parquet.row.attributes" + + def getSchema(configuration: Configuration): Seq[Attribute] = { + val schemaString = configuration.get(RowWriteSupport.SPARK_ROW_SCHEMA) + if (schemaString == null) { + throw new RuntimeException("Missing schema!") + } + ParquetTypesConverter.convertFromString(schemaString) + } + + def setSchema(schema: Seq[Attribute], configuration: Configuration) { + val encoded = ParquetTypesConverter.convertToString(schema) + configuration.set(SPARK_ROW_SCHEMA, encoded) + configuration.set( + ParquetOutputFormat.WRITER_VERSION, + ParquetProperties.WriterVersion.PARQUET_1_0.toString) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/c1838e43/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala new file mode 100644 index 0000000..019db34 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTypesConverter.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.io.IOException + +import scala.collection.JavaConversions._ +import scala.util.Try + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.mapreduce.Job +import org.apache.parquet.format.converter.ParquetMetadataConverter +import org.apache.parquet.hadoop.metadata.{FileMetaData, ParquetMetadata} +import org.apache.parquet.hadoop.util.ContextUtil +import org.apache.parquet.hadoop.{Footer, ParquetFileReader, ParquetFileWriter} +import org.apache.parquet.schema.MessageType + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.types._ + + +private[parquet] object ParquetTypesConverter extends Logging { + def isPrimitiveType(ctype: DataType): Boolean = ctype match { + case _: NumericType | BooleanType | DateType | TimestampType | StringType | BinaryType => true + case _ => false + } + + /** + * Compute the FIXED_LEN_BYTE_ARRAY length needed to represent a given DECIMAL precision. + */ + private[parquet] val BYTES_FOR_PRECISION = Array.tabulate[Int](38) { precision => + var length = 1 + while (math.pow(2.0, 8 * length - 1) < math.pow(10.0, precision)) { + length += 1 + } + length + } + + def convertFromAttributes(attributes: Seq[Attribute]): MessageType = { + val converter = new CatalystSchemaConverter() + converter.convert(StructType.fromAttributes(attributes)) + } + + def convertFromString(string: String): Seq[Attribute] = { + Try(DataType.fromJson(string)).getOrElse(DataType.fromCaseClassString(string)) match { + case s: StructType => s.toAttributes + case other => sys.error(s"Can convert $string to row") + } + } + + def convertToString(schema: Seq[Attribute]): String = { + schema.map(_.name).foreach(CatalystSchemaConverter.checkFieldName) + StructType.fromAttributes(schema).json + } + + def writeMetaData(attributes: Seq[Attribute], origPath: Path, conf: Configuration): Unit = { + if (origPath == null) { + throw new IllegalArgumentException("Unable to write Parquet metadata: path is null") + } + val fs = origPath.getFileSystem(conf) + if (fs == null) { + throw new IllegalArgumentException( + s"Unable to write Parquet metadata: path $origPath is incorrectly formatted") + } + val path = origPath.makeQualified(fs) + if (fs.exists(path) && !fs.getFileStatus(path).isDir) { + throw new IllegalArgumentException(s"Expected to write to directory $path but found file") + } + val metadataPath = new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE) + if (fs.exists(metadataPath)) { + try { + fs.delete(metadataPath, true) + } catch { + case e: IOException => + throw new IOException(s"Unable to delete previous PARQUET_METADATA_FILE at $metadataPath") + } + } + val extraMetadata = new java.util.HashMap[String, String]() + extraMetadata.put( + CatalystReadSupport.SPARK_METADATA_KEY, + ParquetTypesConverter.convertToString(attributes)) + // TODO: add extra data, e.g., table name, date, etc.? + + val parquetSchema: MessageType = ParquetTypesConverter.convertFromAttributes(attributes) + val metaData: FileMetaData = new FileMetaData( + parquetSchema, + extraMetadata, + "Spark") + + ParquetRelation.enableLogForwarding() + ParquetFileWriter.writeMetadataFile( + conf, + path, + new Footer(path, new ParquetMetadata(metaData, Nil)) :: Nil) + } + + /** + * Try to read Parquet metadata at the given Path. We first see if there is a summary file + * in the parent directory. If so, this is used. Else we read the actual footer at the given + * location. + * @param origPath The path at which we expect one (or more) Parquet files. + * @param configuration The Hadoop configuration to use. + * @return The `ParquetMetadata` containing among other things the schema. + */ + def readMetaData(origPath: Path, configuration: Option[Configuration]): ParquetMetadata = { + if (origPath == null) { + throw new IllegalArgumentException("Unable to read Parquet metadata: path is null") + } + val job = new Job() + val conf = configuration.getOrElse(ContextUtil.getConfiguration(job)) + val fs: FileSystem = origPath.getFileSystem(conf) + if (fs == null) { + throw new IllegalArgumentException(s"Incorrectly formatted Parquet metadata path $origPath") + } + val path = origPath.makeQualified(fs) + + val children = + fs + .globStatus(path) + .flatMap { status => if (status.isDir) fs.listStatus(status.getPath) else List(status) } + .filterNot { status => + val name = status.getPath.getName + (name(0) == '.' || name(0) == '_') && name != ParquetFileWriter.PARQUET_METADATA_FILE + } + + ParquetRelation.enableLogForwarding() + + // NOTE (lian): Parquet "_metadata" file can be very slow if the file consists of lots of row + // groups. Since Parquet schema is replicated among all row groups, we only need to touch a + // single row group to read schema related metadata. Notice that we are making assumptions that + // all data in a single Parquet file have the same schema, which is normally true. + children + // Try any non-"_metadata" file first... + .find(_.getPath.getName != ParquetFileWriter.PARQUET_METADATA_FILE) + // ... and fallback to "_metadata" if no such file exists (which implies the Parquet file is + // empty, thus normally the "_metadata" file is expected to be fairly small). + .orElse(children.find(_.getPath.getName == ParquetFileWriter.PARQUET_METADATA_FILE)) + .map(ParquetFileReader.readFooter(conf, _, ParquetMetadataConverter.NO_FILTER)) + .getOrElse( + throw new IllegalArgumentException(s"Could not find Parquet metadata at path $path")) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/c1838e43/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala new file mode 100644 index 0000000..1b51a5e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -0,0 +1,115 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.metric + +import org.apache.spark.{Accumulable, AccumulableParam, SparkContext} + +/** + * Create a layer for specialized metric. We cannot add `@specialized` to + * `Accumulable/AccumulableParam` because it will break Java source compatibility. + * + * An implementation of SQLMetric should override `+=` and `add` to avoid boxing. + */ +private[sql] abstract class SQLMetric[R <: SQLMetricValue[T], T]( + name: String, val param: SQLMetricParam[R, T]) + extends Accumulable[R, T](param.zero, param, Some(name), true) + +/** + * Create a layer for specialized metric. We cannot add `@specialized` to + * `Accumulable/AccumulableParam` because it will break Java source compatibility. + */ +private[sql] trait SQLMetricParam[R <: SQLMetricValue[T], T] extends AccumulableParam[R, T] { + + def zero: R +} + +/** + * Create a layer for specialized metric. We cannot add `@specialized` to + * `Accumulable/AccumulableParam` because it will break Java source compatibility. + */ +private[sql] trait SQLMetricValue[T] extends Serializable { + + def value: T + + override def toString: String = value.toString +} + +/** + * A wrapper of Long to avoid boxing and unboxing when using Accumulator + */ +private[sql] class LongSQLMetricValue(private var _value : Long) extends SQLMetricValue[Long] { + + def add(incr: Long): LongSQLMetricValue = { + _value += incr + this + } + + // Although there is a boxing here, it's fine because it's only called in SQLListener + override def value: Long = _value +} + +/** + * A wrapper of Int to avoid boxing and unboxing when using Accumulator + */ +private[sql] class IntSQLMetricValue(private var _value: Int) extends SQLMetricValue[Int] { + + def add(term: Int): IntSQLMetricValue = { + _value += term + this + } + + // Although there is a boxing here, it's fine because it's only called in SQLListener + override def value: Int = _value +} + +/** + * A specialized long Accumulable to avoid boxing and unboxing when using Accumulator's + * `+=` and `add`. + */ +private[sql] class LongSQLMetric private[metric](name: String) + extends SQLMetric[LongSQLMetricValue, Long](name, LongSQLMetricParam) { + + override def +=(term: Long): Unit = { + localValue.add(term) + } + + override def add(term: Long): Unit = { + localValue.add(term) + } +} + +private object LongSQLMetricParam extends SQLMetricParam[LongSQLMetricValue, Long] { + + override def addAccumulator(r: LongSQLMetricValue, t: Long): LongSQLMetricValue = r.add(t) + + override def addInPlace(r1: LongSQLMetricValue, r2: LongSQLMetricValue): LongSQLMetricValue = + r1.add(r2.value) + + override def zero(initialValue: LongSQLMetricValue): LongSQLMetricValue = zero + + override def zero: LongSQLMetricValue = new LongSQLMetricValue(0L) +} + +private[sql] object SQLMetrics { + + def createLongMetric(sc: SparkContext, name: String): LongSQLMetric = { + val acc = new LongSQLMetric(name) + sc.cleaner.foreach(_.registerAccumulatorForCleanup(acc)) + acc + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/c1838e43/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala index 66237f8..28fa231 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/package.scala @@ -18,12 +18,6 @@ package org.apache.spark.sql /** - * :: DeveloperApi :: - * An execution engine for relational query plans that runs on top Spark and returns RDDs. - * - * Note that the operators in this package are created automatically by a query planner using a - * [[SQLContext]] and are not intended to be used directly by end users of Spark SQL. They are - * documented here in order to make it easier for others to understand the performance - * characteristics of query plans that are generated by Spark SQL. + * The physical execution component of Spark SQL. Note that this is a private package. */ package object execution http://git-wip-us.apache.org/repos/asf/spark/blob/c1838e43/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala new file mode 100644 index 0000000..49646a9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.ui + +import javax.servlet.http.HttpServletRequest + +import scala.collection.mutable +import scala.xml.Node + +import org.apache.commons.lang3.StringEscapeUtils + +import org.apache.spark.Logging +import org.apache.spark.ui.{UIUtils, WebUIPage} + +private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with Logging { + + private val listener = parent.listener + + override def render(request: HttpServletRequest): Seq[Node] = { + val currentTime = System.currentTimeMillis() + val content = listener.synchronized { + val _content = mutable.ListBuffer[Node]() + if (listener.getRunningExecutions.nonEmpty) { + _content ++= + new RunningExecutionTable( + parent, "Running Queries", currentTime, + listener.getRunningExecutions.sortBy(_.submissionTime).reverse).toNodeSeq + } + if (listener.getCompletedExecutions.nonEmpty) { + _content ++= + new CompletedExecutionTable( + parent, "Completed Queries", currentTime, + listener.getCompletedExecutions.sortBy(_.submissionTime).reverse).toNodeSeq + } + if (listener.getFailedExecutions.nonEmpty) { + _content ++= + new FailedExecutionTable( + parent, "Failed Queries", currentTime, + listener.getFailedExecutions.sortBy(_.submissionTime).reverse).toNodeSeq + } + _content + } + UIUtils.headerSparkPage("SQL", content, parent, Some(5000)) + } +} + +private[ui] abstract class ExecutionTable( + parent: SQLTab, + tableId: String, + tableName: String, + currentTime: Long, + executionUIDatas: Seq[SQLExecutionUIData], + showRunningJobs: Boolean, + showSucceededJobs: Boolean, + showFailedJobs: Boolean) { + + protected def baseHeader: Seq[String] = Seq( + "ID", + "Description", + "Submitted", + "Duration") + + protected def header: Seq[String] + + protected def row(currentTime: Long, executionUIData: SQLExecutionUIData): Seq[Node] = { + val submissionTime = executionUIData.submissionTime + val duration = executionUIData.completionTime.getOrElse(currentTime) - submissionTime + + val runningJobs = executionUIData.runningJobs.map { jobId => + <a href={jobURL(jobId)}>{jobId.toString}</a><br/> + } + val succeededJobs = executionUIData.succeededJobs.sorted.map { jobId => + <a href={jobURL(jobId)}>{jobId.toString}</a><br/> + } + val failedJobs = executionUIData.failedJobs.sorted.map { jobId => + <a href={jobURL(jobId)}>{jobId.toString}</a><br/> + } + <tr> + <td> + {executionUIData.executionId.toString} + </td> + <td> + {descriptionCell(executionUIData)} + </td> + <td sorttable_customkey={submissionTime.toString}> + {UIUtils.formatDate(submissionTime)} + </td> + <td sorttable_customkey={duration.toString}> + {UIUtils.formatDuration(duration)} + </td> + {if (showRunningJobs) { + <td> + {runningJobs} + </td> + }} + {if (showSucceededJobs) { + <td> + {succeededJobs} + </td> + }} + {if (showFailedJobs) { + <td> + {failedJobs} + </td> + }} + {detailCell(executionUIData.physicalPlanDescription)} + </tr> + } + + private def descriptionCell(execution: SQLExecutionUIData): Seq[Node] = { + val details = if (execution.details.nonEmpty) { + <span onclick="this.parentNode.querySelector('.stage-details').classList.toggle('collapsed')" + class="expand-details"> + +details + </span> ++ + <div class="stage-details collapsed"> + <pre>{execution.details}</pre> + </div> + } else { + Nil + } + + val desc = { + <a href={executionURL(execution.executionId)}>{execution.description}</a> + } + + <div>{desc} {details}</div> + } + + private def detailCell(physicalPlan: String): Seq[Node] = { + val isMultiline = physicalPlan.indexOf('\n') >= 0 + val summary = StringEscapeUtils.escapeHtml4( + if (isMultiline) { + physicalPlan.substring(0, physicalPlan.indexOf('\n')) + } else { + physicalPlan + }) + val details = if (isMultiline) { + // scalastyle:off + <span onclick="this.parentNode.querySelector('.stacktrace-details').classList.toggle('collapsed')" + class="expand-details"> + +details + </span> ++ + <div class="stacktrace-details collapsed"> + <pre>{physicalPlan}</pre> + </div> + // scalastyle:on + } else { + "" + } + <td>{summary}{details}</td> + } + + def toNodeSeq: Seq[Node] = { + <div> + <h4>{tableName}</h4> + {UIUtils.listingTable[SQLExecutionUIData]( + header, row(currentTime, _), executionUIDatas, id = Some(tableId))} + </div> + } + + private def jobURL(jobId: Long): String = + "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(parent.basePath), jobId) + + private def executionURL(executionID: Long): String = + s"${UIUtils.prependBaseUri(parent.basePath)}/${parent.prefix}/execution?id=$executionID" +} + +private[ui] class RunningExecutionTable( + parent: SQLTab, + tableName: String, + currentTime: Long, + executionUIDatas: Seq[SQLExecutionUIData]) + extends ExecutionTable( + parent, + "running-execution-table", + tableName, + currentTime, + executionUIDatas, + showRunningJobs = true, + showSucceededJobs = true, + showFailedJobs = true) { + + override protected def header: Seq[String] = + baseHeader ++ Seq("Running Jobs", "Succeeded Jobs", "Failed Jobs", "Detail") +} + +private[ui] class CompletedExecutionTable( + parent: SQLTab, + tableName: String, + currentTime: Long, + executionUIDatas: Seq[SQLExecutionUIData]) + extends ExecutionTable( + parent, + "completed-execution-table", + tableName, + currentTime, + executionUIDatas, + showRunningJobs = false, + showSucceededJobs = true, + showFailedJobs = false) { + + override protected def header: Seq[String] = baseHeader ++ Seq("Jobs", "Detail") +} + +private[ui] class FailedExecutionTable( + parent: SQLTab, + tableName: String, + currentTime: Long, + executionUIDatas: Seq[SQLExecutionUIData]) + extends ExecutionTable( + parent, + "failed-execution-table", + tableName, + currentTime, + executionUIDatas, + showRunningJobs = false, + showSucceededJobs = true, + showFailedJobs = true) { + + override protected def header: Seq[String] = + baseHeader ++ Seq("Succeeded Jobs", "Failed Jobs", "Detail") +} http://git-wip-us.apache.org/repos/asf/spark/blob/c1838e43/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala new file mode 100644 index 0000000..f0b56c2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/ExecutionPage.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.ui + +import javax.servlet.http.HttpServletRequest + +import scala.xml.{Node, Unparsed} + +import org.apache.commons.lang3.StringEscapeUtils + +import org.apache.spark.Logging +import org.apache.spark.ui.{UIUtils, WebUIPage} + +private[sql] class ExecutionPage(parent: SQLTab) extends WebUIPage("execution") with Logging { + + private val listener = parent.listener + + override def render(request: HttpServletRequest): Seq[Node] = listener.synchronized { + val parameterExecutionId = request.getParameter("id") + require(parameterExecutionId != null && parameterExecutionId.nonEmpty, + "Missing execution id parameter") + + val executionId = parameterExecutionId.toLong + val content = listener.getExecution(executionId).map { executionUIData => + val currentTime = System.currentTimeMillis() + val duration = + executionUIData.completionTime.getOrElse(currentTime) - executionUIData.submissionTime + + val summary = + <div> + <ul class="unstyled"> + <li> + <strong>Submitted Time: </strong>{UIUtils.formatDate(executionUIData.submissionTime)} + </li> + <li> + <strong>Duration: </strong>{UIUtils.formatDuration(duration)} + </li> + {if (executionUIData.runningJobs.nonEmpty) { + <li> + <strong>Running Jobs: </strong> + {executionUIData.runningJobs.sorted.map { jobId => + <a href={jobURL(jobId)}>{jobId.toString}</a><span> </span> + }} + </li> + }} + {if (executionUIData.succeededJobs.nonEmpty) { + <li> + <strong>Succeeded Jobs: </strong> + {executionUIData.succeededJobs.sorted.map { jobId => + <a href={jobURL(jobId)}>{jobId.toString}</a><span> </span> + }} + </li> + }} + {if (executionUIData.failedJobs.nonEmpty) { + <li> + <strong>Failed Jobs: </strong> + {executionUIData.failedJobs.sorted.map { jobId => + <a href={jobURL(jobId)}>{jobId.toString}</a><span> </span> + }} + </li> + }} + <li> + <strong>Detail: </strong><br/> + <pre>{executionUIData.physicalPlanDescription}</pre> + </li> + </ul> + </div> + + val metrics = listener.getExecutionMetrics(executionId) + + summary ++ planVisualization(metrics, executionUIData.physicalPlanGraph) + }.getOrElse { + <div>No information to display for Plan {executionId}</div> + } + + UIUtils.headerSparkPage(s"Details for Query $executionId", content, parent, Some(5000)) + } + + + private def planVisualizationResources: Seq[Node] = { + // scalastyle:off + <link rel="stylesheet" href={UIUtils.prependBaseUri("/static/sql/spark-sql-viz.css")} type="text/css"/> + <script src={UIUtils.prependBaseUri("/static/d3.min.js")}></script> + <script src={UIUtils.prependBaseUri("/static/dagre-d3.min.js")}></script> + <script src={UIUtils.prependBaseUri("/static/graphlib-dot.min.js")}></script> + <script src={UIUtils.prependBaseUri("/static/sql/spark-sql-viz.js")}></script> + // scalastyle:on + } + + private def planVisualization(metrics: Map[Long, Any], graph: SparkPlanGraph): Seq[Node] = { + val metadata = graph.nodes.flatMap { node => + val nodeId = s"plan-meta-data-${node.id}" + <div id={nodeId}>{node.desc}</div> + } + + <div> + <div id="plan-viz-graph"></div> + <div id="plan-viz-metadata" style="display:none"> + <div class="dot-file"> + {graph.makeDotFile(metrics)} + </div> + <div id="plan-viz-metadata-size">{graph.nodes.size.toString}</div> + {metadata} + </div> + {planVisualizationResources} + <script>$(function(){{ renderPlanViz(); }})</script> + </div> + } + + private def jobURL(jobId: Long): String = + "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(parent.basePath), jobId) +} http://git-wip-us.apache.org/repos/asf/spark/blob/c1838e43/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala new file mode 100644 index 0000000..0b9bad9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -0,0 +1,352 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.ui + +import scala.collection.mutable + +import com.google.common.annotations.VisibleForTesting + +import org.apache.spark.{JobExecutionStatus, Logging} +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler._ +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.execution.SQLExecution +import org.apache.spark.sql.execution.metric.{SQLMetricParam, SQLMetricValue} + +private[sql] class SQLListener(sqlContext: SQLContext) extends SparkListener with Logging { + + private val retainedExecutions = + sqlContext.sparkContext.conf.getInt("spark.sql.ui.retainedExecutions", 1000) + + private val activeExecutions = mutable.HashMap[Long, SQLExecutionUIData]() + + // Old data in the following fields must be removed in "trimExecutionsIfNecessary". + // If adding new fields, make sure "trimExecutionsIfNecessary" can clean up old data + private val _executionIdToData = mutable.HashMap[Long, SQLExecutionUIData]() + + /** + * Maintain the relation between job id and execution id so that we can get the execution id in + * the "onJobEnd" method. + */ + private val _jobIdToExecutionId = mutable.HashMap[Long, Long]() + + private val _stageIdToStageMetrics = mutable.HashMap[Long, SQLStageMetrics]() + + private val failedExecutions = mutable.ListBuffer[SQLExecutionUIData]() + + private val completedExecutions = mutable.ListBuffer[SQLExecutionUIData]() + + def executionIdToData: Map[Long, SQLExecutionUIData] = synchronized { + _executionIdToData.toMap + } + + def jobIdToExecutionId: Map[Long, Long] = synchronized { + _jobIdToExecutionId.toMap + } + + def stageIdToStageMetrics: Map[Long, SQLStageMetrics] = synchronized { + _stageIdToStageMetrics.toMap + } + + private def trimExecutionsIfNecessary( + executions: mutable.ListBuffer[SQLExecutionUIData]): Unit = { + if (executions.size > retainedExecutions) { + val toRemove = math.max(retainedExecutions / 10, 1) + executions.take(toRemove).foreach { execution => + for (executionUIData <- _executionIdToData.remove(execution.executionId)) { + for (jobId <- executionUIData.jobs.keys) { + _jobIdToExecutionId.remove(jobId) + } + for (stageId <- executionUIData.stages) { + _stageIdToStageMetrics.remove(stageId) + } + } + } + executions.trimStart(toRemove) + } + } + + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + val executionIdString = jobStart.properties.getProperty(SQLExecution.EXECUTION_ID_KEY) + if (executionIdString == null) { + // This is not a job created by SQL + return + } + val executionId = executionIdString.toLong + val jobId = jobStart.jobId + val stageIds = jobStart.stageIds + + synchronized { + activeExecutions.get(executionId).foreach { executionUIData => + executionUIData.jobs(jobId) = JobExecutionStatus.RUNNING + executionUIData.stages ++= stageIds + stageIds.foreach(stageId => + _stageIdToStageMetrics(stageId) = new SQLStageMetrics(stageAttemptId = 0)) + _jobIdToExecutionId(jobId) = executionId + } + } + } + + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = synchronized { + val jobId = jobEnd.jobId + for (executionId <- _jobIdToExecutionId.get(jobId); + executionUIData <- _executionIdToData.get(executionId)) { + jobEnd.jobResult match { + case JobSucceeded => executionUIData.jobs(jobId) = JobExecutionStatus.SUCCEEDED + case JobFailed(_) => executionUIData.jobs(jobId) = JobExecutionStatus.FAILED + } + if (executionUIData.completionTime.nonEmpty && !executionUIData.hasRunningJobs) { + // We are the last job of this execution, so mark the execution as finished. Note that + // `onExecutionEnd` also does this, but currently that can be called before `onJobEnd` + // since these are called on different threads. + markExecutionFinished(executionId) + } + } + } + + override def onExecutorMetricsUpdate( + executorMetricsUpdate: SparkListenerExecutorMetricsUpdate): Unit = synchronized { + for ((taskId, stageId, stageAttemptID, metrics) <- executorMetricsUpdate.taskMetrics) { + updateTaskAccumulatorValues(taskId, stageId, stageAttemptID, metrics, finishTask = false) + } + } + + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = synchronized { + val stageId = stageSubmitted.stageInfo.stageId + val stageAttemptId = stageSubmitted.stageInfo.attemptId + // Always override metrics for old stage attempt + _stageIdToStageMetrics(stageId) = new SQLStageMetrics(stageAttemptId) + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { + updateTaskAccumulatorValues( + taskEnd.taskInfo.taskId, + taskEnd.stageId, + taskEnd.stageAttemptId, + taskEnd.taskMetrics, + finishTask = true) + } + + /** + * Update the accumulator values of a task with the latest metrics for this task. This is called + * every time we receive an executor heartbeat or when a task finishes. + */ + private def updateTaskAccumulatorValues( + taskId: Long, + stageId: Int, + stageAttemptID: Int, + metrics: TaskMetrics, + finishTask: Boolean): Unit = { + if (metrics == null) { + return + } + + _stageIdToStageMetrics.get(stageId) match { + case Some(stageMetrics) => + if (stageAttemptID < stageMetrics.stageAttemptId) { + // A task of an old stage attempt. Because a new stage is submitted, we can ignore it. + } else if (stageAttemptID > stageMetrics.stageAttemptId) { + logWarning(s"A task should not have a higher stageAttemptID ($stageAttemptID) then " + + s"what we have seen (${stageMetrics.stageAttemptId}})") + } else { + // TODO We don't know the attemptId. Currently, what we can do is overriding the + // accumulator updates. However, if there are two same task are running, such as + // speculation, the accumulator updates will be overriding by different task attempts, + // the results will be weird. + stageMetrics.taskIdToMetricUpdates.get(taskId) match { + case Some(taskMetrics) => + if (finishTask) { + taskMetrics.finished = true + taskMetrics.accumulatorUpdates = metrics.accumulatorUpdates() + } else if (!taskMetrics.finished) { + taskMetrics.accumulatorUpdates = metrics.accumulatorUpdates() + } else { + // If a task is finished, we should not override with accumulator updates from + // heartbeat reports + } + case None => + // TODO Now just set attemptId to 0. Should fix here when we can get the attempt + // id from SparkListenerExecutorMetricsUpdate + stageMetrics.taskIdToMetricUpdates(taskId) = new SQLTaskMetrics( + attemptId = 0, finished = finishTask, metrics.accumulatorUpdates()) + } + } + case None => + // This execution and its stage have been dropped + } + } + + def onExecutionStart( + executionId: Long, + description: String, + details: String, + physicalPlanDescription: String, + physicalPlanGraph: SparkPlanGraph, + time: Long): Unit = { + val sqlPlanMetrics = physicalPlanGraph.nodes.flatMap { node => + node.metrics.map(metric => metric.accumulatorId -> metric) + } + + val executionUIData = new SQLExecutionUIData(executionId, description, details, + physicalPlanDescription, physicalPlanGraph, sqlPlanMetrics.toMap, time) + synchronized { + activeExecutions(executionId) = executionUIData + _executionIdToData(executionId) = executionUIData + } + } + + def onExecutionEnd(executionId: Long, time: Long): Unit = synchronized { + _executionIdToData.get(executionId).foreach { executionUIData => + executionUIData.completionTime = Some(time) + if (!executionUIData.hasRunningJobs) { + // onExecutionEnd happens after all "onJobEnd"s + // So we should update the execution lists. + markExecutionFinished(executionId) + } else { + // There are some running jobs, onExecutionEnd happens before some "onJobEnd"s. + // Then we don't if the execution is successful, so let the last onJobEnd updates the + // execution lists. + } + } + } + + private def markExecutionFinished(executionId: Long): Unit = { + activeExecutions.remove(executionId).foreach { executionUIData => + if (executionUIData.isFailed) { + failedExecutions += executionUIData + trimExecutionsIfNecessary(failedExecutions) + } else { + completedExecutions += executionUIData + trimExecutionsIfNecessary(completedExecutions) + } + } + } + + def getRunningExecutions: Seq[SQLExecutionUIData] = synchronized { + activeExecutions.values.toSeq + } + + def getFailedExecutions: Seq[SQLExecutionUIData] = synchronized { + failedExecutions + } + + def getCompletedExecutions: Seq[SQLExecutionUIData] = synchronized { + completedExecutions + } + + def getExecution(executionId: Long): Option[SQLExecutionUIData] = synchronized { + _executionIdToData.get(executionId) + } + + /** + * Get all accumulator updates from all tasks which belong to this execution and merge them. + */ + def getExecutionMetrics(executionId: Long): Map[Long, Any] = synchronized { + _executionIdToData.get(executionId) match { + case Some(executionUIData) => + val accumulatorUpdates = { + for (stageId <- executionUIData.stages; + stageMetrics <- _stageIdToStageMetrics.get(stageId).toIterable; + taskMetrics <- stageMetrics.taskIdToMetricUpdates.values; + accumulatorUpdate <- taskMetrics.accumulatorUpdates.toSeq) yield { + accumulatorUpdate + } + }.filter { case (id, _) => executionUIData.accumulatorMetrics.contains(id) } + mergeAccumulatorUpdates(accumulatorUpdates, accumulatorId => + executionUIData.accumulatorMetrics(accumulatorId).metricParam). + mapValues(_.asInstanceOf[SQLMetricValue[_]].value) + case None => + // This execution has been dropped + Map.empty + } + } + + private def mergeAccumulatorUpdates( + accumulatorUpdates: Seq[(Long, Any)], + paramFunc: Long => SQLMetricParam[SQLMetricValue[Any], Any]): Map[Long, Any] = { + accumulatorUpdates.groupBy(_._1).map { case (accumulatorId, values) => + val param = paramFunc(accumulatorId) + (accumulatorId, + values.map(_._2.asInstanceOf[SQLMetricValue[Any]]).foldLeft(param.zero)(param.addInPlace)) + } + } + +} + +/** + * Represent all necessary data for an execution that will be used in Web UI. + */ +private[ui] class SQLExecutionUIData( + val executionId: Long, + val description: String, + val details: String, + val physicalPlanDescription: String, + val physicalPlanGraph: SparkPlanGraph, + val accumulatorMetrics: Map[Long, SQLPlanMetric], + val submissionTime: Long, + var completionTime: Option[Long] = None, + val jobs: mutable.HashMap[Long, JobExecutionStatus] = mutable.HashMap.empty, + val stages: mutable.ArrayBuffer[Int] = mutable.ArrayBuffer()) { + + /** + * Return whether there are running jobs in this execution. + */ + def hasRunningJobs: Boolean = jobs.values.exists(_ == JobExecutionStatus.RUNNING) + + /** + * Return whether there are any failed jobs in this execution. + */ + def isFailed: Boolean = jobs.values.exists(_ == JobExecutionStatus.FAILED) + + def runningJobs: Seq[Long] = + jobs.filter { case (_, status) => status == JobExecutionStatus.RUNNING }.keys.toSeq + + def succeededJobs: Seq[Long] = + jobs.filter { case (_, status) => status == JobExecutionStatus.SUCCEEDED }.keys.toSeq + + def failedJobs: Seq[Long] = + jobs.filter { case (_, status) => status == JobExecutionStatus.FAILED }.keys.toSeq +} + +/** + * Represent a metric in a SQLPlan. + * + * Because we cannot revert our changes for an "Accumulator", we need to maintain accumulator + * updates for each task. So that if a task is retried, we can simply override the old updates with + * the new updates of the new attempt task. Since we cannot add them to accumulator, we need to use + * "AccumulatorParam" to get the aggregation value. + */ +private[ui] case class SQLPlanMetric( + name: String, + accumulatorId: Long, + metricParam: SQLMetricParam[SQLMetricValue[Any], Any]) + +/** + * Store all accumulatorUpdates for all tasks in a Spark stage. + */ +private[ui] class SQLStageMetrics( + val stageAttemptId: Long, + val taskIdToMetricUpdates: mutable.HashMap[Long, SQLTaskMetrics] = mutable.HashMap.empty) + +/** + * Store all accumulatorUpdates for a Spark task. + */ +private[ui] class SQLTaskMetrics( + val attemptId: Long, // TODO not used yet + var finished: Boolean, + var accumulatorUpdates: Map[Long, Any]) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org