http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala new file mode 100644 index 0000000..4218fd5 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import org.apache.spark.sql.{ForeachWriter, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.python.PythonForeachWriter +import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamingWriteSupportProvider} +import org.apache.spark.sql.sources.v2.writer.{DataWriter, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.StructType + +/** + * A [[org.apache.spark.sql.sources.v2.DataSourceV2]] for forwarding data into the specified + * [[ForeachWriter]]. + * + * @param writer The [[ForeachWriter]] to process all data. + * @param converter An object to convert internal rows to target type T. Either it can be + * a [[ExpressionEncoder]] or a direct converter function. + * @tparam T The expected type of the sink. + */ +case class ForeachWriteSupportProvider[T]( + writer: ForeachWriter[T], + converter: Either[ExpressionEncoder[T], InternalRow => T]) + extends StreamingWriteSupportProvider { + + override def createStreamingWriteSupport( + queryId: String, + schema: StructType, + mode: OutputMode, + options: DataSourceOptions): StreamingWriteSupport = { + new StreamingWriteSupport { + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + + override def createStreamingWriterFactory(): StreamingDataWriterFactory = { + val rowConverter: InternalRow => T = converter match { + case Left(enc) => + val boundEnc = enc.resolveAndBind( + schema.toAttributes, + SparkSession.getActiveSession.get.sessionState.analyzer) + boundEnc.fromRow + case Right(func) => + func + } + ForeachWriterFactory(writer, rowConverter) + } + + override def toString: String = "ForeachSink" + } + } +} + +object ForeachWriteSupportProvider { + def apply[T]( + writer: ForeachWriter[T], + encoder: ExpressionEncoder[T]): ForeachWriteSupportProvider[_] = { + writer match { + case pythonWriter: PythonForeachWriter => + new ForeachWriteSupportProvider[UnsafeRow]( + pythonWriter, Right((x: InternalRow) => x.asInstanceOf[UnsafeRow])) + case _ => + new ForeachWriteSupportProvider[T](writer, Left(encoder)) + } + } +} + +case class ForeachWriterFactory[T]( + writer: ForeachWriter[T], + rowConverter: InternalRow => T) + extends StreamingDataWriterFactory { + override def createWriter( + partitionId: Int, + taskId: Long, + epochId: Long): ForeachDataWriter[T] = { + new ForeachDataWriter(writer, rowConverter, partitionId, epochId) + } +} + +/** + * A [[DataWriter]] which writes data in this partition to a [[ForeachWriter]]. + * + * @param writer The [[ForeachWriter]] to process all data. + * @param rowConverter A function which can convert [[InternalRow]] to the required type [[T]] + * @param partitionId + * @param epochId + * @tparam T The type expected by the writer. + */ +class ForeachDataWriter[T]( + writer: ForeachWriter[T], + rowConverter: InternalRow => T, + partitionId: Int, + epochId: Long) + extends DataWriter[InternalRow] { + + // If open returns false, we should skip writing rows. + private val opened = writer.open(partitionId, epochId) + + override def write(record: InternalRow): Unit = { + if (!opened) return + + try { + writer.process(rowConverter(record)) + } catch { + case t: Throwable => + writer.close(t) + throw t + } + } + + override def commit(): WriterCommitMessage = { + writer.close(null) + ForeachWriterCommitMessage + } + + override def abort(): Unit = {} +} + +/** + * An empty [[WriterCommitMessage]]. [[ForeachWriter]] implementations have no global coordination. + */ +case object ForeachWriterCommitMessage extends WriterCommitMessage
http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala deleted file mode 100644 index e8ce21c..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterProvider.scala +++ /dev/null @@ -1,139 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.sources - -import org.apache.spark.sql.{ForeachWriter, SparkSession} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.python.PythonForeachWriter -import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamWriteSupport} -import org.apache.spark.sql.sources.v2.writer.{DataWriter, DataWriterFactory, WriterCommitMessage} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter -import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.StructType - -/** - * A [[org.apache.spark.sql.sources.v2.DataSourceV2]] for forwarding data into the specified - * [[ForeachWriter]]. - * - * @param writer The [[ForeachWriter]] to process all data. - * @param converter An object to convert internal rows to target type T. Either it can be - * a [[ExpressionEncoder]] or a direct converter function. - * @tparam T The expected type of the sink. - */ -case class ForeachWriterProvider[T]( - writer: ForeachWriter[T], - converter: Either[ExpressionEncoder[T], InternalRow => T]) extends StreamWriteSupport { - - override def createStreamWriter( - queryId: String, - schema: StructType, - mode: OutputMode, - options: DataSourceOptions): StreamWriter = { - new StreamWriter { - override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} - override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} - - override def createWriterFactory(): DataWriterFactory[InternalRow] = { - val rowConverter: InternalRow => T = converter match { - case Left(enc) => - val boundEnc = enc.resolveAndBind( - schema.toAttributes, - SparkSession.getActiveSession.get.sessionState.analyzer) - boundEnc.fromRow - case Right(func) => - func - } - ForeachWriterFactory(writer, rowConverter) - } - - override def toString: String = "ForeachSink" - } - } -} - -object ForeachWriterProvider { - def apply[T]( - writer: ForeachWriter[T], - encoder: ExpressionEncoder[T]): ForeachWriterProvider[_] = { - writer match { - case pythonWriter: PythonForeachWriter => - new ForeachWriterProvider[UnsafeRow]( - pythonWriter, Right((x: InternalRow) => x.asInstanceOf[UnsafeRow])) - case _ => - new ForeachWriterProvider[T](writer, Left(encoder)) - } - } -} - -case class ForeachWriterFactory[T]( - writer: ForeachWriter[T], - rowConverter: InternalRow => T) - extends DataWriterFactory[InternalRow] { - override def createDataWriter( - partitionId: Int, - taskId: Long, - epochId: Long): ForeachDataWriter[T] = { - new ForeachDataWriter(writer, rowConverter, partitionId, epochId) - } -} - -/** - * A [[DataWriter]] which writes data in this partition to a [[ForeachWriter]]. - * - * @param writer The [[ForeachWriter]] to process all data. - * @param rowConverter A function which can convert [[InternalRow]] to the required type [[T]] - * @param partitionId - * @param epochId - * @tparam T The type expected by the writer. - */ -class ForeachDataWriter[T]( - writer: ForeachWriter[T], - rowConverter: InternalRow => T, - partitionId: Int, - epochId: Long) - extends DataWriter[InternalRow] { - - // If open returns false, we should skip writing rows. - private val opened = writer.open(partitionId, epochId) - - override def write(record: InternalRow): Unit = { - if (!opened) return - - try { - writer.process(rowConverter(record)) - } catch { - case t: Throwable => - writer.close(t) - throw t - } - } - - override def commit(): WriterCommitMessage = { - writer.close(null) - ForeachWriterCommitMessage - } - - override def abort(): Unit = {} -} - -/** - * An empty [[WriterCommitMessage]]. [[ForeachWriter]] implementations have no global coordination. - */ -case object ForeachWriterCommitMessage extends WriterCommitMessage http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWritSupport.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWritSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWritSupport.scala new file mode 100644 index 0000000..9f88416 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWritSupport.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.sources.v2.writer.{BatchWriteSupport, DataWriter, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} + +/** + * A [[BatchWriteSupport]] used to hook V2 stream writers into a microbatch plan. It implements + * the non-streaming interface, forwarding the epoch ID determined at construction to a wrapped + * streaming write support. + */ +class MicroBatchWritSupport(eppchId: Long, val writeSupport: StreamingWriteSupport) + extends BatchWriteSupport { + + override def commit(messages: Array[WriterCommitMessage]): Unit = { + writeSupport.commit(eppchId, messages) + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = { + writeSupport.abort(eppchId, messages) + } + + override def createBatchWriterFactory(): DataWriterFactory = { + new MicroBatchWriterFactory(eppchId, writeSupport.createStreamingWriterFactory()) + } +} + +class MicroBatchWriterFactory(epochId: Long, streamingWriterFactory: StreamingDataWriterFactory) + extends DataWriterFactory { + + override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = { + streamingWriterFactory.createWriter(partitionId, taskId, epochId) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala deleted file mode 100644 index 2d43a7b..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.sources - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriterFactory, WriterCommitMessage} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter - -/** - * A [[DataSourceWriter]] used to hook V2 stream writers into a microbatch plan. It implements - * the non-streaming interface, forwarding the batch ID determined at construction to a wrapped - * streaming writer. - */ -class MicroBatchWriter(batchId: Long, val writer: StreamWriter) extends DataSourceWriter { - override def commit(messages: Array[WriterCommitMessage]): Unit = { - writer.commit(batchId, messages) - } - - override def abort(messages: Array[WriterCommitMessage]): Unit = writer.abort(batchId, messages) - - override def createWriterFactory(): DataWriterFactory[InternalRow] = writer.createWriterFactory() -} http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala index f26e11d..ac3c71c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala @@ -21,17 +21,18 @@ import scala.collection.mutable import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.{BatchWriteSupport, DataWriter, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingDataWriterFactory /** * A simple [[DataWriterFactory]] whose tasks just pack rows into the commit message for delivery - * to a [[DataSourceWriter]] on the driver. + * to a [[BatchWriteSupport]] on the driver. * * Note that, because it sends all rows to the driver, this factory will generally be unsuitable * for production-quality sinks. It's intended for use in tests. */ -case object PackedRowWriterFactory extends DataWriterFactory[InternalRow] { - override def createDataWriter( +case object PackedRowWriterFactory extends StreamingDataWriterFactory { + override def createWriter( partitionId: Int, taskId: Long, epochId: Long): DataWriter[InternalRow] = { http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchReadSupport.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchReadSupport.scala new file mode 100644 index 0000000..90680ea --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchReadSupport.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset} + +// A special `MicroBatchReadSupport` that can get latestOffset with a start offset. +trait RateControlMicroBatchReadSupport extends MicroBatchReadSupport { + + override def latestOffset(): Offset = { + throw new IllegalAccessException( + "latestOffset should not be called for RateControlMicroBatchReadSupport") + } + + def latestOffset(start: Offset): Offset +} http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReadSupport.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReadSupport.scala new file mode 100644 index 0000000..f536404 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReadSupport.scala @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.io._ +import java.nio.charset.StandardCharsets +import java.util.concurrent.TimeUnit + +import org.apache.commons.io.IOUtils + +import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.{ManualClock, SystemClock} + +class RateStreamMicroBatchReadSupport(options: DataSourceOptions, checkpointLocation: String) + extends MicroBatchReadSupport with Logging { + import RateStreamProvider._ + + private[sources] val clock = { + // The option to use a manual clock is provided only for unit testing purposes. + if (options.getBoolean("useManualClock", false)) new ManualClock else new SystemClock + } + + private val rowsPerSecond = + options.get(ROWS_PER_SECOND).orElse("1").toLong + + private val rampUpTimeSeconds = + Option(options.get(RAMP_UP_TIME).orElse(null.asInstanceOf[String])) + .map(JavaUtils.timeStringAsSec(_)) + .getOrElse(0L) + + private val maxSeconds = Long.MaxValue / rowsPerSecond + + if (rampUpTimeSeconds > maxSeconds) { + throw new ArithmeticException( + s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" + + s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.") + } + + private[sources] val creationTimeMs = { + val session = SparkSession.getActiveSession.orElse(SparkSession.getDefaultSession) + require(session.isDefined) + + val metadataLog = + new HDFSMetadataLog[LongOffset](session.get, checkpointLocation) { + override def serialize(metadata: LongOffset, out: OutputStream): Unit = { + val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) + writer.write("v" + VERSION + "\n") + writer.write(metadata.json) + writer.flush + } + + override def deserialize(in: InputStream): LongOffset = { + val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) + // HDFSMetadataLog guarantees that it never creates a partial file. + assert(content.length != 0) + if (content(0) == 'v') { + val indexOfNewLine = content.indexOf("\n") + if (indexOfNewLine > 0) { + parseVersion(content.substring(0, indexOfNewLine), VERSION) + LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } + } + + metadataLog.get(0).getOrElse { + val offset = LongOffset(clock.getTimeMillis()) + metadataLog.add(0, offset) + logInfo(s"Start time: $offset") + offset + }.offset + } + + @volatile private var lastTimeMs: Long = creationTimeMs + + override def initialOffset(): Offset = LongOffset(0L) + + override def latestOffset(): Offset = { + val now = clock.getTimeMillis() + if (lastTimeMs < now) { + lastTimeMs = now + } + LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - creationTimeMs)) + } + + override def deserializeOffset(json: String): Offset = { + LongOffset(json.toLong) + } + + override def fullSchema(): StructType = SCHEMA + + override def newScanConfigBuilder(start: Offset, end: Offset): ScanConfigBuilder = { + new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end)) + } + + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val sc = config.asInstanceOf[SimpleStreamingScanConfig] + val startSeconds = sc.start.asInstanceOf[LongOffset].offset + val endSeconds = sc.end.get.asInstanceOf[LongOffset].offset + assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") + if (endSeconds > maxSeconds) { + throw new ArithmeticException("Integer overflow. Max offset with " + + s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.") + } + // Fix "lastTimeMs" for recovery + if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs) { + lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs + } + val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds) + val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds) + logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " + + s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") + + if (rangeStart == rangeEnd) { + return Array.empty + } + + val localStartTimeMs = creationTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) + val relativeMsPerValue = + TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart) + val numPartitions = { + val activeSession = SparkSession.getActiveSession + require(activeSession.isDefined) + Option(options.get(NUM_PARTITIONS).orElse(null.asInstanceOf[String])) + .map(_.toInt) + .getOrElse(activeSession.get.sparkContext.defaultParallelism) + } + + (0 until numPartitions).map { p => + new RateStreamMicroBatchInputPartition( + p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) + }.toArray + } + + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + RateStreamMicroBatchReaderFactory + } + + override def commit(end: Offset): Unit = {} + + override def stop(): Unit = {} + + override def toString: String = s"RateStreamV2[rowsPerSecond=$rowsPerSecond, " + + s"rampUpTimeSeconds=$rampUpTimeSeconds, " + + s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}" +} + +case class RateStreamMicroBatchInputPartition( + partitionId: Int, + numPartitions: Int, + rangeStart: Long, + rangeEnd: Long, + localStartTimeMs: Long, + relativeMsPerValue: Double) extends InputPartition + +object RateStreamMicroBatchReaderFactory extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val p = partition.asInstanceOf[RateStreamMicroBatchInputPartition] + new RateStreamMicroBatchPartitionReader(p.partitionId, p.numPartitions, p.rangeStart, + p.rangeEnd, p.localStartTimeMs, p.relativeMsPerValue) + } +} + +class RateStreamMicroBatchPartitionReader( + partitionId: Int, + numPartitions: Int, + rangeStart: Long, + rangeEnd: Long, + localStartTimeMs: Long, + relativeMsPerValue: Double) extends PartitionReader[InternalRow] { + private var count: Long = 0 + + override def next(): Boolean = { + rangeStart + partitionId + numPartitions * count < rangeEnd + } + + override def get(): InternalRow = { + val currValue = rangeStart + partitionId + numPartitions * count + count += 1 + val relative = math.round((currValue - rangeStart) * relativeMsPerValue) + InternalRow(DateTimeUtils.fromMillis(relative + localStartTimeMs), currValue) + } + + override def close(): Unit = {} +} http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala deleted file mode 100644 index 9e0d954..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala +++ /dev/null @@ -1,220 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.sources - -import java.io._ -import java.nio.charset.StandardCharsets -import java.util.Optional -import java.util.concurrent.TimeUnit - -import scala.collection.JavaConverters._ - -import org.apache.commons.io.IOUtils - -import org.apache.spark.internal.Logging -import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.{Row, SparkSession} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} -import org.apache.spark.sql.types.StructType -import org.apache.spark.util.{ManualClock, SystemClock} - -class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: String) - extends MicroBatchReader with Logging { - import RateStreamProvider._ - - private[sources] val clock = { - // The option to use a manual clock is provided only for unit testing purposes. - if (options.getBoolean("useManualClock", false)) new ManualClock else new SystemClock - } - - private val rowsPerSecond = - options.get(ROWS_PER_SECOND).orElse("1").toLong - - private val rampUpTimeSeconds = - Option(options.get(RAMP_UP_TIME).orElse(null.asInstanceOf[String])) - .map(JavaUtils.timeStringAsSec(_)) - .getOrElse(0L) - - private val maxSeconds = Long.MaxValue / rowsPerSecond - - if (rampUpTimeSeconds > maxSeconds) { - throw new ArithmeticException( - s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" + - s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.") - } - - private[sources] val creationTimeMs = { - val session = SparkSession.getActiveSession.orElse(SparkSession.getDefaultSession) - require(session.isDefined) - - val metadataLog = - new HDFSMetadataLog[LongOffset](session.get, checkpointLocation) { - override def serialize(metadata: LongOffset, out: OutputStream): Unit = { - val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) - writer.write("v" + VERSION + "\n") - writer.write(metadata.json) - writer.flush - } - - override def deserialize(in: InputStream): LongOffset = { - val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) - // HDFSMetadataLog guarantees that it never creates a partial file. - assert(content.length != 0) - if (content(0) == 'v') { - val indexOfNewLine = content.indexOf("\n") - if (indexOfNewLine > 0) { - parseVersion(content.substring(0, indexOfNewLine), VERSION) - LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) - } else { - throw new IllegalStateException( - s"Log file was malformed: failed to detect the log file version line.") - } - } else { - throw new IllegalStateException( - s"Log file was malformed: failed to detect the log file version line.") - } - } - } - - metadataLog.get(0).getOrElse { - val offset = LongOffset(clock.getTimeMillis()) - metadataLog.add(0, offset) - logInfo(s"Start time: $offset") - offset - }.offset - } - - @volatile private var lastTimeMs: Long = creationTimeMs - - private var start: LongOffset = _ - private var end: LongOffset = _ - - override def readSchema(): StructType = SCHEMA - - override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = { - this.start = start.orElse(LongOffset(0L)).asInstanceOf[LongOffset] - this.end = end.orElse { - val now = clock.getTimeMillis() - if (lastTimeMs < now) { - lastTimeMs = now - } - LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - creationTimeMs)) - }.asInstanceOf[LongOffset] - } - - override def getStartOffset(): Offset = { - if (start == null) throw new IllegalStateException("start offset not set") - start - } - override def getEndOffset(): Offset = { - if (end == null) throw new IllegalStateException("end offset not set") - end - } - - override def deserializeOffset(json: String): Offset = { - LongOffset(json.toLong) - } - - override def planInputPartitions(): java.util.List[InputPartition[InternalRow]] = { - val startSeconds = LongOffset.convert(start).map(_.offset).getOrElse(0L) - val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) - assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") - if (endSeconds > maxSeconds) { - throw new ArithmeticException("Integer overflow. Max offset with " + - s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.") - } - // Fix "lastTimeMs" for recovery - if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs) { - lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs - } - val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds) - val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds) - logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " + - s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") - - if (rangeStart == rangeEnd) { - return List.empty.asJava - } - - val localStartTimeMs = creationTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) - val relativeMsPerValue = - TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart) - val numPartitions = { - val activeSession = SparkSession.getActiveSession - require(activeSession.isDefined) - Option(options.get(NUM_PARTITIONS).orElse(null.asInstanceOf[String])) - .map(_.toInt) - .getOrElse(activeSession.get.sparkContext.defaultParallelism) - } - - (0 until numPartitions).map { p => - new RateStreamMicroBatchInputPartition( - p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) - : InputPartition[InternalRow] - }.toList.asJava - } - - override def commit(end: Offset): Unit = {} - - override def stop(): Unit = {} - - override def toString: String = s"RateStreamV2[rowsPerSecond=$rowsPerSecond, " + - s"rampUpTimeSeconds=$rampUpTimeSeconds, " + - s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}" -} - -class RateStreamMicroBatchInputPartition( - partitionId: Int, - numPartitions: Int, - rangeStart: Long, - rangeEnd: Long, - localStartTimeMs: Long, - relativeMsPerValue: Double) extends InputPartition[InternalRow] { - - override def createPartitionReader(): InputPartitionReader[InternalRow] = - new RateStreamMicroBatchInputPartitionReader(partitionId, numPartitions, rangeStart, rangeEnd, - localStartTimeMs, relativeMsPerValue) -} - -class RateStreamMicroBatchInputPartitionReader( - partitionId: Int, - numPartitions: Int, - rangeStart: Long, - rangeEnd: Long, - localStartTimeMs: Long, - relativeMsPerValue: Double) extends InputPartitionReader[InternalRow] { - private var count: Long = 0 - - override def next(): Boolean = { - rangeStart + partitionId + numPartitions * count < rangeEnd - } - - override def get(): InternalRow = { - val currValue = rangeStart + partitionId + numPartitions * count - count += 1 - val relative = math.round((currValue - rangeStart) * relativeMsPerValue) - InternalRow(DateTimeUtils.fromMillis(relative + localStartTimeMs), currValue) - } - - override def close(): Unit = {} -} http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala index 6bdd492..6942dfb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala @@ -17,14 +17,11 @@ package org.apache.spark.sql.execution.streaming.sources -import java.util.Optional - import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader +import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReadSupport import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, MicroBatchReadSupport} import org.apache.spark.sql.types._ /** @@ -42,13 +39,12 @@ import org.apache.spark.sql.types._ * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. */ class RateStreamProvider extends DataSourceV2 - with MicroBatchReadSupport with ContinuousReadSupport with DataSourceRegister { + with MicroBatchReadSupportProvider with ContinuousReadSupportProvider with DataSourceRegister { import RateStreamProvider._ - override def createMicroBatchReader( - schema: Optional[StructType], + override def createMicroBatchReadSupport( checkpointLocation: String, - options: DataSourceOptions): MicroBatchReader = { + options: DataSourceOptions): MicroBatchReadSupport = { if (options.get(ROWS_PER_SECOND).isPresent) { val rowsPerSecond = options.get(ROWS_PER_SECOND).get().toLong if (rowsPerSecond <= 0) { @@ -74,17 +70,14 @@ class RateStreamProvider extends DataSourceV2 } } - if (schema.isPresent) { - throw new AnalysisException("The rate source does not support a user-specified schema.") - } - - new RateStreamMicroBatchReader(options, checkpointLocation) + new RateStreamMicroBatchReadSupport(options, checkpointLocation) } - override def createContinuousReader( - schema: Optional[StructType], + override def createContinuousReadSupport( checkpointLocation: String, - options: DataSourceOptions): ContinuousReader = new RateStreamContinuousReader(options) + options: DataSourceOptions): ContinuousReadSupport = { + new RateStreamContinuousReadSupport(options) + } override def shortName(): String = "rate" } http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 2a5d21f..2509450 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -35,9 +35,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink} -import org.apache.spark.sql.sources.v2.{CustomMetrics, DataSourceOptions, DataSourceV2, StreamWriteSupport} +import org.apache.spark.sql.sources.v2.{CustomMetrics, DataSourceOptions, DataSourceV2, StreamingWriteSupportProvider} import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamWriter, SupportsCustomWriterMetrics} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport, SupportsCustomWriterMetrics} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -45,13 +45,15 @@ import org.apache.spark.sql.types.StructType * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkBase with Logging { - override def createStreamWriter( +class MemorySinkV2 extends DataSourceV2 with StreamingWriteSupportProvider + with MemorySinkBase with Logging { + + override def createStreamingWriteSupport( queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceOptions): StreamWriter = { - new MemoryStreamWriter(this, mode, schema) + options: DataSourceOptions): StreamingWriteSupport = { + new MemoryStreamingWriteSupport(this, mode, schema) } private case class AddedData(batchId: Long, data: Array[Row]) @@ -132,35 +134,15 @@ class MemoryV2CustomMetrics(sink: MemorySinkV2) extends CustomMetrics { override def json(): String = Serialization.write(Map("numRows" -> sink.numRows)) } -class MemoryWriter(sink: MemorySinkV2, batchId: Long, outputMode: OutputMode, schema: StructType) - extends DataSourceWriter with SupportsCustomWriterMetrics with Logging { - - private val memoryV2CustomMetrics = new MemoryV2CustomMetrics(sink) - - override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode, schema) - - def commit(messages: Array[WriterCommitMessage]): Unit = { - val newRows = messages.flatMap { - case message: MemoryWriterCommitMessage => message.data - } - sink.write(batchId, outputMode, newRows) - } - - override def abort(messages: Array[WriterCommitMessage]): Unit = { - // Don't accept any of the new input. - } - - override def getCustomMetrics: CustomMetrics = { - memoryV2CustomMetrics - } -} - -class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode, schema: StructType) - extends StreamWriter with SupportsCustomWriterMetrics { +class MemoryStreamingWriteSupport( + val sink: MemorySinkV2, outputMode: OutputMode, schema: StructType) + extends StreamingWriteSupport with SupportsCustomWriterMetrics { private val customMemoryV2Metrics = new MemoryV2CustomMetrics(sink) - override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode, schema) + override def createStreamingWriterFactory: MemoryWriterFactory = { + MemoryWriterFactory(outputMode, schema) + } override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { val newRows = messages.flatMap { @@ -173,19 +155,23 @@ class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode, schema: // Don't accept any of the new input. } - override def getCustomMetrics: CustomMetrics = { - customMemoryV2Metrics - } + override def getCustomMetrics: CustomMetrics = customMemoryV2Metrics } case class MemoryWriterFactory(outputMode: OutputMode, schema: StructType) - extends DataWriterFactory[InternalRow] { + extends DataWriterFactory with StreamingDataWriterFactory { - override def createDataWriter( + override def createWriter( + partitionId: Int, + taskId: Long): DataWriter[InternalRow] = { + new MemoryDataWriter(partitionId, outputMode, schema) + } + + override def createWriter( partitionId: Int, taskId: Long, epochId: Long): DataWriter[InternalRow] = { - new MemoryDataWriter(partitionId, outputMode, schema) + createWriter(partitionId, taskId) } } http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala index 874c479..b2a573e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala @@ -20,11 +20,10 @@ package org.apache.spark.sql.execution.streaming.sources import java.io.{BufferedReader, InputStreamReader, IOException} import java.net.Socket import java.text.SimpleDateFormat -import java.util.{Calendar, List => JList, Locale, Optional} +import java.util.{Calendar, Locale} import java.util.concurrent.atomic.AtomicBoolean import javax.annotation.concurrent.GuardedBy -import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer import scala.util.{Failure, Success, Try} @@ -32,16 +31,15 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.LongOffset -import org.apache.spark.sql.execution.streaming.continuous.TextSocketContinuousReader +import org.apache.spark.sql.execution.streaming.{LongOffset, SimpleStreamingScanConfig, SimpleStreamingScanConfigBuilder} +import org.apache.spark.sql.execution.streaming.continuous.TextSocketContinuousReadSupport import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, DataSourceV2, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader, Offset} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, DataSourceV2, MicroBatchReadSupportProvider} +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, MicroBatchReadSupport, Offset} import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} import org.apache.spark.unsafe.types.UTF8String -// Shared object for micro-batch and continuous reader object TextSocketReader { val SCHEMA_REGULAR = StructType(StructField("value", StringType) :: Nil) val SCHEMA_TIMESTAMP = StructType(StructField("value", StringType) :: @@ -50,14 +48,12 @@ object TextSocketReader { } /** - * A MicroBatchReader that reads text lines through a TCP socket, designed only for tutorials and - * debugging. This MicroBatchReader will *not* work in production applications due to multiple - * reasons, including no support for fault recovery. + * A MicroBatchReadSupport that reads text lines through a TCP socket, designed only for tutorials + * and debugging. This MicroBatchReadSupport will *not* work in production applications due to + * multiple reasons, including no support for fault recovery. */ -class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchReader with Logging { - - private var startOffset: Offset = _ - private var endOffset: Offset = _ +class TextSocketMicroBatchReadSupport(options: DataSourceOptions) + extends MicroBatchReadSupport with Logging { private val host: String = options.get("host").get() private val port: Int = options.get("port").get().toInt @@ -103,7 +99,7 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR logWarning(s"Stream closed by $host:$port") return } - TextSocketMicroBatchReader.this.synchronized { + TextSocketMicroBatchReadSupport.this.synchronized { val newData = ( UTF8String.fromString(line), DateTimeUtils.fromMillis(Calendar.getInstance().getTimeInMillis) @@ -120,24 +116,15 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR readThread.start() } - override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = synchronized { - startOffset = start.orElse(LongOffset(-1L)) - endOffset = end.orElse(currentOffset) - } + override def initialOffset(): Offset = LongOffset(-1L) - override def getStartOffset(): Offset = { - Option(startOffset).getOrElse(throw new IllegalStateException("start offset not set")) - } - - override def getEndOffset(): Offset = { - Option(endOffset).getOrElse(throw new IllegalStateException("end offset not set")) - } + override def latestOffset(): Offset = currentOffset override def deserializeOffset(json: String): Offset = { LongOffset(json.toLong) } - override def readSchema(): StructType = { + override def fullSchema(): StructType = { if (options.getBoolean("includeTimestamp", false)) { TextSocketReader.SCHEMA_TIMESTAMP } else { @@ -145,12 +132,14 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR } } - override def planInputPartitions(): JList[InputPartition[InternalRow]] = { - assert(startOffset != null && endOffset != null, - "start offset and end offset should already be set before create read tasks.") + override def newScanConfigBuilder(start: Offset, end: Offset): ScanConfigBuilder = { + new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end)) + } - val startOrdinal = LongOffset.convert(startOffset).get.offset.toInt + 1 - val endOrdinal = LongOffset.convert(endOffset).get.offset.toInt + 1 + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val sc = config.asInstanceOf[SimpleStreamingScanConfig] + val startOrdinal = sc.start.asInstanceOf[LongOffset].offset.toInt + 1 + val endOrdinal = sc.end.get.asInstanceOf[LongOffset].offset.toInt + 1 // Internal buffer only holds the batches after lastOffsetCommitted val rawList = synchronized { @@ -172,26 +161,29 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR slices(idx % numPartitions).append(r) } - (0 until numPartitions).map { i => - val slice = slices(i) - new InputPartition[InternalRow] { - override def createPartitionReader(): InputPartitionReader[InternalRow] = - new InputPartitionReader[InternalRow] { - private var currentIdx = -1 + slices.map(TextSocketInputPartition) + } - override def next(): Boolean = { - currentIdx += 1 - currentIdx < slice.size - } + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + new PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val slice = partition.asInstanceOf[TextSocketInputPartition].slice + new PartitionReader[InternalRow] { + private var currentIdx = -1 - override def get(): InternalRow = { - InternalRow(slice(currentIdx)._1, slice(currentIdx)._2) - } + override def next(): Boolean = { + currentIdx += 1 + currentIdx < slice.size + } - override def close(): Unit = {} + override def get(): InternalRow = { + InternalRow(slice(currentIdx)._1, slice(currentIdx)._2) } + + override def close(): Unit = {} + } } - }.toList.asJava + } } override def commit(end: Offset): Unit = synchronized { @@ -227,8 +219,11 @@ class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchR override def toString: String = s"TextSocketV2[host: $host, port: $port]" } +case class TextSocketInputPartition(slice: ListBuffer[(UTF8String, Long)]) extends InputPartition + class TextSocketSourceProvider extends DataSourceV2 - with MicroBatchReadSupport with ContinuousReadSupport with DataSourceRegister with Logging { + with MicroBatchReadSupportProvider with ContinuousReadSupportProvider + with DataSourceRegister with Logging { private def checkParameters(params: DataSourceOptions): Unit = { logWarning("The socket source should not be used for production applications! " + @@ -248,27 +243,18 @@ class TextSocketSourceProvider extends DataSourceV2 } } - override def createMicroBatchReader( - schema: Optional[StructType], + override def createMicroBatchReadSupport( checkpointLocation: String, - options: DataSourceOptions): MicroBatchReader = { + options: DataSourceOptions): MicroBatchReadSupport = { checkParameters(options) - if (schema.isPresent) { - throw new AnalysisException("The socket source does not support a user-specified schema.") - } - - new TextSocketMicroBatchReader(options) + new TextSocketMicroBatchReadSupport(options) } - override def createContinuousReader( - schema: Optional[StructType], + override def createContinuousReadSupport( checkpointLocation: String, - options: DataSourceOptions): ContinuousReader = { + options: DataSourceOptions): ContinuousReadSupport = { checkParameters(options) - if (schema.isPresent) { - throw new AnalysisException("The socket source does not support a user-specified schema.") - } - new TextSocketContinuousReader(options) + new TextSocketContinuousReadSupport(options) } /** String that represents the format that this data source provider uses. */ http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index ef8dc3a..39e9e1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.streaming -import java.util.{Locale, Optional} +import java.util.Locale import scala.collection.JavaConverters._ @@ -28,8 +28,8 @@ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} import org.apache.spark.sql.sources.StreamSourceProvider -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader +import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, MicroBatchReadSupportProvider} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, MicroBatchReadSupport} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -172,19 +172,21 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo case _ => None } ds match { - case s: MicroBatchReadSupport => - var tempReader: MicroBatchReader = null + case s: MicroBatchReadSupportProvider => + var tempReadSupport: MicroBatchReadSupport = null val schema = try { - tempReader = s.createMicroBatchReader( - Optional.ofNullable(userSpecifiedSchema.orNull), - Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath, - options) - tempReader.readSchema() + val tmpCheckpointPath = Utils.createTempDir(namePrefix = s"tempCP").getCanonicalPath + tempReadSupport = if (userSpecifiedSchema.isDefined) { + s.createMicroBatchReadSupport(userSpecifiedSchema.get, tmpCheckpointPath, options) + } else { + s.createMicroBatchReadSupport(tmpCheckpointPath, options) + } + tempReadSupport.fullSchema() } finally { // Stop tempReader to avoid side-effect thing - if (tempReader != null) { - tempReader.stop() - tempReader = null + if (tempReadSupport != null) { + tempReadSupport.stop() + tempReadSupport = null } } Dataset.ofRows( @@ -192,16 +194,28 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo StreamingRelationV2( s, source, extraOptions.toMap, schema.toAttributes, v1Relation)(sparkSession)) - case s: ContinuousReadSupport => - val tempReader = s.createContinuousReader( - Optional.ofNullable(userSpecifiedSchema.orNull), - Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath, - options) + case s: ContinuousReadSupportProvider => + var tempReadSupport: ContinuousReadSupport = null + val schema = try { + val tmpCheckpointPath = Utils.createTempDir(namePrefix = s"tempCP").getCanonicalPath + tempReadSupport = if (userSpecifiedSchema.isDefined) { + s.createContinuousReadSupport(userSpecifiedSchema.get, tmpCheckpointPath, options) + } else { + s.createContinuousReadSupport(tmpCheckpointPath, options) + } + tempReadSupport.fullSchema() + } finally { + // Stop tempReader to avoid side-effect thing + if (tempReadSupport != null) { + tempReadSupport.stop() + tempReadSupport = null + } + } Dataset.ofRows( sparkSession, StreamingRelationV2( s, source, extraOptions.toMap, - tempReader.readSchema().toAttributes, v1Relation)(sparkSession)) + schema.toAttributes, v1Relation)(sparkSession)) case _ => // Code path for data source v1. Dataset.ofRows(sparkSession, StreamingRelation(v1DataSource)) http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 3b9a56f..7866e4f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -21,7 +21,7 @@ import java.util.Locale import scala.collection.JavaConverters._ -import org.apache.spark.annotation.{InterfaceStability, Since} +import org.apache.spark.annotation.InterfaceStability import org.apache.spark.api.java.function.VoidFunction2 import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources._ -import org.apache.spark.sql.sources.v2.StreamWriteSupport +import org.apache.spark.sql.sources.v2.StreamingWriteSupportProvider /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -270,7 +270,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { query } else if (source == "foreach") { assertNotPartitioned("foreach") - val sink = ForeachWriterProvider[T](foreachWriter, ds.exprEnc) + val sink = ForeachWriteSupportProvider[T](foreachWriter, ds.exprEnc) df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), @@ -299,7 +299,8 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) val disabledSources = df.sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",") val sink = ds.newInstance() match { - case w: StreamWriteSupport if !disabledSources.contains(w.getClass.getCanonicalName) => w + case w: StreamingWriteSupportProvider + if !disabledSources.contains(w.getClass.getCanonicalName) => w case _ => val ds = DataSource( df.sparkSession, http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 25bb052..cd52d99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.STREAMING_QUERY_LISTENERS -import org.apache.spark.sql.sources.v2.StreamWriteSupport +import org.apache.spark.sql.sources.v2.StreamingWriteSupportProvider import org.apache.spark.util.{Clock, SystemClock, Utils} /** @@ -256,7 +256,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo } (sink, trigger) match { - case (v2Sink: StreamWriteSupport, trigger: ContinuousTrigger) => + case (v2Sink: StreamingWriteSupportProvider, trigger: ContinuousTrigger) => if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) { UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) } http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index e4cead9..5602310 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -24,29 +24,71 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.sources.Filter; import org.apache.spark.sql.sources.GreaterThan; +import org.apache.spark.sql.sources.v2.BatchReadSupportProvider; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.StructType; -public class JavaAdvancedDataSourceV2 implements DataSourceV2, ReadSupport { +public class JavaAdvancedDataSourceV2 implements DataSourceV2, BatchReadSupportProvider { - public class Reader implements DataSourceReader, SupportsPushDownRequiredColumns, - SupportsPushDownFilters { + public class ReadSupport extends JavaSimpleReadSupport { + @Override + public ScanConfigBuilder newScanConfigBuilder() { + return new AdvancedScanConfigBuilder(); + } + + @Override + public InputPartition[] planInputPartitions(ScanConfig config) { + Filter[] filters = ((AdvancedScanConfigBuilder) config).filters; + List<InputPartition> res = new ArrayList<>(); + + Integer lowerBound = null; + for (Filter filter : filters) { + if (filter instanceof GreaterThan) { + GreaterThan f = (GreaterThan) filter; + if ("i".equals(f.attribute()) && f.value() instanceof Integer) { + lowerBound = (Integer) f.value(); + break; + } + } + } + + if (lowerBound == null) { + res.add(new JavaRangeInputPartition(0, 5)); + res.add(new JavaRangeInputPartition(5, 10)); + } else if (lowerBound < 4) { + res.add(new JavaRangeInputPartition(lowerBound + 1, 5)); + res.add(new JavaRangeInputPartition(5, 10)); + } else if (lowerBound < 9) { + res.add(new JavaRangeInputPartition(lowerBound + 1, 10)); + } + + return res.stream().toArray(InputPartition[]::new); + } + + @Override + public PartitionReaderFactory createReaderFactory(ScanConfig config) { + StructType requiredSchema = ((AdvancedScanConfigBuilder) config).requiredSchema; + return new AdvancedReaderFactory(requiredSchema); + } + } + + public static class AdvancedScanConfigBuilder implements ScanConfigBuilder, ScanConfig, + SupportsPushDownFilters, SupportsPushDownRequiredColumns { // Exposed for testing. public StructType requiredSchema = new StructType().add("i", "int").add("j", "int"); public Filter[] filters = new Filter[0]; @Override - public StructType readSchema() { - return requiredSchema; + public void pruneColumns(StructType requiredSchema) { + this.requiredSchema = requiredSchema; } @Override - public void pruneColumns(StructType requiredSchema) { - this.requiredSchema = requiredSchema; + public StructType readSchema() { + return requiredSchema; } @Override @@ -79,79 +121,54 @@ public class JavaAdvancedDataSourceV2 implements DataSourceV2, ReadSupport { } @Override - public List<InputPartition<InternalRow>> planInputPartitions() { - List<InputPartition<InternalRow>> res = new ArrayList<>(); - - Integer lowerBound = null; - for (Filter filter : filters) { - if (filter instanceof GreaterThan) { - GreaterThan f = (GreaterThan) filter; - if ("i".equals(f.attribute()) && f.value() instanceof Integer) { - lowerBound = (Integer) f.value(); - break; - } - } - } - - if (lowerBound == null) { - res.add(new JavaAdvancedInputPartition(0, 5, requiredSchema)); - res.add(new JavaAdvancedInputPartition(5, 10, requiredSchema)); - } else if (lowerBound < 4) { - res.add(new JavaAdvancedInputPartition(lowerBound + 1, 5, requiredSchema)); - res.add(new JavaAdvancedInputPartition(5, 10, requiredSchema)); - } else if (lowerBound < 9) { - res.add(new JavaAdvancedInputPartition(lowerBound + 1, 10, requiredSchema)); - } - - return res; + public ScanConfig build() { + return this; } } - static class JavaAdvancedInputPartition implements InputPartition<InternalRow>, - InputPartitionReader<InternalRow> { - private int start; - private int end; - private StructType requiredSchema; + static class AdvancedReaderFactory implements PartitionReaderFactory { + StructType requiredSchema; - JavaAdvancedInputPartition(int start, int end, StructType requiredSchema) { - this.start = start; - this.end = end; + AdvancedReaderFactory(StructType requiredSchema) { this.requiredSchema = requiredSchema; } @Override - public InputPartitionReader<InternalRow> createPartitionReader() { - return new JavaAdvancedInputPartition(start - 1, end, requiredSchema); - } - - @Override - public boolean next() { - start += 1; - return start < end; - } + public PartitionReader<InternalRow> createReader(InputPartition partition) { + JavaRangeInputPartition p = (JavaRangeInputPartition) partition; + return new PartitionReader<InternalRow>() { + private int current = p.start - 1; + + @Override + public boolean next() throws IOException { + current += 1; + return current < p.end; + } - @Override - public InternalRow get() { - Object[] values = new Object[requiredSchema.size()]; - for (int i = 0; i < values.length; i++) { - if ("i".equals(requiredSchema.apply(i).name())) { - values[i] = start; - } else if ("j".equals(requiredSchema.apply(i).name())) { - values[i] = -start; + @Override + public InternalRow get() { + Object[] values = new Object[requiredSchema.size()]; + for (int i = 0; i < values.length; i++) { + if ("i".equals(requiredSchema.apply(i).name())) { + values[i] = current; + } else if ("j".equals(requiredSchema.apply(i).name())) { + values[i] = -current; + } + } + return new GenericInternalRow(values); } - } - return new GenericInternalRow(values); - } - @Override - public void close() throws IOException { + @Override + public void close() throws IOException { + } + }; } } @Override - public DataSourceReader createReader(DataSourceOptions options) { - return new Reader(); + public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { + return new ReadSupport(); } } http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java deleted file mode 100644 index 97d6176..0000000 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package test.org.apache.spark.sql.sources.v2; - -import java.io.IOException; -import java.util.List; - -import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; -import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.ReadSupport; -import org.apache.spark.sql.sources.v2.reader.*; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.types.StructType; -import org.apache.spark.sql.vectorized.ColumnVector; -import org.apache.spark.sql.vectorized.ColumnarBatch; - - -public class JavaBatchDataSourceV2 implements DataSourceV2, ReadSupport { - - class Reader implements DataSourceReader, SupportsScanColumnarBatch { - private final StructType schema = new StructType().add("i", "int").add("j", "int"); - - @Override - public StructType readSchema() { - return schema; - } - - @Override - public List<InputPartition<ColumnarBatch>> planBatchInputPartitions() { - return java.util.Arrays.asList( - new JavaBatchInputPartition(0, 50), new JavaBatchInputPartition(50, 90)); - } - } - - static class JavaBatchInputPartition - implements InputPartition<ColumnarBatch>, InputPartitionReader<ColumnarBatch> { - private int start; - private int end; - - private static final int BATCH_SIZE = 20; - - private OnHeapColumnVector i; - private OnHeapColumnVector j; - private ColumnarBatch batch; - - JavaBatchInputPartition(int start, int end) { - this.start = start; - this.end = end; - } - - @Override - public InputPartitionReader<ColumnarBatch> createPartitionReader() { - this.i = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); - this.j = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); - ColumnVector[] vectors = new ColumnVector[2]; - vectors[0] = i; - vectors[1] = j; - this.batch = new ColumnarBatch(vectors); - return this; - } - - @Override - public boolean next() { - i.reset(); - j.reset(); - int count = 0; - while (start < end && count < BATCH_SIZE) { - i.putInt(count, start); - j.putInt(count, -start); - start += 1; - count += 1; - } - - if (count == 0) { - return false; - } else { - batch.setNumRows(count); - return true; - } - } - - @Override - public ColumnarBatch get() { - return batch; - } - - @Override - public void close() throws IOException { - batch.close(); - } - } - - - @Override - public DataSourceReader createReader(DataSourceOptions options) { - return new Reader(); - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java new file mode 100644 index 0000000..28a9330 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import java.io.IOException; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; +import org.apache.spark.sql.sources.v2.BatchReadSupportProvider; +import org.apache.spark.sql.sources.v2.DataSourceOptions; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarBatch; + + +public class JavaColumnarDataSourceV2 implements DataSourceV2, BatchReadSupportProvider { + + class ReadSupport extends JavaSimpleReadSupport { + + @Override + public InputPartition[] planInputPartitions(ScanConfig config) { + InputPartition[] partitions = new InputPartition[2]; + partitions[0] = new JavaRangeInputPartition(0, 50); + partitions[1] = new JavaRangeInputPartition(50, 90); + return partitions; + } + + @Override + public PartitionReaderFactory createReaderFactory(ScanConfig config) { + return new ColumnarReaderFactory(); + } + } + + static class ColumnarReaderFactory implements PartitionReaderFactory { + private static final int BATCH_SIZE = 20; + + @Override + public boolean supportColumnarReads(InputPartition partition) { + return true; + } + + @Override + public PartitionReader<InternalRow> createReader(InputPartition partition) { + throw new UnsupportedOperationException(""); + } + + @Override + public PartitionReader<ColumnarBatch> createColumnarReader(InputPartition partition) { + JavaRangeInputPartition p = (JavaRangeInputPartition) partition; + OnHeapColumnVector i = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); + OnHeapColumnVector j = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); + ColumnVector[] vectors = new ColumnVector[2]; + vectors[0] = i; + vectors[1] = j; + ColumnarBatch batch = new ColumnarBatch(vectors); + + return new PartitionReader<ColumnarBatch>() { + private int current = p.start; + + @Override + public boolean next() throws IOException { + i.reset(); + j.reset(); + int count = 0; + while (current < p.end && count < BATCH_SIZE) { + i.putInt(count, current); + j.putInt(count, -current); + current += 1; + count += 1; + } + + if (count == 0) { + return false; + } else { + batch.setNumRows(count); + return true; + } + } + + @Override + public ColumnarBatch get() { + return batch; + } + + @Override + public void close() throws IOException { + batch.close(); + } + }; + } + } + + @Override + public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { + return new ReadSupport(); + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java index 2d21324..18a11dd 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java @@ -19,38 +19,34 @@ package test.org.apache.spark.sql.sources.v2; import java.io.IOException; import java.util.Arrays; -import java.util.List; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; -import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.sources.v2.*; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.sources.v2.reader.partitioning.ClusteredDistribution; import org.apache.spark.sql.sources.v2.reader.partitioning.Distribution; import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning; -import org.apache.spark.sql.types.StructType; -public class JavaPartitionAwareDataSource implements DataSourceV2, ReadSupport { +public class JavaPartitionAwareDataSource implements DataSourceV2, BatchReadSupportProvider { - class Reader implements DataSourceReader, SupportsReportPartitioning { - private final StructType schema = new StructType().add("a", "int").add("b", "int"); + class ReadSupport extends JavaSimpleReadSupport implements SupportsReportPartitioning { @Override - public StructType readSchema() { - return schema; + public InputPartition[] planInputPartitions(ScanConfig config) { + InputPartition[] partitions = new InputPartition[2]; + partitions[0] = new SpecificInputPartition(new int[]{1, 1, 3}, new int[]{4, 4, 6}); + partitions[1] = new SpecificInputPartition(new int[]{2, 4, 4}, new int[]{6, 2, 2}); + return partitions; } @Override - public List<InputPartition<InternalRow>> planInputPartitions() { - return java.util.Arrays.asList( - new SpecificInputPartition(new int[]{1, 1, 3}, new int[]{4, 4, 6}), - new SpecificInputPartition(new int[]{2, 4, 4}, new int[]{6, 2, 2})); + public PartitionReaderFactory createReaderFactory(ScanConfig config) { + return new SpecificReaderFactory(); } @Override - public Partitioning outputPartitioning() { + public Partitioning outputPartitioning(ScanConfig config) { return new MyPartitioning(); } } @@ -66,50 +62,53 @@ public class JavaPartitionAwareDataSource implements DataSourceV2, ReadSupport { public boolean satisfy(Distribution distribution) { if (distribution instanceof ClusteredDistribution) { String[] clusteredCols = ((ClusteredDistribution) distribution).clusteredColumns; - return Arrays.asList(clusteredCols).contains("a"); + return Arrays.asList(clusteredCols).contains("i"); } return false; } } - static class SpecificInputPartition implements InputPartition<InternalRow>, - InputPartitionReader<InternalRow> { - - private int[] i; - private int[] j; - private int current = -1; + static class SpecificInputPartition implements InputPartition { + int[] i; + int[] j; SpecificInputPartition(int[] i, int[] j) { assert i.length == j.length; this.i = i; this.j = j; } + } - @Override - public boolean next() throws IOException { - current += 1; - return current < i.length; - } - - @Override - public InternalRow get() { - return new GenericInternalRow(new Object[] {i[current], j[current]}); - } - - @Override - public void close() throws IOException { - - } + static class SpecificReaderFactory implements PartitionReaderFactory { @Override - public InputPartitionReader<InternalRow> createPartitionReader() { - return this; + public PartitionReader<InternalRow> createReader(InputPartition partition) { + SpecificInputPartition p = (SpecificInputPartition) partition; + return new PartitionReader<InternalRow>() { + private int current = -1; + + @Override + public boolean next() throws IOException { + current += 1; + return current < p.i.length; + } + + @Override + public InternalRow get() { + return new GenericInternalRow(new Object[] {p.i[current], p.j[current]}); + } + + @Override + public void close() throws IOException { + + } + }; } } @Override - public DataSourceReader createReader(DataSourceOptions options) { - return new Reader(); + public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { + return new ReadSupport(); } } http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java index 6fd6a44..cc9ac04 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java @@ -17,43 +17,39 @@ package test.org.apache.spark.sql.sources.v2; -import java.util.List; - -import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.sources.v2.BatchReadSupportProvider; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.ReadSupport; -import org.apache.spark.sql.sources.v2.reader.DataSourceReader; -import org.apache.spark.sql.sources.v2.reader.InputPartition; +import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.StructType; -public class JavaSchemaRequiredDataSource implements DataSourceV2, ReadSupport { +public class JavaSchemaRequiredDataSource implements DataSourceV2, BatchReadSupportProvider { - class Reader implements DataSourceReader { + class ReadSupport extends JavaSimpleReadSupport { private final StructType schema; - Reader(StructType schema) { + ReadSupport(StructType schema) { this.schema = schema; } @Override - public StructType readSchema() { + public StructType fullSchema() { return schema; } @Override - public List<InputPartition<InternalRow>> planInputPartitions() { - return java.util.Collections.emptyList(); + public InputPartition[] planInputPartitions(ScanConfig config) { + return new InputPartition[0]; } } @Override - public DataSourceReader createReader(DataSourceOptions options) { + public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { throw new IllegalArgumentException("requires a user-supplied schema"); } @Override - public DataSourceReader createReader(StructType schema, DataSourceOptions options) { - return new Reader(schema); + public BatchReadSupport createBatchReadSupport(StructType schema, DataSourceOptions options) { + return new ReadSupport(schema); } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org