This is an automated email from the ASF dual-hosted git repository.

lostluck pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 644f5399dce Move to a conditionVariable for messages+state stream + 
test. (#27060)
644f5399dce is described below

commit 644f5399dce646ba9b6b0146b44ff7778fc0a8c1
Author: Robert Burke <[email protected]>
AuthorDate: Thu Jun 15 12:31:41 2023 -0700

    Move to a conditionVariable for messages+state stream + test. (#27060)
    
    Co-authored-by: lostluck <[email protected]>
---
 .../beam/runners/prism/internal/jobservices/job.go |  45 +++-
 .../prism/internal/jobservices/management.go       |  83 +++++---
 .../prism/internal/jobservices/management_test.go  | 230 ++++++++++++++++++++-
 3 files changed, 315 insertions(+), 43 deletions(-)

diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go 
b/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go
index 5b8e786ac6f..4ac37c5db59 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go
@@ -29,6 +29,7 @@ import (
        "fmt"
        "sort"
        "strings"
+       "sync"
        "sync/atomic"
 
        fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1"
@@ -70,9 +71,12 @@ type Job struct {
        options  *structpb.Struct
 
        // Management side concerns.
-       msgChan   chan string
-       state     atomic.Value // jobpb.JobState_Enum
-       stateChan chan jobpb.JobState_Enum
+       streamCond *sync.Cond
+       // TODO, consider unifying messages and state to a single ordered 
buffer.
+       minMsg, maxMsg int // logical indices into the message slice
+       msgs           []string
+       stateIdx       int
+       state          atomic.Value // jobpb.JobState_Enum
 
        // Context used to terminate this job.
        RootCtx  context.Context
@@ -107,25 +111,50 @@ func (j *Job) LogValue() slog.Value {
 }
 
 func (j *Job) SendMsg(msg string) {
-       j.msgChan <- msg
+       j.streamCond.L.Lock()
+       defer j.streamCond.L.Unlock()
+       j.maxMsg++
+       // Trim so we never have more than 120 messages, keeping the last 100 
for sure
+       // but amortize it so that messages are only trimmed every 20 messages 
beyond
+       // that.
+       // TODO, make this configurable
+       const buffered, trigger = 100, 20
+       if len(j.msgs) > buffered+trigger {
+               copy(j.msgs[0:], j.msgs[trigger:])
+               for k, n := len(j.msgs)-trigger, len(j.msgs); k < n; k++ {
+                       j.msgs[k] = ""
+               }
+               j.msgs = j.msgs[:len(j.msgs)-trigger]
+               j.minMsg += trigger // increase the "min" message higher as a 
result.
+       }
+       j.msgs = append(j.msgs, msg)
+       j.streamCond.Broadcast()
+}
+
+func (j *Job) sendState(state jobpb.JobState_Enum) {
+       j.streamCond.L.Lock()
+       defer j.streamCond.L.Unlock()
+       j.stateIdx++
+       j.state.Store(state)
+       j.streamCond.Broadcast()
 }
 
 // Start indicates that the job is preparing to execute.
 func (j *Job) Start() {
-       j.stateChan <- jobpb.JobState_STARTING
+       j.sendState(jobpb.JobState_STARTING)
 }
 
 // Running indicates that the job is executing.
 func (j *Job) Running() {
-       j.stateChan <- jobpb.JobState_RUNNING
+       j.sendState(jobpb.JobState_RUNNING)
 }
 
 // Done indicates that the job completed successfully.
 func (j *Job) Done() {
-       j.stateChan <- jobpb.JobState_DONE
+       j.sendState(jobpb.JobState_DONE)
 }
 
 // Failed indicates that the job completed unsuccessfully.
 func (j *Job) Failed() {
-       j.stateChan <- jobpb.JobState_FAILED
+       j.sendState(jobpb.JobState_FAILED)
 }
diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go 
b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go
index 6e774332f0e..cecd95536ae 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go
@@ -18,6 +18,7 @@ package jobservices
 import (
        "context"
        "fmt"
+       "sync"
        "sync/atomic"
 
        jobpb 
"github.com/apache/beam/sdks/v2/go/pkg/beam/model/jobmanagement_v1"
@@ -69,20 +70,17 @@ func (s *Server) Prepare(ctx context.Context, req 
*jobpb.PrepareJobRequest) (*jo
        // Since jobs execute in the background, they should not be tied to a 
request's context.
        rootCtx, cancelFn := context.WithCancel(context.Background())
        job := &Job{
-               key:      s.nextId(),
-               Pipeline: req.GetPipeline(),
-               jobName:  req.GetJobName(),
-               options:  req.GetPipelineOptions(),
-
-               msgChan:   make(chan string, 100),
-               stateChan: make(chan jobpb.JobState_Enum, 1),
-               RootCtx:   rootCtx,
-               CancelFn:  cancelFn,
+               key:        s.nextId(),
+               Pipeline:   req.GetPipeline(),
+               jobName:    req.GetJobName(),
+               options:    req.GetPipelineOptions(),
+               streamCond: sync.NewCond(&sync.Mutex{}),
+               RootCtx:    rootCtx,
+               CancelFn:   cancelFn,
        }
 
        // Queue initial state of the job.
        job.state.Store(jobpb.JobState_STOPPED)
-       job.stateChan <- job.state.Load().(jobpb.JobState_Enum)
 
        if err := isSupported(job.Pipeline.GetRequirements()); err != nil {
                slog.Error("unable to run job", slog.String("error", 
err.Error()), slog.String("jobname", req.GetJobName()))
@@ -165,15 +163,45 @@ func (s *Server) Run(ctx context.Context, req 
*jobpb.RunJobRequest) (*jobpb.RunJ
        }, nil
 }
 
-// GetMessageStream subscribes to a stream of state changes and messages from 
the job
+// GetMessageStream subscribes to a stream of state changes and messages from 
the job. If throughput
+// is high, this may cause losses of messages.
 func (s *Server) GetMessageStream(req *jobpb.JobMessagesRequest, stream 
jobpb.JobService_GetMessageStreamServer) error {
        s.mu.Lock()
-       job := s.jobs[req.GetJobId()]
+       job, ok := s.jobs[req.GetJobId()]
        s.mu.Unlock()
+       if !ok {
+               return fmt.Errorf("job with id %v not found", req.GetJobId())
+       }
 
+       job.streamCond.L.Lock()
+       defer job.streamCond.L.Unlock()
+       curMsg := job.minMsg
+       curState := job.stateIdx
+
+       stream.Context()
+
+       state := job.state.Load().(jobpb.JobState_Enum)
        for {
-               select {
-               case msg := <-job.msgChan:
+               for (curMsg >= job.maxMsg || len(job.msgs) == 0) && curState > 
job.stateIdx {
+                       switch state {
+                       case jobpb.JobState_CANCELLED, jobpb.JobState_DONE, 
jobpb.JobState_DRAINED, jobpb.JobState_FAILED, jobpb.JobState_UPDATED:
+                               // Reached terminal state.
+                               return nil
+                       }
+                       job.streamCond.Wait()
+                       select { // Quit out if the external connection is done.
+                       case <-stream.Context().Done():
+                               return stream.Context().Err()
+                       default:
+                       }
+               }
+
+               if curMsg < job.minMsg {
+                       // TODO report missed messages for this stream.
+                       curMsg = job.minMsg
+               }
+               for curMsg < job.maxMsg && len(job.msgs) > 0 {
+                       msg := job.msgs[curMsg-job.minMsg]
                        stream.Send(&jobpb.JobMessagesResponse{
                                Response: 
&jobpb.JobMessagesResponse_MessageResponse{
                                        MessageResponse: &jobpb.JobMessage{
@@ -182,20 +210,12 @@ func (s *Server) GetMessageStream(req 
*jobpb.JobMessagesRequest, stream jobpb.Jo
                                        },
                                },
                        })
-
-               case state, ok := <-job.stateChan:
-                       // TODO: Don't block job execution if WaitForCompletion 
isn't being run.
-                       // The state channel means the job may only execute if 
something is observing
-                       // the message stream, as the send on the state or 
message channel may block
-                       // once full.
-                       // Not a problem for tests or short lived batch, but 
would be hazardous for
-                       // asynchronous jobs.
-
-                       // Channel is closed, so the job must be done.
-                       if !ok {
-                               state = jobpb.JobState_DONE
-                       }
-                       job.state.Store(state)
+                       curMsg++
+               }
+               if curState <= job.stateIdx {
+                       state = job.state.Load().(jobpb.JobState_Enum)
+                       curState = job.stateIdx + 1
+                       job.streamCond.L.Unlock()
                        stream.Send(&jobpb.JobMessagesResponse{
                                Response: 
&jobpb.JobMessagesResponse_StateResponse{
                                        StateResponse: &jobpb.JobStateEvent{
@@ -203,14 +223,9 @@ func (s *Server) GetMessageStream(req 
*jobpb.JobMessagesRequest, stream jobpb.Jo
                                        },
                                },
                        })
-                       switch state {
-                       case jobpb.JobState_CANCELLED, jobpb.JobState_DONE, 
jobpb.JobState_DRAINED, jobpb.JobState_FAILED, jobpb.JobState_UPDATED:
-                               // Reached terminal state.
-                               return nil
-                       }
+                       job.streamCond.L.Lock()
                }
        }
-
 }
 
 // GetJobMetrics Fetch metrics for a given job.
diff --git 
a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management_test.go 
b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management_test.go
index b7861276702..5813e6ef73e 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management_test.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management_test.go
@@ -17,6 +17,10 @@ package jobservices
 
 import (
        "context"
+       "errors"
+       "fmt"
+       "io"
+       "net"
        "sync"
        "testing"
 
@@ -27,6 +31,9 @@ import (
        "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns"
        "github.com/google/go-cmp/cmp"
        "github.com/google/go-cmp/cmp/cmpopts"
+       "google.golang.org/grpc"
+       "google.golang.org/grpc/credentials/insecure"
+       "google.golang.org/grpc/test/bufconn"
        "google.golang.org/protobuf/testing/protocmp"
 )
 
@@ -218,4 +225,225 @@ func TestServer(t *testing.T) {
        }
 }
 
-// TODO impelment message stream test, once message/State implementation is 
sync.Cond based.
+func TestGetMessageStream(t *testing.T) {
+       wantName := "testJob"
+       wantPipeline := &pipepb.Pipeline{
+               Requirements: []string{urns.RequirementSplittableDoFn},
+       }
+       var called sync.WaitGroup
+       called.Add(1)
+       ctx, _, clientConn := serveTestServer(t, func(j *Job) {
+               j.Start()
+               j.SendMsg("job starting")
+               j.Running()
+               j.SendMsg("job running")
+               j.SendMsg("job finished")
+               j.Done()
+               j.SendMsg("job done")
+               called.Done()
+       })
+       jobCli := jobpb.NewJobServiceClient(clientConn)
+
+       // PreJob submission
+       msgStream, err := jobCli.GetMessageStream(ctx, 
&jobpb.JobMessagesRequest{
+               JobId: "job-001",
+       })
+       if err != nil {
+               t.Errorf("GetMessageStream: wanted successful connection, got 
%v", err)
+       }
+       _, err = msgStream.Recv()
+       if err == nil {
+               t.Error("wanted error on non-existent job, but didn't happen.")
+       }
+
+       prepResp, err := jobCli.Prepare(ctx, &jobpb.PrepareJobRequest{
+               Pipeline: wantPipeline,
+               JobName:  wantName,
+       })
+       if err != nil {
+               t.Fatalf("Prepare(%v) = %v, want nil", wantName, err)
+       }
+
+       // Post Job submission
+       msgStream, err = jobCli.GetMessageStream(ctx, &jobpb.JobMessagesRequest{
+               JobId: "job-001",
+       })
+       if err != nil {
+               t.Errorf("GetMessageStream: wanted successful connection, got 
%v", err)
+       }
+       stateResponse, err := msgStream.Recv()
+       if err != nil {
+               t.Errorf("GetMessageStream().Recv() = %v, want nil", err)
+       }
+       if got, want := stateResponse.GetStateResponse().GetState(), 
jobpb.JobState_STOPPED; got != want {
+               t.Errorf("GetMessageStream().Recv() = %v, want %v", got, want)
+       }
+
+       _, err = jobCli.Run(ctx, &jobpb.RunJobRequest{
+               PreparationId: prepResp.GetPreparationId(),
+       })
+       if err != nil {
+               t.Fatalf("Run(%v) = %v, want nil", wantName, err)
+       }
+
+       called.Wait() // Wait for the job to terminate.
+
+       receivedDone := false
+       var msgCount int
+       for {
+               // Continue with the same message stream.
+               resp, err := msgStream.Recv()
+               if err != nil {
+                       if errors.Is(err, io.EOF) {
+                               break // successful message stream completion
+                       }
+                       t.Errorf("GetMessageStream().Recv() = %v, want nil", 
err)
+               }
+               switch {
+
+               case resp.GetMessageResponse() != nil:
+                       msgCount++
+               case resp.GetStateResponse() != nil:
+                       if resp.GetStateResponse().GetState() == 
jobpb.JobState_DONE {
+                               receivedDone = true
+                       }
+               }
+       }
+       if got, want := msgCount, 4; got != want {
+               t.Errorf("GetMessageStream() didn't correct number of messages, 
got %v, want %v", got, want)
+       }
+       if !receivedDone {
+               t.Error("GetMessageStream() didn't return job done state")
+       }
+       msgStream.CloseSend()
+
+       // Create a new message stream, we should still get a tail of messages 
(in this case, all of them)
+       // And the final state.
+       msgStream, err = jobCli.GetMessageStream(ctx, &jobpb.JobMessagesRequest{
+               JobId: "job-001",
+       })
+       if err != nil {
+               t.Errorf("GetMessageStream: wanted successful connection, got 
%v", err)
+       }
+
+       receivedDone = false
+       msgCount = 0
+       for {
+               // Continue with the same message stream.
+               resp, err := msgStream.Recv()
+               if err != nil {
+                       if errors.Is(err, io.EOF) {
+                               break // successful message stream completion
+                       }
+                       t.Errorf("GetMessageStream().Recv() = %v, want nil", 
err)
+               }
+               switch {
+
+               case resp.GetMessageResponse() != nil:
+                       msgCount++
+               case resp.GetStateResponse() != nil:
+                       if resp.GetStateResponse().GetState() == 
jobpb.JobState_DONE {
+                               receivedDone = true
+                       }
+               }
+       }
+       if got, want := msgCount, 4; got != want {
+               t.Errorf("GetMessageStream() didn't correct number of messages, 
got %v, want %v", got, want)
+       }
+       if !receivedDone {
+               t.Error("GetMessageStream() didn't return job done state")
+       }
+}
+
+func TestGetMessageStream_BufferCycling(t *testing.T) {
+       wantName := "testJob"
+       wantPipeline := &pipepb.Pipeline{
+               Requirements: []string{urns.RequirementSplittableDoFn},
+       }
+       var called sync.WaitGroup
+       called.Add(1)
+       ctx, _, clientConn := serveTestServer(t, func(j *Job) {
+               j.Start()
+               // Using an offset from the trigger amount to ensure expected
+               // behavior (we can sometimes get more than the last 100 
messages).
+               for i := 0; i < 512; i++ {
+                       j.SendMsg(fmt.Sprintf("message number %v", i))
+               }
+               j.Done()
+               called.Done()
+       })
+       jobCli := jobpb.NewJobServiceClient(clientConn)
+
+       prepResp, err := jobCli.Prepare(ctx, &jobpb.PrepareJobRequest{
+               Pipeline: wantPipeline,
+               JobName:  wantName,
+       })
+       if err != nil {
+               t.Fatalf("Prepare(%v) = %v, want nil", wantName, err)
+       }
+       _, err = jobCli.Run(ctx, &jobpb.RunJobRequest{
+               PreparationId: prepResp.GetPreparationId(),
+       })
+       if err != nil {
+               t.Fatalf("Run(%v) = %v, want nil", wantName, err)
+       }
+
+       called.Wait() // Wait for the job to terminate.
+
+       // Create a new message stream, we should still get a tail of messages 
(in this case, all of them)
+       // And the final state.
+       msgStream, err := jobCli.GetMessageStream(ctx, 
&jobpb.JobMessagesRequest{
+               JobId: "job-001",
+       })
+       if err != nil {
+               t.Errorf("GetMessageStream: wanted successful connection, got 
%v", err)
+       }
+
+       receivedDone := false
+       var msgCount int
+       for {
+               // Continue with the same message stream.
+               resp, err := msgStream.Recv()
+               if err != nil {
+                       if errors.Is(err, io.EOF) {
+                               break // successful message stream completion
+                       }
+                       t.Errorf("GetMessageStream().Recv() = %v, want nil", 
err)
+               }
+               switch {
+               case resp.GetMessageResponse() != nil:
+                       msgCount++
+               case resp.GetStateResponse() != nil:
+                       if resp.GetStateResponse().GetState() == 
jobpb.JobState_DONE {
+                               receivedDone = true
+                       }
+               }
+       }
+       if got, want := msgCount, 112; got != want {
+               t.Errorf("GetMessageStream() didn't correct number of messages, 
got %v, want %v", got, want)
+       }
+       if !receivedDone {
+               t.Error("GetMessageStream() didn't return job done state")
+       }
+
+}
+
+func serveTestServer(t *testing.T, execute func(j *Job)) (context.Context, 
*Server, *grpc.ClientConn) {
+       t.Helper()
+       ctx, cancelFn := context.WithCancel(context.Background())
+       t.Cleanup(cancelFn)
+
+       s := NewServer(0, execute)
+       lis := bufconn.Listen(1024 * 64)
+       s.lis = lis
+       t.Cleanup(func() { s.Stop() })
+       go s.Serve()
+
+       clientConn, err := grpc.DialContext(ctx, "", 
grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) {
+               return lis.DialContext(ctx)
+       }), grpc.WithTransportCredentials(insecure.NewCredentials()), 
grpc.WithBlock())
+       if err != nil {
+               t.Fatal("couldn't create bufconn grpc connection:", err)
+       }
+       return ctx, s, clientConn
+}

Reply via email to