This is an automated email from the ASF dual-hosted git repository.

lostluck pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 29ea6e0eb8a [Go SDK]: Allow SDF methods to have context param and 
error return value (#25437)
29ea6e0eb8a is described below

commit 29ea6e0eb8a2f2cf571d1a27799d125ca051008b
Author: Johanna Öjeling <51084516+johannaojel...@users.noreply.github.com>
AuthorDate: Thu Feb 16 21:02:08 2023 +0100

    [Go SDK]: Allow SDF methods to have context param and error return value 
(#25437)
    
    * Allow context param and error return value in SDF validation
    
    * Use context param and error return value in SDF method invocation
    
    * Run go fmt
    
    * Clean up error messages from "fn reflect.methodValueCall"
    
    * Validate return value count in a more correct way
---
 sdks/go/pkg/beam/core/graph/fn.go                  | 108 ++++---
 sdks/go/pkg/beam/core/graph/fn_test.go             |  87 ++++++
 sdks/go/pkg/beam/core/runtime/exec/datasource.go   |   8 +-
 .../pkg/beam/core/runtime/exec/datasource_test.go  |  12 +-
 sdks/go/pkg/beam/core/runtime/exec/sdf.go          | 134 +++++---
 sdks/go/pkg/beam/core/runtime/exec/sdf_invokers.go | 339 +++++++++------------
 .../beam/core/runtime/exec/sdf_invokers_arity.go   | 336 ++++++++++++++++++++
 .../beam/core/runtime/exec/sdf_invokers_arity.tmpl | 246 +++++++++++++++
 .../beam/core/runtime/exec/sdf_invokers_test.go    | 250 +++++++++++++--
 sdks/go/pkg/beam/core/runtime/exec/sdf_test.go     |  17 +-
 .../beam/runners/prism/internal/config/config.go   |   2 +-
 .../beam/runners/prism/internal/urns/urns_test.go  |   2 +-
 12 files changed, 1227 insertions(+), 314 deletions(-)

diff --git a/sdks/go/pkg/beam/core/graph/fn.go 
b/sdks/go/pkg/beam/core/graph/fn.go
index 907af4045b5..25b846370fb 100644
--- a/sdks/go/pkg/beam/core/graph/fn.go
+++ b/sdks/go/pkg/beam/core/graph/fn.go
@@ -866,14 +866,15 @@ func validateSdfSignatures(fn *Fn, numMainIn mainInputs) 
error {
        // CreateInitialRestriction.
        if numMainIn == MainUnknown {
                initialRestFn := fn.methods[createInitialRestrictionName]
-               paramNum := len(initialRestFn.Param)
+               paramNum := len(initialRestFn.Params(funcx.FnValue))
+
                switch paramNum {
                case int(MainSingle), int(MainKv):
                        num = paramNum
                default: // Can't infer because method has invalid # of main 
inputs.
-                       err := errors.Errorf("invalid number of params in 
method %v. got: %v, want: %v or %v",
+                       err := errors.Errorf("invalid number of main input 
params in method %v. got: %v, want: %v or %v",
                                createInitialRestrictionName, paramNum, 
int(MainSingle), int(MainKv))
-                       return errors.SetTopLevelMsgf(err, "Invalid number of 
parameters in method %v. "+
+                       return errors.SetTopLevelMsgf(err, "Invalid number of 
main input parameters in method %v. "+
                                "Got: %v, Want: %v or %v. Check that the 
signature conforms to the expected signature for %v, "+
                                "and that elements in SDF method parameters 
match elements in %v.",
                                createInitialRestrictionName, paramNum, 
int(MainSingle), int(MainKv), createInitialRestrictionName, processElementName)
@@ -894,7 +895,7 @@ func validateSdfSignatures(fn *Fn, numMainIn mainInputs) 
error {
 // in each SDF method in the given Fn, and returns an error if a method has an
 // invalid/unexpected number.
 func validateSdfSigNumbers(fn *Fn, num int) error {
-       paramNums := map[string]int{
+       reqParamNums := map[string]int{
                createInitialRestrictionName: num,
                splitRestrictionName:         num + 1,
                restrictionSizeName:          num + 1,
@@ -904,32 +905,52 @@ func validateSdfSigNumbers(fn *Fn, num int) error {
        optionalSdfs := map[string]bool{
                truncateRestrictionName: true,
        }
-       returnNum := 1 // TODO(BEAM-3301): Enable optional error params in SDF 
methods.
+       reqReturnNum := 1
 
        for _, name := range sdfNames {
                method, ok := fn.methods[name]
                if !ok && optionalSdfs[name] {
                        continue
                }
-               if len(method.Param) != paramNums[name] {
-                       err := errors.Errorf("unexpected number of params in 
method %v. got: %v, want: %v",
-                               name, len(method.Param), paramNums[name])
+
+               reqParamNum := reqParamNums[name]
+               if !sdfHasValidParamNum(method.Param, reqParamNum) {
+                       err := errors.Errorf("unexpected number of params in 
method %v. got: %v, want: %v or optionally %v "+
+                               "if first param is of type context.Context", 
name, len(method.Param), reqParamNum, reqParamNum+1)
                        return errors.SetTopLevelMsgf(err, "Unexpected number 
of parameters in method %v. "+
-                               "Got: %v, Want: %v. Check that the signature 
conforms to the expected signature for %v, "+
-                               "and that elements in SDF method parameters 
match elements in %v.",
-                               name, len(method.Param), paramNums[name], name, 
processElementName)
+                               "Got: %v, Want: %v or optionally %v if first 
param is of type context.Context. "+
+                               "Check that the signature conforms to the 
expected signature for %v, and that elements in SDF method "+
+                               "parameters match elements in %v.", name, 
len(method.Param), reqParamNum, reqParamNum+1,
+                               name, processElementName)
                }
-               if len(method.Ret) != returnNum {
-                       err := errors.Errorf("unexpected number of returns in 
method %v. got: %v, want: %v",
-                               name, len(method.Ret), returnNum)
+               if !sdfHasValidReturnNum(method.Ret, reqReturnNum) {
+                       err := errors.Errorf("unexpected number of returns in 
method %v. got: %v, want: %v or optionally %v "+
+                               "if last value is of type error", name, 
len(method.Ret), reqReturnNum, reqReturnNum+1)
                        return errors.SetTopLevelMsgf(err, "Unexpected number 
of return values in method %v. "+
-                               "Got: %v, Want: %v. Check that the signature 
conforms to the expected signature for %v.",
-                               name, len(method.Ret), returnNum, name)
+                               "Got: %v, Want: %v or optionally %v if last 
value is of type error. "+
+                               "Check that the signature conforms to the 
expected signature for %v.",
+                               name, len(method.Ret), reqReturnNum, 
reqReturnNum+1, name)
                }
        }
        return nil
 }
 
+func sdfHasValidParamNum(params []funcx.FnParam, requiredNum int) bool {
+       if len(params) == requiredNum {
+               return true
+       }
+
+       return len(params) == requiredNum+1 && params[0].Kind == funcx.FnContext
+}
+
+func sdfHasValidReturnNum(returns []funcx.ReturnParam, requiredNum int) bool {
+       if len(returns) == requiredNum {
+               return true
+       }
+
+       return len(returns) == requiredNum+1 && returns[len(returns)-1].Kind == 
funcx.RetError
+}
+
 // validateSdfSigTypes validates the types of the parameters and return values
 // in each SDF method in the given Fn, and returns an error if a method has an
 // invalid/mismatched type. Assumes that the number of parameters and return
@@ -940,22 +961,25 @@ func validateSdfSigTypes(fn *Fn, num int) error {
 
        for _, name := range requiredSdfNames {
                method := fn.methods[name]
+               startIdx := sdfRequiredParamStartIndex(method)
+
                switch name {
                case createInitialRestrictionName:
-                       if err := validateSdfElementT(fn, 
createInitialRestrictionName, method, num, 0); err != nil {
+                       if err := validateSdfElementT(fn, 
createInitialRestrictionName, method, num, startIdx); err != nil {
                                return err
                        }
                case splitRestrictionName:
-                       if err := validateSdfElementT(fn, splitRestrictionName, 
method, num, 0); err != nil {
+                       if err := validateSdfElementT(fn, splitRestrictionName, 
method, num, startIdx); err != nil {
                                return err
                        }
-                       if method.Param[num].T != restrictionT {
+                       idx := num + startIdx
+                       if method.Param[idx].T != restrictionT {
                                err := errors.Errorf("mismatched restriction 
type in method %v, param %v. got: %v, want: %v",
-                                       splitRestrictionName, num, 
method.Param[num].T, restrictionT)
+                                       splitRestrictionName, idx, 
method.Param[idx].T, restrictionT)
                                return errors.SetTopLevelMsgf(err, "Mismatched 
restriction type in method %v, "+
                                        "parameter at index %v. Got: %v, Want: 
%v (from method %v). "+
                                        "Ensure that all restrictions in an SDF 
are the same type.",
-                                       splitRestrictionName, num, 
method.Param[num].T, restrictionT, createInitialRestrictionName)
+                                       splitRestrictionName, idx, 
method.Param[idx].T, restrictionT, createInitialRestrictionName)
                        }
                        if method.Ret[0].T.Kind() != reflect.Slice ||
                                method.Ret[0].T.Elem() != restrictionT {
@@ -967,16 +991,17 @@ func validateSdfSigTypes(fn *Fn, num int) error {
                                        splitRestrictionName, 0, 
method.Ret[0].T, reflect.SliceOf(restrictionT), createInitialRestrictionName, 
splitRestrictionName)
                        }
                case restrictionSizeName:
-                       if err := validateSdfElementT(fn, restrictionSizeName, 
method, num, 0); err != nil {
+                       if err := validateSdfElementT(fn, restrictionSizeName, 
method, num, startIdx); err != nil {
                                return err
                        }
-                       if method.Param[num].T != restrictionT {
+                       idx := num + startIdx
+                       if method.Param[idx].T != restrictionT {
                                err := errors.Errorf("mismatched restriction 
type in method %v, param %v. got: %v, want: %v",
-                                       restrictionSizeName, num, 
method.Param[num].T, restrictionT)
+                                       restrictionSizeName, idx, 
method.Param[idx].T, restrictionT)
                                return errors.SetTopLevelMsgf(err, "Mismatched 
restriction type in method %v, "+
                                        "parameter at index %v. Got: %v, Want: 
%v (from method %v). "+
                                        "Ensure that all restrictions in an SDF 
are the same type.",
-                                       restrictionSizeName, num, 
method.Param[num].T, restrictionT, createInitialRestrictionName)
+                                       restrictionSizeName, idx, 
method.Param[idx].T, restrictionT, createInitialRestrictionName)
                        }
                        if method.Ret[0].T != reflectx.Float64 {
                                err := errors.Errorf("invalid output type in 
method %v, return %v. got: %v, want: %v",
@@ -986,13 +1011,13 @@ func validateSdfSigTypes(fn *Fn, num int) error {
                                        restrictionSizeName, 0, 
method.Ret[0].T, reflectx.Float64)
                        }
                case createTrackerName:
-                       if method.Param[0].T != restrictionT {
+                       if method.Param[startIdx].T != restrictionT {
                                err := errors.Errorf("mismatched restriction 
type in method %v, param %v. got: %v, want: %v",
-                                       createTrackerName, 0, 
method.Param[0].T, restrictionT)
+                                       createTrackerName, startIdx, 
method.Param[startIdx].T, restrictionT)
                                return errors.SetTopLevelMsgf(err, "Mismatched 
restriction type in method %v, "+
                                        "parameter at index %v. Got: %v, Want: 
%v (from method %v). "+
                                        "Ensure that all restrictions in an SDF 
are the same type.",
-                                       createTrackerName, 0, 
method.Param[0].T, restrictionT, createInitialRestrictionName)
+                                       createTrackerName, startIdx, 
method.Param[startIdx].T, restrictionT, createInitialRestrictionName)
                        }
                        if !method.Ret[0].T.Implements(rTrackerT) {
                                err := errors.Errorf("invalid output type in 
method %v, return %v: %v does not implement sdf.RTracker",
@@ -1020,15 +1045,18 @@ func validateSdfSigTypes(fn *Fn, num int) error {
                if !ok {
                        continue
                }
+
+               startIdx := sdfRequiredParamStartIndex(method)
+
                switch name {
                case truncateRestrictionName:
-                       if method.Param[0].T != rTrackerImplT {
+                       if method.Param[startIdx].T != rTrackerImplT {
                                err := errors.Errorf("mismatched restriction 
tracker type in method %v, param %v. got: %v, want: %v",
-                                       truncateRestrictionName, 0, 
method.Param[0].T, rTrackerImplT)
+                                       truncateRestrictionName, startIdx, 
method.Param[startIdx].T, rTrackerImplT)
                                return errors.SetTopLevelMsgf(err, "Mismatched 
restriction tracker type in method %v, "+
                                        "parameter at index %v. Got: %v, Want: 
%v (from method %v). "+
                                        "Ensure that restriction tracker is the 
first parameter.",
-                                       truncateRestrictionName, 0, 
method.Param[0].T, rTrackerImplT, createTrackerName)
+                                       truncateRestrictionName, startIdx, 
method.Param[startIdx].T, rTrackerImplT, createTrackerName)
                        }
                        if method.Ret[0].T != restrictionT {
                                err := errors.Errorf("invalid output type in 
method %v, return %v. got: %v, want: %v",
@@ -1052,6 +1080,14 @@ func validateSdfSigTypes(fn *Fn, num int) error {
        return nil
 }
 
+func sdfRequiredParamStartIndex(method *funcx.Fn) int {
+       if ctxIndex, ok := method.Context(); ok {
+               return ctxIndex + 1
+       }
+
+       return 0
+}
+
 // validateSdfElementT validates that element types in an SDF method are
 // consistent with the ProcessElement method. This method assumes that the
 // first 'num' parameters starting with startIndex are the elements.
@@ -1062,13 +1098,14 @@ func validateSdfElementT(fn *Fn, name string, method 
*funcx.Fn, num int, startIn
        pos, _, _ := processFn.Inputs()
 
        for i := 0; i < num; i++ {
-               if method.Param[i+startIndex].T != processFn.Param[pos+i].T {
+               idx := i + startIndex
+               if method.Param[idx].T != processFn.Param[pos+i].T {
                        err := errors.Errorf("mismatched element type in method 
%v, param %v. got: %v, want: %v",
-                               name, i, method.Param[i].T, 
processFn.Param[pos+i].T)
+                               name, idx, method.Param[idx].T, 
processFn.Param[pos+i].T)
                        return errors.SetTopLevelMsgf(err, "Mismatched element 
type in method %v, "+
                                "parameter at index %v. Got: %v, Want: %v (from 
method %v). "+
                                "Ensure that element parameters in SDF methods 
have consistent types with element parameters in %v.",
-                               name, i, method.Param[i].T, 
processFn.Param[pos+i].T, processElementName, processElementName)
+                               name, idx, method.Param[idx].T, 
processFn.Param[pos+i].T, processElementName, processElementName)
                }
        }
        return nil
@@ -1178,7 +1215,8 @@ func validateStatefulWatermarkSig(fn *Fn, numMainIn int) 
error {
        // CreateInitialRestriction.
        if numMainIn == int(MainUnknown) {
                initialRestFn := fn.methods[createInitialRestrictionName]
-               paramNum := len(initialRestFn.Param)
+               paramNum := len(initialRestFn.Params(funcx.FnValue))
+
                switch paramNum {
                case int(MainSingle), int(MainKv):
                        numMainIn = paramNum
diff --git a/sdks/go/pkg/beam/core/graph/fn_test.go 
b/sdks/go/pkg/beam/core/graph/fn_test.go
index d2f88a8a5ce..cf44761d4f3 100644
--- a/sdks/go/pkg/beam/core/graph/fn_test.go
+++ b/sdks/go/pkg/beam/core/graph/fn_test.go
@@ -190,6 +190,9 @@ func TestNewDoFnSdf(t *testing.T) {
                }{
                        {dfn: &GoodSdf{}, main: MainSingle},
                        {dfn: &GoodSdfKv{}, main: MainKv},
+                       {dfn: &GoodSdfWContext{}, main: MainSingle},
+                       {dfn: &GoodSdfKvWContext{}, main: MainKv},
+                       {dfn: &GoodSdfWErr{}, main: MainSingle},
                        {dfn: &GoodIgnoreOtherExportedMethods{}, main: 
MainSingle},
                }
 
@@ -987,6 +990,90 @@ func (fn *GoodSdfKv) TruncateRestriction(*RTrackerT, int, 
int) RestT {
        return RestT{}
 }
 
+type GoodSdfWContext struct {
+       *GoodDoFn
+}
+
+func (fn *GoodSdfWContext) CreateInitialRestriction(context.Context, int) 
RestT {
+       return RestT{}
+}
+
+func (fn *GoodSdfWContext) SplitRestriction(context.Context, int, RestT) 
[]RestT {
+       return []RestT{}
+}
+
+func (fn *GoodSdfWContext) RestrictionSize(context.Context, int, RestT) 
float64 {
+       return 0
+}
+
+func (fn *GoodSdfWContext) CreateTracker(context.Context, RestT) *RTrackerT {
+       return &RTrackerT{}
+}
+
+func (fn *GoodSdfWContext) ProcessElement(context.Context, *RTrackerT, int) 
(int, sdf.ProcessContinuation) {
+       return 0, sdf.StopProcessing()
+}
+
+func (fn *GoodSdfWContext) TruncateRestriction(context.Context, *RTrackerT, 
int) RestT {
+       return RestT{}
+}
+
+type GoodSdfKvWContext struct {
+       *GoodDoFnKv
+}
+
+func (fn *GoodSdfKvWContext) CreateInitialRestriction(context.Context, int, 
int) RestT {
+       return RestT{}
+}
+
+func (fn *GoodSdfKvWContext) SplitRestriction(context.Context, int, int, 
RestT) []RestT {
+       return []RestT{}
+}
+
+func (fn *GoodSdfKvWContext) RestrictionSize(context.Context, int, int, RestT) 
float64 {
+       return 0
+}
+
+func (fn *GoodSdfKvWContext) CreateTracker(context.Context, RestT) *RTrackerT {
+       return &RTrackerT{}
+}
+
+func (fn *GoodSdfKvWContext) ProcessElement(context.Context, *RTrackerT, int, 
int) (int, sdf.ProcessContinuation) {
+       return 0, sdf.StopProcessing()
+}
+
+func (fn *GoodSdfKvWContext) TruncateRestriction(context.Context, *RTrackerT, 
int, int) RestT {
+       return RestT{}
+}
+
+type GoodSdfWErr struct {
+       *GoodDoFn
+}
+
+func (fn *GoodSdfWErr) CreateInitialRestriction(int) (RestT, error) {
+       return RestT{}, nil
+}
+
+func (fn *GoodSdfWErr) SplitRestriction(int, RestT) ([]RestT, error) {
+       return []RestT{}, nil
+}
+
+func (fn *GoodSdfWErr) RestrictionSize(int, RestT) (float64, error) {
+       return 0, nil
+}
+
+func (fn *GoodSdfWErr) CreateTracker(RestT) (*RTrackerT, error) {
+       return &RTrackerT{}, nil
+}
+
+func (fn *GoodSdfWErr) ProcessElement(*RTrackerT, int) (int, 
sdf.ProcessContinuation, error) {
+       return 0, sdf.StopProcessing(), nil
+}
+
+func (fn *GoodSdfWErr) TruncateRestriction(*RTrackerT, int) (RestT, error) {
+       return RestT{}, nil
+}
+
 type GoodIgnoreOtherExportedMethods struct {
        *GoodSdf
 }
diff --git a/sdks/go/pkg/beam/core/runtime/exec/datasource.go 
b/sdks/go/pkg/beam/core/runtime/exec/datasource.go
index 9c4de0564c8..a6347fc8d0e 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/datasource.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/datasource.go
@@ -196,7 +196,7 @@ func (n *DataSource) Process(ctx context.Context) 
([]*Checkpoint, error) {
                // Check if there's a continuation and return residuals
                // Needs to be done immeadiately after processing to not lose 
the element.
                if c := n.getProcessContinuation(); c != nil {
-                       cp, err := n.checkpointThis(c)
+                       cp, err := n.checkpointThis(ctx, c)
                        if err != nil {
                                // Errors during checkpointing should fail a 
bundle.
                                return nil, err
@@ -422,7 +422,7 @@ type Checkpoint struct {
 // splittable or has not returned a resuming continuation, the function 
returns an empty
 // SplitResult, a negative resumption time, and a false boolean to indicate 
that no split
 // occurred.
-func (n *DataSource) checkpointThis(pc sdf.ProcessContinuation) (*Checkpoint, 
error) {
+func (n *DataSource) checkpointThis(ctx context.Context, pc 
sdf.ProcessContinuation) (*Checkpoint, error) {
        n.mu.Lock()
        defer n.mu.Unlock()
 
@@ -435,7 +435,7 @@ func (n *DataSource) checkpointThis(pc 
sdf.ProcessContinuation) (*Checkpoint, er
        ow := su.GetOutputWatermark()
 
        // Checkpointing is functionally a split at fraction 0.0
-       rs, err := su.Checkpoint()
+       rs, err := su.Checkpoint(ctx)
        if err != nil {
                return nil, err
        }
@@ -530,7 +530,7 @@ func (n *DataSource) Split(ctx context.Context, splits 
[]int64, frac float64, bu
        // Get the output watermark before splitting to avoid accidentally 
overestimating
        ow := su.GetOutputWatermark()
        // Otherwise, perform a sub-element split.
-       ps, rs, err := su.Split(fr)
+       ps, rs, err := su.Split(ctx, fr)
        if err != nil {
                return SplitResult{}, err
        }
diff --git a/sdks/go/pkg/beam/core/runtime/exec/datasource_test.go 
b/sdks/go/pkg/beam/core/runtime/exec/datasource_test.go
index 64a37739b24..2da3284f016 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/datasource_test.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/datasource_test.go
@@ -631,7 +631,7 @@ type TestSplittableUnit struct {
 
 // Split checks the input fraction for correctness, but otherwise always 
returns
 // a successful split. The split elements are just copies of the original.
-func (n *TestSplittableUnit) Split(f float64) ([]*FullValue, []*FullValue, 
error) {
+func (n *TestSplittableUnit) Split(_ context.Context, f float64) 
([]*FullValue, []*FullValue, error) {
        if f > 1.0 || f < 0.0 {
                return nil, nil, errors.Errorf("Error")
        }
@@ -639,8 +639,8 @@ func (n *TestSplittableUnit) Split(f float64) 
([]*FullValue, []*FullValue, error
 }
 
 // Checkpoint routes through the Split() function to satisfy the interface.
-func (n *TestSplittableUnit) Checkpoint() ([]*FullValue, error) {
-       _, r, err := n.Split(0.0)
+func (n *TestSplittableUnit) Checkpoint(ctx context.Context) ([]*FullValue, 
error) {
+       _, r, err := n.Split(ctx, 0.0)
        return r, err
 }
 
@@ -876,13 +876,13 @@ func TestSplitHelper(t *testing.T) {
 
 func TestCheckpointing(t *testing.T) {
        t.Run("nil", func(t *testing.T) {
-               cps, err := (&DataSource{}).checkpointThis(nil)
+               cps, err := 
(&DataSource{}).checkpointThis(context.Background(), nil)
                if err != nil {
                        t.Fatalf("checkpointThis() = %v, %v", cps, err)
                }
        })
        t.Run("Stop", func(t *testing.T) {
-               cps, err := (&DataSource{}).checkpointThis(sdf.StopProcessing())
+               cps, err := 
(&DataSource{}).checkpointThis(context.Background(), sdf.StopProcessing())
                if err != nil {
                        t.Fatalf("checkpointThis() = %v, %v", cps, err)
                }
@@ -899,7 +899,7 @@ func TestCheckpointing(t *testing.T) {
                                },
                        },
                }
-               cp, err := 
root.checkpointThis(sdf.ResumeProcessingIn(time.Second * 13))
+               cp, err := root.checkpointThis(context.Background(), 
sdf.ResumeProcessingIn(time.Second*13))
                if err != nil {
                        t.Fatalf("checkpointThis() = %v, %v, want nil", cp, err)
                }
diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf.go 
b/sdks/go/pkg/beam/core/runtime/exec/sdf.go
index e22496eae6e..1dd3e35dc4d 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/sdf.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/sdf.go
@@ -90,7 +90,11 @@ func (n *PairWithRestriction) StartBundle(ctx 
context.Context, id string, data D
 //       Timestamps
 //     }
 func (n *PairWithRestriction) ProcessElement(ctx context.Context, elm 
*FullValue, values ...ReStream) error {
-       rest := n.inv.Invoke(elm)
+       rest, err := n.inv.Invoke(ctx, elm)
+       if err != nil {
+               return err
+       }
+
        output := FullValue{Elm: elm, Elm2: &FullValue{Elm: rest, Elm2: 
n.iwesInv.Invoke(rest, elm)}, Timestamp: elm.Timestamp, Windows: elm.Windows}
 
        return n.Out.ProcessElement(ctx, &output, values...)
@@ -195,10 +199,17 @@ func (n *SplitAndSizeRestrictions) ProcessElement(ctx 
context.Context, elm *Full
        // the element may not be wrapped in a *FullValue
        mainElm := convertIfNeeded(elm.Elm, &FullValue{})
 
-       splitRests := n.splitInv.Invoke(mainElm, rest)
+       splitRests, err := n.splitInv.Invoke(ctx, mainElm, rest)
+       if err != nil {
+               return err
+       }
 
        for _, splitRest := range splitRests {
-               size := n.sizeInv.Invoke(mainElm, splitRest)
+               size, err := n.sizeInv.Invoke(ctx, mainElm, splitRest)
+               if err != nil {
+                       return err
+               }
+
                if size < 0 {
                        err := errors.Errorf("size returned expected to be 
non-negative but received %v.", size)
                        return errors.WithContextf(err, "%v", n)
@@ -325,13 +336,25 @@ func (n *TruncateSizedRestriction) ProcessElement(ctx 
context.Context, elm *Full
                inp = e
        }
        rest := elm.Elm.(*FullValue).Elm2.(*FullValue).Elm
-       rt := n.ctInv.Invoke(rest)
-       newRest := n.truncateInv.Invoke(rt, mainElm)
+
+       rt, err := n.ctInv.Invoke(ctx, rest)
+       if err != nil {
+               return err
+       }
+
+       newRest, err := n.truncateInv.Invoke(ctx, rt, mainElm)
+       if err != nil {
+               return err
+       }
        if newRest == nil {
                // do not propagate discarded restrictions.
                return nil
        }
-       size := n.sizeInv.Invoke(mainElm, newRest)
+
+       size, err := n.sizeInv.Invoke(ctx, mainElm, newRest)
+       if err != nil {
+               return err
+       }
 
        output := &FullValue{}
        output.Timestamp = elm.Timestamp
@@ -476,7 +499,7 @@ func (n *ProcessSizedElementsAndRestrictions) 
StartBundle(ctx context.Context, i
 // and processes each element using the underlying ParDo and adding the
 // restriction tracker to the normal invocation. Sizing information is present
 // but currently ignored. Output is forwarded to the underlying ParDo's 
outputs.
-func (n *ProcessSizedElementsAndRestrictions) ProcessElement(_ 
context.Context, elm *FullValue, values ...ReStream) error {
+func (n *ProcessSizedElementsAndRestrictions) ProcessElement(ctx 
context.Context, elm *FullValue, values ...ReStream) error {
        if n.PDo.status != Active {
                err := errors.Errorf("invalid status %v, want Active", 
n.PDo.status)
                return errors.WithContextf(err, "%v", n)
@@ -520,7 +543,12 @@ func (n *ProcessSizedElementsAndRestrictions) 
ProcessElement(_ context.Context,
                // If windows don't need to be exploded (i.e. aren't observed), 
treat
                // all windows as one as an optimization.
                rest := elm.Elm.(*FullValue).Elm2.(*FullValue).Elm
-               rt := n.ctInv.Invoke(rest)
+
+               rt, err := n.ctInv.Invoke(ctx, rest)
+               if err != nil {
+                       return err
+               }
+
                mainIn.RTracker = rt
 
                n.numW = 1 // Even if there's more than one window, treat them 
as one.
@@ -542,7 +570,12 @@ func (n *ProcessSizedElementsAndRestrictions) 
ProcessElement(_ context.Context,
 
                for i := 0; i < n.numW; i++ {
                        rest := elm.Elm.(*FullValue).Elm2.(*FullValue).Elm
-                       rt := n.ctInv.Invoke(rest)
+
+                       rt, err := n.ctInv.Invoke(ctx, rest)
+                       if err != nil {
+                               return err
+                       }
+
                        key := &mainIn.Key
                        w := elm.Windows[i]
                        wElm := FullValue{Elm: key.Elm, Elm2: key.Elm2, 
Timestamp: key.Timestamp, Windows: []typex.Window{w}}
@@ -552,7 +585,7 @@ func (n *ProcessSizedElementsAndRestrictions) 
ProcessElement(_ context.Context,
                        n.elm = elm
                        n.SU <- n
                        // TODO(BEAM-11104): Remove placeholder for 
ProcessContinuation return.
-                       _, err := n.PDo.processSingleWindow(&MainInput{Key: 
wElm, Values: mainIn.Values, RTracker: rt})
+                       _, err = n.PDo.processSingleWindow(&MainInput{Key: 
wElm, Values: mainIn.Values, RTracker: rt})
                        if err != nil {
                                <-n.SU
                                return n.PDo.fail(err)
@@ -596,13 +629,13 @@ type SplittableUnit interface {
        //
        // More than one primary/residual can happen if the split result cannot 
be
        // fully represented in just one.
-       Split(fraction float64) (primaries, residuals []*FullValue, err error)
+       Split(ctx context.Context, fraction float64) (primaries, residuals 
[]*FullValue, err error)
 
        // Checkpoint performs a split at fraction 0.0 of an element that has 
stopped
        // processing and has work that needs to be resumed later. This 
function will
        // check that the produced primary restriction from the split represents
        // completed work to avoid data loss and will error if work remains.
-       Checkpoint() (residuals []*FullValue, err error)
+       Checkpoint(ctx context.Context) (residuals []*FullValue, err error)
 
        // GetProgress returns the fraction of progress the current element has
        // made in processing. (ex. 0.0 means no progress, and 1.0 means fully
@@ -631,7 +664,7 @@ type SplittableUnit interface {
 // windows need to be taken into account. For implementation details on when
 // each case occurs and the implementation details, see the documentation for
 // the singleWindowSplit and multiWindowSplit methods.
-func (n *ProcessSizedElementsAndRestrictions) Split(f float64) ([]*FullValue, 
[]*FullValue, error) {
+func (n *ProcessSizedElementsAndRestrictions) Split(ctx context.Context, f 
float64) ([]*FullValue, []*FullValue, error) {
        // Get the watermark state immediately so that we don't overestimate 
our current watermark.
        var pWeState any
        var rWeState any
@@ -658,7 +691,7 @@ func (n *ProcessSizedElementsAndRestrictions) Split(f 
float64) ([]*FullValue, []
        // Split behavior differs depending on whether this is a 
window-observing
        // DoFn or not.
        if len(n.elm.Windows) > 1 {
-               p, r, err := n.multiWindowSplit(f, pWeState, rWeState)
+               p, r, err := n.multiWindowSplit(ctx, f, pWeState, rWeState)
                if err != nil {
                        return nil, nil, addContext(err)
                }
@@ -666,7 +699,7 @@ func (n *ProcessSizedElementsAndRestrictions) Split(f 
float64) ([]*FullValue, []
        }
 
        // Not window-observing, or window-observing but only one window.
-       p, r, err := n.singleWindowSplit(f, pWeState, rWeState)
+       p, r, err := n.singleWindowSplit(ctx, f, pWeState, rWeState)
        if err != nil {
                return nil, nil, addContext(err)
        }
@@ -677,11 +710,11 @@ func (n *ProcessSizedElementsAndRestrictions) Split(f 
float64) ([]*FullValue, []
 // later by the runner. This is done iff the underlying Splittable DoFn 
returns a resuming
 // ProcessContinuation. If the split occurs and the primary restriction is 
marked as done
 // my the RTracker, the Checkpoint fails as this is a potential data-loss case.
-func (n *ProcessSizedElementsAndRestrictions) Checkpoint() ([]*FullValue, 
error) {
+func (n *ProcessSizedElementsAndRestrictions) Checkpoint(ctx context.Context) 
([]*FullValue, error) {
        addContext := func(err error) error {
                return errors.WithContext(err, "Attempting checkpoint in 
ProcessSizedElementsAndRestrictions")
        }
-       _, r, err := n.Split(0.0)
+       _, r, err := n.Split(ctx, 0.0)
 
        if err != nil {
                return nil, addContext(err)
@@ -699,7 +732,7 @@ func (n *ProcessSizedElementsAndRestrictions) Checkpoint() 
([]*FullValue, error)
 // behavior is identical). A single restriction split will occur and all 
windows
 // present in the unsplit element will be present in both the resulting primary
 // and residual.
-func (n *ProcessSizedElementsAndRestrictions) singleWindowSplit(f float64, 
pWeState, rWeState any) ([]*FullValue, []*FullValue, error) {
+func (n *ProcessSizedElementsAndRestrictions) singleWindowSplit(ctx 
context.Context, f float64, pWeState, rWeState any) ([]*FullValue, 
[]*FullValue, error) {
        if n.rt.IsDone() { // Not an error, but not splittable.
                return []*FullValue{}, []*FullValue{}, nil
        }
@@ -714,14 +747,14 @@ func (n *ProcessSizedElementsAndRestrictions) 
singleWindowSplit(f float64, pWeSt
 
        var primaryResult []*FullValue
        if p != nil {
-               pfv, err := n.newSplitResult(p, n.elm.Windows, pWeState)
+               pfv, err := n.newSplitResult(ctx, p, n.elm.Windows, pWeState)
                if err != nil {
                        return nil, nil, err
                }
                primaryResult = append(primaryResult, pfv)
        }
 
-       rfv, err := n.newSplitResult(r, n.elm.Windows, rWeState)
+       rfv, err := n.newSplitResult(ctx, r, n.elm.Windows, rWeState)
        if err != nil {
                return nil, nil, err
        }
@@ -752,7 +785,7 @@ func (n *ProcessSizedElementsAndRestrictions) 
singleWindowSplit(f float64, pWeSt
 //
 // This method also updates the current number of windows (n.numW) so that
 // windows in the residual will no longer be processed.
-func (n *ProcessSizedElementsAndRestrictions) multiWindowSplit(f float64, 
pWeState any, rWeState any) ([]*FullValue, []*FullValue, error) {
+func (n *ProcessSizedElementsAndRestrictions) multiWindowSplit(ctx 
context.Context, f float64, pWeState any, rWeState any) ([]*FullValue, 
[]*FullValue, error) {
        // Get the split point in window range, to see what window it falls in.
        done, rem := n.rt.GetProgress()
        cwp := done / (done + rem)                      // Progress in current 
window.
@@ -765,25 +798,25 @@ func (n *ProcessSizedElementsAndRestrictions) 
multiWindowSplit(f float64, pWeSta
                if n.rt.IsDone() {
                        // Current RTracker is done so we can't split within 
the window, so
                        // split at window boundary instead.
-                       return n.windowBoundarySplit(n.currW+1, pWeState, 
rWeState)
+                       return n.windowBoundarySplit(ctx, n.currW+1, pWeState, 
rWeState)
                }
 
                // Get the fraction of remaining work in the current window to 
split at.
                cwsp := wsp - float64(n.currW) // Split point in current window.
                rf := (cwsp - cwp) / (1 - cwp) // Fraction of work in RTracker 
to split at.
 
-               return n.currentWindowSplit(rf, pWeState, rWeState)
+               return n.currentWindowSplit(ctx, rf, pWeState, rWeState)
        } else {
                // Split at nearest window boundary to split point.
                wb := math.Round(wsp)
-               return n.windowBoundarySplit(int(wb), pWeState, rWeState)
+               return n.windowBoundarySplit(ctx, int(wb), pWeState, rWeState)
        }
 }
 
 // currentWindowSplit performs an appropriate split at the given fraction of
 // remaining work in the current window. Also updates numW to stop after the
 // current window.
-func (n *ProcessSizedElementsAndRestrictions) currentWindowSplit(f float64, 
pWeState any, rWeState any) ([]*FullValue, []*FullValue, error) {
+func (n *ProcessSizedElementsAndRestrictions) currentWindowSplit(ctx 
context.Context, f float64, pWeState any, rWeState any) ([]*FullValue, 
[]*FullValue, error) {
        p, r, err := n.rt.TrySplit(f)
        if err != nil {
                return nil, nil, err
@@ -791,18 +824,18 @@ func (n *ProcessSizedElementsAndRestrictions) 
currentWindowSplit(f float64, pWeS
        if r == nil {
                // If r is nil then the split failed/returned an empty 
residual, but
                // we can still split at a window boundary.
-               return n.windowBoundarySplit(n.currW+1, pWeState, rWeState)
+               return n.windowBoundarySplit(ctx, n.currW+1, pWeState, rWeState)
        }
 
        // Split of currently processing restriction in a single window.
        ps := make([]*FullValue, 1)
-       newP, err := n.newSplitResult(p, n.elm.Windows[n.currW:n.currW+1], 
pWeState)
+       newP, err := n.newSplitResult(ctx, p, n.elm.Windows[n.currW:n.currW+1], 
pWeState)
        if err != nil {
                return nil, nil, err
        }
        ps[0] = newP
        rs := make([]*FullValue, 1)
-       newR, err := n.newSplitResult(r, n.elm.Windows[n.currW:n.currW+1], 
rWeState)
+       newR, err := n.newSplitResult(ctx, r, n.elm.Windows[n.currW:n.currW+1], 
rWeState)
        if err != nil {
                return nil, nil, err
        }
@@ -810,14 +843,14 @@ func (n *ProcessSizedElementsAndRestrictions) 
currentWindowSplit(f float64, pWeS
        // Window boundary split surrounding the split restriction above.
        full := n.elm.Elm.(*FullValue).Elm2.(*FullValue).Elm
        if 0 < n.currW {
-               newP, err := n.newSplitResult(full, n.elm.Windows[0:n.currW], 
pWeState)
+               newP, err := n.newSplitResult(ctx, full, 
n.elm.Windows[0:n.currW], pWeState)
                if err != nil {
                        return nil, nil, err
                }
                ps = append(ps, newP)
        }
        if n.currW+1 < n.numW {
-               newR, err := n.newSplitResult(full, 
n.elm.Windows[n.currW+1:n.numW], rWeState)
+               newR, err := n.newSplitResult(ctx, full, 
n.elm.Windows[n.currW+1:n.numW], rWeState)
                if err != nil {
                        return nil, nil, err
                }
@@ -830,17 +863,17 @@ func (n *ProcessSizedElementsAndRestrictions) 
currentWindowSplit(f float64, pWeS
 // windowBoundarySplit performs an appropriate split at a window boundary. The
 // split point taken should be the index of the first window in the residual.
 // Also updates numW to stop at the split point.
-func (n *ProcessSizedElementsAndRestrictions) windowBoundarySplit(splitPt int, 
pWeState any, rWeState any) ([]*FullValue, []*FullValue, error) {
+func (n *ProcessSizedElementsAndRestrictions) windowBoundarySplit(ctx 
context.Context, splitPt int, pWeState any, rWeState any) ([]*FullValue, 
[]*FullValue, error) {
        // If this is at the boundary of the last window, split is a no-op.
        if splitPt == n.numW {
                return []*FullValue{}, []*FullValue{}, nil
        }
        full := n.elm.Elm.(*FullValue).Elm2.(*FullValue).Elm
-       pFv, err := n.newSplitResult(full, n.elm.Windows[0:splitPt], pWeState)
+       pFv, err := n.newSplitResult(ctx, full, n.elm.Windows[0:splitPt], 
pWeState)
        if err != nil {
                return nil, nil, err
        }
-       rFv, err := n.newSplitResult(full, n.elm.Windows[splitPt:n.numW], 
rWeState)
+       rFv, err := n.newSplitResult(ctx, full, n.elm.Windows[splitPt:n.numW], 
rWeState)
        if err != nil {
                return nil, nil, err
        }
@@ -852,18 +885,27 @@ func (n *ProcessSizedElementsAndRestrictions) 
windowBoundarySplit(splitPt int, p
 // element restriction pair based on the currently processing element, but with
 // a modified restriction and windows. Intended for creating primaries and
 // residuals to return as split results.
-func (n *ProcessSizedElementsAndRestrictions) newSplitResult(rest any, w 
[]typex.Window, weState any) (*FullValue, error) {
+func (n *ProcessSizedElementsAndRestrictions) newSplitResult(ctx 
context.Context, rest any, w []typex.Window, weState any) (*FullValue, error) {
        var size float64
+       var err error
        elm := n.elm.Elm.(*FullValue).Elm
        if fv, ok := elm.(*FullValue); ok {
-               size = n.sizeInv.Invoke(fv, rest)
+               size, err = n.sizeInv.Invoke(ctx, fv, rest)
+               if err != nil {
+                       return nil, err
+               }
+
                if size < 0 {
                        err := errors.Errorf("size returned expected to be 
non-negative but received %v.", size)
                        return nil, errors.WithContextf(err, "%v", n)
                }
        } else {
                fv := &FullValue{Elm: elm}
-               size = n.sizeInv.Invoke(fv, rest)
+               size, err = n.sizeInv.Invoke(ctx, fv, rest)
+               if err != nil {
+                       return nil, err
+               }
+
                if size < 0 {
                        err := errors.Errorf("size returned expected to be 
non-negative but received %v.", size)
                        return nil, errors.WithContextf(err, "%v", n)
@@ -973,21 +1015,33 @@ func (n *SdfFallback) StartBundle(ctx context.Context, 
id string, data DataConte
 // restrictions, and then creating restriction trackers and processing each
 // restriction with the underlying ParDo. This executor skips the sizing step
 // because sizing information is unnecessary for unexpanded SDFs.
-func (n *SdfFallback) ProcessElement(_ context.Context, elm *FullValue, values 
...ReStream) error {
+func (n *SdfFallback) ProcessElement(ctx context.Context, elm *FullValue, 
values ...ReStream) error {
        if n.PDo.status != Active {
                err := errors.Errorf("invalid status %v, want Active", 
n.PDo.status)
                return errors.WithContextf(err, "%v", n)
        }
 
-       rest := n.initRestInv.Invoke(elm)
-       splitRests := n.splitInv.Invoke(elm, rest)
+       rest, err := n.initRestInv.Invoke(ctx, elm)
+       if err != nil {
+               return err
+       }
+
+       splitRests, err := n.splitInv.Invoke(ctx, elm, rest)
+       if err != nil {
+               return err
+       }
+
        if len(splitRests) == 0 {
                err := errors.Errorf("initial splitting returned 0 
restrictions.")
                return errors.WithContextf(err, "%v", n)
        }
 
        for _, splitRest := range splitRests {
-               rt := n.trackerInv.Invoke(splitRest)
+               rt, err := n.trackerInv.Invoke(ctx, splitRest)
+               if err != nil {
+                       return err
+               }
+
                mainIn := &MainInput{
                        Key:      *elm,
                        Values:   values,
diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers.go 
b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers.go
index 5d58f198a64..2dd894ed08c 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers.go
@@ -16,6 +16,7 @@
 package exec
 
 import (
+       "context"
        "reflect"
 
        "github.com/apache/beam/sdks/v2/go/pkg/beam/core/funcx"
@@ -24,6 +25,9 @@ import (
        "github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors"
 )
 
+//go:generate specialize --input=sdf_invokers_arity.tmpl
+//go:generate gofmt -w sdf_invokers_arity.go
+
 // This file contains invokers for SDF methods. These invokers are based off
 // exec.invoker which is used for regular DoFns. Since exec.invoker is
 // specialized for DoFns it cannot be used for SDF methods. Instead, these
@@ -39,9 +43,10 @@ import (
 
 // cirInvoker is an invoker for CreateInitialRestriction.
 type cirInvoker struct {
-       fn   *funcx.Fn
-       args []any // Cache to avoid allocating new slices per-element.
-       call func(elms *FullValue) (rest any)
+       fn     *funcx.Fn
+       args   []any // Cache to avoid allocating new slices per-element.
+       ctxIdx int
+       call   func() (rest any, err error)
 }
 
 func newCreateInitialRestrictionInvoker(fn *funcx.Fn) (*cirInvoker, error) {
@@ -49,51 +54,34 @@ func newCreateInitialRestrictionInvoker(fn *funcx.Fn) 
(*cirInvoker, error) {
                fn:   fn,
                args: make([]any, len(fn.Param)),
        }
+
+       var ok bool
+       if n.ctxIdx, ok = fn.Context(); !ok {
+               n.ctxIdx = -1
+       }
+
        if err := n.initCallFn(); err != nil {
                return nil, errors.WithContext(err, "sdf 
CreateInitialRestriction invoker")
        }
        return n, nil
 }
 
-func (n *cirInvoker) initCallFn() error {
-       // Expects a signature of the form:
-       // (key?, value) restriction
-       // TODO(BEAM-9643): Link to full documentation.
-       switch fnT := n.fn.Fn.(type) {
-       case reflectx.Func1x1:
-               n.call = func(elms *FullValue) any {
-                       return fnT.Call1x1(elms.Elm)
-               }
-       case reflectx.Func2x1:
-               n.call = func(elms *FullValue) any {
-                       return fnT.Call2x1(elms.Elm, elms.Elm2)
-               }
-       default:
-               switch len(n.fn.Param) {
-               case 1:
-                       n.call = func(elms *FullValue) any {
-                               n.args[0] = elms.Elm
-                               return n.fn.Fn.Call(n.args)[0]
-                       }
-               case 2:
-                       n.call = func(elms *FullValue) any {
-                               n.args[0] = elms.Elm
-                               n.args[1] = elms.Elm2
-                               return n.fn.Fn.Call(n.args)[0]
-                       }
-               default:
-                       return errors.Errorf("CreateInitialRestriction fn %v 
has unexpected number of parameters: %v",
-                               n.fn.Fn.Name(), len(n.fn.Param))
-               }
+// Invoke calls CreateInitialRestriction with the given FullValue as the 
element
+// and returns the resulting restriction.
+func (n *cirInvoker) Invoke(ctx context.Context, elms *FullValue) (rest any, 
err error) {
+       if n.ctxIdx >= 0 {
+               n.args[n.ctxIdx] = ctx
        }
 
-       return nil
-}
+       i := n.ctxIdx + 1
+       n.args[i] = elms.Elm
 
-// Invoke calls CreateInitialRestriction with the given FullValue as the 
element
-// and returns the resulting restriction.
-func (n *cirInvoker) Invoke(elms *FullValue) (rest any) {
-       return n.call(elms)
+       if elms.Elm2 != nil {
+               i++
+               n.args[i] = elms.Elm2
+       }
+
+       return n.call()
 }
 
 // Reset zeroes argument entries in the cached slice to allow values to be
@@ -106,9 +94,10 @@ func (n *cirInvoker) Reset() {
 
 // srInvoker is an invoker for SplitRestriction.
 type srInvoker struct {
-       fn   *funcx.Fn
-       args []any // Cache to avoid allocating new slices per-element.
-       call func(elms *FullValue, rest any) (splits any)
+       fn     *funcx.Fn
+       args   []any // Cache to avoid allocating new slices per-element.
+       ctxIdx int
+       call   func() (splits any, err error)
 }
 
 func newSplitRestrictionInvoker(fn *funcx.Fn) (*srInvoker, error) {
@@ -116,52 +105,40 @@ func newSplitRestrictionInvoker(fn *funcx.Fn) 
(*srInvoker, error) {
                fn:   fn,
                args: make([]any, len(fn.Param)),
        }
+
+       var ok bool
+       if n.ctxIdx, ok = fn.Context(); !ok {
+               n.ctxIdx = -1
+       }
+
        if err := n.initCallFn(); err != nil {
                return nil, errors.WithContext(err, "sdf SplitRestriction 
invoker")
        }
        return n, nil
 }
 
-func (n *srInvoker) initCallFn() error {
-       // Expects a signature of the form:
-       // (key?, value, restriction) []restriction
-       // TODO(BEAM-9643): Link to full documentation.
-       switch fnT := n.fn.Fn.(type) {
-       case reflectx.Func2x1:
-               n.call = func(elms *FullValue, rest any) any {
-                       return fnT.Call2x1(elms.Elm, rest)
-               }
-       case reflectx.Func3x1:
-               n.call = func(elms *FullValue, rest any) any {
-                       return fnT.Call3x1(elms.Elm, elms.Elm2, rest)
-               }
-       default:
-               switch len(n.fn.Param) {
-               case 2:
-                       n.call = func(elms *FullValue, rest any) any {
-                               n.args[0] = elms.Elm
-                               n.args[1] = rest
-                               return n.fn.Fn.Call(n.args)[0]
-                       }
-               case 3:
-                       n.call = func(elms *FullValue, rest any) any {
-                               n.args[0] = elms.Elm
-                               n.args[1] = elms.Elm2
-                               n.args[2] = rest
-                               return n.fn.Fn.Call(n.args)[0]
-                       }
-               default:
-                       return errors.Errorf("SplitRestriction fn %v has 
unexpected number of parameters: %v",
-                               n.fn.Fn.Name(), len(n.fn.Param))
-               }
-       }
-       return nil
-}
-
 // Invoke calls SplitRestriction given a FullValue containing an element and
 // the associated restriction, and returns a slice of split restrictions.
-func (n *srInvoker) Invoke(elms *FullValue, rest any) (splits []any) {
-       ret := n.call(elms, rest)
+func (n *srInvoker) Invoke(ctx context.Context, elms *FullValue, rest any) 
(splits []any, err error) {
+       if n.ctxIdx >= 0 {
+               n.args[n.ctxIdx] = ctx
+       }
+
+       i := n.ctxIdx + 1
+       n.args[i] = elms.Elm
+
+       if elms.Elm2 != nil {
+               i++
+               n.args[i] = elms.Elm2
+       }
+
+       i++
+       n.args[i] = rest
+
+       ret, err := n.call()
+       if err != nil {
+               return nil, err
+       }
 
        // Return value is an any, but we need to convert it to a []any.
        val := reflect.ValueOf(ret)
@@ -169,7 +146,7 @@ func (n *srInvoker) Invoke(elms *FullValue, rest any) 
(splits []any) {
        for i := 0; i < val.Len(); i++ {
                s = append(s, val.Index(i).Interface())
        }
-       return s
+       return s, nil
 }
 
 // Reset zeroes argument entries in the cached slice to allow values to be
@@ -182,9 +159,10 @@ func (n *srInvoker) Reset() {
 
 // rsInvoker is an invoker for RestrictionSize.
 type rsInvoker struct {
-       fn   *funcx.Fn
-       args []any // Cache to avoid allocating new slices per-element.
-       call func(elms *FullValue, rest any) (size float64)
+       fn     *funcx.Fn
+       args   []any // Cache to avoid allocating new slices per-element.
+       ctxIdx int
+       call   func() (size float64, err error)
 }
 
 func newRestrictionSizeInvoker(fn *funcx.Fn) (*rsInvoker, error) {
@@ -192,52 +170,37 @@ func newRestrictionSizeInvoker(fn *funcx.Fn) (*rsInvoker, 
error) {
                fn:   fn,
                args: make([]any, len(fn.Param)),
        }
+
+       var ok bool
+       if n.ctxIdx, ok = fn.Context(); !ok {
+               n.ctxIdx = -1
+       }
+
        if err := n.initCallFn(); err != nil {
                return nil, errors.WithContext(err, "sdf RestrictionSize 
invoker")
        }
        return n, nil
 }
 
-func (n *rsInvoker) initCallFn() error {
-       // Expects a signature of the form:
-       // (key?, value, restriction) float64
-       // TODO(BEAM-9643): Link to full documentation.
-       switch fnT := n.fn.Fn.(type) {
-       case reflectx.Func2x1:
-               n.call = func(elms *FullValue, rest any) float64 {
-                       return fnT.Call2x1(elms.Elm, rest).(float64)
-               }
-       case reflectx.Func3x1:
-               n.call = func(elms *FullValue, rest any) float64 {
-                       return fnT.Call3x1(elms.Elm, elms.Elm2, rest).(float64)
-               }
-       default:
-               switch len(n.fn.Param) {
-               case 2:
-                       n.call = func(elms *FullValue, rest any) float64 {
-                               n.args[0] = elms.Elm
-                               n.args[1] = rest
-                               return n.fn.Fn.Call(n.args)[0].(float64)
-                       }
-               case 3:
-                       n.call = func(elms *FullValue, rest any) float64 {
-                               n.args[0] = elms.Elm
-                               n.args[1] = elms.Elm2
-                               n.args[2] = rest
-                               return n.fn.Fn.Call(n.args)[0].(float64)
-                       }
-               default:
-                       return errors.Errorf("RestrictionSize fn %v has 
unexpected number of parameters: %v",
-                               n.fn.Fn.Name(), len(n.fn.Param))
-               }
-       }
-       return nil
-}
-
 // Invoke calls RestrictionSize given a FullValue containing an element and
 // the associated restriction, and returns a size.
-func (n *rsInvoker) Invoke(elms *FullValue, rest any) (size float64) {
-       return n.call(elms, rest)
+func (n *rsInvoker) Invoke(ctx context.Context, elms *FullValue, rest any) 
(size float64, err error) {
+       if n.ctxIdx >= 0 {
+               n.args[n.ctxIdx] = ctx
+       }
+
+       i := n.ctxIdx + 1
+       n.args[i] = elms.Elm
+
+       if elms.Elm2 != nil {
+               i++
+               n.args[i] = elms.Elm2
+       }
+
+       i++
+       n.args[i] = rest
+
+       return n.call()
 }
 
 // Reset zeroes argument entries in the cached slice to allow values to be
@@ -250,9 +213,10 @@ func (n *rsInvoker) Reset() {
 
 // ctInvoker is an invoker for CreateTracker.
 type ctInvoker struct {
-       fn   *funcx.Fn
-       args []any // Cache to avoid allocating new slices per-element.
-       call func(rest any) sdf.RTracker
+       fn     *funcx.Fn
+       args   []any // Cache to avoid allocating new slices per-element.
+       ctxIdx int
+       call   func() (rt sdf.RTracker, err error)
 }
 
 func newCreateTrackerInvoker(fn *funcx.Fn) (*ctInvoker, error) {
@@ -260,37 +224,27 @@ func newCreateTrackerInvoker(fn *funcx.Fn) (*ctInvoker, 
error) {
                fn:   fn,
                args: make([]any, len(fn.Param)),
        }
+
+       var ok bool
+       if n.ctxIdx, ok = fn.Context(); !ok {
+               n.ctxIdx = -1
+       }
+
        if err := n.initCallFn(); err != nil {
                return nil, errors.WithContext(err, "sdf CreateTracker invoker")
        }
        return n, nil
 }
 
-func (n *ctInvoker) initCallFn() error {
-       // Expects a signature of the form:
-       // (restriction) sdf.RTracker
-       // TODO(BEAM-9643): Link to full documentation.
-       switch fnT := n.fn.Fn.(type) {
-       case reflectx.Func1x1:
-               n.call = func(rest any) sdf.RTracker {
-                       return fnT.Call1x1(rest).(sdf.RTracker)
-               }
-       default:
-               if len(n.fn.Param) != 1 {
-                       return errors.Errorf("CreateTracker fn %v has 
unexpected number of parameters: %v",
-                               n.fn.Fn.Name(), len(n.fn.Param))
-               }
-               n.call = func(rest any) sdf.RTracker {
-                       n.args[0] = rest
-                       return n.fn.Fn.Call(n.args)[0].(sdf.RTracker)
-               }
+// Invoke calls CreateTracker given a restriction and returns an sdf.RTracker.
+func (n *ctInvoker) Invoke(ctx context.Context, rest any) (sdf.RTracker, 
error) {
+       if n.ctxIdx >= 0 {
+               n.args[n.ctxIdx] = ctx
        }
-       return nil
-}
 
-// Invoke calls CreateTracker given a restriction and returns an sdf.RTracker.
-func (n *ctInvoker) Invoke(rest any) sdf.RTracker {
-       return n.call(rest)
+       n.args[n.ctxIdx+1] = rest
+
+       return n.call()
 }
 
 // Reset zeroes argument entries in the cached slice to allow values to be
@@ -303,9 +257,10 @@ func (n *ctInvoker) Reset() {
 
 // trInvoker is an invoker for TruncateRestriction.
 type trInvoker struct {
-       fn   *funcx.Fn
-       args []any
-       call func(rest any, elms *FullValue) (pair any)
+       fn     *funcx.Fn
+       args   []any
+       ctxIdx int
+       call   func() (rest any, err error)
 }
 
 func defaultTruncateRestriction(restTracker any) (newRest any) {
@@ -320,6 +275,12 @@ func newTruncateRestrictionInvoker(fn *funcx.Fn) 
(*trInvoker, error) {
                fn:   fn,
                args: make([]any, len(fn.Param)),
        }
+
+       var ok bool
+       if n.ctxIdx, ok = fn.Context(); !ok {
+               n.ctxIdx = -1
+       }
+
        if err := n.initCallFn(); err != nil {
                return nil, errors.WithContext(err, "sdf TruncateRestriction 
invoker")
        }
@@ -327,53 +288,39 @@ func newTruncateRestrictionInvoker(fn *funcx.Fn) 
(*trInvoker, error) {
 }
 
 func newDefaultTruncateRestrictionInvoker() (*trInvoker, error) {
-       n := &trInvoker{}
-       n.call = func(rest any, elms *FullValue) any {
-               return defaultTruncateRestriction(rest)
+       n := &trInvoker{
+               args: make([]any, 1),
        }
-       return n, nil
-}
-
-func (n *trInvoker) initCallFn() error {
-       // Expects a signature of the form:
-       // (key?, value, restriction) []restriction
-       // TODO(BEAM-9643): Link to full documentation.
-       switch fnT := n.fn.Fn.(type) {
-       case reflectx.Func2x1:
-               n.call = func(rest any, elms *FullValue) any {
-                       return fnT.Call2x1(rest, elms.Elm)
-               }
-       case reflectx.Func3x1:
-               n.call = func(rest any, elms *FullValue) any {
-                       return fnT.Call3x1(rest, elms.Elm, elms.Elm2)
-               }
-       default:
-               switch len(n.fn.Param) {
-               case 2:
-                       n.call = func(rest any, elms *FullValue) any {
-                               n.args[0] = rest
-                               n.args[1] = elms.Elm
-                               return n.fn.Fn.Call(n.args)[0]
-                       }
-               case 3:
-                       n.call = func(rest any, elms *FullValue) any {
-                               n.args[0] = rest
-                               n.args[1] = elms.Elm
-                               n.args[2] = elms.Elm2
-                               return n.fn.Fn.Call(n.args)[0]
-                       }
-               default:
-                       return errors.Errorf("TruncateRestriction fn %v has 
unexpected number of parameters: %v",
-                               n.fn.Fn.Name(), len(n.fn.Param))
-               }
+       n.call = func() (any, error) {
+               return defaultTruncateRestriction(n.args[0]), nil
        }
-       return nil
+       return n, nil
 }
 
 // Invoke calls TruncateRestriction given a FullValue containing an element and
 // the associated restriction tracker, and returns a truncated restriction.
-func (n *trInvoker) Invoke(rt any, elms *FullValue) (rest any) {
-       return n.call(rt, elms)
+func (n *trInvoker) Invoke(ctx context.Context, rt any, elms *FullValue) (rest 
any, err error) {
+       if n.fn == nil {
+               n.args[0] = rt
+               return n.call()
+       }
+
+       if n.ctxIdx >= 0 {
+               n.args[n.ctxIdx] = ctx
+       }
+
+       i := n.ctxIdx + 1
+       n.args[i] = rt
+
+       i++
+       n.args[i] = elms.Elm
+
+       if elms.Elm2 != nil {
+               i++
+               n.args[i] = elms.Elm2
+       }
+
+       return n.call()
 }
 
 // Reset zeroes argument entries in the cached slice to allow values to be
@@ -589,3 +536,11 @@ func (n *wesInvoker) Reset() {
                n.args[i] = nil
        }
 }
+
+func asError(val any) error {
+       if val != nil {
+               return val.(error)
+       }
+
+       return nil
+}
diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_arity.go 
b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_arity.go
new file mode 100644
index 00000000000..cdefa711603
--- /dev/null
+++ b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_arity.go
@@ -0,0 +1,336 @@
+// File generated by specialize. Do not edit.
+
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Code generated from sdf_invokers_arity.tmpl. DO NOT EDIT.
+
+package exec
+
+import (
+       "fmt"
+
+       "github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf"
+       "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx"
+       "github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors"
+)
+
+func (n *cirInvoker) initCallFn() error {
+       // Expects a signature of the form:
+       // (context.Context?, key?, value) (restriction, error?)
+       // TODO(BEAM-9643): Link to full documentation.
+       switch fnT := n.fn.Fn.(type) {
+
+       case reflectx.Func1x1:
+               n.call = func() (rest any, err error) {
+                       r0 := fnT.Call1x1(n.args[0])
+                       return r0, nil
+               }
+
+       case reflectx.Func2x1:
+               n.call = func() (rest any, err error) {
+                       r0 := fnT.Call2x1(n.args[0], n.args[1])
+                       return r0, nil
+               }
+
+       case reflectx.Func3x1:
+               n.call = func() (rest any, err error) {
+                       r0 := fnT.Call3x1(n.args[0], n.args[1], n.args[2])
+                       return r0, nil
+               }
+
+       case reflectx.Func1x2:
+               n.call = func() (rest any, err error) {
+                       r0, r1 := fnT.Call1x2(n.args[0])
+                       return r0, asError(r1)
+               }
+
+       case reflectx.Func2x2:
+               n.call = func() (rest any, err error) {
+                       r0, r1 := fnT.Call2x2(n.args[0], n.args[1])
+                       return r0, asError(r1)
+               }
+
+       case reflectx.Func3x2:
+               n.call = func() (rest any, err error) {
+                       r0, r1 := fnT.Call3x2(n.args[0], n.args[1], n.args[2])
+                       return r0, asError(r1)
+               }
+
+       default:
+               if len(n.fn.Param) < 1 || len(n.fn.Param) > 3 {
+                       return errors.Errorf("CreateInitialRestriction has 
unexpected number of parameters: %v", len(n.fn.Param))
+               }
+
+               n.call = func() (rest any, err error) {
+                       ret := n.fn.Fn.Call(n.args)
+
+                       switch len(ret) {
+                       case 1:
+                               return ret[0], nil
+                       case 2:
+                               return ret[0], asError(ret[1])
+                       }
+
+                       panic(fmt.Sprintf("CreateInitialRestriction has 
unexpected number of return values: %v", len(ret)))
+               }
+       }
+
+       return nil
+}
+
+func (n *srInvoker) initCallFn() error {
+       // Expects a signature of the form:
+       // (context.Context?, key?, value, restriction) ([]restriction, error?)
+       // TODO(BEAM-9643): Link to full documentation.
+       switch fnT := n.fn.Fn.(type) {
+
+       case reflectx.Func2x1:
+               n.call = func() (splits any, err error) {
+                       r0 := fnT.Call2x1(n.args[0], n.args[1])
+                       return r0, nil
+               }
+
+       case reflectx.Func3x1:
+               n.call = func() (splits any, err error) {
+                       r0 := fnT.Call3x1(n.args[0], n.args[1], n.args[2])
+                       return r0, nil
+               }
+
+       case reflectx.Func4x1:
+               n.call = func() (splits any, err error) {
+                       r0 := fnT.Call4x1(n.args[0], n.args[1], n.args[2], 
n.args[3])
+                       return r0, nil
+               }
+
+       case reflectx.Func2x2:
+               n.call = func() (splits any, err error) {
+                       r0, r1 := fnT.Call2x2(n.args[0], n.args[1])
+                       return r0, asError(r1)
+               }
+
+       case reflectx.Func3x2:
+               n.call = func() (splits any, err error) {
+                       r0, r1 := fnT.Call3x2(n.args[0], n.args[1], n.args[2])
+                       return r0, asError(r1)
+               }
+
+       case reflectx.Func4x2:
+               n.call = func() (splits any, err error) {
+                       r0, r1 := fnT.Call4x2(n.args[0], n.args[1], n.args[2], 
n.args[3])
+                       return r0, asError(r1)
+               }
+
+       default:
+               if len(n.fn.Param) < 2 || len(n.fn.Param) > 4 {
+                       return errors.Errorf("SplitRestriction has unexpected 
number of parameters: %v", len(n.fn.Param))
+               }
+
+               n.call = func() (splits any, err error) {
+                       ret := n.fn.Fn.Call(n.args)
+
+                       switch len(ret) {
+                       case 1:
+                               return ret[0], nil
+                       case 2:
+                               return ret[0], asError(ret[1])
+                       }
+
+                       panic(fmt.Sprintf("SplitRestriction has unexpected 
number of return values: %v", len(ret)))
+               }
+       }
+
+       return nil
+}
+
+func (n *rsInvoker) initCallFn() error {
+       // Expects a signature of the form:
+       // (context.Context?, key?, value, restriction) (float64, error?)
+       // TODO(BEAM-9643): Link to full documentation.
+       switch fnT := n.fn.Fn.(type) {
+
+       case reflectx.Func2x1:
+               n.call = func() (size float64, err error) {
+                       r0 := fnT.Call2x1(n.args[0], n.args[1])
+                       return r0.(float64), nil
+               }
+
+       case reflectx.Func3x1:
+               n.call = func() (size float64, err error) {
+                       r0 := fnT.Call3x1(n.args[0], n.args[1], n.args[2])
+                       return r0.(float64), nil
+               }
+
+       case reflectx.Func4x1:
+               n.call = func() (size float64, err error) {
+                       r0 := fnT.Call4x1(n.args[0], n.args[1], n.args[2], 
n.args[3])
+                       return r0.(float64), nil
+               }
+
+       case reflectx.Func2x2:
+               n.call = func() (size float64, err error) {
+                       r0, r1 := fnT.Call2x2(n.args[0], n.args[1])
+                       return r0.(float64), asError(r1)
+               }
+
+       case reflectx.Func3x2:
+               n.call = func() (size float64, err error) {
+                       r0, r1 := fnT.Call3x2(n.args[0], n.args[1], n.args[2])
+                       return r0.(float64), asError(r1)
+               }
+
+       case reflectx.Func4x2:
+               n.call = func() (size float64, err error) {
+                       r0, r1 := fnT.Call4x2(n.args[0], n.args[1], n.args[2], 
n.args[3])
+                       return r0.(float64), asError(r1)
+               }
+
+       default:
+               if len(n.fn.Param) < 2 || len(n.fn.Param) > 4 {
+                       return errors.Errorf("RestrictionSize has unexpected 
number of parameters: %v", len(n.fn.Param))
+               }
+
+               n.call = func() (size float64, err error) {
+                       ret := n.fn.Fn.Call(n.args)
+
+                       switch len(ret) {
+                       case 1:
+                               return ret[0].(float64), nil
+                       case 2:
+                               return ret[0].(float64), asError(ret[1])
+                       }
+
+                       panic(fmt.Sprintf("RestrictionSize has unexpected 
number of return values: %v", len(ret)))
+               }
+       }
+
+       return nil
+}
+
+func (n *ctInvoker) initCallFn() error {
+       // Expects a signature of the form:
+       // (context.Context?, restriction) (sdf.RTracker, error?)
+       // TODO(BEAM-9643): Link to full documentation.
+       switch fnT := n.fn.Fn.(type) {
+
+       case reflectx.Func1x1:
+               n.call = func() (rt sdf.RTracker, err error) {
+                       r0 := fnT.Call1x1(n.args[0])
+                       return r0.(sdf.RTracker), nil
+               }
+
+       case reflectx.Func2x1:
+               n.call = func() (rt sdf.RTracker, err error) {
+                       r0 := fnT.Call2x1(n.args[0], n.args[1])
+                       return r0.(sdf.RTracker), nil
+               }
+
+       case reflectx.Func1x2:
+               n.call = func() (rt sdf.RTracker, err error) {
+                       r0, r1 := fnT.Call1x2(n.args[0])
+                       return r0.(sdf.RTracker), asError(r1)
+               }
+
+       case reflectx.Func2x2:
+               n.call = func() (rt sdf.RTracker, err error) {
+                       r0, r1 := fnT.Call2x2(n.args[0], n.args[1])
+                       return r0.(sdf.RTracker), asError(r1)
+               }
+
+       default:
+               if len(n.fn.Param) < 1 || len(n.fn.Param) > 2 {
+                       return errors.Errorf("CreateTracker has unexpected 
number of parameters: %v", len(n.fn.Param))
+               }
+
+               n.call = func() (rt sdf.RTracker, err error) {
+                       ret := n.fn.Fn.Call(n.args)
+
+                       switch len(ret) {
+                       case 1:
+                               return ret[0].(sdf.RTracker), nil
+                       case 2:
+                               return ret[0].(sdf.RTracker), asError(ret[1])
+                       }
+
+                       panic(fmt.Sprintf("CreateTracker has unexpected number 
of return values: %v", len(ret)))
+               }
+       }
+
+       return nil
+}
+
+func (n *trInvoker) initCallFn() error {
+       // Expects a signature of the form:
+       // (context.Context?, sdf.RTracker, key?, value) (restriction, error?)
+       // TODO(BEAM-9643): Link to full documentation.
+       switch fnT := n.fn.Fn.(type) {
+
+       case reflectx.Func2x1:
+               n.call = func() (rest any, err error) {
+                       r0 := fnT.Call2x1(n.args[0], n.args[1])
+                       return r0, nil
+               }
+
+       case reflectx.Func3x1:
+               n.call = func() (rest any, err error) {
+                       r0 := fnT.Call3x1(n.args[0], n.args[1], n.args[2])
+                       return r0, nil
+               }
+
+       case reflectx.Func4x1:
+               n.call = func() (rest any, err error) {
+                       r0 := fnT.Call4x1(n.args[0], n.args[1], n.args[2], 
n.args[3])
+                       return r0, nil
+               }
+
+       case reflectx.Func2x2:
+               n.call = func() (rest any, err error) {
+                       r0, r1 := fnT.Call2x2(n.args[0], n.args[1])
+                       return r0, asError(r1)
+               }
+
+       case reflectx.Func3x2:
+               n.call = func() (rest any, err error) {
+                       r0, r1 := fnT.Call3x2(n.args[0], n.args[1], n.args[2])
+                       return r0, asError(r1)
+               }
+
+       case reflectx.Func4x2:
+               n.call = func() (rest any, err error) {
+                       r0, r1 := fnT.Call4x2(n.args[0], n.args[1], n.args[2], 
n.args[3])
+                       return r0, asError(r1)
+               }
+
+       default:
+               if len(n.fn.Param) < 2 || len(n.fn.Param) > 4 {
+                       return errors.Errorf("TruncateRestriction has 
unexpected number of parameters: %v", len(n.fn.Param))
+               }
+
+               n.call = func() (rest any, err error) {
+                       ret := n.fn.Fn.Call(n.args)
+
+                       switch len(ret) {
+                       case 1:
+                               return ret[0], nil
+                       case 2:
+                               return ret[0], asError(ret[1])
+                       }
+
+                       panic(fmt.Sprintf("TruncateRestriction has unexpected 
number of return values: %v", len(ret)))
+               }
+       }
+
+       return nil
+}
diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_arity.tmpl 
b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_arity.tmpl
new file mode 100644
index 00000000000..7df994be0c7
--- /dev/null
+++ b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_arity.tmpl
@@ -0,0 +1,246 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Code generated from sdf_invokers_arity.tmpl. DO NOT EDIT.
+
+package exec
+
+import (
+       "fmt"
+
+       "github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf"
+       "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx"
+       "github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors"
+)
+
+func (n *cirInvoker) initCallFn() error {
+       // Expects a signature of the form:
+       // (context.Context?, key?, value) (restriction, error?)
+       // TODO(BEAM-9643): Link to full documentation.
+       switch fnT := n.fn.Fn.(type) {
+{{range $out := upto 3}}
+{{range $in := upto 4}}
+    {{if gt $out 0}}
+    {{if gt $in 0}}
+       case reflectx.Func{{$in}}x{{$out}}:
+               n.call = func() (rest any, err error) {
+                       {{mktuplef $out "r%v"}} := 
fnT.Call{{$in}}x{{$out}}({{mktuplef $in "n.args[%v]"}})
+                       {{- if eq $out 1}}
+                       return r0, nil
+                       {{- else}}
+                       return r0, asError(r1)
+                       {{- end}}
+               }
+       {{end}}
+       {{end}}
+{{end}}
+{{end}}
+       default:
+               if len(n.fn.Param) < 1 || len(n.fn.Param) > 3 {
+                       return errors.Errorf("CreateInitialRestriction has 
unexpected number of parameters: %v", len(n.fn.Param))
+               }
+
+               n.call = func() (rest any, err error) {
+                       ret := n.fn.Fn.Call(n.args)
+
+                       switch len(ret) {
+                       case 1:
+                               return ret[0], nil
+                       case 2:
+                               return ret[0], asError(ret[1])
+                       }
+
+                       panic(fmt.Sprintf("CreateInitialRestriction has 
unexpected number of return values: %v", len(ret)))
+               }
+       }
+
+       return nil
+}
+
+func (n *srInvoker) initCallFn() error {
+       // Expects a signature of the form:
+       // (context.Context?, key?, value, restriction) ([]restriction, error?)
+       // TODO(BEAM-9643): Link to full documentation.
+       switch fnT := n.fn.Fn.(type) {
+{{range $out := upto 3}}
+{{range $in := upto 5}}
+    {{if gt $out 0}}
+    {{if gt $in 1}}
+       case reflectx.Func{{$in}}x{{$out}}:
+               n.call = func() (splits any, err error) {
+                       {{mktuplef $out "r%v"}} := 
fnT.Call{{$in}}x{{$out}}({{mktuplef $in "n.args[%v]"}})
+                       {{- if eq $out 1}}
+                       return r0, nil
+                       {{- else}}
+                       return r0, asError(r1)
+                       {{- end}}
+               }
+       {{end}}
+       {{end}}
+{{end}}
+{{end}}
+       default:
+               if len(n.fn.Param) < 2 || len(n.fn.Param) > 4 {
+                       return errors.Errorf("SplitRestriction has unexpected 
number of parameters: %v", len(n.fn.Param))
+               }
+
+               n.call = func() (splits any, err error) {
+                       ret := n.fn.Fn.Call(n.args)
+
+                       switch len(ret) {
+                       case 1:
+                               return ret[0], nil
+                       case 2:
+                               return ret[0], asError(ret[1])
+                       }
+
+                       panic(fmt.Sprintf("SplitRestriction has unexpected 
number of return values: %v", len(ret)))
+               }
+       }
+
+       return nil
+}
+
+func (n *rsInvoker) initCallFn() error {
+       // Expects a signature of the form:
+       // (context.Context?, key?, value, restriction) (float64, error?)
+       // TODO(BEAM-9643): Link to full documentation.
+       switch fnT := n.fn.Fn.(type) {
+{{range $out := upto 3}}
+{{range $in := upto 5}}
+    {{if gt $out 0}}
+    {{if gt $in 1}}
+       case reflectx.Func{{$in}}x{{$out}}:
+               n.call = func() (size float64, err error) {
+                       {{mktuplef $out "r%v"}} := 
fnT.Call{{$in}}x{{$out}}({{mktuplef $in "n.args[%v]"}})
+                       {{- if eq $out 1}}
+                       return r0.(float64), nil
+                       {{- else}}
+                       return r0.(float64), asError(r1)
+                       {{- end}}
+               }
+       {{end}}
+       {{end}}
+{{end}}
+{{end}}
+       default:
+               if len(n.fn.Param) < 2 || len(n.fn.Param) > 4 {
+                       return errors.Errorf("RestrictionSize has unexpected 
number of parameters: %v", len(n.fn.Param))
+               }
+
+               n.call = func() (size float64, err error) {
+                       ret := n.fn.Fn.Call(n.args)
+
+                       switch len(ret) {
+                       case 1:
+                               return ret[0].(float64), nil
+                       case 2:
+                               return ret[0].(float64), asError(ret[1])
+                       }
+
+                       panic(fmt.Sprintf("RestrictionSize has unexpected 
number of return values: %v", len(ret)))
+               }
+       }
+
+       return nil
+}
+
+func (n *ctInvoker) initCallFn() error {
+       // Expects a signature of the form:
+       // (context.Context?, restriction) (sdf.RTracker, error?)
+       // TODO(BEAM-9643): Link to full documentation.
+       switch fnT := n.fn.Fn.(type) {
+{{range $out := upto 3}}
+{{range $in := upto 3}}
+    {{if gt $out 0}}
+    {{if gt $in 0}}
+       case reflectx.Func{{$in}}x{{$out}}:
+               n.call = func() (rt sdf.RTracker, err error) {
+                       {{mktuplef $out "r%v"}} := 
fnT.Call{{$in}}x{{$out}}({{mktuplef $in "n.args[%v]"}})
+                       {{- if eq $out 1}}
+                       return r0.(sdf.RTracker), nil
+                       {{- else}}
+                       return r0.(sdf.RTracker), asError(r1)
+                       {{- end}}
+               }
+       {{end}}
+       {{end}}
+{{end}}
+{{end}}
+       default:
+               if len(n.fn.Param) < 1 || len(n.fn.Param) > 2 {
+                       return errors.Errorf("CreateTracker has unexpected 
number of parameters: %v", len(n.fn.Param))
+               }
+
+               n.call = func() (rt sdf.RTracker, err error) {
+                       ret := n.fn.Fn.Call(n.args)
+
+                       switch len(ret) {
+                       case 1:
+                               return ret[0].(sdf.RTracker), nil
+                       case 2:
+                               return ret[0].(sdf.RTracker), asError(ret[1])
+                       }
+
+                       panic(fmt.Sprintf("CreateTracker has unexpected number 
of return values: %v", len(ret)))
+               }
+       }
+
+       return nil
+}
+
+func (n *trInvoker) initCallFn() error {
+       // Expects a signature of the form:
+       // (context.Context?, sdf.RTracker, key?, value) (restriction, error?)
+       // TODO(BEAM-9643): Link to full documentation.
+       switch fnT := n.fn.Fn.(type) {
+{{range $out := upto 3}}
+{{range $in := upto 5}}
+    {{if gt $out 0}}
+    {{if gt $in 1}}
+       case reflectx.Func{{$in}}x{{$out}}:
+               n.call = func() (rest any, err error) {
+                       {{mktuplef $out "r%v"}} := 
fnT.Call{{$in}}x{{$out}}({{mktuplef $in "n.args[%v]"}})
+                       {{- if eq $out 1}}
+                       return r0, nil
+                       {{- else}}
+                       return r0, asError(r1)
+                       {{- end}}
+               }
+       {{end}}
+       {{end}}
+{{end}}
+{{end}}
+       default:
+               if len(n.fn.Param) < 2 || len(n.fn.Param) > 4 {
+                       return errors.Errorf("TruncateRestriction has 
unexpected number of parameters: %v", len(n.fn.Param))
+               }
+
+               n.call = func() (rest any, err error) {
+                       ret := n.fn.Fn.Call(n.args)
+
+                       switch len(ret) {
+                       case 1:
+                               return ret[0], nil
+                       case 2:
+                               return ret[0], asError(ret[1])
+                       }
+
+                       panic(fmt.Sprintf("TruncateRestriction has unexpected 
number of return values: %v", len(ret)))
+               }
+       }
+
+       return nil
+}
diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_test.go 
b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_test.go
index e308071aaf9..edef16e51d7 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_test.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_test.go
@@ -16,6 +16,8 @@
 package exec
 
 import (
+       "context"
+       "errors"
        "testing"
        "time"
 
@@ -48,13 +50,59 @@ func TestInvokes(t *testing.T) {
        }
        statefulWeFn := (*graph.SplittableDoFn)(dfn)
 
+       initialRestErrDfn, err := graph.NewDoFn(
+               &VetCreateInitialRestrictionErrSdf{},
+               graph.NumMainInputs(graph.MainSingle),
+       )
+       if err != nil {
+               t.Fatalf("invalid function: %v", err)
+       }
+       initialRestErrSdf := (*graph.SplittableDoFn)(initialRestErrDfn)
+
+       splitRestErrDfn, err := graph.NewDoFn(
+               &VetSplitRestrictionErrSdf{},
+               graph.NumMainInputs(graph.MainSingle),
+       )
+       if err != nil {
+               t.Fatalf("invalid function: %v", err)
+       }
+       splitRestErrSdf := (*graph.SplittableDoFn)(splitRestErrDfn)
+
+       restSizeErrDfn, err := graph.NewDoFn(
+               &VetRestrictionSizeErrSdf{},
+               graph.NumMainInputs(graph.MainSingle),
+       )
+       if err != nil {
+               t.Fatalf("invalid function: %v", err)
+       }
+       restSizeErrSdf := (*graph.SplittableDoFn)(restSizeErrDfn)
+
+       trackerErrDfn, err := graph.NewDoFn(
+               &VetCreateTrackerErrSdf{},
+               graph.NumMainInputs(graph.MainSingle),
+       )
+       if err != nil {
+               t.Fatalf("invalid function: %v", err)
+       }
+       trackerErrSdf := (*graph.SplittableDoFn)(trackerErrDfn)
+
+       truncateRestErrDfn, err := graph.NewDoFn(
+               &VetTruncateRestrictionErrSdf{},
+               graph.NumMainInputs(graph.MainSingle),
+       )
+       if err != nil {
+               t.Fatalf("invalid function: %v", err)
+       }
+       truncateRestErrSdf := (*graph.SplittableDoFn)(truncateRestErrDfn)
+
        // Tests.
        t.Run("CreateInitialRestriction Invoker (cirInvoker)", func(t 
*testing.T) {
                tests := []struct {
-                       name string
-                       sdf  *graph.SplittableDoFn
-                       elms *FullValue
-                       want *VetRestriction
+                       name    string
+                       sdf     *graph.SplittableDoFn
+                       elms    *FullValue
+                       want    *VetRestriction
+                       wantErr bool
                }{
                        {
                                name: "SingleElem",
@@ -68,6 +116,12 @@ func TestInvokes(t *testing.T) {
                                elms: &FullValue{Elm: 1, Elm2: 2},
                                want: &VetRestriction{ID: "KvSdf", CreateRest: 
true, Key: 1, Val: 2},
                        },
+                       {
+                               name:    "Error",
+                               sdf:     initialRestErrSdf,
+                               elms:    &FullValue{Elm: 1},
+                               wantErr: true,
+                       },
                }
                for _, test := range tests {
                        test := test
@@ -77,7 +131,15 @@ func TestInvokes(t *testing.T) {
                                if err != nil {
                                        
t.Fatalf("newCreateInitialRestrictionInvoker failed: %v", err)
                                }
-                               got := invoker.Invoke(test.elms)
+
+                               got, err := 
invoker.Invoke(context.Background(), test.elms)
+                               if (err != nil) != test.wantErr {
+                                       t.Fatalf("Invoke(%v) error = %v, 
wantErr %v", test.elms, err, test.wantErr)
+                               }
+                               if test.wantErr {
+                                       return
+                               }
+
                                if !cmp.Equal(got, test.want) {
                                        t.Errorf("Invoke(%v) has incorrect 
output: got: %v, want: %v",
                                                test.elms, got, test.want)
@@ -94,11 +156,12 @@ func TestInvokes(t *testing.T) {
 
        t.Run("SplitRestriction Invoker (srInvoker)", func(t *testing.T) {
                tests := []struct {
-                       name string
-                       sdf  *graph.SplittableDoFn
-                       elms *FullValue
-                       rest *VetRestriction
-                       want []any
+                       name    string
+                       sdf     *graph.SplittableDoFn
+                       elms    *FullValue
+                       rest    *VetRestriction
+                       want    []any
+                       wantErr bool
                }{
                        {
                                name: "SingleElem",
@@ -119,6 +182,13 @@ func TestInvokes(t *testing.T) {
                                        &VetRestriction{ID: "KvSdf.2", 
SplitRest: true, Key: 1, Val: 2},
                                },
                        },
+                       {
+                               name:    "Error",
+                               sdf:     splitRestErrSdf,
+                               elms:    &FullValue{Elm: 1},
+                               rest:    &VetRestriction{ID: "Sdf"},
+                               wantErr: true,
+                       },
                }
                for _, test := range tests {
                        test := test
@@ -128,8 +198,16 @@ func TestInvokes(t *testing.T) {
                                if err != nil {
                                        t.Fatalf("newSplitRestrictionInvoker 
failed: %v", err)
                                }
+
                                rest := *test.rest // Create a copy because our 
test SDF edits the restriction.
-                               got := invoker.Invoke(test.elms, &rest)
+                               got, err := 
invoker.Invoke(context.Background(), test.elms, &rest)
+                               if (err != nil) != test.wantErr {
+                                       t.Fatalf("Invoke(%v, %v) error = %v, 
wantErr %v", test.elms, test.rest, err, test.wantErr)
+                               }
+                               if test.wantErr {
+                                       return
+                               }
+
                                if !cmp.Equal(got, test.want) {
                                        t.Errorf("Invoke(%v, %v) has incorrect 
output: got: %v, want: %v",
                                                test.elms, test.rest, got, 
test.want)
@@ -152,6 +230,7 @@ func TestInvokes(t *testing.T) {
                        rest     *VetRestriction
                        want     float64
                        restWant *VetRestriction
+                       wantErr  bool
                }{
                        {
                                name:     "SingleElem",
@@ -168,6 +247,13 @@ func TestInvokes(t *testing.T) {
                                want:     3,
                                restWant: &VetRestriction{ID: "KvSdf", 
RestSize: true, Key: 1, Val: 2},
                        },
+                       {
+                               name:    "Error",
+                               sdf:     restSizeErrSdf,
+                               elms:    &FullValue{Elm: 1},
+                               rest:    &VetRestriction{ID: "Sdf"},
+                               wantErr: true,
+                       },
                }
                for _, test := range tests {
                        test := test
@@ -178,7 +264,15 @@ func TestInvokes(t *testing.T) {
                                        t.Fatalf("newRestrictionSizeInvoker 
failed: %v", err)
                                }
                                rest := *test.rest // Create a copy because our 
test SDF edits the restriction.
-                               got := invoker.Invoke(test.elms, &rest)
+
+                               got, err := 
invoker.Invoke(context.Background(), test.elms, &rest)
+                               if (err != nil) != test.wantErr {
+                                       t.Fatalf("Invoke(%v, %v) error = %v, 
wantErr %v", test.elms, test.rest, err, test.wantErr)
+                               }
+                               if test.wantErr {
+                                       return
+                               }
+
                                if !cmp.Equal(got, test.want) {
                                        t.Errorf("Invoke(%v, %v) has incorrect 
output: got: %v, want: %v",
                                                test.elms, test.rest, got, 
test.want)
@@ -199,10 +293,11 @@ func TestInvokes(t *testing.T) {
 
        t.Run("CreateTracker Invoker (ctInvoker)", func(t *testing.T) {
                tests := []struct {
-                       name string
-                       sdf  *graph.SplittableDoFn
-                       rest *VetRestriction
-                       want *VetRTracker
+                       name    string
+                       sdf     *graph.SplittableDoFn
+                       rest    *VetRestriction
+                       want    *VetRTracker
+                       wantErr bool
                }{
                        {
                                name: "SingleElem",
@@ -215,6 +310,12 @@ func TestInvokes(t *testing.T) {
                                rest: &VetRestriction{ID: "KvSdf"},
                                want: &VetRTracker{&VetRestriction{ID: "KvSdf", 
CreateTracker: true}},
                        },
+                       {
+                               name:    "Error",
+                               sdf:     trackerErrSdf,
+                               rest:    &VetRestriction{ID: "Sdf"},
+                               wantErr: true,
+                       },
                }
                for _, test := range tests {
                        test := test
@@ -224,7 +325,15 @@ func TestInvokes(t *testing.T) {
                                if err != nil {
                                        t.Fatalf("newCreateTrackerInvoker 
failed: %v", err)
                                }
-                               got := invoker.Invoke(test.rest)
+
+                               got, err := 
invoker.Invoke(context.Background(), test.rest)
+                               if (err != nil) != test.wantErr {
+                                       t.Fatalf("Invoke(%v) error = %v, 
wantErr %v", test.rest, err, test.wantErr)
+                               }
+                               if test.wantErr {
+                                       return
+                               }
+
                                if !cmp.Equal(got, test.want) {
                                        t.Errorf("Invoke(%v) has incorrect 
output: got: %v, want: %v",
                                                test.rest, got, test.want)
@@ -310,11 +419,12 @@ func TestInvokes(t *testing.T) {
 
        t.Run("TruncateRestriction Invoker (trInvoker)", func(t *testing.T) {
                tests := []struct {
-                       name string
-                       sdf  *graph.SplittableDoFn
-                       elms *FullValue
-                       rest *VetRestriction
-                       want any
+                       name    string
+                       sdf     *graph.SplittableDoFn
+                       elms    *FullValue
+                       rest    *VetRestriction
+                       want    any
+                       wantErr bool
                }{
                        {
                                name: "SingleElem",
@@ -329,6 +439,13 @@ func TestInvokes(t *testing.T) {
                                rest: &VetRestriction{ID: "KvSdf"},
                                want: &VetRestriction{ID: "KvSdf", 
CreateTracker: true, TruncateRest: true, RestSize: true, Key: 1, Val: 2},
                        },
+                       {
+                               name:    "Error",
+                               sdf:     truncateRestErrSdf,
+                               elms:    &FullValue{Elm: 1},
+                               rest:    &VetRestriction{ID: "Sdf"},
+                               wantErr: true,
+                       },
                }
                for _, test := range tests {
                        test := test
@@ -336,24 +453,38 @@ func TestInvokes(t *testing.T) {
                        ctFn := test.sdf.CreateTrackerFn()
                        rsFn := test.sdf.RestrictionSizeFn()
                        t.Run(test.name, func(t *testing.T) {
+                               ctx := context.Background()
                                rest := test.rest // Create a copy because our 
test SDF edits the restriction.
                                ctInvoker, err := newCreateTrackerInvoker(ctFn)
                                if err != nil {
                                        t.Fatalf("newCreateTrackerInvoker 
failed: %v", err)
                                }
-                               rt := ctInvoker.Invoke(rest)
+                               rt, err := ctInvoker.Invoke(ctx, rest)
+                               if err != nil {
+                                       t.Fatalf("ctInvoker.Invoke(%v) failed: 
%v", rest, err)
+                               }
 
                                trInvoker, err := 
newTruncateRestrictionInvoker(fn)
                                if err != nil {
                                        t.Fatalf("newTruncateRestrictionInvoker 
failed: %v", err)
                                }
-                               trRest := trInvoker.Invoke(rt, test.elms)
+
+                               trRest, err := trInvoker.Invoke(ctx, rt, 
test.elms)
+                               if (err != nil) != test.wantErr {
+                                       t.Fatalf("trInvoker.Invoke(%v, %v) = 
%v, wantErr %v", rt, test.elms, err, test.wantErr)
+                               }
+                               if test.wantErr {
+                                       return
+                               }
 
                                rsInvoker, err := 
newRestrictionSizeInvoker(rsFn)
                                if err != nil {
                                        t.Fatalf("newRestrictionSizeInvoker 
failed: %v", err)
                                }
-                               _ = rsInvoker.Invoke(test.elms, trRest)
+                               if _, err := rsInvoker.Invoke(ctx, test.elms, 
trRest); err != nil {
+                                       t.Fatalf("rsInvoker.Invoke(%v, %v) 
failed: %v", test.elms, trRest, err)
+                               }
+
                                if !cmp.Equal(trRest, test.want) {
                                        t.Errorf("Invoke(%v, %v) has incorrect 
output: got: %v, want: %v",
                                                test.elms, test.rest, trRest, 
test.want)
@@ -411,24 +542,35 @@ func TestInvokes(t *testing.T) {
                        ctFn := test.sdf.CreateTrackerFn()
                        rsFn := test.sdf.RestrictionSizeFn()
                        t.Run(test.name, func(t *testing.T) {
+                               ctx := context.Background()
                                rest := test.rest // Create a copy because our 
test SDF edits the restriction.
                                ctInvoker, err := newCreateTrackerInvoker(ctFn)
                                if err != nil {
                                        t.Fatalf("newCreateTrackerInvoker 
failed: %v", err)
                                }
-                               rt := ctInvoker.Invoke(rest)
+                               rt, err := ctInvoker.Invoke(ctx, rest)
+                               if err != nil {
+                                       t.Fatalf("ctInvoker.Invoke(%v) failed: 
%v", rest, err)
+                               }
 
                                trInvoker, err := 
newDefaultTruncateRestrictionInvoker()
                                if err != nil {
                                        t.Fatalf("newTruncateRestrictionInvoker 
failed: %v", err)
                                }
-                               trRest := trInvoker.Invoke(rt, test.elms)
+                               trRest, err := trInvoker.Invoke(ctx, rt, 
test.elms)
+                               if err != nil {
+                                       t.Fatalf("trInvoker.Invoke(%v, %v) 
failed: %v", rt, test.elms, err)
+                               }
+
                                if trRest != nil {
                                        rsInvoker, err := 
newRestrictionSizeInvoker(rsFn)
                                        if err != nil {
                                                
t.Fatalf("newRestrictionSizeInvoker failed: %v", err)
                                        }
-                                       _ = rsInvoker.Invoke(test.elms, trRest)
+                                       if _, err := rsInvoker.Invoke(ctx, 
test.elms, trRest); err != nil {
+                                               t.Fatalf("rsInvoker.Invoke(%v, 
%v) failed: %v", test.elms, trRest, err)
+                                       }
+
                                        if !cmp.Equal(trRest, test.want) {
                                                t.Errorf("Invoke(%v, %v) has 
incorrect output: got: %v, want: %v",
                                                        test.elms, test.rest, 
trRest, test.want)
@@ -732,3 +874,55 @@ func (fn *VetEmptyInitialSplitSdf) ProcessElement(rt 
*VetRTracker, i int, emit f
        emit(rest)
        return sdf.ResumeProcessingIn(1 * time.Second)
 }
+
+var errSdf = errors.New("SDF error")
+
+// VetCreateInitialRestrictionErrSdf is an SDF with a CreateInitialRestriction 
method
+// that returns a non-nil error.
+type VetCreateInitialRestrictionErrSdf struct {
+       VetSdf
+}
+
+func (fn *VetCreateInitialRestrictionErrSdf) CreateInitialRestriction(i int) 
(*VetRestriction, error) {
+       return nil, errSdf
+}
+
+// VetSplitRestrictionErrSdf is an SDF with a SplitRestriction method
+// that returns a non-nil error.
+type VetSplitRestrictionErrSdf struct {
+       VetSdf
+}
+
+func (fn *VetSplitRestrictionErrSdf) SplitRestriction(int, *VetRestriction) 
([]*VetRestriction, error) {
+       return nil, errSdf
+}
+
+// VetRestrictionSizeErrSdf is an SDF with a RestrictionSize method
+// that returns a non-nil error.
+type VetRestrictionSizeErrSdf struct {
+       VetSdf
+}
+
+func (fn *VetRestrictionSizeErrSdf) RestrictionSize(int, *VetRestriction) 
(float64, error) {
+       return -1, errSdf
+}
+
+// VetCreateTrackerErrSdf is an SDF with a CreateTracker method
+// that returns a non-nil error.
+type VetCreateTrackerErrSdf struct {
+       VetSdf
+}
+
+func (fn *VetCreateTrackerErrSdf) CreateTracker(*VetRestriction) 
(*VetRTracker, error) {
+       return nil, errSdf
+}
+
+// VetTruncateRestrictionErrSdf is an SDF with a TruncateRestriction method
+// that returns a non-nil error.
+type VetTruncateRestrictionErrSdf struct {
+       VetSdf
+}
+
+func (fn *VetTruncateRestrictionErrSdf) TruncateRestriction(*VetRTracker, int) 
(*VetRestriction, error) {
+       return nil, errSdf
+}
diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf_test.go 
b/sdks/go/pkg/beam/core/runtime/exec/sdf_test.go
index 414f28553a8..a0380796e86 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/sdf_test.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/sdf_test.go
@@ -1114,10 +1114,11 @@ func TestAsSplittableUnit(t *testing.T) {
 
                                // Call from SplittableUnit and check results.
                                su := SplittableUnit(node)
-                               if err := node.Up(context.Background()); err != 
nil {
+                               ctx := context.Background()
+                               if err := node.Up(ctx); err != nil {
                                        
t.Fatalf("ProcessSizedElementsAndRestrictions.Up() failed: %v", err)
                                }
-                               gotPrimaries, gotResiduals, err := 
su.Split(test.frac)
+                               gotPrimaries, gotResiduals, err := 
su.Split(ctx, test.frac)
                                if err != nil {
                                        t.Fatalf("SplittableUnit.Split(%v) 
failed: %v", test.frac, err)
                                }
@@ -1184,10 +1185,11 @@ func TestAsSplittableUnit(t *testing.T) {
 
                                // Call from SplittableUnit and check results.
                                su := SplittableUnit(node)
-                               if err := node.Up(context.Background()); err != 
nil {
+                               ctx := context.Background()
+                               if err := node.Up(ctx); err != nil {
                                        
t.Fatalf("ProcessSizedElementsAndRestrictions.Up() failed: %v", err)
                                }
-                               _, _, err := su.Split(0.5)
+                               _, _, err := su.Split(ctx, 0.5)
                                if err == nil {
                                        t.Errorf("SplittableUnit.Split(%v) was 
expected to fail.", test.in)
                                }
@@ -1251,10 +1253,11 @@ func TestAsSplittableUnit(t *testing.T) {
                                node.currW = 0
                                // Call from SplittableUnit and check results.
                                su := SplittableUnit(node)
-                               if err := node.Up(context.Background()); err != 
nil {
+                               ctx := context.Background()
+                               if err := node.Up(ctx); err != nil {
                                        
t.Fatalf("ProcessSizedElementsAndRestrictions.Up() failed: %v", err)
                                }
-                               gotResiduals, err := su.Checkpoint()
+                               gotResiduals, err := su.Checkpoint(ctx)
 
                                if err != nil {
                                        t.Fatalf("SplittableUnit.Checkpoint() 
returned error, got %v", err)
@@ -1401,7 +1404,7 @@ func TestMultiWindowProcessing(t *testing.T) {
        // Split should hit window boundary between 2 and 3. We don't need to 
check
        // the split result here, just the effects it has on currW and numW.
        frac := 0.5
-       if _, _, err := su.Split(frac); err != nil {
+       if _, _, err := su.Split(context.Background(), frac); err != nil {
                t.Errorf("Split(%v) failed with error: %v", frac, err)
        }
        if got, want := node.currW, blockW; got != want {
diff --git a/sdks/go/pkg/beam/runners/prism/internal/config/config.go 
b/sdks/go/pkg/beam/runners/prism/internal/config/config.go
index fc2b68d092f..9c3bdd012bc 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/config/config.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/config/config.go
@@ -244,4 +244,4 @@ func (r *HandlerRegistry) GetVariant(name string) *Variant {
                return nil
        }
        return &Variant{parent: r, name: name, handlers: vs.Handlers}
-}
\ No newline at end of file
+}
diff --git a/sdks/go/pkg/beam/runners/prism/internal/urns/urns_test.go 
b/sdks/go/pkg/beam/runners/prism/internal/urns/urns_test.go
index 3c7cae97397..7b553f6ad65 100644
--- a/sdks/go/pkg/beam/runners/prism/internal/urns/urns_test.go
+++ b/sdks/go/pkg/beam/runners/prism/internal/urns/urns_test.go
@@ -33,4 +33,4 @@ func Test_toUrn(t *testing.T) {
        if got := quickUrn(pipepb.StandardPTransforms_PAR_DO); got != want {
                t.Errorf("quickUrn(\"pipepb.StandardPTransforms_PAR_DO\") = %v, 
want %v", got, want)
        }
-}
\ No newline at end of file
+}

Reply via email to