Github user tdas commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20552#discussion_r167080181
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala
 ---
    @@ -17,52 +17,119 @@
     
     package org.apache.spark.sql.execution.streaming
     
    +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, 
ObjectInputStream, ObjectOutputStream}
    +
     import org.apache.spark.TaskContext
    -import org.apache.spark.sql.{DataFrame, Encoder, ForeachWriter}
    -import org.apache.spark.sql.catalyst.encoders.encoderFor
    +import org.apache.spark.sql._
    +import org.apache.spark.sql.catalyst.InternalRow
    +import org.apache.spark.sql.catalyst.encoders.{encoderFor, 
ExpressionEncoder}
    +import 
org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
    +import org.apache.spark.sql.sources.v2.{DataSourceOptions, 
StreamWriteSupport}
    +import org.apache.spark.sql.sources.v2.writer._
    +import org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter
    +import org.apache.spark.sql.streaming.OutputMode
    +import org.apache.spark.sql.types.StructType
     
    -/**
    - * A [[Sink]] that forwards all data into [[ForeachWriter]] according to 
the contract defined by
    - * [[ForeachWriter]].
    - *
    - * @param writer The [[ForeachWriter]] to process all data.
    - * @tparam T The expected type of the sink.
    - */
    -class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with 
Serializable {
    -
    -  override def addBatch(batchId: Long, data: DataFrame): Unit = {
    -    // This logic should've been as simple as:
    -    // ```
    -    //   data.as[T].foreachPartition { iter => ... }
    -    // ```
    -    //
    -    // Unfortunately, doing that would just break the incremental planing. 
The reason is,
    -    // `Dataset.foreachPartition()` would further call `Dataset.rdd()`, 
but `Dataset.rdd()` will
    -    // create a new plan. Because StreamExecution uses the existing plan 
to collect metrics and
    -    // update watermark, we should never create a new plan. Otherwise, 
metrics and watermark are
    -    // updated in the new plan, and StreamExecution cannot retrieval them.
    -    //
    -    // Hence, we need to manually convert internal rows to objects using 
encoder.
    +
    +case class ForeachWriterProvider[T: Encoder](writer: ForeachWriter[T]) 
extends StreamWriteSupport {
    +  override def createStreamWriter(
    +      queryId: String,
    +      schema: StructType,
    +      mode: OutputMode,
    +      options: DataSourceOptions): StreamWriter = {
         val encoder = encoderFor[T].resolveAndBind(
    -      data.logicalPlan.output,
    -      data.sparkSession.sessionState.analyzer)
    -    data.queryExecution.toRdd.foreachPartition { iter =>
    -      if (writer.open(TaskContext.getPartitionId(), batchId)) {
    -        try {
    -          while (iter.hasNext) {
    -            writer.process(encoder.fromRow(iter.next()))
    -          }
    -        } catch {
    -          case e: Throwable =>
    -            writer.close(e)
    -            throw e
    -        }
    -        writer.close(null)
    -      } else {
    -        writer.close(null)
    +      schema.toAttributes,
    +      SparkSession.getActiveSession.get.sessionState.analyzer)
    +    ForeachInternalWriter(writer, encoder)
    +  }
    +}
    +
    +case class ForeachInternalWriter[T: Encoder](
    +    writer: ForeachWriter[T], encoder: ExpressionEncoder[T])
    +    extends StreamWriter with SupportsWriteInternalRow {
    +  override def commit(epochId: Long, messages: 
Array[WriterCommitMessage]): Unit = {}
    +  override def abort(epochId: Long, messages: Array[WriterCommitMessage]): 
Unit = {}
    +
    +  override def createInternalRowWriterFactory(): 
DataWriterFactory[InternalRow] = {
    +    ForeachWriterFactory(writer, encoder)
    +  }
    +}
    +
    +case class ForeachWriterFactory[T: Encoder](writer: ForeachWriter[T], 
encoder: ExpressionEncoder[T])
    +    extends DataWriterFactory[InternalRow] {
    +  override def createDataWriter(partitionId: Int, attemptNumber: Int): 
ForeachDataWriter[T] = {
    +    new ForeachDataWriter(writer, encoder, partitionId)
    +  }
    +}
    +
    +class ForeachDataWriter[T : Encoder](
    +    private var writer: ForeachWriter[T], encoder: ExpressionEncoder[T], 
partitionId: Int)
    +    extends DataWriter[InternalRow] {
    +  private val initialEpochId: Long = {
    +    // Start with the microbatch ID. If it's not there, we're in 
continuous execution,
    +    // so get the start epoch.
    +    // This ID will be incremented as commits happen.
    +    TaskContext.get().getLocalProperty(MicroBatchExecution.BATCH_ID_KEY) 
match {
    +      case null => 
TaskContext.get().getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong
    +      case batch => batch.toLong
    +    }
    +  }
    +  private var currentEpochId = initialEpochId
    +
    +  // The lifecycle of the ForeachWriter is incompatible with the lifecycle 
of DataSourceV2 writers.
    +  // Unfortunately, we cannot migrate ForeachWriter, as its 
implementations live in user code. So
    +  // we need a small state machine to shim between them.
    +  //  * CLOSED means close() has been called.
    +  //  * OPENED
    +  private object WriterState extends Enumeration {
    +    type WriterState = Value
    +    val CLOSED, OPENED, OPENED_SKIP_PROCESSING = Value
    +  }
    +  import WriterState._
    +
    +  private var state = CLOSED
    +
    +  private def openAndSetState(epochId: Long) = {
    +    // Create a new writer by roundtripping through the serialization for 
compatibility.
    +    // In the old API, a writer instantiation would never get reused.
    +    val byteStream = new ByteArrayOutputStream()
    --- End diff --
    
    Why are you serializing and deserializing here? If you are reserializing 
the ForeachWriter, doesnt this mean that you are going to retain state (of the 
non-transient fields) across them? Is that what you want?
    
    seems the best thing to do is to serialize the writer at the driver, send 
the bytes to the task, and then deserialize repeatedly. then you only incur the 
cost of deserializing between epochs and you always start with a fresh copy of 
the ForeachWriter?



---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to