http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java index 274dc37..2cdbba8 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java @@ -17,72 +17,26 @@ package test.org.apache.spark.sql.sources.v2; -import java.io.IOException; -import java.util.List; - -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.sources.v2.BatchReadSupportProvider; import org.apache.spark.sql.sources.v2.DataSourceV2; import org.apache.spark.sql.sources.v2.DataSourceOptions; -import org.apache.spark.sql.sources.v2.ReadSupport; -import org.apache.spark.sql.sources.v2.reader.InputPartitionReader; -import org.apache.spark.sql.sources.v2.reader.InputPartition; -import org.apache.spark.sql.sources.v2.reader.DataSourceReader; -import org.apache.spark.sql.types.StructType; - -public class JavaSimpleDataSourceV2 implements DataSourceV2, ReadSupport { - - class Reader implements DataSourceReader { - private final StructType schema = new StructType().add("i", "int").add("j", "int"); - - @Override - public StructType readSchema() { - return schema; - } - - @Override - public List<InputPartition<InternalRow>> planInputPartitions() { - return java.util.Arrays.asList( - new JavaSimpleInputPartition(0, 5), - new JavaSimpleInputPartition(5, 10)); - } - } - - static class JavaSimpleInputPartition implements InputPartition<InternalRow>, - InputPartitionReader<InternalRow> { +import org.apache.spark.sql.sources.v2.reader.*; - private int start; - private int end; +public class JavaSimpleDataSourceV2 implements DataSourceV2, BatchReadSupportProvider { - JavaSimpleInputPartition(int start, int end) { - this.start = start; - this.end = end; - } - - @Override - public InputPartitionReader<InternalRow> createPartitionReader() { - return new JavaSimpleInputPartition(start - 1, end); - } + class ReadSupport extends JavaSimpleReadSupport { @Override - public boolean next() { - start += 1; - return start < end; - } - - @Override - public InternalRow get() { - return new GenericInternalRow(new Object[] {start, -start}); - } - - @Override - public void close() throws IOException { - + public InputPartition[] planInputPartitions(ScanConfig config) { + InputPartition[] partitions = new InputPartition[2]; + partitions[0] = new JavaRangeInputPartition(0, 5); + partitions[1] = new JavaRangeInputPartition(5, 10); + return partitions; } } @Override - public DataSourceReader createReader(DataSourceOptions options) { - return new Reader(); + public BatchReadSupport createBatchReadSupport(DataSourceOptions options) { + return new ReadSupport(); } }
http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java new file mode 100644 index 0000000..685f9b9 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleReadSupport.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import java.io.IOException; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.types.StructType; + +abstract class JavaSimpleReadSupport implements BatchReadSupport { + + @Override + public StructType fullSchema() { + return new StructType().add("i", "int").add("j", "int"); + } + + @Override + public ScanConfigBuilder newScanConfigBuilder() { + return new JavaNoopScanConfigBuilder(fullSchema()); + } + + @Override + public PartitionReaderFactory createReaderFactory(ScanConfig config) { + return new JavaSimpleReaderFactory(); + } +} + +class JavaNoopScanConfigBuilder implements ScanConfigBuilder, ScanConfig { + + private StructType schema; + + JavaNoopScanConfigBuilder(StructType schema) { + this.schema = schema; + } + + @Override + public ScanConfig build() { + return this; + } + + @Override + public StructType readSchema() { + return schema; + } +} + +class JavaSimpleReaderFactory implements PartitionReaderFactory { + + @Override + public PartitionReader<InternalRow> createReader(InputPartition partition) { + JavaRangeInputPartition p = (JavaRangeInputPartition) partition; + return new PartitionReader<InternalRow>() { + private int current = p.start - 1; + + @Override + public boolean next() throws IOException { + current += 1; + return current < p.end; + } + + @Override + public InternalRow get() { + return new GenericInternalRow(new Object[] {current, -current}); + } + + @Override + public void close() throws IOException { + + } + }; + } +} + +class JavaRangeInputPartition implements InputPartition { + int start; + int end; + + JavaRangeInputPartition(int start, int end) { + this.start = start; + this.end = end; + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister ---------------------------------------------------------------------- diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index 46b38be..a36b0cf 100644 --- a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -9,6 +9,6 @@ org.apache.spark.sql.streaming.sources.FakeReadMicroBatchOnly org.apache.spark.sql.streaming.sources.FakeReadContinuousOnly org.apache.spark.sql.streaming.sources.FakeReadBothModes org.apache.spark.sql.streaming.sources.FakeReadNeitherMode -org.apache.spark.sql.streaming.sources.FakeWrite +org.apache.spark.sql.streaming.sources.FakeWriteSupportProvider org.apache.spark.sql.streaming.sources.FakeNoWrite -org.apache.spark.sql.streaming.sources.FakeWriteV1Fallback +org.apache.spark.sql.streaming.sources.FakeWriteSupportProviderV1Fallback http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala index 1efaead..50f13be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala @@ -41,10 +41,11 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { assert(writer.commit().data.isEmpty) } - test("continuous writer") { + test("streaming writer") { val sink = new MemorySinkV2 - val writer = new MemoryStreamWriter(sink, OutputMode.Append(), new StructType().add("i", "int")) - writer.commit(0, + val writeSupport = new MemoryStreamingWriteSupport( + sink, OutputMode.Append(), new StructType().add("i", "int")) + writeSupport.commit(0, Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), @@ -52,29 +53,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { )) assert(sink.latestBatchId.contains(0)) assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) - writer.commit(19, - Array( - MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), - MemoryWriterCommitMessage(0, Seq(Row(33))) - )) - assert(sink.latestBatchId.contains(19)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11, 22, 33)) - - assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11, 22, 33)) - } - - test("microbatch writer") { - val sink = new MemorySinkV2 - val schema = new StructType().add("i", "int") - new MemoryWriter(sink, 0, OutputMode.Append(), schema).commit( - Array( - MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), - MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), - MemoryWriterCommitMessage(2, Seq(Row(6), Row(7))) - )) - assert(sink.latestBatchId.contains(0)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) - new MemoryWriter(sink, 19, OutputMode.Append(), schema).commit( + writeSupport.commit(19, Array( MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), MemoryWriterCommitMessage(0, Seq(Row(33))) @@ -88,22 +67,21 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("writer metrics") { val sink = new MemorySinkV2 val schema = new StructType().add("i", "int") + val writeSupport = new MemoryStreamingWriteSupport( + sink, OutputMode.Append(), schema) // batch 0 - var writer = new MemoryWriter(sink, 0, OutputMode.Append(), schema) - writer.commit( + writeSupport.commit(0, Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), MemoryWriterCommitMessage(2, Seq(Row(5), Row(6))) )) - assert(writer.getCustomMetrics.json() == "{\"numRows\":6}") + assert(writeSupport.getCustomMetrics.json() == "{\"numRows\":6}") // batch 1 - writer = new MemoryWriter(sink, 1, OutputMode.Append(), schema - ) - writer.commit( + writeSupport.commit(1, Array( MemoryWriterCommitMessage(0, Seq(Row(7), Row(8))) )) - assert(writer.getCustomMetrics.json() == "{\"numRows\":8}") + assert(writeSupport.getCustomMetrics.json() == "{\"numRows\":8}") } } http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala new file mode 100644 index 0000000..5884380 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupportSuite.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.io.ByteArrayOutputStream + +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.streaming.{StreamTest, Trigger} + +class ConsoleWriteSupportSuite extends StreamTest { + import testImplicits._ + + test("microbatch - default") { + val input = MemoryStream[Int] + + val captured = new ByteArrayOutputStream() + Console.withOut(captured) { + val query = input.toDF().writeStream.format("console").start() + try { + input.addData(1, 2, 3) + query.processAllAvailable() + input.addData(4, 5, 6) + query.processAllAvailable() + input.addData() + query.processAllAvailable() + } finally { + query.stop() + } + } + + assert(captured.toString() == + """------------------------------------------- + |Batch: 0 + |------------------------------------------- + |+-----+ + ||value| + |+-----+ + || 1| + || 2| + || 3| + |+-----+ + | + |------------------------------------------- + |Batch: 1 + |------------------------------------------- + |+-----+ + ||value| + |+-----+ + || 4| + || 5| + || 6| + |+-----+ + | + |------------------------------------------- + |Batch: 2 + |------------------------------------------- + |+-----+ + ||value| + |+-----+ + |+-----+ + | + |""".stripMargin) + } + + test("microbatch - with numRows") { + val input = MemoryStream[Int] + + val captured = new ByteArrayOutputStream() + Console.withOut(captured) { + val query = input.toDF().writeStream.format("console").option("NUMROWS", 2).start() + try { + input.addData(1, 2, 3) + query.processAllAvailable() + } finally { + query.stop() + } + } + + assert(captured.toString() == + """------------------------------------------- + |Batch: 0 + |------------------------------------------- + |+-----+ + ||value| + |+-----+ + || 1| + || 2| + |+-----+ + |only showing top 2 rows + | + |""".stripMargin) + } + + test("microbatch - truncation") { + val input = MemoryStream[String] + + val captured = new ByteArrayOutputStream() + Console.withOut(captured) { + val query = input.toDF().writeStream.format("console").option("TRUNCATE", true).start() + try { + input.addData("123456789012345678901234567890") + query.processAllAvailable() + } finally { + query.stop() + } + } + + assert(captured.toString() == + """------------------------------------------- + |Batch: 0 + |------------------------------------------- + |+--------------------+ + || value| + |+--------------------+ + ||12345678901234567...| + |+--------------------+ + | + |""".stripMargin) + } + + test("continuous - default") { + val captured = new ByteArrayOutputStream() + Console.withOut(captured) { + val input = spark.readStream + .format("rate") + .option("numPartitions", "1") + .option("rowsPerSecond", "5") + .load() + .select('value) + + val query = input.writeStream.format("console").trigger(Trigger.Continuous(200)).start() + assert(query.isActive) + query.stop() + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala deleted file mode 100644 index 55acf2b..0000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriterSuite.scala +++ /dev/null @@ -1,153 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.sources - -import java.io.ByteArrayOutputStream - -import org.scalatest.time.SpanSugar._ - -import org.apache.spark.sql.execution.streaming.MemoryStream -import org.apache.spark.sql.streaming.{StreamTest, Trigger} - -class ConsoleWriterSuite extends StreamTest { - import testImplicits._ - - test("microbatch - default") { - val input = MemoryStream[Int] - - val captured = new ByteArrayOutputStream() - Console.withOut(captured) { - val query = input.toDF().writeStream.format("console").start() - try { - input.addData(1, 2, 3) - query.processAllAvailable() - input.addData(4, 5, 6) - query.processAllAvailable() - input.addData() - query.processAllAvailable() - } finally { - query.stop() - } - } - - assert(captured.toString() == - """------------------------------------------- - |Batch: 0 - |------------------------------------------- - |+-----+ - ||value| - |+-----+ - || 1| - || 2| - || 3| - |+-----+ - | - |------------------------------------------- - |Batch: 1 - |------------------------------------------- - |+-----+ - ||value| - |+-----+ - || 4| - || 5| - || 6| - |+-----+ - | - |------------------------------------------- - |Batch: 2 - |------------------------------------------- - |+-----+ - ||value| - |+-----+ - |+-----+ - | - |""".stripMargin) - } - - test("microbatch - with numRows") { - val input = MemoryStream[Int] - - val captured = new ByteArrayOutputStream() - Console.withOut(captured) { - val query = input.toDF().writeStream.format("console").option("NUMROWS", 2).start() - try { - input.addData(1, 2, 3) - query.processAllAvailable() - } finally { - query.stop() - } - } - - assert(captured.toString() == - """------------------------------------------- - |Batch: 0 - |------------------------------------------- - |+-----+ - ||value| - |+-----+ - || 1| - || 2| - |+-----+ - |only showing top 2 rows - | - |""".stripMargin) - } - - test("microbatch - truncation") { - val input = MemoryStream[String] - - val captured = new ByteArrayOutputStream() - Console.withOut(captured) { - val query = input.toDF().writeStream.format("console").option("TRUNCATE", true).start() - try { - input.addData("123456789012345678901234567890") - query.processAllAvailable() - } finally { - query.stop() - } - } - - assert(captured.toString() == - """------------------------------------------- - |Batch: 0 - |------------------------------------------- - |+--------------------+ - || value| - |+--------------------+ - ||12345678901234567...| - |+--------------------+ - | - |""".stripMargin) - } - - test("continuous - default") { - val captured = new ByteArrayOutputStream() - Console.withOut(captured) { - val input = spark.readStream - .format("rate") - .option("numPartitions", "1") - .option("rowsPerSecond", "5") - .load() - .select('value) - - val query = input.writeStream.format("console").trigger(Trigger.Continuous(200)).start() - assert(query.isActive) - query.stop() - } - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index 7e53da1..9c1756d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -17,19 +17,17 @@ package org.apache.spark.sql.execution.streaming.sources -import java.util.Optional import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceOptions, MicroBatchReadSupport} +import org.apache.spark.sql.sources.v2.{ContinuousReadSupportProvider, DataSourceOptions, MicroBatchReadSupportProvider} import org.apache.spark.sql.sources.v2.reader.streaming.Offset import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.util.ManualClock @@ -42,7 +40,7 @@ class RateSourceSuite extends StreamTest { override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { assert(query.nonEmpty) val rateSource = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source: RateStreamMicroBatchReader, _) => source + case StreamingExecutionRelation(source: RateStreamMicroBatchReadSupport, _) => source }.head rateSource.clock.asInstanceOf[ManualClock].advance(TimeUnit.SECONDS.toMillis(seconds)) @@ -55,10 +53,10 @@ class RateSourceSuite extends StreamTest { test("microbatch in registry") { withTempDir { temp => DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupport => - val reader = ds.createMicroBatchReader( - Optional.empty(), temp.getCanonicalPath, DataSourceOptions.empty()) - assert(reader.isInstanceOf[RateStreamMicroBatchReader]) + case ds: MicroBatchReadSupportProvider => + val readSupport = ds.createMicroBatchReadSupport( + temp.getCanonicalPath, DataSourceOptions.empty()) + assert(readSupport.isInstanceOf[RateStreamMicroBatchReadSupport]) case _ => throw new IllegalStateException("Could not find read support for rate") } @@ -68,7 +66,7 @@ class RateSourceSuite extends StreamTest { test("compatible with old path in registry") { DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.RateSourceProvider", spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupport => + case ds: MicroBatchReadSupportProvider => assert(ds.isInstanceOf[RateStreamProvider]) case _ => throw new IllegalStateException("Could not find read support for rate") @@ -109,30 +107,19 @@ class RateSourceSuite extends StreamTest { ) } - test("microbatch - set offset") { - withTempDir { temp => - val reader = new RateStreamMicroBatchReader(DataSourceOptions.empty(), temp.getCanonicalPath) - val startOffset = LongOffset(0L) - val endOffset = LongOffset(1L) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - assert(reader.getStartOffset() == startOffset) - assert(reader.getEndOffset() == endOffset) - } - } - test("microbatch - infer offsets") { withTempDir { temp => - val reader = new RateStreamMicroBatchReader( + val readSupport = new RateStreamMicroBatchReadSupport( new DataSourceOptions( Map("numPartitions" -> "1", "rowsPerSecond" -> "100", "useManualClock" -> "true").asJava), temp.getCanonicalPath) - reader.clock.asInstanceOf[ManualClock].advance(100000) - reader.setOffsetRange(Optional.empty(), Optional.empty()) - reader.getStartOffset() match { + readSupport.clock.asInstanceOf[ManualClock].advance(100000) + val startOffset = readSupport.initialOffset() + startOffset match { case r: LongOffset => assert(r.offset === 0L) case _ => throw new IllegalStateException("unexpected offset type") } - reader.getEndOffset() match { + readSupport.latestOffset() match { case r: LongOffset => assert(r.offset >= 100) case _ => throw new IllegalStateException("unexpected offset type") } @@ -141,15 +128,16 @@ class RateSourceSuite extends StreamTest { test("microbatch - predetermined batch size") { withTempDir { temp => - val reader = new RateStreamMicroBatchReader( + val readSupport = new RateStreamMicroBatchReadSupport( new DataSourceOptions(Map("numPartitions" -> "1", "rowsPerSecond" -> "20").asJava), temp.getCanonicalPath) val startOffset = LongOffset(0L) val endOffset = LongOffset(1L) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.planInputPartitions() + val config = readSupport.newScanConfigBuilder(startOffset, endOffset).build() + val tasks = readSupport.planInputPartitions(config) + val readerFactory = readSupport.createReaderFactory(config) assert(tasks.size == 1) - val dataReader = tasks.get(0).createPartitionReader() + val dataReader = readerFactory.createReader(tasks(0)) val data = ArrayBuffer[InternalRow]() while (dataReader.next()) { data.append(dataReader.get()) @@ -160,24 +148,25 @@ class RateSourceSuite extends StreamTest { test("microbatch - data read") { withTempDir { temp => - val reader = new RateStreamMicroBatchReader( + val readSupport = new RateStreamMicroBatchReadSupport( new DataSourceOptions(Map("numPartitions" -> "11", "rowsPerSecond" -> "33").asJava), temp.getCanonicalPath) val startOffset = LongOffset(0L) val endOffset = LongOffset(1L) - reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset)) - val tasks = reader.planInputPartitions() + val config = readSupport.newScanConfigBuilder(startOffset, endOffset).build() + val tasks = readSupport.planInputPartitions(config) + val readerFactory = readSupport.createReaderFactory(config) assert(tasks.size == 11) - val readData = tasks.asScala - .map(_.createPartitionReader()) + val readData = tasks + .map(readerFactory.createReader) .flatMap { reader => val buf = scala.collection.mutable.ListBuffer[InternalRow]() while (reader.next()) buf.append(reader.get()) buf } - assert(readData.map(_.getLong(1)).sorted == Range(0, 33)) + assert(readData.map(_.getLong(1)).sorted === 0.until(33).toArray) } } @@ -288,41 +277,44 @@ class RateSourceSuite extends StreamTest { } test("user-specified schema given") { - val exception = intercept[AnalysisException] { + val exception = intercept[UnsupportedOperationException] { spark.readStream .format("rate") .schema(spark.range(1).schema) .load() } assert(exception.getMessage.contains( - "rate source does not support a user-specified schema")) + "rate source does not support user-specified schema")) } test("continuous in registry") { DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() match { - case ds: ContinuousReadSupport => - val reader = ds.createContinuousReader(Optional.empty(), "", DataSourceOptions.empty()) - assert(reader.isInstanceOf[RateStreamContinuousReader]) + case ds: ContinuousReadSupportProvider => + val readSupport = ds.createContinuousReadSupport( + "", DataSourceOptions.empty()) + assert(readSupport.isInstanceOf[RateStreamContinuousReadSupport]) case _ => throw new IllegalStateException("Could not find read support for continuous rate") } } test("continuous data") { - val reader = new RateStreamContinuousReader( + val readSupport = new RateStreamContinuousReadSupport( new DataSourceOptions(Map("numPartitions" -> "2", "rowsPerSecond" -> "20").asJava)) - reader.setStartOffset(Optional.empty()) - val tasks = reader.planInputPartitions() + val config = readSupport.newScanConfigBuilder(readSupport.initialOffset).build() + val tasks = readSupport.planInputPartitions(config) + val readerFactory = readSupport.createContinuousReaderFactory(config) assert(tasks.size == 2) val data = scala.collection.mutable.ListBuffer[InternalRow]() - tasks.asScala.foreach { + tasks.foreach { case t: RateStreamContinuousInputPartition => - val startTimeMs = reader.getStartOffset() + val startTimeMs = readSupport.initialOffset() .asInstanceOf[RateStreamOffset] .partitionToValueAndRunTimeMs(t.partitionIndex) .runTimeMs - val r = t.createPartitionReader().asInstanceOf[RateStreamContinuousInputPartitionReader] + val r = readerFactory.createReader(t) + .asInstanceOf[RateStreamContinuousPartitionReader] for (rowIndex <- 0 to 9) { r.next() data.append(r.get()) http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala index 48e5cf7..409156e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala @@ -21,7 +21,6 @@ import java.net.{InetSocketAddress, SocketException} import java.nio.ByteBuffer import java.nio.channels.ServerSocketChannel import java.sql.Timestamp -import java.util.Optional import java.util.concurrent.LinkedBlockingQueue import scala.collection.JavaConverters._ @@ -34,8 +33,8 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupport} -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, MicroBatchReadSupportProvider} +import org.apache.spark.sql.sources.v2.reader.streaming.Offset import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -49,14 +48,9 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before serverThread.join() serverThread = null } - if (batchReader != null) { - batchReader.stop() - batchReader = null - } } private var serverThread: ServerThread = null - private var batchReader: MicroBatchReader = null case class AddSocketData(data: String*) extends AddData { override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { @@ -65,7 +59,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before "Cannot add data when there is no query for finding the active socket source") val sources = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source: TextSocketMicroBatchReader, _) => source + case StreamingExecutionRelation(source: TextSocketMicroBatchReadSupport, _) => source } if (sources.isEmpty) { throw new Exception( @@ -91,7 +85,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before test("backward compatibility with old path") { DataSource.lookupDataSource("org.apache.spark.sql.execution.streaming.TextSocketSourceProvider", spark.sqlContext.conf).newInstance() match { - case ds: MicroBatchReadSupport => + case ds: MicroBatchReadSupportProvider => assert(ds.isInstanceOf[TextSocketSourceProvider]) case _ => throw new IllegalStateException("Could not find socket source") @@ -181,16 +175,16 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before test("params not given") { val provider = new TextSocketSourceProvider intercept[AnalysisException] { - provider.createMicroBatchReader(Optional.empty(), "", - new DataSourceOptions(Map.empty[String, String].asJava)) + provider.createMicroBatchReadSupport( + "", new DataSourceOptions(Map.empty[String, String].asJava)) } intercept[AnalysisException] { - provider.createMicroBatchReader(Optional.empty(), "", - new DataSourceOptions(Map("host" -> "localhost").asJava)) + provider.createMicroBatchReadSupport( + "", new DataSourceOptions(Map("host" -> "localhost").asJava)) } intercept[AnalysisException] { - provider.createMicroBatchReader(Optional.empty(), "", - new DataSourceOptions(Map("port" -> "1234").asJava)) + provider.createMicroBatchReadSupport( + "", new DataSourceOptions(Map("port" -> "1234").asJava)) } } @@ -199,7 +193,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before val params = Map("host" -> "localhost", "port" -> "1234", "includeTimestamp" -> "fasle") intercept[AnalysisException] { val a = new DataSourceOptions(params.asJava) - provider.createMicroBatchReader(Optional.empty(), "", a) + provider.createMicroBatchReadSupport("", a) } } @@ -209,12 +203,12 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before StructField("name", StringType) :: StructField("area", StringType) :: Nil) val params = Map("host" -> "localhost", "port" -> "1234") - val exception = intercept[AnalysisException] { - provider.createMicroBatchReader( - Optional.of(userSpecifiedSchema), "", new DataSourceOptions(params.asJava)) + val exception = intercept[UnsupportedOperationException] { + provider.createMicroBatchReadSupport( + userSpecifiedSchema, "", new DataSourceOptions(params.asJava)) } assert(exception.getMessage.contains( - "socket source does not support a user-specified schema")) + "socket source does not support user-specified schema")) } test("input row metrics") { @@ -305,25 +299,27 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before serverThread = new ServerThread() serverThread.start() - val reader = new TextSocketContinuousReader( + val readSupport = new TextSocketContinuousReadSupport( new DataSourceOptions(Map("numPartitions" -> "2", "host" -> "localhost", "port" -> serverThread.port.toString).asJava)) - reader.setStartOffset(Optional.empty()) - val tasks = reader.planInputPartitions() + + val scanConfig = readSupport.newScanConfigBuilder(readSupport.initialOffset()).build() + val tasks = readSupport.planInputPartitions(scanConfig) assert(tasks.size == 2) val numRecords = 10 val data = scala.collection.mutable.ListBuffer[Int]() val offsets = scala.collection.mutable.ListBuffer[Int]() + val readerFactory = readSupport.createContinuousReaderFactory(scanConfig) import org.scalatest.time.SpanSugar._ failAfter(5 seconds) { // inject rows, read and check the data and offsets for (i <- 0 until numRecords) { serverThread.enqueue(i.toString) } - tasks.asScala.foreach { + tasks.foreach { case t: TextSocketContinuousInputPartition => - val r = t.createPartitionReader().asInstanceOf[TextSocketContinuousInputPartitionReader] + val r = readerFactory.createReader(t).asInstanceOf[TextSocketContinuousPartitionReader] for (i <- 0 until numRecords / 2) { r.next() offsets.append(r.getOffset().asInstanceOf[ContinuousRecordPartitionOffset].offset) @@ -339,16 +335,15 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before data.clear() case _ => throw new IllegalStateException("Unexpected task type") } - assert(reader.getStartOffset.asInstanceOf[TextSocketOffset].offsets == List(3, 3)) - reader.commit(TextSocketOffset(List(5, 5))) - assert(reader.getStartOffset.asInstanceOf[TextSocketOffset].offsets == List(5, 5)) + assert(readSupport.startOffset.offsets == List(3, 3)) + readSupport.commit(TextSocketOffset(List(5, 5))) + assert(readSupport.startOffset.offsets == List(5, 5)) } def commitOffset(partition: Int, offset: Int): Unit = { - val offsetsToCommit = reader.getStartOffset.asInstanceOf[TextSocketOffset] - .offsets.updated(partition, offset) - reader.commit(TextSocketOffset(offsetsToCommit)) - assert(reader.getStartOffset.asInstanceOf[TextSocketOffset].offsets == offsetsToCommit) + val offsetsToCommit = readSupport.startOffset.offsets.updated(partition, offset) + readSupport.commit(TextSocketOffset(offsetsToCommit)) + assert(readSupport.startOffset.offsets == offsetsToCommit) } } @@ -356,14 +351,13 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before serverThread = new ServerThread() serverThread.start() - val reader = new TextSocketContinuousReader( + val readSupport = new TextSocketContinuousReadSupport( new DataSourceOptions(Map("numPartitions" -> "2", "host" -> "localhost", "port" -> serverThread.port.toString).asJava)) - reader.setStartOffset(Optional.of(TextSocketOffset(List(5, 5)))) - // ok to commit same offset - reader.setStartOffset(Optional.of(TextSocketOffset(List(5, 5)))) + + readSupport.startOffset = TextSocketOffset(List(5, 5)) assertThrows[IllegalStateException] { - reader.commit(TextSocketOffset(List(6, 6))) + readSupport.commit(TextSocketOffset(List(6, 6))) } } @@ -371,12 +365,12 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before serverThread = new ServerThread() serverThread.start() - val reader = new TextSocketContinuousReader( + val readSupport = new TextSocketContinuousReadSupport( new DataSourceOptions(Map("numPartitions" -> "2", "host" -> "localhost", "includeTimestamp" -> "true", "port" -> serverThread.port.toString).asJava)) - reader.setStartOffset(Optional.empty()) - val tasks = reader.planInputPartitions() + val scanConfig = readSupport.newScanConfigBuilder(readSupport.initialOffset()).build() + val tasks = readSupport.planInputPartitions(scanConfig) assert(tasks.size == 2) val numRecords = 4 @@ -384,9 +378,10 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before for (i <- 0 until numRecords) { serverThread.enqueue(i.toString) } - tasks.asScala.foreach { + val readerFactory = readSupport.createContinuousReaderFactory(scanConfig) + tasks.foreach { case t: TextSocketContinuousInputPartition => - val r = t.createPartitionReader().asInstanceOf[TextSocketContinuousInputPartitionReader] + val r = readerFactory.createReader(t).asInstanceOf[TextSocketContinuousPartitionReader] for (i <- 0 until numRecords / 2) { r.next() assert(r.get().get(0, TextSocketReader.SCHEMA_TIMESTAMP) http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index aa5f723..5edeff5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.sources.v2 -import java.util.{ArrayList, List => JList} - import test.org.apache.spark.sql.sources.v2._ import org.apache.spark.SparkException @@ -38,6 +36,21 @@ import org.apache.spark.sql.vectorized.ColumnarBatch class DataSourceV2Suite extends QueryTest with SharedSQLContext { import testImplicits._ + private def getScanConfig(query: DataFrame): AdvancedScanConfigBuilder = { + query.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => + d.scanConfig.asInstanceOf[AdvancedScanConfigBuilder] + }.head + } + + private def getJavaScanConfig( + query: DataFrame): JavaAdvancedDataSourceV2.AdvancedScanConfigBuilder = { + query.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => + d.scanConfig.asInstanceOf[JavaAdvancedDataSourceV2.AdvancedScanConfigBuilder] + }.head + } + test("simplest implementation") { Seq(classOf[SimpleDataSourceV2], classOf[JavaSimpleDataSourceV2]).foreach { cls => withClue(cls.getName) { @@ -50,18 +63,6 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } test("advanced implementation") { - def getReader(query: DataFrame): AdvancedDataSourceV2#Reader = { - query.queryExecution.executedPlan.collect { - case d: DataSourceV2ScanExec => d.reader.asInstanceOf[AdvancedDataSourceV2#Reader] - }.head - } - - def getJavaReader(query: DataFrame): JavaAdvancedDataSourceV2#Reader = { - query.queryExecution.executedPlan.collect { - case d: DataSourceV2ScanExec => d.reader.asInstanceOf[JavaAdvancedDataSourceV2#Reader] - }.head - } - Seq(classOf[AdvancedDataSourceV2], classOf[JavaAdvancedDataSourceV2]).foreach { cls => withClue(cls.getName) { val df = spark.read.format(cls.getName).load() @@ -70,58 +71,58 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { val q1 = df.select('j) checkAnswer(q1, (0 until 10).map(i => Row(-i))) if (cls == classOf[AdvancedDataSourceV2]) { - val reader = getReader(q1) - assert(reader.filters.isEmpty) - assert(reader.requiredSchema.fieldNames === Seq("j")) + val config = getScanConfig(q1) + assert(config.filters.isEmpty) + assert(config.requiredSchema.fieldNames === Seq("j")) } else { - val reader = getJavaReader(q1) - assert(reader.filters.isEmpty) - assert(reader.requiredSchema.fieldNames === Seq("j")) + val config = getJavaScanConfig(q1) + assert(config.filters.isEmpty) + assert(config.requiredSchema.fieldNames === Seq("j")) } val q2 = df.filter('i > 3) checkAnswer(q2, (4 until 10).map(i => Row(i, -i))) if (cls == classOf[AdvancedDataSourceV2]) { - val reader = getReader(q2) - assert(reader.filters.flatMap(_.references).toSet == Set("i")) - assert(reader.requiredSchema.fieldNames === Seq("i", "j")) + val config = getScanConfig(q2) + assert(config.filters.flatMap(_.references).toSet == Set("i")) + assert(config.requiredSchema.fieldNames === Seq("i", "j")) } else { - val reader = getJavaReader(q2) - assert(reader.filters.flatMap(_.references).toSet == Set("i")) - assert(reader.requiredSchema.fieldNames === Seq("i", "j")) + val config = getJavaScanConfig(q2) + assert(config.filters.flatMap(_.references).toSet == Set("i")) + assert(config.requiredSchema.fieldNames === Seq("i", "j")) } val q3 = df.select('i).filter('i > 6) checkAnswer(q3, (7 until 10).map(i => Row(i))) if (cls == classOf[AdvancedDataSourceV2]) { - val reader = getReader(q3) - assert(reader.filters.flatMap(_.references).toSet == Set("i")) - assert(reader.requiredSchema.fieldNames === Seq("i")) + val config = getScanConfig(q3) + assert(config.filters.flatMap(_.references).toSet == Set("i")) + assert(config.requiredSchema.fieldNames === Seq("i")) } else { - val reader = getJavaReader(q3) - assert(reader.filters.flatMap(_.references).toSet == Set("i")) - assert(reader.requiredSchema.fieldNames === Seq("i")) + val config = getJavaScanConfig(q3) + assert(config.filters.flatMap(_.references).toSet == Set("i")) + assert(config.requiredSchema.fieldNames === Seq("i")) } val q4 = df.select('j).filter('j < -10) checkAnswer(q4, Nil) if (cls == classOf[AdvancedDataSourceV2]) { - val reader = getReader(q4) + val config = getScanConfig(q4) // 'j < 10 is not supported by the testing data source. - assert(reader.filters.isEmpty) - assert(reader.requiredSchema.fieldNames === Seq("j")) + assert(config.filters.isEmpty) + assert(config.requiredSchema.fieldNames === Seq("j")) } else { - val reader = getJavaReader(q4) + val config = getJavaScanConfig(q4) // 'j < 10 is not supported by the testing data source. - assert(reader.filters.isEmpty) - assert(reader.requiredSchema.fieldNames === Seq("j")) + assert(config.filters.isEmpty) + assert(config.requiredSchema.fieldNames === Seq("j")) } } } } test("columnar batch scan implementation") { - Seq(classOf[BatchDataSourceV2], classOf[JavaBatchDataSourceV2]).foreach { cls => + Seq(classOf[ColumnarDataSourceV2], classOf[JavaColumnarDataSourceV2]).foreach { cls => withClue(cls.getName) { val df = spark.read.format(cls.getName).load() checkAnswer(df, (0 until 90).map(i => Row(i, -i))) @@ -153,25 +154,25 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { val df = spark.read.format(cls.getName).load() checkAnswer(df, Seq(Row(1, 4), Row(1, 4), Row(3, 6), Row(2, 6), Row(4, 2), Row(4, 2))) - val groupByColA = df.groupBy('a).agg(sum('b)) + val groupByColA = df.groupBy('i).agg(sum('j)) checkAnswer(groupByColA, Seq(Row(1, 8), Row(2, 6), Row(3, 6), Row(4, 4))) assert(groupByColA.queryExecution.executedPlan.collectFirst { case e: ShuffleExchangeExec => e }.isEmpty) - val groupByColAB = df.groupBy('a, 'b).agg(count("*")) + val groupByColAB = df.groupBy('i, 'j).agg(count("*")) checkAnswer(groupByColAB, Seq(Row(1, 4, 2), Row(2, 6, 1), Row(3, 6, 1), Row(4, 2, 2))) assert(groupByColAB.queryExecution.executedPlan.collectFirst { case e: ShuffleExchangeExec => e }.isEmpty) - val groupByColB = df.groupBy('b).agg(sum('a)) + val groupByColB = df.groupBy('j).agg(sum('i)) checkAnswer(groupByColB, Seq(Row(2, 8), Row(4, 2), Row(6, 5))) assert(groupByColB.queryExecution.executedPlan.collectFirst { case e: ShuffleExchangeExec => e }.isDefined) - val groupByAPlusB = df.groupBy('a + 'b).agg(count("*")) + val groupByAPlusB = df.groupBy('i + 'j).agg(count("*")) checkAnswer(groupByAPlusB, Seq(Row(5, 2), Row(6, 2), Row(8, 1), Row(9, 1))) assert(groupByAPlusB.queryExecution.executedPlan.collectFirst { case e: ShuffleExchangeExec => e @@ -272,36 +273,30 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } test("SPARK-23301: column pruning with arbitrary expressions") { - def getReader(query: DataFrame): AdvancedDataSourceV2#Reader = { - query.queryExecution.executedPlan.collect { - case d: DataSourceV2ScanExec => d.reader.asInstanceOf[AdvancedDataSourceV2#Reader] - }.head - } - val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() val q1 = df.select('i + 1) checkAnswer(q1, (1 until 11).map(i => Row(i))) - val reader1 = getReader(q1) - assert(reader1.requiredSchema.fieldNames === Seq("i")) + val config1 = getScanConfig(q1) + assert(config1.requiredSchema.fieldNames === Seq("i")) val q2 = df.select(lit(1)) checkAnswer(q2, (0 until 10).map(i => Row(1))) - val reader2 = getReader(q2) - assert(reader2.requiredSchema.isEmpty) + val config2 = getScanConfig(q2) + assert(config2.requiredSchema.isEmpty) // 'j === 1 can't be pushed down, but we should still be able do column pruning val q3 = df.filter('j === -1).select('j * 2) checkAnswer(q3, Row(-2)) - val reader3 = getReader(q3) - assert(reader3.filters.isEmpty) - assert(reader3.requiredSchema.fieldNames === Seq("j")) + val config3 = getScanConfig(q3) + assert(config3.filters.isEmpty) + assert(config3.requiredSchema.fieldNames === Seq("j")) // column pruning should work with other operators. val q4 = df.sort('i).limit(1).select('i + 1) checkAnswer(q4, Row(1)) - val reader4 = getReader(q4) - assert(reader4.requiredSchema.fieldNames === Seq("i")) + val config4 = getScanConfig(q4) + assert(config4.requiredSchema.fieldNames === Seq("i")) } test("SPARK-23315: get output from canonicalized data source v2 related plans") { @@ -324,240 +319,290 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } -class SimpleSinglePartitionSource extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceReader { - override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") +case class RangeInputPartition(start: Int, end: Int) extends InputPartition - override def planInputPartitions(): JList[InputPartition[InternalRow]] = { - java.util.Arrays.asList(new SimpleInputPartition(0, 5)) - } - } - - override def createReader(options: DataSourceOptions): DataSourceReader = new Reader +case class NoopScanConfigBuilder(readSchema: StructType) extends ScanConfigBuilder with ScanConfig { + override def build(): ScanConfig = this } -class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { +object SimpleReaderFactory extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val RangeInputPartition(start, end) = partition + new PartitionReader[InternalRow] { + private var current = start - 1 + + override def next(): Boolean = { + current += 1 + current < end + } - class Reader extends DataSourceReader { - override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") + override def get(): InternalRow = InternalRow(current, -current) - override def planInputPartitions(): JList[InputPartition[InternalRow]] = { - java.util.Arrays.asList(new SimpleInputPartition(0, 5), new SimpleInputPartition(5, 10)) + override def close(): Unit = {} } } - - override def createReader(options: DataSourceOptions): DataSourceReader = new Reader } -class SimpleInputPartition(start: Int, end: Int) - extends InputPartition[InternalRow] - with InputPartitionReader[InternalRow] { - private var current = start - 1 - - override def createPartitionReader(): InputPartitionReader[InternalRow] = - new SimpleInputPartition(start, end) +abstract class SimpleReadSupport extends BatchReadSupport { + override def fullSchema(): StructType = new StructType().add("i", "int").add("j", "int") - override def next(): Boolean = { - current += 1 - current < end + override def newScanConfigBuilder(): ScanConfigBuilder = { + NoopScanConfigBuilder(fullSchema()) } - override def get(): InternalRow = InternalRow(current, -current) - - override def close(): Unit = {} + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + SimpleReaderFactory + } } +class SimpleSinglePartitionSource extends DataSourceV2 with BatchReadSupportProvider { -class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { + class ReadSupport extends SimpleReadSupport { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + Array(RangeInputPartition(0, 5)) + } + } + + override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { + new ReadSupport + } +} - class Reader extends DataSourceReader - with SupportsPushDownRequiredColumns with SupportsPushDownFilters { - var requiredSchema = new StructType().add("i", "int").add("j", "int") - var filters = Array.empty[Filter] +class SimpleDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider { - override def pruneColumns(requiredSchema: StructType): Unit = { - this.requiredSchema = requiredSchema + class ReadSupport extends SimpleReadSupport { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + Array(RangeInputPartition(0, 5), RangeInputPartition(5, 10)) } + } - override def pushFilters(filters: Array[Filter]): Array[Filter] = { - val (supported, unsupported) = filters.partition { - case GreaterThan("i", _: Int) => true - case _ => false - } - this.filters = supported - unsupported - } + override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { + new ReadSupport + } +} - override def pushedFilters(): Array[Filter] = filters - override def readSchema(): StructType = { - requiredSchema - } +class AdvancedDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider { + + class ReadSupport extends SimpleReadSupport { + override def newScanConfigBuilder(): ScanConfigBuilder = new AdvancedScanConfigBuilder() + + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + val filters = config.asInstanceOf[AdvancedScanConfigBuilder].filters - override def planInputPartitions(): JList[InputPartition[InternalRow]] = { val lowerBound = filters.collectFirst { case GreaterThan("i", v: Int) => v } - val res = new ArrayList[InputPartition[InternalRow]] + val res = scala.collection.mutable.ArrayBuffer.empty[InputPartition] if (lowerBound.isEmpty) { - res.add(new AdvancedInputPartition(0, 5, requiredSchema)) - res.add(new AdvancedInputPartition(5, 10, requiredSchema)) + res.append(RangeInputPartition(0, 5)) + res.append(RangeInputPartition(5, 10)) } else if (lowerBound.get < 4) { - res.add(new AdvancedInputPartition(lowerBound.get + 1, 5, requiredSchema)) - res.add(new AdvancedInputPartition(5, 10, requiredSchema)) + res.append(RangeInputPartition(lowerBound.get + 1, 5)) + res.append(RangeInputPartition(5, 10)) } else if (lowerBound.get < 9) { - res.add(new AdvancedInputPartition(lowerBound.get + 1, 10, requiredSchema)) + res.append(RangeInputPartition(lowerBound.get + 1, 10)) } - res + res.toArray + } + + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + val requiredSchema = config.asInstanceOf[AdvancedScanConfigBuilder].requiredSchema + new AdvancedReaderFactory(requiredSchema) } } - override def createReader(options: DataSourceOptions): DataSourceReader = new Reader + override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { + new ReadSupport + } } -class AdvancedInputPartition(start: Int, end: Int, requiredSchema: StructType) - extends InputPartition[InternalRow] with InputPartitionReader[InternalRow] { +class AdvancedScanConfigBuilder extends ScanConfigBuilder with ScanConfig + with SupportsPushDownRequiredColumns with SupportsPushDownFilters { - private var current = start - 1 + var requiredSchema = new StructType().add("i", "int").add("j", "int") + var filters = Array.empty[Filter] - override def createPartitionReader(): InputPartitionReader[InternalRow] = { - new AdvancedInputPartition(start, end, requiredSchema) + override def pruneColumns(requiredSchema: StructType): Unit = { + this.requiredSchema = requiredSchema } - override def close(): Unit = {} + override def readSchema(): StructType = requiredSchema - override def next(): Boolean = { - current += 1 - current < end + override def pushFilters(filters: Array[Filter]): Array[Filter] = { + val (supported, unsupported) = filters.partition { + case GreaterThan("i", _: Int) => true + case _ => false + } + this.filters = supported + unsupported } - override def get(): InternalRow = { - val values = requiredSchema.map(_.name).map { - case "i" => current - case "j" => -current + override def pushedFilters(): Array[Filter] = filters + + override def build(): ScanConfig = this +} + +class AdvancedReaderFactory(requiredSchema: StructType) extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val RangeInputPartition(start, end) = partition + new PartitionReader[InternalRow] { + private var current = start - 1 + + override def next(): Boolean = { + current += 1 + current < end + } + + override def get(): InternalRow = { + val values = requiredSchema.map(_.name).map { + case "i" => current + case "j" => -current + } + InternalRow.fromSeq(values) + } + + override def close(): Unit = {} } - InternalRow.fromSeq(values) } } -class SchemaRequiredDataSource extends DataSourceV2 with ReadSupport { +class SchemaRequiredDataSource extends DataSourceV2 with BatchReadSupportProvider { - class Reader(val readSchema: StructType) extends DataSourceReader { - override def planInputPartitions(): JList[InputPartition[InternalRow]] = - java.util.Collections.emptyList() + class ReadSupport(val schema: StructType) extends SimpleReadSupport { + override def fullSchema(): StructType = schema + + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = + Array.empty } - override def createReader(options: DataSourceOptions): DataSourceReader = { + override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { throw new IllegalArgumentException("requires a user-supplied schema") } - override def createReader(schema: StructType, options: DataSourceOptions): DataSourceReader = { - new Reader(schema) + override def createBatchReadSupport( + schema: StructType, options: DataSourceOptions): BatchReadSupport = { + new ReadSupport(schema) } } -class BatchDataSourceV2 extends DataSourceV2 with ReadSupport { +class ColumnarDataSourceV2 extends DataSourceV2 with BatchReadSupportProvider { - class Reader extends DataSourceReader with SupportsScanColumnarBatch { - override def readSchema(): StructType = new StructType().add("i", "int").add("j", "int") + class ReadSupport extends SimpleReadSupport { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { + Array(RangeInputPartition(0, 50), RangeInputPartition(50, 90)) + } - override def planBatchInputPartitions(): JList[InputPartition[ColumnarBatch]] = { - java.util.Arrays.asList( - new BatchInputPartitionReader(0, 50), new BatchInputPartitionReader(50, 90)) + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + ColumnarReaderFactory } } - override def createReader(options: DataSourceOptions): DataSourceReader = new Reader + override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { + new ReadSupport + } } -class BatchInputPartitionReader(start: Int, end: Int) - extends InputPartition[ColumnarBatch] with InputPartitionReader[ColumnarBatch] { - +object ColumnarReaderFactory extends PartitionReaderFactory { private final val BATCH_SIZE = 20 - private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType) - private lazy val j = new OnHeapColumnVector(BATCH_SIZE, IntegerType) - private lazy val batch = new ColumnarBatch(Array(i, j)) - private var current = start + override def supportColumnarReads(partition: InputPartition): Boolean = true - override def createPartitionReader(): InputPartitionReader[ColumnarBatch] = this + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + throw new UnsupportedOperationException + } - override def next(): Boolean = { - i.reset() - j.reset() + override def createColumnarReader(partition: InputPartition): PartitionReader[ColumnarBatch] = { + val RangeInputPartition(start, end) = partition + new PartitionReader[ColumnarBatch] { + private lazy val i = new OnHeapColumnVector(BATCH_SIZE, IntegerType) + private lazy val j = new OnHeapColumnVector(BATCH_SIZE, IntegerType) + private lazy val batch = new ColumnarBatch(Array(i, j)) + + private var current = start + + override def next(): Boolean = { + i.reset() + j.reset() + + var count = 0 + while (current < end && count < BATCH_SIZE) { + i.putInt(count, current) + j.putInt(count, -current) + current += 1 + count += 1 + } - var count = 0 - while (current < end && count < BATCH_SIZE) { - i.putInt(count, current) - j.putInt(count, -current) - current += 1 - count += 1 - } + if (count == 0) { + false + } else { + batch.setNumRows(count) + true + } + } - if (count == 0) { - false - } else { - batch.setNumRows(count) - true - } - } + override def get(): ColumnarBatch = batch - override def get(): ColumnarBatch = { - batch + override def close(): Unit = batch.close() + } } - - override def close(): Unit = batch.close() } -class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { - class Reader extends DataSourceReader with SupportsReportPartitioning { - override def readSchema(): StructType = new StructType().add("a", "int").add("b", "int") +class PartitionAwareDataSource extends DataSourceV2 with BatchReadSupportProvider { - override def planInputPartitions(): JList[InputPartition[InternalRow]] = { + class ReadSupport extends SimpleReadSupport with SupportsReportPartitioning { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { // Note that we don't have same value of column `a` across partitions. - java.util.Arrays.asList( - new SpecificInputPartitionReader(Array(1, 1, 3), Array(4, 4, 6)), - new SpecificInputPartitionReader(Array(2, 4, 4), Array(6, 2, 2))) + Array( + SpecificInputPartition(Array(1, 1, 3), Array(4, 4, 6)), + SpecificInputPartition(Array(2, 4, 4), Array(6, 2, 2))) + } + + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + SpecificReaderFactory } - override def outputPartitioning(): Partitioning = new MyPartitioning + override def outputPartitioning(config: ScanConfig): Partitioning = new MyPartitioning } class MyPartitioning extends Partitioning { override def numPartitions(): Int = 2 override def satisfy(distribution: Distribution): Boolean = distribution match { - case c: ClusteredDistribution => c.clusteredColumns.contains("a") + case c: ClusteredDistribution => c.clusteredColumns.contains("i") case _ => false } } - override def createReader(options: DataSourceOptions): DataSourceReader = new Reader + override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { + new ReadSupport + } } -class SpecificInputPartitionReader(i: Array[Int], j: Array[Int]) - extends InputPartition[InternalRow] - with InputPartitionReader[InternalRow] { - assert(i.length == j.length) - - private var current = -1 +case class SpecificInputPartition(i: Array[Int], j: Array[Int]) extends InputPartition - override def createPartitionReader(): InputPartitionReader[InternalRow] = this +object SpecificReaderFactory extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val p = partition.asInstanceOf[SpecificInputPartition] + new PartitionReader[InternalRow] { + private var current = -1 - override def next(): Boolean = { - current += 1 - current < i.length - } + override def next(): Boolean = { + current += 1 + current < p.i.length + } - override def get(): InternalRow = InternalRow(i(current), j(current)) + override def get(): InternalRow = InternalRow(p.i(current), p.j(current)) - override def close(): Unit = {} + override def close(): Unit = {} + } + } } http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index e1b8e9c..952241b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -18,34 +18,36 @@ package org.apache.spark.sql.sources.v2 import java.io.{BufferedReader, InputStreamReader, IOException} -import java.util.{Collections, List => JList, Optional} +import java.util.Optional import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path} +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.SparkContext import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, InputPartition, InputPartitionReader} +import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.SerializableConfiguration /** * A HDFS based transactional writable data source. - * Each task writes data to `target/_temporary/jobId/$jobId-$partitionId-$attemptNumber`. - * Each job moves files from `target/_temporary/jobId/` to `target`. + * Each task writes data to `target/_temporary/queryId/$jobId-$partitionId-$attemptNumber`. + * Each job moves files from `target/_temporary/queryId/` to `target`. */ -class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteSupport { +class SimpleWritableDataSource extends DataSourceV2 + with BatchReadSupportProvider with BatchWriteSupportProvider { private val schema = new StructType().add("i", "long").add("j", "long") - class Reader(path: String, conf: Configuration) extends DataSourceReader { - override def readSchema(): StructType = schema + class ReadSupport(path: String, conf: Configuration) extends SimpleReadSupport { - override def planInputPartitions(): JList[InputPartition[InternalRow]] = { + override def fullSchema(): StructType = schema + + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { val dataPath = new Path(path) val fs = dataPath.getFileSystem(conf) if (fs.exists(dataPath)) { @@ -53,21 +55,23 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS val name = status.getPath.getName name.startsWith("_") || name.startsWith(".") }.map { f => - val serializableConf = new SerializableConfiguration(conf) - new SimpleCSVInputPartitionReader( - f.getPath.toUri.toString, - serializableConf): InputPartition[InternalRow] - }.toList.asJava + CSVInputPartitionReader(f.getPath.toUri.toString) + }.toArray } else { - Collections.emptyList() + Array.empty } } + + override def createReaderFactory(config: ScanConfig): PartitionReaderFactory = { + val serializableConf = new SerializableConfiguration(conf) + new CSVReaderFactory(serializableConf) + } } - class Writer(jobId: String, path: String, conf: Configuration) extends DataSourceWriter { - override def createWriterFactory(): DataWriterFactory[InternalRow] = { + class WritSupport(queryId: String, path: String, conf: Configuration) extends BatchWriteSupport { + override def createBatchWriterFactory(): DataWriterFactory = { SimpleCounter.resetCounter - new CSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) + new CSVDataWriterFactory(path, queryId, new SerializableConfiguration(conf)) } override def onDataWriterCommit(message: WriterCommitMessage): Unit = { @@ -76,7 +80,7 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS override def commit(messages: Array[WriterCommitMessage]): Unit = { val finalPath = new Path(path) - val jobPath = new Path(new Path(finalPath, "_temporary"), jobId) + val jobPath = new Path(new Path(finalPath, "_temporary"), queryId) val fs = jobPath.getFileSystem(conf) try { for (file <- fs.listStatus(jobPath).map(_.getPath)) { @@ -91,23 +95,23 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS } override def abort(messages: Array[WriterCommitMessage]): Unit = { - val jobPath = new Path(new Path(path, "_temporary"), jobId) + val jobPath = new Path(new Path(path, "_temporary"), queryId) val fs = jobPath.getFileSystem(conf) fs.delete(jobPath, true) } } - override def createReader(options: DataSourceOptions): DataSourceReader = { + override def createBatchReadSupport(options: DataSourceOptions): BatchReadSupport = { val path = new Path(options.get("path").get()) val conf = SparkContext.getActive.get.hadoopConfiguration - new Reader(path.toUri.toString, conf) + new ReadSupport(path.toUri.toString, conf) } - override def createWriter( - jobId: String, + override def createBatchWriteSupport( + queryId: String, schema: StructType, mode: SaveMode, - options: DataSourceOptions): Optional[DataSourceWriter] = { + options: DataSourceOptions): Optional[BatchWriteSupport] = { assert(DataType.equalsStructurally(schema.asNullable, this.schema.asNullable)) assert(!SparkContext.getActive.get.conf.getBoolean("spark.speculation", false)) @@ -130,39 +134,42 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS } val pathStr = path.toUri.toString - Optional.of(new Writer(jobId, pathStr, conf)) + Optional.of(new WritSupport(queryId, pathStr, conf)) } } -class SimpleCSVInputPartitionReader(path: String, conf: SerializableConfiguration) - extends InputPartition[InternalRow] with InputPartitionReader[InternalRow] { +case class CSVInputPartitionReader(path: String) extends InputPartition - @transient private var lines: Iterator[String] = _ - @transient private var currentLine: String = _ - @transient private var inputStream: FSDataInputStream = _ +class CSVReaderFactory(conf: SerializableConfiguration) + extends PartitionReaderFactory { - override def createPartitionReader(): InputPartitionReader[InternalRow] = { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + val path = partition.asInstanceOf[CSVInputPartitionReader].path val filePath = new Path(path) val fs = filePath.getFileSystem(conf.value) - inputStream = fs.open(filePath) - lines = new BufferedReader(new InputStreamReader(inputStream)) - .lines().iterator().asScala - this - } - override def next(): Boolean = { - if (lines.hasNext) { - currentLine = lines.next() - true - } else { - false - } - } + new PartitionReader[InternalRow] { + private val inputStream = fs.open(filePath) + private val lines = new BufferedReader(new InputStreamReader(inputStream)) + .lines().iterator().asScala - override def get(): InternalRow = InternalRow(currentLine.split(",").map(_.trim.toLong): _*) + private var currentLine: String = _ - override def close(): Unit = { - inputStream.close() + override def next(): Boolean = { + if (lines.hasNext) { + currentLine = lines.next() + true + } else { + false + } + } + + override def get(): InternalRow = InternalRow(currentLine.split(",").map(_.trim.toLong): _*) + + override def close(): Unit = { + inputStream.close() + } + } } } @@ -183,12 +190,11 @@ private[v2] object SimpleCounter { } class CSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) - extends DataWriterFactory[InternalRow] { + extends DataWriterFactory { - override def createDataWriter( + override def createWriter( partitionId: Int, - taskId: Long, - epochId: Long): DataWriter[InternalRow] = { + taskId: Long): DataWriter[InternalRow] = { val jobPath = new Path(new Path(path, "_temporary"), jobId) val filePath = new Path(jobPath, s"$jobId-$partitionId-$taskId") val fs = filePath.getFileSystem(conf.value) http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index df22bc1..b528006 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -686,7 +686,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be plan .collect { case r: StreamingExecutionRelation => r.source - case r: StreamingDataSourceV2Relation => r.reader + case r: StreamingDataSourceV2Relation => r.readSupport } .zipWithIndex .find(_._1 == source) http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index 0f15cd6..fe77a1b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -299,9 +299,9 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { try { val input = new MemoryStream[Int](0, sqlContext) { @volatile var numTriggers = 0 - override def getEndOffset: OffsetV2 = { + override def latestOffset(): OffsetV2 = { numTriggers += 1 - super.getEndOffset + super.latestOffset() } } val clock = new StreamManualClock() http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 268ed58..7359252 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -17,15 +17,12 @@ package org.apache.spark.sql.streaming -import java.{util => ju} -import java.util.Optional import java.util.concurrent.CountDownLatch import scala.collection.mutable import org.apache.commons.lang3.RandomStringUtils import org.json4s.NoTypeHints -import org.json4s.jackson.JsonMethods._ import org.json4s.jackson.Serialization import org.scalactic.TolerantNumerics import org.scalatest.BeforeAndAfter @@ -35,13 +32,12 @@ import org.scalatest.mockito.MockitoSugar import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Literal, Rand, Randn, Shuffle, Uuid} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.reader.InputPartition +import org.apache.spark.sql.sources.v2.reader.{InputPartition, ScanConfig} import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider, StreamManualClock} import org.apache.spark.sql.types.StructType @@ -218,25 +214,17 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi private def dataAdded: Boolean = currentOffset.offset != -1 - // setOffsetRange should take 50 ms the first time it is called after data is added - override def setOffsetRange(start: Optional[OffsetV2], end: Optional[OffsetV2]): Unit = { - synchronized { - if (dataAdded) clock.waitTillTime(1050) - super.setOffsetRange(start, end) - } - } - - // getEndOffset should take 100 ms the first time it is called after data is added - override def getEndOffset(): OffsetV2 = synchronized { - if (dataAdded) clock.waitTillTime(1150) - super.getEndOffset() + // latestOffset should take 50 ms the first time it is called after data is added + override def latestOffset(): OffsetV2 = synchronized { + if (dataAdded) clock.waitTillTime(1050) + super.latestOffset() } // getBatch should take 100 ms the first time it is called - override def planInputPartitions(): ju.List[InputPartition[InternalRow]] = { + override def planInputPartitions(config: ScanConfig): Array[InputPartition] = { synchronized { - clock.waitTillTime(1350) - super.planInputPartitions() + clock.waitTillTime(1150) + super.planInputPartitions(config) } } } @@ -277,34 +265,26 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AssertOnQuery(_.status.message === "Waiting for next trigger"), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - // Test status and progress when setOffsetRange is being called + // Test status and progress when `latestOffset` is being called AddData(inputData, 1, 2), - AdvanceManualClock(1000), // time = 1000 to start new trigger, will block on setOffsetRange + AdvanceManualClock(1000), // time = 1000 to start new trigger, will block on `latestOffset` AssertStreamExecThreadIsWaitingForTime(1050), AssertOnQuery(_.status.isDataAvailable === false), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message.startsWith("Getting offsets from")), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - AdvanceManualClock(50), // time = 1050 to unblock setOffsetRange + AdvanceManualClock(50), // time = 1050 to unblock `latestOffset` AssertClockTime(1050), - AssertStreamExecThreadIsWaitingForTime(1150), // will block on getEndOffset that needs 1150 - AssertOnQuery(_.status.isDataAvailable === false), - AssertOnQuery(_.status.isTriggerActive === true), - AssertOnQuery(_.status.message.startsWith("Getting offsets from")), - AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - - AdvanceManualClock(100), // time = 1150 to unblock getEndOffset - AssertClockTime(1150), - // will block on planInputPartitions that needs 1350 - AssertStreamExecThreadIsWaitingForTime(1350), + // will block on `planInputPartitions` that needs 1350 + AssertStreamExecThreadIsWaitingForTime(1150), AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === true), AssertOnQuery(_.status.message === "Processing new data"), AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), - AdvanceManualClock(200), // time = 1350 to unblock planInputPartitions - AssertClockTime(1350), + AdvanceManualClock(100), // time = 1150 to unblock `planInputPartitions` + AssertClockTime(1150), AssertStreamExecThreadIsWaitingForTime(1500), // will block on map task that needs 1500 AssertOnQuery(_.status.isDataAvailable === true), AssertOnQuery(_.status.isTriggerActive === true), @@ -312,7 +292,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AssertOnQuery(_.recentProgress.count(_.numInputRows > 0) === 0), // Test status and progress while batch processing has completed - AdvanceManualClock(150), // time = 1500 to unblock map task + AdvanceManualClock(350), // time = 1500 to unblock map task AssertClockTime(1500), CheckAnswer(2), AssertStreamExecThreadIsWaitingForTime(2000), // will block until the next trigger @@ -332,11 +312,10 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(progress.numInputRows === 2) assert(progress.processedRowsPerSecond === 4.0) - assert(progress.durationMs.get("setOffsetRange") === 50) - assert(progress.durationMs.get("getEndOffset") === 100) - assert(progress.durationMs.get("queryPlanning") === 200) + assert(progress.durationMs.get("latestOffset") === 50) + assert(progress.durationMs.get("queryPlanning") === 100) assert(progress.durationMs.get("walCommit") === 0) - assert(progress.durationMs.get("addBatch") === 150) + assert(progress.durationMs.get("addBatch") === 350) assert(progress.durationMs.get("triggerExecution") === 500) assert(progress.sources.length === 1) http://git-wip-us.apache.org/repos/asf/spark/blob/e7548871/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala index 4f19881..d6819ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala @@ -22,16 +22,15 @@ import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue} import org.mockito.Mockito._ import org.scalatest.mockito.MockitoSugar -import org.apache.spark.{SparkEnv, SparkFunSuite, TaskContext} -import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv} +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.streaming.continuous._ -import org.apache.spark.sql.sources.v2.reader.InputPartition -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousInputPartitionReader, ContinuousReader, PartitionOffset} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousPartitionReader, ContinuousReadSupport, PartitionOffset} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport import org.apache.spark.sql.streaming.StreamTest -import org.apache.spark.sql.types.{DataType, IntegerType} +import org.apache.spark.sql.types.{DataType, IntegerType, StructType} class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { case class LongPartitionOffset(offset: Long) extends PartitionOffset @@ -44,8 +43,8 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { override def beforeEach(): Unit = { super.beforeEach() epochEndpoint = EpochCoordinatorRef.create( - mock[StreamWriter], - mock[ContinuousReader], + mock[StreamingWriteSupport], + mock[ContinuousReadSupport], mock[ContinuousExecution], coordinatorId, startEpoch, @@ -73,26 +72,26 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { */ private def setup(): (BlockingQueue[UnsafeRow], ContinuousQueuedDataReader) = { val queue = new ArrayBlockingQueue[UnsafeRow](1024) - val factory = new InputPartition[InternalRow] { - override def createPartitionReader() = new ContinuousInputPartitionReader[InternalRow] { - var index = -1 - var curr: UnsafeRow = _ - - override def next() = { - curr = queue.take() - index += 1 - true - } + val partitionReader = new ContinuousPartitionReader[InternalRow] { + var index = -1 + var curr: UnsafeRow = _ + + override def next() = { + curr = queue.take() + index += 1 + true + } - override def get = curr + override def get = curr - override def getOffset = LongPartitionOffset(index) + override def getOffset = LongPartitionOffset(index) - override def close() = {} - } + override def close() = {} } val reader = new ContinuousQueuedDataReader( - new ContinuousDataSourceRDDPartition(0, factory), + 0, + partitionReader, + new StructType().add("i", "int"), mockContext, dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize, epochPollIntervalMs = sqlContext.conf.continuousStreamingExecutorPollIntervalMs) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org