haiyangsun-db commented on code in PR #55657: URL: https://github.com/apache/spark/pull/55657#discussion_r3236829670
########## udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/EchoProtocolSuite.scala: ########## @@ -0,0 +1,716 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.udf.worker.core + +import java.util.concurrent.{CountDownLatch, LinkedBlockingQueue, TimeUnit} +import java.util.concurrent.atomic.AtomicBoolean + +import com.google.protobuf.ByteString + +// NOTE: These imports require adding grpc-stub + grpc-inprocess to the pom. +// The proto module currently only runs protoc (no grpc-java plugin), so +// UdfWorkerGrpc does not exist yet. Adding it requires: +// 1. grpc-java codegen plugin in udf/worker/proto/pom.xml +// 2. grpc-stub + grpc-netty (or grpc-inprocess) deps in udf/worker/core/pom.xml +import io.grpc.stub.StreamObserver +import io.grpc.{ManagedChannel, Server, Status} +import io.grpc.inprocess.{InProcessChannelBuilder, InProcessServerBuilder} +import org.apache.spark.udf.worker.UdfWorkerGrpc + +import org.apache.spark.udf.worker.{ + Cancel, CancelResponse, DataRequest, DataResponse, + ExecutionError, UserError, WorkerError, ProtocolError, + Finish, FinishResponse, Heartbeat, HeartbeatResponse, + Init, InitResponse, ShutdownRequest, ShutdownResponse, + UDFWorkerDataFormat, UdfControlRequest, UdfControlResponse, + UdfPayload, UdfRequest, UdfResponse, WorkerRequest, WorkerResponse +} + +// scalastyle:off funsuite +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.BeforeAndAfterEach + +/** + * Validates the UDF gRPC protocol by implementing a dummy echo worker and + * an engine client. The worker echoes each DataRequest batch back as a + * DataResponse. Error paths (ExecutionError, Cancel) are exercised with fake + * triggers. + * + * FINDINGS -- things that were unclear or missing from the proto when + * writing this implementation: + * See the FINDING comments throughout the file. + */ +class EchoProtocolSuite extends AnyFunSuite with BeforeAndAfterEach { +// scalastyle:on funsuite + + private val SUPPORTED_VERSION: Int = 1 + // Trigger word: a DataRequest whose payload equals this string causes the + // worker to emit ExecutionError + FinishResponse instead of echoing. + private val ERROR_TRIGGER: ByteString = ByteString.copyFromUtf8("ERROR") + + private var server: Server = _ + private var channel: ManagedChannel = _ + private var stub: UdfWorkerGrpc.UdfWorkerStub = _ + + override def beforeEach(): Unit = { + val serverName = InProcessServerBuilder.generateName() + server = InProcessServerBuilder.forName(serverName) + .directExecutor() + .addService(new EchoWorkerService) + .build() + .start() + channel = InProcessChannelBuilder.forName(serverName).directExecutor().build() + stub = UdfWorkerGrpc.newStub(channel) + } + + override def afterEach(): Unit = { + channel.shutdownNow() + server.shutdownNow() + } + + // =========================================================================== + // WORKER SIDE (gRPC server) + // =========================================================================== + + /** Worker state machine for one Execute stream. */ + private sealed trait WorkerState + private case object AwaitingInit extends WorkerState + // FINDING 1: The proto does not state explicitly what the worker does + // between Init and the first data message when no chunks are expected. + // The chunking section says "when the engine uses PayloadChunk at all, + // it MUST set last=true." So the worker enters Chunking only after it + // actually receives a PayloadChunk. Until then it stays in AwaitingData, + // ready to send InitResponse immediately when the first DataRequest or + // Finish arrives. The proto should state this explicitly: + // "If no PayloadChunk arrives, InitResponse MUST be sent before the + // first DataRequest or Finish is processed." + private case class AwaitingData(initPayload: ByteString) extends WorkerState + private case class Chunking(accumulated: ByteString) extends WorkerState + private case object Data extends WorkerState + // Finish received; FinishResponse not yet sent. Cancel may still win if it + // arrives before the drain completes and FinishResponse is written. + private case object Finishing extends WorkerState + private case object Done extends WorkerState + + private class EchoWorkerService extends UdfWorkerGrpc.UdfWorkerImplBase { + + override def execute( + responseObserver: StreamObserver[UdfResponse]): StreamObserver[UdfRequest] = + new ExecuteStreamHandler(responseObserver) + + override def manage( + request: WorkerRequest, + responseObserver: StreamObserver[WorkerResponse]): Unit = { + request.getManage match { + case WorkerRequest.Manage.Heartbeat(_) => + responseObserver.onNext(WorkerResponse.newBuilder() + .setHeartbeat(HeartbeatResponse.getDefaultInstance) + .build()) + responseObserver.onCompleted() + + case WorkerRequest.Manage.Shutdown(req) => + // FINDING 2: The proto says the worker SHOULD exit after all Execute + // streams terminate. But ShutdownResponse gives no way to indicate + // "acknowledged, draining" vs "acknowledged, already idle." A boolean + // or enum field on ShutdownResponse would make the worker's state + // visible to the engine without requiring a separate Heartbeat probe. + responseObserver.onNext(WorkerResponse.newBuilder() + .setShutdown(ShutdownResponse.getDefaultInstance) + .build()) + responseObserver.onCompleted() + + case _ => + responseObserver.onError( + Status.INVALID_ARGUMENT.withDescription("empty manage request") + .asRuntimeException()) + } + } + } + + private class ExecuteStreamHandler( + responseObserver: StreamObserver[UdfResponse]) extends StreamObserver[UdfRequest] { + + @volatile private var state: WorkerState = AwaitingInit + // Guards responseObserver: gRPC does not allow concurrent onNext calls. + // FINDING 3: The proto says DataRequest and DataResponse are "independent + // streams" and the worker may emit DataResponse at any time. In practice, + // if the worker dispatches processing to a thread pool, multiple threads + // could race to call responseObserver.onNext(). The proto does not mention + // this constraint. Worker implementations must serialize all writes to the + // response observer themselves. + private val responseLock = new Object + + override def onNext(request: UdfRequest): Unit = { + request.getRequest match { + case UdfRequest.Request.Control(ctrl) => handleControl(ctrl) + case UdfRequest.Request.Data(data) => handleDataRequest(data) + case _ => closeWithProtocolError("empty request oneof") + } + } + + private def handleControl(ctrl: UdfControlRequest): Unit = { + ctrl.getControl match { + case UdfControlRequest.Control.Init(init) => handleInit(init) + case UdfControlRequest.Control.Payload(chunk) => handleChunk(chunk) + case UdfControlRequest.Control.Finish(_) => handleFinish() + case UdfControlRequest.Control.Cancel(cancel) => handleCancel(cancel) + case _ => closeWithProtocolError("empty control oneof") + } + } + + private def handleInit(init: Init): Unit = state match { + case AwaitingInit => + // FINDING 4 (resolved): unsupported protocol_version is now surfaced + // via ExecutionError(ProtocolError). The worker sends ProtocolError + // + FinishResponse, keeping the stream lifecycle intact. + if (init.hasProtocolVersion && init.getProtocolVersion != SUPPORTED_VERSION) { + sendControl(UdfControlResponse.newBuilder() + .setError(ExecutionError.newBuilder() + .setProtocol(ProtocolError.newBuilder() + .setMessage(s"unsupported protocol version: ${init.getProtocolVersion}") + .build()) + .build()) + .build()) + sendControl(UdfControlResponse.newBuilder() + .setFinish(FinishResponse.getDefaultInstance) + .build()) + responseLock.synchronized { responseObserver.onCompleted() } + state = Done + return + } + val inlinePayload = init.getUdf.getPayload + state = AwaitingData(inlinePayload) + + case _ => closeWithProtocolError(s"Init received in state $state") + } + + private def handleChunk(chunk: org.apache.spark.udf.worker.PayloadChunk): Unit = + state match { + case AwaitingData(existing) => + val updated = existing.concat(chunk.getData) + if (chunk.hasLast && chunk.getLast) { + // Payload is complete. Send InitResponse and move to Data. + sendInitResponse() + state = Data + } else { + state = Chunking(updated) + } + + case Chunking(existing) => + val updated = existing.concat(chunk.getData) Review Comment: yeah, this is a test-only dummy server as we would not like to do any serious serialization / deserialization. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
