This is an automated email from the ASF dual-hosted git repository.
kenn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 630f32ada54 [Drain] Expose drain to dofn processElement and onTimer
(#37825)
630f32ada54 is described below
commit 630f32ada545fd22812c428e5543e4a89c8075ca
Author: Radosław Stankiewicz <[email protected]>
AuthorDate: Thu Mar 12 15:36:23 2026 +0100
[Drain] Expose drain to dofn processElement and onTimer (#37825)
---
.../apache/beam/runners/core/SimpleDoFnRunner.java | 15 ++++++++++
.../reflect/ByteBuddyDoFnInvokerFactory.java | 11 +++++++
.../beam/sdk/transforms/reflect/DoFnInvoker.java | 15 ++++++++++
.../beam/sdk/transforms/reflect/DoFnSignature.java | 26 +++++++++++++++++
.../sdk/transforms/reflect/DoFnSignatures.java | 10 +++++++
.../construction/SplittableParDoNaiveBounded.java | 5 ++++
.../sdk/transforms/reflect/DoFnSignaturesTest.java | 34 ++++++++++++++++++++--
.../apache/beam/fn/harness/FnApiDoFnRunner.java | 10 +++++++
8 files changed, 124 insertions(+), 2 deletions(-)
diff --git
a/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java
b/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java
index a255467fc59..74f5a4d0900 100644
---
a/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java
+++
b/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java
@@ -555,6 +555,11 @@ public class SimpleDoFnRunner<InputT, OutputT> implements
DoFnRunner<InputT, Out
return timestamp();
}
+ @Override
+ public CausedByDrain causedByDrain(DoFn<InputT, OutputT> doFn) {
+ return elem.causedByDrain();
+ }
+
@Override
public String timerId(DoFn<InputT, OutputT> doFn) {
throw new UnsupportedOperationException(
@@ -831,6 +836,11 @@ public class SimpleDoFnRunner<InputT, OutputT> implements
DoFnRunner<InputT, Out
return timestamp();
}
+ @Override
+ public CausedByDrain causedByDrain(DoFn<InputT, OutputT> doFn) {
+ return causedByDrain;
+ }
+
@Override
public String timerId(DoFn<InputT, OutputT> doFn) {
return timerId;
@@ -1119,6 +1129,11 @@ public class SimpleDoFnRunner<InputT, OutputT>
implements DoFnRunner<InputT, Out
return timestamp;
}
+ @Override
+ public CausedByDrain causedByDrain(DoFn<InputT, OutputT> doFn) {
+ throw new UnsupportedOperationException("CausedByDrain parameters are
not supported.");
+ }
+
@Override
public String timerId(DoFn<InputT, OutputT> doFn) {
throw new UnsupportedOperationException("Timer parameters are not
supported.");
diff --git
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java
index 780eb0075db..54d630d92fe 100644
---
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java
+++
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java
@@ -76,6 +76,7 @@ import
org.apache.beam.sdk.transforms.DoFn.TruncateRestriction;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.OnTimerMethod;
import
org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.BundleFinalizerParameter;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.Cases;
+import
org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.CausedByDrainParameter;
import
org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.ElementParameter;
import
org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.FinishBundleContextParameter;
import
org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.OnTimerContextParameter;
@@ -126,6 +127,7 @@ class ByteBuddyDoFnInvokerFactory implements
DoFnInvokerFactory {
public static final String ELEMENT_PARAMETER_METHOD = "element";
public static final String SCHEMA_ELEMENT_PARAMETER_METHOD = "schemaElement";
public static final String TIMESTAMP_PARAMETER_METHOD = "timestamp";
+ public static final String CAUSED_BY_DRAIN_PARAMETER_METHOD =
"causedByDrain";
public static final String BUNDLE_FINALIZER_PARAMETER_METHOD =
"bundleFinalizer";
public static final String OUTPUT_ROW_RECEIVER_METHOD = "outputRowReceiver";
public static final String TIME_DOMAIN_PARAMETER_METHOD = "timeDomain";
@@ -1100,6 +1102,15 @@ class ByteBuddyDoFnInvokerFactory implements
DoFnInvokerFactory {
TIMESTAMP_PARAMETER_METHOD, DoFn.class)));
}
+ @Override
+ public StackManipulation dispatch(CausedByDrainParameter p) {
+ return new StackManipulation.Compound(
+ pushDelegate,
+ MethodInvocation.invoke(
+ getExtraContextFactoryMethodDescription(
+ CAUSED_BY_DRAIN_PARAMETER_METHOD, DoFn.class)));
+ }
+
@Override
public StackManipulation dispatch(BundleFinalizerParameter p) {
return
simpleExtraContextParameter(BUNDLE_FINALIZER_PARAMETER_METHOD);
diff --git
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java
index 0079435700c..a615761292a 100644
---
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java
+++
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java
@@ -41,6 +41,7 @@ import
org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker.Truncate
import org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimator;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
+import org.apache.beam.sdk.values.CausedByDrain;
import org.apache.beam.sdk.values.Row;
import
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting;
import org.checkerframework.checker.nullness.qual.Nullable;
@@ -217,6 +218,9 @@ public interface DoFnInvoker<InputT, OutputT> {
/** Provide a reference to the input element timestamp. */
Instant timestamp(DoFn<InputT, OutputT> doFn);
+ /** Provide a reference to the caused by drain. */
+ CausedByDrain causedByDrain(DoFn<InputT, OutputT> doFn);
+
/** Provide a reference to the time domain for a timer firing. */
TimeDomain timeDomain(DoFn<InputT, OutputT> doFn);
@@ -325,6 +329,12 @@ public interface DoFnInvoker<InputT, OutputT> {
String.format("Timestamp unsupported in %s", getErrorContext()));
}
+ @Override
+ public CausedByDrain causedByDrain(DoFn<InputT, OutputT> doFn) {
+ throw new UnsupportedOperationException(
+ String.format("CausedByDrain unsupported in %s", getErrorContext()));
+ }
+
@Override
public String timerId(DoFn<InputT, OutputT> doFn) {
throw new UnsupportedOperationException(
@@ -514,6 +524,11 @@ public interface DoFnInvoker<InputT, OutputT> {
return delegate.timestamp(doFn);
}
+ @Override
+ public CausedByDrain causedByDrain(DoFn<InputT, OutputT> doFn) {
+ return delegate.causedByDrain(doFn);
+ }
+
@Override
public TimeDomain timeDomain(DoFn<InputT, OutputT> doFn) {
return delegate.timeDomain(doFn);
diff --git
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java
index 8f254642f08..af0353c902a 100644
---
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java
+++
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java
@@ -342,6 +342,8 @@ public abstract class DoFnSignature {
return cases.dispatch((TimerIdParameter) this);
} else if (this instanceof BundleFinalizerParameter) {
return cases.dispatch((BundleFinalizerParameter) this);
+ } else if (this instanceof CausedByDrainParameter) {
+ return cases.dispatch((CausedByDrainParameter) this);
} else if (this instanceof KeyParameter) {
return cases.dispatch((KeyParameter) this);
} else {
@@ -400,6 +402,8 @@ public abstract class DoFnSignature {
ResultT dispatch(BundleFinalizerParameter p);
+ ResultT dispatch(CausedByDrainParameter p);
+
ResultT dispatch(KeyParameter p);
/** A base class for a visitor with a default method for cases it is not
interested in. */
@@ -497,6 +501,11 @@ public abstract class DoFnSignature {
return dispatchDefault(p);
}
+ @Override
+ public ResultT dispatch(CausedByDrainParameter p) {
+ return dispatchDefault(p);
+ }
+
@Override
public ResultT dispatch(StateParameter p) {
return dispatchDefault(p);
@@ -552,6 +561,8 @@ public abstract class DoFnSignature {
new AutoValue_DoFnSignature_Parameter_PipelineOptionsParameter();
private static final BundleFinalizerParameter BUNDLE_FINALIZER_PARAMETER =
new AutoValue_DoFnSignature_Parameter_BundleFinalizerParameter();
+ private static final CausedByDrainParameter CAUSED_BY_DRAIN_PARAMETER =
+ new AutoValue_DoFnSignature_Parameter_CausedByDrainParameter();
private static final OnWindowExpirationContextParameter
ON_WINDOW_EXPIRATION_CONTEXT_PARAMETER =
new
AutoValue_DoFnSignature_Parameter_OnWindowExpirationContextParameter();
@@ -575,6 +586,11 @@ public abstract class DoFnSignature {
return BUNDLE_FINALIZER_PARAMETER;
}
+ /** Returns a {@link CausedByDrainParameter}. */
+ public static CausedByDrainParameter causedByDrainParameter() {
+ return CAUSED_BY_DRAIN_PARAMETER;
+ }
+
public static ElementParameter elementParameter(TypeDescriptor<?>
elementT) {
return new AutoValue_DoFnSignature_Parameter_ElementParameter(elementT);
}
@@ -727,6 +743,16 @@ public abstract class DoFnSignature {
BundleFinalizerParameter() {}
}
+ /**
+ * Descriptor for a {@link Parameter} of type {@link
org.apache.beam.sdk.values.CausedByDrain}.
+ *
+ * <p>All such descriptors are equal.
+ */
+ @AutoValue
+ public abstract static class CausedByDrainParameter extends Parameter {
+ CausedByDrainParameter() {}
+ }
+
/**
* Descriptor for a {@link Parameter} of type {@link DoFn.Element}.
*
diff --git
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
index c39edccd58f..3dcf7ff1f9d 100644
---
a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
+++
b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java
@@ -91,6 +91,7 @@ import
org.apache.beam.sdk.transforms.splittabledofn.WatermarkEstimator;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.util.common.ReflectHelpers;
+import org.apache.beam.sdk.values.CausedByDrain;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.Row;
@@ -139,6 +140,7 @@ public class DoFnSignatures {
Parameter.StateParameter.class,
Parameter.SideInputParameter.class,
Parameter.TimerFamilyParameter.class,
+ Parameter.CausedByDrainParameter.class,
Parameter.BundleFinalizerParameter.class);
private static final ImmutableList<Class<? extends Parameter>>
@@ -155,6 +157,7 @@ public class DoFnSignatures {
Parameter.RestrictionTrackerParameter.class,
Parameter.WatermarkEstimatorParameter.class,
Parameter.SideInputParameter.class,
+ Parameter.CausedByDrainParameter.class,
Parameter.BundleFinalizerParameter.class);
private static final ImmutableList<Class<? extends Parameter>>
ALLOWED_SETUP_PARAMETERS =
@@ -185,6 +188,7 @@ public class DoFnSignatures {
Parameter.StateParameter.class,
Parameter.TimerFamilyParameter.class,
Parameter.TimerIdParameter.class,
+ Parameter.CausedByDrainParameter.class,
Parameter.KeyParameter.class);
private static final ImmutableList<Class<? extends Parameter>>
@@ -201,6 +205,7 @@ public class DoFnSignatures {
Parameter.StateParameter.class,
Parameter.TimerFamilyParameter.class,
Parameter.TimerIdParameter.class,
+ Parameter.CausedByDrainParameter.class,
Parameter.KeyParameter.class);
private static final Collection<Class<? extends Parameter>>
@@ -1357,6 +1362,11 @@ public class DoFnSignatures {
return Parameter.keyT(paramT);
} else if (rawType.equals(TimeDomain.class)) {
return Parameter.timeDomainParameter();
+ } else if (CausedByDrain.class.isAssignableFrom(rawType)) {
+ methodErrors.checkArgument(
+ rawType.equals(CausedByDrain.class),
+ "CausedByDrain argument must have type
org.apache.beam.sdk.values.CausedByDrain.");
+ return Parameter.causedByDrainParameter();
} else if (hasAnnotation(DoFn.SideInput.class, param.getAnnotations())) {
String sideInputId = getSideInputId(param.getAnnotations());
paramErrors.checkArgument(
diff --git
a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/SplittableParDoNaiveBounded.java
b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/SplittableParDoNaiveBounded.java
index a22d3378cfd..6d058b3b6ad 100644
---
a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/SplittableParDoNaiveBounded.java
+++
b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/SplittableParDoNaiveBounded.java
@@ -543,6 +543,11 @@ public class SplittableParDoNaiveBounded {
return outerContext.timestamp();
}
+ @Override
+ public CausedByDrain causedByDrain(DoFn<InputT, OutputT> doFn) {
+ return outerContext.causedByDrain();
+ }
+
@Override
public String timerId(DoFn<InputT, OutputT> doFn) {
throw new UnsupportedOperationException();
diff --git
a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java
index de4a622e03d..3369e18519b 100644
---
a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java
+++
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java
@@ -56,6 +56,7 @@ import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Sum;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter;
import
org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.BundleFinalizerParameter;
+import
org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.CausedByDrainParameter;
import
org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.ElementParameter;
import
org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.FinishBundleContextParameter;
import
org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.OutputReceiverParameter;
@@ -78,6 +79,7 @@ import
org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
+import org.apache.beam.sdk.values.CausedByDrain;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.sdk.values.TypeDescriptor;
@@ -130,10 +132,11 @@ public class DoFnSignaturesTest {
PipelineOptions options,
@SideInput("tag1") String input1,
@SideInput("tag2") Integer input2,
- BundleFinalizer bundleFinalizer) {}
+ BundleFinalizer bundleFinalizer,
+ CausedByDrain causedByDrain) {}
}.getClass());
- assertThat(sig.processElement().extraParameters().size(), equalTo(9));
+ assertThat(sig.processElement().extraParameters().size(), equalTo(10));
assertThat(sig.processElement().extraParameters().get(0),
instanceOf(ElementParameter.class));
assertThat(sig.processElement().extraParameters().get(1),
instanceOf(TimestampParameter.class));
assertThat(sig.processElement().extraParameters().get(2),
instanceOf(WindowParameter.class));
@@ -146,6 +149,8 @@ public class DoFnSignaturesTest {
assertThat(sig.processElement().extraParameters().get(7),
instanceOf(SideInputParameter.class));
assertThat(
sig.processElement().extraParameters().get(8),
instanceOf(BundleFinalizerParameter.class));
+ assertThat(
+ sig.processElement().extraParameters().get(9),
instanceOf(CausedByDrainParameter.class));
}
@Test
@@ -585,6 +590,31 @@ public class DoFnSignaturesTest {
instanceOf(WindowParameter.class));
}
+ @Test
+ public void testCausedByDrainOnTimer() throws Exception {
+ final String timerId = "some-timer-id";
+ final String timerDeclarationId = TimerDeclaration.PREFIX + timerId;
+
+ DoFnSignature sig =
+ DoFnSignatures.getSignature(
+ new DoFn<String, String>() {
+
+ @TimerId(timerId)
+ private final TimerSpec myfield1 =
TimerSpecs.timer(TimeDomain.EVENT_TIME);
+
+ @ProcessElement
+ public void process(ProcessContext c) {}
+
+ @OnTimer(timerId)
+ public void onTimer(CausedByDrain causedByDrain) {}
+ }.getClass());
+
+
assertThat(sig.onTimerMethods().get(timerDeclarationId).extraParameters().size(),
equalTo(1));
+ assertThat(
+ sig.onTimerMethods().get(timerDeclarationId).extraParameters().get(0),
+ instanceOf(CausedByDrainParameter.class));
+ }
+
@Test
public void testAllParamsOnTimer() throws Exception {
final String timerId = "some-timer-id";
diff --git
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
index 3893c0f405e..1dfa336e35f 100644
---
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
+++
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java
@@ -1804,6 +1804,11 @@ public class FnApiDoFnRunner<InputT, RestrictionT,
PositionT, WatermarkEstimator
outputTo(consumer, WindowedValues.of(output, timestamp, windows,
paneInfo));
}
+ @Override
+ public CausedByDrain causedByDrain(DoFn<InputT, OutputT> doFn) {
+ return currentElement.causedByDrain();
+ }
+
@Override
public State state(String stateId, boolean alwaysFetched) {
StateDeclaration stateDeclaration =
doFnSignature.stateDeclarations().get(stateId);
@@ -1946,6 +1951,11 @@ public class FnApiDoFnRunner<InputT, RestrictionT,
PositionT, WatermarkEstimator
public CausedByDrain causedByDrain() {
return currentElement.causedByDrain();
}
+
+ @Override
+ public CausedByDrain causedByDrain(DoFn<InputT, OutputT> doFn) {
+ return currentElement.causedByDrain();
+ }
}
/** Provides base arguments for a {@link DoFnInvoker} for a non-window
observing method. */