Github user jose-torres commented on a diff in the pull request: https://github.com/apache/spark/pull/20382#discussion_r163616762 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSourceV2.scala --- @@ -0,0 +1,247 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import java.io.{BufferedReader, InputStreamReader, IOException} +import java.net.Socket +import java.sql.Timestamp +import java.text.SimpleDateFormat +import java.util._ +import java.util.{List => JList} +import java.util.concurrent.atomic.AtomicLong +import javax.annotation.concurrent.GuardedBy + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ListBuffer +import scala.util.{Failure, Success, Try} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.sources.DataSourceRegister +import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options} +import org.apache.spark.sql.sources.v2.reader.{DataReader, ReadTask} +import org.apache.spark.sql.sources.v2.streaming.MicroBatchReadSupport +import org.apache.spark.sql.sources.v2.streaming.reader.{MicroBatchReader, Offset} +import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} + + +object TextSocketSourceProviderV2 { + val HOST = "host" + val PORT = "port" + val INCLUDE_TIMESTAMP = "includeTimestamp" + val NUM_PARTITIONS = "numPartitions" + val SCHEMA_REGULAR = StructType(StructField("value", StringType) :: Nil) + val SCHEMA_TIMESTAMP = StructType(StructField("value", StringType) :: + StructField("timestamp", TimestampType) :: Nil) + val DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US) +} + +class TextSocketSourceProviderV2 extends DataSourceV2 + with MicroBatchReadSupport with DataSourceRegister with Logging { + override def shortName(): String = "socketv2" + + override def createMicroBatchReader( + schema: Optional[StructType], + checkpointLocation: String, + options: DataSourceV2Options): MicroBatchReader = { + logWarning("The socket source should not be used for production applications! " + + "It does not support recovery.") + if (!options.get(TextSocketSourceProviderV2.HOST).isPresent) { + throw new AnalysisException("Set a host to read from with option(\"host\", ...).") + } + if (!options.get(TextSocketSourceProviderV2.PORT).isPresent) { + throw new AnalysisException("Set a port to read from with option(\"port\", ...).") + } + if (schema.isPresent) { + throw new AnalysisException("The socket source does not support a user-specified schema.") + } + + if (options.get(TextSocketSourceProviderV2.INCLUDE_TIMESTAMP).isPresent) { + Try(options.get(TextSocketSourceProviderV2.INCLUDE_TIMESTAMP).get().toBoolean) match { + case Success(bool) => + case Failure(_) => + throw new AnalysisException( + "includeTimestamp must be set to either \"true\" or \"false\"") + } + } + + new TextSocketStreamMicroBatchReader(options) + } +} + +case class TextSocketStreamOffset(offset: Long) extends Offset { + override def json(): String = offset.toString +} + +class TextSocketStreamMicroBatchReader(options: DataSourceV2Options) + extends MicroBatchReader with Logging { + + import TextSocketSourceProviderV2._ + + private var start: TextSocketStreamOffset = _ + private var end: TextSocketStreamOffset = _ + + private val host = options.get(HOST).get() + private val port = options.get(PORT).get().toInt + private val includeTimestamp = options.getBoolean(INCLUDE_TIMESTAMP, false) + private val numPartitions = options.getInt(NUM_PARTITIONS, 1) + + @GuardedBy("this") + private var socket: Socket = _ + + @GuardedBy("this") + private var readThread: Thread = _ + + @GuardedBy("this") + private val batches = new ListBuffer[(String, Timestamp)] + + private val currentOffset = new AtomicLong(-1L) + + private var initialized = false + + @GuardedBy("this") + private var lastOffsetCommitted: Long = -1L + + override def setOffsetRange(start: Optional[Offset], end: Optional[Offset]): Unit = { + if (!initialized) { + initialize() + initialized = true + } + + this.start = start.orElse(TextSocketStreamOffset(-1L)).asInstanceOf[TextSocketStreamOffset] + this.end = + end.orElse(TextSocketStreamOffset(currentOffset.get())).asInstanceOf[TextSocketStreamOffset] + } + + override def getStartOffset(): Offset = { + Option(start).getOrElse(throw new IllegalStateException("start offset not set")) + } + + override def getEndOffset(): Offset = { + Option(end).getOrElse(throw new IllegalStateException("end offset not set")) + } + + override def deserializeOffset(json: String): Offset = { + TextSocketStreamOffset(json.toLong) + } + + override def readSchema(): StructType = { + if (includeTimestamp) { + SCHEMA_TIMESTAMP + } else { + SCHEMA_REGULAR + } + } + + override def createReadTasks(): JList[ReadTask[Row]] = { + val startOrdinal = start.offset.toInt + 1 + val endOrdinal = end.offset.toInt + 1 + val sliceStart = startOrdinal - lastOffsetCommitted.toInt - 1 + val sliceEnd = endOrdinal - lastOffsetCommitted.toInt - 1 + + val rawList = TextSocketStreamMicroBatchReader.this.synchronized { + batches.slice(sliceStart, sliceEnd) + } + + val slices = Array.fill(numPartitions)(new ListBuffer[(String, Timestamp)]) + rawList.zipWithIndex.foreach { case (r, idx) => + slices(idx % numPartitions).append(r) + } + + (0 until numPartitions).map { i => + val slice = slices(i) + new ReadTask[Row] { + override def createDataReader(): DataReader[Row] = new DataReader[Row] { + private var currentIdx = -1 + + override def next(): Boolean = { + currentIdx += 1 + currentIdx < slice.size + } + + override def get(): Row = { + Row(slice(currentIdx)._1, slice(currentIdx)._2) + } + + override def close(): Unit = {} + } + } + }.toList.asJava + } + + override def commit(end: Offset): Unit = synchronized { + val newOffset = end.asInstanceOf[TextSocketStreamOffset] + val offsetDiff = (newOffset.offset - lastOffsetCommitted).toInt --- End diff -- nit: conversion to int is unnecessary
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org