Repository: spark Updated Branches: refs/heads/master e127ec34d -> 1221849f9
[SPARK-8005][SQL] Input file name Users can now get the file name of the partition being read in. A thread local variable is in `SQLNewHadoopRDD` and is set when the partition is computed. `SQLNewHadoopRDD` is moved to core so that the catalyst package can reach it. This supports: `df.select(inputFileName())` and `sqlContext.sql("select input_file_name() from table")` Author: Joseph Batchik <josephbatc...@gmail.com> Closes #7743 from JDrit/input_file_name and squashes the following commits: abb8609 [Joseph Batchik] fixed failing test and changed the default value to be an empty string d2f323d [Joseph Batchik] updates per review 102061f [Joseph Batchik] updates per review 75313f5 [Joseph Batchik] small fixes c7f7b5a [Joseph Batchik] addeding input file name to Spark SQL Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1221849f Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1221849f Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1221849f Branch: refs/heads/master Commit: 1221849f91739454b8e495889cba7498ba8beea7 Parents: e127ec3 Author: Joseph Batchik <josephbatc...@gmail.com> Authored: Wed Jul 29 23:35:55 2015 -0700 Committer: Reynold Xin <r...@databricks.com> Committed: Wed Jul 29 23:35:55 2015 -0700 ---------------------------------------------------------------------- .../org/apache/spark/rdd/SqlNewHadoopRDD.scala | 297 +++++++++++++++++++ .../catalyst/analysis/FunctionRegistry.scala | 3 +- .../catalyst/expressions/InputFileName.scala | 49 +++ .../catalyst/expressions/SparkPartitionID.scala | 2 + .../expressions/NondeterministicSuite.scala | 4 + .../spark/sql/execution/SqlNewHadoopRDD.scala | 273 ----------------- .../scala/org/apache/spark/sql/functions.scala | 9 + .../spark/sql/parquet/ParquetRelation.scala | 3 +- .../spark/sql/ColumnExpressionSuite.scala | 17 +- .../scala/org/apache/spark/sql/UDFSuite.scala | 17 +- .../org/apache/spark/sql/hive/UDFSuite.scala | 6 - 11 files changed, 396 insertions(+), 284 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/1221849f/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala new file mode 100644 index 0000000..35e44cb --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDD.scala @@ -0,0 +1,297 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import java.text.SimpleDateFormat +import java.util.Date + +import scala.reflect.ClassTag + +import org.apache.hadoop.conf.{Configurable, Configuration} +import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.executor.DataReadMethod +import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.{Partition => SparkPartition, _} +import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.{SerializableConfiguration, Utils} + + +private[spark] class SqlNewHadoopPartition( + rddId: Int, + val index: Int, + @transient rawSplit: InputSplit with Writable) + extends SparkPartition { + + val serializableHadoopSplit = new SerializableWritable(rawSplit) + + override def hashCode(): Int = 41 * (41 + rddId) + index +} + +/** + * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS, + * sources in HBase, or S3), using the new MapReduce API (`org.apache.hadoop.mapreduce`). + * It is based on [[org.apache.spark.rdd.NewHadoopRDD]]. It has three additions. + * 1. A shared broadcast Hadoop Configuration. + * 2. An optional closure `initDriverSideJobFuncOpt` that set configurations at the driver side + * to the shared Hadoop Configuration. + * 3. An optional closure `initLocalJobFuncOpt` that set configurations at both the driver side + * and the executor side to the shared Hadoop Configuration. + * + * Note: This is RDD is basically a cloned version of [[org.apache.spark.rdd.NewHadoopRDD]] with + * changes based on [[org.apache.spark.rdd.HadoopRDD]]. In future, this functionality will be + * folded into core. + */ +private[spark] class SqlNewHadoopRDD[K, V]( + @transient sc : SparkContext, + broadcastedConf: Broadcast[SerializableConfiguration], + @transient initDriverSideJobFuncOpt: Option[Job => Unit], + initLocalJobFuncOpt: Option[Job => Unit], + inputFormatClass: Class[_ <: InputFormat[K, V]], + keyClass: Class[K], + valueClass: Class[V]) + extends RDD[(K, V)](sc, Nil) + with SparkHadoopMapReduceUtil + with Logging { + + protected def getJob(): Job = { + val conf: Configuration = broadcastedConf.value.value + // "new Job" will make a copy of the conf. Then, it is + // safe to mutate conf properties with initLocalJobFuncOpt + // and initDriverSideJobFuncOpt. + val newJob = new Job(conf) + initLocalJobFuncOpt.map(f => f(newJob)) + newJob + } + + def getConf(isDriverSide: Boolean): Configuration = { + val job = getJob() + if (isDriverSide) { + initDriverSideJobFuncOpt.map(f => f(job)) + } + job.getConfiguration + } + + private val jobTrackerId: String = { + val formatter = new SimpleDateFormat("yyyyMMddHHmm") + formatter.format(new Date()) + } + + @transient protected val jobId = new JobID(jobTrackerId, id) + + override def getPartitions: Array[SparkPartition] = { + val conf = getConf(isDriverSide = true) + val inputFormat = inputFormatClass.newInstance + inputFormat match { + case configurable: Configurable => + configurable.setConf(conf) + case _ => + } + val jobContext = newJobContext(conf, jobId) + val rawSplits = inputFormat.getSplits(jobContext).toArray + val result = new Array[SparkPartition](rawSplits.size) + for (i <- 0 until rawSplits.size) { + result(i) = + new SqlNewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + } + result + } + + override def compute( + theSplit: SparkPartition, + context: TaskContext): InterruptibleIterator[(K, V)] = { + val iter = new Iterator[(K, V)] { + val split = theSplit.asInstanceOf[SqlNewHadoopPartition] + logInfo("Input split: " + split.serializableHadoopSplit) + val conf = getConf(isDriverSide = false) + + val inputMetrics = context.taskMetrics + .getInputMetricsForReadMethod(DataReadMethod.Hadoop) + + // Sets the thread local variable for the file's name + split.serializableHadoopSplit.value match { + case fs: FileSplit => SqlNewHadoopRDD.setInputFileName(fs.getPath.toString) + case _ => SqlNewHadoopRDD.unsetInputFileName() + } + + // Find a function that will return the FileSystem bytes read by this thread. Do this before + // creating RecordReader, because RecordReader's constructor might read some bytes + val bytesReadCallback = inputMetrics.bytesReadCallback.orElse { + split.serializableHadoopSplit.value match { + case _: FileSplit | _: CombineFileSplit => + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() + case _ => None + } + } + inputMetrics.setBytesReadCallback(bytesReadCallback) + + val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) + val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) + val format = inputFormatClass.newInstance + format match { + case configurable: Configurable => + configurable.setConf(conf) + case _ => + } + private var reader = format.createRecordReader( + split.serializableHadoopSplit.value, hadoopAttemptContext) + reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) + + // Register an on-task-completion callback to close the input stream. + context.addTaskCompletionListener(context => close()) + var havePair = false + var finished = false + var recordsSinceMetricsUpdate = 0 + + override def hasNext: Boolean = { + if (!finished && !havePair) { + finished = !reader.nextKeyValue + if (finished) { + // Close and release the reader here; close() will also be called when the task + // completes, but for tasks that read from many files, it helps to release the + // resources early. + close() + } + havePair = !finished + } + !finished + } + + override def next(): (K, V) = { + if (!hasNext) { + throw new java.util.NoSuchElementException("End of stream") + } + havePair = false + if (!finished) { + inputMetrics.incRecordsRead(1) + } + (reader.getCurrentKey, reader.getCurrentValue) + } + + private def close() { + try { + if (reader != null) { + reader.close() + reader = null + + SqlNewHadoopRDD.unsetInputFileName() + + if (bytesReadCallback.isDefined) { + inputMetrics.updateBytesRead() + } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || + split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { + // If we can't get the bytes read from the FS stats, fall back to the split size, + // which may be inaccurate. + try { + inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) + } catch { + case e: java.io.IOException => + logWarning("Unable to get input size to set InputMetrics for task", e) + } + } + } + } catch { + case e: Exception => { + if (!Utils.inShutdown()) { + logWarning("Exception in RecordReader.close()", e) + } + } + } + } + } + new InterruptibleIterator(context, iter) + } + + /** Maps over a partition, providing the InputSplit that was used as the base of the partition. */ + @DeveloperApi + def mapPartitionsWithInputSplit[U: ClassTag]( + f: (InputSplit, Iterator[(K, V)]) => Iterator[U], + preservesPartitioning: Boolean = false): RDD[U] = { + new NewHadoopMapPartitionsWithSplitRDD(this, f, preservesPartitioning) + } + + override def getPreferredLocations(hsplit: SparkPartition): Seq[String] = { + val split = hsplit.asInstanceOf[SqlNewHadoopPartition].serializableHadoopSplit.value + val locs = HadoopRDD.SPLIT_INFO_REFLECTIONS match { + case Some(c) => + try { + val infos = c.newGetLocationInfo.invoke(split).asInstanceOf[Array[AnyRef]] + Some(HadoopRDD.convertSplitLocationInfo(infos)) + } catch { + case e : Exception => + logDebug("Failed to use InputSplit#getLocationInfo.", e) + None + } + case None => None + } + locs.getOrElse(split.getLocations.filter(_ != "localhost")) + } + + override def persist(storageLevel: StorageLevel): this.type = { + if (storageLevel.deserialized) { + logWarning("Caching NewHadoopRDDs as deserialized objects usually leads to undesired" + + " behavior because Hadoop's RecordReader reuses the same Writable object for all records." + + " Use a map transformation to make copies of the records.") + } + super.persist(storageLevel) + } +} + +private[spark] object SqlNewHadoopRDD { + + /** + * The thread variable for the name of the current file being read. This is used by + * the InputFileName function in Spark SQL. + */ + private[this] val inputFileName: ThreadLocal[UTF8String] = new ThreadLocal[UTF8String] { + override protected def initialValue(): UTF8String = UTF8String.fromString("") + } + + def getInputFileName(): UTF8String = inputFileName.get() + + private[spark] def setInputFileName(file: String) = inputFileName.set(UTF8String.fromString(file)) + + private[spark] def unsetInputFileName(): Unit = inputFileName.remove() + + /** + * Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to + * the given function rather than the index of the partition. + */ + private[spark] class NewHadoopMapPartitionsWithSplitRDD[U: ClassTag, T: ClassTag]( + prev: RDD[T], + f: (InputSplit, Iterator[T]) => Iterator[U], + preservesPartitioning: Boolean = false) + extends RDD[U](prev) { + + override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None + + override def getPartitions: Array[SparkPartition] = firstParent[T].partitions + + override def compute(split: SparkPartition, context: TaskContext): Iterator[U] = { + val partition = split.asInstanceOf[SqlNewHadoopPartition] + val inputSplit = partition.serializableHadoopSplit.value + f(inputSplit, firstParent[T].iterator(split, context)) + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/1221849f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 372f80d..378df4f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -230,7 +230,8 @@ object FunctionRegistry { expression[Sha1]("sha"), expression[Sha1]("sha1"), expression[Sha2]("sha2"), - expression[SparkPartitionID]("spark_partition_id") + expression[SparkPartitionID]("spark_partition_id"), + expression[InputFileName]("input_file_name") ) val builtin: FunctionRegistry = { http://git-wip-us.apache.org/repos/asf/spark/blob/1221849f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala new file mode 100644 index 0000000..1e74f71 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.rdd.SqlNewHadoopRDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.types.{DataType, StringType} +import org.apache.spark.unsafe.types.UTF8String + +/** + * Expression that returns the name of the current file being read in using [[SqlNewHadoopRDD]] + */ +case class InputFileName() extends LeafExpression with Nondeterministic { + + override def nullable: Boolean = true + + override def dataType: DataType = StringType + + override val prettyName = "INPUT_FILE_NAME" + + override protected def initInternal(): Unit = {} + + override protected def evalInternal(input: InternalRow): UTF8String = { + SqlNewHadoopRDD.getInputFileName() + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + ev.isNull = "false" + s"final ${ctx.javaType(dataType)} ${ev.primitive} = " + + "org.apache.spark.rdd.SqlNewHadoopRDD.getInputFileName();" + } + +} http://git-wip-us.apache.org/repos/asf/spark/blob/1221849f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 3f6480b..4b1772a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -34,6 +34,8 @@ private[sql] case class SparkPartitionID() extends LeafExpression with Nondeterm @transient private[this] var partitionId: Int = _ + override val prettyName = "SPARK_PARTITION_ID" + override protected def initInternal(): Unit = { partitionId = TaskContext.getPartitionId() } http://git-wip-us.apache.org/repos/asf/spark/blob/1221849f/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NondeterministicSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NondeterministicSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NondeterministicSuite.scala index 8289482..bf1c930 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NondeterministicSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NondeterministicSuite.scala @@ -27,4 +27,8 @@ class NondeterministicSuite extends SparkFunSuite with ExpressionEvalHelper { test("SparkPartitionID") { checkEvaluation(SparkPartitionID(), 0) } + + test("InputFileName") { + checkEvaluation(InputFileName(), "") + } } http://git-wip-us.apache.org/repos/asf/spark/blob/1221849f/sql/core/src/main/scala/org/apache/spark/sql/execution/SqlNewHadoopRDD.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SqlNewHadoopRDD.scala deleted file mode 100644 index 3d75b6a..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SqlNewHadoopRDD.scala +++ /dev/null @@ -1,273 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import java.text.SimpleDateFormat -import java.util.Date - -import org.apache.spark.{Partition => SparkPartition, _} -import org.apache.hadoop.conf.{Configurable, Configuration} -import org.apache.hadoop.io.Writable -import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.executor.DataReadMethod -import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil -import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD -import org.apache.spark.rdd.{HadoopRDD, RDD} -import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{SerializableConfiguration, Utils} - -import scala.reflect.ClassTag - -private[spark] class SqlNewHadoopPartition( - rddId: Int, - val index: Int, - @transient rawSplit: InputSplit with Writable) - extends SparkPartition { - - val serializableHadoopSplit = new SerializableWritable(rawSplit) - - override def hashCode(): Int = 41 * (41 + rddId) + index -} - -/** - * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS, - * sources in HBase, or S3), using the new MapReduce API (`org.apache.hadoop.mapreduce`). - * It is based on [[org.apache.spark.rdd.NewHadoopRDD]]. It has three additions. - * 1. A shared broadcast Hadoop Configuration. - * 2. An optional closure `initDriverSideJobFuncOpt` that set configurations at the driver side - * to the shared Hadoop Configuration. - * 3. An optional closure `initLocalJobFuncOpt` that set configurations at both the driver side - * and the executor side to the shared Hadoop Configuration. - * - * Note: This is RDD is basically a cloned version of [[org.apache.spark.rdd.NewHadoopRDD]] with - * changes based on [[org.apache.spark.rdd.HadoopRDD]]. In future, this functionality will be - * folded into core. - */ -private[sql] class SqlNewHadoopRDD[K, V]( - @transient sc : SparkContext, - broadcastedConf: Broadcast[SerializableConfiguration], - @transient initDriverSideJobFuncOpt: Option[Job => Unit], - initLocalJobFuncOpt: Option[Job => Unit], - inputFormatClass: Class[_ <: InputFormat[K, V]], - keyClass: Class[K], - valueClass: Class[V]) - extends RDD[(K, V)](sc, Nil) - with SparkHadoopMapReduceUtil - with Logging { - - protected def getJob(): Job = { - val conf: Configuration = broadcastedConf.value.value - // "new Job" will make a copy of the conf. Then, it is - // safe to mutate conf properties with initLocalJobFuncOpt - // and initDriverSideJobFuncOpt. - val newJob = new Job(conf) - initLocalJobFuncOpt.map(f => f(newJob)) - newJob - } - - def getConf(isDriverSide: Boolean): Configuration = { - val job = getJob() - if (isDriverSide) { - initDriverSideJobFuncOpt.map(f => f(job)) - } - job.getConfiguration - } - - private val jobTrackerId: String = { - val formatter = new SimpleDateFormat("yyyyMMddHHmm") - formatter.format(new Date()) - } - - @transient protected val jobId = new JobID(jobTrackerId, id) - - override def getPartitions: Array[SparkPartition] = { - val conf = getConf(isDriverSide = true) - val inputFormat = inputFormatClass.newInstance - inputFormat match { - case configurable: Configurable => - configurable.setConf(conf) - case _ => - } - val jobContext = newJobContext(conf, jobId) - val rawSplits = inputFormat.getSplits(jobContext).toArray - val result = new Array[SparkPartition](rawSplits.size) - for (i <- 0 until rawSplits.size) { - result(i) = - new SqlNewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) - } - result - } - - override def compute( - theSplit: SparkPartition, - context: TaskContext): InterruptibleIterator[(K, V)] = { - val iter = new Iterator[(K, V)] { - val split = theSplit.asInstanceOf[SqlNewHadoopPartition] - logInfo("Input split: " + split.serializableHadoopSplit) - val conf = getConf(isDriverSide = false) - - val inputMetrics = context.taskMetrics - .getInputMetricsForReadMethod(DataReadMethod.Hadoop) - - // Find a function that will return the FileSystem bytes read by this thread. Do this before - // creating RecordReader, because RecordReader's constructor might read some bytes - val bytesReadCallback = inputMetrics.bytesReadCallback.orElse { - split.serializableHadoopSplit.value match { - case _: FileSplit | _: CombineFileSplit => - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() - case _ => None - } - } - inputMetrics.setBytesReadCallback(bytesReadCallback) - - val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) - val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) - val format = inputFormatClass.newInstance - format match { - case configurable: Configurable => - configurable.setConf(conf) - case _ => - } - private var reader = format.createRecordReader( - split.serializableHadoopSplit.value, hadoopAttemptContext) - reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) - - // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener(context => close()) - var havePair = false - var finished = false - var recordsSinceMetricsUpdate = 0 - - override def hasNext: Boolean = { - if (!finished && !havePair) { - finished = !reader.nextKeyValue - if (finished) { - // Close and release the reader here; close() will also be called when the task - // completes, but for tasks that read from many files, it helps to release the - // resources early. - close() - } - havePair = !finished - } - !finished - } - - override def next(): (K, V) = { - if (!hasNext) { - throw new java.util.NoSuchElementException("End of stream") - } - havePair = false - if (!finished) { - inputMetrics.incRecordsRead(1) - } - (reader.getCurrentKey, reader.getCurrentValue) - } - - private def close() { - try { - if (reader != null) { - reader.close() - reader = null - - if (bytesReadCallback.isDefined) { - inputMetrics.updateBytesRead() - } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || - split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { - // If we can't get the bytes read from the FS stats, fall back to the split size, - // which may be inaccurate. - try { - inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength) - } catch { - case e: java.io.IOException => - logWarning("Unable to get input size to set InputMetrics for task", e) - } - } - } - } catch { - case e: Exception => { - if (!Utils.inShutdown()) { - logWarning("Exception in RecordReader.close()", e) - } - } - } - } - } - new InterruptibleIterator(context, iter) - } - - /** Maps over a partition, providing the InputSplit that was used as the base of the partition. */ - @DeveloperApi - def mapPartitionsWithInputSplit[U: ClassTag]( - f: (InputSplit, Iterator[(K, V)]) => Iterator[U], - preservesPartitioning: Boolean = false): RDD[U] = { - new NewHadoopMapPartitionsWithSplitRDD(this, f, preservesPartitioning) - } - - override def getPreferredLocations(hsplit: SparkPartition): Seq[String] = { - val split = hsplit.asInstanceOf[SqlNewHadoopPartition].serializableHadoopSplit.value - val locs = HadoopRDD.SPLIT_INFO_REFLECTIONS match { - case Some(c) => - try { - val infos = c.newGetLocationInfo.invoke(split).asInstanceOf[Array[AnyRef]] - Some(HadoopRDD.convertSplitLocationInfo(infos)) - } catch { - case e : Exception => - logDebug("Failed to use InputSplit#getLocationInfo.", e) - None - } - case None => None - } - locs.getOrElse(split.getLocations.filter(_ != "localhost")) - } - - override def persist(storageLevel: StorageLevel): this.type = { - if (storageLevel.deserialized) { - logWarning("Caching NewHadoopRDDs as deserialized objects usually leads to undesired" + - " behavior because Hadoop's RecordReader reuses the same Writable object for all records." + - " Use a map transformation to make copies of the records.") - } - super.persist(storageLevel) - } -} - -private[spark] object SqlNewHadoopRDD { - /** - * Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to - * the given function rather than the index of the partition. - */ - private[spark] class NewHadoopMapPartitionsWithSplitRDD[U: ClassTag, T: ClassTag]( - prev: RDD[T], - f: (InputSplit, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean = false) - extends RDD[U](prev) { - - override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None - - override def getPartitions: Array[SparkPartition] = firstParent[T].partitions - - override def compute(split: SparkPartition, context: TaskContext): Iterator[U] = { - val partition = split.asInstanceOf[SqlNewHadoopPartition] - val inputSplit = partition.serializableHadoopSplit.value - f(inputSplit, firstParent[T].iterator(split, context)) - } - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/1221849f/sql/core/src/main/scala/org/apache/spark/sql/functions.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 4e68a88..a2fece6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -744,6 +744,15 @@ object functions { def sparkPartitionId(): Column = SparkPartitionID() /** + * The file name of the current Spark task + * + * Note that this is indeterministic becuase it depends on what is currently being read in. + * + * @group normal_funcs + */ + def inputFileName(): Column = InputFileName() + + /** * Computes the square root of the specified float value. * * @group math_funcs http://git-wip-us.apache.org/repos/asf/spark/blob/1221849f/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index cc6fa2b..1a8176d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -39,11 +39,10 @@ import org.apache.parquet.{Log => ParquetLog} import org.apache.spark.{Logging, Partition => SparkPartition, SparkException} import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{SqlNewHadoopPartition, SqlNewHadoopRDD, RDD} import org.apache.spark.rdd.RDD._ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.{SqlNewHadoopPartition, SqlNewHadoopRDD} import org.apache.spark.sql.execution.datasources.PartitionSpec import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, StructType} http://git-wip-us.apache.org/repos/asf/spark/blob/1221849f/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 1f9f711..5c11024 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -22,13 +22,16 @@ import org.scalatest.Matchers._ import org.apache.spark.sql.execution.Project import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +import org.apache.spark.sql.test.SQLTestUtils -class ColumnExpressionSuite extends QueryTest { +class ColumnExpressionSuite extends QueryTest with SQLTestUtils { import org.apache.spark.sql.TestData._ private lazy val ctx = org.apache.spark.sql.test.TestSQLContext import ctx.implicits._ + override def sqlContext(): SQLContext = ctx + test("alias") { val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") assert(df.select(df("a").as("b")).columns.head === "b") @@ -489,6 +492,18 @@ class ColumnExpressionSuite extends QueryTest { ) } + test("InputFileName") { + withTempPath { dir => + val data = sqlContext.sparkContext.parallelize(0 to 10).toDF("id") + data.write.parquet(dir.getCanonicalPath) + val answer = sqlContext.read.parquet(dir.getCanonicalPath).select(inputFileName()) + .head.getString(0) + assert(answer.contains(dir.getCanonicalPath)) + + checkAnswer(data.select(inputFileName()).limit(1), Row("")) + } + } + test("lift alias out of cast") { compareExpressions( col("1234").as("name").cast("int").expr, http://git-wip-us.apache.org/repos/asf/spark/blob/1221849f/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index d9c8b38..183dc34 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,14 +17,17 @@ package org.apache.spark.sql +import org.apache.spark.sql.test.SQLTestUtils case class FunctionResult(f1: String, f2: String) -class UDFSuite extends QueryTest { +class UDFSuite extends QueryTest with SQLTestUtils { private lazy val ctx = org.apache.spark.sql.test.TestSQLContext import ctx.implicits._ + override def sqlContext(): SQLContext = ctx + test("built-in fixed arity expressions") { val df = ctx.emptyDataFrame df.selectExpr("rand()", "randn()", "rand(5)", "randn(50)") @@ -58,6 +61,18 @@ class UDFSuite extends QueryTest { ctx.dropTempTable("tmp_table") } + test("SPARK-8005 input_file_name") { + withTempPath { dir => + val data = ctx.sparkContext.parallelize(0 to 10, 2).toDF("id") + data.write.parquet(dir.getCanonicalPath) + ctx.read.parquet(dir.getCanonicalPath).registerTempTable("test_table") + val answer = ctx.sql("select input_file_name() from test_table").head().getString(0) + assert(answer.contains(dir.getCanonicalPath)) + assert(ctx.sql("select input_file_name() from test_table").distinct().collect().length >= 2) + ctx.dropTempTable("test_table") + } + } + test("error reporting for incorrect number of arguments") { val df = ctx.emptyDataFrame val e = intercept[AnalysisException] { http://git-wip-us.apache.org/repos/asf/spark/blob/1221849f/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala index 37afc21..9b3ede4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -34,10 +34,4 @@ class UDFSuite extends QueryTest { assert(ctx.sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0) assert(ctx.sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5) } - - test("SPARK-8003 spark_partition_id") { - val df = Seq((1, "Two Fiiiiive")).toDF("id", "saying") - ctx.registerDataFrameAsTable(df, "test_table") - checkAnswer(ctx.sql("select spark_partition_id() from test_table LIMIT 1").toDF(), Row(0)) - } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org