Github user jose-torres commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20382#discussion_r164943885
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/socket.scala 
---
    @@ -47,130 +48,141 @@ object TextSocketSource {
      * This source will *not* work in production applications due to multiple 
reasons, including no
      * support for fault recovery and keeping all of the text read in memory 
forever.
      */
    -class TextSocketSource(host: String, port: Int, includeTimestamp: Boolean, 
sqlContext: SQLContext)
    -  extends Source with Logging {
    -
    -  @GuardedBy("this")
    -  private var socket: Socket = null
    -
    -  @GuardedBy("this")
    -  private var readThread: Thread = null
    -
    -  /**
    -   * All batches from `lastCommittedOffset + 1` to `currentOffset`, 
inclusive.
    -   * Stored in a ListBuffer to facilitate removing committed batches.
    -   */
    -  @GuardedBy("this")
    -  protected val batches = new ListBuffer[(String, Timestamp)]
    -
    -  @GuardedBy("this")
    -  protected var currentOffset: LongOffset = new LongOffset(-1)
    -
    -  @GuardedBy("this")
    -  protected var lastOffsetCommitted : LongOffset = new LongOffset(-1)
    +class TextSocketSource(
    +    protected val host: String,
    +    protected val port: Int,
    +    includeTimestamp: Boolean,
    +    sqlContext: SQLContext)
    +  extends Source with TextSocketReader with Logging {
     
       initialize()
     
    -  private def initialize(): Unit = synchronized {
    -    socket = new Socket(host, port)
    -    val reader = new BufferedReader(new 
InputStreamReader(socket.getInputStream))
    -    readThread = new Thread(s"TextSocketSource($host, $port)") {
    -      setDaemon(true)
    -
    -      override def run(): Unit = {
    -        try {
    -          while (true) {
    -            val line = reader.readLine()
    -            if (line == null) {
    -              // End of file reached
    -              logWarning(s"Stream closed by $host:$port")
    -              return
    -            }
    -            TextSocketSource.this.synchronized {
    -              val newData = (line,
    -                Timestamp.valueOf(
    -                  
TextSocketSource.DATE_FORMAT.format(Calendar.getInstance().getTime()))
    -                )
    -              currentOffset = currentOffset + 1
    -              batches.append(newData)
    -            }
    -          }
    -        } catch {
    -          case e: IOException =>
    -        }
    -      }
    -    }
    -    readThread.start()
    -  }
    -
       /** Returns the schema of the data from this source */
    -  override def schema: StructType = if (includeTimestamp) 
TextSocketSource.SCHEMA_TIMESTAMP
    -  else TextSocketSource.SCHEMA_REGULAR
    -
    -  override def getOffset: Option[Offset] = synchronized {
    -    if (currentOffset.offset == -1) {
    -      None
    -    } else {
    -      Some(currentOffset)
    -    }
    -  }
    +  override def schema: StructType =
    +    if (includeTimestamp) TextSocketSource.SCHEMA_TIMESTAMP else 
TextSocketSource.SCHEMA_REGULAR
    +
    +  override def getOffset: Option[Offset] = 
getOffsetInternal.map(LongOffset(_))
     
       /** Returns the data that is between the offsets (`start`, `end`]. */
    -  override def getBatch(start: Option[Offset], end: Offset): DataFrame = 
synchronized {
    -    val startOrdinal =
    -      
start.flatMap(LongOffset.convert).getOrElse(LongOffset(-1)).offset.toInt + 1
    -    val endOrdinal = 
LongOffset.convert(end).getOrElse(LongOffset(-1)).offset.toInt + 1
    -
    -    // Internal buffer only holds the batches after lastOffsetCommitted
    -    val rawList = synchronized {
    -      val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1
    -      val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1
    -      batches.slice(sliceStart, sliceEnd)
    -    }
    +  override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
    +    val rawList = 
getBatchInternal(start.flatMap(LongOffset.convert).map(_.offset),
    +      LongOffset.convert(end).map(_.offset))
     
         val rdd = sqlContext.sparkContext
           .parallelize(rawList)
           .map { case (v, ts) => InternalRow(UTF8String.fromString(v), 
ts.getTime) }
         sqlContext.internalCreateDataFrame(rdd, schema, isStreaming = true)
       }
     
    -  override def commit(end: Offset): Unit = synchronized {
    +  override def commit(end: Offset): Unit = {
         val newOffset = LongOffset.convert(end).getOrElse(
           sys.error(s"TextSocketStream.commit() received an offset ($end) that 
did not " +
             s"originate with an instance of this class")
         )
     
    -    val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt
    +    commitInternal(newOffset.offset)
    +  }
     
    -    if (offsetDiff < 0) {
    -      sys.error(s"Offsets committed out of order: $lastOffsetCommitted 
followed by $end")
    -    }
    +  override def toString: String = s"TextSocketSource[host: $host, port: 
$port]"
    +}
    +
    +case class TextSocketOffset(offset: Long) extends V2Offset {
    +  override def json(): String = offset.toString
    +}
    +
    +class TextSocketMicroBatchReader(options: DataSourceV2Options)
    +  extends MicroBatchReader with TextSocketReader with Logging {
    +
    +  private var startOffset: TextSocketOffset = _
    +  private var endOffset: TextSocketOffset = _
    +
    +  protected val host: String = options.get("host").get()
    +  protected val port: Int = options.get("port").get().toInt
     
    -    batches.trimStart(offsetDiff)
    -    lastOffsetCommitted = newOffset
    +  initialize()
    +
    +  override def setOffsetRange(start: Optional[V2Offset], end: 
Optional[V2Offset]): Unit = {
    +    startOffset = 
start.orElse(TextSocketOffset(-1L)).asInstanceOf[TextSocketOffset]
    +    endOffset = end.orElse(
    +      
TextSocketOffset(getOffsetInternal.getOrElse(-1L))).asInstanceOf[TextSocketOffset]
       }
     
    -  /** Stop this source. */
    -  override def stop(): Unit = synchronized {
    -    if (socket != null) {
    -      try {
    -        // Unfortunately, BufferedReader.readLine() cannot be interrupted, 
so the only way to
    -        // stop the readThread is to close the socket.
    -        socket.close()
    -      } catch {
    -        case e: IOException =>
    -      }
    -      socket = null
    +  override def getStartOffset(): V2Offset = {
    +    Option(startOffset).getOrElse(throw new IllegalStateException("start 
offset not set"))
    +  }
    +
    +  override def getEndOffset(): V2Offset = {
    +    Option(endOffset).getOrElse(throw new IllegalStateException("end 
offset not set"))
    +  }
    +
    +  override def deserializeOffset(json: String): V2Offset = {
    +    TextSocketOffset(json.toLong)
    +  }
    +
    +  override def readSchema(): StructType = {
    +    val includeTimestamp = options.getBoolean("includeTimestamp", false)
    +    if (includeTimestamp) TextSocketSource.SCHEMA_TIMESTAMP else 
TextSocketSource.SCHEMA_REGULAR
    +  }
    +
    +  override def createReadTasks(): JList[ReadTask[Row]] = {
    +    val rawList = getBatchInternal(Option(startOffset.offset), 
Option(endOffset.offset))
    +
    +    assert(SparkSession.getActiveSession.isDefined)
    +    val spark = SparkSession.getActiveSession.get
    +    val numPartitions = spark.sparkContext.defaultParallelism
    +
    +    val slices = Array.fill(numPartitions)(new ListBuffer[(String, 
Timestamp)])
    +    rawList.zipWithIndex.foreach { case (r, idx) =>
    +      slices(idx % numPartitions).append(r)
         }
    +
    +    (0 until numPartitions).map { i =>
    +      val slice = slices(i)
    +      new ReadTask[Row] {
    +        override def createDataReader(): DataReader[Row] = new 
DataReader[Row] {
    +          private var currentIdx = -1
    +
    +          override def next(): Boolean = {
    +            currentIdx += 1
    +            currentIdx < slice.size
    +          }
    +
    +          override def get(): Row = {
    +            Row(slice(currentIdx)._1, slice(currentIdx)._2)
    +          }
    +
    +          override def close(): Unit = {}
    +        }
    +      }
    +    }.toList.asJava
       }
     
    -  override def toString: String = s"TextSocketSource[host: $host, port: 
$port]"
    +  override def commit(end: V2Offset): Unit = {
    +    val newOffset = end.asInstanceOf[TextSocketOffset]
    +    commitInternal(newOffset.offset)
    +  }
    +
    +  override def toString: String = s"TextSocketMicroBatchReader[host: 
$host, port: $port]"
     }
     
    -class TextSocketSourceProvider extends StreamSourceProvider with 
DataSourceRegister with Logging {
    -  private def parseIncludeTimestamp(params: Map[String, String]): Boolean 
= {
    -    Try(params.getOrElse("includeTimestamp", "false").toBoolean) match {
    -      case Success(bool) => bool
    +class TextSocketSourceProvider extends DataSourceV2
    +  with MicroBatchReadSupport with StreamSourceProvider with 
DataSourceRegister with Logging {
    --- End diff --
    
    TD and I discussed this offline. It should be fine to remove the V1 
StreamSourceProvider implementation, because:
    
    * this isn't a production-quality source, so users shouldn't need to fall 
back to it
    * this source won't be particularly useful at exercising the V1 execution 
pipeline once we transition all sources to V2


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to