This is an automated email from the ASF dual-hosted git repository.

lostluck pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 48adde999b9 [#29917][prism] Initial TestStream support (#30072)
48adde999b9 is described below

commit 48adde999b9212c5bae8a330111fe8739fc1fbde
Author: Robert Burke <lostl...@users.noreply.github.com>
AuthorDate: Fri Feb 16 11:49:29 2024 -0800

    [#29917][prism] Initial TestStream support (#30072)
---
 .../prism/internal/engine/elementmanager.go        |  99 ++++++--
 .../runners/prism/internal/engine/engine_test.go   |  47 ++++
 .../runners/prism/internal/engine/teststream.go    | 269 +++++++++++++++++++++
 sdks/go/pkg/beam/runners/prism/internal/execute.go |  54 +++++
 .../prism/internal/jobservices/management.go       |  18 ++
 .../runners/prism/internal/unimplemented_test.go   |  43 +++-
 sdks/go/pkg/beam/testing/teststream/teststream.go  |   8 +-
 sdks/go/test/integration/integration.go            |   7 +-
 sdks/go/test/integration/primitives/teststream.go  |  43 +++-
 .../test/integration/primitives/teststream_test.go |  10 +
 10 files changed, 548 insertions(+), 50 deletions(-)

diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go 
b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go
index 077d6386315..28ea75ac9e5 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go
@@ -166,6 +166,8 @@ type ElementManager struct {
 
        livePending     atomic.Int64   // An accessible live pending count. 
DEBUG USE ONLY
        pendingElements sync.WaitGroup // pendingElements counts all 
unprocessed elements in a job. Jobs with no pending elements terminate 
successfully.
+
+       testStreamHandler *testStreamHandler // Optional test stream handler 
when a test stream is in the pipeline.
 }
 
 func (em *ElementManager) addPending(v int) {
@@ -223,6 +225,15 @@ func (em *ElementManager) StageStateful(ID string) {
        em.stages[ID].stateful = true
 }
 
+// AddTestStream provides a builder interface for the execution layer to build 
the test stream from
+// the protos.
+func (em *ElementManager) AddTestStream(id string, tagToPCol 
map[string]string) TestStreamBuilder {
+       impl := &testStreamImpl{em: em}
+       impl.initHandler(id)
+       impl.TagsToPCollections(tagToPCol)
+       return impl
+}
+
 // Impulse marks and initializes the given stage as an impulse which
 // is a root transform that starts processing.
 func (em *ElementManager) Impulse(stageID string) {
@@ -319,37 +330,72 @@ func (em *ElementManager) Bundles(ctx context.Context, 
nextBundID func() string)
                                        em.refreshCond.L.Lock()
                                }
                        }
-                       if len(em.inprogressBundles) == 0 && 
len(em.watermarkRefreshes) == 0 {
-                               v := em.livePending.Load()
-                               slog.Debug("Bundles: nothing in progress and no 
refreshes", slog.Int64("pendingElementCount", v))
-                               if v > 0 {
-                                       var stageState []string
-                                       ids := maps.Keys(em.stages)
-                                       sort.Strings(ids)
-                                       for _, id := range ids {
-                                               ss := em.stages[id]
-                                               inW := ss.InputWatermark()
-                                               outW := ss.OutputWatermark()
-                                               upPCol, upW := 
ss.UpstreamWatermark()
-                                               upS := em.pcolParents[upPCol]
-                                               stageState = append(stageState, 
fmt.Sprintln(id, "watermark in", inW, "out", outW, "upstream", upW, "from", 
upS, "pending", ss.pending, "byKey", ss.pendingByKeys, "inprogressKeys", 
ss.inprogressKeys, "byBundle", ss.inprogressKeysByBundle, "holds", 
ss.watermarkHoldHeap, "holdCounts", ss.watermarkHoldsCounts))
-                                       }
-                                       panic(fmt.Sprintf("nothing in progress 
and no refreshes with non zero pending elements: %v\n%v", v, 
strings.Join(stageState, "")))
-                               }
-                       } else if len(em.inprogressBundles) == 0 {
-                               v := em.livePending.Load()
-                               slog.Debug("Bundles: nothing in progress after 
advance",
-                                       slog.Any("advanced", advanced),
-                                       slog.Int("refreshCount", 
len(em.watermarkRefreshes)),
-                                       slog.Int64("pendingElementCount", v),
-                               )
-                       }
-                       em.refreshCond.L.Unlock()
+                       em.checkForQuiescence(advanced)
                }
        }()
        return runStageCh
 }
 
+// checkForQuiescence sees if this element manager is no longer able to do any 
pending work or make progress.
+//
+// Quiescense can happen if there are no inprogress bundles, and there are no 
further watermark refreshes, which
+// are the only way to access new pending elements. If there are no pending 
elements, then the pipeline will
+// terminate successfully.
+//
+// Otherwise, produce information for debugging why the pipeline is stuck and 
take appropriate action, such as
+// executing off the next TestStream event.
+//
+// Must be called while holding em.refreshCond.L.
+func (em *ElementManager) checkForQuiescence(advanced set[string]) {
+       defer em.refreshCond.L.Unlock()
+       if len(em.inprogressBundles) > 0 {
+               // If there are bundles in progress, then there may be 
watermark refreshes when they terminate.
+               return
+       }
+       if len(em.watermarkRefreshes) > 0 {
+               // If there are watermarks to refresh, we aren't yet stuck.
+               v := em.livePending.Load()
+               slog.Debug("Bundles: nothing in progress after advance",
+                       slog.Any("advanced", advanced),
+                       slog.Int("refreshCount", len(em.watermarkRefreshes)),
+                       slog.Int64("pendingElementCount", v),
+               )
+               return
+       }
+       // The job has quiesced!
+
+       // There are no further incoming watermark changes, see if there are 
test stream events for this job.
+       nextEvent := em.testStreamHandler.NextEvent()
+       if nextEvent != nil {
+               nextEvent.Execute(em)
+               // Decrement pending for the event being processed.
+               em.addPending(-1)
+               return
+       }
+
+       v := em.livePending.Load()
+       if v == 0 {
+               // Since there are no further pending elements, the job will be 
terminating successfully.
+               return
+       }
+       // The job is officially stuck. Fail fast and produce debugging 
information.
+       // Jobs must never get stuck so this indicates a bug in prism to be 
investigated.
+
+       slog.Debug("Bundles: nothing in progress and no refreshes", 
slog.Int64("pendingElementCount", v))
+       var stageState []string
+       ids := maps.Keys(em.stages)
+       sort.Strings(ids)
+       for _, id := range ids {
+               ss := em.stages[id]
+               inW := ss.InputWatermark()
+               outW := ss.OutputWatermark()
+               upPCol, upW := ss.UpstreamWatermark()
+               upS := em.pcolParents[upPCol]
+               stageState = append(stageState, fmt.Sprintln(id, "watermark 
in", inW, "out", outW, "upstream", upW, "from", upS, "pending", ss.pending, 
"byKey", ss.pendingByKeys, "inprogressKeys", ss.inprogressKeys, "byBundle", 
ss.inprogressKeysByBundle, "holds", ss.watermarkHoldHeap, "holdCounts", 
ss.watermarkHoldsCounts))
+       }
+       panic(fmt.Sprintf("nothing in progress and no refreshes with non zero 
pending elements: %v\n%v", v, strings.Join(stageState, "")))
+}
+
 // InputForBundle returns pre-allocated data for the given bundle, encoding 
the elements using
 // the PCollection's coders.
 func (em *ElementManager) InputForBundle(rb RunBundle, info PColInfo) [][]byte 
{
@@ -429,6 +475,7 @@ const (
        BlockTimer                  // BlockTimer represents timers for the 
bundle.
 )
 
+// Block represents a contiguous set of data or timers for the same 
destination.
 type Block struct {
        Kind              BlockKind
        Bytes             [][]byte
diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/engine_test.go 
b/sdks/go/pkg/beam/runners/prism/internal/engine/engine_test.go
index 6a39b9d2070..04269e3dd6a 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/engine/engine_test.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/engine/engine_test.go
@@ -169,3 +169,50 @@ func TestElementManagerCoverage(t *testing.T) {
                })
        }
 }
+
+func TestTestStream(t *testing.T) {
+       initRunner(t)
+
+       tests := []struct {
+               pipeline func(s beam.Scope)
+       }{
+               {pipeline: primitives.TestStreamBoolSequence},
+               {pipeline: primitives.TestStreamByteSliceSequence},
+               {pipeline: primitives.TestStreamFloat64Sequence},
+               {pipeline: primitives.TestStreamInt64Sequence},
+               {pipeline: primitives.TestStreamInt16Sequence},
+               {pipeline: primitives.TestStreamStrings},
+               {pipeline: primitives.TestStreamTwoBoolSequences},
+               {pipeline: primitives.TestStreamTwoFloat64Sequences},
+               {pipeline: primitives.TestStreamTwoInt64Sequences},
+               {pipeline: primitives.TestStreamTwoUserTypeSequences},
+       }
+
+       configs := []struct {
+               name                              string
+               OneElementPerKey, OneKeyPerBundle bool
+       }{
+               {"Greedy", false, false},
+               {"AllElementsPerKey", false, true},
+               {"OneElementPerKey", true, false},
+               {"OneElementPerBundle", true, true},
+       }
+       for _, config := range configs {
+               for _, test := range tests {
+                       t.Run(initTestName(test.pipeline)+"_"+config.name, 
func(t *testing.T) {
+                               t.Cleanup(func() {
+                                       engine.OneElementPerKey = false
+                                       engine.OneKeyPerBundle = false
+                               })
+                               engine.OneElementPerKey = 
config.OneElementPerKey
+                               engine.OneKeyPerBundle = config.OneKeyPerBundle
+                               p, s := beam.NewPipelineWithRoot()
+                               test.pipeline(s)
+                               _, err := executeWithT(context.Background(), t, 
p)
+                               if err != nil {
+                                       t.Fatalf("pipeline failed, but feature 
should be implemented in Prism: %v", err)
+                               }
+                       })
+               }
+       }
+}
diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/teststream.go 
b/sdks/go/pkg/beam/runners/prism/internal/engine/teststream.go
new file mode 100644
index 00000000000..c0a0ff8ebe7
--- /dev/null
+++ b/sdks/go/pkg/beam/runners/prism/internal/engine/teststream.go
@@ -0,0 +1,269 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package engine
+
+import (
+       "time"
+
+       "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/mtime"
+       "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window"
+       "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
+)
+
+// We define our own element wrapper and similar to avoid depending on the 
protos within the
+// engine package. This improves compile times, and readability of this 
package.
+
+// TestStreamHandler manages TestStreamEvents for the ElementManager.
+//
+// TestStreams are a pipeline root like an Impulse. They kick off computation, 
and
+// strictly manage Watermark advancements.
+//
+// A given pipeline can only have a single TestStream due to test streams
+// requiring a single source of truth for Relative Processing Time advancements
+// and ordering emissions of Elements.
+// All operations with testStreamHandler are expected to be in the element 
manager's
+// refresh lock critical section.
+type testStreamHandler struct {
+       ID string
+
+       nextEventIndex int
+       events         []tsEvent
+       // Initialzed with normal "time.Now", so this does change by relative 
nature.
+       processingTime time.Time // Override for the processing time clock, for 
triggers and ProcessContinuations.
+
+       tagState map[string]tagState // Map from event tag to related outputs.
+
+       completed bool // indicates that no further test stream events exist, 
and all watermarks are advanced to infinity. Used to send the final event, once.
+}
+
+func makeTestStreamHandler(id string) *testStreamHandler {
+       return &testStreamHandler{
+               ID:       id,
+               tagState: map[string]tagState{},
+       }
+}
+
+// tagState tracks state for a given tag.
+type tagState struct {
+       watermark   mtime.Time // Current Watermark for this tag.
+       pcollection string     // ID for the pcollection of this tag to look up 
consumers.
+}
+
+// Now represents the overridden ProcessingTime, which is only advanced when 
directed by an event.
+// Overrides the elementManager "clock".
+func (ts *testStreamHandler) Now() time.Time {
+       return ts.processingTime
+}
+
+// TagsToPCollections recieves the map of local output tags to global 
pcollection ids.
+func (ts *testStreamHandler) TagsToPCollections(tagToPcol map[string]string) {
+       for tag, pcol := range tagToPcol {
+               ts.tagState[tag] = tagState{
+                       watermark:   mtime.MinTimestamp,
+                       pcollection: pcol,
+               }
+               // If there is only one output pcollection, duplicate initial 
state to the
+               // empty tag string.
+               if len(tagToPcol) == 1 {
+                       ts.tagState[""] = ts.tagState[tag]
+               }
+       }
+}
+
+// AddElementEvent adds an element event to the test stream event queue.
+func (ts *testStreamHandler) AddElementEvent(tag string, elements 
[]TestStreamElement) {
+       ts.events = append(ts.events, tsElementEvent{
+               Tag:      tag,
+               Elements: elements,
+       })
+}
+
+// AddWatermarkEvent adds a watermark event to the test stream event queue.
+func (ts *testStreamHandler) AddWatermarkEvent(tag string, newWatermark 
mtime.Time) {
+       ts.events = append(ts.events, tsWatermarkEvent{
+               Tag:          tag,
+               NewWatermark: newWatermark,
+       })
+}
+
+// AddProcessingTimeEvent adds a processing time event to the test stream 
event queue.
+func (ts *testStreamHandler) AddProcessingTimeEvent(d time.Duration) {
+       ts.events = append(ts.events, tsProcessingTimeEvent{
+               AdvanceBy: d,
+       })
+}
+
+// NextEvent returns the next event.
+// If there are no more events, returns nil.
+func (ts *testStreamHandler) NextEvent() tsEvent {
+       if ts == nil {
+               return nil
+       }
+       if ts.nextEventIndex >= len(ts.events) {
+               if !ts.completed {
+                       ts.completed = true
+                       return tsFinalEvent{stageID: ts.ID}
+               }
+               return nil
+       }
+       ev := ts.events[ts.nextEventIndex]
+       ts.nextEventIndex++
+       return ev
+}
+
+// TestStreamElement wraps the provided bytes and timestamp for ingestion and 
use.
+type TestStreamElement struct {
+       Encoded   []byte
+       EventTime mtime.Time
+}
+
+// tsEvent abstracts over the different TestStream Event kinds so we can keep
+// them in the same queue.
+type tsEvent interface {
+       // Execute the associated event on this element manager.
+       Execute(*ElementManager)
+}
+
+// tsElementEvent implements an element event, inserting additional elements
+// to be pending for consuming stages.
+type tsElementEvent struct {
+       Tag      string
+       Elements []TestStreamElement
+}
+
+// Execute this ElementEvent by routing pending element to their consuming 
stages.
+func (ev tsElementEvent) Execute(em *ElementManager) {
+       t := em.testStreamHandler.tagState[ev.Tag]
+
+       var pending []element
+       for _, e := range ev.Elements {
+               pending = append(pending, element{
+                       window:    window.GlobalWindow{},
+                       timestamp: e.EventTime,
+                       elmBytes:  e.Encoded,
+                       pane:      typex.NoFiringPane(),
+               })
+       }
+
+       // Update the consuming state.
+       for _, sID := range em.consumers[t.pcollection] {
+               ss := em.stages[sID]
+               added := ss.AddPending(pending)
+               em.addPending(added)
+               em.watermarkRefreshes.insert(sID)
+       }
+
+       for _, link := range em.sideConsumers[t.pcollection] {
+               ss := em.stages[link.Global]
+               ss.AddPendingSide(pending, link.Transform, link.Local)
+               em.watermarkRefreshes.insert(link.Global)
+       }
+}
+
+// tsWatermarkEvent sets the watermark for the new stage.
+type tsWatermarkEvent struct {
+       Tag          string
+       NewWatermark mtime.Time
+}
+
+// Execute this WatermarkEvent by updating the watermark for the tag, and 
notify affected downstream stages.
+func (ev tsWatermarkEvent) Execute(em *ElementManager) {
+       t := em.testStreamHandler.tagState[ev.Tag]
+
+       if ev.NewWatermark < t.watermark {
+               panic("test stream event decreases watermark. Watermarks cannot 
go backwards.")
+       }
+       t.watermark = ev.NewWatermark
+       em.testStreamHandler.tagState[ev.Tag] = t
+
+       // Update the upstream watermarks in the consumers.
+       for _, sID := range em.consumers[t.pcollection] {
+               ss := em.stages[sID]
+               ss.updateUpstreamWatermark(ss.inputID, t.watermark)
+               em.watermarkRefreshes.insert(sID)
+       }
+}
+
+// tsProcessingTimeEvent implements advancing the synthetic processing time.
+type tsProcessingTimeEvent struct {
+       AdvanceBy time.Duration
+}
+
+// Execute this ProcessingTime event by advancing the synthetic processing 
time.
+func (ev tsProcessingTimeEvent) Execute(em *ElementManager) {
+       em.testStreamHandler.processingTime = 
em.testStreamHandler.processingTime.Add(ev.AdvanceBy)
+}
+
+// tsFinalEvent is the "last" event we perform after all preceeding events.
+// It's automatically inserted once the user defined events have all been 
executed.
+// It updates the upstream watermarks for all consumers to infinity.
+type tsFinalEvent struct {
+       stageID string
+}
+
+func (ev tsFinalEvent) Execute(em *ElementManager) {
+       em.addPending(1) // We subtrack a pending after event execution, so add 
one now.
+       ss := em.stages[ev.stageID]
+       kickSet := ss.updateWatermarks(em)
+       kickSet.insert(ev.stageID)
+       em.watermarkRefreshes.merge(kickSet)
+}
+
+// TestStreamBuilder builds a synthetic sequence of events for the engine to 
execute.
+// A pipeline may only have a single TestStream and may panic.
+type TestStreamBuilder interface {
+       AddElementEvent(tag string, elements []TestStreamElement)
+       AddWatermarkEvent(tag string, newWatermark mtime.Time)
+       AddProcessingTimeEvent(d time.Duration)
+}
+
+type testStreamImpl struct {
+       em *ElementManager
+}
+
+var (
+       _ TestStreamBuilder = (*testStreamImpl)(nil)
+       _ TestStreamBuilder = (*testStreamHandler)(nil)
+)
+
+func (tsi *testStreamImpl) initHandler(id string) {
+       if tsi.em.testStreamHandler == nil {
+               tsi.em.testStreamHandler = makeTestStreamHandler(id)
+       }
+}
+
+// TagsToPCollections recieves the map of local output tags to global 
pcollection ids.
+func (tsi *testStreamImpl) TagsToPCollections(tagToPcol map[string]string) {
+       tsi.em.testStreamHandler.TagsToPCollections(tagToPcol)
+}
+
+// AddElementEvent adds an element event to the test stream event queue.
+func (tsi *testStreamImpl) AddElementEvent(tag string, elements 
[]TestStreamElement) {
+       tsi.em.testStreamHandler.AddElementEvent(tag, elements)
+       tsi.em.addPending(1)
+}
+
+// AddWatermarkEvent adds a watermark event to the test stream event queue.
+func (tsi *testStreamImpl) AddWatermarkEvent(tag string, newWatermark 
mtime.Time) {
+       tsi.em.testStreamHandler.AddWatermarkEvent(tag, newWatermark)
+       tsi.em.addPending(1)
+}
+
+// AddProcessingTimeEvent adds a processing time event to the test stream 
event queue.
+func (tsi *testStreamImpl) AddProcessingTimeEvent(d time.Duration) {
+       tsi.em.testStreamHandler.AddProcessingTimeEvent(d)
+       tsi.em.addPending(1)
+}
diff --git a/sdks/go/pkg/beam/runners/prism/internal/execute.go 
b/sdks/go/pkg/beam/runners/prism/internal/execute.go
index 1aa95bc6ee1..504125a2bd6 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/execute.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/execute.go
@@ -16,6 +16,7 @@
 package internal
 
 import (
+       "bytes"
        "context"
        "errors"
        "fmt"
@@ -24,6 +25,7 @@ import (
        "sync/atomic"
        "time"
 
+       "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder"
        "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/mtime"
        "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec"
        pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1"
@@ -238,6 +240,58 @@ func executePipeline(ctx context.Context, wks 
map[string]*worker.W, j *jobservic
                        case urns.TransformImpulse:
                                impulses = append(impulses, stage.ID)
                                em.AddStage(stage.ID, nil, 
[]string{getOnlyValue(t.GetOutputs())}, nil)
+                       case urns.TransformTestStream:
+                               // Add a synthetic stage that should largely be 
unused.
+                               em.AddStage(stage.ID, nil, 
maps.Values(t.GetOutputs()), nil)
+                               // Decode the test stream, and convert it to 
the various events for the ElementManager.
+                               var pyld pipepb.TestStreamPayload
+                               if err := 
proto.Unmarshal(t.GetSpec().GetPayload(), &pyld); err != nil {
+                                       return fmt.Errorf("prism error building 
stage %v - decoding TestStreamPayload: \n%w", stage.ID, err)
+                               }
+
+                               // Ensure awareness of the coder used for the 
teststream.
+                               cID, err := lpUnknownCoders(pyld.GetCoderId(), 
coders, comps.GetCoders())
+                               if err != nil {
+                                       panic(err)
+                               }
+                               mayLP := func(v []byte) []byte {
+                                       return v
+                               }
+                               if cID != pyld.GetCoderId() {
+                                       // The coder needed length prefixing. 
For simplicity, add a length prefix to each
+                                       // encoded element, since we will be 
sending a length prefixed coder to consume
+                                       // this anyway. This is simpler than 
trying to find all the re-written coders after the fact.
+                                       mayLP = func(v []byte) []byte {
+                                               var buf bytes.Buffer
+                                               if err := 
coder.EncodeVarInt((int64)(len(v)), &buf); err != nil {
+                                                       panic(err)
+                                               }
+                                               if _, err := buf.Write(v); err 
!= nil {
+                                                       panic(err)
+                                               }
+                                               return buf.Bytes()
+                                       }
+                               }
+
+                               tsb := em.AddTestStream(stage.ID, t.Outputs)
+                               for _, e := range pyld.GetEvents() {
+                                       switch ev := e.GetEvent().(type) {
+                                       case 
*pipepb.TestStreamPayload_Event_ElementEvent:
+                                               var elms 
[]engine.TestStreamElement
+                                               for _, e := range 
ev.ElementEvent.GetElements() {
+                                                       elms = append(elms, 
engine.TestStreamElement{Encoded: mayLP(e.GetEncodedElement()), EventTime: 
mtime.Time(e.GetTimestamp())})
+                                               }
+                                               
tsb.AddElementEvent(ev.ElementEvent.GetTag(), elms)
+                                               ev.ElementEvent.GetTag()
+                                       case 
*pipepb.TestStreamPayload_Event_WatermarkEvent:
+                                               
tsb.AddWatermarkEvent(ev.WatermarkEvent.GetTag(), 
mtime.Time(ev.WatermarkEvent.GetNewWatermark()))
+                                       case 
*pipepb.TestStreamPayload_Event_ProcessingTimeEvent:
+                                               
tsb.AddProcessingTimeEvent(time.Duration(ev.ProcessingTimeEvent.GetAdvanceDuration())
 * time.Millisecond)
+                                       default:
+                                               return fmt.Errorf("prism error 
building stage %v - unknown TestStream event type: %T", stage.ID, ev)
+                                       }
+                               }
+
                        case urns.TransformFlatten:
                                inputs := maps.Values(t.GetInputs())
                                sort.Strings(inputs)
diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go 
b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go
index 1c7e280dcdd..4cff2ae92e7 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go
@@ -117,6 +117,7 @@ func (s *Server) Prepare(ctx context.Context, req 
*jobpb.PrepareJobRequest) (*jo
        // Inspect Transforms for unsupported features.
        bypassedWindowingStrategies := map[string]bool{}
        ts := job.Pipeline.GetComponents().GetTransforms()
+       var testStreamIds []string
        for tid, t := range ts {
                urn := t.GetSpec().GetUrn()
                switch urn {
@@ -170,10 +171,27 @@ func (s *Server) Prepare(ctx context.Context, req 
*jobpb.PrepareJobRequest) (*jo
                                continue
                        }
                        fallthrough
+               case urns.TransformTestStream:
+                       var testStream pipepb.TestStreamPayload
+                       if err := proto.Unmarshal(t.GetSpec().GetPayload(), 
&testStream); err != nil {
+                               return nil, fmt.Errorf("unable to unmarshal 
TestStreamPayload for %v - %q: %w", tid, t.GetUniqueName(), err)
+                       }
+                       for _, ev := range testStream.GetEvents() {
+                               if ev.GetProcessingTimeEvent() != nil {
+                                       check("TestStream.Event - 
ProcessingTimeEvents unsupported.", ev.GetProcessingTimeEvent())
+                               }
+                       }
+
+                       t.EnvironmentId = "" // Unset the environment, to 
ensure it's handled prism side.
+                       testStreamIds = append(testStreamIds, tid)
                default:
                        check("PTransform.Spec.Urn", urn+" "+t.GetUniqueName(), 
"<doesn't exist>")
                }
        }
+       // At most one test stream per pipeline.
+       if len(testStreamIds) > 1 {
+               check("Multiple TestStream Transforms in Pipeline", 
testStreamIds)
+       }
 
        // Inspect Windowing strategies for unsupported features.
        for wsID, ws := range 
job.Pipeline.GetComponents().GetWindowingStrategies() {
diff --git a/sdks/go/pkg/beam/runners/prism/internal/unimplemented_test.go 
b/sdks/go/pkg/beam/runners/prism/internal/unimplemented_test.go
index a50a7fe21b0..7be5f340dde 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/unimplemented_test.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/unimplemented_test.go
@@ -43,18 +43,6 @@ func TestUnimplemented(t *testing.T) {
        }{
                // {pipeline: primitives.Drain}, // Can't test drain 
automatically yet.
 
-               {pipeline: primitives.TestStreamBoolSequence},
-               {pipeline: primitives.TestStreamByteSliceSequence},
-               {pipeline: primitives.TestStreamFloat64Sequence},
-               {pipeline: primitives.TestStreamInt64Sequence},
-               {pipeline: primitives.TestStreamStrings},
-               {pipeline: primitives.TestStreamTwoBoolSequences},
-               {pipeline: primitives.TestStreamTwoFloat64Sequences},
-               {pipeline: primitives.TestStreamTwoInt64Sequences},
-
-               // Needs teststream
-               {pipeline: primitives.Panes},
-
                // Triggers (Need teststream and are unimplemented.)
                {pipeline: primitives.TriggerAlways},
                {pipeline: primitives.TriggerAfterAll},
@@ -68,7 +56,8 @@ func TestUnimplemented(t *testing.T) {
                {pipeline: primitives.TriggerOrFinally},
                {pipeline: primitives.TriggerRepeat},
 
-               // TODO: Timers integration tests.
+               // Needs triggers.
+               {pipeline: primitives.Panes},
        }
 
        for _, test := range tests {
@@ -163,3 +152,31 @@ func TestTimers(t *testing.T) {
                })
        }
 }
+
+func TestTestStream(t *testing.T) {
+       initRunner(t)
+
+       tests := []struct {
+               pipeline func(s beam.Scope)
+       }{
+               {pipeline: primitives.TestStreamBoolSequence},
+               {pipeline: primitives.TestStreamByteSliceSequence},
+               {pipeline: primitives.TestStreamFloat64Sequence},
+               {pipeline: primitives.TestStreamInt64Sequence},
+               {pipeline: primitives.TestStreamStrings},
+               {pipeline: primitives.TestStreamTwoBoolSequences},
+               {pipeline: primitives.TestStreamTwoFloat64Sequences},
+               {pipeline: primitives.TestStreamTwoInt64Sequences},
+       }
+
+       for _, test := range tests {
+               t.Run(initTestName(test.pipeline), func(t *testing.T) {
+                       p, s := beam.NewPipelineWithRoot()
+                       test.pipeline(s)
+                       _, err := executeWithT(context.Background(), t, p)
+                       if err != nil {
+                               t.Fatalf("pipeline failed, but feature should 
be implemented in Prism: %v", err)
+                       }
+               })
+       }
+}
diff --git a/sdks/go/pkg/beam/testing/teststream/teststream.go 
b/sdks/go/pkg/beam/testing/teststream/teststream.go
index 050e57bf04c..c13e2cee9e0 100644
--- a/sdks/go/pkg/beam/testing/teststream/teststream.go
+++ b/sdks/go/pkg/beam/testing/teststream/teststream.go
@@ -18,11 +18,9 @@
 //
 // See https://beam.apache.org/blog/test-stream/ for more information.
 //
-// TestStream is supported on the Flink runner and currently supports int64,
-// float64, and boolean types.
-//
-// TODO(BEAM-12753): Flink currently displays unexpected behavior with 
TestStream,
-// should not be used until this issue is resolved.
+// TestStream is supported on the Flink, and Prism runners.
+// Use on Flink currently supports int64, float64, and boolean types, while
+// Prism supports arbitrary types.
 package teststream
 
 import (
diff --git a/sdks/go/test/integration/integration.go 
b/sdks/go/test/integration/integration.go
index 622689c40d0..8f90ffda9e8 100644
--- a/sdks/go/test/integration/integration.go
+++ b/sdks/go/test/integration/integration.go
@@ -139,8 +139,6 @@ var portableFilters = []string{
 var prismFilters = []string{
        // The prism runner does not yet support Java's CoGBK.
        "TestXLang_CoGroupBy",
-       // The prism runner does not support the TestStream primitive
-       "TestTestStream.*",
        // The trigger and pane tests uses TestStream
        "TestTrigger.*",
        "TestPanes",
@@ -183,6 +181,11 @@ var flinkFilters = []string{
        "TestSetStateClear",
        "TestSetState",
 
+       // With TestStream Flink adds extra length prefixs some data types, 
causing SDK side failures.
+       "TestTestStreamStrings",
+       "TestTestStreamByteSliceSequence",
+       "TestTestStreamTwoUserTypeSequences",
+
        "TestTimers_EventTime_Unbounded", // (failure when comparing on side 
inputs (NPE on window lookup))
 }
 
diff --git a/sdks/go/test/integration/primitives/teststream.go 
b/sdks/go/test/integration/primitives/teststream.go
index d30ec9fe11b..c8ba9b565c0 100644
--- a/sdks/go/test/integration/primitives/teststream.go
+++ b/sdks/go/test/integration/primitives/teststream.go
@@ -31,18 +31,22 @@ func TestStreamStrings(s beam.Scope) {
        col := teststream.Create(s, con)
 
        passert.Count(s, col, "teststream strings", 3)
+       passert.Equals(s, col, "a", "b", "c")
 }
 
 // TestStreamByteSliceSequence tests the TestStream primitive by inserting 
byte slice elements
 // then advancing the watermark to infinity and comparing the output..
 func TestStreamByteSliceSequence(s beam.Scope) {
        con := teststream.NewConfig()
-       b := []byte{91, 92, 93}
-       con.AddElements(1, b)
+
+       a := []byte{91, 92, 93}
+       b := []byte{94, 95, 96}
+       c := []byte{97, 98, 99}
+       con.AddElements(1, a, b, c)
        con.AdvanceWatermarkToInfinity()
        col := teststream.Create(s, con)
-       passert.Count(s, col, "teststream byte", 1)
-       passert.Equals(s, col, append([]byte{3}, b...))
+       passert.Count(s, col, "teststream byte", 3)
+       passert.Equals(s, col, a, b, c)
 }
 
 // TestStreamInt64Sequence tests the TestStream primitive by inserting int64 
elements
@@ -137,3 +141,34 @@ func TestStreamTwoBoolSequences(s beam.Scope) {
        passert.Count(s, col, "teststream bool", 6)
        passert.EqualsList(s, col, append(eo, et...))
 }
+
+// TestStreamTwoUserTypeSequences tests the TestStream primitive by inserting 
two sets of
+// boolean elements that arrive on-time into the TestStream
+func TestStreamTwoUserTypeSequences(s beam.Scope) {
+       con := teststream.NewConfig()
+       eo := []stringPair{{"a", "b"}, {"b", "c"}, {"c", "a"}}
+       et := []stringPair{{"b", "a"}, {"c", "b"}, {"a", "c"}}
+       con.AddElementList(100, eo)
+       con.AdvanceWatermark(110)
+       con.AddElementList(120, et)
+       con.AdvanceWatermark(130)
+
+       col := teststream.Create(s, con)
+
+       passert.Count(s, col, "teststream usertype", 6)
+       passert.EqualsList(s, col, append(eo, et...))
+}
+
+// TestStreamInt16Sequence validates that a non-beam standard coder
+// works with test stream.
+func TestStreamInt16Sequence(s beam.Scope) {
+       con := teststream.NewConfig()
+       ele := []int16{91, 92, 93}
+       con.AddElementList(100, ele)
+       con.AdvanceWatermarkToInfinity()
+
+       col := teststream.Create(s, con)
+
+       passert.Count(s, col, "teststream int15", 3)
+       passert.EqualsList(s, col, ele)
+}
diff --git a/sdks/go/test/integration/primitives/teststream_test.go 
b/sdks/go/test/integration/primitives/teststream_test.go
index 90a2120294e..b0144f148cb 100644
--- a/sdks/go/test/integration/primitives/teststream_test.go
+++ b/sdks/go/test/integration/primitives/teststream_test.go
@@ -37,6 +37,11 @@ func TestTestStreamInt64Sequence(t *testing.T) {
        ptest.BuildAndRun(t, TestStreamInt64Sequence)
 }
 
+func TestTestStreamInt16Sequence(t *testing.T) {
+       integration.CheckFilters(t)
+       ptest.BuildAndRun(t, TestStreamInt16Sequence)
+}
+
 func TestTestStreamTwoInt64Sequences(t *testing.T) {
        integration.CheckFilters(t)
        ptest.BuildAndRun(t, TestStreamTwoInt64Sequences)
@@ -61,3 +66,8 @@ func TestTestStreamTwoBoolSequences(t *testing.T) {
        integration.CheckFilters(t)
        ptest.BuildAndRun(t, TestStreamTwoBoolSequences)
 }
+
+func TestTestStreamTwoUserTypeSequences(t *testing.T) {
+       integration.CheckFilters(t)
+       ptest.BuildAndRun(t, TestStreamTwoUserTypeSequences)
+}

Reply via email to