This is an automated email from the ASF dual-hosted git repository. anton pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new 9d131d4 [BEAM-7896] Implementing RateEstimation for KafkaTable with Unit and Integration Tests new cd2ab9e Merge pull request #9298 from riazela/KafkaRateEstimation2 9d131d4 is described below commit 9d131d490dfa1b4838d0303a3f17f36202c0874b Author: Alireza Samadian <alireza4...@gmail.com> AuthorDate: Tue Aug 6 16:56:03 2019 -0700 [BEAM-7896] Implementing RateEstimation for KafkaTable with Unit and Integration Tests --- sdks/java/extensions/sql/build.gradle | 1 + .../sql/meta/provider/kafka/BeamKafkaTable.java | 147 +++++++++-- .../meta/provider/kafka/BeamKafkaCSVTableTest.java | 118 ++++++++- .../sql/meta/provider/kafka/KafkaCSVTableIT.java | 292 +++++++++++++++++++++ .../sql/meta/provider/kafka/KafkaCSVTestTable.java | 197 ++++++++++++++ .../sql/meta/provider/kafka/KafkaTestRecord.java | 39 +++ 6 files changed, 777 insertions(+), 17 deletions(-) diff --git a/sdks/java/extensions/sql/build.gradle b/sdks/java/extensions/sql/build.gradle index b4a7079..fe07bfe 100644 --- a/sdks/java/extensions/sql/build.gradle +++ b/sdks/java/extensions/sql/build.gradle @@ -203,6 +203,7 @@ task integrationTest(type: Test) { systemProperty "beamTestPipelineOptions", JsonOutput.toJson(pipelineOptions) include '**/*IT.class' + exclude '**/KafkaCSVTableIT.java' maxParallelForks 4 classpath = project(":sdks:java:extensions:sql") .sourceSets diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaTable.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaTable.java index 0e1dab3..11c12f6 100644 --- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaTable.java +++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaTable.java @@ -19,9 +19,13 @@ package org.apache.beam.sdk.extensions.sql.meta.provider.kafka; import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; +import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Properties; +import java.util.stream.Collectors; +import java.util.stream.Stream; import org.apache.beam.sdk.coders.ByteArrayCoder; import org.apache.beam.sdk.extensions.sql.impl.BeamTableStatistics; import org.apache.beam.sdk.extensions.sql.impl.schema.BaseBeamTable; @@ -34,9 +38,15 @@ import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.POutput; import org.apache.beam.sdk.values.Row; +import org.apache.kafka.clients.consumer.Consumer; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.ConsumerRecords; +import org.apache.kafka.clients.consumer.KafkaConsumer; import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.serialization.ByteArrayDeserializer; import org.apache.kafka.common.serialization.ByteArraySerializer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * {@code BeamKafkaTable} represent a Kafka topic, as source or target. Need to extend to convert @@ -47,6 +57,10 @@ public abstract class BeamKafkaTable extends BaseBeamTable { private List<String> topics; private List<TopicPartition> topicPartitions; private Map<String, Object> configUpdates; + private BeamTableStatistics rowCountStatistics = null; + private static final Logger LOGGER = LoggerFactory.getLogger(BeamKafkaTable.class); + // This is the number of records looked from each partition when the rate is estimated + protected int numberOfRecordsForRate = 50; protected BeamKafkaTable(Schema beamSchema) { super(beamSchema); @@ -84,7 +98,14 @@ public abstract class BeamKafkaTable extends BaseBeamTable { @Override public PCollection<Row> buildIOReader(PBegin begin) { - KafkaIO.Read<byte[], byte[]> kafkaRead = null; + return begin + .apply("read", createKafkaRead().withoutMetadata()) + .apply("in_format", getPTransformForInput()) + .setRowSchema(getSchema()); + } + + KafkaIO.Read<byte[], byte[]> createKafkaRead() { + KafkaIO.Read<byte[], byte[]> kafkaRead; if (topics != null) { kafkaRead = KafkaIO.<byte[], byte[]>read() @@ -104,28 +125,25 @@ public abstract class BeamKafkaTable extends BaseBeamTable { } else { throw new IllegalArgumentException("One of topics and topicPartitions must be configurated."); } - - return begin - .apply("read", kafkaRead.withoutMetadata()) - .apply("in_format", getPTransformForInput()) - .setRowSchema(getSchema()); + return kafkaRead; } @Override public POutput buildIOWriter(PCollection<Row> input) { checkArgument( topics != null && topics.size() == 1, "Only one topic can be acceptable as output."); - assert topics != null; return input .apply("out_reformat", getPTransformForOutput()) - .apply( - "persistent", - KafkaIO.<byte[], byte[]>write() - .withBootstrapServers(bootstrapServers) - .withTopic(topics.get(0)) - .withKeySerializer(ByteArraySerializer.class) - .withValueSerializer(ByteArraySerializer.class)); + .apply("persistent", createKafkaWrite()); + } + + private KafkaIO.Write<byte[], byte[]> createKafkaWrite() { + return KafkaIO.<byte[], byte[]>write() + .withBootstrapServers(bootstrapServers) + .withTopic(topics.get(0)) + .withKeySerializer(ByteArraySerializer.class) + .withValueSerializer(ByteArraySerializer.class); } public String getBootstrapServers() { @@ -138,6 +156,105 @@ public abstract class BeamKafkaTable extends BaseBeamTable { @Override public BeamTableStatistics getTableStatistics(PipelineOptions options) { - return BeamTableStatistics.UNBOUNDED_UNKNOWN; + if (rowCountStatistics == null) { + try { + rowCountStatistics = + BeamTableStatistics.createUnboundedTableStatistics( + this.computeRate(numberOfRecordsForRate)); + } catch (Exception e) { + LOGGER.warn("Could not get the row count for the topics " + getTopics(), e); + rowCountStatistics = BeamTableStatistics.UNBOUNDED_UNKNOWN; + } + } + + return rowCountStatistics; + } + + /** + * This method returns the estimate of the computeRate for this table using last numberOfRecords + * tuples in each partition. + */ + double computeRate(int numberOfRecords) throws NoEstimationException { + Properties props = new Properties(); + + props.put("bootstrap.servers", bootstrapServers); + props.put("session.timeout.ms", "30000"); + props.put("key.deserializer", "org.apache.kafka.common.serialization.ByteArrayDeserializer"); + props.put("value.deserializer", "org.apache.kafka.common.serialization.ByteArrayDeserializer"); + + KafkaConsumer<String, String> consumer = new KafkaConsumer<String, String>(props); + + return computeRate(consumer, numberOfRecords); + } + + <T> double computeRate(Consumer<T, T> consumer, int numberOfRecordsToCheck) + throws NoEstimationException { + + Stream<TopicPartition> c = + getTopics().stream() + .map(consumer::partitionsFor) + .flatMap(Collection::stream) + .map(parInf -> new TopicPartition(parInf.topic(), parInf.partition())); + List<TopicPartition> topicPartitions = c.collect(Collectors.toList()); + + consumer.assign(topicPartitions); + // This will return current offset of all the partitions that are assigned to the consumer. (It + // will be the last record in those partitions). Note that each topic can have multiple + // partitions. Since the consumer is not assigned to any consumer group, changing the offset or + // consuming messages does not have any effect on the other consumers (and the data that our + // table is receiving) + Map<TopicPartition, Long> offsets = consumer.endOffsets(topicPartitions); + long nParsSeen = 0; + for (TopicPartition par : topicPartitions) { + long offset = offsets.get(par); + nParsSeen = (offset == 0) ? nParsSeen : nParsSeen + 1; + consumer.seek(par, Math.max(0L, offset - numberOfRecordsToCheck)); + } + + if (nParsSeen == 0) { + throw new NoEstimationException("There is no partition with messages in it."); + } + + ConsumerRecords<T, T> records = consumer.poll(1000); + + // Kafka guarantees the delivery of messages in order they arrive to each partition. + // Therefore the first message seen from each partition is the first message arrived to that. + // We pick all the first messages of the partitions, and then consider the latest one as the + // starting point + // and discard all the messages that have arrived sooner than that in the rate estimation. + Map<Integer, Long> minTimeStamps = new HashMap<>(); + long maxMinTimeStamp = 0; + for (ConsumerRecord<T, T> record : records) { + if (!minTimeStamps.containsKey(record.partition())) { + minTimeStamps.put(record.partition(), record.timestamp()); + + nParsSeen--; + maxMinTimeStamp = Math.max(record.timestamp(), maxMinTimeStamp); + if (nParsSeen == 0) { + break; + } + } + } + + int numberOfRecords = 0; + long maxTimeStamp = 0; + for (ConsumerRecord<T, T> record : records) { + maxTimeStamp = Math.max(maxTimeStamp, record.timestamp()); + numberOfRecords = + record.timestamp() > maxMinTimeStamp ? numberOfRecords + 1 : numberOfRecords; + } + + if (maxTimeStamp == maxMinTimeStamp) { + throw new NoEstimationException("Arrival time of all records are the same."); + } + + return (numberOfRecords * 1000.) / ((double) maxTimeStamp - maxMinTimeStamp); + } + + /** Will be thrown if we cannot estimate the rate for kafka table. */ + static class NoEstimationException extends Exception { + NoEstimationException(String message) { + super(message); + } } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaCSVTableTest.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaCSVTableTest.java index 710a1a5..c407ff4 100644 --- a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaCSVTableTest.java +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/BeamKafkaCSVTableTest.java @@ -20,7 +20,13 @@ package org.apache.beam.sdk.extensions.sql.meta.provider.kafka; import static java.nio.charset.StandardCharsets.UTF_8; import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; +import org.apache.beam.sdk.extensions.sql.BeamSqlTable; +import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv; +import org.apache.beam.sdk.extensions.sql.impl.BeamTableStatistics; import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils; +import org.apache.beam.sdk.extensions.sql.meta.provider.test.TestTableUtils; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; @@ -30,11 +36,13 @@ import org.apache.beam.sdk.transforms.ParDo; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; import org.apache.calcite.adapter.java.JavaTypeFactory; import org.apache.calcite.jdbc.JavaTypeFactoryImpl; import org.apache.calcite.rel.type.RelDataTypeSystem; import org.apache.calcite.sql.type.SqlTypeName; import org.apache.commons.csv.CSVFormat; +import org.junit.Assert; import org.junit.Rule; import org.junit.Test; @@ -46,8 +54,101 @@ public class BeamKafkaCSVTableTest { private static final Row ROW2 = Row.withSchema(genSchema()).addValues(2L, 2, 2.0).build(); + private static Map<String, BeamSqlTable> tables = new HashMap<>(); + protected static BeamSqlEnv env = BeamSqlEnv.readOnly("test", tables); + + @Test + public void testOrderedArrivalSinglePartitionRate() { + KafkaCSVTestTable table = getTable(1); + for (int i = 0; i < 100; i++) { + table.addRecord(KafkaTestRecord.create("key1", i + ",1,2", "topic1", 500 * i)); + } + + BeamTableStatistics stats = table.getTableStatistics(null); + Assert.assertEquals(2d, stats.getRate(), 0.001); + } + + @Test + public void testOrderedArrivalMultiplePartitionsRate() { + KafkaCSVTestTable table = getTable(3); + for (int i = 0; i < 100; i++) { + table.addRecord(KafkaTestRecord.create("key" + i, i + ",1,2", "topic1", 500 * i)); + } + + BeamTableStatistics stats = table.getTableStatistics(null); + Assert.assertEquals(2d, stats.getRate(), 0.001); + } + + @Test + public void testOnePartitionAheadRate() { + KafkaCSVTestTable table = getTable(3); + for (int i = 0; i < 100; i++) { + table.addRecord(KafkaTestRecord.create("1", i + ",1,2", "topic1", 1000 * i)); + table.addRecord(KafkaTestRecord.create("2", i + ",1,2", "topic1", 500 * i)); + } + + table.setNumberOfRecordsForRate(20); + BeamTableStatistics stats = table.getTableStatistics(null); + Assert.assertEquals(1d, stats.getRate(), 0.001); + } + + @Test + public void testLateRecords() { + KafkaCSVTestTable table = getTable(3); + + table.addRecord(KafkaTestRecord.create("1", 132 + ",1,2", "topic1", 1000)); + for (int i = 0; i < 98; i++) { + table.addRecord(KafkaTestRecord.create("1", i + ",1,2", "topic1", 500)); + } + table.addRecord(KafkaTestRecord.create("1", 133 + ",1,2", "topic1", 2000)); + + table.setNumberOfRecordsForRate(200); + BeamTableStatistics stats = table.getTableStatistics(null); + Assert.assertEquals(1d, stats.getRate(), 0.001); + } + @Test - public void testCsvRecorderDecoder() throws Exception { + public void testAllLate() { + KafkaCSVTestTable table = getTable(3); + + table.addRecord(KafkaTestRecord.create("1", 132 + ",1,2", "topic1", 1000)); + for (int i = 0; i < 98; i++) { + table.addRecord(KafkaTestRecord.create("1", i + ",1,2", "topic1", 500)); + } + + table.setNumberOfRecordsForRate(200); + BeamTableStatistics stats = table.getTableStatistics(null); + Assert.assertTrue(stats.isUnknown()); + } + + @Test + public void testEmptyPartitionsRate() { + KafkaCSVTestTable table = getTable(3); + BeamTableStatistics stats = table.getTableStatistics(null); + Assert.assertTrue(stats.isUnknown()); + } + + @Test + public void allTheRecordsSameTimeRate() { + KafkaCSVTestTable table = getTable(3); + for (int i = 0; i < 100; i++) { + table.addRecord(KafkaTestRecord.create("key" + i, i + ",1,2", "topic1", 1000)); + } + BeamTableStatistics stats = table.getTableStatistics(null); + Assert.assertTrue(stats.isUnknown()); + } + + private static class PrintDoFn extends DoFn<Row, Row> { + + @ProcessElement + public void process(ProcessContext c) { + System.out.println("we are here"); + System.out.println(c.element().getValues()); + } + } + + @Test + public void testCsvRecorderDecoder() { PCollection<Row> result = pipeline .apply(Create.of("1,\"1\",1.0", "2,2,2.0")) @@ -60,7 +161,7 @@ public class BeamKafkaCSVTableTest { } @Test - public void testCsvRecorderEncoder() throws Exception { + public void testCsvRecorderEncoder() { PCollection<Row> result = pipeline .apply(Create.of(ROW1, ROW2)) @@ -90,4 +191,17 @@ public class BeamKafkaCSVTableTest { ctx.output(KV.of(new byte[] {}, ctx.element().getBytes(UTF_8))); } } + + private KafkaCSVTestTable getTable(int numberOfPartitions) { + return new KafkaCSVTestTable( + TestTableUtils.buildBeamSqlSchema( + Schema.FieldType.INT32, + "order_id", + Schema.FieldType.INT32, + "site_id", + Schema.FieldType.INT32, + "price"), + ImmutableList.of("topic1", "topic2"), + numberOfPartitions); + } } diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaCSVTableIT.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaCSVTableIT.java new file mode 100644 index 0000000..201a1df --- /dev/null +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaCSVTableIT.java @@ -0,0 +1,292 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.sql.meta.provider.kafka; + +import static org.apache.beam.sdk.schemas.Schema.FieldType.INT32; +import static org.apache.beam.sdk.schemas.Schema.toSchema; + +import com.alibaba.fastjson.JSON; +import java.util.Map; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; +import javax.annotation.Nullable; +import org.apache.beam.runners.direct.DirectOptions; +import org.apache.beam.sdk.extensions.sql.impl.BeamSqlEnv; +import org.apache.beam.sdk.extensions.sql.impl.rel.BeamSqlRelUtils; +import org.apache.beam.sdk.extensions.sql.meta.Table; +import org.apache.beam.sdk.extensions.sql.meta.provider.TableProvider; +import org.apache.beam.sdk.options.Default; +import org.apache.beam.sdk.options.Description; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.options.Validation; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.state.BagState; +import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.StateSpecs; +import org.apache.beam.sdk.state.ValueState; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.SimpleFunction; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.grpc.v1p21p0.com.google.common.base.MoreObjects; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet; +import org.apache.kafka.clients.producer.KafkaProducer; +import org.apache.kafka.clients.producer.Producer; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Rule; +import org.junit.Test; + +/** + * This is an integration test for KafkaCSVTable. There should be a kafka server running and the + * address should be passed to it. (https://issues.apache.org/jira/projects/BEAM/issues/BEAM-7523) + */ +public class KafkaCSVTableIT { + + @Rule public transient TestPipeline pipeline = TestPipeline.create(); + + private static final Schema TEST_TABLE_SCHEMA = + Schema.builder() + .addNullableField("order_id", Schema.FieldType.INT32) + .addNullableField("member_id", Schema.FieldType.INT32) + .addNullableField("item_name", Schema.FieldType.INT32) + .build(); + + @BeforeClass + public static void prepare() { + PipelineOptionsFactory.register(KafkaOptions.class); + } + + @Test + @SuppressWarnings("FutureReturnValueIgnored") + public void testFake2() throws BeamKafkaTable.NoEstimationException { + KafkaOptions kafkaOptions = pipeline.getOptions().as(KafkaOptions.class); + Table table = + Table.builder() + .name("kafka_table") + .comment("kafka" + " table") + .location("") + .schema( + Stream.of( + Schema.Field.nullable("order_id", INT32), + Schema.Field.nullable("member_id", INT32), + Schema.Field.nullable("item_name", INT32)) + .collect(toSchema())) + .type("kafka") + .properties(JSON.parseObject(getKafkaPropertiesString(kafkaOptions))) + .build(); + BeamKafkaTable kafkaTable = (BeamKafkaTable) new KafkaTableProvider().buildBeamSqlTable(table); + produceSomeRecordsWithDelay(100, 20); + double rate1 = kafkaTable.computeRate(20); + produceSomeRecordsWithDelay(100, 10); + double rate2 = kafkaTable.computeRate(20); + Assert.assertTrue(rate2 > rate1); + } + + private String getKafkaPropertiesString(KafkaOptions kafkaOptions) { + return "{ \"bootstrap.servers\" : \"" + + kafkaOptions.getKafkaBootstrapServerAddress() + + "\",\"topics\":[\"" + + kafkaOptions.getKafkaTopic() + + "\"] }"; + } + + static final transient Map<Long, Boolean> FLAG = new ConcurrentHashMap<>(); + + @Test + public void testFake() throws InterruptedException { + KafkaOptions kafkaOptions = pipeline.getOptions().as(KafkaOptions.class); + pipeline.getOptions().as(DirectOptions.class).setBlockOnRun(false); + String createTableString = + "CREATE EXTERNAL TABLE kafka_table(\n" + + "order_id INTEGER, \n" + + "member_id INTEGER, \n" + + "item_name INTEGER \n" + + ") \n" + + "TYPE 'kafka' \n" + + "LOCATION '" + + "'\n" + + "TBLPROPERTIES '" + + getKafkaPropertiesString(kafkaOptions) + + "'"; + TableProvider tb = new KafkaTableProvider(); + BeamSqlEnv env = BeamSqlEnv.inMemory(tb); + + env.executeDdl(createTableString); + + PCollection<Row> queryOutput = + BeamSqlRelUtils.toPCollection(pipeline, env.parseQuery("SELECT * FROM kafka_table")); + + queryOutput + .apply(ParDo.of(new FakeKvPair())) + .apply( + "waitForSuccess", + ParDo.of( + new StreamAssertEqual( + ImmutableSet.of( + row(TEST_TABLE_SCHEMA, 0, 1, 0), + row(TEST_TABLE_SCHEMA, 1, 2, 1), + row(TEST_TABLE_SCHEMA, 2, 3, 2))))); + queryOutput.apply(logRecords("")); + pipeline.run(); + TimeUnit.MILLISECONDS.sleep(3000); + produceSomeRecords(3); + + for (int i = 0; i < 200; i++) { + if (FLAG.getOrDefault(pipeline.getOptions().getOptionsId(), false)) { + return; + } + TimeUnit.MILLISECONDS.sleep(60); + } + Assert.fail(); + } + + private static MapElements<Row, Void> logRecords(String suffix) { + return MapElements.via( + new SimpleFunction<Row, Void>() { + @Override + public @Nullable Void apply(Row input) { + System.out.println(input.getValues() + suffix); + return null; + } + }); + } + + /** This is made because DoFn with states should get KV as input. */ + public static class FakeKvPair extends DoFn<Row, KV<String, Row>> { + @ProcessElement + public void processElement(ProcessContext c) { + c.output(KV.of("fake_key", c.element())); + } + } + + /** This DoFn will set a flag if all the elements are seen. */ + public static class StreamAssertEqual extends DoFn<KV<String, Row>, Void> { + private final Set<Row> expected; + + StreamAssertEqual(Set<Row> expected) { + super(); + this.expected = expected; + } + + @DoFn.StateId("seenValues") + private final StateSpec<BagState<Row>> seenRows = StateSpecs.bag(); + + @StateId("count") + private final StateSpec<ValueState<Integer>> countState = StateSpecs.value(); + + @ProcessElement + public void process( + ProcessContext context, + @StateId("seenValues") BagState<Row> seenValues, + @StateId("count") ValueState<Integer> countState) { + // I don't think doing this will be safe in parallel + int count = MoreObjects.firstNonNull(countState.read(), 0); + count = count + 1; + countState.write(count); + seenValues.add(context.element().getValue()); + + if (count >= expected.size()) { + if (StreamSupport.stream(seenValues.read().spliterator(), false) + .collect(Collectors.toSet()) + .containsAll(expected)) { + System.out.println("in second if"); + FLAG.put(context.getPipelineOptions().getOptionsId(), true); + } + } + } + } + + private Row row(Schema schema, Object... values) { + return Row.withSchema(schema).addValues(values).build(); + } + + @SuppressWarnings("FutureReturnValueIgnored") + private void produceSomeRecords(int num) { + Producer<String, String> producer = new KafkaProducer<String, String>(producerProps()); + String topicName = pipeline.getOptions().as(KafkaOptions.class).getKafkaTopic(); + for (int i = 0; i < num; i++) { + producer.send( + new ProducerRecord<String, String>( + topicName, "k" + i, i + "," + ((i % 3) + 1) + "," + i)); + } + producer.flush(); + producer.close(); + } + + @SuppressWarnings("FutureReturnValueIgnored") + private void produceSomeRecordsWithDelay(int num, int delayMilis) { + Producer<String, String> producer = new KafkaProducer<String, String>(producerProps()); + String topicName = pipeline.getOptions().as(KafkaOptions.class).getKafkaTopic(); + for (int i = 0; i < num; i++) { + producer.send( + new ProducerRecord<String, String>( + topicName, "k" + i, i + "," + ((i % 3) + 1) + "," + i)); + try { + TimeUnit.MILLISECONDS.sleep(delayMilis); + } catch (InterruptedException e) { + throw new RuntimeException("Could not wait for producing", e); + } + } + producer.flush(); + producer.close(); + } + + private Properties producerProps() { + KafkaOptions options = pipeline.getOptions().as(KafkaOptions.class); + Properties props = new Properties(); + props.put("bootstrap.servers", options.getKafkaBootstrapServerAddress()); + props.put("key.serializer", "org.apache.kafka.common.serialization.StringSerializer"); + props.put("value.serializer", "org.apache.kafka.common.serialization.StringSerializer"); + props.put("buffer.memory", 33554432); + props.put("acks", "all"); + props.put("request.required.acks", "1"); + props.put("retries", 0); + props.put("linger.ms", 1); + return props; + } + + /** Pipeline options specific for this test. */ + public interface KafkaOptions extends PipelineOptions { + + @Description("Kafka server address") + @Validation.Required + @Default.String("localhost:9092") + String getKafkaBootstrapServerAddress(); + + void setKafkaBootstrapServerAddress(String address); + + @Description("Kafka topic") + @Validation.Required + @Default.String("test") + String getKafkaTopic(); + + void setKafkaTopic(String topic); + } +} diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaCSVTestTable.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaCSVTestTable.java new file mode 100644 index 0000000..749adea --- /dev/null +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaCSVTestTable.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.sql.meta.provider.kafka; + +import static java.nio.charset.StandardCharsets.UTF_8; + +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; +import org.apache.beam.sdk.io.kafka.KafkaIO; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.Uninterruptibles; +import org.apache.kafka.clients.consumer.ConsumerRecord; +import org.apache.kafka.clients.consumer.MockConsumer; +import org.apache.kafka.clients.consumer.OffsetAndTimestamp; +import org.apache.kafka.clients.consumer.OffsetResetStrategy; +import org.apache.kafka.common.KafkaException; +import org.apache.kafka.common.PartitionInfo; +import org.apache.kafka.common.TopicPartition; +import org.apache.kafka.common.record.TimestampType; + +/** This is a MockKafkaCSVTestTable. It will use a Mock Consumer. */ +public class KafkaCSVTestTable extends BeamKafkaCSVTable { + private int partitionsPerTopic; + private List<KafkaTestRecord> records; + private static final String TIMESTAMP_TYPE_CONFIG = "test.timestamp.type"; + + public KafkaCSVTestTable(Schema beamSchema, List<String> topics, int partitionsPerTopic) { + super(beamSchema, "server:123", topics); + this.partitionsPerTopic = partitionsPerTopic; + this.records = new ArrayList<>(); + } + + @Override + KafkaIO.Read<byte[], byte[]> createKafkaRead() { + return super.createKafkaRead().withConsumerFactoryFn(this::mkMockConsumer); + } + + public void addRecord(KafkaTestRecord record) { + records.add(record); + } + + @Override + double computeRate(int numberOfRecords) throws NoEstimationException { + return super.computeRate(mkMockConsumer(new HashMap<>()), numberOfRecords); + } + + public void setNumberOfRecordsForRate(int numberOfRecordsForRate) { + this.numberOfRecordsForRate = numberOfRecordsForRate; + } + + private MockConsumer<byte[], byte[]> mkMockConsumer(Map<String, Object> config) { + OffsetResetStrategy offsetResetStrategy = OffsetResetStrategy.EARLIEST; + final Map<TopicPartition, List<ConsumerRecord<byte[], byte[]>>> kafkaRecords = new HashMap<>(); + Map<String, List<PartitionInfo>> partitionInfoMap = new HashMap<>(); + Map<String, List<TopicPartition>> partitionMap = new HashMap<>(); + + // Create Topic Paritions + for (String topic : this.getTopics()) { + List<PartitionInfo> partIds = new ArrayList<>(partitionsPerTopic); + List<TopicPartition> topicParitions = new ArrayList<>(partitionsPerTopic); + for (int i = 0; i < partitionsPerTopic; i++) { + TopicPartition tp = new TopicPartition(topic, i); + topicParitions.add(tp); + partIds.add(new PartitionInfo(topic, i, null, null, null)); + kafkaRecords.put(tp, new ArrayList<>()); + } + partitionInfoMap.put(topic, partIds); + partitionMap.put(topic, topicParitions); + } + + TimestampType timestampType = + TimestampType.forName( + (String) + config.getOrDefault( + TIMESTAMP_TYPE_CONFIG, TimestampType.LOG_APPEND_TIME.toString())); + + for (KafkaTestRecord record : this.records) { + int partitionIndex = record.getKey().hashCode() % partitionsPerTopic; + TopicPartition tp = partitionMap.get(record.getTopic()).get(partitionIndex); + byte[] key = record.getKey().getBytes(UTF_8); + byte[] value = record.getValue().getBytes(UTF_8); + kafkaRecords + .get(tp) + .add( + new ConsumerRecord<>( + tp.topic(), + tp.partition(), + kafkaRecords.get(tp).size(), + record.getTimeStamp(), + timestampType, + 0, + key.length, + value.length, + key, + value)); + } + + // This is updated when reader assigns partitions. + final AtomicReference<List<TopicPartition>> assignedPartitions = + new AtomicReference<>(Collections.<TopicPartition>emptyList()); + final MockConsumer<byte[], byte[]> consumer = + new MockConsumer<byte[], byte[]>(offsetResetStrategy) { + @Override + public synchronized void assign(final Collection<TopicPartition> assigned) { + Collection<TopicPartition> realPartitions = + assigned.stream() + .map(part -> partitionMap.get(part.topic()).get(part.partition())) + .collect(Collectors.toList()); + super.assign(realPartitions); + assignedPartitions.set(ImmutableList.copyOf(realPartitions)); + for (TopicPartition tp : realPartitions) { + updateBeginningOffsets(ImmutableMap.of(tp, 0L)); + updateEndOffsets(ImmutableMap.of(tp, (long) kafkaRecords.get(tp).size())); + } + } + // Override offsetsForTimes() in order to look up the offsets by timestamp. + @Override + public synchronized Map<TopicPartition, OffsetAndTimestamp> offsetsForTimes( + Map<TopicPartition, Long> timestampsToSearch) { + return timestampsToSearch.entrySet().stream() + .map( + e -> { + // In test scope, timestamp == offset. ???? + long maxOffset = kafkaRecords.get(e.getKey()).size(); + long offset = e.getValue(); + OffsetAndTimestamp value = + (offset >= maxOffset) ? null : new OffsetAndTimestamp(offset, offset); + return new AbstractMap.SimpleEntry<>(e.getKey(), value); + }) + .collect( + Collectors.toMap( + AbstractMap.SimpleEntry::getKey, AbstractMap.SimpleEntry::getValue)); + } + }; + + for (String topic : getTopics()) { + consumer.updatePartitions(topic, partitionInfoMap.get(topic)); + } + + Runnable recordEnqueueTask = + new Runnable() { + @Override + public void run() { + // add all the records with offset >= current partition position. + int recordsAdded = 0; + for (TopicPartition tp : assignedPartitions.get()) { + long curPos = consumer.position(tp); + for (ConsumerRecord<byte[], byte[]> r : kafkaRecords.get(tp)) { + if (r.offset() >= curPos) { + consumer.addRecord(r); + recordsAdded++; + } + } + } + if (recordsAdded == 0) { + if (config.get("inject.error.at.eof") != null) { + consumer.setException(new KafkaException("Injected error in consumer.poll()")); + } + // MockConsumer.poll(timeout) does not actually wait even when there aren't any + // records. + // Add a small wait here in order to avoid busy looping in the reader. + Uninterruptibles.sleepUninterruptibly(10, TimeUnit.MILLISECONDS); + } + consumer.schedulePollTask(this); + } + }; + + consumer.schedulePollTask(recordEnqueueTask); + + return consumer; + } +} diff --git a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaTestRecord.java b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaTestRecord.java new file mode 100644 index 0000000..015ac8b --- /dev/null +++ b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/meta/provider/kafka/KafkaTestRecord.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.extensions.sql.meta.provider.kafka; + +import com.google.auto.value.AutoValue; +import java.io.Serializable; + +/** This class is created because Kafka Consumer Records are not serializable. */ +@AutoValue +public abstract class KafkaTestRecord implements Serializable { + + public abstract String getKey(); + + public abstract String getValue(); + + public abstract String getTopic(); + + public abstract long getTimeStamp(); + + public static KafkaTestRecord create( + String newKey, String newValue, String newTopic, long newTimeStamp) { + return new AutoValue_KafkaTestRecord(newKey, newValue, newTopic, newTimeStamp); + } +}