This is an automated email from the ASF dual-hosted git repository.

lostluck pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 3463aa3fde5 [prism] Fail jobs on SDK disconnect. (#28193)
3463aa3fde5 is described below

commit 3463aa3fde5d1c6a3f735ed0951046829142eb41
Author: Robert Burke <lostl...@users.noreply.github.com>
AuthorDate: Fri Sep 1 13:39:48 2023 -0700

    [prism] Fail jobs on SDK disconnect. (#28193)
    
    * [prism] Fail jobs on SDK disconnect.
    
    * Reduce flaky short bame for passert test.
    
    * [prism] better workerID, warn on pre-bundle fail, buffer done chan
    
    * Add causes, extract bundle failures to RunPipeline
    
    * Return bundle errors through execPipeline.
    
    ---------
    
    Co-authored-by: lostluck <13907733+lostl...@users.noreply.github.com>
---
 .../prism/internal/engine/elementmanager.go        | 17 ++++-
 sdks/go/pkg/beam/runners/prism/internal/execute.go | 43 +++++++----
 .../beam/runners/prism/internal/jobservices/job.go | 12 ++-
 .../prism/internal/jobservices/management.go       |  2 +-
 sdks/go/pkg/beam/runners/prism/internal/stage.go   | 31 +++-----
 .../beam/runners/prism/internal/worker/bundle.go   | 24 +++---
 .../runners/prism/internal/worker/bundle_test.go   |  2 +-
 .../beam/runners/prism/internal/worker/worker.go   | 85 +++++++++++++++-------
 .../runners/prism/internal/worker/worker_test.go   | 10 +--
 sdks/go/pkg/beam/testing/passert/equals_test.go    |  2 +-
 10 files changed, 142 insertions(+), 86 deletions(-)

diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go 
b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go
index fb9c9802502..df53bce8ac5 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go
@@ -209,11 +209,11 @@ func (rb RunBundle) LogValue() slog.Value {
 // remaining.
 func (em *ElementManager) Bundles(ctx context.Context, nextBundID func() 
string) <-chan RunBundle {
        runStageCh := make(chan RunBundle)
-       ctx, cancelFn := context.WithCancel(ctx)
+       ctx, cancelFn := context.WithCancelCause(ctx)
        go func() {
                em.pendingElements.Wait()
-               slog.Info("no more pending elements: terminating pipeline")
-               cancelFn()
+               slog.Debug("no more pending elements: terminating pipeline")
+               cancelFn(fmt.Errorf("elementManager out of elements, cleaning 
up"))
                // Ensure the watermark evaluation goroutine exits.
                em.refreshCond.Broadcast()
        }()
@@ -394,6 +394,17 @@ func (em *ElementManager) PersistBundle(rb RunBundle, 
col2Coders map[string]PCol
        em.addRefreshAndClearBundle(stage.ID, rb.BundleID)
 }
 
+// FailBundle clears the extant data allowing the execution to shut down.
+func (em *ElementManager) FailBundle(rb RunBundle) {
+       stage := em.stages[rb.StageID]
+       stage.mu.Lock()
+       completed := stage.inprogress[rb.BundleID]
+       em.pendingElements.Add(-len(completed.es))
+       delete(stage.inprogress, rb.BundleID)
+       stage.mu.Unlock()
+       em.addRefreshAndClearBundle(rb.StageID, rb.BundleID)
+}
+
 // ReturnResiduals is called after a successful split, so the remaining work
 // can be re-assigned to a new bundle.
 func (em *ElementManager) ReturnResiduals(rb RunBundle, firstRsIndex int, 
inputInfo PColInfo, residuals [][]byte) {
diff --git a/sdks/go/pkg/beam/runners/prism/internal/execute.go 
b/sdks/go/pkg/beam/runners/prism/internal/execute.go
index 42327a0209d..b2f9d866603 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/execute.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/execute.go
@@ -54,7 +54,7 @@ func RunPipeline(j *jobservices.Job) {
                return
        }
        env, _ := getOnlyPair(envs)
-       wk := worker.New(env) // Cheating by having the worker id match the 
environment id.
+       wk := worker.New(j.String()+"_"+env, env) // Cheating by having the 
worker id match the environment id.
        go wk.Serve()
        timeout := time.Minute
        time.AfterFunc(timeout, func() {
@@ -69,7 +69,7 @@ func RunPipeline(j *jobservices.Job) {
        // When this function exits, we cancel the context to clear
        // any related job resources.
        defer func() {
-               j.CancelFn(nil)
+               j.CancelFn(fmt.Errorf("runPipeline returned, cleaning up"))
        }()
        go runEnvironment(j.RootCtx, j, env, wk)
 
@@ -102,10 +102,10 @@ func runEnvironment(ctx context.Context, j 
*jobservices.Job, env string, wk *wor
        case urns.EnvExternal:
                ep := &pipepb.ExternalPayload{}
                if err := (proto.UnmarshalOptions{}).Unmarshal(e.GetPayload(), 
ep); err != nil {
-                       slog.Error("unmarshing environment payload", err, 
slog.String("envID", wk.ID))
+                       slog.Error("unmarshing environment payload", err, 
slog.String("envID", wk.Env))
                }
                externalEnvironment(ctx, ep, wk)
-               slog.Info("environment stopped", slog.String("envID", 
wk.String()), slog.String("job", j.String()))
+               slog.Debug("environment stopped", slog.String("envID", 
wk.String()), slog.String("job", j.String()))
        default:
                panic(fmt.Sprintf("environment %v with urn %v unimplemented", 
env, e.GetUrn()))
        }
@@ -271,7 +271,7 @@ func executePipeline(ctx context.Context, wk *worker.W, j 
*jobservices.Job) erro
                        }
                        stages[stage.ID] = stage
                        wk.Descriptors[stage.ID] = stage.desc
-               case wk.ID:
+               case wk.Env:
                        // Great! this is for this environment. // Broken 
abstraction.
                        if err := buildDescriptor(stage, comps, wk); err != nil 
{
                                return fmt.Errorf("prism error building stage 
%v: \n%w", stage.ID, err)
@@ -296,16 +296,31 @@ func executePipeline(ctx context.Context, wk *worker.W, j 
*jobservices.Job) erro
        // Use a channel to limit max parallelism for the pipeline.
        maxParallelism := make(chan struct{}, 8)
        // Execute stages here
-       for rb := range em.Bundles(ctx, wk.NextInst) {
-               maxParallelism <- struct{}{}
-               go func(rb engine.RunBundle) {
-                       defer func() { <-maxParallelism }()
-                       s := stages[rb.StageID]
-                       s.Execute(ctx, j, wk, comps, em, rb)
-               }(rb)
+       bundleFailed := make(chan error)
+       bundles := em.Bundles(ctx, wk.NextInst)
+       for {
+               select {
+               case <-ctx.Done():
+                       return context.Cause(ctx)
+               case rb, ok := <-bundles:
+                       if !ok {
+                               slog.Debug("pipeline done!", slog.String("job", 
j.String()))
+                               return nil
+                       }
+                       maxParallelism <- struct{}{}
+                       go func(rb engine.RunBundle) {
+                               defer func() { <-maxParallelism }()
+                               s := stages[rb.StageID]
+                               if err := s.Execute(ctx, j, wk, comps, em, rb); 
err != nil {
+                                       // Ensure we clean up on bundle failure
+                                       em.FailBundle(rb)
+                                       bundleFailed <- err
+                               }
+                       }(rb)
+               case err := <-bundleFailed:
+                       return err
+               }
        }
-       slog.Info("pipeline done!", slog.String("job", j.String()))
-       return nil
 }
 
 func collectionPullDecoder(coldCId string, coders map[string]*pipepb.Coder, 
comps *pipepb.Components) func(io.Reader) []byte {
diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go 
b/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go
index fe4f18bd38e..10d36066391 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go
@@ -137,9 +137,13 @@ func (j *Job) SendMsg(msg string) {
 func (j *Job) sendState(state jobpb.JobState_Enum) {
        j.streamCond.L.Lock()
        defer j.streamCond.L.Unlock()
-       j.stateTime = time.Now()
-       j.stateIdx++
-       j.state.Store(state)
+       old := j.state.Load()
+       // Never overwrite a failed state with another one.
+       if old != jobpb.JobState_FAILED {
+               j.state.Store(state)
+               j.stateTime = time.Now()
+               j.stateIdx++
+       }
        j.streamCond.Broadcast()
 }
 
@@ -163,5 +167,5 @@ func (j *Job) Failed(err error) {
        slog.Error("job failed", slog.Any("job", j), slog.Any("error", err))
        j.failureErr = err
        j.sendState(jobpb.JobState_FAILED)
-       j.CancelFn(err)
+       j.CancelFn(fmt.Errorf("jobFailed %v: %w", j, err))
 }
diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go 
b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go
index d347e88ec60..e626a05b51e 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go
@@ -226,7 +226,7 @@ func (s *Server) GetMessageStream(req 
*jobpb.JobMessagesRequest, stream jobpb.Jo
                        job.streamCond.Wait()
                        select { // Quit out if the external connection is done.
                        case <-stream.Context().Done():
-                               return stream.Context().Err()
+                               return context.Cause(stream.Context())
                        default:
                        }
                }
diff --git a/sdks/go/pkg/beam/runners/prism/internal/stage.go 
b/sdks/go/pkg/beam/runners/prism/internal/stage.go
index 3f4451d7db3..4d8d4621168 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/stage.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/stage.go
@@ -75,12 +75,7 @@ type stage struct {
        OutputsToCoders   map[string]engine.PColInfo
 }
 
-func (s *stage) Execute(ctx context.Context, j *jobservices.Job, wk *worker.W, 
comps *pipepb.Components, em *engine.ElementManager, rb engine.RunBundle) {
-       select {
-       case <-ctx.Done():
-               return
-       default:
-       }
+func (s *stage) Execute(ctx context.Context, j *jobservices.Job, wk *worker.W, 
comps *pipepb.Components, em *engine.ElementManager, rb engine.RunBundle) error 
{
        slog.Debug("Execute: starting bundle", "bundle", rb)
 
        var b *worker.B
@@ -103,7 +98,7 @@ func (s *stage) Execute(ctx context.Context, j 
*jobservices.Job, wk *worker.W, c
                closed := make(chan struct{})
                close(closed)
                dataReady = closed
-       case wk.ID:
+       case wk.Env:
                b = &worker.B{
                        PBDID:  s.ID,
                        InstID: rb.BundleID,
@@ -122,15 +117,10 @@ func (s *stage) Execute(ctx context.Context, j 
*jobservices.Job, wk *worker.W, c
 
                slog.Debug("Execute: processing", "bundle", rb)
                defer b.Cleanup(wk)
-               b.Fail = func(errMsg string) {
-                       slog.Error("job failed", "bundle", rb, "job", j)
-                       err := fmt.Errorf("%v", errMsg)
-                       j.Failed(err)
-               }
                dataReady = b.ProcessOn(ctx, wk)
        default:
                err := fmt.Errorf("unknown environment[%v]", s.envID)
-               slog.Error("Execute", err)
+               slog.Error("Execute", "error", err)
                panic(err)
        }
 
@@ -145,20 +135,20 @@ progress:
                        progTick.Stop()
                        break progress // exit progress loop on close.
                case <-progTick.C:
-                       resp, err := b.Progress(wk)
+                       resp, err := b.Progress(ctx, wk)
                        if err != nil {
                                slog.Debug("SDK Error from progress, aborting 
progress", "bundle", rb, "error", err.Error())
                                break progress
                        }
                        index, unknownIDs := j.ContributeTentativeMetrics(resp)
                        if len(unknownIDs) > 0 {
-                               md := wk.MonitoringMetadata(unknownIDs)
+                               md := wk.MonitoringMetadata(ctx, unknownIDs)
                                j.AddMetricShortIDs(md)
                        }
                        slog.Debug("progress report", "bundle", rb, "index", 
index)
                        // Progress for the bundle hasn't advanced. Try 
splitting.
                        if previousIndex == index && !splitsDone {
-                               sr, err := b.Split(wk, 0.5 /* fraction of 
remainder */, nil /* allowed splits */)
+                               sr, err := b.Split(ctx, wk, 0.5 /* fraction of 
remainder */, nil /* allowed splits */)
                                if err != nil {
                                        slog.Warn("SDK Error from split, 
aborting splits", "bundle", rb, "error", err.Error())
                                        break progress
@@ -200,16 +190,18 @@ progress:
        var resp *fnpb.ProcessBundleResponse
        select {
        case resp = <-b.Resp:
+               if b.BundleErr != nil {
+                       return b.BundleErr
+               }
        case <-ctx.Done():
-               // Ensures we clean up on failure, if the response is blocked.
-               return
+               return context.Cause(ctx)
        }
 
        // Tally metrics immeadiately so they're available before
        // pipeline termination.
        unknownIDs := j.ContributeFinalMetrics(resp)
        if len(unknownIDs) > 0 {
-               md := wk.MonitoringMetadata(unknownIDs)
+               md := wk.MonitoringMetadata(ctx, unknownIDs)
                j.AddMetricShortIDs(md)
        }
        // TODO handle side input data properly.
@@ -239,6 +231,7 @@ progress:
        }
        em.PersistBundle(rb, s.OutputsToCoders, b.OutputData, s.inputInfo, 
residualData, minOutputWatermark)
        b.OutputData = engine.TentativeData{} // Clear the data.
+       return nil
 }
 
 func getSideInputs(t *pipepb.PTransform) (map[string]*pipepb.SideInput, error) 
{
diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go 
b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go
index d17deedec8d..98479e3db07 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go
@@ -52,14 +52,11 @@ type B struct {
        dataSema   atomic.Int32
        OutputData engine.TentativeData
 
-       // TODO move response channel to an atomic and an additional
-       // block on the DataWait channel, to allow progress & splits for
-       // no output DoFns.
-       Resp chan *fnpb.ProcessBundleResponse
+       Resp      chan *fnpb.ProcessBundleResponse
+       BundleErr error
+       responded bool
 
        SinkToPCollection map[string]string
-
-       Fail func(err string) // Called if bundle returns an error.
 }
 
 // Init initializes the bundle's internal state for waiting on all
@@ -90,8 +87,13 @@ func (b *B) LogValue() slog.Value {
 }
 
 func (b *B) Respond(resp *fnpb.InstructionResponse) {
+       if b.responded {
+               slog.Warn("additional bundle response", "bundle", b, "resp", 
resp)
+               return
+       }
+       b.responded = true
        if resp.GetError() != "" {
-               b.Fail(resp.GetError())
+               b.BundleErr = fmt.Errorf("bundle %v failed:%v", 
resp.GetInstructionId(), resp.GetError())
                close(b.Resp)
                return
        }
@@ -152,8 +154,8 @@ func (b *B) Cleanup(wk *W) {
 }
 
 // Progress sends a progress request for the given bundle to the passed in 
worker, blocking on the response.
-func (b *B) Progress(wk *W) (*fnpb.ProcessBundleProgressResponse, error) {
-       resp := wk.sendInstruction(&fnpb.InstructionRequest{
+func (b *B) Progress(ctx context.Context, wk *W) 
(*fnpb.ProcessBundleProgressResponse, error) {
+       resp := wk.sendInstruction(ctx, &fnpb.InstructionRequest{
                Request: &fnpb.InstructionRequest_ProcessBundleProgress{
                        ProcessBundleProgress: 
&fnpb.ProcessBundleProgressRequest{
                                InstructionId: b.InstID,
@@ -167,8 +169,8 @@ func (b *B) Progress(wk *W) 
(*fnpb.ProcessBundleProgressResponse, error) {
 }
 
 // Split sends a split request for the given bundle to the passed in worker, 
blocking on the response.
-func (b *B) Split(wk *W, fraction float64, allowedSplits []int64) 
(*fnpb.ProcessBundleSplitResponse, error) {
-       resp := wk.sendInstruction(&fnpb.InstructionRequest{
+func (b *B) Split(ctx context.Context, wk *W, fraction float64, allowedSplits 
[]int64) (*fnpb.ProcessBundleSplitResponse, error) {
+       resp := wk.sendInstruction(ctx, &fnpb.InstructionRequest{
                Request: &fnpb.InstructionRequest_ProcessBundleSplit{
                        ProcessBundleSplit: &fnpb.ProcessBundleSplitRequest{
                                InstructionId: b.InstID,
diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/bundle_test.go 
b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle_test.go
index c747711f8f0..ba5b10f5fd3 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/worker/bundle_test.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle_test.go
@@ -23,7 +23,7 @@ import (
 )
 
 func TestBundle_ProcessOn(t *testing.T) {
-       wk := New("test")
+       wk := New("test", "testEnv")
        b := &B{
                InstID:      "testInst",
                PBDID:       "testPBDID",
diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go 
b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go
index 405c1e812a4..0ad7ccb3703 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go
@@ -55,15 +55,15 @@ type W struct {
        fnpb.UnimplementedBeamFnLoggingServer
        fnpb.UnimplementedProvisionServiceServer
 
-       ID string
+       ID, Env string
 
        // Server management
        lis    net.Listener
        server *grpc.Server
 
        // These are the ID sources
-       inst, bund uint64
-       connected  atomic.Bool
+       inst, bund         uint64
+       connected, stopped atomic.Bool
 
        InstReqs chan *fnpb.InstructionRequest
        DataReqs chan *fnpb.Elements
@@ -80,7 +80,7 @@ type controlResponder interface {
 }
 
 // New starts the worker server components of FnAPI Execution.
-func New(id string) *W {
+func New(id, env string) *W {
        lis, err := net.Listen("tcp", ":0")
        if err != nil {
                panic(fmt.Sprintf("failed to listen: %v", err))
@@ -90,6 +90,7 @@ func New(id string) *W {
        }
        wk := &W{
                ID:     id,
+               Env:    env,
                lis:    lis,
                server: grpc.NewServer(opts...),
 
@@ -133,6 +134,7 @@ func (wk *W) LogValue() slog.Value {
 // Stop the GRPC server.
 func (wk *W) Stop() {
        slog.Debug("stopping", "worker", wk)
+       wk.stopped.Store(true)
        close(wk.InstReqs)
        close(wk.DataReqs)
        wk.server.Stop()
@@ -246,6 +248,7 @@ func (wk *W) GetProcessBundleDescriptor(ctx 
context.Context, req *fnpb.GetProces
        return desc, nil
 }
 
+// Connected indicates whether the worker has connected to the control RPC.
 func (wk *W) Connected() bool {
        return wk.connected.Load()
 }
@@ -255,27 +258,26 @@ func (wk *W) Connected() bool {
 // Requests come from the runner, and are sent to the client in the SDK.
 func (wk *W) Control(ctrl fnpb.BeamFnControl_ControlServer) error {
        wk.connected.Store(true)
-       done := make(chan struct{})
+       done := make(chan error, 1)
        go func() {
                for {
                        resp, err := ctrl.Recv()
                        if err == io.EOF {
                                slog.Debug("ctrl.Recv finished; marking done", 
"worker", wk)
-                               done <- struct{}{} // means stream is finished
+                               done <- nil // means stream is finished
                                return
                        }
                        if err != nil {
                                switch status.Code(err) {
                                case codes.Canceled:
-                                       done <- struct{}{} // means stream is 
finished
+                                       done <- err // means stream is finished
                                        return
                                default:
-                                       slog.Error("ctrl.Recv failed", err, 
"worker", wk)
+                                       slog.Error("ctrl.Recv failed", "error", 
err, "worker", wk)
                                        panic(err)
                                }
                        }
 
-                       // TODO: Do more than assume these are 
ProcessBundleResponses.
                        wk.mu.Lock()
                        if b, ok := 
wk.activeInstructions[resp.GetInstructionId()]; ok {
                                b.Respond(resp)
@@ -288,19 +290,33 @@ func (wk *W) Control(ctrl 
fnpb.BeamFnControl_ControlServer) error {
 
        for {
                select {
-               case req := <-wk.InstReqs:
-                       err := ctrl.Send(req)
-                       if err != nil {
-                               go func() { <-done }()
+               case req, ok := <-wk.InstReqs:
+                       if !ok {
+                               slog.Debug("Worker shutting down.", "worker", 
wk)
+                               return nil
+                       }
+                       if err := ctrl.Send(req); err != nil {
                                return err
                        }
                case <-ctrl.Context().Done():
-                       slog.Debug("Control context canceled")
-                       go func() { <-done }()
-                       return ctrl.Context().Err()
-               case <-done:
-                       slog.Debug("Control done")
-                       return nil
+                       wk.mu.Lock()
+                       // Fail extant instructions
+                       slog.Debug("SDK Disconnected", "worker", wk, 
"ctx_error", ctrl.Context().Err(), "outstanding_instructions", 
len(wk.activeInstructions))
+                       for instID, b := range wk.activeInstructions {
+                               b.Respond(&fnpb.InstructionResponse{
+                                       InstructionId: instID,
+                                       Error:         "SDK Disconnected",
+                               })
+                       }
+                       wk.mu.Unlock()
+                       return context.Cause(ctrl.Context())
+               case err := <-done:
+                       if err != nil {
+                               slog.Warn("Control done", "error", err, 
"worker", wk)
+                       } else {
+                               slog.Debug("Control done", "worker", wk)
+                       }
+                       return err
                }
        }
 }
@@ -359,7 +375,7 @@ func (wk *W) Data(data fnpb.BeamFnData_DataServer) error {
                        }
                case <-data.Context().Done():
                        slog.Debug("Data context canceled")
-                       return data.Context().Err()
+                       return context.Cause(data.Context())
                }
        }
 }
@@ -394,8 +410,12 @@ func (wk *W) State(state fnpb.BeamFnState_StateServer) 
error {
 
                                // State requests are always for an active 
ProcessBundle instruction
                                wk.mu.Lock()
-                               b := 
wk.activeInstructions[req.GetInstructionId()].(*B)
+                               b, ok := 
wk.activeInstructions[req.GetInstructionId()].(*B)
                                wk.mu.Unlock()
+                               if !ok {
+                                       slog.Warn("state request after bundle 
inactive", "instruction", req.GetInstructionId(), "worker", wk)
+                                       continue
+                               }
                                key := req.GetStateKey()
                                slog.Debug("StateRequest_Get", 
prototext.Format(req), "bundle", b)
 
@@ -490,7 +510,7 @@ func (cr *chanResponder) Respond(resp 
*fnpb.InstructionResponse) {
 
 // sendInstruction is a helper for creating and sending worker single RPCs, 
blocking
 // until the response returns.
-func (wk *W) sendInstruction(req *fnpb.InstructionRequest) 
*fnpb.InstructionResponse {
+func (wk *W) sendInstruction(ctx context.Context, req 
*fnpb.InstructionRequest) *fnpb.InstructionResponse {
        cr := chanResponderPool.Get().(*chanResponder)
        progInst := wk.NextInst()
        wk.mu.Lock()
@@ -506,15 +526,26 @@ func (wk *W) sendInstruction(req 
*fnpb.InstructionRequest) *fnpb.InstructionResp
 
        req.InstructionId = progInst
 
-       // Tell the SDK to start processing the bundle.
+       if wk.stopped.Load() {
+               return nil
+       }
        wk.InstReqs <- req
-       // Protos are safe as nil, so just return directly.
-       return <-cr.Resp
+
+       select {
+       case <-ctx.Done():
+               return &fnpb.InstructionResponse{
+                       InstructionId: progInst,
+                       Error:         "context canceled before receive",
+               }
+       case resp := <-cr.Resp:
+               // Protos are safe as nil, so just return directly.
+               return resp
+       }
 }
 
 // MonitoringMetadata is a convenience method to request the metadata for 
monitoring shortIDs.
-func (wk *W) MonitoringMetadata(unknownIDs []string) 
*fnpb.MonitoringInfosMetadataResponse {
-       return wk.sendInstruction(&fnpb.InstructionRequest{
+func (wk *W) MonitoringMetadata(ctx context.Context, unknownIDs []string) 
*fnpb.MonitoringInfosMetadataResponse {
+       return wk.sendInstruction(ctx, &fnpb.InstructionRequest{
                Request: &fnpb.InstructionRequest_MonitoringInfos{
                        MonitoringInfos: &fnpb.MonitoringInfosMetadataRequest{
                                MonitoringInfoId: unknownIDs,
diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go 
b/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go
index 060c073fa12..ed61f484481 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/worker/worker_test.go
@@ -32,14 +32,14 @@ import (
 )
 
 func TestWorker_New(t *testing.T) {
-       w := New("test")
+       w := New("test", "testEnv")
        if got, want := w.ID, "test"; got != want {
                t.Errorf("New(%q) = %v, want %v", want, got, want)
        }
 }
 
 func TestWorker_NextInst(t *testing.T) {
-       w := New("test")
+       w := New("test", "testEnv")
 
        instIDs := map[string]struct{}{}
        for i := 0; i < 100; i++ {
@@ -51,7 +51,7 @@ func TestWorker_NextInst(t *testing.T) {
 }
 
 func TestWorker_NextStage(t *testing.T) {
-       w := New("test")
+       w := New("test", "testEnv")
 
        stageIDs := map[string]struct{}{}
        for i := 0; i < 100; i++ {
@@ -63,7 +63,7 @@ func TestWorker_NextStage(t *testing.T) {
 }
 
 func TestWorker_GetProcessBundleDescriptor(t *testing.T) {
-       w := New("test")
+       w := New("test", "testEnv")
 
        id := "available"
        w.Descriptors[id] = &fnpb.ProcessBundleDescriptor{
@@ -93,7 +93,7 @@ func serveTestWorker(t *testing.T) (context.Context, *W, 
*grpc.ClientConn) {
        ctx, cancelFn := context.WithCancel(context.Background())
        t.Cleanup(cancelFn)
 
-       w := New("test")
+       w := New("test", "testEnv")
        lis := bufconn.Listen(2048)
        w.lis = lis
        t.Cleanup(func() { w.Stop() })
diff --git a/sdks/go/pkg/beam/testing/passert/equals_test.go 
b/sdks/go/pkg/beam/testing/passert/equals_test.go
index a8a5c835f8f..6e6578dd3e4 100644
--- a/sdks/go/pkg/beam/testing/passert/equals_test.go
+++ b/sdks/go/pkg/beam/testing/passert/equals_test.go
@@ -184,7 +184,7 @@ func ExampleEqualsList_mismatch() {
        err = unwrapError(err)
 
        // Process error for cleaner example output, demonstrating the diff.
-       processedErr := strings.SplitAfter(err.Error(), 
"/passert.failIfBadEntries] failed:")
+       processedErr := strings.SplitAfter(err.Error(), ".failIfBadEntries] 
failed:")
        fmt.Println(processedErr[1])
 
        // Output:

Reply via email to