This is an automated email from the ASF dual-hosted git repository. martijnvisser pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push: new 03b45844228 [FLINK-34337][Core] Sink.InitContextWrapper should implement metadataConsumer method. This closes #24249 03b45844228 is described below commit 03b4584422826d2819d571871dfef4efced19f01 Author: Jiabao Sun <jiabao....@xtransfer.cn> AuthorDate: Wed Feb 7 03:19:43 2024 +0800 [FLINK-34337][Core] Sink.InitContextWrapper should implement metadataConsumer method. This closes #24249 * Sink.InitContextWrapper should implement metadataConsumer method * Add test for InitContextWrapper --- .../org/apache/flink/api/connector/sink2/Sink.java | 6 ++ .../operators/sink/SinkWriterOperatorTestBase.java | 94 ++++++++++++++++++++++ 2 files changed, 100 insertions(+) diff --git a/flink-core/src/main/java/org/apache/flink/api/connector/sink2/Sink.java b/flink-core/src/main/java/org/apache/flink/api/connector/sink2/Sink.java index 7558fb2fa8e..49d3601c40c 100644 --- a/flink-core/src/main/java/org/apache/flink/api/connector/sink2/Sink.java +++ b/flink-core/src/main/java/org/apache/flink/api/connector/sink2/Sink.java @@ -218,5 +218,11 @@ public interface Sink<InputT> extends Serializable { public <IN> TypeSerializer<IN> createInputSerializer() { return wrapped.createInputSerializer(); } + + @Experimental + @Override + public <MetaT> Optional<Consumer<MetaT>> metadataConsumer() { + return wrapped.metadataConsumer(); + } } } diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/sink/SinkWriterOperatorTestBase.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/sink/SinkWriterOperatorTestBase.java index c4627605f71..debe699c3a8 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/sink/SinkWriterOperatorTestBase.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/sink/SinkWriterOperatorTestBase.java @@ -26,6 +26,8 @@ import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.common.typeutils.base.StringSerializer; import org.apache.flink.api.common.typeutils.base.array.BytePrimitiveArraySerializer; import org.apache.flink.api.connector.sink2.Sink; +import org.apache.flink.api.connector.sink2.SinkWriter; +import org.apache.flink.api.connector.sink2.WriterInitContext; import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.core.io.SimpleVersionedSerialization; import org.apache.flink.core.io.SimpleVersionedSerializer; @@ -55,13 +57,17 @@ import org.junit.jupiter.params.provider.ValueSource; import javax.annotation.Nullable; import java.io.IOException; +import java.lang.reflect.Proxy; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.List; +import java.util.Optional; import java.util.Queue; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Consumer; import java.util.function.LongSupplier; import java.util.function.Supplier; @@ -425,6 +431,94 @@ abstract class SinkWriterOperatorTestBase { testHarness.close(); } + @Test + void testInitContextWrapper() throws Exception { + final AtomicReference<Sink.InitContext> initContext = new AtomicReference<>(); + final AtomicReference<WriterInitContext> originalContext = new AtomicReference<>(); + final AtomicBoolean consumed = new AtomicBoolean(false); + final Consumer<AtomicBoolean> metadataConsumer = element -> element.set(true); + + final Sink<String> sink = + new Sink<String>() { + @Override + public SinkWriter<String> createWriter(WriterInitContext context) + throws IOException { + WriterInitContext decoratedContext = + (WriterInitContext) + Proxy.newProxyInstance( + WriterInitContext.class.getClassLoader(), + new Class[] {WriterInitContext.class}, + (proxy, method, args) -> { + if (method.getName() + .equals("metadataConsumer")) { + return Optional.of(metadataConsumer); + } + return method.invoke(context, args); + }); + originalContext.set(decoratedContext); + return Sink.super.createWriter(decoratedContext); + } + + @Override + public SinkWriter<String> createWriter(InitContext context) { + initContext.set(context); + return null; + } + }; + + final int subtaskId = 1; + final int parallelism = 10; + final TypeSerializer<String> typeSerializer = StringSerializer.INSTANCE; + final JobID jobID = new JobID(); + + final MockEnvironment environment = + MockEnvironment.builder() + .setSubtaskIndex(subtaskId) + .setParallelism(parallelism) + .setMaxParallelism(parallelism) + .setJobID(jobID) + .setExecutionConfig(new ExecutionConfig().enableObjectReuse()) + .build(); + + final OneInputStreamOperatorTestHarness<String, CommittableMessage<String>> testHarness = + new OneInputStreamOperatorTestHarness<>( + new SinkWriterOperatorFactory<>(sink), typeSerializer, environment); + testHarness.open(); + + assertContextsEqual(initContext.get(), originalContext.get()); + assertThat(initContext.get().metadataConsumer()) + .isPresent() + .hasValueSatisfying( + consumer -> { + consumer.accept(consumed); + assertThat(consumed).isTrue(); + }); + + testHarness.close(); + } + + private static void assertContextsEqual( + Sink.InitContext initContext, WriterInitContext original) { + assertThat(initContext.getUserCodeClassLoader().asClassLoader()) + .isEqualTo(original.getUserCodeClassLoader().asClassLoader()); + assertThat(initContext.getMailboxExecutor()).isEqualTo(original.getMailboxExecutor()); + assertThat(initContext.getProcessingTimeService()) + .isEqualTo(original.getProcessingTimeService()); + assertThat(initContext.getTaskInfo().getIndexOfThisSubtask()) + .isEqualTo(original.getTaskInfo().getIndexOfThisSubtask()); + assertThat(initContext.getTaskInfo().getNumberOfParallelSubtasks()) + .isEqualTo(original.getTaskInfo().getNumberOfParallelSubtasks()); + assertThat(initContext.getTaskInfo().getAttemptNumber()) + .isEqualTo(original.getTaskInfo().getAttemptNumber()); + assertThat(initContext.metricGroup()).isEqualTo(original.metricGroup()); + assertThat(initContext.getRestoredCheckpointId()) + .isEqualTo(original.getRestoredCheckpointId()); + assertThat(initContext.isObjectReuseEnabled()).isEqualTo(original.isObjectReuseEnabled()); + assertThat(initContext.createInputSerializer()).isEqualTo(original.createInputSerializer()); + assertThat(initContext.getJobInfo().getJobId()).isEqualTo(original.getJobInfo().getJobId()); + assertThat(initContext.metadataConsumer()).isEqualTo(original.metadataConsumer()); + } + @SuppressWarnings("unchecked") private static void assertRestoredCommitterCommittable(Object record, String committable) { assertThat(record)