This is an automated email from the ASF dual-hosted git repository.

lostluck pushed a commit to branch drainFix
in repository https://gitbox.apache.org/repos/asf/beam.git

commit 2acb711d735503ec2bc7bbfefb10e75615547bbb
Author: lostluck <13907733+lostl...@users.noreply.github.com>
AuthorDate: Mon Mar 20 15:14:36 2023 -0700

    Ensure truncate element is wrapped in *FullValue
---
 sdks/go/pkg/beam/core/runtime/exec/sdf.go    | 27 +++++++++++----------------
 sdks/go/test/integration/primitives/drain.go | 24 +++++++++++++-----------
 2 files changed, 24 insertions(+), 27 deletions(-)

diff --git a/sdks/go/pkg/beam/core/runtime/exec/sdf.go 
b/sdks/go/pkg/beam/core/runtime/exec/sdf.go
index 1dd3e35dc4d..6482bfb3a6a 100644
--- a/sdks/go/pkg/beam/core/runtime/exec/sdf.go
+++ b/sdks/go/pkg/beam/core/runtime/exec/sdf.go
@@ -297,10 +297,10 @@ func (n *TruncateSizedRestriction) StartBundle(ctx 
context.Context, id string, d
 // Input Diagram:
 //
 //       *FullValue {
-//         Elm: *FullValue {
-//           Elm:  *FullValue (original input)
+//         Elm: *FullValue {  -- mainElm
+//           Elm:  *FullValue (original input)  -- inp
 //           Elm2: *FullValue {
-//                    Elm: Restriction
+//                    Elm: Restriction  -- rest
 //                    Elm2: Watermark estimator state
 //           }
 //         }
@@ -325,24 +325,19 @@ func (n *TruncateSizedRestriction) StartBundle(ctx 
context.Context, id string, d
 //        }
 func (n *TruncateSizedRestriction) ProcessElement(ctx context.Context, elm 
*FullValue, values ...ReStream) error {
        mainElm := elm.Elm.(*FullValue)
-       inp := mainElm.Elm
-       // For the main element, the way we fill it out depends on whether the 
input element
-       // is a KV or single-element. Single-elements might have been lifted 
out of
-       // their FullValue if they were decoded, so we need to have a case for 
that.
-       // TODO(https://github.com/apache/beam/issues/20196): Optimize this so 
it's decided in exec/translate.go
-       // instead of checking per-element.
-       if e, ok := mainElm.Elm.(*FullValue); ok {
-               mainElm = e
-               inp = e
-       }
-       rest := elm.Elm.(*FullValue).Elm2.(*FullValue).Elm
+
+       // If receiving directly from a datasource,
+       // the element may not be wrapped in a *FullValue
+       inp := convertIfNeeded(mainElm.Elm, &FullValue{})
+
+       rest := mainElm.Elm2.(*FullValue).Elm
 
        rt, err := n.ctInv.Invoke(ctx, rest)
        if err != nil {
                return err
        }
 
-       newRest, err := n.truncateInv.Invoke(ctx, rt, mainElm)
+       newRest, err := n.truncateInv.Invoke(ctx, rt, inp)
        if err != nil {
                return err
        }
@@ -351,7 +346,7 @@ func (n *TruncateSizedRestriction) ProcessElement(ctx 
context.Context, elm *Full
                return nil
        }
 
-       size, err := n.sizeInv.Invoke(ctx, mainElm, newRest)
+       size, err := n.sizeInv.Invoke(ctx, inp, newRest)
        if err != nil {
                return err
        }
diff --git a/sdks/go/test/integration/primitives/drain.go 
b/sdks/go/test/integration/primitives/drain.go
index 2e861f54615..d116dfa8bd3 100644
--- a/sdks/go/test/integration/primitives/drain.go
+++ b/sdks/go/test/integration/primitives/drain.go
@@ -28,7 +28,7 @@ import (
 )
 
 func init() {
-       register.DoFn3x1[*sdf.LockRTracker, []byte, func(int64), 
sdf.ProcessContinuation](&TruncateFn{})
+       register.DoFn4x1[context.Context, *sdf.LockRTracker, []byte, 
func(int64), sdf.ProcessContinuation](&TruncateFn{})
 
        register.Emitter1[int64]()
 }
@@ -83,9 +83,14 @@ func (fn *TruncateFn) SplitRestriction(_ []byte, rest 
offsetrange.Restriction) [
 }
 
 // TruncateRestriction truncates the restriction during drain.
-func (fn *TruncateFn) TruncateRestriction(rt *sdf.LockRTracker, _ []byte) 
offsetrange.Restriction {
-       start := rt.GetRestriction().(offsetrange.Restriction).Start
+func (fn *TruncateFn) TruncateRestriction(ctx context.Context, rt 
*sdf.LockRTracker, _ []byte) offsetrange.Restriction {
+       rest := rt.GetRestriction().(offsetrange.Restriction)
+       start := rest.Start
        newEnd := start + 20
+
+       done, remaining := rt.GetProgress()
+       log.Infof(ctx, "Draining at: done %v, remaining %v, start %v, end %v, 
newEnd %v", done, remaining, start, rest.End, newEnd)
+
        return offsetrange.Restriction{
                Start: start,
                End:   newEnd,
@@ -93,29 +98,26 @@ func (fn *TruncateFn) TruncateRestriction(rt 
*sdf.LockRTracker, _ []byte) offset
 }
 
 // ProcessElement continually gets the start position of the restriction and 
emits the element as it is.
-func (fn *TruncateFn) ProcessElement(rt *sdf.LockRTracker, _ []byte, emit 
func(int64)) sdf.ProcessContinuation {
+func (fn *TruncateFn) ProcessElement(ctx context.Context, rt 
*sdf.LockRTracker, _ []byte, emit func(int64)) sdf.ProcessContinuation {
        position := rt.GetRestriction().(offsetrange.Restriction).Start
-       counter := 0
        for {
                if rt.TryClaim(position) {
+                       log.Infof(ctx, "Claimed position: %v", position)
                        // Successful claim, emit the value and move on.
                        emit(position)
                        position++
-                       counter++
                } else if rt.GetError() != nil || rt.IsDone() {
                        // Stop processing on error or completion
                        if err := rt.GetError(); err != nil {
-                               log.Errorf(context.Background(), "error in 
restriction tracker, got %v", err)
+                               log.Errorf(ctx, "error in restriction tracker, 
got %v", err)
                        }
+                       log.Infof(ctx, "Restriction done at position %v.", 
position)
                        return sdf.StopProcessing()
                } else {
+                       log.Infof(ctx, "Checkpointed at position %v, resuming 
later.", position)
                        // Resume later.
                        return sdf.ResumeProcessingIn(5 * time.Second)
                }
-
-               if counter >= 10 {
-                       return sdf.ResumeProcessingIn(1 * time.Second)
-               }
                time.Sleep(1 * time.Second)
        }
 }

Reply via email to