Repository: samza Updated Branches: refs/heads/master ffd04d9d6 -> 5910ea669
SAMZA-1091; Implement key-based inner join operator with no time constraints Author: Prateek Maheshwari <pmahe...@linkedin.com> Author: Prateek Maheshwari <prate...@utexas.edu> Reviewers: Jagadish Venkatraman <jagad...@apache.org>, Yi Pan <nickpa...@gmail.com> Closes #60 from prateekm/master Project: http://git-wip-us.apache.org/repos/asf/samza/repo Commit: http://git-wip-us.apache.org/repos/asf/samza/commit/5910ea66 Tree: http://git-wip-us.apache.org/repos/asf/samza/tree/5910ea66 Diff: http://git-wip-us.apache.org/repos/asf/samza/diff/5910ea66 Branch: refs/heads/master Commit: 5910ea66970e1e26d0107e4d2e0ef065a59b3af0 Parents: ffd04d9 Author: Prateek Maheshwari <pmahe...@linkedin.com> Authored: Tue Mar 14 13:44:54 2017 -0700 Committer: vjagadish1989 <jvenk...@linkedin.com> Committed: Tue Mar 14 13:44:54 2017 -0700 ---------------------------------------------------------------------- .../apache/samza/operators/MessageStream.java | 7 +- .../samza/operators/MessageStreamImpl.java | 70 ++++-- .../functions/PartialJoinFunction.java | 45 ++-- .../samza/operators/impl/OperatorGraph.java | 17 +- .../samza/operators/impl/OperatorImpl.java | 27 +- .../operators/impl/PartialJoinOperatorImpl.java | 64 ++++- .../samza/operators/impl/RootOperatorImpl.java | 2 +- .../samza/operators/spec/OperatorSpecs.java | 17 +- .../operators/spec/PartialJoinOperatorSpec.java | 54 ++-- .../operators/util/InternalInMemoryStore.java | 122 +++++++++ .../apache/samza/task/StreamOperatorTask.java | 2 +- .../samza/example/NoContextStreamExample.java | 3 +- .../samza/example/OrderShipmentJoinExample.java | 7 +- .../apache/samza/example/TestJoinExample.java | 3 +- .../samza/operators/TestJoinOperator.java | 250 +++++++++++++++++++ .../samza/operators/TestMessageStreamImpl.java | 6 +- .../samza/operators/impl/TestOperatorImpls.java | 6 +- .../samza/operators/spec/TestOperatorSpecs.java | 35 +-- 18 files changed, 616 insertions(+), 121 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/samza/blob/5910ea66/samza-api/src/main/java/org/apache/samza/operators/MessageStream.java ---------------------------------------------------------------------- diff --git a/samza-api/src/main/java/org/apache/samza/operators/MessageStream.java b/samza-api/src/main/java/org/apache/samza/operators/MessageStream.java index adeb4c8..16c5976 100644 --- a/samza-api/src/main/java/org/apache/samza/operators/MessageStream.java +++ b/samza-api/src/main/java/org/apache/samza/operators/MessageStream.java @@ -27,6 +27,7 @@ import org.apache.samza.operators.functions.SinkFunction; import org.apache.samza.operators.windows.Window; import org.apache.samza.operators.windows.WindowPane; +import java.time.Duration; import java.util.Collection; import java.util.function.Function; @@ -109,16 +110,18 @@ public interface MessageStream<M> { /** * Joins this {@link MessageStream} with another {@link MessageStream} using the provided pairwise {@link JoinFunction}. * <p> - * We currently only support 2-way joins. + * Messages in each stream are retained (currently, in memory) for the provided {@code ttl} and join results are + * emitted as matches are found. * * @param otherStream the other {@link MessageStream} to be joined with * @param joinFn the function to join messages from this and the other {@link MessageStream} + * @param ttl the ttl for messages in each stream * @param <K> the type of join key * @param <OM> the type of messages in the other stream * @param <RM> the type of messages resulting from the {@code joinFn} * @return the joined {@link MessageStream} */ - <K, OM, RM> MessageStream<RM> join(MessageStream<OM> otherStream, JoinFunction<K, M, OM, RM> joinFn); + <K, OM, RM> MessageStream<RM> join(MessageStream<OM> otherStream, JoinFunction<K, M, OM, RM> joinFn, Duration ttl); /** * Merge all {@code otherStreams} with this {@link MessageStream}. http://git-wip-us.apache.org/repos/asf/samza/blob/5910ea66/samza-core/src/main/java/org/apache/samza/operators/MessageStreamImpl.java ---------------------------------------------------------------------- diff --git a/samza-core/src/main/java/org/apache/samza/operators/MessageStreamImpl.java b/samza-core/src/main/java/org/apache/samza/operators/MessageStreamImpl.java index b22f199..339df7a 100644 --- a/samza-core/src/main/java/org/apache/samza/operators/MessageStreamImpl.java +++ b/samza-core/src/main/java/org/apache/samza/operators/MessageStreamImpl.java @@ -19,11 +19,6 @@ package org.apache.samza.operators; -import java.util.Collection; -import java.util.Collections; -import java.util.function.Function; -import java.util.HashSet; -import java.util.Set; import org.apache.samza.config.Config; import org.apache.samza.operators.functions.FilterFunction; import org.apache.samza.operators.functions.FlatMapFunction; @@ -33,11 +28,20 @@ import org.apache.samza.operators.functions.PartialJoinFunction; import org.apache.samza.operators.functions.SinkFunction; import org.apache.samza.operators.spec.OperatorSpec; import org.apache.samza.operators.spec.OperatorSpecs; +import org.apache.samza.operators.util.InternalInMemoryStore; import org.apache.samza.operators.windows.Window; import org.apache.samza.operators.windows.WindowPane; import org.apache.samza.operators.windows.internal.WindowInternal; +import org.apache.samza.storage.kv.KeyValueStore; import org.apache.samza.task.TaskContext; +import java.time.Duration; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; +import java.util.function.Function; + /** * The implementation for input/output {@link MessageStream}s to/from the operators. @@ -65,13 +69,15 @@ public class MessageStreamImpl<M> implements MessageStream<M> { this.graph = graph; } - @Override public <TM> MessageStream<TM> map(MapFunction<M, TM> mapFn) { + @Override + public <TM> MessageStream<TM> map(MapFunction<M, TM> mapFn) { OperatorSpec<TM> op = OperatorSpecs.<M, TM>createMapOperatorSpec(mapFn, this.graph, new MessageStreamImpl<>(this.graph)); this.registeredOperatorSpecs.add(op); return op.getNextStream(); } - @Override public MessageStream<M> filter(FilterFunction<M> filterFn) { + @Override + public MessageStream<M> filter(FilterFunction<M> filterFn) { OperatorSpec<M> op = OperatorSpecs.<M>createFilterOperatorSpec(filterFn, this.graph, new MessageStreamImpl<>(this.graph)); this.registeredOperatorSpecs.add(op); return op.getNextStream(); @@ -89,7 +95,8 @@ public class MessageStreamImpl<M> implements MessageStream<M> { this.registeredOperatorSpecs.add(OperatorSpecs.createSinkOperatorSpec(sinkFn, this.graph)); } - @Override public void sendTo(OutputStream<M> stream) { + @Override + public void sendTo(OutputStream<M> stream) { this.registeredOperatorSpecs.add(OperatorSpecs.createSendToOperatorSpec(stream.getSinkFunction(), this.graph, stream)); } @@ -101,13 +108,16 @@ public class MessageStreamImpl<M> implements MessageStream<M> { return wndOp.getNextStream(); } - @Override public <K, OM, RM> MessageStream<RM> join(MessageStream<OM> otherStream, JoinFunction<K, M, OM, RM> joinFn) { + @Override + public <K, JM, RM> MessageStream<RM> join(MessageStream<JM> otherStream, JoinFunction<K, M, JM, RM> joinFn, Duration ttl) { MessageStreamImpl<RM> outputStream = new MessageStreamImpl<>(this.graph); - PartialJoinFunction<K, M, OM, RM> parJoin1 = new PartialJoinFunction<K, M, OM, RM>() { + PartialJoinFunction<K, M, JM, RM> thisPartialJoinFn = new PartialJoinFunction<K, M, JM, RM>() { + private KeyValueStore<K, PartialJoinMessage<M>> thisStreamState; + @Override - public RM apply(M m1, OM om) { - return joinFn.apply(m1, om); + public RM apply(M m, JM jm) { + return joinFn.apply(m, jm); } @Override @@ -116,38 +126,49 @@ public class MessageStreamImpl<M> implements MessageStream<M> { } @Override - public K getOtherKey(OM message) { - return joinFn.getSecondKey(message); + public KeyValueStore<K, PartialJoinMessage<M>> getState() { + return thisStreamState; } @Override public void init(Config config, TaskContext context) { + // joinFn#init() must only be called once, so we do it in this partial join function's #init. joinFn.init(config, context); + + thisStreamState = new InternalInMemoryStore<>(); } }; - PartialJoinFunction<K, OM, M, RM> parJoin2 = new PartialJoinFunction<K, OM, M, RM>() { + PartialJoinFunction<K, JM, M, RM> otherPartialJoinFn = new PartialJoinFunction<K, JM, M, RM>() { + private KeyValueStore<K, PartialJoinMessage<JM>> otherStreamState; + @Override - public RM apply(OM m1, M m) { - return joinFn.apply(m, m1); + public RM apply(JM om, M m) { + return joinFn.apply(m, om); } @Override - public K getKey(OM message) { + public K getKey(JM message) { return joinFn.getSecondKey(message); } @Override - public K getOtherKey(M message) { - return joinFn.getFirstKey(message); + public KeyValueStore<K, PartialJoinMessage<JM>> getState() { + return otherStreamState; + } + + @Override + public void init(Config config, TaskContext taskContext) { + otherStreamState = new InternalInMemoryStore<>(); } }; - // TODO: need to add default store functions for the two partial join functions + this.registeredOperatorSpecs.add(OperatorSpecs.<K, M, JM, RM>createPartialJoinOperatorSpec( + thisPartialJoinFn, otherPartialJoinFn, ttl.toMillis(), this.graph, outputStream)); + + ((MessageStreamImpl<JM>) otherStream).registeredOperatorSpecs.add(OperatorSpecs.<K, JM, M, RM>createPartialJoinOperatorSpec( + otherPartialJoinFn, thisPartialJoinFn, ttl.toMillis(), this.graph, outputStream)); - ((MessageStreamImpl<OM>) otherStream).registeredOperatorSpecs.add( - OperatorSpecs.<OM, K, M, RM>createPartialJoinOperatorSpec(parJoin2, this.graph, outputStream)); - this.registeredOperatorSpecs.add(OperatorSpecs.<M, K, OM, RM>createPartialJoinOperatorSpec(parJoin1, this.graph, outputStream)); return outputStream; } @@ -169,6 +190,7 @@ public class MessageStreamImpl<M> implements MessageStream<M> { this.registeredOperatorSpecs.add(OperatorSpecs.createPartitionOperatorSpec(outputStream.getSinkFunction(), outputStream, opId)); return intStream; } + /** * Gets the operator specs registered to consume the output of this {@link MessageStream}. This is an internal API and * should not be exposed to users. http://git-wip-us.apache.org/repos/asf/samza/blob/5910ea66/samza-core/src/main/java/org/apache/samza/operators/functions/PartialJoinFunction.java ---------------------------------------------------------------------- diff --git a/samza-core/src/main/java/org/apache/samza/operators/functions/PartialJoinFunction.java b/samza-core/src/main/java/org/apache/samza/operators/functions/PartialJoinFunction.java index 809a70a..a961830 100644 --- a/samza-core/src/main/java/org/apache/samza/operators/functions/PartialJoinFunction.java +++ b/samza-core/src/main/java/org/apache/samza/operators/functions/PartialJoinFunction.java @@ -18,39 +18,52 @@ */ package org.apache.samza.operators.functions; -import org.apache.samza.annotation.InterfaceStability; - +import org.apache.samza.storage.kv.KeyValueStore; /** - * This defines the interface function a two-way join functions that takes input messages from two input - * {@link org.apache.samza.operators.MessageStream}s and merge them into a single output joined message in the join output + * An internal function that maintains state and join logic for one side of a two-way join. */ -@InterfaceStability.Unstable -public interface PartialJoinFunction<K, M, OM, RM> extends InitableFunction { +public interface PartialJoinFunction<K, M, JM, RM> extends InitableFunction { /** - * Method to perform join method on the two input messages + * Joins a message in this stream with a message from another stream. * - * @param m1 message from the first input stream - * @param om message from the second input stream + * @param m message from this input stream + * @param jm message from the other input stream * @return the joined message in the output stream */ - RM apply(M m1, OM om); + RM apply(M m, JM jm); /** - * Method to get the key from the input message + * Gets the key for the input message. * - * @param message the input message from the first strean + * @param message the input message from the first stream * @return the join key in the {@code message} */ K getKey(M message); /** - * Method to get the key from the input message in the other stream + * Gets the state associated with this stream. * - * @param message the input message from the other stream - * @return the join key in the {@code message} + * @return the key value store containing the state for this stream */ - K getOtherKey(OM message); + KeyValueStore<K, PartialJoinMessage<M>> getState(); + + class PartialJoinMessage<M> { + private final M message; + private final long receivedTimeMs; + + public PartialJoinMessage(M message, long receivedTimeMs) { + this.message = message; + this.receivedTimeMs = receivedTimeMs; + } + + public M getMessage() { + return message; + } + public long getReceivedTimeMs() { + return receivedTimeMs; + } + } } http://git-wip-us.apache.org/repos/asf/samza/blob/5910ea66/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorGraph.java ---------------------------------------------------------------------- diff --git a/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorGraph.java b/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorGraph.java index 66336f8..3efd5f5 100644 --- a/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorGraph.java +++ b/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorGraph.java @@ -29,6 +29,7 @@ import org.apache.samza.system.SystemStream; import org.apache.samza.task.TaskContext; import java.util.Collection; +import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -65,17 +66,25 @@ public class OperatorGraph { } /** - * Method to get the corresponding {@link RootOperatorImpl} + * Get the {@link RootOperatorImpl} corresponding to the provided {@code ss}. * * @param ss input {@link SystemStream} - * @param <M> the type of input message - * @return the {@link OperatorImpl} that starts processing the input message + * @return the {@link RootOperatorImpl} that starts processing the input message */ - public <M> OperatorImpl<M, M> get(SystemStream ss) { + public RootOperatorImpl get(SystemStream ss) { return this.operatorGraph.get(ss); } /** + * Get all {@link RootOperatorImpl}s for the graph. + * + * @return an immutable view of all {@link RootOperatorImpl}s for the graph + */ + public Collection<RootOperatorImpl> getAll() { + return Collections.unmodifiableCollection(this.operatorGraph.values()); + } + + /** * Traverses the DAG of {@link OperatorSpec}s starting from the provided {@link MessageStreamImpl}, * creates the corresponding DAG of {@link OperatorImpl}s, and returns its root {@link RootOperatorImpl} node. * http://git-wip-us.apache.org/repos/asf/samza/blob/5910ea66/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java ---------------------------------------------------------------------- diff --git a/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java b/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java index abb1fa9..9983307 100644 --- a/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java +++ b/samza-core/src/main/java/org/apache/samza/operators/impl/OperatorImpl.java @@ -43,7 +43,7 @@ public abstract class OperatorImpl<M, RM> { /** * Perform the transformation required for this operator and call the downstream operators. * - * Must call {@link #propagateResult} to propage the output to registered downstream operators correctly. + * Must call {@link #propagateResult} to propagate the output to registered downstream operators correctly. * * @param message the input message * @param collector the {@link MessageCollector} in the context @@ -52,6 +52,19 @@ public abstract class OperatorImpl<M, RM> { public abstract void onNext(M message, MessageCollector collector, TaskCoordinator coordinator); /** + * Perform the actions required on a timer tick and call the downstream operators. + * + * Overriding implementations must call {@link #propagateTimer} to propagate the timer tick to registered + * downstream operators correctly. + * + * @param collector the {@link MessageCollector} in the context + * @param coordinator the {@link TaskCoordinator} in the context + */ + public void onTimer(MessageCollector collector, TaskCoordinator coordinator) { + propagateTimer(collector, coordinator); + } + + /** * Helper method to propagate the output of this operator to all registered downstream operators. * * This method <b>must</b> be called from {@link #onNext} to propagate the operator output correctly. @@ -64,4 +77,16 @@ public abstract class OperatorImpl<M, RM> { nextOperators.forEach(sub -> sub.onNext(outputMessage, collector, coordinator)); } + /** + * Helper method to propagate the timer tick to all registered downstream operators. + * + * This method <b>must</b> be called from {@link #onTimer} to propagate the timer tick correctly. + * + * @param collector the {@link MessageCollector} in the context + * @param coordinator the {@link TaskCoordinator} in the context + */ + void propagateTimer(MessageCollector collector, TaskCoordinator coordinator) { + nextOperators.forEach(sub -> sub.onTimer(collector, coordinator)); + } + } http://git-wip-us.apache.org/repos/asf/samza/blob/5910ea66/samza-core/src/main/java/org/apache/samza/operators/impl/PartialJoinOperatorImpl.java ---------------------------------------------------------------------- diff --git a/samza-core/src/main/java/org/apache/samza/operators/impl/PartialJoinOperatorImpl.java b/samza-core/src/main/java/org/apache/samza/operators/impl/PartialJoinOperatorImpl.java index c8515e1..f704f3f 100644 --- a/samza-core/src/main/java/org/apache/samza/operators/impl/PartialJoinOperatorImpl.java +++ b/samza-core/src/main/java/org/apache/samza/operators/impl/PartialJoinOperatorImpl.java @@ -20,28 +20,80 @@ package org.apache.samza.operators.impl; import org.apache.samza.config.Config; import org.apache.samza.operators.MessageStreamImpl; +import org.apache.samza.operators.functions.PartialJoinFunction; +import org.apache.samza.operators.functions.PartialJoinFunction.PartialJoinMessage; import org.apache.samza.operators.spec.PartialJoinOperatorSpec; +import org.apache.samza.storage.kv.Entry; +import org.apache.samza.storage.kv.KeyValueIterator; +import org.apache.samza.storage.kv.KeyValueStore; import org.apache.samza.task.MessageCollector; import org.apache.samza.task.TaskContext; import org.apache.samza.task.TaskCoordinator; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import java.util.ArrayList; +import java.util.List; /** - * Implementation of a {@link PartialJoinOperatorSpec}. This class implements function - * that only takes in one input stream among all inputs to the join and generate the join output. + * Implementation of a {@link PartialJoinOperatorSpec} that joins messages of type {@code M} in this stream + * with buffered messages of type {@code JM} in the other stream. * * @param <M> type of messages in the input stream * @param <JM> type of messages in the stream to join with * @param <RM> type of messages in the joined stream */ -class PartialJoinOperatorImpl<M, K, JM, RM> extends OperatorImpl<M, RM> { +class PartialJoinOperatorImpl<K, M, JM, RM> extends OperatorImpl<M, RM> { - PartialJoinOperatorImpl(PartialJoinOperatorSpec<M, K, JM, RM> joinOp, MessageStreamImpl<M> source, Config config, TaskContext context) { - // TODO: implement PartialJoinOperatorImpl constructor + private static final Logger LOGGER = LoggerFactory.getLogger(PartialJoinOperatorImpl.class); + + private final PartialJoinFunction<K, M, JM, RM> thisPartialJoinFn; + private final PartialJoinFunction<K, JM, M, RM> otherPartialJoinFn; + private final long ttlMs; + private final int opId; + + PartialJoinOperatorImpl(PartialJoinOperatorSpec<K, M, JM, RM> partialJoinOperatorSpec, MessageStreamImpl<M> source, + Config config, TaskContext context) { + this.thisPartialJoinFn = partialJoinOperatorSpec.getThisPartialJoinFn(); + this.otherPartialJoinFn = partialJoinOperatorSpec.getOtherPartialJoinFn(); + this.ttlMs = partialJoinOperatorSpec.getTtlMs(); + this.opId = partialJoinOperatorSpec.getOpId(); } @Override public void onNext(M message, MessageCollector collector, TaskCoordinator coordinator) { - // TODO: implement PartialJoinOperatorImpl processing logic + K key = thisPartialJoinFn.getKey(message); + thisPartialJoinFn.getState().put(key, new PartialJoinMessage<>(message, System.currentTimeMillis())); + PartialJoinMessage<JM> otherMessage = otherPartialJoinFn.getState().get(key); + long now = System.currentTimeMillis(); + if (otherMessage != null && otherMessage.getReceivedTimeMs() > now - ttlMs) { + RM joinResult = thisPartialJoinFn.apply(message, otherMessage.getMessage()); + this.propagateResult(joinResult, collector, coordinator); + } + } + + @Override + public void onTimer(MessageCollector collector, TaskCoordinator coordinator) { + long now = System.currentTimeMillis(); + + KeyValueStore<K, PartialJoinMessage<M>> thisState = thisPartialJoinFn.getState(); + KeyValueIterator<K, PartialJoinMessage<M>> iterator = thisState.all(); + List<K> keysToRemove = new ArrayList<>(); + + while (iterator.hasNext()) { + Entry<K, PartialJoinMessage<M>> entry = iterator.next(); + if (entry.getValue().getReceivedTimeMs() < now - ttlMs) { + keysToRemove.add(entry.getKey()); + } else { + break; + } + } + + iterator.close(); + thisState.deleteAll(keysToRemove); + + LOGGER.info("Operator ID {} onTimer self time: {} ms", opId, System.currentTimeMillis() - now); + this.propagateTimer(collector, coordinator); } + } http://git-wip-us.apache.org/repos/asf/samza/blob/5910ea66/samza-core/src/main/java/org/apache/samza/operators/impl/RootOperatorImpl.java ---------------------------------------------------------------------- diff --git a/samza-core/src/main/java/org/apache/samza/operators/impl/RootOperatorImpl.java b/samza-core/src/main/java/org/apache/samza/operators/impl/RootOperatorImpl.java index 4b30a5d..eb9b5e2 100644 --- a/samza-core/src/main/java/org/apache/samza/operators/impl/RootOperatorImpl.java +++ b/samza-core/src/main/java/org/apache/samza/operators/impl/RootOperatorImpl.java @@ -26,7 +26,7 @@ import org.apache.samza.task.TaskCoordinator; * A no-op operator implementation that forwards incoming messages to all of its subscribers. * @param <M> type of incoming messages */ -final class RootOperatorImpl<M> extends OperatorImpl<M, M> { +public final class RootOperatorImpl<M> extends OperatorImpl<M, M> { @Override public void onNext(M message, MessageCollector collector, TaskCoordinator coordinator) { http://git-wip-us.apache.org/repos/asf/samza/blob/5910ea66/samza-core/src/main/java/org/apache/samza/operators/spec/OperatorSpecs.java ---------------------------------------------------------------------- diff --git a/samza-core/src/main/java/org/apache/samza/operators/spec/OperatorSpecs.java b/samza-core/src/main/java/org/apache/samza/operators/spec/OperatorSpecs.java index ae82f9d..a0c7820 100644 --- a/samza-core/src/main/java/org/apache/samza/operators/spec/OperatorSpecs.java +++ b/samza-core/src/main/java/org/apache/samza/operators/spec/OperatorSpecs.java @@ -175,18 +175,21 @@ public class OperatorSpecs { /** * Creates a {@link PartialJoinOperatorSpec}. * - * @param partialJoinFn the join function + * @param thisPartialJoinFn the partial join function for this message stream + * @param otherPartialJoinFn the partial join function for the other message stream + * @param ttlMs the ttl in ms for retaining messages in each stream * @param graph the {@link StreamGraphImpl} object * @param joinOutput the output {@link MessageStreamImpl} - * @param <M> type of input message - * @param <K> type of join key + * @param <K> the type of join key + * @param <M> the type of input message * @param <JM> the type of message in the other join stream - * @param <OM> the type of message in the join output + * @param <RM> the type of message in the join output * @return the {@link PartialJoinOperatorSpec} */ - public static <M, K, JM, OM> PartialJoinOperatorSpec<M, K, JM, OM> createPartialJoinOperatorSpec( - PartialJoinFunction<K, M, JM, OM> partialJoinFn, StreamGraphImpl graph, MessageStreamImpl<OM> joinOutput) { - return new PartialJoinOperatorSpec<>(partialJoinFn, joinOutput, graph.getNextOpId()); + public static <K, M, JM, RM> PartialJoinOperatorSpec<K, M, JM, RM> createPartialJoinOperatorSpec( + PartialJoinFunction<K, M, JM, RM> thisPartialJoinFn, PartialJoinFunction<K, JM, M, RM> otherPartialJoinFn, + long ttlMs, StreamGraphImpl graph, MessageStreamImpl<RM> joinOutput) { + return new PartialJoinOperatorSpec<K, M, JM, RM>(thisPartialJoinFn, otherPartialJoinFn, ttlMs, joinOutput, graph.getNextOpId()); } /** http://git-wip-us.apache.org/repos/asf/samza/blob/5910ea66/samza-core/src/main/java/org/apache/samza/operators/spec/PartialJoinOperatorSpec.java ---------------------------------------------------------------------- diff --git a/samza-core/src/main/java/org/apache/samza/operators/spec/PartialJoinOperatorSpec.java b/samza-core/src/main/java/org/apache/samza/operators/spec/PartialJoinOperatorSpec.java index e057c2b..669895f 100644 --- a/samza-core/src/main/java/org/apache/samza/operators/spec/PartialJoinOperatorSpec.java +++ b/samza-core/src/main/java/org/apache/samza/operators/spec/PartialJoinOperatorSpec.java @@ -28,38 +28,38 @@ import org.apache.samza.task.TaskContext; * Spec for the partial join operator that takes messages from one input stream, joins with buffered * messages from another stream, and produces join results to an output {@link MessageStreamImpl}. * - * @param <M> the type of input message * @param <K> the type of join key + * @param <M> the type of input message * @param <JM> the type of message in the other join stream * @param <RM> the type of message in the join output stream */ -public class PartialJoinOperatorSpec<M, K, JM, RM> implements OperatorSpec<RM> { - - private final MessageStreamImpl<RM> joinOutput; +public class PartialJoinOperatorSpec<K, M, JM, RM> implements OperatorSpec<RM> { - /** - * The transformation function of {@link PartialJoinOperatorSpec} that takes an input message of - * type {@code M}, joins with a stream of buffered messages of type {@code JM} from another stream, - * and generates a joined result message of type {@code RM}. - */ - private final PartialJoinFunction<K, M, JM, RM> transformFn; - - /** - * The unique ID for this operator. - */ + private final PartialJoinFunction<K, M, JM, RM> thisPartialJoinFn; + private final PartialJoinFunction<K, JM, M, RM> otherPartialJoinFn; + private final long ttlMs; + private final MessageStreamImpl<RM> joinOutput; private final int opId; /** * Default constructor for a {@link PartialJoinOperatorSpec}. * - * @param partialJoinFn partial join function that take type {@code M} of input message and join - * w/ type {@code JM} of buffered message from another stream + * @param thisPartialJoinFn partial join function that provides state and the join logic for input messages of + * type {@code M} in this stream + * @param otherPartialJoinFn partial join function that provides state for input messages of type {@code JM} + * in the other stream + * @param ttlMs the ttl in ms for retaining messages in each stream * @param joinOutput the output {@link MessageStreamImpl} of the join results + * @param opId the unique ID for this operator */ - PartialJoinOperatorSpec(PartialJoinFunction<K, M, JM, RM> partialJoinFn, MessageStreamImpl<RM> joinOutput, int opId) { + PartialJoinOperatorSpec(PartialJoinFunction<K, M, JM, RM> thisPartialJoinFn, + PartialJoinFunction<K, JM, M, RM> otherPartialJoinFn, long ttlMs, + MessageStreamImpl<RM> joinOutput, int opId) { + this.thisPartialJoinFn = thisPartialJoinFn; + this.otherPartialJoinFn = otherPartialJoinFn; + this.ttlMs = ttlMs; this.joinOutput = joinOutput; - this.transformFn = partialJoinFn; this.opId = opId; } @@ -68,8 +68,16 @@ public class PartialJoinOperatorSpec<M, K, JM, RM> implements OperatorSpec<RM> { return this.joinOutput; } - public PartialJoinFunction<K, M, JM, RM> getTransformFn() { - return this.transformFn; + public PartialJoinFunction<K, M, JM, RM> getThisPartialJoinFn() { + return this.thisPartialJoinFn; + } + + public PartialJoinFunction<K, JM, M, RM> getOtherPartialJoinFn() { + return this.otherPartialJoinFn; + } + + public long getTtlMs() { + return ttlMs; } public OperatorSpec.OpCode getOpCode() { @@ -80,7 +88,9 @@ public class PartialJoinOperatorSpec<M, K, JM, RM> implements OperatorSpec<RM> { return this.opId; } - @Override public void init(Config config, TaskContext context) { - this.transformFn.init(config, context); + @Override + public void init(Config config, TaskContext context) { + this.thisPartialJoinFn.init(config, context); } + } http://git-wip-us.apache.org/repos/asf/samza/blob/5910ea66/samza-core/src/main/java/org/apache/samza/operators/util/InternalInMemoryStore.java ---------------------------------------------------------------------- diff --git a/samza-core/src/main/java/org/apache/samza/operators/util/InternalInMemoryStore.java b/samza-core/src/main/java/org/apache/samza/operators/util/InternalInMemoryStore.java new file mode 100644 index 0000000..e5dab80 --- /dev/null +++ b/samza-core/src/main/java/org/apache/samza/operators/util/InternalInMemoryStore.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.samza.operators.util; + +import org.apache.samza.storage.kv.Entry; +import org.apache.samza.storage.kv.KeyValueIterator; +import org.apache.samza.storage.kv.KeyValueStore; + +import java.util.HashMap; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +/** + * Implements a {@link KeyValueStore} using an in-memory Java Map. + * @param <K> the type of the key in the store + * @param <V> the type of the value in the store + * + * TODO HIGH prateekm: Remove when we switch to an persistent implementation for KeyValueStore API. + */ +public class InternalInMemoryStore<K, V> implements KeyValueStore<K, V> { + + final Map<K, V> map = new LinkedHashMap<>(); + + @Override + public V get(K key) { + return map.get(key); + } + + @Override + public Map<K, V> getAll(List<K> keys) { + Map<K, V> values = new HashMap<>(); + for (K key: keys) { + values.put(key, map.get(key)); + } + return values; + } + + @Override + public void put(K key, V value) { + map.put(key, value); + } + + @Override + public void putAll(List<Entry<K, V>> entries) { + for (Entry<K, V> entry: entries) { + put(entry.getKey(), entry.getValue()); + } + } + + @Override + public void delete(K key) { + map.remove(key); + } + + @Override + public void deleteAll(List<K> keys) { + for (K key : keys) { + delete(key); + } + } + + @Override + public KeyValueIterator<K, V> range(K from, K to) { + throw new RuntimeException("not implemented."); + } + + @Override + public KeyValueIterator<K, V> all() { + final Iterator<Map.Entry<K, V>> iterator = map.entrySet().iterator(); + return new KeyValueIterator<K, V>() { + @Override + public void close() { + //not applicable + } + + @Override + public boolean hasNext() { + return iterator.hasNext(); + } + + @Override + public Entry<K, V> next() { + Map.Entry<K, V> kv = iterator.next(); + return new Entry<>(kv.getKey(), kv.getValue()); + } + + @Override + public void remove() { + iterator.remove(); + } + + }; + } + + @Override + public void close() { + //not applicable + } + + @Override + public void flush() { + //not applicable + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/samza/blob/5910ea66/samza-core/src/main/java/org/apache/samza/task/StreamOperatorTask.java ---------------------------------------------------------------------- diff --git a/samza-core/src/main/java/org/apache/samza/task/StreamOperatorTask.java b/samza-core/src/main/java/org/apache/samza/task/StreamOperatorTask.java index c697b62..fa636ec 100644 --- a/samza-core/src/main/java/org/apache/samza/task/StreamOperatorTask.java +++ b/samza-core/src/main/java/org/apache/samza/task/StreamOperatorTask.java @@ -104,7 +104,7 @@ public final class StreamOperatorTask implements StreamTask, InitableTask, Windo @Override public final void window(MessageCollector collector, TaskCoordinator coordinator) throws Exception { - // TODO: invoke timer based triggers + this.operatorGraph.getAll().forEach(r -> r.onTimer(collector, coordinator)); } @Override http://git-wip-us.apache.org/repos/asf/samza/blob/5910ea66/samza-core/src/test/java/org/apache/samza/example/NoContextStreamExample.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/example/NoContextStreamExample.java b/samza-core/src/test/java/org/apache/samza/example/NoContextStreamExample.java index ef58c8b..371294b 100644 --- a/samza-core/src/test/java/org/apache/samza/example/NoContextStreamExample.java +++ b/samza-core/src/test/java/org/apache/samza/example/NoContextStreamExample.java @@ -32,6 +32,7 @@ import org.apache.samza.system.StreamSpec; import org.apache.samza.system.SystemStreamPartition; import org.apache.samza.util.CommandLine; +import java.time.Duration; import java.util.ArrayList; import java.util.List; @@ -111,7 +112,7 @@ public class NoContextStreamExample implements StreamGraphBuilder { new StringSerde("UTF-8"), new JsonSerde<>()); inputSource1.map(this::getInputMessage). - join(inputSource2.map(this::getInputMessage), new MyJoinFunction()). + join(inputSource2.map(this::getInputMessage), new MyJoinFunction(), Duration.ofMinutes(1)). sendTo(outStream); } http://git-wip-us.apache.org/repos/asf/samza/blob/5910ea66/samza-core/src/test/java/org/apache/samza/example/OrderShipmentJoinExample.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/example/OrderShipmentJoinExample.java b/samza-core/src/test/java/org/apache/samza/example/OrderShipmentJoinExample.java index 0e60c2f..861120d 100644 --- a/samza-core/src/test/java/org/apache/samza/example/OrderShipmentJoinExample.java +++ b/samza-core/src/test/java/org/apache/samza/example/OrderShipmentJoinExample.java @@ -18,11 +18,11 @@ */ package org.apache.samza.example; +import org.apache.samza.config.Config; import org.apache.samza.operators.MessageStream; import org.apache.samza.operators.OutputStream; -import org.apache.samza.operators.StreamGraphBuilder; -import org.apache.samza.config.Config; import org.apache.samza.operators.StreamGraph; +import org.apache.samza.operators.StreamGraphBuilder; import org.apache.samza.operators.data.MessageEnvelope; import org.apache.samza.operators.functions.JoinFunction; import org.apache.samza.serializers.JsonSerde; @@ -31,6 +31,7 @@ import org.apache.samza.runtime.ApplicationRunner; import org.apache.samza.system.StreamSpec; import org.apache.samza.util.CommandLine; +import java.time.Duration; /** * Simple 2-way stream-to-stream join example @@ -55,7 +56,7 @@ public class OrderShipmentJoinExample implements StreamGraphBuilder { MessageStream<ShipmentRecord> shipments = graph.createInStream(input2, new StringSerde("UTF-8"), new JsonSerde<>()); OutputStream<FulFilledOrderRecord> fulfilledOrders = graph.createOutStream(output, new StringSerde("UTF-8"), new JsonSerde<>()); - orders.join(shipments, new MyJoinFunction()).sendTo(fulfilledOrders); + orders.join(shipments, new MyJoinFunction(), Duration.ofMinutes(1)).sendTo(fulfilledOrders); } http://git-wip-us.apache.org/repos/asf/samza/blob/5910ea66/samza-core/src/test/java/org/apache/samza/example/TestJoinExample.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/example/TestJoinExample.java b/samza-core/src/test/java/org/apache/samza/example/TestJoinExample.java index f956972..6c9f8c2 100644 --- a/samza-core/src/test/java/org/apache/samza/example/TestJoinExample.java +++ b/samza-core/src/test/java/org/apache/samza/example/TestJoinExample.java @@ -32,6 +32,7 @@ import org.apache.samza.system.StreamSpec; import org.apache.samza.system.SystemStream; import org.apache.samza.system.SystemStreamPartition; +import java.time.Duration; import java.util.ArrayList; import java.util.List; import java.util.Set; @@ -70,7 +71,7 @@ public class TestJoinExample extends TestExampleBase { if (joinOutput == null) { joinOutput = newSource; } else { - joinOutput = joinOutput.join(newSource, new MyJoinFunction()); + joinOutput = joinOutput.join(newSource, new MyJoinFunction(), Duration.ofMinutes(1)); } } http://git-wip-us.apache.org/repos/asf/samza/blob/5910ea66/samza-core/src/test/java/org/apache/samza/operators/TestJoinOperator.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/operators/TestJoinOperator.java b/samza-core/src/test/java/org/apache/samza/operators/TestJoinOperator.java new file mode 100644 index 0000000..ecd01e7 --- /dev/null +++ b/samza-core/src/test/java/org/apache/samza/operators/TestJoinOperator.java @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.samza.operators; + +import com.google.common.collect.ImmutableSet; +import org.apache.samza.Partition; +import org.apache.samza.config.Config; +import org.apache.samza.operators.data.MessageEnvelope; +import org.apache.samza.operators.functions.JoinFunction; +import org.apache.samza.system.IncomingMessageEnvelope; +import org.apache.samza.system.StreamSpec; +import org.apache.samza.system.SystemStreamPartition; +import org.apache.samza.task.MessageCollector; +import org.apache.samza.task.StreamOperatorTask; +import org.apache.samza.task.TaskContext; +import org.apache.samza.task.TaskCoordinator; +import org.junit.Before; +import org.junit.Test; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class TestJoinOperator { + private final MessageCollector messageCollector = mock(MessageCollector.class); + private final TaskCoordinator taskCoordinator = mock(TaskCoordinator.class); + private final Set<Integer> numbers = ImmutableSet.of(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); + + private StreamOperatorTask sot; + private List<Integer> output = new ArrayList<>(); + + @Before + public void setup() throws Exception { + output.clear(); + + TaskContext taskContext = mock(TaskContext.class); + when(taskContext.getSystemStreamPartitions()).thenReturn(ImmutableSet + .of(new SystemStreamPartition("insystem", "instream", new Partition(0)), + new SystemStreamPartition("insystem2", "instream2", new Partition(0)))); + Config config = mock(Config.class); + + StreamGraphBuilder sgb = new TestStreamGraphBuilder(); + sot = new StreamOperatorTask(sgb); + sot.init(config, taskContext); + } + + @Test + public void join() { + // push messages to first stream + numbers.forEach(n -> sot.process(new FirstStreamIME(n, n), messageCollector, taskCoordinator)); + // push messages to second stream with same keys + numbers.forEach(n -> sot.process(new SecondStreamIME(n, n), messageCollector, taskCoordinator)); + + int outputSum = output.stream().reduce(0, (s, m) -> s + m); + assertEquals(outputSum, 110); + } + + @Test + public void joinReverse() { + // push messages to second stream + numbers.forEach(n -> sot.process(new SecondStreamIME(n, n), messageCollector, taskCoordinator)); + // push messages to first stream with same keys + numbers.forEach(n -> sot.process(new FirstStreamIME(n, n), messageCollector, taskCoordinator)); + + int outputSum = output.stream().reduce(0, (s, m) -> s + m); + assertEquals(outputSum, 110); + } + + @Test + public void joinNoMatch() { + // push messages to first stream + numbers.forEach(n -> sot.process(new FirstStreamIME(n, n), messageCollector, taskCoordinator)); + // push messages to second stream with different keys + numbers.forEach(n -> sot.process(new SecondStreamIME(n + 100, n), messageCollector, taskCoordinator)); + + assertTrue(output.isEmpty()); + } + + @Test + public void joinNoMatchReverse() { + // push messages to second stream + numbers.forEach(n -> sot.process(new SecondStreamIME(n, n), messageCollector, taskCoordinator)); + // push messages to first stream with different keys + numbers.forEach(n -> sot.process(new FirstStreamIME(n + 100, n), messageCollector, taskCoordinator)); + + assertTrue(output.isEmpty()); + } + + @Test + public void joinRetainsLatestMessageForKey() { + // push messages to first stream + numbers.forEach(n -> sot.process(new FirstStreamIME(n, n), messageCollector, taskCoordinator)); + // push messages to first stream again with same keys but different values + numbers.forEach(n -> sot.process(new FirstStreamIME(n, 2 * n), messageCollector, taskCoordinator)); + // push messages to second stream with same key + numbers.forEach(n -> sot.process(new SecondStreamIME(n, n), messageCollector, taskCoordinator)); + + int outputSum = output.stream().reduce(0, (s, m) -> s + m); + assertEquals(outputSum, 165); // should use latest messages in the first stream + } + + @Test + public void joinRetainsLatestMessageForKeyReverse() { + // push messages to second stream + numbers.forEach(n -> sot.process(new SecondStreamIME(n, n), messageCollector, taskCoordinator)); + // push messages to second stream again with same keys but different values + numbers.forEach(n -> sot.process(new SecondStreamIME(n, 2 * n), messageCollector, taskCoordinator)); + // push messages to first stream with same key + numbers.forEach(n -> sot.process(new FirstStreamIME(n, n), messageCollector, taskCoordinator)); + + int outputSum = output.stream().reduce(0, (s, m) -> s + m); + assertEquals(outputSum, 165); // should use latest messages in the second stream + } + + @Test + public void joinRetainsMatchedMessages() { + // push messages to first stream + numbers.forEach(n -> sot.process(new FirstStreamIME(n, n), messageCollector, taskCoordinator)); + // push messages to second stream with same key + numbers.forEach(n -> sot.process(new SecondStreamIME(n, n), messageCollector, taskCoordinator)); + + int outputSum = output.stream().reduce(0, (s, m) -> s + m); + assertEquals(outputSum, 110); + + output.clear(); + + // push messages to first stream with same keys once again. + numbers.forEach(n -> sot.process(new FirstStreamIME(n, n), messageCollector, taskCoordinator)); + int newOutputSum = output.stream().reduce(0, (s, m) -> s + m); + assertEquals(newOutputSum, 110); // should produce the same output as before + } + + @Test + public void joinRetainsMatchedMessagesReverse() { + // push messages to first stream + numbers.forEach(n -> sot.process(new FirstStreamIME(n, n), messageCollector, taskCoordinator)); + // push messages to second stream with same key + numbers.forEach(n -> sot.process(new SecondStreamIME(n, n), messageCollector, taskCoordinator)); + + int outputSum = output.stream().reduce(0, (s, m) -> s + m); + assertEquals(outputSum, 110); + + output.clear(); + + // push messages to second stream with same keys once again. + numbers.forEach(n -> sot.process(new SecondStreamIME(n, n), messageCollector, taskCoordinator)); + int newOutputSum = output.stream().reduce(0, (s, m) -> s + m); + assertEquals(newOutputSum, 110); // should produce the same output as before + } + + @Test + public void joinRemovesExpiredMessages() throws Exception { + // push messages to first stream + numbers.forEach(n -> sot.process(new FirstStreamIME(n, n), messageCollector, taskCoordinator)); + + Thread.sleep(100); // 10 * ttl for join + sot.window(messageCollector, taskCoordinator); // should expire first stream messages + + // push messages to second stream with same key + numbers.forEach(n -> sot.process(new SecondStreamIME(n, n), messageCollector, taskCoordinator)); + + assertTrue(output.isEmpty()); + } + + + @Test + public void joinRemovesExpiredMessagesReverse() throws Exception { + // push messages to second stream + numbers.forEach(n -> sot.process(new SecondStreamIME(n, n), messageCollector, taskCoordinator)); + + Thread.sleep(100); // 10 * ttl for join + sot.window(messageCollector, taskCoordinator); // should expire second stream messages + + // push messages to first stream with same key + numbers.forEach(n -> sot.process(new FirstStreamIME(n, n), messageCollector, taskCoordinator)); + + assertTrue(output.isEmpty()); + } + + private class TestStreamGraphBuilder implements StreamGraphBuilder { + StreamSpec inStreamSpec = new StreamSpec("instream", "instream", "insystem"); + StreamSpec inStreamSpec2 = new StreamSpec("instream2", "instream2", "insystem2"); + + @Override + public void init(StreamGraph graph, Config config) { + MessageStream<MessageEnvelope<Integer, Integer>> inStream = graph.createInStream(inStreamSpec, null, null); + MessageStream<MessageEnvelope<Integer, Integer>> inStream2 = graph.createInStream(inStreamSpec2, null, null); + + inStream + .join(inStream2, new TestJoinFunction(), Duration.ofMillis(10)) + .map(m -> { + output.add(m); + return m; + }); + } + } + + private class TestJoinFunction + implements JoinFunction<Integer, MessageEnvelope<Integer, Integer>, MessageEnvelope<Integer, Integer>, Integer> { + @Override + public Integer apply(MessageEnvelope<Integer, Integer> message, + MessageEnvelope<Integer, Integer> otherMessage) { + return message.getMessage() + otherMessage.getMessage(); + } + + @Override + public Integer getFirstKey(MessageEnvelope<Integer, Integer> message) { + return message.getKey(); + } + + @Override + public Integer getSecondKey(MessageEnvelope<Integer, Integer> message) { + return message.getKey(); + } + } + + private class FirstStreamIME extends IncomingMessageEnvelope { + FirstStreamIME(Integer key, Integer message) { + super(new SystemStreamPartition("insystem", "instream", new Partition(0)), "1", key, message); + } + } + + private class SecondStreamIME extends IncomingMessageEnvelope { + SecondStreamIME(Integer key, Integer message) { + super(new SystemStreamPartition("insystem2", "instream2", new Partition(0)), "1", key, message); + } + } +} http://git-wip-us.apache.org/repos/asf/samza/blob/5910ea66/samza-core/src/test/java/org/apache/samza/operators/TestMessageStreamImpl.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/operators/TestMessageStreamImpl.java b/samza-core/src/test/java/org/apache/samza/operators/TestMessageStreamImpl.java index c22bd95..1d8afd4 100644 --- a/samza-core/src/test/java/org/apache/samza/operators/TestMessageStreamImpl.java +++ b/samza-core/src/test/java/org/apache/samza/operators/TestMessageStreamImpl.java @@ -162,7 +162,7 @@ public class TestMessageStreamImpl { } }; - MessageStream<TestOutputMessageEnvelope> joinOutput = source1.join(source2, joiner); + MessageStream<TestOutputMessageEnvelope> joinOutput = source1.join(source2, joiner, Duration.ofMinutes(1)); Collection<OperatorSpec> subs = source1.getRegisteredOperatorSpecs(); assertEquals(subs.size(), 1); OperatorSpec<TestMessageEnvelope> joinOp1 = subs.iterator().next(); @@ -175,10 +175,10 @@ public class TestMessageStreamImpl { assertEquals(((PartialJoinOperatorSpec) joinOp2).getNextStream(), joinOutput); TestMessageEnvelope joinMsg1 = new TestMessageEnvelope("test-join-1", "join-msg-001", 11111L); TestMessageEnvelope joinMsg2 = new TestMessageEnvelope("test-join-2", "join-msg-002", 22222L); - TestOutputMessageEnvelope xOut = (TestOutputMessageEnvelope) ((PartialJoinOperatorSpec) joinOp1).getTransformFn().apply(joinMsg1, joinMsg2); + TestOutputMessageEnvelope xOut = (TestOutputMessageEnvelope) ((PartialJoinOperatorSpec) joinOp1).getThisPartialJoinFn().apply(joinMsg1, joinMsg2); assertEquals(xOut.getKey(), "test-join-1"); assertEquals(xOut.getMessage(), Integer.valueOf(24)); - xOut = (TestOutputMessageEnvelope) ((PartialJoinOperatorSpec) joinOp2).getTransformFn().apply(joinMsg2, joinMsg1); + xOut = (TestOutputMessageEnvelope) ((PartialJoinOperatorSpec) joinOp2).getThisPartialJoinFn().apply(joinMsg2, joinMsg1); assertEquals(xOut.getKey(), "test-join-1"); assertEquals(xOut.getMessage(), Integer.valueOf(24)); } http://git-wip-us.apache.org/repos/asf/samza/blob/5910ea66/samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImpls.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImpls.java b/samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImpls.java index 02637a3..088cb00 100644 --- a/samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImpls.java +++ b/samza-core/src/test/java/org/apache/samza/operators/impl/TestOperatorImpls.java @@ -113,10 +113,8 @@ public class TestOperatorImpls { assertEquals(sinkFn, sinkFnField.get(opImpl)); // get join operator - PartialJoinOperatorSpec<TestMessageEnvelope, String, TestMessageEnvelope, TestOutputMessageEnvelope> joinOp = mock(PartialJoinOperatorSpec.class); - TestOutputMessageEnvelope mockOutput = mock(TestOutputMessageEnvelope.class); + PartialJoinOperatorSpec<String, TestMessageEnvelope, TestMessageEnvelope, TestOutputMessageEnvelope> joinOp = mock(PartialJoinOperatorSpec.class); PartialJoinFunction<String, TestMessageEnvelope, TestMessageEnvelope, TestOutputMessageEnvelope> joinFn = mock(PartialJoinFunction.class); - when(joinOp.getTransformFn()).thenReturn(joinFn); opImpl = (OperatorImpl<TestMessageEnvelope, ? extends MessageEnvelope>) createOpMethod.invoke(opGraph, mockStream, joinOp, mockConfig, mockContext); assertTrue(opImpl instanceof PartialJoinOperatorImpl); } @@ -207,7 +205,7 @@ public class TestOperatorImpls { public String getSecondKey(TestMessageEnvelope message) { return message.getKey(); } - }) + }, Duration.ofMinutes(1)) .map(m -> m); OperatorGraph opGraph = new OperatorGraph(); // now, we create chained operators from each input sources http://git-wip-us.apache.org/repos/asf/samza/blob/5910ea66/samza-core/src/test/java/org/apache/samza/operators/spec/TestOperatorSpecs.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/operators/spec/TestOperatorSpecs.java b/samza-core/src/test/java/org/apache/samza/operators/spec/TestOperatorSpecs.java index 31257a4..ae3d151 100644 --- a/samza-core/src/test/java/org/apache/samza/operators/spec/TestOperatorSpecs.java +++ b/samza-core/src/test/java/org/apache/samza/operators/spec/TestOperatorSpecs.java @@ -22,6 +22,7 @@ import org.apache.samza.operators.MessageStreamImpl; import org.apache.samza.operators.StreamGraphImpl; import org.apache.samza.operators.TestMessageEnvelope; import org.apache.samza.operators.TestMessageStreamImplUtil; +import org.apache.samza.operators.TestOutputMessageEnvelope; import org.apache.samza.operators.data.MessageEnvelope; import org.apache.samza.operators.functions.FlatMapFunction; import org.apache.samza.operators.functions.PartialJoinFunction; @@ -83,33 +84,17 @@ public class TestOperatorSpecs { @Test public void testGetPartialJoinOperator() { - PartialJoinFunction<Object, MessageEnvelope<Object, ?>, MessageEnvelope<Object, ?>, TestMessageEnvelope> merger = - new PartialJoinFunction<Object, MessageEnvelope<Object, ?>, MessageEnvelope<Object, ?>, TestMessageEnvelope>() { - @Override - public TestMessageEnvelope apply(MessageEnvelope<Object, ?> m1, MessageEnvelope<Object, ?> m2) { - return new TestMessageEnvelope(m1.getKey().toString(), m2.getMessage().toString(), System.nanoTime()); - } - - @Override - public Object getKey(MessageEnvelope<Object, ?> message) { - return message.getKey(); - } - - @Override - public Object getOtherKey(MessageEnvelope<Object, ?> message) { - return message.getKey(); - } - }; - + PartialJoinFunction<String, TestMessageEnvelope, TestMessageEnvelope, TestOutputMessageEnvelope> thisPartialJoinFn = mock(PartialJoinFunction.class); + PartialJoinFunction<String, TestMessageEnvelope, TestMessageEnvelope, TestOutputMessageEnvelope> otherPartialJoinFn = mock(PartialJoinFunction.class); StreamGraphImpl mockGraph = mock(StreamGraphImpl.class); - MessageStreamImpl<TestMessageEnvelope> joinOutput = TestMessageStreamImplUtil.<TestMessageEnvelope>getMessageStreamImpl(mockGraph); - PartialJoinOperatorSpec<MessageEnvelope<Object, ?>, Object, MessageEnvelope<Object, ?>, TestMessageEnvelope> partialJoin = - OperatorSpecs.createPartialJoinOperatorSpec(merger, mockGraph, joinOutput); + MessageStreamImpl<TestOutputMessageEnvelope> joinOutput = TestMessageStreamImplUtil.getMessageStreamImpl(mockGraph); + + PartialJoinOperatorSpec<String, TestMessageEnvelope, TestMessageEnvelope, TestOutputMessageEnvelope> partialJoinSpec = + OperatorSpecs.createPartialJoinOperatorSpec(thisPartialJoinFn, otherPartialJoinFn, 1000 * 60, mockGraph, joinOutput); - assertEquals(partialJoin.getNextStream(), joinOutput); - MessageEnvelope<Object, Object> m = mock(MessageEnvelope.class); - MessageEnvelope<Object, Object> s = mock(MessageEnvelope.class); - assertEquals(partialJoin.getTransformFn(), merger); + assertEquals(partialJoinSpec.getNextStream(), joinOutput); + assertEquals(partialJoinSpec.getThisPartialJoinFn(), thisPartialJoinFn); + assertEquals(partialJoinSpec.getOtherPartialJoinFn(), otherPartialJoinFn); } @Test