This is an automated email from the ASF dual-hosted git repository. boyuanz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new e8eed0b [BEAM-11325] ReadFromKafkaDoFn should stop reading when topic/partition is removed or marked as stopped. new d2e1f69 Merge pull request #13710 from [BEAM-11325] ReadFromKafkaDoFn should stop reading when topic/partition is removed or marked as stopped e8eed0b is described below commit e8eed0bf70c334fe59327f0d70453302935410ee Author: Boyuan Zhang <boyu...@google.com> AuthorDate: Fri Jan 8 13:43:06 2021 -0800 [BEAM-11325] ReadFromKafkaDoFn should stop reading when topic/partition is removed or marked as stopped. --- .../java/org/apache/beam/sdk/io/kafka/KafkaIO.java | 36 ++- .../beam/sdk/io/kafka/ReadFromKafkaDoFn.java | 43 ++- .../beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java | 344 +++++++++++++++++++++ 3 files changed, 421 insertions(+), 2 deletions(-) diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java index 60608e0..93759e6 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java @@ -141,6 +141,11 @@ import org.slf4j.LoggerFactory; * // offset consumed by the pipeline can be committed back. * .commitOffsetsInFinalize() * + * // Specified a serializable function which can determine whether to stop reading from given + * // TopicPartition during runtime. Note that only {@link ReadFromKafkaDoFn} respect the + * // signal. + * .withCheckStopReadingFn(new SerializedFunction<TopicPartition, Boolean>() {}) + * * // finally, if you don't need Kafka metadata, you can drop it.g * .withoutMetadata() // PCollection<KV<Long, String>> * ) @@ -514,6 +519,8 @@ public class KafkaIO { abstract @Nullable DeserializerProvider getValueDeserializerProvider(); + abstract @Nullable SerializableFunction<TopicPartition, Boolean> getCheckStopReadingFn(); + abstract Builder<K, V> toBuilder(); @Experimental(Kind.PORTABILITY) @@ -553,6 +560,9 @@ public class KafkaIO { abstract Builder<K, V> setValueDeserializerProvider( DeserializerProvider deserializerProvider); + abstract Builder<K, V> setCheckStopReadingFn( + SerializableFunction<TopicPartition, Boolean> checkStopReadingFn); + abstract Read<K, V> build(); @Override @@ -998,6 +1008,15 @@ public class KafkaIO { return toBuilder().setConsumerConfig(config).build(); } + /** + * A custom {@link SerializableFunction} that determines whether the {@link ReadFromKafkaDoFn} + * should stop reading from the given {@link TopicPartition}. + */ + public Read<K, V> withCheckStopReadingFn( + SerializableFunction<TopicPartition, Boolean> checkStopReadingFn) { + return toBuilder().setCheckStopReadingFn(checkStopReadingFn).build(); + } + /** Returns a {@link PTransform} for PCollection of {@link KV}, dropping Kafka metatdata. */ public PTransform<PBegin, PCollection<KV<K, V>>> withoutMetadata() { return new TypedWithoutMetadata<>(this); @@ -1080,7 +1099,8 @@ public class KafkaIO { .withKeyDeserializerProvider(getKeyDeserializerProvider()) .withValueDeserializerProvider(getValueDeserializerProvider()) .withManualWatermarkEstimator() - .withTimestampPolicyFactory(getTimestampPolicyFactory()); + .withTimestampPolicyFactory(getTimestampPolicyFactory()) + .withCheckStopReadingFn(getCheckStopReadingFn()); if (isCommitOffsetsInFinalizeEnabled()) { readTransform = readTransform.commitOffsets(); } @@ -1267,6 +1287,8 @@ public class KafkaIO { abstract SerializableFunction<Map<String, Object>, Consumer<byte[], byte[]>> getConsumerFactoryFn(); + abstract @Nullable SerializableFunction<TopicPartition, Boolean> getCheckStopReadingFn(); + abstract @Nullable SerializableFunction<KafkaRecord<K, V>, Instant> getExtractOutputTimestampFn(); @@ -1289,6 +1311,9 @@ public class KafkaIO { abstract ReadSourceDescriptors.Builder<K, V> setConsumerFactoryFn( SerializableFunction<Map<String, Object>, Consumer<byte[], byte[]>> consumerFactoryFn); + abstract ReadSourceDescriptors.Builder<K, V> setCheckStopReadingFn( + SerializableFunction<TopicPartition, Boolean> checkStopReadingFn); + abstract ReadSourceDescriptors.Builder<K, V> setKeyDeserializerProvider( DeserializerProvider deserializerProvider); @@ -1403,6 +1428,15 @@ public class KafkaIO { } /** + * A custom {@link SerializableFunction} that determines whether the {@link ReadFromKafkaDoFn} + * should stop reading from the given {@link TopicPartition}. + */ + public ReadSourceDescriptors<K, V> withCheckStopReadingFn( + SerializableFunction<TopicPartition, Boolean> checkStopReadingFn) { + return toBuilder().setCheckStopReadingFn(checkStopReadingFn).build(); + } + + /** * Updates configuration for the main consumer. This method merges updates from the provided map * with any prior updates using {@link KafkaIOUtils#DEFAULT_CONSUMER_PROPERTIES} as the starting * configuration. diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java index 08a590b..d12332a 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFn.java @@ -20,8 +20,11 @@ package org.apache.beam.sdk.io.kafka; import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState; import java.util.HashMap; +import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.concurrent.TimeUnit; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.io.kafka.KafkaIO.ReadSourceDescriptors; @@ -52,6 +55,7 @@ import org.apache.kafka.clients.consumer.Consumer; import org.apache.kafka.clients.consumer.ConsumerConfig; import org.apache.kafka.clients.consumer.ConsumerRecord; import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.common.PartitionInfo; import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.serialization.Deserializer; import org.joda.time.Duration; @@ -116,6 +120,23 @@ import org.slf4j.LoggerFactory; * extractTimestampFn} and {@link * ReadSourceDescriptors#withMonotonicallyIncreasingWatermarkEstimator()} as the {@link * WatermarkEstimator}. + * + * <h4>Stop Reading from Removed {@link TopicPartition}</h4> + * + * {@link ReadFromKafkaDoFn} will stop reading from any removed {@link TopicPartition} automatically + * by querying Kafka {@link Consumer} APIs. Please note that stopping reading may not happen as soon + * as the {@link TopicPartition} is removed. For example, the removal could happen at the same time + * when {@link ReadFromKafkaDoFn} performs a {@link Consumer#poll(java.time.Duration)}. In that + * case, the {@link ReadFromKafkaDoFn} will still output the fetched records. + * + * <h4>Stop Reading from Stopped {@link TopicPartition}</h4> + * + * {@link ReadFromKafkaDoFn} will also stop reading from certain {@link TopicPartition} if it's a + * good time to do so by querying {@link ReadFromKafkaDoFn#checkStopReadingFn}. {@link + * ReadFromKafkaDoFn#checkStopReadingFn} is a customer-provided callback which is used to determine + * whether to stop reading from the given {@link TopicPartition}. Similar to the mechanism of + * stopping reading from removed {@link TopicPartition}, the stopping reading may not happens + * immediately. */ @UnboundedPerElement @SuppressWarnings({ @@ -134,12 +155,15 @@ class ReadFromKafkaDoFn<K, V> this.extractOutputTimestampFn = transform.getExtractOutputTimestampFn(); this.createWatermarkEstimatorFn = transform.getCreateWatermarkEstimatorFn(); this.timestampPolicyFactory = transform.getTimestampPolicyFactory(); + this.checkStopReadingFn = transform.getCheckStopReadingFn(); } private static final Logger LOG = LoggerFactory.getLogger(ReadFromKafkaDoFn.class); private final Map<String, Object> offsetConsumerConfig; + private final SerializableFunction<TopicPartition, Boolean> checkStopReadingFn; + private final SerializableFunction<Map<String, Object>, Consumer<byte[], byte[]>> consumerFactoryFn; private final SerializableFunction<KafkaRecord<K, V>, Instant> extractOutputTimestampFn; @@ -275,7 +299,11 @@ class ReadFromKafkaDoFn<K, V> RestrictionTracker<OffsetRange, Long> tracker, WatermarkEstimator watermarkEstimator, OutputReceiver<KV<KafkaSourceDescriptor, KafkaRecord<K, V>>> receiver) { - // If there is no future work, resume with max timeout and move to the next element. + // Stop processing current TopicPartition when it's time to stop. + if (checkStopReadingFn != null + && checkStopReadingFn.apply(kafkaSourceDescriptor.getTopicPartition())) { + return ProcessContinuation.stop(); + } Map<String, Object> updatedConsumerConfig = overrideBootstrapServersConfig(consumerConfig, kafkaSourceDescriptor); // If there is a timestampPolicyFactory, create the TimestampPolicy for current @@ -288,6 +316,19 @@ class ReadFromKafkaDoFn<K, V> Optional.ofNullable(watermarkEstimator.currentWatermark())); } try (Consumer<byte[], byte[]> consumer = consumerFactoryFn.apply(updatedConsumerConfig)) { + // Check whether current TopicPartition is still available to read. + Set<TopicPartition> existingTopicPartitions = new HashSet<>(); + for (List<PartitionInfo> topicPartitionList : consumer.listTopics().values()) { + topicPartitionList.forEach( + partitionInfo -> { + existingTopicPartitions.add( + new TopicPartition(partitionInfo.topic(), partitionInfo.partition())); + }); + } + if (!existingTopicPartitions.contains(kafkaSourceDescriptor.getTopicPartition())) { + return ProcessContinuation.stop(); + } + consumerSpEL.evaluateAssign( consumer, ImmutableList.of(kafkaSourceDescriptor.getTopicPartition())); long startOffset = tracker.currentRestriction().getFrom(); diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java new file mode 100644 index 0000000..62c28b2 --- /dev/null +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/ReadFromKafkaDoFnTest.java @@ -0,0 +1,344 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.io.kafka; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.apache.beam.sdk.io.kafka.KafkaIO.ReadSourceDescriptors; +import org.apache.beam.sdk.io.range.OffsetRange; +import org.apache.beam.sdk.transforms.DoFn.OutputReceiver; +import org.apache.beam.sdk.transforms.DoFn.ProcessContinuation; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.splittabledofn.OffsetRangeTracker; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Charsets; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.MockConsumer; +import org.apache.kafka.clients.consumer.OffsetAndTimestamp; +import org.apache.kafka.clients.consumer.OffsetResetStrategy; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.header.internals.RecordHeaders; +import org.apache.kafka.common.serialization.StringDeserializer; +import org.checkerframework.checker.initialization.qual.Initialized; +import org.checkerframework.checker.nullness.qual.NonNull; +import org.checkerframework.checker.nullness.qual.UnknownKeyFor; +import org.joda.time.Instant; +import org.junit.Before; +import org.junit.Test; +import org.testcontainers.shaded.com.google.common.collect.ImmutableMap; + +@SuppressWarnings({ + "rawtypes", // TODO(https://issues.apache.org/jira/browse/BEAM-10556) + "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-10402) +}) +public class ReadFromKafkaDoFnTest { + + private final TopicPartition topicPartition = new TopicPartition("topic", 0); + + private final SimpleMockKafkaConsumer consumer = + new SimpleMockKafkaConsumer(OffsetResetStrategy.NONE, topicPartition); + + private final ReadFromKafkaDoFn<String, String> dofnInstance = + new ReadFromKafkaDoFn(makeReadSourceDescriptor(consumer)); + + private ReadSourceDescriptors<String, String> makeReadSourceDescriptor( + Consumer kafkaMockConsumer) { + return ReadSourceDescriptors.<String, String>read() + .withKeyDeserializer(StringDeserializer.class) + .withValueDeserializer(StringDeserializer.class) + .withConsumerFactoryFn( + new SerializableFunction<Map<String, Object>, Consumer<byte[], byte[]>>() { + @Override + public Consumer<byte[], byte[]> apply(Map<String, Object> input) { + return kafkaMockConsumer; + } + }) + .withBootstrapServers("bootstrap_server"); + } + + private static class SimpleMockKafkaConsumer extends MockConsumer<byte[], byte[]> { + + private final TopicPartition topicPartition; + private boolean isRemoved = false; + private long currentPos = 0L; + private long startOffset = 0L; + private long startOffsetForTime = 0L; + private long numOfRecordsPerPoll; + + public SimpleMockKafkaConsumer( + OffsetResetStrategy offsetResetStrategy, TopicPartition topicPartition) { + super(offsetResetStrategy); + this.topicPartition = topicPartition; + } + + public void reset() { + this.isRemoved = false; + this.currentPos = 0L; + this.startOffset = 0L; + this.startOffsetForTime = 0L; + this.numOfRecordsPerPoll = 0L; + } + + public void setRemoved() { + this.isRemoved = true; + } + + public void setNumOfRecordsPerPoll(long num) { + this.numOfRecordsPerPoll = num; + } + + public void setCurrentPos(long pos) { + this.currentPos = pos; + } + + public void setStartOffsetForTime(long pos) { + this.startOffsetForTime = pos; + } + + @Override + public synchronized Map<String, List<PartitionInfo>> listTopics() { + if (this.isRemoved) { + return ImmutableMap.of(); + } + return ImmutableMap.of( + topicPartition.topic(), + ImmutableList.of( + new PartitionInfo( + topicPartition.topic(), topicPartition.partition(), null, null, null))); + } + + @Override + public synchronized void assign(Collection<TopicPartition> partitions) { + assertTrue(Iterables.getOnlyElement(partitions).equals(this.topicPartition)); + } + + @Override + public synchronized void seek(TopicPartition partition, long offset) { + assertTrue(partition.equals(this.topicPartition)); + this.startOffset = offset; + } + + @Override + public synchronized ConsumerRecords<byte[], byte[]> poll(long timeout) { + if (topicPartition == null) { + return ConsumerRecords.empty(); + } + String key = "key"; + String value = "value"; + List<ConsumerRecord<byte[], byte[]>> records = new ArrayList<>(); + for (long i = 0; i <= numOfRecordsPerPoll; i++) { + records.add( + new ConsumerRecord<byte[], byte[]>( + topicPartition.topic(), + topicPartition.partition(), + startOffset + i, + key.getBytes(Charsets.UTF_8), + value.getBytes(Charsets.UTF_8))); + } + if (records.isEmpty()) { + return ConsumerRecords.empty(); + } + return new ConsumerRecords(ImmutableMap.of(topicPartition, records)); + } + + @Override + public synchronized Map<TopicPartition, OffsetAndTimestamp> offsetsForTimes( + Map<TopicPartition, Long> timestampsToSearch) { + assertTrue( + Iterables.getOnlyElement( + timestampsToSearch.keySet().stream().collect(Collectors.toList())) + .equals(this.topicPartition)); + return ImmutableMap.of( + topicPartition, + new OffsetAndTimestamp( + this.startOffsetForTime, Iterables.getOnlyElement(timestampsToSearch.values()))); + } + + @Override + public synchronized long position(TopicPartition partition) { + assertTrue(partition.equals(this.topicPartition)); + return this.currentPos; + } + } + + private static class MockOutputReceiver + implements OutputReceiver<KV<KafkaSourceDescriptor, KafkaRecord<String, String>>> { + + private final List<KV<KafkaSourceDescriptor, KafkaRecord<String, String>>> records = + new ArrayList<>(); + + @Override + public void output(KV<KafkaSourceDescriptor, KafkaRecord<String, String>> output) {} + + @Override + public void outputWithTimestamp( + KV<KafkaSourceDescriptor, KafkaRecord<String, String>> output, + @UnknownKeyFor @NonNull @Initialized Instant timestamp) { + records.add(output); + } + + public List<KV<KafkaSourceDescriptor, KafkaRecord<String, String>>> getOutputs() { + return this.records; + } + } + + private List<KV<KafkaSourceDescriptor, KafkaRecord<String, String>>> createExpectedRecords( + KafkaSourceDescriptor descriptor, + long startOffset, + int numRecords, + String key, + String value) { + List<KV<KafkaSourceDescriptor, KafkaRecord<String, String>>> records = new ArrayList<>(); + for (int i = 0; i < numRecords; i++) { + records.add( + KV.of( + descriptor, + new KafkaRecord<String, String>( + topicPartition.topic(), + topicPartition.partition(), + startOffset + i, + -1L, + KafkaTimestampType.NO_TIMESTAMP_TYPE, + new RecordHeaders(), + KV.of(key, value)))); + } + return records; + } + + @Before + public void setUp() throws Exception { + dofnInstance.setup(); + consumer.reset(); + } + + @Test + public void testInitialRestrictionWhenHasStartOffset() throws Exception { + long expectedStartOffset = 10L; + consumer.setStartOffsetForTime(15L); + consumer.setCurrentPos(5L); + OffsetRange result = + dofnInstance.initialRestriction( + KafkaSourceDescriptor.of( + topicPartition, expectedStartOffset, Instant.now(), ImmutableList.of())); + assertEquals(new OffsetRange(expectedStartOffset, Long.MAX_VALUE), result); + } + + @Test + public void testInitialRestrictionWhenHasStartTime() throws Exception { + long expectedStartOffset = 10L; + consumer.setStartOffsetForTime(expectedStartOffset); + consumer.setCurrentPos(5L); + OffsetRange result = + dofnInstance.initialRestriction( + KafkaSourceDescriptor.of(topicPartition, null, Instant.now(), ImmutableList.of())); + assertEquals(new OffsetRange(expectedStartOffset, Long.MAX_VALUE), result); + } + + @Test + public void testInitialRestrictionWithConsumerPosition() throws Exception { + long expectedStartOffset = 5L; + consumer.setCurrentPos(5L); + OffsetRange result = + dofnInstance.initialRestriction( + KafkaSourceDescriptor.of(topicPartition, null, null, ImmutableList.of())); + assertEquals(new OffsetRange(expectedStartOffset, Long.MAX_VALUE), result); + } + + @Test + public void testProcessElement() throws Exception { + MockOutputReceiver receiver = new MockOutputReceiver(); + consumer.setNumOfRecordsPerPoll(3L); + long startOffset = 5L; + OffsetRangeTracker tracker = + new OffsetRangeTracker(new OffsetRange(startOffset, startOffset + 3)); + KafkaSourceDescriptor descriptor = KafkaSourceDescriptor.of(topicPartition, null, null, null); + ProcessContinuation result = + dofnInstance.processElement(descriptor, tracker, null, (OutputReceiver) receiver); + assertEquals(ProcessContinuation.stop(), result); + assertEquals( + createExpectedRecords(descriptor, startOffset, 3, "key", "value"), receiver.getOutputs()); + } + + @Test + public void testProcessElementWithEmptyPoll() throws Exception { + MockOutputReceiver receiver = new MockOutputReceiver(); + consumer.setNumOfRecordsPerPoll(-1); + OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(0L, Long.MAX_VALUE)); + ProcessContinuation result = + dofnInstance.processElement( + KafkaSourceDescriptor.of(topicPartition, null, null, null), + tracker, + null, + (OutputReceiver) receiver); + assertEquals(ProcessContinuation.resume(), result); + assertTrue(receiver.getOutputs().isEmpty()); + } + + @Test + public void testProcessElementWhenTopicPartitionIsRemoved() throws Exception { + MockOutputReceiver receiver = new MockOutputReceiver(); + consumer.setRemoved(); + consumer.setNumOfRecordsPerPoll(10); + OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(0L, Long.MAX_VALUE)); + ProcessContinuation result = + dofnInstance.processElement( + KafkaSourceDescriptor.of(topicPartition, null, null, null), + tracker, + null, + (OutputReceiver) receiver); + assertEquals(ProcessContinuation.stop(), result); + } + + @Test + public void testProcessElementWhenTopicPartitionIsStopped() throws Exception { + MockOutputReceiver receiver = new MockOutputReceiver(); + ReadFromKafkaDoFn<String, String> instance = + new ReadFromKafkaDoFn( + makeReadSourceDescriptor(consumer) + .toBuilder() + .setCheckStopReadingFn( + new SerializableFunction<TopicPartition, Boolean>() { + @Override + public Boolean apply(TopicPartition input) { + assertTrue(input.equals(topicPartition)); + return true; + } + }) + .build()); + consumer.setNumOfRecordsPerPoll(10); + OffsetRangeTracker tracker = new OffsetRangeTracker(new OffsetRange(0L, Long.MAX_VALUE)); + ProcessContinuation result = + instance.processElement( + KafkaSourceDescriptor.of(topicPartition, null, null, null), + tracker, + null, + (OutputReceiver) receiver); + assertEquals(ProcessContinuation.stop(), result); + } +}