Github user xuanyuanking commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21199#discussion_r186765402
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala
 ---
    @@ -0,0 +1,304 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.sql.execution.streaming.continuous
    +
    +import java.io.{BufferedReader, InputStreamReader, IOException}
    +import java.net.Socket
    +import java.sql.Timestamp
    +import java.text.SimpleDateFormat
    +import java.util.{Calendar, List => JList, Locale}
    +import javax.annotation.concurrent.GuardedBy
    +
    +import scala.collection.JavaConverters._
    +import scala.collection.mutable.ListBuffer
    +
    +import org.json4s.{DefaultFormats, NoTypeHints}
    +import org.json4s.jackson.Serialization
    +
    +import org.apache.spark.SparkEnv
    +import org.apache.spark.internal.Logging
    +import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, 
ThreadSafeRpcEndpoint}
    +import org.apache.spark.sql._
    +import 
org.apache.spark.sql.execution.streaming.continuous.TextSocketContinuousReader.GetRecord
    +import org.apache.spark.sql.sources.v2.DataSourceOptions
    +import org.apache.spark.sql.sources.v2.reader.{DataReader, 
DataReaderFactory}
    +import 
org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, 
ContinuousReader, Offset, PartitionOffset}
    +import org.apache.spark.sql.types.{StringType, StructField, StructType, 
TimestampType}
    +import org.apache.spark.util.RpcUtils
    +
    +
    +object TextSocketContinuousReader {
    +  val SCHEMA_REGULAR = StructType(StructField("value", StringType) :: Nil)
    +  val SCHEMA_TIMESTAMP = StructType(
    +    StructField("value", StringType)
    +      :: StructField("timestamp", TimestampType) :: Nil)
    +  val DATE_FORMAT = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss", Locale.US)
    +
    +  case class GetRecord(offset: TextSocketPartitionOffset)
    +
    +}
    +
    +/**
    + * A ContinuousReader that reads text lines through a TCP socket, designed 
only for tutorials and
    + * debugging. This ContinuousReader will *not* work in production 
applications due to multiple
    + * reasons, including no support for fault recovery.
    + *
    + * The driver maintains a socket connection to the host-port, keeps the 
received messages in
    + * buckets and serves the messages to the executors via a RPC endpoint.
    + */
    +class TextSocketContinuousReader(options: DataSourceOptions) extends 
ContinuousReader with Logging {
    +  implicit val defaultFormats: DefaultFormats = DefaultFormats
    +
    +  private val host: String = options.get("host").get()
    +  private val port: Int = options.get("port").get().toInt
    +
    +  assert(SparkSession.getActiveSession.isDefined)
    +  private val spark = SparkSession.getActiveSession.get
    +  private val numPartitions = spark.sparkContext.defaultParallelism
    +
    +  @GuardedBy("this")
    +  private var socket: Socket = _
    +
    +  @GuardedBy("this")
    +  private var readThread: Thread = _
    +
    +  @GuardedBy("this")
    +  private val buckets = Seq.fill(numPartitions)(new ListBuffer[(String, 
Timestamp)])
    +
    +  @GuardedBy("this")
    +  private var currentOffset: Int = -1
    +
    +  private var startOffset: TextSocketOffset = _
    +
    +  private val recordEndpoint = new RecordEndpoint()
    +  @volatile private var endpointRef: RpcEndpointRef = _
    +
    +  initialize()
    +
    +  override def mergeOffsets(offsets: Array[PartitionOffset]): Offset = {
    +    assert(offsets.length == numPartitions)
    +    val offs = offsets
    +      .map(_.asInstanceOf[TextSocketPartitionOffset])
    +      .sortBy(_.partitionId)
    +      .map(_.offset)
    +      .toList
    +    TextSocketOffset(offs)
    +  }
    +
    +  override def deserializeOffset(json: String): Offset = {
    +    TextSocketOffset(Serialization.read[List[Int]](json))
    +  }
    +
    +  override def setStartOffset(offset: java.util.Optional[Offset]): Unit = {
    +    this.startOffset = offset
    +      .orElse(TextSocketOffset(List.fill(numPartitions)(0)))
    +      .asInstanceOf[TextSocketOffset]
    +  }
    +
    +  override def getStartOffset: Offset = startOffset
    +
    +  override def readSchema(): StructType = {
    +    if (includeTimestamp) {
    +      TextSocketContinuousReader.SCHEMA_TIMESTAMP
    +    } else {
    +      TextSocketContinuousReader.SCHEMA_REGULAR
    +    }
    +  }
    +
    +  override def createDataReaderFactories(): JList[DataReaderFactory[Row]] 
= {
    +
    +    val endpointName = 
s"TextSocketContinuousReaderEndpoint-${java.util.UUID.randomUUID()}"
    +    endpointRef = recordEndpoint.rpcEnv.setupEndpoint(endpointName, 
recordEndpoint)
    +
    +    val offsets = startOffset match {
    +      case off: TextSocketOffset => off.offsets
    +      case off =>
    +        throw new IllegalArgumentException(
    +          s"invalid offset type ${off.getClass} for 
TextSocketContinuousReader")
    +    }
    +
    +    if (offsets.size != numPartitions) {
    +      throw new IllegalArgumentException(
    +        s"The previous run contained ${offsets.size} partitions, but" +
    +          s" $numPartitions partitions are currently configured. The 
numPartitions option" +
    +          " cannot be changed.")
    +    }
    +
    +    startOffset.offsets.zipWithIndex.map {
    +      case (offset, i) =>
    +        TextSocketContinuousDataReaderFactory(
    +          endpointName, i, offset): DataReaderFactory[Row]
    +    }.asJava
    +
    +  }
    +
    +  override def commit(end: Offset): Unit = synchronized {
    +    val endOffset = end.asInstanceOf[TextSocketOffset]
    +    endOffset.offsets.zipWithIndex.foreach {
    +      case (offset, partition) =>
    +        buckets(partition).trimStart(offset - 
startOffset.offsets(partition))
    +    }
    +    startOffset = endOffset
    --- End diff --
    
    If I understand right, this commit will never enter in your added test case.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to