http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWritSupport.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWritSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWritSupport.scala deleted file mode 100644 index 9f88416..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWritSupport.scala +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.sources - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.writer.{BatchWriteSupport, DataWriter, DataWriterFactory, WriterCommitMessage} -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} - -/** - * A [[BatchWriteSupport]] used to hook V2 stream writers into a microbatch plan. It implements - * the non-streaming interface, forwarding the epoch ID determined at construction to a wrapped - * streaming write support. - */ -class MicroBatchWritSupport(eppchId: Long, val writeSupport: StreamingWriteSupport) - extends BatchWriteSupport { - - override def commit(messages: Array[WriterCommitMessage]): Unit = { - writeSupport.commit(eppchId, messages) - } - - override def abort(messages: Array[WriterCommitMessage]): Unit = { - writeSupport.abort(eppchId, messages) - } - - override def createBatchWriterFactory(): DataWriterFactory = { - new MicroBatchWriterFactory(eppchId, writeSupport.createStreamingWriterFactory()) - } -} - -class MicroBatchWriterFactory(epochId: Long, streamingWriterFactory: StreamingDataWriterFactory) - extends DataWriterFactory { - - override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = { - streamingWriterFactory.createWriter(partitionId, taskId, epochId) - } -}
http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala new file mode 100644 index 0000000..2d43a7b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWriter.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriterFactory, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter + +/** + * A [[DataSourceWriter]] used to hook V2 stream writers into a microbatch plan. It implements + * the non-streaming interface, forwarding the batch ID determined at construction to a wrapped + * streaming writer. + */ +class MicroBatchWriter(batchId: Long, val writer: StreamWriter) extends DataSourceWriter { + override def commit(messages: Array[WriterCommitMessage]): Unit = { + writer.commit(batchId, messages) + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = writer.abort(batchId, messages) + + override def createWriterFactory(): DataWriterFactory[InternalRow] = writer.createWriterFactory() +} http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala index ac3c71c..f26e11d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/PackedRowWriterFactory.scala @@ -21,18 +21,17 @@ import scala.collection.mutable import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.writer.{BatchWriteSupport, DataWriter, DataWriterFactory, WriterCommitMessage} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingDataWriterFactory +import org.apache.spark.sql.sources.v2.writer.{DataSourceWriter, DataWriter, DataWriterFactory, WriterCommitMessage} /** * A simple [[DataWriterFactory]] whose tasks just pack rows into the commit message for delivery - * to a [[BatchWriteSupport]] on the driver. + * to a [[DataSourceWriter]] on the driver. * * Note that, because it sends all rows to the driver, this factory will generally be unsuitable * for production-quality sinks. It's intended for use in tests. */ -case object PackedRowWriterFactory extends StreamingDataWriterFactory { - override def createWriter( +case object PackedRowWriterFactory extends DataWriterFactory[InternalRow] { + override def createDataWriter( partitionId: Int, taskId: Long, epochId: Long): DataWriter[InternalRow] = { http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchReadSupport.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchReadSupport.scala deleted file mode 100644 index 90680ea..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateControlMicroBatchReadSupport.scala +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.sources - -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset} - -// A special `MicroBatchReadSupport` that can get latestOffset with a start offset. -trait RateControlMicroBatchReadSupport extends MicroBatchReadSupport { - - override def latestOffset(): Offset = { - throw new IllegalAccessException( - "latestOffset should not be called for RateControlMicroBatchReadSupport") - } - - def latestOffset(start: Offset): Offset -} http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReadSupport.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReadSupport.scala deleted file mode 100644 index f536404..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReadSupport.scala +++ /dev/null @@ -1,215 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.sources - -import java.io._ -import java.nio.charset.StandardCharsets -import java.util.concurrent.TimeUnit - -import org.apache.commons.io.IOUtils - -import org.apache.spark.internal.Logging -import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.sources.v2.DataSourceOptions -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReadSupport, Offset} -import org.apache.spark.sql.types.StructType -import org.apache.spark.util.{ManualClock, SystemClock} - -class RateStreamMicroBatchReadSupport(options: DataSourceOptions, checkpointLocation: String) - extends MicroBatchReadSupport with Logging { - import RateStreamProvider._ - - private[sources] val clock = { - // The option to use a manual clock is provided only for unit testing purposes. - if (options.getBoolean("useManualClock", false)) new ManualClock else new SystemClock - } - - private val rowsPerSecond = - options.get(ROWS_PER_SECOND).orElse("1").toLong - - private val rampUpTimeSeconds = - Option(options.get(RAMP_UP_TIME).orElse(null.asInstanceOf[String])) - .map(JavaUtils.timeStringAsSec(_)) - .getOrElse(0L) - - private val maxSeconds = Long.MaxValue / rowsPerSecond - - if (rampUpTimeSeconds > maxSeconds) { - throw new ArithmeticException( - s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" + - s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.") - } - - private[sources] val creationTimeMs = { - val session = SparkSession.getActiveSession.orElse(SparkSession.getDefaultSession) - require(session.isDefined) - - val metadataLog = - new HDFSMetadataLog[LongOffset](session.get, checkpointLocation) { - override def serialize(metadata: LongOffset, out: OutputStream): Unit = { - val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) - writer.write("v" + VERSION + "\n") - writer.write(metadata.json) - writer.flush - } - - override def deserialize(in: InputStream): LongOffset = { - val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) - // HDFSMetadataLog guarantees that it never creates a partial file. - assert(content.length != 0) - if (content(0) == 'v') { - val indexOfNewLine = content.indexOf("\n") - if (indexOfNewLine > 0) { - parseVersion(content.substring(0, indexOfNewLine), VERSION) - LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) - } else { - throw new IllegalStateException( - s"Log file was malformed: failed to detect the log file version line.") - } - } else { - throw new IllegalStateException( - s"Log file was malformed: failed to detect the log file version line.") - } - } - } - - metadataLog.get(0).getOrElse { - val offset = LongOffset(clock.getTimeMillis()) - metadataLog.add(0, offset) - logInfo(s"Start time: $offset") - offset - }.offset - } - - @volatile private var lastTimeMs: Long = creationTimeMs - - override def initialOffset(): Offset = LongOffset(0L) - - override def latestOffset(): Offset = { - val now = clock.getTimeMillis() - if (lastTimeMs < now) { - lastTimeMs = now - } - LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - creationTimeMs)) - } - - override def deserializeOffset(json: String): Offset = { - LongOffset(json.toLong) - } - - override def fullSchema(): StructType = SCHEMA - - override def newScanConfigBuilder(start: Offset, end: Offset): ScanConfigBuilder = { - new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end)) - } - - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { - val sc = config.asInstanceOf[SimpleStreamingScanConfig] - val startSeconds = sc.start.asInstanceOf[LongOffset].offset - val endSeconds = sc.end.get.asInstanceOf[LongOffset].offset - assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") - if (endSeconds > maxSeconds) { - throw new ArithmeticException("Integer overflow. Max offset with " + - s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.") - } - // Fix "lastTimeMs" for recovery - if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs) { - lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs - } - val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds) - val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds) - logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " + - s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") - - if (rangeStart == rangeEnd) { - return Array.empty - } - - val localStartTimeMs = creationTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) - val relativeMsPerValue = - TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart) - val numPartitions = { - val activeSession = SparkSession.getActiveSession - require(activeSession.isDefined) - Option(options.get(NUM_PARTITIONS).orElse(null.asInstanceOf[String])) - .map(_.toInt) - .getOrElse(activeSession.get.sparkContext.defaultParallelism) - } - - (0 until numPartitions).map { p => - new RateStreamMicroBatchInputPartition( - p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) - }.toArray - } - - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { - RateStreamMicroBatchReaderFactory - } - - override def commit(end: Offset): Unit = {} - - override def stop(): Unit = {} - - override def toString: String = s"RateStreamV2[rowsPerSecond=$rowsPerSecond, " + - s"rampUpTimeSeconds=$rampUpTimeSeconds, " + - s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}" -} - -case class RateStreamMicroBatchInputPartition( - partitionId: Int, - numPartitions: Int, - rangeStart: Long, - rangeEnd: Long, - localStartTimeMs: Long, - relativeMsPerValue: Double) extends InputPartition - -object RateStreamMicroBatchReaderFactory extends PartitionReaderFactory { - override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { - val p = partition.asInstanceOf[RateStreamMicroBatchInputPartition] - new RateStreamMicroBatchPartitionReader(p.partitionId, p.numPartitions, p.rangeStart, - p.rangeEnd, p.localStartTimeMs, p.relativeMsPerValue) - } -} - -class RateStreamMicroBatchPartitionReader( - partitionId: Int, - numPartitions: Int, - rangeStart: Long, - rangeEnd: Long, - localStartTimeMs: Long, - relativeMsPerValue: Double) extends PartitionReader[InternalRow] { - private var count: Long = 0 - - override def next(): Boolean = { - rangeStart + partitionId + numPartitions * count < rangeEnd - } - - override def get(): InternalRow = { - val currValue = rangeStart + partitionId + numPartitions * count - count += 1 - val relative = math.round((currValue - rangeStart) * relativeMsPerValue) - InternalRow(DateTimeUtils.fromMillis(relative + localStartTimeMs), currValue) - } - - override def close(): Unit = {} -} http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala new file mode 100644 index 0000000..9e0d954 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchReader.scala @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.io._ +import java.nio.charset.StandardCharsets +import java.util.Optional +import java.util.concurrent.TimeUnit + +import scala.collection.JavaConverters._ + +import org.apache.commons.io.IOUtils + +import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.{ManualClock, SystemClock} + +class RateStreamMicroBatchReader(options: DataSourceOptions, checkpointLocation: String) + extends MicroBatchReader with Logging { + import RateStreamProvider._ + + private[sources] val clock = { + // The option to use a manual clock is provided only for unit testing purposes. + if (options.getBoolean("useManualClock", false)) new ManualClock else new SystemClock + } + + private val rowsPerSecond = + options.get(ROWS_PER_SECOND).orElse("1").toLong + + private val rampUpTimeSeconds = + Option(options.get(RAMP_UP_TIME).orElse(null.asInstanceOf[String])) + .map(JavaUtils.timeStringAsSec(_)) + .getOrElse(0L) + + private val maxSeconds = Long.MaxValue / rowsPerSecond + + if (rampUpTimeSeconds > maxSeconds) { + throw new ArithmeticException( + s"Integer overflow. Max offset with $rowsPerSecond rowsPerSecond" + + s" is $maxSeconds, but 'rampUpTimeSeconds' is $rampUpTimeSeconds.") + } + + private[sources] val creationTimeMs = { + val session = SparkSession.getActiveSession.orElse(SparkSession.getDefaultSession) + require(session.isDefined) + + val metadataLog = + new HDFSMetadataLog[LongOffset](session.get, checkpointLocation) { + override def serialize(metadata: LongOffset, out: OutputStream): Unit = { + val writer = new BufferedWriter(new OutputStreamWriter(out, StandardCharsets.UTF_8)) + writer.write("v" + VERSION + "\n") + writer.write(metadata.json) + writer.flush + } + + override def deserialize(in: InputStream): LongOffset = { + val content = IOUtils.toString(new InputStreamReader(in, StandardCharsets.UTF_8)) + // HDFSMetadataLog guarantees that it never creates a partial file. + assert(content.length != 0) + if (content(0) == 'v') { + val indexOfNewLine = content.indexOf("\n") + if (indexOfNewLine > 0) { + parseVersion(content.substring(0, indexOfNewLine), VERSION) + LongOffset(SerializedOffset(content.substring(indexOfNewLine + 1))) + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } else { + throw new IllegalStateException( + s"Log file was malformed: failed to detect the log file version line.") + } + } + } + + metadataLog.get(0).getOrElse { + val offset = LongOffset(clock.getTimeMillis()) + metadataLog.add(0, offset) + logInfo(s"Start time: $offset") + offset + }.offset + } + + @volatile private var lastTimeMs: Long = creationTimeMs + + private var start: LongOffset = _ + private var end: LongOffset = _ + + override def readSchema(): StructType = SCHEMA + + override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = { + this.start = start.orElse(LongOffset(0L)).asInstanceOf[LongOffset] + this.end = end.orElse { + val now = clock.getTimeMillis() + if (lastTimeMs < now) { + lastTimeMs = now + } + LongOffset(TimeUnit.MILLISECONDS.toSeconds(lastTimeMs - creationTimeMs)) + }.asInstanceOf[LongOffset] + } + + override def getStartOffset(): Offset = { + if (start == null) throw new IllegalStateException("start offset not set") + start + } + override def getEndOffset(): Offset = { + if (end == null) throw new IllegalStateException("end offset not set") + end + } + + override def deserializeOffset(json: String): Offset = { + LongOffset(json.toLong) + } + + override def planInputPartitions(): java.util.List[InputPartition[InternalRow]] = { + val startSeconds = LongOffset.convert(start).map(_.offset).getOrElse(0L) + val endSeconds = LongOffset.convert(end).map(_.offset).getOrElse(0L) + assert(startSeconds <= endSeconds, s"startSeconds($startSeconds) > endSeconds($endSeconds)") + if (endSeconds > maxSeconds) { + throw new ArithmeticException("Integer overflow. Max offset with " + + s"$rowsPerSecond rowsPerSecond is $maxSeconds, but it's $endSeconds now.") + } + // Fix "lastTimeMs" for recovery + if (lastTimeMs < TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs) { + lastTimeMs = TimeUnit.SECONDS.toMillis(endSeconds) + creationTimeMs + } + val rangeStart = valueAtSecond(startSeconds, rowsPerSecond, rampUpTimeSeconds) + val rangeEnd = valueAtSecond(endSeconds, rowsPerSecond, rampUpTimeSeconds) + logDebug(s"startSeconds: $startSeconds, endSeconds: $endSeconds, " + + s"rangeStart: $rangeStart, rangeEnd: $rangeEnd") + + if (rangeStart == rangeEnd) { + return List.empty.asJava + } + + val localStartTimeMs = creationTimeMs + TimeUnit.SECONDS.toMillis(startSeconds) + val relativeMsPerValue = + TimeUnit.SECONDS.toMillis(endSeconds - startSeconds).toDouble / (rangeEnd - rangeStart) + val numPartitions = { + val activeSession = SparkSession.getActiveSession + require(activeSession.isDefined) + Option(options.get(NUM_PARTITIONS).orElse(null.asInstanceOf[String])) + .map(_.toInt) + .getOrElse(activeSession.get.sparkContext.defaultParallelism) + } + + (0 until numPartitions).map { p => + new RateStreamMicroBatchInputPartition( + p, numPartitions, rangeStart, rangeEnd, localStartTimeMs, relativeMsPerValue) + : InputPartition[InternalRow] + }.toList.asJava + } + + override def commit(end: Offset): Unit = {} + + override def stop(): Unit = {} + + override def toString: String = s"RateStreamV2[rowsPerSecond=$rowsPerSecond, " + + s"rampUpTimeSeconds=$rampUpTimeSeconds, " + + s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}" +} + +class RateStreamMicroBatchInputPartition( + partitionId: Int, + numPartitions: Int, + rangeStart: Long, + rangeEnd: Long, + localStartTimeMs: Long, + relativeMsPerValue: Double) extends InputPartition[InternalRow] { + + override def createPartitionReader(): InputPartitionReader[InternalRow] = + new RateStreamMicroBatchInputPartitionReader(partitionId, numPartitions, rangeStart, rangeEnd, + localStartTimeMs, relativeMsPerValue) +} + +class RateStreamMicroBatchInputPartitionReader( + partitionId: Int, + numPartitions: Int, + rangeStart: Long, + rangeEnd: Long, + localStartTimeMs: Long, + relativeMsPerValue: Double) extends InputPartitionReader[InternalRow] { + private var count: Long = 0 + + override def next(): Boolean = { + rangeStart + partitionId + numPartitions * count < rangeEnd + } + + override def get(): InternalRow = { + val currValue = rangeStart + partitionId + numPartitions * count + count += 1 + val relative = math.round((currValue - rangeStart) * relativeMsPerValue) + InternalRow(DateTimeUtils.fromMillis(relative + localStartTimeMs), currValue) + } + + override def close(): Unit = {} +} http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala index 6942dfb..6bdd492 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala @@ -17,11 +17,14 @@ package org.apache.spark.sql.execution.streaming.sources +import java.util.Optional + import org.apache.spark.network.util.JavaUtils -import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReadSupport +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousReader import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader} import org.apache.spark.sql.types._ /** @@ -39,12 +42,13 @@ import org.apache.spark.sql.types._ * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. */ class RateStreamProvider extends DataSourceV2 - with MicroBatchReadSupportProvider with ContinuousReadSupportProvider with DataSourceRegister { + with MicroBatchReadSupport with ContinuousReadSupport with DataSourceRegister { import RateStreamProvider._ - override def createMicroBatchReadSupport( + override def createMicroBatchReader( + schema: Optional[StructType], checkpointLocation: String, - options: DataSourceOptions): MicroBatchReadSupport = { + options: DataSourceOptions): MicroBatchReader = { if (options.get(ROWS_PER_SECOND).isPresent) { val rowsPerSecond = options.get(ROWS_PER_SECOND).get().toLong if (rowsPerSecond <= 0) { @@ -70,14 +74,17 @@ class RateStreamProvider extends DataSourceV2 } } - new RateStreamMicroBatchReadSupport(options, checkpointLocation) + if (schema.isPresent) { + throw new AnalysisException("The rate source does not support a user-specified schema.") + } + + new RateStreamMicroBatchReader(options, checkpointLocation) } - override def createContinuousReadSupport( + override def createContinuousReader( + schema: Optional[StructType], checkpointLocation: String, - options: DataSourceOptions): ContinuousReadSupport = { - new RateStreamContinuousReadSupport(options) - } + options: DataSourceOptions): ContinuousReader = new RateStreamContinuousReader(options) override def shortName(): String = "rate" } http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index c50dc7b..cb76e86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -32,9 +32,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamingWriteSupportProvider} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamWriteSupport} import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -42,15 +42,13 @@ import org.apache.spark.sql.types.StructType * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySinkV2 extends DataSourceV2 with StreamingWriteSupportProvider - with MemorySinkBase with Logging { - - override def createStreamingWriteSupport( +class MemorySinkV2 extends DataSourceV2 with StreamWriteSupport with MemorySinkBase with Logging { + override def createStreamWriter( queryId: String, schema: StructType, mode: OutputMode, - options: DataSourceOptions): StreamingWriteSupport = { - new MemoryStreamingWriteSupport(this, mode, schema) + options: DataSourceOptions): StreamWriter = { + new MemoryStreamWriter(this, mode, schema) } private case class AddedData(batchId: Long, data: Array[Row]) @@ -122,13 +120,10 @@ class MemorySinkV2 extends DataSourceV2 with StreamingWriteSupportProvider case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {} -class MemoryStreamingWriteSupport( - val sink: MemorySinkV2, outputMode: OutputMode, schema: StructType) - extends StreamingWriteSupport { +class MemoryStreamWriter(val sink: MemorySinkV2, outputMode: OutputMode, schema: StructType) + extends StreamWriter { - override def createStreamingWriterFactory: MemoryWriterFactory = { - MemoryWriterFactory(outputMode, schema) - } + override def createWriterFactory: MemoryWriterFactory = MemoryWriterFactory(outputMode, schema) override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { val newRows = messages.flatMap { @@ -143,19 +138,13 @@ class MemoryStreamingWriteSupport( } case class MemoryWriterFactory(outputMode: OutputMode, schema: StructType) - extends DataWriterFactory with StreamingDataWriterFactory { + extends DataWriterFactory[InternalRow] { - override def createWriter( - partitionId: Int, - taskId: Long): DataWriter[InternalRow] = { - new MemoryDataWriter(partitionId, outputMode, schema) - } - - override def createWriter( + override def createDataWriter( partitionId: Int, taskId: Long, epochId: Long): DataWriter[InternalRow] = { - createWriter(partitionId, taskId) + new MemoryDataWriter(partitionId, outputMode, schema) } } http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala index b2a573e..874c479 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/socket.scala @@ -20,10 +20,11 @@ package org.apache.spark.sql.execution.streaming.sources import java.io.{BufferedReader, InputStreamReader, IOException} import java.net.Socket import java.text.SimpleDateFormat -import java.util.{Calendar, Locale} +import java.util.{Calendar, List => JList, Locale, Optional} import java.util.concurrent.atomic.AtomicBoolean import javax.annotation.concurrent.GuardedBy +import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer import scala.util.{Failure, Success, Try} @@ -31,15 +32,16 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.streaming.{LongOffset, SimpleStreamingScanConfig, SimpleStreamingScanConfigBuilder} -import org.apache.spark.sql.execution.streaming.continuous.TextSocketContinuousReadSupport +import org.apache.spark.sql.execution.streaming.LongOffset +import org.apache.spark.sql.execution.streaming.continuous.TextSocketContinuousReader import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, DataSourceV2, MicroBatchReadSupportProvider} -import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, MicroBatchReadSupport, Offset} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, DataSourceV2, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.{InputPartition, InputPartitionReader} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, MicroBatchReader, Offset} import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} import org.apache.spark.unsafe.types.UTF8String +// Shared object for micro-batch and continuous reader object TextSocketReader { val SCHEMA_REGULAR = StructType(StructField("value", StringType) :: Nil) val SCHEMA_TIMESTAMP = StructType(StructField("value", StringType) :: @@ -48,12 +50,14 @@ object TextSocketReader { } /** - * A MicroBatchReadSupport that reads text lines through a TCP socket, designed only for tutorials - * and debugging. This MicroBatchReadSupport will *not* work in production applications due to - * multiple reasons, including no support for fault recovery. + * A MicroBatchReader that reads text lines through a TCP socket, designed only for tutorials and + * debugging. This MicroBatchReader will *not* work in production applications due to multiple + * reasons, including no support for fault recovery. */ -class TextSocketMicroBatchReadSupport(options: DataSourceOptions) - extends MicroBatchReadSupport with Logging { +class TextSocketMicroBatchReader(options: DataSourceOptions) extends MicroBatchReader with Logging { + + private var startOffset: Offset = _ + private var endOffset: Offset = _ private val host: String = options.get("host").get() private val port: Int = options.get("port").get().toInt @@ -99,7 +103,7 @@ class TextSocketMicroBatchReadSupport(options: DataSourceOptions) logWarning(s"Stream closed by $host:$port") return } - TextSocketMicroBatchReadSupport.this.synchronized { + TextSocketMicroBatchReader.this.synchronized { val newData = ( UTF8String.fromString(line), DateTimeUtils.fromMillis(Calendar.getInstance().getTimeInMillis) @@ -116,15 +120,24 @@ class TextSocketMicroBatchReadSupport(options: DataSourceOptions) readThread.start() } - override def initialOffset(): Offset = LongOffset(-1L) + override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = synchronized { + startOffset = start.orElse(LongOffset(-1L)) + endOffset = end.orElse(currentOffset) + } - override def latestOffset(): Offset = currentOffset + override def getStartOffset(): Offset = { + Option(startOffset).getOrElse(throw new IllegalStateException("start offset not set")) + } + + override def getEndOffset(): Offset = { + Option(endOffset).getOrElse(throw new IllegalStateException("end offset not set")) + } override def deserializeOffset(json: String): Offset = { LongOffset(json.toLong) } - override def fullSchema(): StructType = { + override def readSchema(): StructType = { if (options.getBoolean("includeTimestamp", false)) { TextSocketReader.SCHEMA_TIMESTAMP } else { @@ -132,14 +145,12 @@ class TextSocketMicroBatchReadSupport(options: DataSourceOptions) } } - override def newScanConfigBuilder(start: Offset, end: Offset): ScanConfigBuilder = { - new SimpleStreamingScanConfigBuilder(fullSchema(), start, Some(end)) - } + override def planInputPartitions(): JList[InputPartition[InternalRow]] = { + assert(startOffset != null && endOffset != null, + "start offset and end offset should already be set before create read tasks.") - override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { - val sc = config.asInstanceOf[SimpleStreamingScanConfig] - val startOrdinal = sc.start.asInstanceOf[LongOffset].offset.toInt + 1 - val endOrdinal = sc.end.get.asInstanceOf[LongOffset].offset.toInt + 1 + val startOrdinal = LongOffset.convert(startOffset).get.offset.toInt + 1 + val endOrdinal = LongOffset.convert(endOffset).get.offset.toInt + 1 // Internal buffer only holds the batches after lastOffsetCommitted val rawList = synchronized { @@ -161,29 +172,26 @@ class TextSocketMicroBatchReadSupport(options: DataSourceOptions) slices(idx % numPartitions).append(r) } - slices.map(TextSocketInputPartition) - } + (0 until numPartitions).map { i => + val slice = slices(i) + new InputPartition[InternalRow] { + override def createPartitionReader(): InputPartitionReader[InternalRow] = + new InputPartitionReader[InternalRow] { + private var currentIdx = -1 - override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { - new PartitionReaderFactory { - override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { - val slice = partition.asInstanceOf[TextSocketInputPartition].slice - new PartitionReader[InternalRow] { - private var currentIdx = -1 + override def next(): Boolean = { + currentIdx += 1 + currentIdx < slice.size + } - override def next(): Boolean = { - currentIdx += 1 - currentIdx < slice.size - } + override def get(): InternalRow = { + InternalRow(slice(currentIdx)._1, slice(currentIdx)._2) + } - override def get(): InternalRow = { - InternalRow(slice(currentIdx)._1, slice(currentIdx)._2) + override def close(): Unit = {} } - - override def close(): Unit = {} - } } - } + }.toList.asJava } override def commit(end: Offset): Unit = synchronized { @@ -219,11 +227,8 @@ class TextSocketMicroBatchReadSupport(options: DataSourceOptions) override def toString: String = s"TextSocketV2[host: $host, port: $port]" } -case class TextSocketInputPartition(slice: ListBuffer[(UTF8String, Long)]) extends InputPartition - class TextSocketSourceProvider extends DataSourceV2 - with MicroBatchReadSupportProvider with ContinuousReadSupportProvider - with DataSourceRegister with Logging { + with MicroBatchReadSupport with ContinuousReadSupport with DataSourceRegister with Logging { private def checkParameters(params: DataSourceOptions): Unit = { logWarning("The socket source should not be used for production applications! " + @@ -243,18 +248,27 @@ class TextSocketSourceProvider extends DataSourceV2 } } - override def createMicroBatchReadSupport( + override def createMicroBatchReader( + schema: Optional[StructType], checkpointLocation: String, - options: DataSourceOptions): MicroBatchReadSupport = { + options: DataSourceOptions): MicroBatchReader = { checkParameters(options) - new TextSocketMicroBatchReadSupport(options) + if (schema.isPresent) { + throw new AnalysisException("The socket source does not support a user-specified schema.") + } + + new TextSocketMicroBatchReader(options) } - override def createContinuousReadSupport( + override def createContinuousReader( + schema: Optional[StructType], checkpointLocation: String, - options: DataSourceOptions): ContinuousReadSupport = { + options: DataSourceOptions): ContinuousReader = { checkParameters(options) - new TextSocketContinuousReadSupport(options) + if (schema.isPresent) { + throw new AnalysisException("The socket source does not support a user-specified schema.") + } + new TextSocketContinuousReader(options) } /** String that represents the format that this data source provider uses. */ http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 2a4db4a..7eb5db5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.streaming -import java.util.Locale +import java.util.{Locale, Optional} import scala.collection.JavaConverters._ @@ -28,8 +28,8 @@ import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} import org.apache.spark.sql.sources.StreamSourceProvider -import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, MicroBatchReadSupportProvider} -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReadSupport, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -172,21 +172,19 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo case _ => None } ds match { - case s: MicroBatchReadSupportProvider => - var tempReadSupport: MicroBatchReadSupport = null + case s: MicroBatchReadSupport => + var tempReader: MicroBatchReader = null val schema = try { - val tmpCheckpointPath = Utils.createTempDir(namePrefix = s"tempCP").getCanonicalPath - tempReadSupport = if (userSpecifiedSchema.isDefined) { - s.createMicroBatchReadSupport(userSpecifiedSchema.get, tmpCheckpointPath, options) - } else { - s.createMicroBatchReadSupport(tmpCheckpointPath, options) - } - tempReadSupport.fullSchema() + tempReader = s.createMicroBatchReader( + Optional.ofNullable(userSpecifiedSchema.orNull), + Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath, + options) + tempReader.readSchema() } finally { // Stop tempReader to avoid side-effect thing - if (tempReadSupport != null) { - tempReadSupport.stop() - tempReadSupport = null + if (tempReader != null) { + tempReader.stop() + tempReader = null } } Dataset.ofRows( @@ -194,28 +192,16 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo StreamingRelationV2( s, source, extraOptions.toMap, schema.toAttributes, v1Relation)(sparkSession)) - case s: ContinuousReadSupportProvider => - var tempReadSupport: ContinuousReadSupport = null - val schema = try { - val tmpCheckpointPath = Utils.createTempDir(namePrefix = s"tempCP").getCanonicalPath - tempReadSupport = if (userSpecifiedSchema.isDefined) { - s.createContinuousReadSupport(userSpecifiedSchema.get, tmpCheckpointPath, options) - } else { - s.createContinuousReadSupport(tmpCheckpointPath, options) - } - tempReadSupport.fullSchema() - } finally { - // Stop tempReader to avoid side-effect thing - if (tempReadSupport != null) { - tempReadSupport.stop() - tempReadSupport = null - } - } + case s: ContinuousReadSupport => + val tempReader = s.createContinuousReader( + Optional.ofNullable(userSpecifiedSchema.orNull), + Utils.createTempDir(namePrefix = s"temporaryReader").getCanonicalPath, + options) Dataset.ofRows( sparkSession, StreamingRelationV2( s, source, extraOptions.toMap, - schema.toAttributes, v1Relation)(sparkSession)) + tempReader.readSchema().toAttributes, v1Relation)(sparkSession)) case _ => // Code path for data source v1. Dataset.ofRows(sparkSession, StreamingRelation(v1DataSource)) http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 7866e4f..3b9a56f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -21,7 +21,7 @@ import java.util.Locale import scala.collection.JavaConverters._ -import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.annotation.{InterfaceStability, Since} import org.apache.spark.api.java.function.VoidFunction2 import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes @@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources._ -import org.apache.spark.sql.sources.v2.StreamingWriteSupportProvider +import org.apache.spark.sql.sources.v2.StreamWriteSupport /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -270,7 +270,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { query } else if (source == "foreach") { assertNotPartitioned("foreach") - val sink = ForeachWriteSupportProvider[T](foreachWriter, ds.exprEnc) + val sink = ForeachWriterProvider[T](foreachWriter, ds.exprEnc) df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), @@ -299,8 +299,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) val disabledSources = df.sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",") val sink = ds.newInstance() match { - case w: StreamingWriteSupportProvider - if !disabledSources.contains(w.getClass.getCanonicalName) => w + case w: StreamWriteSupport if !disabledSources.contains(w.getClass.getCanonicalName) => w case _ => val ds = DataSource( df.sparkSession, http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index cd52d99..25bb052 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.STREAMING_QUERY_LISTENERS -import org.apache.spark.sql.sources.v2.StreamingWriteSupportProvider +import org.apache.spark.sql.sources.v2.StreamWriteSupport import org.apache.spark.util.{Clock, SystemClock, Utils} /** @@ -256,7 +256,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo } (sink, trigger) match { - case (v2Sink: StreamingWriteSupportProvider, trigger: ContinuousTrigger) => + case (v2Sink: StreamWriteSupport, trigger: ContinuousTrigger) => if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) { UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) } http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index 5602310..e4cead9 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -24,71 +24,29 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.sources.Filter; import org.apache.spark.sql.sources.GreaterThan; -import org.apache.spark.sql.sources.v2.BatchReadSupportProvider; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.StructType; -public class JavaAdvancedDataSourceV2 implements DataSourceV2, BatchReadSupportProvider { +public class JavaAdvancedDataSourceV2 implements DataSourceV2, ReadSupport { - public class ReadSupport extends JavaSimpleReadSupport { - @Override - public ScanConfigBuilder newScanConfigBuilder() { - return new AdvancedScanConfigBuilder(); - } - - @Override - public InputPartition[] planInputPartitions(ScanConfig config) { - Filter[] filters = ((AdvancedScanConfigBuilder) config).filters; - List<InputPartition> res = new ArrayList<>(); - - Integer lowerBound = null; - for (Filter filter : filters) { - if (filter instanceof GreaterThan) { - GreaterThan f = (GreaterThan) filter; - if ("i".equals(f.attribute()) && f.value() instanceof Integer) { - lowerBound = (Integer) f.value(); - break; - } - } - } - - if (lowerBound == null) { - res.add(new JavaRangeInputPartition(0, 5)); - res.add(new JavaRangeInputPartition(5, 10)); - } else if (lowerBound < 4) { - res.add(new JavaRangeInputPartition(lowerBound + 1, 5)); - res.add(new JavaRangeInputPartition(5, 10)); - } else if (lowerBound < 9) { - res.add(new JavaRangeInputPartition(lowerBound + 1, 10)); - } - - return res.stream().toArray(InputPartition[]::new); - } - - @Override - public PartitionReaderFactory createReaderFactory(ScanConfig config) { - StructType requiredSchema = ((AdvancedScanConfigBuilder) config).requiredSchema; - return new AdvancedReaderFactory(requiredSchema); - } - } - - public static class AdvancedScanConfigBuilder implements ScanConfigBuilder, ScanConfig, - SupportsPushDownFilters, SupportsPushDownRequiredColumns { + public class Reader implements DataSourceReader, SupportsPushDownRequiredColumns, + SupportsPushDownFilters { // Exposed for testing. public StructType requiredSchema = new StructType().add("i", "int").add("j", "int"); public Filter[] filters = new Filter[0]; @Override - public void pruneColumns(StructType requiredSchema) { - this.requiredSchema = requiredSchema; + public StructType readSchema() { + return requiredSchema; } @Override - public StructType readSchema() { - return requiredSchema; + public void pruneColumns(StructType requiredSchema) { + this.requiredSchema = requiredSchema; } @Override @@ -121,54 +79,79 @@ public class JavaAdvancedDataSourceV2 implements DataSourceV2, BatchReadSupportP } @Override - public ScanConfig build() { - return this; + public List<InputPartition<InternalRow>> planInputPartitions() { + List<InputPartition<InternalRow>> res = new ArrayList<>(); + + Integer lowerBound = null; + for (Filter filter : filters) { + if (filter instanceof GreaterThan) { + GreaterThan f = (GreaterThan) filter; + if ("i".equals(f.attribute()) && f.value() instanceof Integer) { + lowerBound = (Integer) f.value(); + break; + } + } + } + + if (lowerBound == null) { + res.add(new JavaAdvancedInputPartition(0, 5, requiredSchema)); + res.add(new JavaAdvancedInputPartition(5, 10, requiredSchema)); + } else if (lowerBound < 4) { + res.add(new JavaAdvancedInputPartition(lowerBound + 1, 5, requiredSchema)); + res.add(new JavaAdvancedInputPartition(5, 10, requiredSchema)); + } else if (lowerBound < 9) { + res.add(new JavaAdvancedInputPartition(lowerBound + 1, 10, requiredSchema)); + } + + return res; } } - static class AdvancedReaderFactory implements PartitionReaderFactory { - StructType requiredSchema; + static class JavaAdvancedInputPartition implements InputPartition<InternalRow>, + InputPartitionReader<InternalRow> { + private int start; + private int end; + private StructType requiredSchema; - AdvancedReaderFactory(StructType requiredSchema) { + JavaAdvancedInputPartition(int start, int end, StructType requiredSchema) { + this.start = start; + this.end = end; this.requiredSchema = requiredSchema; } @Override - public PartitionReader<InternalRow> createReader(InputPartition partition) { - JavaRangeInputPartition p = (JavaRangeInputPartition) partition; - return new PartitionReader<InternalRow>() { - private int current = p.start - 1; - - @Override - public boolean next() throws IOException { - current += 1; - return current < p.end; - } + public InputPartitionReader<InternalRow> createPartitionReader() { + return new JavaAdvancedInputPartition(start - 1, end, requiredSchema); + } - @Override - public InternalRow get() { - Object[] values = new Object[requiredSchema.size()]; - for (int i = 0; i < values.length; i++) { - if ("i".equals(requiredSchema.apply(i).name())) { - values[i] = current; - } else if ("j".equals(requiredSchema.apply(i).name())) { - values[i] = -current; - } - } - return new GenericInternalRow(values); + @Override + public boolean next() { + start += 1; + return start < end; + } + + @Override + public InternalRow get() { + Object[] values = new Object[requiredSchema.size()]; + for (int i = 0; i < values.length; i++) { + if ("i".equals(requiredSchema.apply(i).name())) { + values[i] = start; + } else if ("j".equals(requiredSchema.apply(i).name())) { + values[i] = -start; } + } + return new GenericInternalRow(values); + } - @Override - public void close() throws IOException { + @Override + public void close() throws IOException { - } - }; } } @Override - public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { - return new ReadSupport(); + public DataSourceReader createReader(DataSourceOptions options) { + return new Reader(); } } http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java new file mode 100644 index 0000000..97d6176 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaBatchDataSourceV2.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import java.io.IOException; +import java.util.List; + +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; +import org.apache.spark.sql.sources.v2.DataSourceOptions; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarBatch; + + +public class JavaBatchDataSourceV2 implements DataSourceV2, ReadSupport { + + class Reader implements DataSourceReader, SupportsScanColumnarBatch { + private final StructType schema = new StructType().add("i", "int").add("j", "int"); + + @Override + public StructType readSchema() { + return schema; + } + + @Override + public List<InputPartition<ColumnarBatch>> planBatchInputPartitions() { + return java.util.Arrays.asList( + new JavaBatchInputPartition(0, 50), new JavaBatchInputPartition(50, 90)); + } + } + + static class JavaBatchInputPartition + implements InputPartition<ColumnarBatch>, InputPartitionReader<ColumnarBatch> { + private int start; + private int end; + + private static final int BATCH_SIZE = 20; + + private OnHeapColumnVector i; + private OnHeapColumnVector j; + private ColumnarBatch batch; + + JavaBatchInputPartition(int start, int end) { + this.start = start; + this.end = end; + } + + @Override + public InputPartitionReader<ColumnarBatch> createPartitionReader() { + this.i = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); + this.j = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); + ColumnVector[] vectors = new ColumnVector[2]; + vectors[0] = i; + vectors[1] = j; + this.batch = new ColumnarBatch(vectors); + return this; + } + + @Override + public boolean next() { + i.reset(); + j.reset(); + int count = 0; + while (start < end && count < BATCH_SIZE) { + i.putInt(count, start); + j.putInt(count, -start); + start += 1; + count += 1; + } + + if (count == 0) { + return false; + } else { + batch.setNumRows(count); + return true; + } + } + + @Override + public ColumnarBatch get() { + return batch; + } + + @Override + public void close() throws IOException { + batch.close(); + } + } + + + @Override + public DataSourceReader createReader(DataSourceOptions options) { + return new Reader(); + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java deleted file mode 100644 index 28a9330..0000000 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package test.org.apache.spark.sql.sources.v2; - -import java.io.IOException; - -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; -import org.apache.spark.sql.sources.v2.BatchReadSupportProvider; -import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.reader.*; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.vectorized.ColumnVector; -import org.apache.spark.sql.vectorized.ColumnarBatch; - - -public class JavaColumnarDataSourceV2 implements DataSourceV2, BatchReadSupportProvider { - - class ReadSupport extends JavaSimpleReadSupport { - - @Override - public InputPartition[] planInputPartitions(ScanConfig config) { - InputPartition[] partitions = new InputPartition[2]; - partitions[0] = new JavaRangeInputPartition(0, 50); - partitions[1] = new JavaRangeInputPartition(50, 90); - return partitions; - } - - @Override - public PartitionReaderFactory createReaderFactory(ScanConfig config) { - return new ColumnarReaderFactory(); - } - } - - static class ColumnarReaderFactory implements PartitionReaderFactory { - private static final int BATCH_SIZE = 20; - - @Override - public boolean supportColumnarReads(InputPartition partition) { - return true; - } - - @Override - public PartitionReader<InternalRow> createReader(InputPartition partition) { - throw new UnsupportedOperationException(""); - } - - @Override - public PartitionReader<ColumnarBatch> createColumnarReader(InputPartition partition) { - JavaRangeInputPartition p = (JavaRangeInputPartition) partition; - OnHeapColumnVector i = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); - OnHeapColumnVector j = new OnHeapColumnVector(BATCH_SIZE, DataTypes.IntegerType); - ColumnVector[] vectors = new ColumnVector[2]; - vectors[0] = i; - vectors[1] = j; - ColumnarBatch batch = new ColumnarBatch(vectors); - - return new PartitionReader<ColumnarBatch>() { - private int current = p.start; - - @Override - public boolean next() throws IOException { - i.reset(); - j.reset(); - int count = 0; - while (current < p.end && count < BATCH_SIZE) { - i.putInt(count, current); - j.putInt(count, -current); - current += 1; - count += 1; - } - - if (count == 0) { - return false; - } else { - batch.setNumRows(count); - return true; - } - } - - @Override - public ColumnarBatch get() { - return batch; - } - - @Override - public void close() throws IOException { - batch.close(); - } - }; - } - } - - @Override - public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { - return new ReadSupport(); - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java index 18a11dd..2d21324 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java @@ -19,34 +19,38 @@ package test.org.apache.spark.sql.sources.v2; import java.io.IOException; import java.util.Arrays; +import java.util.List; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; -import org.apache.spark.sql.sources.v2.*; +import org.apache.spark.sql.sources.v2.DataSourceOptions; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.ReadSupport; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.sources.v2.reader.partitioning.ClusteredDistribution; import org.apache.spark.sql.sources.v2.reader.partitioning.Distribution; import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning; +import org.apache.spark.sql.types.StructType; -public class JavaPartitionAwareDataSource implements DataSourceV2, BatchReadSupportProvider { +public class JavaPartitionAwareDataSource implements DataSourceV2, ReadSupport { - class ReadSupport extends JavaSimpleReadSupport implements SupportsReportPartitioning { + class Reader implements DataSourceReader, SupportsReportPartitioning { + private final StructType schema = new StructType().add("a", "int").add("b", "int"); @Override - public InputPartition[] planInputPartitions(ScanConfig config) { - InputPartition[] partitions = new InputPartition[2]; - partitions[0] = new SpecificInputPartition(new int[]{1, 1, 3}, new int[]{4, 4, 6}); - partitions[1] = new SpecificInputPartition(new int[]{2, 4, 4}, new int[]{6, 2, 2}); - return partitions; + public StructType readSchema() { + return schema; } @Override - public PartitionReaderFactory createReaderFactory(ScanConfig config) { - return new SpecificReaderFactory(); + public List<InputPartition<InternalRow>> planInputPartitions() { + return java.util.Arrays.asList( + new SpecificInputPartition(new int[]{1, 1, 3}, new int[]{4, 4, 6}), + new SpecificInputPartition(new int[]{2, 4, 4}, new int[]{6, 2, 2})); } @Override - public Partitioning outputPartitioning(ScanConfig config) { + public Partitioning outputPartitioning() { return new MyPartitioning(); } } @@ -62,53 +66,50 @@ public class JavaPartitionAwareDataSource implements DataSourceV2, BatchReadSupp public boolean satisfy(Distribution distribution) { if (distribution instanceof ClusteredDistribution) { String[] clusteredCols = ((ClusteredDistribution) distribution).clusteredColumns; - return Arrays.asList(clusteredCols).contains("i"); + return Arrays.asList(clusteredCols).contains("a"); } return false; } } - static class SpecificInputPartition implements InputPartition { - int[] i; - int[] j; + static class SpecificInputPartition implements InputPartition<InternalRow>, + InputPartitionReader<InternalRow> { + + private int[] i; + private int[] j; + private int current = -1; SpecificInputPartition(int[] i, int[] j) { assert i.length == j.length; this.i = i; this.j = j; } - } - static class SpecificReaderFactory implements PartitionReaderFactory { + @Override + public boolean next() throws IOException { + current += 1; + return current < i.length; + } + + @Override + public InternalRow get() { + return new GenericInternalRow(new Object[] {i[current], j[current]}); + } + + @Override + public void close() throws IOException { + + } @Override - public PartitionReader<InternalRow> createReader(InputPartition partition) { - SpecificInputPartition p = (SpecificInputPartition) partition; - return new PartitionReader<InternalRow>() { - private int current = -1; - - @Override - public boolean next() throws IOException { - current += 1; - return current < p.i.length; - } - - @Override - public InternalRow get() { - return new GenericInternalRow(new Object[] {p.i[current], p.j[current]}); - } - - @Override - public void close() throws IOException { - - } - }; + public InputPartitionReader<InternalRow> createPartitionReader() { + return this; } } @Override - public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { - return new ReadSupport(); + public DataSourceReader createReader(DataSourceOptions options) { + return new Reader(); } } http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java index cc9ac04..6fd6a44 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java @@ -17,39 +17,43 @@ package test.org.apache.spark.sql.sources.v2; -import org.apache.spark.sql.sources.v2.BatchReadSupportProvider; +import java.util.List; + +import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.DataSourceV2; -import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.sources.v2.reader.DataSourceReader; +import org.apache.spark.sql.sources.v2.reader.InputPartition; import org.apache.spark.sql.types.StructType; -public class JavaSchemaRequiredDataSource implements DataSourceV2, BatchReadSupportProvider { +public class JavaSchemaRequiredDataSource implements DataSourceV2, ReadSupport { - class ReadSupport extends JavaSimpleReadSupport { + class Reader implements DataSourceReader { private final StructType schema; - ReadSupport(StructType schema) { + Reader(StructType schema) { this.schema = schema; } @Override - public StructType fullSchema() { + public StructType readSchema() { return schema; } @Override - public InputPartition[] planInputPartitions(ScanConfig config) { - return new InputPartition[0]; + public List<InputPartition<InternalRow>> planInputPartitions() { + return java.util.Collections.emptyList(); } } @Override - public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { + public DataSourceReader createReader(DataSourceOptions options) { throw new IllegalArgumentException("requires a user-supplied schema"); } @Override - public BatchReadSupport createBatchReadSupport(StructType schema, DataSourceOptions options) { - return new ReadSupport(schema); + public DataSourceReader createReader(StructType schema, DataSourceOptions options) { + return new Reader(schema); } } http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java index 2cdbba8..274dc37 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java @@ -17,26 +17,72 @@ package test.org.apache.spark.sql.sources.v2; -import org.apache.spark.sql.sources.v2.BatchReadSupportProvider; +import java.io.IOException; +import java.util.List; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; +import org.apache.spark.sql.sources.v2.reader.InputPartition; +import org.apache.spark.sql.sources.v2.reader.DataSourceReader; +import org.apache.spark.sql.types.StructType; + +public class JavaSimpleDataSourceV2 implements DataSourceV2, ReadSupport { + + class Reader implements DataSourceReader { + private final StructType schema = new StructType().add("i", "int").add("j", "int"); + + @Override + public StructType readSchema() { + return schema; + } + + @Override + public List<InputPartition<InternalRow>> planInputPartitions() { + return java.util.Arrays.asList( + new JavaSimpleInputPartition(0, 5), + new JavaSimpleInputPartition(5, 10)); + } + } + + static class JavaSimpleInputPartition implements InputPartition<InternalRow>, + InputPartitionReader<InternalRow> { -public class JavaSimpleDataSourceV2 implements DataSourceV2, BatchReadSupportProvider { + private int start; + private int end; - class ReadSupport extends JavaSimpleReadSupport { + JavaSimpleInputPartition(int start, int end) { + this.start = start; + this.end = end; + } + + @Override + public InputPartitionReader<InternalRow> createPartitionReader() { + return new JavaSimpleInputPartition(start - 1, end); + } @Override - public InputPartition[] planInputPartitions(ScanConfig config) { - InputPartition[] partitions = new InputPartition[2]; - partitions[0] = new JavaRangeInputPartition(0, 5); - partitions[1] = new JavaRangeInputPartition(5, 10); - return partitions; + public boolean next() { + start += 1; + return start < end; + } + + @Override + public InternalRow get() { + return new GenericInternalRow(new Object[] {start, -start}); + } + + @Override + public void close() throws IOException { + } } @Override - public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { - return new ReadSupport(); + public DataSourceReader createReader(DataSourceOptions options) { + return new Reader(); } } http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java deleted file mode 100644 index 685f9b9..0000000 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package test.org.apache.spark.sql.sources.v2; - -import java.io.IOException; - -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; -import org.apache.spark.sql.sources.v2.reader.*; -import org.apache.spark.sql.types.StructType; - -abstract class JavaSimpleReadSupport implements BatchReadSupport { - - @Override - public StructType fullSchema() { - return new StructType().add("i", "int").add("j", "int"); - } - - @Override - public ScanConfigBuilder newScanConfigBuilder() { - return new JavaNoopScanConfigBuilder(fullSchema()); - } - - @Override - public PartitionReaderFactory createReaderFactory(ScanConfig config) { - return new JavaSimpleReaderFactory(); - } -} - -class JavaNoopScanConfigBuilder implements ScanConfigBuilder, ScanConfig { - - private StructType schema; - - JavaNoopScanConfigBuilder(StructType schema) { - this.schema = schema; - } - - @Override - public ScanConfig build() { - return this; - } - - @Override - public StructType readSchema() { - return schema; - } -} - -class JavaSimpleReaderFactory implements PartitionReaderFactory { - - @Override - public PartitionReader<InternalRow> createReader(InputPartition partition) { - JavaRangeInputPartition p = (JavaRangeInputPartition) partition; - return new PartitionReader<InternalRow>() { - private int current = p.start - 1; - - @Override - public boolean next() throws IOException { - current += 1; - return current < p.end; - } - - @Override - public InternalRow get() { - return new GenericInternalRow(new Object[] {current, -current}); - } - - @Override - public void close() throws IOException { - - } - }; - } -} - -class JavaRangeInputPartition implements InputPartition { - int start; - int end; - - JavaRangeInputPartition(int start, int end) { - this.start = start; - this.end = end; - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister ---------------------------------------------------------------------- diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index a36b0cf..46b38be 100644 --- a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -9,6 +9,6 @@ org.apache.spark.sql.streaming.sources.FakeReadMicroBatchOnly org.apache.spark.sql.streaming.sources.FakeReadContinuousOnly org.apache.spark.sql.streaming.sources.FakeReadBothModes org.apache.spark.sql.streaming.sources.FakeReadNeitherMode -org.apache.spark.sql.streaming.sources.FakeWriteSupportProvider +org.apache.spark.sql.streaming.sources.FakeWrite org.apache.spark.sql.streaming.sources.FakeNoWrite -org.apache.spark.sql.streaming.sources.FakeWriteSupportProviderV1Fallback +org.apache.spark.sql.streaming.sources.FakeWriteV1Fallback http://git-wip-us.apache.org/repos/asf/spark/blob/15d2e9d7/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala index 6185736..7bb2cf5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala @@ -43,7 +43,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("streaming writer") { val sink = new MemorySinkV2 - val writeSupport = new MemoryStreamingWriteSupport( + val writeSupport = new MemoryStreamWriter( sink, OutputMode.Append(), new StructType().add("i", "int")) writeSupport.commit(0, Array( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org