This is an automated email from the ASF dual-hosted git repository.

yhu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new c8c674e1c04 Add support for dynamic write in `MqttIO` (#32470)
c8c674e1c04 is described below

commit c8c674e1c0462fa3f70c6af75d07d1ee4f5f664e
Author: twosom <[email protected]>
AuthorDate: Wed Oct 2 00:31:10 2024 +0900

    Add support for dynamic write in `MqttIO` (#32470)
    
    * add support for dynamic write in MqttIO
    
    * Update CHANGES.md
    
    * add some assertions in testDynamicWrite
    
    * remove whitespace in CHANGES.md
    
    * refactor duplicated Write transform
    
    * change WriteFn to use Write spec
---
 CHANGES.md                                         |   2 +
 .../java/org/apache/beam/sdk/io/mqtt/MqttIO.java   | 139 ++++++++++---
 .../org/apache/beam/sdk/io/mqtt/MqttIOTest.java    | 220 +++++++++++++++++++++
 3 files changed, 332 insertions(+), 29 deletions(-)

diff --git a/CHANGES.md b/CHANGES.md
index 728c8247acc..d92639d626b 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -70,6 +70,8 @@
 * Prism
   * Prism now supports Bundle Finalization. 
([#32425](https://github.com/apache/beam/pull/32425))
 * Significantly improved performance of Kafka IO reads that enable 
[commitOffsetsInFinalize](https://beam.apache.org/releases/javadoc/current/org/apache/beam/sdk/io/kafka/KafkaIO.Read.html#commitOffsetsInFinalize--)
 by removing the data reshuffle from SDF implementation.  
([#31682](https://github.com/apache/beam/pull/31682)).
+* Added support for dynamic writing in MqttIO (Java) 
([#19376](https://github.com/apache/beam/issues/19376))
+* X feature added (Java/Python) 
([#X](https://github.com/apache/beam/issues/X)).
 
 ## Breaking Changes
 
diff --git 
a/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java 
b/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java
index 0e584d564b5..e1868e2c846 100644
--- a/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java
+++ b/sdks/java/io/mqtt/src/main/java/org/apache/beam/sdk/io/mqtt/MqttIO.java
@@ -39,6 +39,8 @@ import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.transforms.DoFn;
 import org.apache.beam.sdk.transforms.PTransform;
 import org.apache.beam.sdk.transforms.ParDo;
+import org.apache.beam.sdk.transforms.SerializableFunction;
+import org.apache.beam.sdk.transforms.SerializableFunctions;
 import org.apache.beam.sdk.transforms.display.DisplayData;
 import org.apache.beam.sdk.values.PBegin;
 import org.apache.beam.sdk.values.PCollection;
@@ -99,6 +101,26 @@ import org.slf4j.LoggerFactory;
  *       "my_topic"))
  *
  * }</pre>
+ *
+ * <h3>Dynamic Writing to a MQTT Broker</h3>
+ *
+ * <p>MqttIO also supports dynamic writing to multiple topics based on the 
data. You can specify a
+ * function to determine the target topic for each message. The following 
example demonstrates how
+ * to configure dynamic topic writing:
+ *
+ * <pre>{@code
+ * pipeline
+ *   .apply(...)  // Provide PCollection<InputT>
+ *   .apply(
+ *     MqttIO.<InputT>dynamicWrite()
+ *       .withConnectionConfiguration(
+ *         MqttIO.ConnectionConfiguration.create("tcp://host:11883"))
+ *       .withTopicFn(<Function to determine the topic dynamically>)
+ *       .withPayloadFn(<Function to extract the payload>));
+ * }</pre>
+ *
+ * <p>This dynamic writing capability allows for more flexible MQTT message 
routing based on the
+ * message content, enabling scenarios where messages are directed to 
different topics.
  */
 @SuppressWarnings({
   "nullness" // TODO(https://github.com/apache/beam/issues/20497)
@@ -115,8 +137,16 @@ public class MqttIO {
         .build();
   }
 
-  public static Write write() {
-    return new AutoValue_MqttIO_Write.Builder().setRetained(false).build();
+  public static Write<byte[]> write() {
+    return new AutoValue_MqttIO_Write.Builder<byte[]>()
+        .setRetained(false)
+        .setPayloadFn(SerializableFunctions.identity())
+        .setDynamic(false)
+        .build();
+  }
+
+  public static <InputT> Write<InputT> dynamicWrite() {
+    return new 
AutoValue_MqttIO_Write.Builder<InputT>().setRetained(false).setDynamic(true).build();
   }
 
   private MqttIO() {}
@@ -127,7 +157,7 @@ public class MqttIO {
 
     abstract String getServerUri();
 
-    abstract String getTopic();
+    abstract @Nullable String getTopic();
 
     abstract @Nullable String getClientId();
 
@@ -169,6 +199,11 @@ public class MqttIO {
           .build();
     }
 
+    public static ConnectionConfiguration create(String serverUri) {
+      checkArgument(serverUri != null, "serverUri can not be null");
+      return new 
AutoValue_MqttIO_ConnectionConfiguration.Builder().setServerUri(serverUri).build();
+    }
+
     /** Set up the MQTT broker URI. */
     public ConnectionConfiguration withServerUri(String serverUri) {
       checkArgument(serverUri != null, "serverUri can not be null");
@@ -199,7 +234,7 @@ public class MqttIO {
 
     private void populateDisplayData(DisplayData.Builder builder) {
       builder.add(DisplayData.item("serverUri", getServerUri()));
-      builder.add(DisplayData.item("topic", getTopic()));
+      builder.addIfNotNull(DisplayData.item("topic", getTopic()));
       builder.addIfNotNull(DisplayData.item("clientId", getClientId()));
       builder.addIfNotNull(DisplayData.item("username", getUsername()));
     }
@@ -278,6 +313,9 @@ public class MqttIO {
 
     @Override
     public PCollection<byte[]> expand(PBegin input) {
+      checkArgument(connectionConfiguration() != null, 
"connectionConfiguration can not be null");
+      checkArgument(connectionConfiguration().getTopic() != null, "topic can 
not be null");
+
       org.apache.beam.sdk.io.Read.Unbounded<byte[]> unbounded =
           org.apache.beam.sdk.io.Read.from(new UnboundedMqttSource(this));
 
@@ -505,29 +543,50 @@ public class MqttIO {
 
   /** A {@link PTransform} to write and send a message to a MQTT server. */
   @AutoValue
-  public abstract static class Write extends PTransform<PCollection<byte[]>, 
PDone> {
-
+  public abstract static class Write<InputT> extends 
PTransform<PCollection<InputT>, PDone> {
     abstract @Nullable ConnectionConfiguration connectionConfiguration();
 
+    abstract @Nullable SerializableFunction<InputT, String> topicFn();
+
+    abstract @Nullable SerializableFunction<InputT, byte[]> payloadFn();
+
+    abstract boolean dynamic();
+
     abstract boolean retained();
 
-    abstract Builder builder();
+    abstract Builder<InputT> builder();
 
     @AutoValue.Builder
-    abstract static class Builder {
-      abstract Builder setConnectionConfiguration(ConnectionConfiguration 
configuration);
+    abstract static class Builder<InputT> {
+      abstract Builder<InputT> 
setConnectionConfiguration(ConnectionConfiguration configuration);
+
+      abstract Builder<InputT> setRetained(boolean retained);
+
+      abstract Builder<InputT> setTopicFn(SerializableFunction<InputT, String> 
topicFn);
 
-      abstract Builder setRetained(boolean retained);
+      abstract Builder<InputT> setPayloadFn(SerializableFunction<InputT, 
byte[]> payloadFn);
 
-      abstract Write build();
+      abstract Builder<InputT> setDynamic(boolean dynamic);
+
+      abstract Write<InputT> build();
     }
 
     /** Define MQTT connection configuration used to connect to the MQTT 
broker. */
-    public Write withConnectionConfiguration(ConnectionConfiguration 
configuration) {
+    public Write<InputT> withConnectionConfiguration(ConnectionConfiguration 
configuration) {
       checkArgument(configuration != null, "configuration can not be null");
       return builder().setConnectionConfiguration(configuration).build();
     }
 
+    public Write<InputT> withTopicFn(SerializableFunction<InputT, String> 
topicFn) {
+      checkArgument(dynamic(), "withTopicFn can not use in non-dynamic write");
+      return builder().setTopicFn(topicFn).build();
+    }
+
+    public Write<InputT> withPayloadFn(SerializableFunction<InputT, byte[]> 
payloadFn) {
+      checkArgument(dynamic(), "withPayloadFn can not use in non-dynamic 
write");
+      return builder().setPayloadFn(payloadFn).build();
+    }
+
     /**
      * Whether or not the publish message should be retained by the messaging 
engine. Sending a
      * message with the retained set to {@code false} will clear the retained 
message from the
@@ -538,54 +597,76 @@ public class MqttIO {
      * @param retained Whether or not the messaging engine should retain the 
message.
      * @return The {@link Write} {@link PTransform} with the corresponding 
retained configuration.
      */
-    public Write withRetained(boolean retained) {
+    public Write<InputT> withRetained(boolean retained) {
       return builder().setRetained(retained).build();
     }
 
-    @Override
-    public PDone expand(PCollection<byte[]> input) {
-      input.apply(ParDo.of(new WriteFn(this)));
-      return PDone.in(input.getPipeline());
-    }
-
     @Override
     public void populateDisplayData(DisplayData.Builder builder) {
       connectionConfiguration().populateDisplayData(builder);
       builder.add(DisplayData.item("retained", retained()));
     }
 
-    private static class WriteFn extends DoFn<byte[], Void> {
+    @Override
+    public PDone expand(PCollection<InputT> input) {
+      checkArgument(connectionConfiguration() != null, 
"connectionConfiguration can not be null");
+      if (dynamic()) {
+        checkArgument(
+            connectionConfiguration().getTopic() == null, "DynamicWrite can 
not have static topic");
+        checkArgument(topicFn() != null, "topicFn can not be null");
+      } else {
+        checkArgument(connectionConfiguration().getTopic() != null, "topic can 
not be null");
+      }
+      checkArgument(payloadFn() != null, "payloadFn can not be null");
+
+      input.apply(ParDo.of(new WriteFn<>(this)));
+      return PDone.in(input.getPipeline());
+    }
+
+    private static class WriteFn<InputT> extends DoFn<InputT, Void> {
 
-      private final Write spec;
+      private final Write<InputT> spec;
+      private final SerializableFunction<InputT, String> topicFn;
+      private final SerializableFunction<InputT, byte[]> payloadFn;
+      private final boolean retained;
 
       private transient MQTT client;
       private transient BlockingConnection connection;
 
-      public WriteFn(Write spec) {
+      public WriteFn(Write<InputT> spec) {
         this.spec = spec;
+        if (spec.dynamic()) {
+          this.topicFn = spec.topicFn();
+        } else {
+          String topic = spec.connectionConfiguration().getTopic();
+          this.topicFn = ignore -> topic;
+        }
+        this.payloadFn = spec.payloadFn();
+        this.retained = spec.retained();
       }
 
       @Setup
       public void createMqttClient() throws Exception {
         LOG.debug("Starting MQTT writer");
-        client = spec.connectionConfiguration().createClient();
+        this.client = this.spec.connectionConfiguration().createClient();
         LOG.debug("MQTT writer client ID is {}", client.getClientId());
-        connection = createConnection(client);
+        this.connection = createConnection(client);
       }
 
       @ProcessElement
       public void processElement(ProcessContext context) throws Exception {
-        byte[] payload = context.element();
+        InputT element = context.element();
+        byte[] payload = this.payloadFn.apply(element);
+        String topic = this.topicFn.apply(element);
         LOG.debug("Sending message {}", new String(payload, 
StandardCharsets.UTF_8));
-        connection.publish(
-            spec.connectionConfiguration().getTopic(), payload, 
QoS.AT_LEAST_ONCE, false);
+        this.connection.publish(topic, payload, QoS.AT_LEAST_ONCE, 
this.retained);
       }
 
       @Teardown
       public void closeMqttClient() throws Exception {
-        if (connection != null) {
+        if (this.connection != null) {
           LOG.debug("Disconnecting MQTT connection (client ID {})", 
client.getClientId());
-          connection.disconnect();
+          this.connection.disconnect();
         }
       }
     }
diff --git 
a/sdks/java/io/mqtt/src/test/java/org/apache/beam/sdk/io/mqtt/MqttIOTest.java 
b/sdks/java/io/mqtt/src/test/java/org/apache/beam/sdk/io/mqtt/MqttIOTest.java
index 7d60d6d6578..8dfa7838d66 100644
--- 
a/sdks/java/io/mqtt/src/test/java/org/apache/beam/sdk/io/mqtt/MqttIOTest.java
+++ 
b/sdks/java/io/mqtt/src/test/java/org/apache/beam/sdk/io/mqtt/MqttIOTest.java
@@ -18,6 +18,7 @@
 package org.apache.beam.sdk.io.mqtt;
 
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.assertTrue;
 
 import java.io.ByteArrayInputStream;
@@ -26,16 +27,25 @@ import java.io.ObjectInputStream;
 import java.io.ObjectOutputStream;
 import java.nio.charset.StandardCharsets;
 import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.Map;
 import java.util.Set;
 import java.util.UUID;
+import java.util.concurrent.ConcurrentSkipListMap;
 import java.util.concurrent.ConcurrentSkipListSet;
 import org.apache.activemq.broker.BrokerService;
 import org.apache.activemq.broker.Connection;
+import org.apache.beam.sdk.coders.ByteArrayCoder;
+import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.io.common.NetworkTestHelper;
 import org.apache.beam.sdk.io.mqtt.MqttIO.Read;
 import org.apache.beam.sdk.testing.PAssert;
 import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.values.KV;
+import org.apache.beam.sdk.values.PBegin;
 import org.apache.beam.sdk.values.PCollection;
 import org.fusesource.hawtbuf.Buffer;
 import org.fusesource.mqtt.client.BlockingConnection;
@@ -266,6 +276,216 @@ public class MqttIOTest {
     }
   }
 
+  @Test
+  public void testDynamicWrite() throws Exception {
+    final int numberOfTopic1Count = 100;
+    final int numberOfTopic2Count = 100;
+    final int numberOfTestMessages = numberOfTopic1Count + numberOfTopic2Count;
+
+    MQTT client = new MQTT();
+    client.setHost("tcp://localhost:" + port);
+    final BlockingConnection connection = client.blockingConnection();
+    connection.connect();
+    final String writeTopic1 = "WRITE_TOPIC_1";
+    final String writeTopic2 = "WRITE_TOPIC_2";
+    connection.subscribe(
+        new Topic[] {
+          new Topic(Buffer.utf8(writeTopic1), QoS.EXACTLY_ONCE),
+          new Topic(Buffer.utf8(writeTopic2), QoS.EXACTLY_ONCE)
+        });
+
+    final Map<String, List<String>> messageMap = new ConcurrentSkipListMap<>();
+    final Thread subscriber =
+        new Thread(
+            () -> {
+              try {
+                for (int i = 0; i < numberOfTestMessages; i++) {
+                  Message message = connection.receive();
+                  List<String> messages = messageMap.get(message.getTopic());
+                  if (messages == null) {
+                    messages = new ArrayList<>();
+                  }
+                  messages.add(new String(message.getPayload(), 
StandardCharsets.UTF_8));
+                  messageMap.put(message.getTopic(), messages);
+                  message.ack();
+                }
+              } catch (Exception e) {
+                LOG.error("Can't receive message", e);
+              }
+            });
+
+    subscriber.start();
+
+    ArrayList<KV<String, byte[]>> data = new ArrayList<>();
+    for (int i = 0; i < numberOfTopic1Count; i++) {
+      data.add(KV.of(writeTopic1, ("Test" + 
i).getBytes(StandardCharsets.UTF_8)));
+    }
+
+    for (int i = 0; i < numberOfTopic2Count; i++) {
+      data.add(KV.of(writeTopic2, ("Test" + 
i).getBytes(StandardCharsets.UTF_8)));
+    }
+
+    pipeline
+        .apply(Create.of(data))
+        .setCoder(KvCoder.of(StringUtf8Coder.of(), ByteArrayCoder.of()))
+        .apply(
+            MqttIO.<KV<String, byte[]>>dynamicWrite()
+                .withConnectionConfiguration(
+                    MqttIO.ConnectionConfiguration.create("tcp://localhost:" + 
port)
+                        .withClientId("READ_PIPELINE"))
+                .withTopicFn(input -> input.getKey())
+                .withPayloadFn(input -> input.getValue()));
+
+    pipeline.run();
+    subscriber.join();
+
+    connection.disconnect();
+
+    assertEquals(
+        numberOfTestMessages, 
messageMap.values().stream().mapToLong(Collection::size).sum());
+
+    assertEquals(2, messageMap.keySet().size());
+    assertTrue(messageMap.containsKey(writeTopic1));
+    assertTrue(messageMap.containsKey(writeTopic2));
+    for (Map.Entry<String, List<String>> entry : messageMap.entrySet()) {
+      final List<String> messages = entry.getValue();
+      messages.forEach(message -> assertTrue(message.contains("Test")));
+      if (entry.getKey().equals(writeTopic1)) {
+        assertEquals(numberOfTopic1Count, messages.size());
+      } else {
+        assertEquals(numberOfTopic2Count, messages.size());
+      }
+    }
+  }
+
+  @Test
+  public void testReadHaveNoConnectionConfiguration() {
+    IllegalArgumentException exception =
+        assertThrows(
+            IllegalArgumentException.class, () -> 
MqttIO.read().expand(PBegin.in(pipeline)));
+
+    assertEquals("connectionConfiguration can not be null", 
exception.getMessage());
+  }
+
+  @Test
+  public void testReadHaveNoTopic() {
+    IllegalArgumentException exception =
+        assertThrows(
+            IllegalArgumentException.class,
+            () ->
+                MqttIO.read()
+                    
.withConnectionConfiguration(MqttIO.ConnectionConfiguration.create("serverUri"))
+                    .expand(PBegin.in(pipeline)));
+
+    assertEquals("topic can not be null", exception.getMessage());
+
+    pipeline.run();
+  }
+
+  @Test
+  public void testWriteHaveNoConnectionConfiguration() {
+    IllegalArgumentException exception =
+        assertThrows(
+            IllegalArgumentException.class,
+            () -> MqttIO.write().expand(pipeline.apply(Create.of(new byte[] 
{}))));
+
+    assertEquals("connectionConfiguration can not be null", 
exception.getMessage());
+
+    pipeline.run();
+  }
+
+  @Test
+  public void testWriteHaveNoTopic() {
+    IllegalArgumentException exception =
+        assertThrows(
+            IllegalArgumentException.class,
+            () ->
+                MqttIO.write()
+                    
.withConnectionConfiguration(MqttIO.ConnectionConfiguration.create("serverUri"))
+                    .expand(pipeline.apply(Create.of(new byte[] {}))));
+
+    assertEquals("topic can not be null", exception.getMessage());
+
+    pipeline.run();
+  }
+
+  @Test
+  public void testDynamicWriteHaveNoConnectionConfiguration() {
+    IllegalArgumentException exception =
+        assertThrows(
+            IllegalArgumentException.class,
+            () -> MqttIO.dynamicWrite().expand(pipeline.apply(Create.of(new 
byte[] {}))));
+
+    assertEquals("connectionConfiguration can not be null", 
exception.getMessage());
+
+    pipeline.run();
+  }
+
+  @Test
+  public void testDynamicWriteHaveNoTopicFn() {
+    IllegalArgumentException exception =
+        assertThrows(
+            IllegalArgumentException.class,
+            () ->
+                MqttIO.dynamicWrite()
+                    
.withConnectionConfiguration(MqttIO.ConnectionConfiguration.create("serverUri"))
+                    .expand(pipeline.apply(Create.of(new byte[] {}))));
+
+    assertEquals("topicFn can not be null", exception.getMessage());
+
+    pipeline.run();
+  }
+
+  @Test
+  public void testDynamicWriteHaveNoPayloadFn() {
+    IllegalArgumentException exception =
+        assertThrows(
+            IllegalArgumentException.class,
+            () ->
+                MqttIO.dynamicWrite()
+                    
.withConnectionConfiguration(MqttIO.ConnectionConfiguration.create("serverUri"))
+                    .withTopicFn(input -> "topic")
+                    .expand(pipeline.apply(Create.of(new byte[] {}))));
+
+    assertEquals("payloadFn can not be null", exception.getMessage());
+
+    pipeline.run();
+  }
+
+  @Test
+  public void testDynamicWriteHaveStaticTopic() {
+    IllegalArgumentException exception =
+        assertThrows(
+            IllegalArgumentException.class,
+            () ->
+                MqttIO.dynamicWrite()
+                    .withConnectionConfiguration(
+                        MqttIO.ConnectionConfiguration.create("serverUri", 
"topic"))
+                    .expand(pipeline.apply(Create.of(new byte[] {}))));
+
+    assertEquals("DynamicWrite can not have static topic", 
exception.getMessage());
+
+    pipeline.run();
+  }
+
+  @Test
+  public void testWriteWithTopicFn() {
+    IllegalArgumentException exception =
+        assertThrows(
+            IllegalArgumentException.class, () -> MqttIO.write().withTopicFn(e 
-> "some topic"));
+
+    assertEquals("withTopicFn can not use in non-dynamic write", 
exception.getMessage());
+  }
+
+  @Test
+  public void testWriteWithPayloadFn() {
+    final IllegalArgumentException exception =
+        assertThrows(
+            IllegalArgumentException.class, () -> 
MqttIO.write().withPayloadFn(e -> new byte[] {}));
+
+    assertEquals("withPayloadFn can not use in non-dynamic write", 
exception.getMessage());
+  }
+
   @Test
   public void testReadObject() throws Exception {
     ByteArrayOutputStream bos = new ByteArrayOutputStream();

Reply via email to