This is an automated email from the ASF dual-hosted git repository.

lostluck pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new bc42a637d56 [#28543][prism] Implement State API (#29712)
bc42a637d56 is described below

commit bc42a637d566b1dbc042632114afb539555ff353
Author: Robert Burke <lostl...@users.noreply.github.com>
AuthorDate: Thu Dec 14 10:45:17 2023 -0800

    [#28543][prism] Implement State API (#29712)
---
 .../pkg/beam/runners/prism/internal/engine/data.go | 176 ++++++++++++++++
 .../prism/internal/engine/elementmanager.go        | 226 ++++++++++++++++++++-
 .../runners/prism/internal/engine/engine_test.go   | 159 +++++++++++++++
 sdks/go/pkg/beam/runners/prism/internal/execute.go |  22 ++
 .../beam/runners/prism/internal/execute_test.go    |   8 +-
 .../pkg/beam/runners/prism/internal/handlepardo.go |   9 +-
 .../beam/runners/prism/internal/jobservices/job.go |   3 +-
 .../prism/internal/jobservices/management.go       |  20 +-
 .../pkg/beam/runners/prism/internal/preprocess.go  |  14 +-
 sdks/go/pkg/beam/runners/prism/internal/stage.go   |  22 +-
 .../runners/prism/internal/unimplemented_test.go   |  48 +++--
 .../pkg/beam/runners/prism/internal/urns/urns.go   |   5 +
 .../beam/runners/prism/internal/worker/bundle.go   |   6 +-
 .../beam/runners/prism/internal/worker/worker.go   |  89 ++++++--
 sdks/go/test/integration/primitives/state.go       |  53 +++--
 15 files changed, 769 insertions(+), 91 deletions(-)

diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/data.go 
b/sdks/go/pkg/beam/runners/prism/internal/engine/data.go
index 6fc192ac83b..6679f484aa2 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/engine/data.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/engine/data.go
@@ -15,10 +15,30 @@
 
 package engine
 
+import (
+       "bytes"
+       "fmt"
+
+       "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder"
+       "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/window"
+       "github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/exec"
+       "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
+       "golang.org/x/exp/slog"
+)
+
+// StateData is a "union" between Bag state and MultiMap state to increase 
common code.
+type StateData struct {
+       Bag      [][]byte
+       Multimap map[string][][]byte
+}
+
 // TentativeData is where data for in progress bundles is put
 // until the bundle executes successfully.
 type TentativeData struct {
        Raw map[string][][]byte
+
+       // state is a map from transformID + UserStateID, to window, to 
userKey, to datavalues.
+       state map[LinkID]map[typex.Window]map[string]StateData
 }
 
 // WriteData adds data to a given global collectionID.
@@ -28,3 +48,159 @@ func (d *TentativeData) WriteData(colID string, data 
[]byte) {
        }
        d.Raw[colID] = append(d.Raw[colID], data)
 }
+
+func (d *TentativeData) toWindow(wKey []byte) typex.Window {
+       if len(wKey) == 0 {
+               return window.GlobalWindow{}
+       }
+       // TODO: Custom Window handling.
+       w, err := 
exec.MakeWindowDecoder(coder.NewIntervalWindow()).DecodeSingle(bytes.NewBuffer(wKey))
+       if err != nil {
+               panic(fmt.Sprintf("error decoding append bag user state window 
key %v: %v", wKey, err))
+       }
+       return w
+}
+
+// GetBagState retrieves available state from the tentative bundle data.
+// The stateID has the Transform and Local fields populated, for the Transform 
and UserStateID respectively.
+func (d *TentativeData) GetBagState(stateID LinkID, wKey, uKey []byte) 
[][]byte {
+       winMap := d.state[stateID]
+       w := d.toWindow(wKey)
+       data := winMap[w][string(uKey)]
+       slog.Debug("State() Bag.Get", slog.Any("StateID", stateID), 
slog.Any("UserKey", uKey), slog.Any("Window", w), slog.Any("Data", data))
+       return data.Bag
+}
+
+func (d *TentativeData) appendState(stateID LinkID, wKey []byte) 
map[string]StateData {
+       if d.state == nil {
+               d.state = map[LinkID]map[typex.Window]map[string]StateData{}
+       }
+       winMap, ok := d.state[stateID]
+       if !ok {
+               winMap = map[typex.Window]map[string]StateData{}
+               d.state[stateID] = winMap
+       }
+       w := d.toWindow(wKey)
+       kmap, ok := winMap[w]
+       if !ok {
+               kmap = map[string]StateData{}
+               winMap[w] = kmap
+       }
+       return kmap
+}
+
+// AppendBagState appends the incoming data to the existing tentative data 
bundle.
+//
+// The stateID has the Transform and Local fields populated, for the Transform 
and UserStateID respectively.
+func (d *TentativeData) AppendBagState(stateID LinkID, wKey, uKey, data 
[]byte) {
+       kmap := d.appendState(stateID, wKey)
+       kmap[string(uKey)] = StateData{Bag: append(kmap[string(uKey)].Bag, 
data)}
+       slog.Debug("State() Bag.Append", slog.Any("StateID", stateID), 
slog.Any("UserKey", uKey), slog.Any("Window", wKey), slog.Any("NewData", data))
+}
+
+func (d *TentativeData) clearState(stateID LinkID, wKey []byte) 
map[string]StateData {
+       if d.state == nil {
+               return nil
+       }
+       winMap, ok := d.state[stateID]
+       if !ok {
+               return nil
+       }
+       w := d.toWindow(wKey)
+       return winMap[w]
+}
+
+// ClearBagState clears any tentative data for the state. Since state data is 
only initialized if any exists,
+// Clear takes the approach to not create state that doesn't already exist. 
Existing state is zeroed
+// to allow that to be committed post bundle commpletion.
+//
+// The stateID has the Transform and Local fields populated, for the Transform 
and UserStateID respectively.
+func (d *TentativeData) ClearBagState(stateID LinkID, wKey, uKey []byte) {
+       kmap := d.clearState(stateID, wKey)
+       if kmap == nil {
+               return
+       }
+       // Zero the current entry to clear.
+       // Delete makes it difficult to delete the persisted stage state for 
the key.
+       kmap[string(uKey)] = StateData{}
+       slog.Debug("State() Bag.Clear", slog.Any("StateID", stateID), 
slog.Any("UserKey", uKey), slog.Any("WindowKey", wKey))
+}
+
+// GetMultimapState retrieves available state from the tentative bundle data.
+// The stateID has the Transform and Local fields populated, for the Transform 
and UserStateID respectively.
+func (d *TentativeData) GetMultimapState(stateID LinkID, wKey, uKey, mapKey 
[]byte) [][]byte {
+       winMap := d.state[stateID]
+       w := d.toWindow(wKey)
+       data := winMap[w][string(uKey)].Multimap[string(mapKey)]
+       slog.Debug("State() Multimap.Get", slog.Any("StateID", stateID), 
slog.Any("UserKey", uKey), slog.Any("Window", w), slog.Any("Data", data))
+       return data
+}
+
+// AppendMultimapState appends the incoming data to the existing tentative 
data bundle.
+//
+// The stateID has the Transform and Local fields populated, for the Transform 
and UserStateID respectively.
+func (d *TentativeData) AppendMultimapState(stateID LinkID, wKey, uKey, 
mapKey, data []byte) {
+       kmap := d.appendState(stateID, wKey)
+       stateData, ok := kmap[string(uKey)]
+       if !ok || stateData.Multimap == nil { // Incase of All Key Clear 
tombstones, we may have a nil map.
+               stateData = StateData{Multimap: map[string][][]byte{}}
+               kmap[string(uKey)] = stateData
+       }
+       stateData.Multimap[string(mapKey)] = 
append(stateData.Multimap[string(mapKey)], data)
+       // The Multimap field is aliased to the instance we stored in kmap,
+       // so we don't need to re-assign back to kmap after appending the data 
to mapKey.
+       slog.Debug("State() Multimap.Append", slog.Any("StateID", stateID), 
slog.Any("UserKey", uKey), slog.Any("MapKey", mapKey), slog.Any("Window", 
wKey), slog.Any("NewData", data))
+}
+
+// ClearMultimapState clears any tentative data for the state. Since state 
data is only initialized if any exists,
+// Clear takes the approach to not create state that doesn't already exist. 
Existing state is zeroed
+// to allow that to be committed post bundle commpletion.
+//
+// The stateID has the Transform and Local fields populated, for the Transform 
and UserStateID respectively.
+func (d *TentativeData) ClearMultimapState(stateID LinkID, wKey, uKey, mapKey 
[]byte) {
+       kmap := d.clearState(stateID, wKey)
+       if kmap == nil {
+               return
+       }
+       // Nil the current entry to clear.
+       // Delete makes it difficult to delete the persisted stage state for 
the key.
+       userMap, ok := kmap[string(uKey)]
+       if !ok || userMap.Multimap == nil {
+               return
+       }
+       userMap.Multimap[string(mapKey)] = nil
+       // The Multimap field is aliased to the instance we stored in kmap,
+       // so we don't need to re-assign back to kmap after clearing the data 
from mapKey.
+       slog.Debug("State() Multimap.Clear", slog.Any("StateID", stateID), 
slog.Any("UserKey", uKey), slog.Any("Window", wKey))
+}
+
+// GetMultimapKeysState retrieves all available user map keys.
+//
+// The stateID has the Transform and Local fields populated, for the Transform 
and UserStateID respectively.
+func (d *TentativeData) GetMultimapKeysState(stateID LinkID, wKey, uKey 
[]byte) [][]byte {
+       winMap := d.state[stateID]
+       w := d.toWindow(wKey)
+       userMap := winMap[w][string(uKey)]
+       var keys [][]byte
+       for k := range userMap.Multimap {
+               keys = append(keys, []byte(k))
+       }
+       slog.Debug("State() MultimapKeys.Get", slog.Any("StateID", stateID), 
slog.Any("UserKey", uKey), slog.Any("Window", w), slog.Any("Keys", keys))
+       return keys
+}
+
+// ClearMultimapKeysState clears tentative data for all user map keys. Since 
state data is only initialized if any exists,
+// Clear takes the approach to not create state that doesn't already exist. 
Existing state is zeroed
+// to allow that to be committed post bundle commpletion.
+//
+// The stateID has the Transform and Local fields populated, for the Transform 
and UserStateID respectively.
+func (d *TentativeData) ClearMultimapKeysState(stateID LinkID, wKey, uKey 
[]byte) {
+       kmap := d.clearState(stateID, wKey)
+       if kmap == nil {
+               return
+       }
+       // Zero the current entry to clear.
+       // Delete makes it difficult to delete the persisted stage state for 
the key.
+       kmap[string(uKey)] = StateData{}
+       slog.Debug("State() MultimapKeys.Clear", slog.Any("StateID", stateID), 
slog.Any("UserKey", uKey), slog.Any("WindowKey", wKey))
+}
diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go 
b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go
index 656525c6704..6cb55235418 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go
@@ -23,6 +23,7 @@ import (
        "context"
        "fmt"
        "io"
+       "strings"
        "sync"
        "sync/atomic"
 
@@ -39,6 +40,7 @@ type element struct {
        pane      typex.PaneInfo
 
        elmBytes []byte
+       keyBytes []byte
 }
 
 type elements struct {
@@ -51,6 +53,7 @@ type PColInfo struct {
        WDec     exec.WindowDecoder
        WEnc     exec.WindowEncoder
        EDec     func(io.Reader) []byte
+       KeyDec   func(io.Reader) []byte
 }
 
 // ToData recodes the elements with their approprate windowed value header.
@@ -182,6 +185,12 @@ func (em *ElementManager) StageAggregates(ID string) {
        em.stages[ID].aggregate = true
 }
 
+// StageStateful marks the given stage as stateful, which means elements are
+// processed by key.
+func (em *ElementManager) StageStateful(ID string) {
+       em.stages[ID].stateful = true
+}
+
 // Impulse marks and initializes the given stage as an impulse which
 // is a root transform that starts processing.
 func (em *ElementManager) Impulse(stageID string) {
@@ -257,10 +266,13 @@ func (em *ElementManager) Bundles(ctx context.Context, 
nextBundID func() string)
                                ss := em.stages[stageID]
                                watermark, ready := ss.bundleReady(em)
                                if ready {
-                                       bundleID, ok := 
ss.startBundle(watermark, nextBundID)
+                                       bundleID, ok, reschedule := 
ss.startBundle(watermark, nextBundID)
                                        if !ok {
                                                continue
                                        }
+                                       if reschedule {
+                                               
em.watermarkRefreshes.insert(stageID)
+                                       }
                                        rb := RunBundle{StageID: stageID, 
BundleID: bundleID, Watermark: watermark}
 
                                        em.inprogressBundles.insert(rb.BundleID)
@@ -278,7 +290,11 @@ func (em *ElementManager) Bundles(ctx context.Context, 
nextBundID func() string)
                                v := em.livePending.Load()
                                slog.Debug("Bundles: nothing in progress and no 
refreshes", slog.Int64("pendingElementCount", v))
                                if v > 0 {
-                                       panic(fmt.Sprintf("nothing in progress 
and no refreshes with non zero pending elements: %v", v))
+                                       var stageState []string
+                                       for id, ss := range em.stages {
+                                               stageState = append(stageState, 
fmt.Sprintln(id, ss.pending, ss.pendingByKeys, ss.inprogressKeys, 
ss.inprogressKeysByBundle))
+                                       }
+                                       panic(fmt.Sprintf("nothing in progress 
and no refreshes with non zero pending elements: %v\n%v", v, 
strings.Join(stageState, "")))
                                }
                        } else if len(em.inprogressBundles) == 0 {
                                v := em.livePending.Load()
@@ -304,6 +320,56 @@ func (em *ElementManager) InputForBundle(rb RunBundle, 
info PColInfo) [][]byte {
        return es.ToData(info)
 }
 
+// StateForBundle retreives relevant state for the given bundle, WRT the data 
in the bundle.
+//
+// TODO(lostluck): Consider unifiying with InputForBundle, to reduce lock 
contention.
+func (em *ElementManager) StateForBundle(rb RunBundle) TentativeData {
+       ss := em.stages[rb.StageID]
+       ss.mu.Lock()
+       defer ss.mu.Unlock()
+       var ret TentativeData
+       keys := ss.inprogressKeysByBundle[rb.BundleID]
+       // TODO(lostluck): Also track windows per bundle, to reduce copying.
+       if len(ss.state) > 0 {
+               ret.state = map[LinkID]map[typex.Window]map[string]StateData{}
+       }
+       for link, winMap := range ss.state {
+               for w, keyMap := range winMap {
+                       for key := range keys {
+                               data, ok := keyMap[key]
+                               if !ok {
+                                       continue
+                               }
+                               linkMap, ok := ret.state[link]
+                               if !ok {
+                                       linkMap = 
map[typex.Window]map[string]StateData{}
+                                       ret.state[link] = linkMap
+                               }
+                               wlinkMap, ok := linkMap[w]
+                               if !ok {
+                                       wlinkMap = map[string]StateData{}
+                                       linkMap[w] = wlinkMap
+                               }
+                               var mm map[string][][]byte
+                               if len(data.Multimap) > 0 {
+                                       mm = map[string][][]byte{}
+                                       for uk, v := range data.Multimap {
+                                               // Clone the "holding" slice, 
but refer to the existing data bytes.
+                                               mm[uk] = append([][]byte(nil), 
v...)
+                                       }
+                               }
+                               // Clone the "holding" slice, but refer to the 
existing data bytes.
+                               wlinkMap[key] = StateData{
+                                       Bag:      append([][]byte(nil), 
data.Bag...),
+                                       Multimap: mm,
+                               }
+                       }
+               }
+       }
+
+       return ret
+}
+
 // reElementResiduals extracts the windowed value header from residual bytes, 
and explodes them
 // back out to their windows.
 func reElementResiduals(residuals [][]byte, inputInfo PColInfo, rb RunBundle) 
[]element {
@@ -322,6 +388,15 @@ func reElementResiduals(residuals [][]byte, inputInfo 
PColInfo, rb RunBundle) []
                        slog.Error("reElementResiduals: sdk provided a windowed 
value header 0 windows", "bundle", rb)
                        panic("error decoding residual header: sdk provided a 
windowed value header 0 windows")
                }
+               // POSSIBLY BAD PATTERN: The buffer is invalidated on the next 
call, which doesn't always happen.
+               // But the decoder won't be mutating the buffer bytes, just 
reading the data. So the elmBytes
+               // should remain pointing to the whole element, and we should 
have a copy of the key bytes.
+               // Ideally, we're simply refering to the key part of the 
existing buffer.
+               elmBytes := buf.Bytes()
+               var keyBytes []byte
+               if inputInfo.KeyDec != nil {
+                       keyBytes = inputInfo.KeyDec(buf)
+               }
 
                for _, w := range ws {
                        unprocessedElements = append(unprocessedElements,
@@ -329,7 +404,8 @@ func reElementResiduals(residuals [][]byte, inputInfo 
PColInfo, rb RunBundle) []
                                        window:    w,
                                        timestamp: et,
                                        pane:      pn,
-                                       elmBytes:  buf.Bytes(),
+                                       elmBytes:  elmBytes,
+                                       keyBytes:  keyBytes,
                                })
                }
        }
@@ -373,6 +449,11 @@ func (em *ElementManager) PersistBundle(rb RunBundle, 
col2Coders map[string]PCol
                                }
                                // TODO: Optimize unnecessary copies. This is 
doubleteeing.
                                elmBytes := info.EDec(tee)
+                               var keyBytes []byte
+                               if info.KeyDec != nil {
+                                       kbuf := bytes.NewBuffer(elmBytes)
+                                       keyBytes = info.KeyDec(kbuf) // TODO: 
Optimize unnecessary copies. This is tripleteeing?
+                               }
                                for _, w := range ws {
                                        newPending = append(newPending,
                                                element{
@@ -380,6 +461,7 @@ func (em *ElementManager) PersistBundle(rb RunBundle, 
col2Coders map[string]PCol
                                                        timestamp: et,
                                                        pane:      pn,
                                                        elmBytes:  elmBytes,
+                                                       keyBytes:  keyBytes,
                                                })
                                }
                        }
@@ -412,6 +494,10 @@ func (em *ElementManager) PersistBundle(rb RunBundle, 
col2Coders map[string]PCol
        completed := stage.inprogress[rb.BundleID]
        em.addPending(-len(completed.es))
        delete(stage.inprogress, rb.BundleID)
+       for k := range stage.inprogressKeysByBundle[rb.BundleID] {
+               delete(stage.inprogressKeys, k)
+       }
+       delete(stage.inprogressKeysByBundle, rb.BundleID)
        // If there are estimated output watermarks, set the estimated
        // output watermark for the stage.
        if len(estimatedOWM) > 0 {
@@ -421,6 +507,25 @@ func (em *ElementManager) PersistBundle(rb RunBundle, 
col2Coders map[string]PCol
                }
                stage.estimatedOutput = estimate
        }
+
+       // Handle persisting.
+       for link, winMap := range d.state {
+               linkMap, ok := stage.state[link]
+               if !ok {
+                       linkMap = map[typex.Window]map[string]StateData{}
+                       stage.state[link] = linkMap
+               }
+               for w, keyMap := range winMap {
+                       wlinkMap, ok := linkMap[w]
+                       if !ok {
+                               wlinkMap = map[string]StateData{}
+                               linkMap[w] = wlinkMap
+                       }
+                       for key, data := range keyMap {
+                               wlinkMap[key] = data
+                       }
+               }
+       }
        stage.mu.Unlock()
 
        // TODO support state/timer watermark holds.
@@ -499,6 +604,11 @@ func (em *ElementManager) refreshWatermarks() set[string] {
 
 type set[K comparable] map[K]struct{}
 
+func (s set[K]) present(k K) bool {
+       _, ok := s[k]
+       return ok
+}
+
 func (s set[K]) remove(k K) {
        delete(s, k)
 }
@@ -525,7 +635,8 @@ type stageState struct {
        sides     []LinkID // PCollection IDs of side inputs that can block 
execution.
 
        // Special handling bits
-       aggregate bool     // whether this state needs to block for aggregation.
+       stateful  bool     // whether this stage uses state or timers, and 
needs keyed processing.
+       aggregate bool     // whether this stage needs to block for aggregation.
        strat     winStrat // Windowing Strategy for aggregation fireings.
 
        mu                 sync.Mutex
@@ -537,6 +648,12 @@ type stageState struct {
        pending    elementHeap                          // pending input 
elements for this stage that are to be processesd
        inprogress map[string]elements                  // inprogress elements 
by active bundles, keyed by bundle
        sideInputs map[LinkID]map[typex.Window][][]byte // side input data for 
this stage, from {tid, inputID} -> window
+
+       // Fields for stateful stages which need to be per key.
+       pendingByKeys          map[string]elementHeap                           
// pending input elements by Key, if stateful.
+       inprogressKeys         set[string]                                      
// all keys that are assigned to bundles.
+       inprogressKeysByBundle map[string]set[string]                           
// bundle to key assignments.
+       state                  map[LinkID]map[typex.Window]map[string]StateData 
// state data for this stage, from {tid, stateID} -> window -> userKey
 }
 
 // makeStageState produces an initialized stageState.
@@ -546,6 +663,7 @@ func makeStageState(ID string, inputIDs, outputIDs 
[]string, sides []LinkID) *st
                outputIDs: outputIDs,
                sides:     sides,
                strat:     defaultStrat{},
+               state:     map[LinkID]map[typex.Window]map[string]StateData{},
 
                input:           mtime.MinTimestamp,
                output:          mtime.MinTimestamp,
@@ -566,8 +684,22 @@ func makeStageState(ID string, inputIDs, outputIDs 
[]string, sides []LinkID) *st
 func (ss *stageState) AddPending(newPending []element) {
        ss.mu.Lock()
        defer ss.mu.Unlock()
-       ss.pending = append(ss.pending, newPending...)
-       heap.Init(&ss.pending)
+       if ss.stateful {
+               if ss.pendingByKeys == nil {
+                       ss.pendingByKeys = map[string]elementHeap{}
+               }
+               for _, e := range newPending {
+                       if len(e.keyBytes) == 0 {
+                               panic(fmt.Sprintf("zero length key: %v %v", 
ss.ID, ss.inputID))
+                       }
+                       h := ss.pendingByKeys[string(e.keyBytes)]
+                       h.Push(e)
+                       ss.pendingByKeys[string(e.keyBytes)] = h // (Is this 
necessary, with the way the heap interface works over a slice?)
+               }
+       } else {
+               ss.pending = append(ss.pending, newPending...)
+               heap.Init(&ss.pending)
+       }
 }
 
 // AddPendingSide adds elements to be consumed as side inputs.
@@ -647,10 +779,16 @@ func (ss *stageState) OutputWatermark() mtime.Time {
        return ss.output
 }
 
+// TODO: Move to better place for configuration
+var (
+       OneKeyPerBundle  bool // OneKeyPerBundle sets if a bundle is restricted 
to a single key.
+       OneElementPerKey bool // OneElementPerKey sets if a key in a bundle is 
restricted to one element.
+)
+
 // startBundle initializes a bundle with elements if possible.
 // A bundle only starts if there are elements at all, and if it's
 // an aggregation stage, if the windowing stratgy allows it.
-func (ss *stageState) startBundle(watermark mtime.Time, genBundID func() 
string) (string, bool) {
+func (ss *stageState) startBundle(watermark mtime.Time, genBundID func() 
string) (string, bool, bool) {
        defer func() {
                if e := recover(); e != nil {
                        panic(fmt.Sprintf("generating bundle for stage %v at %v 
panicked\n%v", ss.ID, watermark, e))
@@ -669,21 +807,73 @@ func (ss *stageState) startBundle(watermark mtime.Time, 
genBundID func() string)
        }
        ss.pending = notYet
        heap.Init(&ss.pending)
+       if ss.inprogressKeys == nil {
+               ss.inprogressKeys = set[string]{}
+       }
+       minTs := mtime.MaxTimestamp
+       // TODO: Allow configurable limit of keys per bundle, and elements per 
key to improve parallelism.
+       // TODO: when we do, we need to ensure that the stage remains 
schedualable for bundle execution, for remaining pending elements and keys.
+       // With the greedy approach, we don't need to since "new data" triggers 
a refresh, and so should completing processing of a bundle.
+       newKeys := set[string]{}
+       stillSchedulable := true
+
+keysPerBundle:
+       for k, h := range ss.pendingByKeys {
+               if ss.inprogressKeys.present(k) {
+                       continue
+               }
+               newKeys.insert(k)
+               // Track the min-timestamp for later watermark handling.
+               if h[0].timestamp < minTs {
+                       minTs = h[0].timestamp
+               }
+
+               if OneElementPerKey {
+                       hp := &h
+                       toProcess = append(toProcess, heap.Pop(hp).(element))
+                       if hp.Len() == 0 {
+                               // Once we've taken all the elements for a key,
+                               // we must delete them from pending as well.
+                               delete(ss.pendingByKeys, k)
+                       } else {
+                               ss.pendingByKeys[k] = *hp
+                       }
+               } else {
+                       toProcess = append(toProcess, h...)
+                       delete(ss.pendingByKeys, k)
+               }
+               if OneKeyPerBundle {
+                       break keysPerBundle
+               }
+       }
+       if len(ss.pendingByKeys) == 0 {
+               stillSchedulable = false
+       }
 
        if len(toProcess) == 0 {
-               return "", false
+               return "", false, false
+       }
+
+       if toProcess[0].timestamp < minTs {
+               // Catch the ordinary case.
+               minTs = toProcess[0].timestamp
        }
-       // Is THIS is where basic splits should happen/per element processing?
+
        es := elements{
                es:           toProcess,
-               minTimestamp: toProcess[0].timestamp,
+               minTimestamp: minTs,
        }
        if ss.inprogress == nil {
                ss.inprogress = make(map[string]elements)
        }
+       if ss.inprogressKeysByBundle == nil {
+               ss.inprogressKeysByBundle = make(map[string]set[string])
+       }
        bundID := genBundID()
        ss.inprogress[bundID] = es
-       return bundID, true
+       ss.inprogressKeysByBundle[bundID] = newKeys
+       ss.inprogressKeys.merge(newKeys)
+       return bundID, true, stillSchedulable
 }
 
 func (ss *stageState) splitBundle(rb RunBundle, firstResidual int) {
@@ -713,6 +903,12 @@ func (ss *stageState) minPendingTimestamp() mtime.Time {
        if len(ss.pending) != 0 {
                minPending = ss.pending[0].timestamp
        }
+       if len(ss.pendingByKeys) != 0 {
+               // TODO(lostluck): Can we figure out how to avoid checking 
every key on every watermark refresh?
+               for _, h := range ss.pendingByKeys {
+                       minPending = mtime.Min(minPending, h[0].timestamp)
+               }
+       }
        for _, es := range ss.inprogress {
                minPending = mtime.Min(minPending, es.minTimestamp)
        }
@@ -785,6 +981,14 @@ func (ss *stageState) updateWatermarks(minPending, 
minStateHold mtime.Time, em *
                                }
                        }
                }
+               for _, wins := range ss.state {
+                       for win := range wins {
+                               // Clear out anything we've already used.
+                               if win.MaxTimestamp() < newOut {
+                                       delete(wins, win)
+                               }
+                       }
+               }
        }
        return refreshes
 }
diff --git a/sdks/go/pkg/beam/runners/prism/internal/engine/engine_test.go 
b/sdks/go/pkg/beam/runners/prism/internal/engine/engine_test.go
new file mode 100644
index 00000000000..af41e089a2e
--- /dev/null
+++ b/sdks/go/pkg/beam/runners/prism/internal/engine/engine_test.go
@@ -0,0 +1,159 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package engine_test ensures coverage of the element manager via pipeline 
actuation.
+package engine_test
+
+import (
+       "context"
+       "fmt"
+       "math/rand"
+       "os"
+       "strings"
+       "testing"
+       "time"
+
+       "github.com/apache/beam/sdks/v2/go/pkg/beam"
+       "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx"
+       "github.com/apache/beam/sdks/v2/go/pkg/beam/options/jobopts"
+       "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal"
+       
"github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine"
+       
"github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/jobservices"
+       "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/universal"
+       "github.com/apache/beam/sdks/v2/go/pkg/beam/testing/ptest"
+       "github.com/apache/beam/sdks/v2/go/test/integration/primitives"
+)
+
+func init() {
+       // Not actually being used, but explicitly registering
+       // will avoid accidentally using a different runner for
+       // the tests if I change things later.
+       beam.RegisterRunner("testlocal", execute)
+}
+
+func TestMain(m *testing.M) {
+       ptest.MainWithDefault(m, "testlocal")
+}
+
+func initRunner(t testing.TB) {
+       t.Helper()
+       if *jobopts.Endpoint == "" {
+               s := jobservices.NewServer(0, internal.RunPipeline)
+               *jobopts.Endpoint = s.Endpoint()
+               go s.Serve()
+               t.Cleanup(func() {
+                       *jobopts.Endpoint = ""
+                       s.Stop()
+               })
+       }
+       if !jobopts.IsLoopback() {
+               *jobopts.EnvironmentType = "loopback"
+       }
+       // Since we force loopback, avoid cross-compilation.
+       f, err := os.CreateTemp("", "dummy")
+       if err != nil {
+               t.Fatal(err)
+       }
+       t.Cleanup(func() { os.Remove(f.Name()) })
+       *jobopts.WorkerBinary = f.Name()
+}
+
+func execute(ctx context.Context, p *beam.Pipeline) (beam.PipelineResult, 
error) {
+       return universal.Execute(ctx, p)
+}
+
+func executeWithT(ctx context.Context, t testing.TB, p *beam.Pipeline) 
(beam.PipelineResult, error) {
+       t.Helper()
+       t.Log("startingTest - ", t.Name())
+       s1 := rand.NewSource(time.Now().UnixNano())
+       r1 := rand.New(s1)
+       *jobopts.JobName = fmt.Sprintf("%v-%v", strings.ToLower(t.Name()), 
r1.Intn(1000))
+       return execute(ctx, p)
+}
+
+func initTestName(fn any) string {
+       name := reflectx.FunctionName(fn)
+       n := strings.LastIndex(name, "/")
+       return name[n+1:]
+}
+
+func TestStateAPI(t *testing.T) {
+       initRunner(t)
+
+       tests := []struct {
+               pipeline func(s beam.Scope)
+       }{
+               {pipeline: primitives.BagStateParDo},
+               {pipeline: primitives.BagStateParDoClear},
+               {pipeline: primitives.CombiningStateParDo},
+               {pipeline: primitives.ValueStateParDo},
+               {pipeline: primitives.ValueStateParDoClear},
+               {pipeline: primitives.ValueStateParDoWindowed},
+               {pipeline: primitives.MapStateParDo},
+               {pipeline: primitives.MapStateParDoClear},
+               {pipeline: primitives.SetStateParDo},
+               {pipeline: primitives.SetStateParDoClear},
+       }
+
+       configs := []struct {
+               name                              string
+               OneElementPerKey, OneKeyPerBundle bool
+       }{
+               {"Greedy", false, false},
+               {"AllElementsPerKey", false, true},
+               {"OneElementPerKey", true, false},
+               {"OneElementPerBundle", true, true},
+       }
+       for _, config := range configs {
+               for _, test := range tests {
+                       t.Run(initTestName(test.pipeline)+"_"+config.name, 
func(t *testing.T) {
+                               t.Cleanup(func() {
+                                       engine.OneElementPerKey = false
+                                       engine.OneKeyPerBundle = false
+                               })
+                               engine.OneElementPerKey = 
config.OneElementPerKey
+                               engine.OneKeyPerBundle = config.OneKeyPerBundle
+                               p, s := beam.NewPipelineWithRoot()
+                               test.pipeline(s)
+                               _, err := executeWithT(context.Background(), t, 
p)
+                               if err != nil {
+                                       t.Fatalf("pipeline failed, but feature 
should be implemented in Prism: %v", err)
+                               }
+                       })
+               }
+       }
+}
+
+func TestElementManagerCoverage(t *testing.T) {
+       initRunner(t)
+
+       tests := []struct {
+               pipeline func(s beam.Scope)
+       }{
+               {pipeline: primitives.Checkpoints}, // (Doesn't run long enough 
to split.)
+               {pipeline: primitives.WindowSums_Lifted},
+       }
+
+       for _, test := range tests {
+               t.Run(initTestName(test.pipeline), func(t *testing.T) {
+                       p, s := beam.NewPipelineWithRoot()
+                       test.pipeline(s)
+                       _, err := executeWithT(context.Background(), t, p)
+                       if err != nil {
+                               t.Fatalf("pipeline failed, but feature should 
be implemented in Prism: %v", err)
+                       }
+               })
+       }
+}
diff --git a/sdks/go/pkg/beam/runners/prism/internal/execute.go 
b/sdks/go/pkg/beam/runners/prism/internal/execute.go
index 89512238385..b8bc68dcd1b 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/execute.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/execute.go
@@ -179,11 +179,16 @@ func executePipeline(ctx context.Context, wks 
map[string]*worker.W, j *jobservic
                        ed := collectionPullDecoder(col.GetCoderId(), coders, 
comps)
                        wDec, wEnc := getWindowValueCoders(comps, col, coders)
 
+                       var kd func(io.Reader) []byte
+                       if kcid, ok := extractKVCoderID(col.GetCoderId(), 
coders); ok {
+                               kd = collectionPullDecoder(kcid, coders, comps)
+                       }
                        stage.OutputsToCoders[onlyOut] = engine.PColInfo{
                                GlobalID: onlyOut,
                                WDec:     wDec,
                                WEnc:     wEnc,
                                EDec:     ed,
+                               KeyDec:   kd,
                        }
 
                        // There's either 0, 1 or many inputs, but they should 
be all the same
@@ -208,11 +213,17 @@ func executePipeline(ctx context.Context, wks 
map[string]*worker.W, j *jobservic
                                        col := comps.GetPcollections()[global]
                                        ed := 
collectionPullDecoder(col.GetCoderId(), coders, comps)
                                        wDec, wEnc := 
getWindowValueCoders(comps, col, coders)
+
+                                       var kd func(io.Reader) []byte
+                                       if kcid, ok := 
extractKVCoderID(col.GetCoderId(), coders); ok {
+                                               kd = 
collectionPullDecoder(kcid, coders, comps)
+                                       }
                                        stage.inputInfo = engine.PColInfo{
                                                GlobalID: global,
                                                WDec:     wDec,
                                                WEnc:     wEnc,
                                                EDec:     ed,
+                                               KeyDec:   kd,
                                        }
                                }
                                em.StageAggregates(stage.ID)
@@ -234,6 +245,9 @@ func executePipeline(ctx context.Context, wks 
map[string]*worker.W, j *jobservic
                        outputs := maps.Keys(stage.OutputsToCoders)
                        sort.Strings(outputs)
                        em.AddStage(stage.ID, []string{stage.primaryInput}, 
outputs, stage.sideInputs)
+                       if stage.stateful {
+                               em.StageStateful(stage.ID)
+                       }
                default:
                        err := fmt.Errorf("unknown environment[%v]", 
t.GetEnvironmentId())
                        slog.Error("Execute", err)
@@ -286,6 +300,14 @@ func collectionPullDecoder(coldCId string, coders 
map[string]*pipepb.Coder, comp
        return pullDecoder(coders[cID], coders)
 }
 
+func extractKVCoderID(coldCId string, coders map[string]*pipepb.Coder) 
(string, bool) {
+       c := coders[coldCId]
+       if c.GetSpec().GetUrn() == urns.CoderKV {
+               return c.GetComponentCoderIds()[0], true
+       }
+       return "", false
+}
+
 func getWindowValueCoders(comps *pipepb.Components, col *pipepb.PCollection, 
coders map[string]*pipepb.Coder) (exec.WindowDecoder, exec.WindowEncoder) {
        ws := comps.GetWindowingStrategies()[col.GetWindowingStrategyId()]
        wcID, err := lpUnknownCoders(ws.GetWindowCoderId(), coders, 
comps.GetCoders())
diff --git a/sdks/go/pkg/beam/runners/prism/internal/execute_test.go 
b/sdks/go/pkg/beam/runners/prism/internal/execute_test.go
index fe3da83c67e..29fccaeb238 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/execute_test.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/execute_test.go
@@ -37,6 +37,10 @@ import (
        "github.com/apache/beam/sdks/v2/go/test/integration/primitives"
 )
 
+func TestMain(m *testing.M) {
+       ptest.MainWithDefault(m, "testlocal")
+}
+
 func initRunner(t testing.TB) {
        t.Helper()
        if *jobopts.Endpoint == "" {
@@ -585,10 +589,6 @@ func init() {
 // There's a doubling bug since we re-use the same pcollection IDs for the 
source & sink, and
 // don't do any re-writing.
 
-func TestMain(m *testing.M) {
-       ptest.MainWithDefault(m, "testlocal")
-}
-
 func init() {
        // Basic Registration
        // beam.RegisterFunction(identity)
diff --git a/sdks/go/pkg/beam/runners/prism/internal/handlepardo.go 
b/sdks/go/pkg/beam/runners/prism/internal/handlepardo.go
index 45223c1b2bc..38e7e9454df 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/handlepardo.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/handlepardo.go
@@ -82,19 +82,26 @@ func (h *pardo) PrepareTransform(tid string, t 
*pipepb.PTransform, comps *pipepb
                !pdo.RequestsFinalization &&
                !pdo.RequiresStableInput &&
                !pdo.RequiresTimeSortedInput &&
-               len(pdo.StateSpecs) == 0 &&
                len(pdo.TimerFamilySpecs) == 0 &&
                pdo.RestrictionCoderId == "" {
                // Which inputs are Side inputs don't change the graph further,
                // so they're not included here. Any nearly any ParDo can have 
them.
 
                // At their simplest, we don't need to do anything special at 
pre-processing time, and simply pass through as normal.
+
+               // StatefulDoFns need to be marked as being roots.
+               var forcedRoots []string
+               if len(pdo.StateSpecs)+len(pdo.TimerFamilySpecs) > 0 {
+                       forcedRoots = append(forcedRoots, tid)
+               }
+
                return prepareResult{
                        SubbedComps: &pipepb.Components{
                                Transforms: map[string]*pipepb.PTransform{
                                        tid: t,
                                },
                        },
+                       ForcedRoots: forcedRoots,
                }
        }
 
diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go 
b/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go
index cd302a70fcc..d6e906bee59 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/job.go
@@ -42,7 +42,8 @@ import (
 )
 
 var supportedRequirements = map[string]struct{}{
-       urns.RequirementSplittableDoFn: {},
+       urns.RequirementSplittableDoFn:     {},
+       urns.RequirementStatefulProcessing: {},
 }
 
 // TODO, move back to main package, and key off of executor handlers?
diff --git a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go 
b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go
index 0fd7381e17f..d3727b65086 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/jobservices/management.go
@@ -26,6 +26,7 @@ import (
        "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns"
        "golang.org/x/exp/maps"
        "golang.org/x/exp/slog"
+       "google.golang.org/protobuf/proto"
        "google.golang.org/protobuf/types/known/timestamppb"
 )
 
@@ -110,11 +111,10 @@ func (s *Server) Prepare(ctx context.Context, req 
*jobpb.PrepareJobRequest) (*jo
        // Inspect Transforms for unsupported features.
        bypassedWindowingStrategies := map[string]bool{}
        ts := job.Pipeline.GetComponents().GetTransforms()
-       for _, t := range ts {
+       for tid, t := range ts {
                urn := t.GetSpec().GetUrn()
                switch urn {
                case urns.TransformImpulse,
-                       urns.TransformParDo,
                        urns.TransformGBK,
                        urns.TransformFlatten,
                        urns.TransformCombinePerKey,
@@ -140,6 +140,22 @@ func (s *Server) Prepare(ctx context.Context, req 
*jobpb.PrepareJobRequest) (*jo
                                wsID := pcs[col].GetWindowingStrategyId()
                                bypassedWindowingStrategies[wsID] = true
                        }
+
+               case urns.TransformParDo:
+                       var pardo pipepb.ParDoPayload
+                       if err := proto.Unmarshal(t.GetSpec().GetPayload(), 
&pardo); err != nil {
+                               return nil, fmt.Errorf("unable to unmarshal 
ParDoPayload for %v - %q: %w", tid, t.GetUniqueName(), err)
+                       }
+
+                       // Validate all the state features
+                       for _, spec := range pardo.GetStateSpecs() {
+                               check("StateSpec.Protocol.Urn", 
spec.GetProtocol().GetUrn(), urns.UserStateBag, urns.UserStateMultiMap)
+                       }
+                       // Validate all the timer features
+                       for _, spec := range pardo.GetTimerFamilySpecs() {
+                               check("TimerFamilySpecs.TimeDomain.Urn", 
spec.GetTimeDomain())
+                       }
+
                case "":
                        // Composites can often have no spec
                        if len(t.GetSubtransforms()) > 0 {
diff --git a/sdks/go/pkg/beam/runners/prism/internal/preprocess.go 
b/sdks/go/pkg/beam/runners/prism/internal/preprocess.go
index 494baa5b4a9..ea4cf2c9969 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/preprocess.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/preprocess.go
@@ -26,6 +26,7 @@ import (
        "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns"
        "golang.org/x/exp/maps"
        "golang.org/x/exp/slog"
+       "google.golang.org/protobuf/proto"
 )
 
 // transformPreparer is an interface for handling different urns in the 
preprocessor
@@ -440,7 +441,18 @@ func finalizeStage(stg *stage, comps *pipepb.Components, 
pipelineFacts *fusionFa
                inputs[pid] = true
                for _, link := range plinks {
                        t := comps.GetTransforms()[link.Transform]
-                       sis, _ := getSideInputs(t)
+
+                       var sis map[string]*pipepb.SideInput
+                       if t.GetSpec().GetUrn() == urns.TransformParDo {
+                               pardo := &pipepb.ParDoPayload{}
+                               if err := 
(proto.UnmarshalOptions{}).Unmarshal(t.GetSpec().GetPayload(), pardo); err != 
nil {
+                                       return fmt.Errorf("unable to decode 
ParDoPayload for %v", link.Transform)
+                               }
+                               if 
len(pardo.GetTimerFamilySpecs())+len(pardo.GetStateSpecs()) > 0 {
+                                       stg.stateful = true
+                               }
+                               sis = pardo.GetSideInputs()
+                       }
                        if _, ok := sis[link.Local]; ok {
                                sideInputs = append(sideInputs, 
engine.LinkID{Transform: link.Transform, Global: link.Global, Local: 
link.Local})
                        } else {
diff --git a/sdks/go/pkg/beam/runners/prism/internal/stage.go 
b/sdks/go/pkg/beam/runners/prism/internal/stage.go
index 1ce10240638..b415f5c241d 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/stage.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/stage.go
@@ -19,6 +19,7 @@ import (
        "bytes"
        "context"
        "fmt"
+       "io"
        "time"
 
        "github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/mtime"
@@ -61,6 +62,7 @@ type stage struct {
        sideInputs   []engine.LinkID // Non-parallel input PCollections and 
their consumers
        internalCols []string        // PCollections that escape. Used for 
precise coder sending.
        envID        string
+       stateful     bool
 
        exe              transformExecuter
        inputTransformID string
@@ -77,6 +79,7 @@ func (s *stage) Execute(ctx context.Context, j 
*jobservices.Job, wk *worker.W, c
 
        var b *worker.B
        inputData := em.InputForBundle(rb, s.inputInfo)
+       initialState := em.StateForBundle(rb)
        var dataReady <-chan struct{}
        switch s.envID {
        case "": // Runner Transforms
@@ -102,8 +105,8 @@ func (s *stage) Execute(ctx context.Context, j 
*jobservices.Job, wk *worker.W, c
 
                        InputTransformID: s.inputTransformID,
 
-                       // TODO Here's where we can split data for processing 
in multiple bundles.
-                       InputData: inputData,
+                       InputData:  inputData,
+                       OutputData: initialState,
 
                        SinkToPCollection: s.SinkToPCollection,
                        OutputCount:       len(s.outputs),
@@ -300,6 +303,12 @@ func buildDescriptor(stg *stage, comps *pipepb.Components, 
wk *worker.W, em *eng
                }
                sinkID := o.Transform + "_" + o.Local
                ed := collectionPullDecoder(col.GetCoderId(), coders, comps)
+
+               var kd func(io.Reader) []byte
+               if kcid, ok := extractKVCoderID(col.GetCoderId(), coders); ok {
+                       kd = collectionPullDecoder(kcid, coders, comps)
+               }
+
                wDec, wEnc := getWindowValueCoders(comps, col, coders)
                sink2Col[sinkID] = o.Global
                col2Coders[o.Global] = engine.PColInfo{
@@ -307,6 +316,7 @@ func buildDescriptor(stg *stage, comps *pipepb.Components, 
wk *worker.W, em *eng
                        WDec:     wDec,
                        WEnc:     wEnc,
                        EDec:     ed,
+                       KeyDec:   kd,
                }
                transforms[sinkID] = sinkTransform(sinkID, portFor(wOutCid, 
wk), o.Global)
        }
@@ -350,14 +360,20 @@ func buildDescriptor(stg *stage, comps 
*pipepb.Components, wk *worker.W, em *eng
        if err != nil {
                return fmt.Errorf("buildDescriptor: failed to handle coder on 
stage %v for primary input, pcol %q %v:\n%w\n%v", stg.ID, stg.primaryInput, 
prototext.Format(col), err, stg.transforms)
        }
-
        ed := collectionPullDecoder(col.GetCoderId(), coders, comps)
        wDec, wEnc := getWindowValueCoders(comps, col, coders)
+
+       var kd func(io.Reader) []byte
+       if kcid, ok := extractKVCoderID(col.GetCoderId(), coders); ok {
+               kd = collectionPullDecoder(kcid, coders, comps)
+       }
+
        inputInfo := engine.PColInfo{
                GlobalID: stg.primaryInput,
                WDec:     wDec,
                WEnc:     wEnc,
                EDec:     ed,
+               KeyDec:   kd,
        }
 
        stg.inputTransformID = stg.ID + "_source"
diff --git a/sdks/go/pkg/beam/runners/prism/internal/unimplemented_test.go 
b/sdks/go/pkg/beam/runners/prism/internal/unimplemented_test.go
index b8a04a7306b..323773bd4cd 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/unimplemented_test.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/unimplemented_test.go
@@ -27,7 +27,7 @@ import (
 
 // This file covers pipelines with features that aren't yet supported by Prism.
 
-func intTestName(fn any) string {
+func initTestName(fn any) string {
        name := reflectx.FunctionName(fn)
        n := strings.LastIndex(name, "/")
        return name[n+1:]
@@ -68,23 +68,11 @@ func TestUnimplemented(t *testing.T) {
                {pipeline: primitives.TriggerOrFinally},
                {pipeline: primitives.TriggerRepeat},
 
-               // State API
-               {pipeline: primitives.BagStateParDo},
-               {pipeline: primitives.BagStateParDoClear},
-               {pipeline: primitives.MapStateParDo},
-               {pipeline: primitives.MapStateParDoClear},
-               {pipeline: primitives.SetStateParDo},
-               {pipeline: primitives.SetStateParDoClear},
-               {pipeline: primitives.CombiningStateParDo},
-               {pipeline: primitives.ValueStateParDo},
-               {pipeline: primitives.ValueStateParDoClear},
-               {pipeline: primitives.ValueStateParDoWindowed},
-
                // TODO: Timers integration tests.
        }
 
        for _, test := range tests {
-               t.Run(intTestName(test.pipeline), func(t *testing.T) {
+               t.Run(initTestName(test.pipeline), func(t *testing.T) {
                        p, s := beam.NewPipelineWithRoot()
                        test.pipeline(s)
                        _, err := executeWithT(context.Background(), t, p)
@@ -113,7 +101,37 @@ func TestImplemented(t *testing.T) {
        }
 
        for _, test := range tests {
-               t.Run(intTestName(test.pipeline), func(t *testing.T) {
+               t.Run(initTestName(test.pipeline), func(t *testing.T) {
+                       p, s := beam.NewPipelineWithRoot()
+                       test.pipeline(s)
+                       _, err := executeWithT(context.Background(), t, p)
+                       if err != nil {
+                               t.Fatalf("pipeline failed, but feature should 
be implemented in Prism: %v", err)
+                       }
+               })
+       }
+}
+
+func TestStateAPI(t *testing.T) {
+       initRunner(t)
+
+       tests := []struct {
+               pipeline func(s beam.Scope)
+       }{
+               {pipeline: primitives.BagStateParDo},
+               {pipeline: primitives.BagStateParDoClear},
+               {pipeline: primitives.CombiningStateParDo},
+               {pipeline: primitives.ValueStateParDo},
+               {pipeline: primitives.ValueStateParDoClear},
+               {pipeline: primitives.ValueStateParDoWindowed},
+               {pipeline: primitives.MapStateParDo},
+               {pipeline: primitives.MapStateParDoClear},
+               {pipeline: primitives.SetStateParDo},
+               {pipeline: primitives.SetStateParDoClear},
+       }
+
+       for _, test := range tests {
+               t.Run(initTestName(test.pipeline), func(t *testing.T) {
                        p, s := beam.NewPipelineWithRoot()
                        test.pipeline(s)
                        _, err := executeWithT(context.Background(), t, p)
diff --git a/sdks/go/pkg/beam/runners/prism/internal/urns/urns.go 
b/sdks/go/pkg/beam/runners/prism/internal/urns/urns.go
index bf1e3665666..5312fd799c8 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/urns/urns.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/urns/urns.go
@@ -51,6 +51,7 @@ var (
        reqUrn     = toUrn[pipepb.StandardRequirements_Enum]()
        runProcUrn = toUrn[pipepb.StandardRunnerProtocols_Enum]()
        envUrn     = toUrn[pipepb.StandardEnvironments_Environments]()
+       usUrn      = toUrn[pipepb.StandardUserStateTypes_Enum]()
 )
 
 var (
@@ -93,6 +94,10 @@ var (
        SideInputIterable = siUrn(pipepb.StandardSideInputTypes_ITERABLE)
        SideInputMultiMap = siUrn(pipepb.StandardSideInputTypes_MULTIMAP)
 
+       // UserState kinds
+       UserStateBag      = usUrn(pipepb.StandardUserStateTypes_BAG)
+       UserStateMultiMap = usUrn(pipepb.StandardUserStateTypes_MULTIMAP)
+
        // WindowsFns
        WindowFnGlobal  = quickUrn(pipepb.GlobalWindowsPayload_PROPERTIES)
        WindowFnFixed   = quickUrn(pipepb.FixedWindowsPayload_PROPERTIES)
diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go 
b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go
index 97250092940..6ef3a81e623 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/worker/bundle.go
@@ -42,11 +42,13 @@ type B struct {
        InputTransformID string
        InputData        [][]byte // Data specifically for this bundle.
 
-       // IterableSideInputData is a map from transformID, to inputID, to 
window, to data.
+       // IterableSideInputData is a map from transformID + inputID, to 
window, to data.
        IterableSideInputData map[SideInputKey]map[typex.Window][][]byte
-       // MultiMapSideInputData is a map from transformID, to inputID, to 
window, to data key, to data values.
+       // MultiMapSideInputData is a map from transformID + inputID, to 
window, to data key, to data values.
        MultiMapSideInputData 
map[SideInputKey]map[typex.Window]map[string][][]byte
 
+       // State lives in OutputData
+
        // OutputCount is the number of data or timer outputs this bundle has.
        // We need to see this many closed data channels before the bundle is 
complete.
        OutputCount int
diff --git a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go 
b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go
index beee5e896ff..2859dfe2356 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/worker/worker.go
@@ -36,6 +36,7 @@ import (
        "github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
        fnpb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/fnexecution_v1"
        pipepb "github.com/apache/beam/sdks/v2/go/pkg/beam/model/pipeline_v1"
+       
"github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/engine"
        "github.com/apache/beam/sdks/v2/go/pkg/beam/runners/prism/internal/urns"
        "golang.org/x/exp/slog"
        "google.golang.org/grpc"
@@ -412,21 +413,21 @@ func (wk *W) State(state fnpb.BeamFnState_StateServer) 
error {
                                        panic(err)
                                }
                        }
+
+                       // State requests are always for an active 
ProcessBundle instruction
+                       wk.mu.Lock()
+                       b, ok := 
wk.activeInstructions[req.GetInstructionId()].(*B)
+                       wk.mu.Unlock()
+                       if !ok {
+                               slog.Warn("state request after bundle 
inactive", "instruction", req.GetInstructionId(), "worker", wk)
+                               continue
+                       }
                        switch req.GetRequest().(type) {
                        case *fnpb.StateRequest_Get:
                                // TODO: move data handling to be pcollection 
based.
 
-                               // State requests are always for an active 
ProcessBundle instruction
-                               wk.mu.Lock()
-                               b, ok := 
wk.activeInstructions[req.GetInstructionId()].(*B)
-                               wk.mu.Unlock()
-                               if !ok {
-                                       slog.Warn("state request after bundle 
inactive", "instruction", req.GetInstructionId(), "worker", wk)
-                                       continue
-                               }
                                key := req.GetStateKey()
                                slog.Debug("StateRequest_Get", 
prototext.Format(req), "bundle", b)
-
                                var data [][]byte
                                switch key.GetType().(type) {
                                case *fnpb.StateKey_IterableSideInput_:
@@ -442,11 +443,13 @@ func (wk *W) State(state fnpb.BeamFnState_StateServer) 
error {
                                                }
                                        }
                                        winMap := 
b.IterableSideInputData[SideInputKey{TransformID: ikey.GetTransformId(), Local: 
ikey.GetSideInputId()}]
+
                                        var wins []typex.Window
                                        for w := range winMap {
                                                wins = append(wins, w)
                                        }
                                        slog.Debug(fmt.Sprintf("side 
input[%v][%v] I Key: %v Windows: %v", req.GetId(), req.GetInstructionId(), w, 
wins))
+
                                        data = winMap[w]
 
                                case *fnpb.StateKey_MultimapSideInput_:
@@ -458,37 +461,81 @@ func (wk *W) State(state fnpb.BeamFnState_StateServer) 
error {
                                        } else {
                                                w, err = 
exec.MakeWindowDecoder(coder.NewIntervalWindow()).DecodeSingle(bytes.NewBuffer(wKey))
                                                if err != nil {
-                                                       
panic(fmt.Sprintf("error decoding iterable side input window key %v: %v", wKey, 
err))
+                                                       
panic(fmt.Sprintf("error decoding multimap side input window key %v: %v", wKey, 
err))
                                                }
                                        }
                                        dKey := mmkey.GetKey()
                                        winMap := 
b.MultiMapSideInputData[SideInputKey{TransformID: mmkey.GetTransformId(), 
Local: mmkey.GetSideInputId()}]
-                                       var wins []typex.Window
-                                       for w := range winMap {
-                                               wins = append(wins, w)
-                                       }
-                                       slog.Debug(fmt.Sprintf("side 
input[%v][%v] MM Key: %v Windows: %v", req.GetId(), req.GetInstructionId(), w, 
wins))
+
+                                       slog.Debug(fmt.Sprintf("side 
input[%v][%v] MultiMap Window: %v", req.GetId(), req.GetInstructionId(), w))
 
                                        data = winMap[w][string(dKey)]
 
+                               case *fnpb.StateKey_BagUserState_:
+                                       bagkey := key.GetBagUserState()
+                                       data = 
b.OutputData.GetBagState(engine.LinkID{Transform: bagkey.GetTransformId(), 
Local: bagkey.GetUserStateId()}, bagkey.GetWindow(), bagkey.GetKey())
+                               case *fnpb.StateKey_MultimapUserState_:
+                                       mmkey := key.GetMultimapUserState()
+                                       data = 
b.OutputData.GetMultimapState(engine.LinkID{Transform: mmkey.GetTransformId(), 
Local: mmkey.GetUserStateId()}, mmkey.GetWindow(), mmkey.GetKey(), 
mmkey.GetMapKey())
+                               case *fnpb.StateKey_MultimapKeysUserState_:
+                                       mmkey := key.GetMultimapKeysUserState()
+                                       data = 
b.OutputData.GetMultimapKeysState(engine.LinkID{Transform: 
mmkey.GetTransformId(), Local: mmkey.GetUserStateId()}, mmkey.GetWindow(), 
mmkey.GetKey())
                                default:
-                                       panic(fmt.Sprintf("unsupported StateKey 
Access type: %T: %v", key.GetType(), prototext.Format(key)))
+                                       panic(fmt.Sprintf("unsupported StateKey 
Get type: %T: %v", key.GetType(), prototext.Format(key)))
                                }
 
                                // Encode the runner iterable (no length, just 
consecutive elements), and send it out.
                                // This is also where we can handle things like 
State Backed Iterables.
-                               var buf bytes.Buffer
-                               for _, value := range data {
-                                       buf.Write(value)
-                               }
                                responses <- &fnpb.StateResponse{
                                        Id: req.GetId(),
                                        Response: &fnpb.StateResponse_Get{
                                                Get: &fnpb.StateGetResponse{
-                                                       Data: buf.Bytes(),
+                                                       Data: bytes.Join(data, 
[]byte{}),
                                                },
                                        },
                                }
+
+                       case *fnpb.StateRequest_Append:
+                               key := req.GetStateKey()
+                               switch key.GetType().(type) {
+                               case *fnpb.StateKey_BagUserState_:
+                                       bagkey := key.GetBagUserState()
+                                       
b.OutputData.AppendBagState(engine.LinkID{Transform: bagkey.GetTransformId(), 
Local: bagkey.GetUserStateId()}, bagkey.GetWindow(), bagkey.GetKey(), 
req.GetAppend().GetData())
+                               case *fnpb.StateKey_MultimapUserState_:
+                                       mmkey := key.GetMultimapUserState()
+                                       
b.OutputData.AppendMultimapState(engine.LinkID{Transform: 
mmkey.GetTransformId(), Local: mmkey.GetUserStateId()}, mmkey.GetWindow(), 
mmkey.GetKey(), mmkey.GetMapKey(), req.GetAppend().GetData())
+                               default:
+                                       panic(fmt.Sprintf("unsupported StateKey 
Append type: %T: %v", key.GetType(), prototext.Format(key)))
+                               }
+                               responses <- &fnpb.StateResponse{
+                                       Id: req.GetId(),
+                                       Response: &fnpb.StateResponse_Append{
+                                               Append: 
&fnpb.StateAppendResponse{},
+                                       },
+                               }
+
+                       case *fnpb.StateRequest_Clear:
+                               key := req.GetStateKey()
+                               switch key.GetType().(type) {
+                               case *fnpb.StateKey_BagUserState_:
+                                       bagkey := key.GetBagUserState()
+                                       
b.OutputData.ClearBagState(engine.LinkID{Transform: bagkey.GetTransformId(), 
Local: bagkey.GetUserStateId()}, bagkey.GetWindow(), bagkey.GetKey())
+                               case *fnpb.StateKey_MultimapUserState_:
+                                       mmkey := key.GetMultimapUserState()
+                                       
b.OutputData.ClearMultimapState(engine.LinkID{Transform: 
mmkey.GetTransformId(), Local: mmkey.GetUserStateId()}, mmkey.GetWindow(), 
mmkey.GetKey(), mmkey.GetMapKey())
+                               case *fnpb.StateKey_MultimapKeysUserState_:
+                                       mmkey := key.GetMultimapUserState()
+                                       
b.OutputData.ClearMultimapKeysState(engine.LinkID{Transform: 
mmkey.GetTransformId(), Local: mmkey.GetUserStateId()}, mmkey.GetWindow(), 
mmkey.GetKey())
+                               default:
+                                       panic(fmt.Sprintf("unsupported StateKey 
Clear type: %T: %v", key.GetType(), prototext.Format(key)))
+                               }
+                               responses <- &fnpb.StateResponse{
+                                       Id: req.GetId(),
+                                       Response: &fnpb.StateResponse_Clear{
+                                               Clear: 
&fnpb.StateClearResponse{},
+                                       },
+                               }
+
                        default:
                                panic(fmt.Sprintf("unsupported StateRequest 
kind %T: %v", req.GetRequest(), prototext.Format(req)))
                        }
diff --git a/sdks/go/test/integration/primitives/state.go 
b/sdks/go/test/integration/primitives/state.go
index 5f105597ba3..acf1bf8fa66 100644
--- a/sdks/go/test/integration/primitives/state.go
+++ b/sdks/go/test/integration/primitives/state.go
@@ -39,6 +39,7 @@ func init() {
        register.DoFn3x1[state.Provider, string, int, 
string](&mapStateClearFn{})
        register.DoFn3x1[state.Provider, string, int, string](&setStateFn{})
        register.DoFn3x1[state.Provider, string, int, 
string](&setStateClearFn{})
+       register.Function2x0(pairWithOne)
        register.Emitter2[string, int]()
        register.Combiner1[int](&combine1{})
        register.Combiner2[string, int](&combine2{})
@@ -78,12 +79,14 @@ func (f *valueStateFn) ProcessElement(s state.Provider, w 
string, c int) string
        return fmt.Sprintf("%s: %v, %s", w, i, j)
 }
 
+func pairWithOne(w string, emit func(string, int)) {
+       emit(w, 1)
+}
+
 // ValueStateParDo tests a DoFn that uses value state.
 func ValueStateParDo(s beam.Scope) {
        in := beam.Create(s, "apple", "pear", "peach", "apple", "apple", "pear")
-       keyed := beam.ParDo(s, func(w string, emit func(string, int)) {
-               emit(w, 1)
-       }, in)
+       keyed := beam.ParDo(s, pairWithOne, in)
        counts := beam.ParDo(s, &valueStateFn{}, keyed)
        passert.Equals(s, counts, "apple: 1, I", "pear: 1, I", "peach: 1, I", 
"apple: 2, II", "apple: 3, III", "pear: 2, II")
 }
@@ -124,9 +127,7 @@ func (f *valueStateClearFn) ProcessElement(s 
state.Provider, w string, c int) st
 // ValueStateParDoClear tests that a DoFn that uses value state can be cleared.
 func ValueStateParDoClear(s beam.Scope) {
        in := beam.Create(s, "apple", "pear", "peach", "apple", "apple", 
"pear", "pear", "apple")
-       keyed := beam.ParDo(s, func(w string, emit func(string, int)) {
-               emit(w, 1)
-       }, in)
+       keyed := beam.ParDo(s, pairWithOne, in)
        counts := beam.ParDo(s, &valueStateClearFn{State1: 
state.MakeValueState[int]("key1")}, keyed)
        passert.Equals(s, counts, "apple: 0,false", "pear: 0,false", "peach: 
0,false", "apple: 1,true", "apple: 0,false", "pear: 1,true", "pear: 0,false", 
"apple: 1,true")
 }
@@ -170,9 +171,7 @@ func (f *bagStateFn) ProcessElement(s state.Provider, w 
string, c int) string {
 // BagStateParDo tests a DoFn that uses bag state.
 func BagStateParDo(s beam.Scope) {
        in := beam.Create(s, "apple", "pear", "peach", "apple", "apple", "pear")
-       keyed := beam.ParDo(s, func(w string, emit func(string, int)) {
-               emit(w, 1)
-       }, in)
+       keyed := beam.ParDo(s, pairWithOne, in)
        counts := beam.ParDo(s, &bagStateFn{}, keyed)
        passert.Equals(s, counts, "apple: 0, ", "pear: 0, ", "peach: 0, ", 
"apple: 1, I", "apple: 2, I,I", "pear: 1, I")
 }
@@ -207,9 +206,7 @@ func (f *bagStateClearFn) ProcessElement(s state.Provider, 
w string, c int) stri
 // BagStateParDoClear tests a DoFn that uses bag state.
 func BagStateParDoClear(s beam.Scope) {
        in := beam.Create(s, "apple", "pear", "apple", "apple", "pear", 
"apple", "apple", "pear", "pear", "pear", "apple", "pear")
-       keyed := beam.ParDo(s, func(w string, emit func(string, int)) {
-               emit(w, 1)
-       }, in)
+       keyed := beam.ParDo(s, pairWithOne, in)
        counts := beam.ParDo(s, &bagStateClearFn{State1: 
state.MakeBagState[int]("key1")}, keyed)
        passert.Equals(s, counts, "apple: 0", "pear: 0", "apple: 1", "apple: 
2", "pear: 1", "apple: 3", "apple: 0", "pear: 2", "pear: 3", "pear: 0", "apple: 
1", "pear: 1")
 }
@@ -312,16 +309,20 @@ func (f *combiningStateFn) ProcessElement(s 
state.Provider, w string, c int) str
        return fmt.Sprintf("%s: %v %v %v %v %v", w, i, i1, i2, i3, i4)
 }
 
+func init() {
+       register.Function2x1(sumInt)
+}
+
+func sumInt(a, b int) int {
+       return a + b
+}
+
 // CombiningStateParDo tests a DoFn that uses value state.
 func CombiningStateParDo(s beam.Scope) {
        in := beam.Create(s, "apple", "pear", "peach", "apple", "apple", "pear")
-       keyed := beam.ParDo(s, func(w string, emit func(string, int)) {
-               emit(w, 1)
-       }, in)
+       keyed := beam.ParDo(s, pairWithOne, in)
        counts := beam.ParDo(s, &combiningStateFn{
-               State0: state.MakeCombiningState[int, int, int]("key0", func(a, 
b int) int {
-                       return a + b
-               }),
+               State0: state.MakeCombiningState[int, int, int]("key0", sumInt),
                State1: state.Combining[int, int, 
int](state.MakeCombiningState[int, int, int]("key1", &combine1{})),
                State2: state.Combining[string, string, 
int](state.MakeCombiningState[string, string, int]("key2", &combine2{})),
                State3: state.Combining[string, string, 
int](state.MakeCombiningState[string, string, int]("key3", &combine3{})),
@@ -369,9 +370,7 @@ func (f *mapStateFn) ProcessElement(s state.Provider, w 
string, c int) string {
 // MapStateParDo tests a DoFn that uses value state.
 func MapStateParDo(s beam.Scope) {
        in := beam.Create(s, "apple", "pear", "peach", "apple", "apple", "pear")
-       keyed := beam.ParDo(s, func(w string, emit func(string, int)) {
-               emit(w, 1)
-       }, in)
+       keyed := beam.ParDo(s, pairWithOne, in)
        counts := beam.ParDo(s, &mapStateFn{State1: state.MakeMapState[string, 
int]("key1")}, keyed)
        passert.Equals(s, counts, "apple: 1, keys: [apple apple1]", "pear: 1, 
keys: [pear pear1]", "peach: 1, keys: [peach peach1]", "apple: 2, keys: [apple 
apple1 apple2]", "apple: 3, keys: [apple apple1 apple2 apple3]", "pear: 2, 
keys: [pear pear1 pear2]")
 }
@@ -425,9 +424,7 @@ func (f *mapStateClearFn) ProcessElement(s state.Provider, 
w string, c int) stri
 // MapStateParDoClear tests clearing and removing from a DoFn that uses map 
state.
 func MapStateParDoClear(s beam.Scope) {
        in := beam.Create(s, "apple", "pear", "peach", "apple", "apple", "pear")
-       keyed := beam.ParDo(s, func(w string, emit func(string, int)) {
-               emit(w, 1)
-       }, in)
+       keyed := beam.ParDo(s, pairWithOne, in)
        counts := beam.ParDo(s, &mapStateClearFn{State1: 
state.MakeMapState[string, int]("key1")}, keyed)
        passert.Equals(s, counts, "apple: [apple]", "pear: [pear]", "peach: 
[peach]", "apple: [apple1 apple2 apple3]", "apple: []", "pear: [pear1 pear2 
pear3]")
 }
@@ -465,9 +462,7 @@ func (f *setStateFn) ProcessElement(s state.Provider, w 
string, c int) string {
 // SetStateParDo tests a DoFn that uses set state.
 func SetStateParDo(s beam.Scope) {
        in := beam.Create(s, "apple", "pear", "peach", "apple", "apple", "pear")
-       keyed := beam.ParDo(s, func(w string, emit func(string, int)) {
-               emit(w, 1)
-       }, in)
+       keyed := beam.ParDo(s, pairWithOne, in)
        counts := beam.ParDo(s, &setStateFn{State1: 
state.MakeSetState[string]("key1")}, keyed)
        passert.Equals(s, counts, "apple: false, keys: [apple]", "pear: false, 
keys: [pear]", "peach: false, keys: [peach]", "apple: true, keys: [apple 
apple1]", "apple: true, keys: [apple apple1]", "pear: true, keys: [pear pear1]")
 }
@@ -521,9 +516,7 @@ func (f *setStateClearFn) ProcessElement(s state.Provider, 
w string, c int) stri
 // SetStateParDoClear tests clearing and removing from a DoFn that uses set 
state.
 func SetStateParDoClear(s beam.Scope) {
        in := beam.Create(s, "apple", "pear", "peach", "apple", "apple", "pear")
-       keyed := beam.ParDo(s, func(w string, emit func(string, int)) {
-               emit(w, 1)
-       }, in)
+       keyed := beam.ParDo(s, pairWithOne, in)
        counts := beam.ParDo(s, &setStateClearFn{State1: 
state.MakeSetState[string]("key1")}, keyed)
        passert.Equals(s, counts, "apple: [apple]", "pear: [pear]", "peach: 
[peach]", "apple: [apple1 apple2 apple3]", "apple: []", "pear: [pear1 pear2 
pear3]")
 }

Reply via email to