This is an automated email from the ASF dual-hosted git repository. lostluck pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new 29ea6e0eb8a [Go SDK]: Allow SDF methods to have context param and error return value (#25437) 29ea6e0eb8a is described below commit 29ea6e0eb8a2f2cf571d1a27799d125ca051008b Author: Johanna Öjeling <51084516+johannaojel...@users.noreply.github.com> AuthorDate: Thu Feb 16 21:02:08 2023 +0100 [Go SDK]: Allow SDF methods to have context param and error return value (#25437) * Allow context param and error return value in SDF validation * Use context param and error return value in SDF method invocation * Run go fmt * Clean up error messages from "fn reflect.methodValueCall" * Validate return value count in a more correct way --- sdks/go/pkg/beam/core/graph/fn.go | 108 ++++--- sdks/go/pkg/beam/core/graph/fn_test.go | 87 ++++++ sdks/go/pkg/beam/core/runtime/exec/datasource.go | 8 +- .../pkg/beam/core/runtime/exec/datasource_test.go | 12 +- sdks/go/pkg/beam/core/runtime/exec/sdf.go | 134 +++++--- sdks/go/pkg/beam/core/runtime/exec/sdf_invokers.go | 339 +++++++++------------ .../beam/core/runtime/exec/sdf_invokers_arity.go | 336 ++++++++++++++++++++ .../beam/core/runtime/exec/sdf_invokers_arity.tmpl | 246 +++++++++++++++ .../beam/core/runtime/exec/sdf_invokers_test.go | 250 +++++++++++++-- sdks/go/pkg/beam/core/runtime/exec/sdf_test.go | 17 +- .../beam/runners/prism/internal/config/config.go | 2 +- .../beam/runners/prism/internal/urns/urns_test.go | 2 +- 12 files changed, 1227 insertions(+), 314 deletions(-) diff --git a/sdks/go/pkg/beam/core/graph/fn.go b/sdks/go/pkg/beam/core/graph/fn.go index 907af4045b5..25b846370fb 100644 --- a/sdks/go/pkg/beam/core/graph/fn.go +++ b/sdks/go/pkg/beam/core/graph/fn.go @@ -866,14 +866,15 @@ func validateSdfSignatures(fn *Fn, numMainIn mainInputs) error { // CreateInitialRestriction. if numMainIn == MainUnknown { initialRestFn := fn.methods[createInitialRestrictionName] - paramNum := len(initialRestFn.Param) + paramNum := len(initialRestFn.Params(funcx.FnValue)) + switch paramNum { case int(MainSingle), int(MainKv): num = paramNum default: // Can't infer because method has invalid # of main inputs. - err := errors.Errorf("invalid number of params in method %v. got: %v, want: %v or %v", + err := errors.Errorf("invalid number of main input params in method %v. got: %v, want: %v or %v", createInitialRestrictionName, paramNum, int(MainSingle), int(MainKv)) - return errors.SetTopLevelMsgf(err, "Invalid number of parameters in method %v. "+ + return errors.SetTopLevelMsgf(err, "Invalid number of main input parameters in method %v. "+ "Got: %v, Want: %v or %v. Check that the signature conforms to the expected signature for %v, "+ "and that elements in SDF method parameters match elements in %v.", createInitialRestrictionName, paramNum, int(MainSingle), int(MainKv), createInitialRestrictionName, processElementName) @@ -894,7 +895,7 @@ func validateSdfSignatures(fn *Fn, numMainIn mainInputs) error { // in each SDF method in the given Fn, and returns an error if a method has an // invalid/unexpected number. func validateSdfSigNumbers(fn *Fn, num int) error { - paramNums := map[string]int{ + reqParamNums := map[string]int{ createInitialRestrictionName: num, splitRestrictionName: num + 1, restrictionSizeName: num + 1, @@ -904,32 +905,52 @@ func validateSdfSigNumbers(fn *Fn, num int) error { optionalSdfs := map[string]bool{ truncateRestrictionName: true, } - returnNum := 1 // TODO(BEAM-3301): Enable optional error params in SDF methods. + reqReturnNum := 1 for _, name := range sdfNames { method, ok := fn.methods[name] if !ok && optionalSdfs[name] { continue } - if len(method.Param) != paramNums[name] { - err := errors.Errorf("unexpected number of params in method %v. got: %v, want: %v", - name, len(method.Param), paramNums[name]) + + reqParamNum := reqParamNums[name] + if !sdfHasValidParamNum(method.Param, reqParamNum) { + err := errors.Errorf("unexpected number of params in method %v. got: %v, want: %v or optionally %v "+ + "if first param is of type context.Context", name, len(method.Param), reqParamNum, reqParamNum+1) return errors.SetTopLevelMsgf(err, "Unexpected number of parameters in method %v. "+ - "Got: %v, Want: %v. Check that the signature conforms to the expected signature for %v, "+ - "and that elements in SDF method parameters match elements in %v.", - name, len(method.Param), paramNums[name], name, processElementName) + "Got: %v, Want: %v or optionally %v if first param is of type context.Context. "+ + "Check that the signature conforms to the expected signature for %v, and that elements in SDF method "+ + "parameters match elements in %v.", name, len(method.Param), reqParamNum, reqParamNum+1, + name, processElementName) } - if len(method.Ret) != returnNum { - err := errors.Errorf("unexpected number of returns in method %v. got: %v, want: %v", - name, len(method.Ret), returnNum) + if !sdfHasValidReturnNum(method.Ret, reqReturnNum) { + err := errors.Errorf("unexpected number of returns in method %v. got: %v, want: %v or optionally %v "+ + "if last value is of type error", name, len(method.Ret), reqReturnNum, reqReturnNum+1) return errors.SetTopLevelMsgf(err, "Unexpected number of return values in method %v. "+ - "Got: %v, Want: %v. Check that the signature conforms to the expected signature for %v.", - name, len(method.Ret), returnNum, name) + "Got: %v, Want: %v or optionally %v if last value is of type error. "+ + "Check that the signature conforms to the expected signature for %v.", + name, len(method.Ret), reqReturnNum, reqReturnNum+1, name) } } return nil } +func sdfHasValidParamNum(params []funcx.FnParam, requiredNum int) bool { + if len(params) == requiredNum { + return true + } + + return len(params) == requiredNum+1 && params[0].Kind == funcx.FnContext +} + +func sdfHasValidReturnNum(returns []funcx.ReturnParam, requiredNum int) bool { + if len(returns) == requiredNum { + return true + } + + return len(returns) == requiredNum+1 && returns[len(returns)-1].Kind == funcx.RetError +} + // validateSdfSigTypes validates the types of the parameters and return values // in each SDF method in the given Fn, and returns an error if a method has an // invalid/mismatched type. Assumes that the number of parameters and return @@ -940,22 +961,25 @@ func validateSdfSigTypes(fn *Fn, num int) error { for _, name := range requiredSdfNames { method := fn.methods[name] + startIdx := sdfRequiredParamStartIndex(method) + switch name { case createInitialRestrictionName: - if err := validateSdfElementT(fn, createInitialRestrictionName, method, num, 0); err != nil { + if err := validateSdfElementT(fn, createInitialRestrictionName, method, num, startIdx); err != nil { return err } case splitRestrictionName: - if err := validateSdfElementT(fn, splitRestrictionName, method, num, 0); err != nil { + if err := validateSdfElementT(fn, splitRestrictionName, method, num, startIdx); err != nil { return err } - if method.Param[num].T != restrictionT { + idx := num + startIdx + if method.Param[idx].T != restrictionT { err := errors.Errorf("mismatched restriction type in method %v, param %v. got: %v, want: %v", - splitRestrictionName, num, method.Param[num].T, restrictionT) + splitRestrictionName, idx, method.Param[idx].T, restrictionT) return errors.SetTopLevelMsgf(err, "Mismatched restriction type in method %v, "+ "parameter at index %v. Got: %v, Want: %v (from method %v). "+ "Ensure that all restrictions in an SDF are the same type.", - splitRestrictionName, num, method.Param[num].T, restrictionT, createInitialRestrictionName) + splitRestrictionName, idx, method.Param[idx].T, restrictionT, createInitialRestrictionName) } if method.Ret[0].T.Kind() != reflect.Slice || method.Ret[0].T.Elem() != restrictionT { @@ -967,16 +991,17 @@ func validateSdfSigTypes(fn *Fn, num int) error { splitRestrictionName, 0, method.Ret[0].T, reflect.SliceOf(restrictionT), createInitialRestrictionName, splitRestrictionName) } case restrictionSizeName: - if err := validateSdfElementT(fn, restrictionSizeName, method, num, 0); err != nil { + if err := validateSdfElementT(fn, restrictionSizeName, method, num, startIdx); err != nil { return err } - if method.Param[num].T != restrictionT { + idx := num + startIdx + if method.Param[idx].T != restrictionT { err := errors.Errorf("mismatched restriction type in method %v, param %v. got: %v, want: %v", - restrictionSizeName, num, method.Param[num].T, restrictionT) + restrictionSizeName, idx, method.Param[idx].T, restrictionT) return errors.SetTopLevelMsgf(err, "Mismatched restriction type in method %v, "+ "parameter at index %v. Got: %v, Want: %v (from method %v). "+ "Ensure that all restrictions in an SDF are the same type.", - restrictionSizeName, num, method.Param[num].T, restrictionT, createInitialRestrictionName) + restrictionSizeName, idx, method.Param[idx].T, restrictionT, createInitialRestrictionName) } if method.Ret[0].T != reflectx.Float64 { err := errors.Errorf("invalid output type in method %v, return %v. got: %v, want: %v", @@ -986,13 +1011,13 @@ func validateSdfSigTypes(fn *Fn, num int) error { restrictionSizeName, 0, method.Ret[0].T, reflectx.Float64) } case createTrackerName: - if method.Param[0].T != restrictionT { + if method.Param[startIdx].T != restrictionT { err := errors.Errorf("mismatched restriction type in method %v, param %v. got: %v, want: %v", - createTrackerName, 0, method.Param[0].T, restrictionT) + createTrackerName, startIdx, method.Param[startIdx].T, restrictionT) return errors.SetTopLevelMsgf(err, "Mismatched restriction type in method %v, "+ "parameter at index %v. Got: %v, Want: %v (from method %v). "+ "Ensure that all restrictions in an SDF are the same type.", - createTrackerName, 0, method.Param[0].T, restrictionT, createInitialRestrictionName) + createTrackerName, startIdx, method.Param[startIdx].T, restrictionT, createInitialRestrictionName) } if !method.Ret[0].T.Implements(rTrackerT) { err := errors.Errorf("invalid output type in method %v, return %v: %v does not implement sdf.RTracker", @@ -1020,15 +1045,18 @@ func validateSdfSigTypes(fn *Fn, num int) error { if !ok { continue } + + startIdx := sdfRequiredParamStartIndex(method) + switch name { case truncateRestrictionName: - if method.Param[0].T != rTrackerImplT { + if method.Param[startIdx].T != rTrackerImplT { err := errors.Errorf("mismatched restriction tracker type in method %v, param %v. got: %v, want: %v", - truncateRestrictionName, 0, method.Param[0].T, rTrackerImplT) + truncateRestrictionName, startIdx, method.Param[startIdx].T, rTrackerImplT) return errors.SetTopLevelMsgf(err, "Mismatched restriction tracker type in method %v, "+ "parameter at index %v. Got: %v, Want: %v (from method %v). "+ "Ensure that restriction tracker is the first parameter.", - truncateRestrictionName, 0, method.Param[0].T, rTrackerImplT, createTrackerName) + truncateRestrictionName, startIdx, method.Param[startIdx].T, rTrackerImplT, createTrackerName) } if method.Ret[0].T != restrictionT { err := errors.Errorf("invalid output type in method %v, return %v. got: %v, want: %v", @@ -1052,6 +1080,14 @@ func validateSdfSigTypes(fn *Fn, num int) error { return nil } +func sdfRequiredParamStartIndex(method *funcx.Fn) int { + if ctxIndex, ok := method.Context(); ok { + return ctxIndex + 1 + } + + return 0 +} + // validateSdfElementT validates that element types in an SDF method are // consistent with the ProcessElement method. This method assumes that the // first 'num' parameters starting with startIndex are the elements. @@ -1062,13 +1098,14 @@ func validateSdfElementT(fn *Fn, name string, method *funcx.Fn, num int, startIn pos, _, _ := processFn.Inputs() for i := 0; i < num; i++ { - if method.Param[i+startIndex].T != processFn.Param[pos+i].T { + idx := i + startIndex + if method.Param[idx].T != processFn.Param[pos+i].T { err := errors.Errorf("mismatched element type in method %v, param %v. got: %v, want: %v", - name, i, method.Param[i].T, processFn.Param[pos+i].T) + name, idx, method.Param[idx].T, processFn.Param[pos+i].T) return errors.SetTopLevelMsgf(err, "Mismatched element type in method %v, "+ "parameter at index %v. Got: %v, Want: %v (from method %v). "+ "Ensure that element parameters in SDF methods have consistent types with element parameters in %v.", - name, i, method.Param[i].T, processFn.Param[pos+i].T, processElementName, processElementName) + name, idx, method.Param[idx].T, processFn.Param[pos+i].T, processElementName, processElementName) } } return nil @@ -1178,7 +1215,8 @@ func validateStatefulWatermarkSig(fn *Fn, numMainIn int) error { // CreateInitialRestriction. if numMainIn == int(MainUnknown) { initialRestFn := fn.methods[createInitialRestrictionName] - paramNum := len(initialRestFn.Param) + paramNum := len(initialRestFn.Params(funcx.FnValue)) + switch paramNum { case int(MainSingle), int(MainKv): numMainIn = paramNum diff --git a/sdks/go/pkg/beam/core/graph/fn_test.go b/sdks/go/pkg/beam/core/graph/fn_test.go index d2f88a8a5ce..cf44761d4f3 100644 --- a/sdks/go/pkg/beam/core/graph/fn_test.go +++ b/sdks/go/pkg/beam/core/graph/fn_test.go @@ -190,6 +190,9 @@ func TestNewDoFnSdf(t *testing.T) { }{ {dfn: &GoodSdf{}, main: MainSingle}, {dfn: &GoodSdfKv{}, main: MainKv}, + {dfn: &GoodSdfWContext{}, main: MainSingle}, + {dfn: &GoodSdfKvWContext{}, main: MainKv}, + {dfn: &GoodSdfWErr{}, main: MainSingle}, {dfn: &GoodIgnoreOtherExportedMethods{}, main: MainSingle}, } @@ -987,6 +990,90 @@ func (fn *GoodSdfKv) TruncateRestriction(*RTrackerT, int, int) RestT { return RestT{} } +type GoodSdfWContext struct { + *GoodDoFn +} + +func (fn *GoodSdfWContext) CreateInitialRestriction(context.Context, int) RestT { + return RestT{} +} + +func (fn *GoodSdfWContext) SplitRestriction(context.Context, int, RestT) []RestT { + return []RestT{} +} + +func (fn *GoodSdfWContext) RestrictionSize(context.Context, int, RestT) float64 { + return 0 +} + +func (fn *GoodSdfWContext) CreateTracker(context.Context, RestT) *RTrackerT { + return &RTrackerT{} +} + +func (fn *GoodSdfWContext) ProcessElement(context.Context, *RTrackerT, int) (int, sdf.ProcessContinuation) { + return 0, sdf.StopProcessing() +} + +func (fn *GoodSdfWContext) TruncateRestriction(context.Context, *RTrackerT, int) RestT { + return RestT{} +} + +type GoodSdfKvWContext struct { + *GoodDoFnKv +} + +func (fn *GoodSdfKvWContext) CreateInitialRestriction(context.Context, int, int) RestT { + return RestT{} +} + +func (fn *GoodSdfKvWContext) SplitRestriction(context.Context, int, int, RestT) []RestT { + return []RestT{} +} + +func (fn *GoodSdfKvWContext) RestrictionSize(context.Context, int, int, RestT) float64 { + return 0 +} + +func (fn *GoodSdfKvWContext) CreateTracker(context.Context, RestT) *RTrackerT { + return &RTrackerT{} +} + +func (fn *GoodSdfKvWContext) ProcessElement(context.Context, *RTrackerT, int, int) (int, sdf.ProcessContinuation) { + return 0, sdf.StopProcessing() +} + +func (fn *GoodSdfKvWContext) TruncateRestriction(context.Context, *RTrackerT, int, int) RestT { + return RestT{} +} + +type GoodSdfWErr struct { + *GoodDoFn +} + +func (fn *GoodSdfWErr) CreateInitialRestriction(int) (RestT, error) { + return RestT{}, nil +} + +func (fn *GoodSdfWErr) SplitRestriction(int, RestT) ([]RestT, error) { + return []RestT{}, nil +} + +func (fn *GoodSdfWErr) RestrictionSize(int, RestT) (float64, error) { + return 0, nil +} + +func (fn *GoodSdfWErr) CreateTracker(RestT) (*RTrackerT, error) { + return &RTrackerT{}, nil +} + +func (fn *GoodSdfWErr) ProcessElement(*RTrackerT, int) (int, sdf.ProcessContinuation, error) { + return 0, sdf.StopProcessing(), nil +} + +func (fn *GoodSdfWErr) TruncateRestriction(*RTrackerT, int) (RestT, error) { + return RestT{}, nil +} + type GoodIgnoreOtherExportedMethods struct { *GoodSdf } diff --git a/sdks/go/pkg/beam/core/runtime/exec/datasource.go b/sdks/go/pkg/beam/core/runtime/exec/datasource.go index 9c4de0564c8..a6347fc8d0e 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/datasource.go +++ b/sdks/go/pkg/beam/core/runtime/exec/datasource.go @@ -196,7 +196,7 @@ func (n *DataSource) Process(ctx context.Context) ([]*Checkpoint, error) { // Check if there's a continuation and return residuals // Needs to be done immeadiately after processing to not lose the element. if c := n.getProcessContinuation(); c != nil { - cp, err := n.checkpointThis(c) + cp, err := n.checkpointThis(ctx, c) if err != nil { // Errors during checkpointing should fail a bundle. return nil, err @@ -422,7 +422,7 @@ type Checkpoint struct { // splittable or has not returned a resuming continuation, the function returns an empty // SplitResult, a negative resumption time, and a false boolean to indicate that no split // occurred. -func (n *DataSource) checkpointThis(pc sdf.ProcessContinuation) (*Checkpoint, error) { +func (n *DataSource) checkpointThis(ctx context.Context, pc sdf.ProcessContinuation) (*Checkpoint, error) { n.mu.Lock() defer n.mu.Unlock() @@ -435,7 +435,7 @@ func (n *DataSource) checkpointThis(pc sdf.ProcessContinuation) (*Checkpoint, er ow := su.GetOutputWatermark() // Checkpointing is functionally a split at fraction 0.0 - rs, err := su.Checkpoint() + rs, err := su.Checkpoint(ctx) if err != nil { return nil, err } @@ -530,7 +530,7 @@ func (n *DataSource) Split(ctx context.Context, splits []int64, frac float64, bu // Get the output watermark before splitting to avoid accidentally overestimating ow := su.GetOutputWatermark() // Otherwise, perform a sub-element split. - ps, rs, err := su.Split(fr) + ps, rs, err := su.Split(ctx, fr) if err != nil { return SplitResult{}, err } diff --git a/sdks/go/pkg/beam/core/runtime/exec/datasource_test.go b/sdks/go/pkg/beam/core/runtime/exec/datasource_test.go index 64a37739b24..2da3284f016 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/datasource_test.go +++ b/sdks/go/pkg/beam/core/runtime/exec/datasource_test.go @@ -631,7 +631,7 @@ type TestSplittableUnit struct { // Split checks the input fraction for correctness, but otherwise always returns // a successful split. The split elements are just copies of the original. -func (n *TestSplittableUnit) Split(f float64) ([]*FullValue, []*FullValue, error) { +func (n *TestSplittableUnit) Split(_ context.Context, f float64) ([]*FullValue, []*FullValue, error) { if f > 1.0 || f < 0.0 { return nil, nil, errors.Errorf("Error") } @@ -639,8 +639,8 @@ func (n *TestSplittableUnit) Split(f float64) ([]*FullValue, []*FullValue, error } // Checkpoint routes through the Split() function to satisfy the interface. -func (n *TestSplittableUnit) Checkpoint() ([]*FullValue, error) { - _, r, err := n.Split(0.0) +func (n *TestSplittableUnit) Checkpoint(ctx context.Context) ([]*FullValue, error) { + _, r, err := n.Split(ctx, 0.0) return r, err } @@ -876,13 +876,13 @@ func TestSplitHelper(t *testing.T) { func TestCheckpointing(t *testing.T) { t.Run("nil", func(t *testing.T) { - cps, err := (&DataSource{}).checkpointThis(nil) + cps, err := (&DataSource{}).checkpointThis(context.Background(), nil) if err != nil { t.Fatalf("checkpointThis() = %v, %v", cps, err) } }) t.Run("Stop", func(t *testing.T) { - cps, err := (&DataSource{}).checkpointThis(sdf.StopProcessing()) + cps, err := (&DataSource{}).checkpointThis(context.Background(), sdf.StopProcessing()) if err != nil { t.Fatalf("checkpointThis() = %v, %v", cps, err) } @@ -899,7 +899,7 @@ func TestCheckpointing(t *testing.T) { }, }, } - cp, err := root.checkpointThis(sdf.ResumeProcessingIn(time.Second * 13)) + cp, err := root.checkpointThis(context.Background(), sdf.ResumeProcessingIn(time.Second*13)) if err != nil { t.Fatalf("checkpointThis() = %v, %v, want nil", cp, err) } diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf.go b/sdks/go/pkg/beam/core/runtime/exec/sdf.go index e22496eae6e..1dd3e35dc4d 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/sdf.go +++ b/sdks/go/pkg/beam/core/runtime/exec/sdf.go @@ -90,7 +90,11 @@ func (n *PairWithRestriction) StartBundle(ctx context.Context, id string, data D // Timestamps // } func (n *PairWithRestriction) ProcessElement(ctx context.Context, elm *FullValue, values ...ReStream) error { - rest := n.inv.Invoke(elm) + rest, err := n.inv.Invoke(ctx, elm) + if err != nil { + return err + } + output := FullValue{Elm: elm, Elm2: &FullValue{Elm: rest, Elm2: n.iwesInv.Invoke(rest, elm)}, Timestamp: elm.Timestamp, Windows: elm.Windows} return n.Out.ProcessElement(ctx, &output, values...) @@ -195,10 +199,17 @@ func (n *SplitAndSizeRestrictions) ProcessElement(ctx context.Context, elm *Full // the element may not be wrapped in a *FullValue mainElm := convertIfNeeded(elm.Elm, &FullValue{}) - splitRests := n.splitInv.Invoke(mainElm, rest) + splitRests, err := n.splitInv.Invoke(ctx, mainElm, rest) + if err != nil { + return err + } for _, splitRest := range splitRests { - size := n.sizeInv.Invoke(mainElm, splitRest) + size, err := n.sizeInv.Invoke(ctx, mainElm, splitRest) + if err != nil { + return err + } + if size < 0 { err := errors.Errorf("size returned expected to be non-negative but received %v.", size) return errors.WithContextf(err, "%v", n) @@ -325,13 +336,25 @@ func (n *TruncateSizedRestriction) ProcessElement(ctx context.Context, elm *Full inp = e } rest := elm.Elm.(*FullValue).Elm2.(*FullValue).Elm - rt := n.ctInv.Invoke(rest) - newRest := n.truncateInv.Invoke(rt, mainElm) + + rt, err := n.ctInv.Invoke(ctx, rest) + if err != nil { + return err + } + + newRest, err := n.truncateInv.Invoke(ctx, rt, mainElm) + if err != nil { + return err + } if newRest == nil { // do not propagate discarded restrictions. return nil } - size := n.sizeInv.Invoke(mainElm, newRest) + + size, err := n.sizeInv.Invoke(ctx, mainElm, newRest) + if err != nil { + return err + } output := &FullValue{} output.Timestamp = elm.Timestamp @@ -476,7 +499,7 @@ func (n *ProcessSizedElementsAndRestrictions) StartBundle(ctx context.Context, i // and processes each element using the underlying ParDo and adding the // restriction tracker to the normal invocation. Sizing information is present // but currently ignored. Output is forwarded to the underlying ParDo's outputs. -func (n *ProcessSizedElementsAndRestrictions) ProcessElement(_ context.Context, elm *FullValue, values ...ReStream) error { +func (n *ProcessSizedElementsAndRestrictions) ProcessElement(ctx context.Context, elm *FullValue, values ...ReStream) error { if n.PDo.status != Active { err := errors.Errorf("invalid status %v, want Active", n.PDo.status) return errors.WithContextf(err, "%v", n) @@ -520,7 +543,12 @@ func (n *ProcessSizedElementsAndRestrictions) ProcessElement(_ context.Context, // If windows don't need to be exploded (i.e. aren't observed), treat // all windows as one as an optimization. rest := elm.Elm.(*FullValue).Elm2.(*FullValue).Elm - rt := n.ctInv.Invoke(rest) + + rt, err := n.ctInv.Invoke(ctx, rest) + if err != nil { + return err + } + mainIn.RTracker = rt n.numW = 1 // Even if there's more than one window, treat them as one. @@ -542,7 +570,12 @@ func (n *ProcessSizedElementsAndRestrictions) ProcessElement(_ context.Context, for i := 0; i < n.numW; i++ { rest := elm.Elm.(*FullValue).Elm2.(*FullValue).Elm - rt := n.ctInv.Invoke(rest) + + rt, err := n.ctInv.Invoke(ctx, rest) + if err != nil { + return err + } + key := &mainIn.Key w := elm.Windows[i] wElm := FullValue{Elm: key.Elm, Elm2: key.Elm2, Timestamp: key.Timestamp, Windows: []typex.Window{w}} @@ -552,7 +585,7 @@ func (n *ProcessSizedElementsAndRestrictions) ProcessElement(_ context.Context, n.elm = elm n.SU <- n // TODO(BEAM-11104): Remove placeholder for ProcessContinuation return. - _, err := n.PDo.processSingleWindow(&MainInput{Key: wElm, Values: mainIn.Values, RTracker: rt}) + _, err = n.PDo.processSingleWindow(&MainInput{Key: wElm, Values: mainIn.Values, RTracker: rt}) if err != nil { <-n.SU return n.PDo.fail(err) @@ -596,13 +629,13 @@ type SplittableUnit interface { // // More than one primary/residual can happen if the split result cannot be // fully represented in just one. - Split(fraction float64) (primaries, residuals []*FullValue, err error) + Split(ctx context.Context, fraction float64) (primaries, residuals []*FullValue, err error) // Checkpoint performs a split at fraction 0.0 of an element that has stopped // processing and has work that needs to be resumed later. This function will // check that the produced primary restriction from the split represents // completed work to avoid data loss and will error if work remains. - Checkpoint() (residuals []*FullValue, err error) + Checkpoint(ctx context.Context) (residuals []*FullValue, err error) // GetProgress returns the fraction of progress the current element has // made in processing. (ex. 0.0 means no progress, and 1.0 means fully @@ -631,7 +664,7 @@ type SplittableUnit interface { // windows need to be taken into account. For implementation details on when // each case occurs and the implementation details, see the documentation for // the singleWindowSplit and multiWindowSplit methods. -func (n *ProcessSizedElementsAndRestrictions) Split(f float64) ([]*FullValue, []*FullValue, error) { +func (n *ProcessSizedElementsAndRestrictions) Split(ctx context.Context, f float64) ([]*FullValue, []*FullValue, error) { // Get the watermark state immediately so that we don't overestimate our current watermark. var pWeState any var rWeState any @@ -658,7 +691,7 @@ func (n *ProcessSizedElementsAndRestrictions) Split(f float64) ([]*FullValue, [] // Split behavior differs depending on whether this is a window-observing // DoFn or not. if len(n.elm.Windows) > 1 { - p, r, err := n.multiWindowSplit(f, pWeState, rWeState) + p, r, err := n.multiWindowSplit(ctx, f, pWeState, rWeState) if err != nil { return nil, nil, addContext(err) } @@ -666,7 +699,7 @@ func (n *ProcessSizedElementsAndRestrictions) Split(f float64) ([]*FullValue, [] } // Not window-observing, or window-observing but only one window. - p, r, err := n.singleWindowSplit(f, pWeState, rWeState) + p, r, err := n.singleWindowSplit(ctx, f, pWeState, rWeState) if err != nil { return nil, nil, addContext(err) } @@ -677,11 +710,11 @@ func (n *ProcessSizedElementsAndRestrictions) Split(f float64) ([]*FullValue, [] // later by the runner. This is done iff the underlying Splittable DoFn returns a resuming // ProcessContinuation. If the split occurs and the primary restriction is marked as done // my the RTracker, the Checkpoint fails as this is a potential data-loss case. -func (n *ProcessSizedElementsAndRestrictions) Checkpoint() ([]*FullValue, error) { +func (n *ProcessSizedElementsAndRestrictions) Checkpoint(ctx context.Context) ([]*FullValue, error) { addContext := func(err error) error { return errors.WithContext(err, "Attempting checkpoint in ProcessSizedElementsAndRestrictions") } - _, r, err := n.Split(0.0) + _, r, err := n.Split(ctx, 0.0) if err != nil { return nil, addContext(err) @@ -699,7 +732,7 @@ func (n *ProcessSizedElementsAndRestrictions) Checkpoint() ([]*FullValue, error) // behavior is identical). A single restriction split will occur and all windows // present in the unsplit element will be present in both the resulting primary // and residual. -func (n *ProcessSizedElementsAndRestrictions) singleWindowSplit(f float64, pWeState, rWeState any) ([]*FullValue, []*FullValue, error) { +func (n *ProcessSizedElementsAndRestrictions) singleWindowSplit(ctx context.Context, f float64, pWeState, rWeState any) ([]*FullValue, []*FullValue, error) { if n.rt.IsDone() { // Not an error, but not splittable. return []*FullValue{}, []*FullValue{}, nil } @@ -714,14 +747,14 @@ func (n *ProcessSizedElementsAndRestrictions) singleWindowSplit(f float64, pWeSt var primaryResult []*FullValue if p != nil { - pfv, err := n.newSplitResult(p, n.elm.Windows, pWeState) + pfv, err := n.newSplitResult(ctx, p, n.elm.Windows, pWeState) if err != nil { return nil, nil, err } primaryResult = append(primaryResult, pfv) } - rfv, err := n.newSplitResult(r, n.elm.Windows, rWeState) + rfv, err := n.newSplitResult(ctx, r, n.elm.Windows, rWeState) if err != nil { return nil, nil, err } @@ -752,7 +785,7 @@ func (n *ProcessSizedElementsAndRestrictions) singleWindowSplit(f float64, pWeSt // // This method also updates the current number of windows (n.numW) so that // windows in the residual will no longer be processed. -func (n *ProcessSizedElementsAndRestrictions) multiWindowSplit(f float64, pWeState any, rWeState any) ([]*FullValue, []*FullValue, error) { +func (n *ProcessSizedElementsAndRestrictions) multiWindowSplit(ctx context.Context, f float64, pWeState any, rWeState any) ([]*FullValue, []*FullValue, error) { // Get the split point in window range, to see what window it falls in. done, rem := n.rt.GetProgress() cwp := done / (done + rem) // Progress in current window. @@ -765,25 +798,25 @@ func (n *ProcessSizedElementsAndRestrictions) multiWindowSplit(f float64, pWeSta if n.rt.IsDone() { // Current RTracker is done so we can't split within the window, so // split at window boundary instead. - return n.windowBoundarySplit(n.currW+1, pWeState, rWeState) + return n.windowBoundarySplit(ctx, n.currW+1, pWeState, rWeState) } // Get the fraction of remaining work in the current window to split at. cwsp := wsp - float64(n.currW) // Split point in current window. rf := (cwsp - cwp) / (1 - cwp) // Fraction of work in RTracker to split at. - return n.currentWindowSplit(rf, pWeState, rWeState) + return n.currentWindowSplit(ctx, rf, pWeState, rWeState) } else { // Split at nearest window boundary to split point. wb := math.Round(wsp) - return n.windowBoundarySplit(int(wb), pWeState, rWeState) + return n.windowBoundarySplit(ctx, int(wb), pWeState, rWeState) } } // currentWindowSplit performs an appropriate split at the given fraction of // remaining work in the current window. Also updates numW to stop after the // current window. -func (n *ProcessSizedElementsAndRestrictions) currentWindowSplit(f float64, pWeState any, rWeState any) ([]*FullValue, []*FullValue, error) { +func (n *ProcessSizedElementsAndRestrictions) currentWindowSplit(ctx context.Context, f float64, pWeState any, rWeState any) ([]*FullValue, []*FullValue, error) { p, r, err := n.rt.TrySplit(f) if err != nil { return nil, nil, err @@ -791,18 +824,18 @@ func (n *ProcessSizedElementsAndRestrictions) currentWindowSplit(f float64, pWeS if r == nil { // If r is nil then the split failed/returned an empty residual, but // we can still split at a window boundary. - return n.windowBoundarySplit(n.currW+1, pWeState, rWeState) + return n.windowBoundarySplit(ctx, n.currW+1, pWeState, rWeState) } // Split of currently processing restriction in a single window. ps := make([]*FullValue, 1) - newP, err := n.newSplitResult(p, n.elm.Windows[n.currW:n.currW+1], pWeState) + newP, err := n.newSplitResult(ctx, p, n.elm.Windows[n.currW:n.currW+1], pWeState) if err != nil { return nil, nil, err } ps[0] = newP rs := make([]*FullValue, 1) - newR, err := n.newSplitResult(r, n.elm.Windows[n.currW:n.currW+1], rWeState) + newR, err := n.newSplitResult(ctx, r, n.elm.Windows[n.currW:n.currW+1], rWeState) if err != nil { return nil, nil, err } @@ -810,14 +843,14 @@ func (n *ProcessSizedElementsAndRestrictions) currentWindowSplit(f float64, pWeS // Window boundary split surrounding the split restriction above. full := n.elm.Elm.(*FullValue).Elm2.(*FullValue).Elm if 0 < n.currW { - newP, err := n.newSplitResult(full, n.elm.Windows[0:n.currW], pWeState) + newP, err := n.newSplitResult(ctx, full, n.elm.Windows[0:n.currW], pWeState) if err != nil { return nil, nil, err } ps = append(ps, newP) } if n.currW+1 < n.numW { - newR, err := n.newSplitResult(full, n.elm.Windows[n.currW+1:n.numW], rWeState) + newR, err := n.newSplitResult(ctx, full, n.elm.Windows[n.currW+1:n.numW], rWeState) if err != nil { return nil, nil, err } @@ -830,17 +863,17 @@ func (n *ProcessSizedElementsAndRestrictions) currentWindowSplit(f float64, pWeS // windowBoundarySplit performs an appropriate split at a window boundary. The // split point taken should be the index of the first window in the residual. // Also updates numW to stop at the split point. -func (n *ProcessSizedElementsAndRestrictions) windowBoundarySplit(splitPt int, pWeState any, rWeState any) ([]*FullValue, []*FullValue, error) { +func (n *ProcessSizedElementsAndRestrictions) windowBoundarySplit(ctx context.Context, splitPt int, pWeState any, rWeState any) ([]*FullValue, []*FullValue, error) { // If this is at the boundary of the last window, split is a no-op. if splitPt == n.numW { return []*FullValue{}, []*FullValue{}, nil } full := n.elm.Elm.(*FullValue).Elm2.(*FullValue).Elm - pFv, err := n.newSplitResult(full, n.elm.Windows[0:splitPt], pWeState) + pFv, err := n.newSplitResult(ctx, full, n.elm.Windows[0:splitPt], pWeState) if err != nil { return nil, nil, err } - rFv, err := n.newSplitResult(full, n.elm.Windows[splitPt:n.numW], rWeState) + rFv, err := n.newSplitResult(ctx, full, n.elm.Windows[splitPt:n.numW], rWeState) if err != nil { return nil, nil, err } @@ -852,18 +885,27 @@ func (n *ProcessSizedElementsAndRestrictions) windowBoundarySplit(splitPt int, p // element restriction pair based on the currently processing element, but with // a modified restriction and windows. Intended for creating primaries and // residuals to return as split results. -func (n *ProcessSizedElementsAndRestrictions) newSplitResult(rest any, w []typex.Window, weState any) (*FullValue, error) { +func (n *ProcessSizedElementsAndRestrictions) newSplitResult(ctx context.Context, rest any, w []typex.Window, weState any) (*FullValue, error) { var size float64 + var err error elm := n.elm.Elm.(*FullValue).Elm if fv, ok := elm.(*FullValue); ok { - size = n.sizeInv.Invoke(fv, rest) + size, err = n.sizeInv.Invoke(ctx, fv, rest) + if err != nil { + return nil, err + } + if size < 0 { err := errors.Errorf("size returned expected to be non-negative but received %v.", size) return nil, errors.WithContextf(err, "%v", n) } } else { fv := &FullValue{Elm: elm} - size = n.sizeInv.Invoke(fv, rest) + size, err = n.sizeInv.Invoke(ctx, fv, rest) + if err != nil { + return nil, err + } + if size < 0 { err := errors.Errorf("size returned expected to be non-negative but received %v.", size) return nil, errors.WithContextf(err, "%v", n) @@ -973,21 +1015,33 @@ func (n *SdfFallback) StartBundle(ctx context.Context, id string, data DataConte // restrictions, and then creating restriction trackers and processing each // restriction with the underlying ParDo. This executor skips the sizing step // because sizing information is unnecessary for unexpanded SDFs. -func (n *SdfFallback) ProcessElement(_ context.Context, elm *FullValue, values ...ReStream) error { +func (n *SdfFallback) ProcessElement(ctx context.Context, elm *FullValue, values ...ReStream) error { if n.PDo.status != Active { err := errors.Errorf("invalid status %v, want Active", n.PDo.status) return errors.WithContextf(err, "%v", n) } - rest := n.initRestInv.Invoke(elm) - splitRests := n.splitInv.Invoke(elm, rest) + rest, err := n.initRestInv.Invoke(ctx, elm) + if err != nil { + return err + } + + splitRests, err := n.splitInv.Invoke(ctx, elm, rest) + if err != nil { + return err + } + if len(splitRests) == 0 { err := errors.Errorf("initial splitting returned 0 restrictions.") return errors.WithContextf(err, "%v", n) } for _, splitRest := range splitRests { - rt := n.trackerInv.Invoke(splitRest) + rt, err := n.trackerInv.Invoke(ctx, splitRest) + if err != nil { + return err + } + mainIn := &MainInput{ Key: *elm, Values: values, diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers.go b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers.go index 5d58f198a64..2dd894ed08c 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers.go +++ b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers.go @@ -16,6 +16,7 @@ package exec import ( + "context" "reflect" "github.com/apache/beam/sdks/v2/go/pkg/beam/core/funcx" @@ -24,6 +25,9 @@ import ( "github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors" ) +//go:generate specialize --input=sdf_invokers_arity.tmpl +//go:generate gofmt -w sdf_invokers_arity.go + // This file contains invokers for SDF methods. These invokers are based off // exec.invoker which is used for regular DoFns. Since exec.invoker is // specialized for DoFns it cannot be used for SDF methods. Instead, these @@ -39,9 +43,10 @@ import ( // cirInvoker is an invoker for CreateInitialRestriction. type cirInvoker struct { - fn *funcx.Fn - args []any // Cache to avoid allocating new slices per-element. - call func(elms *FullValue) (rest any) + fn *funcx.Fn + args []any // Cache to avoid allocating new slices per-element. + ctxIdx int + call func() (rest any, err error) } func newCreateInitialRestrictionInvoker(fn *funcx.Fn) (*cirInvoker, error) { @@ -49,51 +54,34 @@ func newCreateInitialRestrictionInvoker(fn *funcx.Fn) (*cirInvoker, error) { fn: fn, args: make([]any, len(fn.Param)), } + + var ok bool + if n.ctxIdx, ok = fn.Context(); !ok { + n.ctxIdx = -1 + } + if err := n.initCallFn(); err != nil { return nil, errors.WithContext(err, "sdf CreateInitialRestriction invoker") } return n, nil } -func (n *cirInvoker) initCallFn() error { - // Expects a signature of the form: - // (key?, value) restriction - // TODO(BEAM-9643): Link to full documentation. - switch fnT := n.fn.Fn.(type) { - case reflectx.Func1x1: - n.call = func(elms *FullValue) any { - return fnT.Call1x1(elms.Elm) - } - case reflectx.Func2x1: - n.call = func(elms *FullValue) any { - return fnT.Call2x1(elms.Elm, elms.Elm2) - } - default: - switch len(n.fn.Param) { - case 1: - n.call = func(elms *FullValue) any { - n.args[0] = elms.Elm - return n.fn.Fn.Call(n.args)[0] - } - case 2: - n.call = func(elms *FullValue) any { - n.args[0] = elms.Elm - n.args[1] = elms.Elm2 - return n.fn.Fn.Call(n.args)[0] - } - default: - return errors.Errorf("CreateInitialRestriction fn %v has unexpected number of parameters: %v", - n.fn.Fn.Name(), len(n.fn.Param)) - } +// Invoke calls CreateInitialRestriction with the given FullValue as the element +// and returns the resulting restriction. +func (n *cirInvoker) Invoke(ctx context.Context, elms *FullValue) (rest any, err error) { + if n.ctxIdx >= 0 { + n.args[n.ctxIdx] = ctx } - return nil -} + i := n.ctxIdx + 1 + n.args[i] = elms.Elm -// Invoke calls CreateInitialRestriction with the given FullValue as the element -// and returns the resulting restriction. -func (n *cirInvoker) Invoke(elms *FullValue) (rest any) { - return n.call(elms) + if elms.Elm2 != nil { + i++ + n.args[i] = elms.Elm2 + } + + return n.call() } // Reset zeroes argument entries in the cached slice to allow values to be @@ -106,9 +94,10 @@ func (n *cirInvoker) Reset() { // srInvoker is an invoker for SplitRestriction. type srInvoker struct { - fn *funcx.Fn - args []any // Cache to avoid allocating new slices per-element. - call func(elms *FullValue, rest any) (splits any) + fn *funcx.Fn + args []any // Cache to avoid allocating new slices per-element. + ctxIdx int + call func() (splits any, err error) } func newSplitRestrictionInvoker(fn *funcx.Fn) (*srInvoker, error) { @@ -116,52 +105,40 @@ func newSplitRestrictionInvoker(fn *funcx.Fn) (*srInvoker, error) { fn: fn, args: make([]any, len(fn.Param)), } + + var ok bool + if n.ctxIdx, ok = fn.Context(); !ok { + n.ctxIdx = -1 + } + if err := n.initCallFn(); err != nil { return nil, errors.WithContext(err, "sdf SplitRestriction invoker") } return n, nil } -func (n *srInvoker) initCallFn() error { - // Expects a signature of the form: - // (key?, value, restriction) []restriction - // TODO(BEAM-9643): Link to full documentation. - switch fnT := n.fn.Fn.(type) { - case reflectx.Func2x1: - n.call = func(elms *FullValue, rest any) any { - return fnT.Call2x1(elms.Elm, rest) - } - case reflectx.Func3x1: - n.call = func(elms *FullValue, rest any) any { - return fnT.Call3x1(elms.Elm, elms.Elm2, rest) - } - default: - switch len(n.fn.Param) { - case 2: - n.call = func(elms *FullValue, rest any) any { - n.args[0] = elms.Elm - n.args[1] = rest - return n.fn.Fn.Call(n.args)[0] - } - case 3: - n.call = func(elms *FullValue, rest any) any { - n.args[0] = elms.Elm - n.args[1] = elms.Elm2 - n.args[2] = rest - return n.fn.Fn.Call(n.args)[0] - } - default: - return errors.Errorf("SplitRestriction fn %v has unexpected number of parameters: %v", - n.fn.Fn.Name(), len(n.fn.Param)) - } - } - return nil -} - // Invoke calls SplitRestriction given a FullValue containing an element and // the associated restriction, and returns a slice of split restrictions. -func (n *srInvoker) Invoke(elms *FullValue, rest any) (splits []any) { - ret := n.call(elms, rest) +func (n *srInvoker) Invoke(ctx context.Context, elms *FullValue, rest any) (splits []any, err error) { + if n.ctxIdx >= 0 { + n.args[n.ctxIdx] = ctx + } + + i := n.ctxIdx + 1 + n.args[i] = elms.Elm + + if elms.Elm2 != nil { + i++ + n.args[i] = elms.Elm2 + } + + i++ + n.args[i] = rest + + ret, err := n.call() + if err != nil { + return nil, err + } // Return value is an any, but we need to convert it to a []any. val := reflect.ValueOf(ret) @@ -169,7 +146,7 @@ func (n *srInvoker) Invoke(elms *FullValue, rest any) (splits []any) { for i := 0; i < val.Len(); i++ { s = append(s, val.Index(i).Interface()) } - return s + return s, nil } // Reset zeroes argument entries in the cached slice to allow values to be @@ -182,9 +159,10 @@ func (n *srInvoker) Reset() { // rsInvoker is an invoker for RestrictionSize. type rsInvoker struct { - fn *funcx.Fn - args []any // Cache to avoid allocating new slices per-element. - call func(elms *FullValue, rest any) (size float64) + fn *funcx.Fn + args []any // Cache to avoid allocating new slices per-element. + ctxIdx int + call func() (size float64, err error) } func newRestrictionSizeInvoker(fn *funcx.Fn) (*rsInvoker, error) { @@ -192,52 +170,37 @@ func newRestrictionSizeInvoker(fn *funcx.Fn) (*rsInvoker, error) { fn: fn, args: make([]any, len(fn.Param)), } + + var ok bool + if n.ctxIdx, ok = fn.Context(); !ok { + n.ctxIdx = -1 + } + if err := n.initCallFn(); err != nil { return nil, errors.WithContext(err, "sdf RestrictionSize invoker") } return n, nil } -func (n *rsInvoker) initCallFn() error { - // Expects a signature of the form: - // (key?, value, restriction) float64 - // TODO(BEAM-9643): Link to full documentation. - switch fnT := n.fn.Fn.(type) { - case reflectx.Func2x1: - n.call = func(elms *FullValue, rest any) float64 { - return fnT.Call2x1(elms.Elm, rest).(float64) - } - case reflectx.Func3x1: - n.call = func(elms *FullValue, rest any) float64 { - return fnT.Call3x1(elms.Elm, elms.Elm2, rest).(float64) - } - default: - switch len(n.fn.Param) { - case 2: - n.call = func(elms *FullValue, rest any) float64 { - n.args[0] = elms.Elm - n.args[1] = rest - return n.fn.Fn.Call(n.args)[0].(float64) - } - case 3: - n.call = func(elms *FullValue, rest any) float64 { - n.args[0] = elms.Elm - n.args[1] = elms.Elm2 - n.args[2] = rest - return n.fn.Fn.Call(n.args)[0].(float64) - } - default: - return errors.Errorf("RestrictionSize fn %v has unexpected number of parameters: %v", - n.fn.Fn.Name(), len(n.fn.Param)) - } - } - return nil -} - // Invoke calls RestrictionSize given a FullValue containing an element and // the associated restriction, and returns a size. -func (n *rsInvoker) Invoke(elms *FullValue, rest any) (size float64) { - return n.call(elms, rest) +func (n *rsInvoker) Invoke(ctx context.Context, elms *FullValue, rest any) (size float64, err error) { + if n.ctxIdx >= 0 { + n.args[n.ctxIdx] = ctx + } + + i := n.ctxIdx + 1 + n.args[i] = elms.Elm + + if elms.Elm2 != nil { + i++ + n.args[i] = elms.Elm2 + } + + i++ + n.args[i] = rest + + return n.call() } // Reset zeroes argument entries in the cached slice to allow values to be @@ -250,9 +213,10 @@ func (n *rsInvoker) Reset() { // ctInvoker is an invoker for CreateTracker. type ctInvoker struct { - fn *funcx.Fn - args []any // Cache to avoid allocating new slices per-element. - call func(rest any) sdf.RTracker + fn *funcx.Fn + args []any // Cache to avoid allocating new slices per-element. + ctxIdx int + call func() (rt sdf.RTracker, err error) } func newCreateTrackerInvoker(fn *funcx.Fn) (*ctInvoker, error) { @@ -260,37 +224,27 @@ func newCreateTrackerInvoker(fn *funcx.Fn) (*ctInvoker, error) { fn: fn, args: make([]any, len(fn.Param)), } + + var ok bool + if n.ctxIdx, ok = fn.Context(); !ok { + n.ctxIdx = -1 + } + if err := n.initCallFn(); err != nil { return nil, errors.WithContext(err, "sdf CreateTracker invoker") } return n, nil } -func (n *ctInvoker) initCallFn() error { - // Expects a signature of the form: - // (restriction) sdf.RTracker - // TODO(BEAM-9643): Link to full documentation. - switch fnT := n.fn.Fn.(type) { - case reflectx.Func1x1: - n.call = func(rest any) sdf.RTracker { - return fnT.Call1x1(rest).(sdf.RTracker) - } - default: - if len(n.fn.Param) != 1 { - return errors.Errorf("CreateTracker fn %v has unexpected number of parameters: %v", - n.fn.Fn.Name(), len(n.fn.Param)) - } - n.call = func(rest any) sdf.RTracker { - n.args[0] = rest - return n.fn.Fn.Call(n.args)[0].(sdf.RTracker) - } +// Invoke calls CreateTracker given a restriction and returns an sdf.RTracker. +func (n *ctInvoker) Invoke(ctx context.Context, rest any) (sdf.RTracker, error) { + if n.ctxIdx >= 0 { + n.args[n.ctxIdx] = ctx } - return nil -} -// Invoke calls CreateTracker given a restriction and returns an sdf.RTracker. -func (n *ctInvoker) Invoke(rest any) sdf.RTracker { - return n.call(rest) + n.args[n.ctxIdx+1] = rest + + return n.call() } // Reset zeroes argument entries in the cached slice to allow values to be @@ -303,9 +257,10 @@ func (n *ctInvoker) Reset() { // trInvoker is an invoker for TruncateRestriction. type trInvoker struct { - fn *funcx.Fn - args []any - call func(rest any, elms *FullValue) (pair any) + fn *funcx.Fn + args []any + ctxIdx int + call func() (rest any, err error) } func defaultTruncateRestriction(restTracker any) (newRest any) { @@ -320,6 +275,12 @@ func newTruncateRestrictionInvoker(fn *funcx.Fn) (*trInvoker, error) { fn: fn, args: make([]any, len(fn.Param)), } + + var ok bool + if n.ctxIdx, ok = fn.Context(); !ok { + n.ctxIdx = -1 + } + if err := n.initCallFn(); err != nil { return nil, errors.WithContext(err, "sdf TruncateRestriction invoker") } @@ -327,53 +288,39 @@ func newTruncateRestrictionInvoker(fn *funcx.Fn) (*trInvoker, error) { } func newDefaultTruncateRestrictionInvoker() (*trInvoker, error) { - n := &trInvoker{} - n.call = func(rest any, elms *FullValue) any { - return defaultTruncateRestriction(rest) + n := &trInvoker{ + args: make([]any, 1), } - return n, nil -} - -func (n *trInvoker) initCallFn() error { - // Expects a signature of the form: - // (key?, value, restriction) []restriction - // TODO(BEAM-9643): Link to full documentation. - switch fnT := n.fn.Fn.(type) { - case reflectx.Func2x1: - n.call = func(rest any, elms *FullValue) any { - return fnT.Call2x1(rest, elms.Elm) - } - case reflectx.Func3x1: - n.call = func(rest any, elms *FullValue) any { - return fnT.Call3x1(rest, elms.Elm, elms.Elm2) - } - default: - switch len(n.fn.Param) { - case 2: - n.call = func(rest any, elms *FullValue) any { - n.args[0] = rest - n.args[1] = elms.Elm - return n.fn.Fn.Call(n.args)[0] - } - case 3: - n.call = func(rest any, elms *FullValue) any { - n.args[0] = rest - n.args[1] = elms.Elm - n.args[2] = elms.Elm2 - return n.fn.Fn.Call(n.args)[0] - } - default: - return errors.Errorf("TruncateRestriction fn %v has unexpected number of parameters: %v", - n.fn.Fn.Name(), len(n.fn.Param)) - } + n.call = func() (any, error) { + return defaultTruncateRestriction(n.args[0]), nil } - return nil + return n, nil } // Invoke calls TruncateRestriction given a FullValue containing an element and // the associated restriction tracker, and returns a truncated restriction. -func (n *trInvoker) Invoke(rt any, elms *FullValue) (rest any) { - return n.call(rt, elms) +func (n *trInvoker) Invoke(ctx context.Context, rt any, elms *FullValue) (rest any, err error) { + if n.fn == nil { + n.args[0] = rt + return n.call() + } + + if n.ctxIdx >= 0 { + n.args[n.ctxIdx] = ctx + } + + i := n.ctxIdx + 1 + n.args[i] = rt + + i++ + n.args[i] = elms.Elm + + if elms.Elm2 != nil { + i++ + n.args[i] = elms.Elm2 + } + + return n.call() } // Reset zeroes argument entries in the cached slice to allow values to be @@ -589,3 +536,11 @@ func (n *wesInvoker) Reset() { n.args[i] = nil } } + +func asError(val any) error { + if val != nil { + return val.(error) + } + + return nil +} diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_arity.go b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_arity.go new file mode 100644 index 00000000000..cdefa711603 --- /dev/null +++ b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_arity.go @@ -0,0 +1,336 @@ +// File generated by specialize. Do not edit. + +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated from sdf_invokers_arity.tmpl. DO NOT EDIT. + +package exec + +import ( + "fmt" + + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx" + "github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors" +) + +func (n *cirInvoker) initCallFn() error { + // Expects a signature of the form: + // (context.Context?, key?, value) (restriction, error?) + // TODO(BEAM-9643): Link to full documentation. + switch fnT := n.fn.Fn.(type) { + + case reflectx.Func1x1: + n.call = func() (rest any, err error) { + r0 := fnT.Call1x1(n.args[0]) + return r0, nil + } + + case reflectx.Func2x1: + n.call = func() (rest any, err error) { + r0 := fnT.Call2x1(n.args[0], n.args[1]) + return r0, nil + } + + case reflectx.Func3x1: + n.call = func() (rest any, err error) { + r0 := fnT.Call3x1(n.args[0], n.args[1], n.args[2]) + return r0, nil + } + + case reflectx.Func1x2: + n.call = func() (rest any, err error) { + r0, r1 := fnT.Call1x2(n.args[0]) + return r0, asError(r1) + } + + case reflectx.Func2x2: + n.call = func() (rest any, err error) { + r0, r1 := fnT.Call2x2(n.args[0], n.args[1]) + return r0, asError(r1) + } + + case reflectx.Func3x2: + n.call = func() (rest any, err error) { + r0, r1 := fnT.Call3x2(n.args[0], n.args[1], n.args[2]) + return r0, asError(r1) + } + + default: + if len(n.fn.Param) < 1 || len(n.fn.Param) > 3 { + return errors.Errorf("CreateInitialRestriction has unexpected number of parameters: %v", len(n.fn.Param)) + } + + n.call = func() (rest any, err error) { + ret := n.fn.Fn.Call(n.args) + + switch len(ret) { + case 1: + return ret[0], nil + case 2: + return ret[0], asError(ret[1]) + } + + panic(fmt.Sprintf("CreateInitialRestriction has unexpected number of return values: %v", len(ret))) + } + } + + return nil +} + +func (n *srInvoker) initCallFn() error { + // Expects a signature of the form: + // (context.Context?, key?, value, restriction) ([]restriction, error?) + // TODO(BEAM-9643): Link to full documentation. + switch fnT := n.fn.Fn.(type) { + + case reflectx.Func2x1: + n.call = func() (splits any, err error) { + r0 := fnT.Call2x1(n.args[0], n.args[1]) + return r0, nil + } + + case reflectx.Func3x1: + n.call = func() (splits any, err error) { + r0 := fnT.Call3x1(n.args[0], n.args[1], n.args[2]) + return r0, nil + } + + case reflectx.Func4x1: + n.call = func() (splits any, err error) { + r0 := fnT.Call4x1(n.args[0], n.args[1], n.args[2], n.args[3]) + return r0, nil + } + + case reflectx.Func2x2: + n.call = func() (splits any, err error) { + r0, r1 := fnT.Call2x2(n.args[0], n.args[1]) + return r0, asError(r1) + } + + case reflectx.Func3x2: + n.call = func() (splits any, err error) { + r0, r1 := fnT.Call3x2(n.args[0], n.args[1], n.args[2]) + return r0, asError(r1) + } + + case reflectx.Func4x2: + n.call = func() (splits any, err error) { + r0, r1 := fnT.Call4x2(n.args[0], n.args[1], n.args[2], n.args[3]) + return r0, asError(r1) + } + + default: + if len(n.fn.Param) < 2 || len(n.fn.Param) > 4 { + return errors.Errorf("SplitRestriction has unexpected number of parameters: %v", len(n.fn.Param)) + } + + n.call = func() (splits any, err error) { + ret := n.fn.Fn.Call(n.args) + + switch len(ret) { + case 1: + return ret[0], nil + case 2: + return ret[0], asError(ret[1]) + } + + panic(fmt.Sprintf("SplitRestriction has unexpected number of return values: %v", len(ret))) + } + } + + return nil +} + +func (n *rsInvoker) initCallFn() error { + // Expects a signature of the form: + // (context.Context?, key?, value, restriction) (float64, error?) + // TODO(BEAM-9643): Link to full documentation. + switch fnT := n.fn.Fn.(type) { + + case reflectx.Func2x1: + n.call = func() (size float64, err error) { + r0 := fnT.Call2x1(n.args[0], n.args[1]) + return r0.(float64), nil + } + + case reflectx.Func3x1: + n.call = func() (size float64, err error) { + r0 := fnT.Call3x1(n.args[0], n.args[1], n.args[2]) + return r0.(float64), nil + } + + case reflectx.Func4x1: + n.call = func() (size float64, err error) { + r0 := fnT.Call4x1(n.args[0], n.args[1], n.args[2], n.args[3]) + return r0.(float64), nil + } + + case reflectx.Func2x2: + n.call = func() (size float64, err error) { + r0, r1 := fnT.Call2x2(n.args[0], n.args[1]) + return r0.(float64), asError(r1) + } + + case reflectx.Func3x2: + n.call = func() (size float64, err error) { + r0, r1 := fnT.Call3x2(n.args[0], n.args[1], n.args[2]) + return r0.(float64), asError(r1) + } + + case reflectx.Func4x2: + n.call = func() (size float64, err error) { + r0, r1 := fnT.Call4x2(n.args[0], n.args[1], n.args[2], n.args[3]) + return r0.(float64), asError(r1) + } + + default: + if len(n.fn.Param) < 2 || len(n.fn.Param) > 4 { + return errors.Errorf("RestrictionSize has unexpected number of parameters: %v", len(n.fn.Param)) + } + + n.call = func() (size float64, err error) { + ret := n.fn.Fn.Call(n.args) + + switch len(ret) { + case 1: + return ret[0].(float64), nil + case 2: + return ret[0].(float64), asError(ret[1]) + } + + panic(fmt.Sprintf("RestrictionSize has unexpected number of return values: %v", len(ret))) + } + } + + return nil +} + +func (n *ctInvoker) initCallFn() error { + // Expects a signature of the form: + // (context.Context?, restriction) (sdf.RTracker, error?) + // TODO(BEAM-9643): Link to full documentation. + switch fnT := n.fn.Fn.(type) { + + case reflectx.Func1x1: + n.call = func() (rt sdf.RTracker, err error) { + r0 := fnT.Call1x1(n.args[0]) + return r0.(sdf.RTracker), nil + } + + case reflectx.Func2x1: + n.call = func() (rt sdf.RTracker, err error) { + r0 := fnT.Call2x1(n.args[0], n.args[1]) + return r0.(sdf.RTracker), nil + } + + case reflectx.Func1x2: + n.call = func() (rt sdf.RTracker, err error) { + r0, r1 := fnT.Call1x2(n.args[0]) + return r0.(sdf.RTracker), asError(r1) + } + + case reflectx.Func2x2: + n.call = func() (rt sdf.RTracker, err error) { + r0, r1 := fnT.Call2x2(n.args[0], n.args[1]) + return r0.(sdf.RTracker), asError(r1) + } + + default: + if len(n.fn.Param) < 1 || len(n.fn.Param) > 2 { + return errors.Errorf("CreateTracker has unexpected number of parameters: %v", len(n.fn.Param)) + } + + n.call = func() (rt sdf.RTracker, err error) { + ret := n.fn.Fn.Call(n.args) + + switch len(ret) { + case 1: + return ret[0].(sdf.RTracker), nil + case 2: + return ret[0].(sdf.RTracker), asError(ret[1]) + } + + panic(fmt.Sprintf("CreateTracker has unexpected number of return values: %v", len(ret))) + } + } + + return nil +} + +func (n *trInvoker) initCallFn() error { + // Expects a signature of the form: + // (context.Context?, sdf.RTracker, key?, value) (restriction, error?) + // TODO(BEAM-9643): Link to full documentation. + switch fnT := n.fn.Fn.(type) { + + case reflectx.Func2x1: + n.call = func() (rest any, err error) { + r0 := fnT.Call2x1(n.args[0], n.args[1]) + return r0, nil + } + + case reflectx.Func3x1: + n.call = func() (rest any, err error) { + r0 := fnT.Call3x1(n.args[0], n.args[1], n.args[2]) + return r0, nil + } + + case reflectx.Func4x1: + n.call = func() (rest any, err error) { + r0 := fnT.Call4x1(n.args[0], n.args[1], n.args[2], n.args[3]) + return r0, nil + } + + case reflectx.Func2x2: + n.call = func() (rest any, err error) { + r0, r1 := fnT.Call2x2(n.args[0], n.args[1]) + return r0, asError(r1) + } + + case reflectx.Func3x2: + n.call = func() (rest any, err error) { + r0, r1 := fnT.Call3x2(n.args[0], n.args[1], n.args[2]) + return r0, asError(r1) + } + + case reflectx.Func4x2: + n.call = func() (rest any, err error) { + r0, r1 := fnT.Call4x2(n.args[0], n.args[1], n.args[2], n.args[3]) + return r0, asError(r1) + } + + default: + if len(n.fn.Param) < 2 || len(n.fn.Param) > 4 { + return errors.Errorf("TruncateRestriction has unexpected number of parameters: %v", len(n.fn.Param)) + } + + n.call = func() (rest any, err error) { + ret := n.fn.Fn.Call(n.args) + + switch len(ret) { + case 1: + return ret[0], nil + case 2: + return ret[0], asError(ret[1]) + } + + panic(fmt.Sprintf("TruncateRestriction has unexpected number of return values: %v", len(ret))) + } + } + + return nil +} diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_arity.tmpl b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_arity.tmpl new file mode 100644 index 00000000000..7df994be0c7 --- /dev/null +++ b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_arity.tmpl @@ -0,0 +1,246 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Code generated from sdf_invokers_arity.tmpl. DO NOT EDIT. + +package exec + +import ( + "fmt" + + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf" + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx" + "github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors" +) + +func (n *cirInvoker) initCallFn() error { + // Expects a signature of the form: + // (context.Context?, key?, value) (restriction, error?) + // TODO(BEAM-9643): Link to full documentation. + switch fnT := n.fn.Fn.(type) { +{{range $out := upto 3}} +{{range $in := upto 4}} + {{if gt $out 0}} + {{if gt $in 0}} + case reflectx.Func{{$in}}x{{$out}}: + n.call = func() (rest any, err error) { + {{mktuplef $out "r%v"}} := fnT.Call{{$in}}x{{$out}}({{mktuplef $in "n.args[%v]"}}) + {{- if eq $out 1}} + return r0, nil + {{- else}} + return r0, asError(r1) + {{- end}} + } + {{end}} + {{end}} +{{end}} +{{end}} + default: + if len(n.fn.Param) < 1 || len(n.fn.Param) > 3 { + return errors.Errorf("CreateInitialRestriction has unexpected number of parameters: %v", len(n.fn.Param)) + } + + n.call = func() (rest any, err error) { + ret := n.fn.Fn.Call(n.args) + + switch len(ret) { + case 1: + return ret[0], nil + case 2: + return ret[0], asError(ret[1]) + } + + panic(fmt.Sprintf("CreateInitialRestriction has unexpected number of return values: %v", len(ret))) + } + } + + return nil +} + +func (n *srInvoker) initCallFn() error { + // Expects a signature of the form: + // (context.Context?, key?, value, restriction) ([]restriction, error?) + // TODO(BEAM-9643): Link to full documentation. + switch fnT := n.fn.Fn.(type) { +{{range $out := upto 3}} +{{range $in := upto 5}} + {{if gt $out 0}} + {{if gt $in 1}} + case reflectx.Func{{$in}}x{{$out}}: + n.call = func() (splits any, err error) { + {{mktuplef $out "r%v"}} := fnT.Call{{$in}}x{{$out}}({{mktuplef $in "n.args[%v]"}}) + {{- if eq $out 1}} + return r0, nil + {{- else}} + return r0, asError(r1) + {{- end}} + } + {{end}} + {{end}} +{{end}} +{{end}} + default: + if len(n.fn.Param) < 2 || len(n.fn.Param) > 4 { + return errors.Errorf("SplitRestriction has unexpected number of parameters: %v", len(n.fn.Param)) + } + + n.call = func() (splits any, err error) { + ret := n.fn.Fn.Call(n.args) + + switch len(ret) { + case 1: + return ret[0], nil + case 2: + return ret[0], asError(ret[1]) + } + + panic(fmt.Sprintf("SplitRestriction has unexpected number of return values: %v", len(ret))) + } + } + + return nil +} + +func (n *rsInvoker) initCallFn() error { + // Expects a signature of the form: + // (context.Context?, key?, value, restriction) (float64, error?) + // TODO(BEAM-9643): Link to full documentation. + switch fnT := n.fn.Fn.(type) { +{{range $out := upto 3}} +{{range $in := upto 5}} + {{if gt $out 0}} + {{if gt $in 1}} + case reflectx.Func{{$in}}x{{$out}}: + n.call = func() (size float64, err error) { + {{mktuplef $out "r%v"}} := fnT.Call{{$in}}x{{$out}}({{mktuplef $in "n.args[%v]"}}) + {{- if eq $out 1}} + return r0.(float64), nil + {{- else}} + return r0.(float64), asError(r1) + {{- end}} + } + {{end}} + {{end}} +{{end}} +{{end}} + default: + if len(n.fn.Param) < 2 || len(n.fn.Param) > 4 { + return errors.Errorf("RestrictionSize has unexpected number of parameters: %v", len(n.fn.Param)) + } + + n.call = func() (size float64, err error) { + ret := n.fn.Fn.Call(n.args) + + switch len(ret) { + case 1: + return ret[0].(float64), nil + case 2: + return ret[0].(float64), asError(ret[1]) + } + + panic(fmt.Sprintf("RestrictionSize has unexpected number of return values: %v", len(ret))) + } + } + + return nil +} + +func (n *ctInvoker) initCallFn() error { + // Expects a signature of the form: + // (context.Context?, restriction) (sdf.RTracker, error?) + // TODO(BEAM-9643): Link to full documentation. + switch fnT := n.fn.Fn.(type) { +{{range $out := upto 3}} +{{range $in := upto 3}} + {{if gt $out 0}} + {{if gt $in 0}} + case reflectx.Func{{$in}}x{{$out}}: + n.call = func() (rt sdf.RTracker, err error) { + {{mktuplef $out "r%v"}} := fnT.Call{{$in}}x{{$out}}({{mktuplef $in "n.args[%v]"}}) + {{- if eq $out 1}} + return r0.(sdf.RTracker), nil + {{- else}} + return r0.(sdf.RTracker), asError(r1) + {{- end}} + } + {{end}} + {{end}} +{{end}} +{{end}} + default: + if len(n.fn.Param) < 1 || len(n.fn.Param) > 2 { + return errors.Errorf("CreateTracker has unexpected number of parameters: %v", len(n.fn.Param)) + } + + n.call = func() (rt sdf.RTracker, err error) { + ret := n.fn.Fn.Call(n.args) + + switch len(ret) { + case 1: + return ret[0].(sdf.RTracker), nil + case 2: + return ret[0].(sdf.RTracker), asError(ret[1]) + } + + panic(fmt.Sprintf("CreateTracker has unexpected number of return values: %v", len(ret))) + } + } + + return nil +} + +func (n *trInvoker) initCallFn() error { + // Expects a signature of the form: + // (context.Context?, sdf.RTracker, key?, value) (restriction, error?) + // TODO(BEAM-9643): Link to full documentation. + switch fnT := n.fn.Fn.(type) { +{{range $out := upto 3}} +{{range $in := upto 5}} + {{if gt $out 0}} + {{if gt $in 1}} + case reflectx.Func{{$in}}x{{$out}}: + n.call = func() (rest any, err error) { + {{mktuplef $out "r%v"}} := fnT.Call{{$in}}x{{$out}}({{mktuplef $in "n.args[%v]"}}) + {{- if eq $out 1}} + return r0, nil + {{- else}} + return r0, asError(r1) + {{- end}} + } + {{end}} + {{end}} +{{end}} +{{end}} + default: + if len(n.fn.Param) < 2 || len(n.fn.Param) > 4 { + return errors.Errorf("TruncateRestriction has unexpected number of parameters: %v", len(n.fn.Param)) + } + + n.call = func() (rest any, err error) { + ret := n.fn.Fn.Call(n.args) + + switch len(ret) { + case 1: + return ret[0], nil + case 2: + return ret[0], asError(ret[1]) + } + + panic(fmt.Sprintf("TruncateRestriction has unexpected number of return values: %v", len(ret))) + } + } + + return nil +} diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_test.go b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_test.go index e308071aaf9..edef16e51d7 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_test.go +++ b/sdks/go/pkg/beam/core/runtime/exec/sdf_invokers_test.go @@ -16,6 +16,8 @@ package exec import ( + "context" + "errors" "testing" "time" @@ -48,13 +50,59 @@ func TestInvokes(t *testing.T) { } statefulWeFn := (*graph.SplittableDoFn)(dfn) + initialRestErrDfn, err := graph.NewDoFn( + &VetCreateInitialRestrictionErrSdf{}, + graph.NumMainInputs(graph.MainSingle), + ) + if err != nil { + t.Fatalf("invalid function: %v", err) + } + initialRestErrSdf := (*graph.SplittableDoFn)(initialRestErrDfn) + + splitRestErrDfn, err := graph.NewDoFn( + &VetSplitRestrictionErrSdf{}, + graph.NumMainInputs(graph.MainSingle), + ) + if err != nil { + t.Fatalf("invalid function: %v", err) + } + splitRestErrSdf := (*graph.SplittableDoFn)(splitRestErrDfn) + + restSizeErrDfn, err := graph.NewDoFn( + &VetRestrictionSizeErrSdf{}, + graph.NumMainInputs(graph.MainSingle), + ) + if err != nil { + t.Fatalf("invalid function: %v", err) + } + restSizeErrSdf := (*graph.SplittableDoFn)(restSizeErrDfn) + + trackerErrDfn, err := graph.NewDoFn( + &VetCreateTrackerErrSdf{}, + graph.NumMainInputs(graph.MainSingle), + ) + if err != nil { + t.Fatalf("invalid function: %v", err) + } + trackerErrSdf := (*graph.SplittableDoFn)(trackerErrDfn) + + truncateRestErrDfn, err := graph.NewDoFn( + &VetTruncateRestrictionErrSdf{}, + graph.NumMainInputs(graph.MainSingle), + ) + if err != nil { + t.Fatalf("invalid function: %v", err) + } + truncateRestErrSdf := (*graph.SplittableDoFn)(truncateRestErrDfn) + // Tests. t.Run("CreateInitialRestriction Invoker (cirInvoker)", func(t *testing.T) { tests := []struct { - name string - sdf *graph.SplittableDoFn - elms *FullValue - want *VetRestriction + name string + sdf *graph.SplittableDoFn + elms *FullValue + want *VetRestriction + wantErr bool }{ { name: "SingleElem", @@ -68,6 +116,12 @@ func TestInvokes(t *testing.T) { elms: &FullValue{Elm: 1, Elm2: 2}, want: &VetRestriction{ID: "KvSdf", CreateRest: true, Key: 1, Val: 2}, }, + { + name: "Error", + sdf: initialRestErrSdf, + elms: &FullValue{Elm: 1}, + wantErr: true, + }, } for _, test := range tests { test := test @@ -77,7 +131,15 @@ func TestInvokes(t *testing.T) { if err != nil { t.Fatalf("newCreateInitialRestrictionInvoker failed: %v", err) } - got := invoker.Invoke(test.elms) + + got, err := invoker.Invoke(context.Background(), test.elms) + if (err != nil) != test.wantErr { + t.Fatalf("Invoke(%v) error = %v, wantErr %v", test.elms, err, test.wantErr) + } + if test.wantErr { + return + } + if !cmp.Equal(got, test.want) { t.Errorf("Invoke(%v) has incorrect output: got: %v, want: %v", test.elms, got, test.want) @@ -94,11 +156,12 @@ func TestInvokes(t *testing.T) { t.Run("SplitRestriction Invoker (srInvoker)", func(t *testing.T) { tests := []struct { - name string - sdf *graph.SplittableDoFn - elms *FullValue - rest *VetRestriction - want []any + name string + sdf *graph.SplittableDoFn + elms *FullValue + rest *VetRestriction + want []any + wantErr bool }{ { name: "SingleElem", @@ -119,6 +182,13 @@ func TestInvokes(t *testing.T) { &VetRestriction{ID: "KvSdf.2", SplitRest: true, Key: 1, Val: 2}, }, }, + { + name: "Error", + sdf: splitRestErrSdf, + elms: &FullValue{Elm: 1}, + rest: &VetRestriction{ID: "Sdf"}, + wantErr: true, + }, } for _, test := range tests { test := test @@ -128,8 +198,16 @@ func TestInvokes(t *testing.T) { if err != nil { t.Fatalf("newSplitRestrictionInvoker failed: %v", err) } + rest := *test.rest // Create a copy because our test SDF edits the restriction. - got := invoker.Invoke(test.elms, &rest) + got, err := invoker.Invoke(context.Background(), test.elms, &rest) + if (err != nil) != test.wantErr { + t.Fatalf("Invoke(%v, %v) error = %v, wantErr %v", test.elms, test.rest, err, test.wantErr) + } + if test.wantErr { + return + } + if !cmp.Equal(got, test.want) { t.Errorf("Invoke(%v, %v) has incorrect output: got: %v, want: %v", test.elms, test.rest, got, test.want) @@ -152,6 +230,7 @@ func TestInvokes(t *testing.T) { rest *VetRestriction want float64 restWant *VetRestriction + wantErr bool }{ { name: "SingleElem", @@ -168,6 +247,13 @@ func TestInvokes(t *testing.T) { want: 3, restWant: &VetRestriction{ID: "KvSdf", RestSize: true, Key: 1, Val: 2}, }, + { + name: "Error", + sdf: restSizeErrSdf, + elms: &FullValue{Elm: 1}, + rest: &VetRestriction{ID: "Sdf"}, + wantErr: true, + }, } for _, test := range tests { test := test @@ -178,7 +264,15 @@ func TestInvokes(t *testing.T) { t.Fatalf("newRestrictionSizeInvoker failed: %v", err) } rest := *test.rest // Create a copy because our test SDF edits the restriction. - got := invoker.Invoke(test.elms, &rest) + + got, err := invoker.Invoke(context.Background(), test.elms, &rest) + if (err != nil) != test.wantErr { + t.Fatalf("Invoke(%v, %v) error = %v, wantErr %v", test.elms, test.rest, err, test.wantErr) + } + if test.wantErr { + return + } + if !cmp.Equal(got, test.want) { t.Errorf("Invoke(%v, %v) has incorrect output: got: %v, want: %v", test.elms, test.rest, got, test.want) @@ -199,10 +293,11 @@ func TestInvokes(t *testing.T) { t.Run("CreateTracker Invoker (ctInvoker)", func(t *testing.T) { tests := []struct { - name string - sdf *graph.SplittableDoFn - rest *VetRestriction - want *VetRTracker + name string + sdf *graph.SplittableDoFn + rest *VetRestriction + want *VetRTracker + wantErr bool }{ { name: "SingleElem", @@ -215,6 +310,12 @@ func TestInvokes(t *testing.T) { rest: &VetRestriction{ID: "KvSdf"}, want: &VetRTracker{&VetRestriction{ID: "KvSdf", CreateTracker: true}}, }, + { + name: "Error", + sdf: trackerErrSdf, + rest: &VetRestriction{ID: "Sdf"}, + wantErr: true, + }, } for _, test := range tests { test := test @@ -224,7 +325,15 @@ func TestInvokes(t *testing.T) { if err != nil { t.Fatalf("newCreateTrackerInvoker failed: %v", err) } - got := invoker.Invoke(test.rest) + + got, err := invoker.Invoke(context.Background(), test.rest) + if (err != nil) != test.wantErr { + t.Fatalf("Invoke(%v) error = %v, wantErr %v", test.rest, err, test.wantErr) + } + if test.wantErr { + return + } + if !cmp.Equal(got, test.want) { t.Errorf("Invoke(%v) has incorrect output: got: %v, want: %v", test.rest, got, test.want) @@ -310,11 +419,12 @@ func TestInvokes(t *testing.T) { t.Run("TruncateRestriction Invoker (trInvoker)", func(t *testing.T) { tests := []struct { - name string - sdf *graph.SplittableDoFn - elms *FullValue - rest *VetRestriction - want any + name string + sdf *graph.SplittableDoFn + elms *FullValue + rest *VetRestriction + want any + wantErr bool }{ { name: "SingleElem", @@ -329,6 +439,13 @@ func TestInvokes(t *testing.T) { rest: &VetRestriction{ID: "KvSdf"}, want: &VetRestriction{ID: "KvSdf", CreateTracker: true, TruncateRest: true, RestSize: true, Key: 1, Val: 2}, }, + { + name: "Error", + sdf: truncateRestErrSdf, + elms: &FullValue{Elm: 1}, + rest: &VetRestriction{ID: "Sdf"}, + wantErr: true, + }, } for _, test := range tests { test := test @@ -336,24 +453,38 @@ func TestInvokes(t *testing.T) { ctFn := test.sdf.CreateTrackerFn() rsFn := test.sdf.RestrictionSizeFn() t.Run(test.name, func(t *testing.T) { + ctx := context.Background() rest := test.rest // Create a copy because our test SDF edits the restriction. ctInvoker, err := newCreateTrackerInvoker(ctFn) if err != nil { t.Fatalf("newCreateTrackerInvoker failed: %v", err) } - rt := ctInvoker.Invoke(rest) + rt, err := ctInvoker.Invoke(ctx, rest) + if err != nil { + t.Fatalf("ctInvoker.Invoke(%v) failed: %v", rest, err) + } trInvoker, err := newTruncateRestrictionInvoker(fn) if err != nil { t.Fatalf("newTruncateRestrictionInvoker failed: %v", err) } - trRest := trInvoker.Invoke(rt, test.elms) + + trRest, err := trInvoker.Invoke(ctx, rt, test.elms) + if (err != nil) != test.wantErr { + t.Fatalf("trInvoker.Invoke(%v, %v) = %v, wantErr %v", rt, test.elms, err, test.wantErr) + } + if test.wantErr { + return + } rsInvoker, err := newRestrictionSizeInvoker(rsFn) if err != nil { t.Fatalf("newRestrictionSizeInvoker failed: %v", err) } - _ = rsInvoker.Invoke(test.elms, trRest) + if _, err := rsInvoker.Invoke(ctx, test.elms, trRest); err != nil { + t.Fatalf("rsInvoker.Invoke(%v, %v) failed: %v", test.elms, trRest, err) + } + if !cmp.Equal(trRest, test.want) { t.Errorf("Invoke(%v, %v) has incorrect output: got: %v, want: %v", test.elms, test.rest, trRest, test.want) @@ -411,24 +542,35 @@ func TestInvokes(t *testing.T) { ctFn := test.sdf.CreateTrackerFn() rsFn := test.sdf.RestrictionSizeFn() t.Run(test.name, func(t *testing.T) { + ctx := context.Background() rest := test.rest // Create a copy because our test SDF edits the restriction. ctInvoker, err := newCreateTrackerInvoker(ctFn) if err != nil { t.Fatalf("newCreateTrackerInvoker failed: %v", err) } - rt := ctInvoker.Invoke(rest) + rt, err := ctInvoker.Invoke(ctx, rest) + if err != nil { + t.Fatalf("ctInvoker.Invoke(%v) failed: %v", rest, err) + } trInvoker, err := newDefaultTruncateRestrictionInvoker() if err != nil { t.Fatalf("newTruncateRestrictionInvoker failed: %v", err) } - trRest := trInvoker.Invoke(rt, test.elms) + trRest, err := trInvoker.Invoke(ctx, rt, test.elms) + if err != nil { + t.Fatalf("trInvoker.Invoke(%v, %v) failed: %v", rt, test.elms, err) + } + if trRest != nil { rsInvoker, err := newRestrictionSizeInvoker(rsFn) if err != nil { t.Fatalf("newRestrictionSizeInvoker failed: %v", err) } - _ = rsInvoker.Invoke(test.elms, trRest) + if _, err := rsInvoker.Invoke(ctx, test.elms, trRest); err != nil { + t.Fatalf("rsInvoker.Invoke(%v, %v) failed: %v", test.elms, trRest, err) + } + if !cmp.Equal(trRest, test.want) { t.Errorf("Invoke(%v, %v) has incorrect output: got: %v, want: %v", test.elms, test.rest, trRest, test.want) @@ -732,3 +874,55 @@ func (fn *VetEmptyInitialSplitSdf) ProcessElement(rt *VetRTracker, i int, emit f emit(rest) return sdf.ResumeProcessingIn(1 * time.Second) } + +var errSdf = errors.New("SDF error") + +// VetCreateInitialRestrictionErrSdf is an SDF with a CreateInitialRestriction method +// that returns a non-nil error. +type VetCreateInitialRestrictionErrSdf struct { + VetSdf +} + +func (fn *VetCreateInitialRestrictionErrSdf) CreateInitialRestriction(i int) (*VetRestriction, error) { + return nil, errSdf +} + +// VetSplitRestrictionErrSdf is an SDF with a SplitRestriction method +// that returns a non-nil error. +type VetSplitRestrictionErrSdf struct { + VetSdf +} + +func (fn *VetSplitRestrictionErrSdf) SplitRestriction(int, *VetRestriction) ([]*VetRestriction, error) { + return nil, errSdf +} + +// VetRestrictionSizeErrSdf is an SDF with a RestrictionSize method +// that returns a non-nil error. +type VetRestrictionSizeErrSdf struct { + VetSdf +} + +func (fn *VetRestrictionSizeErrSdf) RestrictionSize(int, *VetRestriction) (float64, error) { + return -1, errSdf +} + +// VetCreateTrackerErrSdf is an SDF with a CreateTracker method +// that returns a non-nil error. +type VetCreateTrackerErrSdf struct { + VetSdf +} + +func (fn *VetCreateTrackerErrSdf) CreateTracker(*VetRestriction) (*VetRTracker, error) { + return nil, errSdf +} + +// VetTruncateRestrictionErrSdf is an SDF with a TruncateRestriction method +// that returns a non-nil error. +type VetTruncateRestrictionErrSdf struct { + VetSdf +} + +func (fn *VetTruncateRestrictionErrSdf) TruncateRestriction(*VetRTracker, int) (*VetRestriction, error) { + return nil, errSdf +} diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf_test.go b/sdks/go/pkg/beam/core/runtime/exec/sdf_test.go index 414f28553a8..a0380796e86 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/sdf_test.go +++ b/sdks/go/pkg/beam/core/runtime/exec/sdf_test.go @@ -1114,10 +1114,11 @@ func TestAsSplittableUnit(t *testing.T) { // Call from SplittableUnit and check results. su := SplittableUnit(node) - if err := node.Up(context.Background()); err != nil { + ctx := context.Background() + if err := node.Up(ctx); err != nil { t.Fatalf("ProcessSizedElementsAndRestrictions.Up() failed: %v", err) } - gotPrimaries, gotResiduals, err := su.Split(test.frac) + gotPrimaries, gotResiduals, err := su.Split(ctx, test.frac) if err != nil { t.Fatalf("SplittableUnit.Split(%v) failed: %v", test.frac, err) } @@ -1184,10 +1185,11 @@ func TestAsSplittableUnit(t *testing.T) { // Call from SplittableUnit and check results. su := SplittableUnit(node) - if err := node.Up(context.Background()); err != nil { + ctx := context.Background() + if err := node.Up(ctx); err != nil { t.Fatalf("ProcessSizedElementsAndRestrictions.Up() failed: %v", err) } - _, _, err := su.Split(0.5) + _, _, err := su.Split(ctx, 0.5) if err == nil { t.Errorf("SplittableUnit.Split(%v) was expected to fail.", test.in) } @@ -1251,10 +1253,11 @@ func TestAsSplittableUnit(t *testing.T) { node.currW = 0 // Call from SplittableUnit and check results. su := SplittableUnit(node) - if err := node.Up(context.Background()); err != nil { + ctx := context.Background() + if err := node.Up(ctx); err != nil { t.Fatalf("ProcessSizedElementsAndRestrictions.Up() failed: %v", err) } - gotResiduals, err := su.Checkpoint() + gotResiduals, err := su.Checkpoint(ctx) if err != nil { t.Fatalf("SplittableUnit.Checkpoint() returned error, got %v", err) @@ -1401,7 +1404,7 @@ func TestMultiWindowProcessing(t *testing.T) { // Split should hit window boundary between 2 and 3. We don't need to check // the split result here, just the effects it has on currW and numW. frac := 0.5 - if _, _, err := su.Split(frac); err != nil { + if _, _, err := su.Split(context.Background(), frac); err != nil { t.Errorf("Split(%v) failed with error: %v", frac, err) } if got, want := node.currW, blockW; got != want { diff --git a/sdks/go/pkg/beam/runners/prism/internal/config/config.go b/sdks/go/pkg/beam/runners/prism/internal/config/config.go index fc2b68d092f..9c3bdd012bc 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/config/config.go +++ b/sdks/go/pkg/beam/runners/prism/internal/config/config.go @@ -244,4 +244,4 @@ func (r *HandlerRegistry) GetVariant(name string) *Variant { return nil } return &Variant{parent: r, name: name, handlers: vs.Handlers} -} \ No newline at end of file +} diff --git a/sdks/go/pkg/beam/runners/prism/internal/urns/urns_test.go b/sdks/go/pkg/beam/runners/prism/internal/urns/urns_test.go index 3c7cae97397..7b553f6ad65 100644 --- a/sdks/go/pkg/beam/runners/prism/internal/urns/urns_test.go +++ b/sdks/go/pkg/beam/runners/prism/internal/urns/urns_test.go @@ -33,4 +33,4 @@ func Test_toUrn(t *testing.T) { if got := quickUrn(pipepb.StandardPTransforms_PAR_DO); got != want { t.Errorf("quickUrn(\"pipepb.StandardPTransforms_PAR_DO\") = %v, want %v", got, want) } -} \ No newline at end of file +}