http://git-wip-us.apache.org/repos/asf/samza/blob/53d7f262/samza-core/src/test/java/org/apache/samza/operators/spec/OperatorSpecTestUtils.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/operators/spec/OperatorSpecTestUtils.java b/samza-core/src/test/java/org/apache/samza/operators/spec/OperatorSpecTestUtils.java new file mode 100644 index 0000000..b39b0d0 --- /dev/null +++ b/samza-core/src/test/java/org/apache/samza/operators/spec/OperatorSpecTestUtils.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.samza.operators.spec; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import org.apache.samza.operators.OperatorSpecGraph; +import org.apache.samza.operators.TableImpl; +import org.apache.samza.operators.functions.TimerFunction; +import org.apache.samza.operators.functions.WatermarkFunction; +import org.apache.samza.serializers.SerializableSerde; +import org.apache.samza.system.StreamSpec; +import org.apache.samza.table.TableSpec; + +import static org.junit.Assert.*; + + +/** + * Test util methods for {@link OperatorSpec} classes + */ +public class OperatorSpecTestUtils { + private final static SerializableSerde<OperatorSpec> SPEC_SERDE = new SerializableSerde<>(); + + static OperatorSpec copyOpSpec(OperatorSpec original) { + return SPEC_SERDE.fromBytes(SPEC_SERDE.toBytes(original)); + } + + enum TestEnum { + One, Two, Three + } + + public static void assertClonedGraph(OperatorSpecGraph originalGraph, OperatorSpecGraph clonedGraph) { + assertClonedInputs(originalGraph.getInputOperators(), clonedGraph.getInputOperators()); + assertClonedOutputs(originalGraph.getOutputStreams(), clonedGraph.getOutputStreams()); + assertClonedTables(originalGraph.getTables(), clonedGraph.getTables()); + assertAllOperators(originalGraph.getAllOperatorSpecs(), clonedGraph.getAllOperatorSpecs()); + } + + private static void assertAllOperators(Collection<OperatorSpec> originalOpSpecs, Collection<OperatorSpec> clonedOpSpecs) { + assertEquals(originalOpSpecs.size(), clonedOpSpecs.size()); + List<OperatorSpec> originalList = new ArrayList<>(originalOpSpecs); + List<OperatorSpec> clonedList = new ArrayList<>(clonedOpSpecs); + Collections.sort(originalList, Comparator.comparing(OperatorSpec::getOpId)); + Collections.sort(clonedList, Comparator.comparing(OperatorSpec::getOpId)); + Iterator<OperatorSpec> oIter = originalList.iterator(); + Iterator<OperatorSpec> nIter = clonedList.iterator(); + oIter.forEachRemaining(opSpec -> assertClonedOpSpec(opSpec, nIter.next())); + } + + private static void assertClonedOpSpec(OperatorSpec oOpSpec, OperatorSpec nOpSpec) { + assertNotEquals(oOpSpec, nOpSpec); + assertEquals(oOpSpec.getOpId(), nOpSpec.getOpId()); + assertEquals(oOpSpec.getOpCode(), nOpSpec.getOpCode()); + assertTimerFnsNotEqual(oOpSpec.getTimerFn(), nOpSpec.getTimerFn()); + assertWatermarkFnNotEqual(nOpSpec.getWatermarkFn(), nOpSpec.getWatermarkFn()); + assertAllOperators(oOpSpec.getRegisteredOperatorSpecs(), nOpSpec.getRegisteredOperatorSpecs()); + } + + private static void assertWatermarkFnNotEqual(WatermarkFunction watermarkFn, WatermarkFunction watermarkFn1) { + if (watermarkFn == watermarkFn1 && watermarkFn == null) { + return; + } + assertNotEquals(watermarkFn, watermarkFn1); + } + + private static void assertTimerFnsNotEqual(TimerFunction timerFn, TimerFunction timerFn1) { + if (timerFn == timerFn1 && timerFn == null) { + return; + } + assertNotEquals(timerFn, timerFn1); + } + + private static void assertClonedTables(Map<TableSpec, TableImpl> originalTables, Map<TableSpec, TableImpl> clonedTables) { + assertEquals(originalTables.size(), clonedTables.size()); + assertEquals(originalTables.keySet(), clonedTables.keySet()); + Iterator<TableImpl> oIter = originalTables.values().iterator(); + Iterator<TableImpl> nIter = clonedTables.values().iterator(); + oIter.forEachRemaining(oTable -> assertClonedTableImpl(oTable, nIter.next())); + } + + private static void assertClonedTableImpl(TableImpl oTable, TableImpl nTable) { + assertNotEquals(oTable, nTable); + assertEquals(oTable.getTableSpec(), nTable.getTableSpec()); + } + + private static void assertClonedOutputs(Map<StreamSpec, OutputStreamImpl> originalOutputs, + Map<StreamSpec, OutputStreamImpl> clonedOutputs) { + assertEquals(originalOutputs.size(), clonedOutputs.size()); + assertEquals(originalOutputs.keySet(), clonedOutputs.keySet()); + Iterator<OutputStreamImpl> oIter = originalOutputs.values().iterator(); + Iterator<OutputStreamImpl> nIter = clonedOutputs.values().iterator(); + oIter.forEachRemaining(oOutput -> assertClonedOutputImpl(oOutput, nIter.next())); + } + + private static void assertClonedOutputImpl(OutputStreamImpl oOutput, OutputStreamImpl nOutput) { + assertNotEquals(oOutput, nOutput); + assertEquals(oOutput.isKeyed(), nOutput.isKeyed()); + assertEquals(oOutput.getSystemStream(), nOutput.getSystemStream()); + assertEquals(oOutput.getStreamSpec(), nOutput.getStreamSpec()); + } + + private static void assertClonedInputs(Map<StreamSpec, InputOperatorSpec> originalInputs, + Map<StreamSpec, InputOperatorSpec> clonedInputs) { + assertEquals(originalInputs.size(), clonedInputs.size()); + assertEquals(originalInputs.keySet(), clonedInputs.keySet()); + Iterator<InputOperatorSpec> oIter = originalInputs.values().iterator(); + Iterator<InputOperatorSpec> nIter = clonedInputs.values().iterator(); + oIter.forEachRemaining(inputOp -> assertClonedInputOp(inputOp, nIter.next())); + } + + private static void assertClonedInputOp(InputOperatorSpec originalInput, InputOperatorSpec clonedInput) { + assertNotEquals(originalInput, clonedInput); + assertEquals(originalInput.getOpId(), clonedInput.getOpId()); + assertEquals(originalInput.getOpCode(), clonedInput.getOpCode()); + assertEquals(originalInput.getStreamSpec(), clonedInput.getStreamSpec()); + assertEquals(originalInput.isKeyed(), clonedInput.isKeyed()); + } + +}
http://git-wip-us.apache.org/repos/asf/samza/blob/53d7f262/samza-core/src/test/java/org/apache/samza/operators/spec/TestOperatorSpec.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/operators/spec/TestOperatorSpec.java b/samza-core/src/test/java/org/apache/samza/operators/spec/TestOperatorSpec.java new file mode 100644 index 0000000..cb221b0 --- /dev/null +++ b/samza-core/src/test/java/org/apache/samza/operators/spec/TestOperatorSpec.java @@ -0,0 +1,465 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.samza.operators.spec; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import org.apache.samza.config.MapConfig; +import org.apache.samza.operators.KV; +import org.apache.samza.operators.TimerRegistry; +import org.apache.samza.operators.data.TestMessageEnvelope; +import org.apache.samza.operators.data.TestOutputMessageEnvelope; +import org.apache.samza.operators.functions.FilterFunction; +import org.apache.samza.operators.functions.FlatMapFunction; +import org.apache.samza.operators.functions.JoinFunction; +import org.apache.samza.operators.functions.MapFunction; +import org.apache.samza.operators.functions.SinkFunction; +import org.apache.samza.operators.functions.StreamTableJoinFunction; +import org.apache.samza.operators.functions.TimerFunction; +import org.apache.samza.operators.functions.WatermarkFunction; +import org.apache.samza.serializers.JsonSerdeV2; +import org.apache.samza.serializers.KVSerde; +import org.apache.samza.serializers.NoOpSerde; +import org.apache.samza.serializers.Serde; +import org.apache.samza.serializers.StringSerde; +import org.apache.samza.system.StreamSpec; +import org.apache.samza.table.TableSpec; +import org.junit.Test; +import org.mockito.internal.util.reflection.Whitebox; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; + + +/** + * Test for all {@link OperatorSpec} + */ +public class TestOperatorSpec { + + private static class MapWithWatermarkFn implements MapFunction<TestMessageEnvelope, TestOutputMessageEnvelope>, WatermarkFunction<TestOutputMessageEnvelope> { + + @Override + public Collection<TestOutputMessageEnvelope> processWatermark(long watermark) { + return null; + } + + @Override + public Long getOutputWatermark() { + return null; + } + + @Override + public TestOutputMessageEnvelope apply(TestMessageEnvelope m) { + return new TestOutputMessageEnvelope(m.getKey(), m.getMessage().hashCode()); + } + } + + private static class MapWithTimerFn implements MapFunction<TestMessageEnvelope, TestOutputMessageEnvelope>, TimerFunction<String, TestOutputMessageEnvelope> { + + @Override + public TestOutputMessageEnvelope apply(TestMessageEnvelope m) { + return new TestOutputMessageEnvelope(m.getKey(), m.getMessage().hashCode()); + } + + @Override + public void registerTimer(TimerRegistry<String> timerRegistry) { + + } + + @Override + public Collection<TestOutputMessageEnvelope> onTimer(String key, long timestamp) { + return null; + } + } + + private static class MapWithEnum implements MapFunction<TestMessageEnvelope, TestOutputMessageEnvelope> { + private OperatorSpecTestUtils.TestEnum type; + + MapWithEnum(OperatorSpecTestUtils.TestEnum type) { + this.type = type; + } + + OperatorSpecTestUtils.TestEnum getType() { + return this.type; + } + + void setType(OperatorSpecTestUtils.TestEnum type) { + this.type = type; + } + + @Override + public TestOutputMessageEnvelope apply(TestMessageEnvelope m) { + return new TestOutputMessageEnvelope(m.getKey(), m.getMessage().hashCode()); + } + } + + private static class TestJoinFunction implements JoinFunction<String, Object, Object, TestOutputMessageEnvelope> { + @Override + public TestOutputMessageEnvelope apply(Object message, Object otherMessage) { + return new TestOutputMessageEnvelope(message.toString(), message.hashCode() + otherMessage.hashCode()); + } + + @Override + public String getFirstKey(Object message) { + return message.toString(); + } + + @Override + public String getSecondKey(Object message) { + return message.toString(); + } + } + + private static class TestStreamTableJoinFunction implements StreamTableJoinFunction<String, Object, Object, TestOutputMessageEnvelope> { + @Override + public TestOutputMessageEnvelope apply(Object message, Object record) { + return new TestOutputMessageEnvelope(message.toString(), message.hashCode() + record.hashCode()); + } + + @Override + public String getMessageKey(Object message) { + return message.toString(); + } + + @Override + public String getRecordKey(Object record) { + return record.toString(); + } + } + + @Test + public void testStreamOperatorSpecWithFlatMap() { + FlatMapFunction<TestMessageEnvelope, TestOutputMessageEnvelope> flatMap = m -> { + List<TestOutputMessageEnvelope> result = new ArrayList<>(); + result.add(new TestOutputMessageEnvelope(m.getKey(), m.getMessage().hashCode())); + return result; + }; + StreamOperatorSpec<TestMessageEnvelope, TestOutputMessageEnvelope> streamOperatorSpec = + OperatorSpecs.createFlatMapOperatorSpec(flatMap, "op0"); + StreamOperatorSpec<TestMessageEnvelope, TestOutputMessageEnvelope> cloneOperatorSpec = + (StreamOperatorSpec<TestMessageEnvelope, TestOutputMessageEnvelope>) OperatorSpecTestUtils.copyOpSpec(streamOperatorSpec); + assertNotEquals(streamOperatorSpec, cloneOperatorSpec); + assertTrue(streamOperatorSpec.isClone(cloneOperatorSpec)); + assertNotEquals(streamOperatorSpec.getTransformFn(), cloneOperatorSpec.getTransformFn()); + assertTrue(cloneOperatorSpec.getTransformFn() instanceof FlatMapFunction); + assertNull(streamOperatorSpec.getWatermarkFn()); + assertNull(cloneOperatorSpec.getWatermarkFn()); + assertNull(streamOperatorSpec.getTimerFn()); + assertNull(cloneOperatorSpec.getTimerFn()); + } + + @Test + public void testStreamOperatorSpecWithMap() { + MapFunction<TestMessageEnvelope, TestOutputMessageEnvelope> mapFn = + m -> new TestOutputMessageEnvelope(m.getKey(), m.getMessage().hashCode()); + StreamOperatorSpec<TestMessageEnvelope, TestOutputMessageEnvelope> streamOperatorSpec = + OperatorSpecs.createMapOperatorSpec(mapFn, "op0"); + StreamOperatorSpec<TestMessageEnvelope, TestOutputMessageEnvelope> cloneOperatorSpec = + (StreamOperatorSpec<TestMessageEnvelope, TestOutputMessageEnvelope>) OperatorSpecTestUtils.copyOpSpec(streamOperatorSpec); + assertNotEquals(streamOperatorSpec, cloneOperatorSpec); + assertTrue(streamOperatorSpec.isClone(cloneOperatorSpec)); + MapFunction userFn = (MapFunction) Whitebox.getInternalState(streamOperatorSpec, "mapFn"); + assertEquals(userFn, mapFn); + assertNotEquals(streamOperatorSpec.getTransformFn(), cloneOperatorSpec.getTransformFn()); + MapFunction clonedUserFn = (MapFunction) Whitebox.getInternalState(cloneOperatorSpec, "mapFn"); + assertTrue(cloneOperatorSpec.getTransformFn() instanceof FlatMapFunction); + assertTrue(clonedUserFn instanceof MapFunction); + assertNotEquals(userFn, clonedUserFn); + assertNull(streamOperatorSpec.getWatermarkFn()); + assertNull(cloneOperatorSpec.getWatermarkFn()); + assertNull(streamOperatorSpec.getTimerFn()); + assertNull(cloneOperatorSpec.getTimerFn()); + } + + @Test + public void testStreamOperatorSpecWithFilter() { + FilterFunction<TestMessageEnvelope> filterFn = m -> m.getKey().equals("key1"); + StreamOperatorSpec<TestMessageEnvelope, TestMessageEnvelope> streamOperatorSpec = + OperatorSpecs.createFilterOperatorSpec(filterFn, "op0"); + StreamOperatorSpec<TestMessageEnvelope, TestOutputMessageEnvelope> cloneOperatorSpec = + (StreamOperatorSpec<TestMessageEnvelope, TestOutputMessageEnvelope>) OperatorSpecTestUtils.copyOpSpec(streamOperatorSpec); + assertNotEquals(streamOperatorSpec, cloneOperatorSpec); + assertTrue(streamOperatorSpec.isClone(cloneOperatorSpec)); + FilterFunction userFn = (FilterFunction) Whitebox.getInternalState(streamOperatorSpec, "filterFn"); + assertEquals(userFn, filterFn); + assertNotEquals(streamOperatorSpec.getTransformFn(), cloneOperatorSpec.getTransformFn()); + FilterFunction clonedUserFn = (FilterFunction) Whitebox.getInternalState(cloneOperatorSpec, "filterFn"); + assertTrue(cloneOperatorSpec.getTransformFn() instanceof FlatMapFunction); + assertTrue(clonedUserFn instanceof FilterFunction); + assertNotEquals(userFn, clonedUserFn); + assertNull(streamOperatorSpec.getWatermarkFn()); + assertNull(cloneOperatorSpec.getWatermarkFn()); + assertNull(streamOperatorSpec.getTimerFn()); + assertNull(cloneOperatorSpec.getTimerFn()); + } + + @Test + public void testInputOperatorSpec() { + Serde<Object> objSerde = new Serde<Object>() { + + @Override + public Object fromBytes(byte[] bytes) { + return null; + } + + @Override + public byte[] toBytes(Object object) { + return new byte[0]; + } + }; + + StreamSpec mockStreamSpec = mock(StreamSpec.class); + InputOperatorSpec<String, Object> inputOperatorSpec = new InputOperatorSpec<>( + mockStreamSpec, new StringSerde("UTF-8"), objSerde, true, "op0"); + InputOperatorSpec<String, Object> inputOpCopy = (InputOperatorSpec<String, Object>) OperatorSpecTestUtils.copyOpSpec(inputOperatorSpec); + + assertNotEquals("Expected deserialized copy of operator spec should not be the same as the original operator spec", inputOperatorSpec, inputOpCopy); + assertTrue(inputOperatorSpec.isClone(inputOpCopy)); + + } + + @Test + public void testOutputOperatorSpec() { + Serde<Object> objSerde = new Serde<Object>() { + + @Override + public Object fromBytes(byte[] bytes) { + return null; + } + + @Override + public byte[] toBytes(Object object) { + return new byte[0]; + } + }; + StreamSpec mockStreamSpec = mock(StreamSpec.class); + OutputStreamImpl<KV<String, Object>> outputStrmImpl = new OutputStreamImpl<>(mockStreamSpec, new StringSerde("UTF-8"), objSerde, true); + OutputOperatorSpec<KV<String, Object>> outputOperatorSpec = new OutputOperatorSpec<KV<String, Object>>(outputStrmImpl, "op0"); + OutputOperatorSpec<KV<String, Object>> outputOpCopy = (OutputOperatorSpec<KV<String, Object>>) OperatorSpecTestUtils + .copyOpSpec(outputOperatorSpec); + assertNotEquals("Expected deserialized copy of operator spec should not be the same as the original operator spec", outputOperatorSpec, outputOpCopy); + assertTrue(outputOperatorSpec.isClone(outputOpCopy)); + } + + @Test + public void testSinkOperatorSpec() { + SinkFunction<TestMessageEnvelope> sinkFn = (m, c, tc) -> System.out.print(m.toString()); + SinkOperatorSpec<TestMessageEnvelope> sinkOpSpec = new SinkOperatorSpec<>(sinkFn, "op0"); + SinkOperatorSpec<TestMessageEnvelope> sinkOpCopy = (SinkOperatorSpec<TestMessageEnvelope>) OperatorSpecTestUtils.copyOpSpec(sinkOpSpec); + assertNotEquals("Expected deserialized copy of operator spec should not be the same as the original operator spec", sinkOpSpec, sinkOpCopy); + assertTrue(sinkOpSpec.isClone(sinkOpCopy)); + } + + @Test + public void testJoinOperatorSpec() { + + InputOperatorSpec<TestMessageEnvelope, Object> leftOpSpec = new InputOperatorSpec<>( + new StreamSpec("test-input-1", "test-input-1", "kafka"), new NoOpSerde<>(), + new NoOpSerde<>(), false, "op0"); + InputOperatorSpec<TestMessageEnvelope, Object> rightOpSpec = new InputOperatorSpec<>( + new StreamSpec("test-input-2", "test-input-2", "kafka"), new NoOpSerde<>(), + new NoOpSerde<>(), false, "op1"); + + Serde<Object> objSerde = new Serde<Object>() { + + @Override + public Object fromBytes(byte[] bytes) { + return null; + } + + @Override + public byte[] toBytes(Object object) { + return new byte[0]; + } + }; + + JoinFunction<String, Object, Object, TestOutputMessageEnvelope> joinFn = new TestJoinFunction(); + JoinOperatorSpec<String, Object, Object, TestOutputMessageEnvelope> joinOperatorSpec = + new JoinOperatorSpec<>(leftOpSpec, rightOpSpec, joinFn, new StringSerde("UTF-8"), objSerde, objSerde, 50000, "op2"); + JoinOperatorSpec<String, Object, Object, TestOutputMessageEnvelope> joinOpCopy = + (JoinOperatorSpec<String, Object, Object, TestOutputMessageEnvelope>) OperatorSpecTestUtils.copyOpSpec(joinOperatorSpec); + assertNotEquals("Expected deserialized copy of operator spec should not be the same as the original operator spec", joinOperatorSpec, joinOpCopy); + assertTrue(joinOperatorSpec.isClone(joinOpCopy)); + assertTrue(joinOpCopy.getLeftInputOpSpec().isClone(leftOpSpec)); + assertTrue(joinOpCopy.getRightInputOpSpec().isClone(rightOpSpec)); + } + + @Test + public void testStreamTableJoinOperatorSpec() { + StreamTableJoinFunction<String, Object, Object, TestOutputMessageEnvelope> joinFn = new TestStreamTableJoinFunction(); + + TableSpec tableSpec = new TableSpec("table-0", KVSerde.of(new StringSerde("UTF-8"), new JsonSerdeV2<>()), "my.table.provider.class", + new MapConfig(new HashMap<String, String>() { { this.put("config1", "value1"); this.put("config2", "value2"); } })); + + StreamTableJoinOperatorSpec<String, Object, Object, TestOutputMessageEnvelope> joinOperatorSpec = + new StreamTableJoinOperatorSpec<>(tableSpec, joinFn, "join-3"); + + StreamTableJoinOperatorSpec<String, Object, Object, TestOutputMessageEnvelope> joinOpSpecCopy = + (StreamTableJoinOperatorSpec<String, Object, Object, TestOutputMessageEnvelope>) OperatorSpecTestUtils.copyOpSpec(joinOperatorSpec); + assertNotEquals(joinOpSpecCopy, joinOperatorSpec); + assertEquals(joinOpSpecCopy.getOpId(), joinOperatorSpec.getOpId()); + assertTrue(joinOpSpecCopy.getTableSpec() != joinOperatorSpec.getTableSpec()); + assertEquals(joinOpSpecCopy.getTableSpec().getId(), joinOperatorSpec.getTableSpec().getId()); + assertEquals(joinOpSpecCopy.getTableSpec().getTableProviderFactoryClassName(), joinOperatorSpec.getTableSpec().getTableProviderFactoryClassName()); + } + + @Test + public void testSendToTableOperatorSpec() { + TableSpec tableSpec = new TableSpec("table-0", KVSerde.of(new StringSerde("UTF-8"), new JsonSerdeV2<>()), "my.table.provider.class", + new MapConfig(new HashMap<String, String>() { { this.put("config1", "value1"); this.put("config2", "value2"); } })); + SendToTableOperatorSpec<String, Integer> sendOpSpec = + new SendToTableOperatorSpec<>(tableSpec, "output-1"); + SendToTableOperatorSpec<String, Integer> sendToCopy = (SendToTableOperatorSpec<String, Integer>) OperatorSpecTestUtils + .copyOpSpec(sendOpSpec); + assertNotEquals(sendToCopy, sendOpSpec); + assertEquals(sendToCopy.getOpId(), sendOpSpec.getOpId()); + assertTrue(sendToCopy.getTableSpec() != sendOpSpec.getTableSpec() && sendToCopy.getTableSpec().equals(sendOpSpec.getTableSpec())); + } + + @Test + public void testBroadcastOperatorSpec() { + OutputStreamImpl<TestOutputMessageEnvelope> outputStream = + new OutputStreamImpl<>(new StreamSpec("output-0", "outputStream-0", "kafka"), new StringSerde("UTF-8"), new JsonSerdeV2<TestOutputMessageEnvelope>(), true); + BroadcastOperatorSpec<TestOutputMessageEnvelope> broadcastOpSpec = new BroadcastOperatorSpec<>(outputStream, "broadcast-1"); + BroadcastOperatorSpec<TestOutputMessageEnvelope> broadcastOpCopy = (BroadcastOperatorSpec<TestOutputMessageEnvelope>) OperatorSpecTestUtils + .copyOpSpec(broadcastOpSpec); + assertNotEquals(broadcastOpCopy, broadcastOpSpec); + assertEquals(broadcastOpCopy.getOpId(), broadcastOpSpec.getOpId()); + assertTrue(broadcastOpCopy.getOutputStream() != broadcastOpSpec.getOutputStream()); + assertEquals(broadcastOpCopy.getOutputStream().getSystemStream(), broadcastOpSpec.getOutputStream().getSystemStream()); + assertEquals(broadcastOpCopy.getOutputStream().isKeyed(), broadcastOpSpec.getOutputStream().isKeyed()); + } + + @Test + public void testMapStreamOperatorSpecWithWatermark() { + MapWithWatermarkFn testMapFn = new MapWithWatermarkFn(); + + StreamOperatorSpec<TestMessageEnvelope, TestOutputMessageEnvelope> streamOperatorSpec = + OperatorSpecs.createMapOperatorSpec(testMapFn, "op0"); + StreamOperatorSpec<TestMessageEnvelope, TestOutputMessageEnvelope> cloneOperatorSpec = + (StreamOperatorSpec<TestMessageEnvelope, TestOutputMessageEnvelope>) OperatorSpecTestUtils.copyOpSpec(streamOperatorSpec); + assertNotEquals(streamOperatorSpec, cloneOperatorSpec); + assertTrue(streamOperatorSpec.isClone(cloneOperatorSpec)); + assertNotEquals(streamOperatorSpec.getTransformFn(), cloneOperatorSpec.getTransformFn()); + assertEquals(streamOperatorSpec.getWatermarkFn(), testMapFn); + assertNotNull(cloneOperatorSpec.getWatermarkFn()); + assertNotEquals(cloneOperatorSpec.getTransformFn(), cloneOperatorSpec.getWatermarkFn()); + assertNull(streamOperatorSpec.getTimerFn()); + assertNull(cloneOperatorSpec.getTimerFn()); + } + + @Test + public void testMapStreamOperatorSpecWithTimer() { + MapWithTimerFn testMapFn = new MapWithTimerFn(); + + StreamOperatorSpec<TestMessageEnvelope, TestOutputMessageEnvelope> streamOperatorSpec = + OperatorSpecs.createMapOperatorSpec(testMapFn, "op0"); + StreamOperatorSpec<TestMessageEnvelope, TestOutputMessageEnvelope> cloneOperatorSpec = + (StreamOperatorSpec<TestMessageEnvelope, TestOutputMessageEnvelope>) OperatorSpecTestUtils.copyOpSpec(streamOperatorSpec); + assertNotEquals(streamOperatorSpec, cloneOperatorSpec); + assertTrue(streamOperatorSpec.isClone(cloneOperatorSpec)); + assertNotEquals(streamOperatorSpec.getTransformFn(), cloneOperatorSpec.getTransformFn()); + assertNull(streamOperatorSpec.getWatermarkFn()); + assertNull(cloneOperatorSpec.getWatermarkFn()); + assertNotEquals(cloneOperatorSpec.getTransformFn(), cloneOperatorSpec.getWatermarkFn()); + assertEquals(streamOperatorSpec.getTimerFn(), testMapFn); + assertNotNull(cloneOperatorSpec.getTimerFn()); + assertNotEquals(streamOperatorSpec.getTimerFn(), cloneOperatorSpec.getTimerFn()); + } + + @Test + public void testStreamOperatorSpecWithMapAndListInClosure() { + List<Integer> integers = new ArrayList<>(1); + integers.add(0, 100); + List<String> keys = new ArrayList<>(1); + keys.add(0, "test-1"); + MapFunction<TestMessageEnvelope, TestOutputMessageEnvelope> mapFn = + m -> new TestOutputMessageEnvelope(keys.get(m.getKey().hashCode() % 1), integers.get(m.getMessage().hashCode() % 1)); + StreamOperatorSpec<TestMessageEnvelope, TestOutputMessageEnvelope> streamOperatorSpec = + OperatorSpecs.createMapOperatorSpec(mapFn, "op0"); + StreamOperatorSpec<TestMessageEnvelope, TestOutputMessageEnvelope> cloneOperatorSpec = + (StreamOperatorSpec<TestMessageEnvelope, TestOutputMessageEnvelope>) OperatorSpecTestUtils.copyOpSpec(streamOperatorSpec); + assertNotEquals(streamOperatorSpec, cloneOperatorSpec); + assertTrue(streamOperatorSpec.isClone(cloneOperatorSpec)); + MapFunction userFn = (MapFunction) Whitebox.getInternalState(streamOperatorSpec, "mapFn"); + assertEquals(userFn, mapFn); + assertNotEquals(streamOperatorSpec.getTransformFn(), cloneOperatorSpec.getTransformFn()); + MapFunction clonedUserFn = (MapFunction) Whitebox.getInternalState(cloneOperatorSpec, "mapFn"); + assertTrue(cloneOperatorSpec.getTransformFn() instanceof FlatMapFunction); + assertTrue(clonedUserFn instanceof MapFunction); + assertNotEquals(userFn, clonedUserFn); + + // verify changing the values in the original keys and integers list will change the result of the original map function + TestMessageEnvelope mockImsg = new TestMessageEnvelope("input-key-x", new String("value-x")); + assertEquals(((MapFunction) userFn).apply(mockImsg), new TestOutputMessageEnvelope("test-1", 100)); + integers.set(0, 200); + keys.set(0, "test-2"); + assertEquals(((MapFunction) userFn).apply(mockImsg), new TestOutputMessageEnvelope("test-2", 200)); + // verify that the cloned map function uses a different copy of lists and still yields the same result + assertEquals(((MapFunction) clonedUserFn).apply(mockImsg), new TestOutputMessageEnvelope("test-1", 100)); + } + + @Test + public void testStreamOperatorSpecWithMapWithFunctionReference() { + MapFunction<KV<String, Object>, Object> mapFn = KV::getValue; + StreamOperatorSpec<KV<String, Object>, Object> streamOperatorSpec = + OperatorSpecs.createMapOperatorSpec(mapFn, "op0"); + StreamOperatorSpec<TestMessageEnvelope, TestOutputMessageEnvelope> cloneOperatorSpec = + (StreamOperatorSpec<TestMessageEnvelope, TestOutputMessageEnvelope>) OperatorSpecTestUtils.copyOpSpec(streamOperatorSpec); + assertNotEquals(streamOperatorSpec, cloneOperatorSpec); + assertTrue(streamOperatorSpec.isClone(cloneOperatorSpec)); + MapFunction userFn = (MapFunction) Whitebox.getInternalState(streamOperatorSpec, "mapFn"); + assertEquals(userFn, mapFn); + assertNotEquals(streamOperatorSpec.getTransformFn(), cloneOperatorSpec.getTransformFn()); + MapFunction clonedUserFn = (MapFunction) Whitebox.getInternalState(cloneOperatorSpec, "mapFn"); + assertTrue(cloneOperatorSpec.getTransformFn() instanceof FlatMapFunction); + assertTrue(clonedUserFn instanceof MapFunction); + assertNotEquals(userFn, clonedUserFn); + } + + @Test + public void testStreamOperatorSpecWithMapWithEnum() { + MapFunction<TestMessageEnvelope, TestOutputMessageEnvelope> mapFn = new MapWithEnum(OperatorSpecTestUtils.TestEnum.One); + StreamOperatorSpec<TestMessageEnvelope, TestOutputMessageEnvelope> streamOperatorSpec = + OperatorSpecs.createMapOperatorSpec(mapFn, "op0"); + assertTrue(streamOperatorSpec instanceof MapOperatorSpec); + StreamOperatorSpec<TestMessageEnvelope, TestOutputMessageEnvelope> cloneOperatorSpec = + (StreamOperatorSpec<TestMessageEnvelope, TestOutputMessageEnvelope>) OperatorSpecTestUtils.copyOpSpec(streamOperatorSpec); + assertNotEquals(streamOperatorSpec, cloneOperatorSpec); + assertTrue(streamOperatorSpec.isClone(cloneOperatorSpec)); + MapFunction userFn = (MapFunction) Whitebox.getInternalState(streamOperatorSpec, "mapFn"); + assertEquals(userFn, mapFn); + assertNotEquals(streamOperatorSpec.getTransformFn(), cloneOperatorSpec.getTransformFn()); + MapFunction clonedUserFn = (MapFunction) Whitebox.getInternalState(cloneOperatorSpec, "mapFn"); + assertTrue(cloneOperatorSpec.getTransformFn() instanceof FlatMapFunction); + assertTrue(clonedUserFn instanceof MapWithEnum); + assertNotEquals(userFn, clonedUserFn); + // originally the types should be the same + assertTrue(((MapWithEnum) userFn).getType() == ((MapWithEnum) clonedUserFn).getType()); + // after changing the type of the cloned user function, the types are different now + ((MapWithEnum) clonedUserFn).setType(OperatorSpecTestUtils.TestEnum.Two); + assertTrue(((MapWithEnum) userFn).getType() != ((MapWithEnum) clonedUserFn).getType()); + } +} http://git-wip-us.apache.org/repos/asf/samza/blob/53d7f262/samza-core/src/test/java/org/apache/samza/operators/spec/TestPartitionByOperatorSpec.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/operators/spec/TestPartitionByOperatorSpec.java b/samza-core/src/test/java/org/apache/samza/operators/spec/TestPartitionByOperatorSpec.java new file mode 100644 index 0000000..00ec176 --- /dev/null +++ b/samza-core/src/test/java/org/apache/samza/operators/spec/TestPartitionByOperatorSpec.java @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.samza.operators.spec; + +import java.util.Collection; +import org.apache.samza.config.Config; +import org.apache.samza.config.JobConfig; +import org.apache.samza.operators.KV; +import org.apache.samza.operators.MessageStream; +import org.apache.samza.operators.OperatorSpecGraph; +import org.apache.samza.operators.StreamGraphSpec; +import org.apache.samza.operators.TimerRegistry; +import org.apache.samza.operators.functions.MapFunction; +import org.apache.samza.operators.functions.TimerFunction; +import org.apache.samza.operators.functions.WatermarkFunction; +import org.apache.samza.runtime.ApplicationRunner; +import org.apache.samza.serializers.NoOpSerde; +import org.apache.samza.system.StreamSpec; +import org.junit.Before; +import org.junit.Test; +import org.mockito.internal.util.reflection.Whitebox; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + + +/** + * Unit tests for partitionBy operator + */ +public class TestPartitionByOperatorSpec { + + private final ApplicationRunner mockRunner = mock(ApplicationRunner.class); + private final Config mockConfig = mock(Config.class); + private final String testInputId = "test-input-1"; + private final String testJobName = "testJob"; + private final String testJobId = "1"; + private final String testReparStreamName = "parByKey"; + private StreamGraphSpec graphSpec = null; + + class TimerMapFn implements MapFunction<Object, String>, TimerFunction<String, Object> { + + @Override + public String apply(Object message) { + return message.toString(); + } + + @Override + public void registerTimer(TimerRegistry<String> timerRegistry) { + + } + + @Override + public Collection<Object> onTimer(String key, long timestamp) { + return null; + } + } + + class WatermarkMapFn implements MapFunction<Object, String>, WatermarkFunction<Object> { + + @Override + public String apply(Object message) { + return message.toString(); + } + + @Override + public Collection<Object> processWatermark(long watermark) { + return null; + } + + @Override + public Long getOutputWatermark() { + return null; + } + } + + @Before + public void setup() { + when(mockConfig.get(JobConfig.JOB_NAME())).thenReturn(testJobName); + when(mockConfig.get(JobConfig.JOB_ID(), "1")).thenReturn(testJobId); + StreamSpec inputSpec1 = new StreamSpec(testInputId, testInputId, "kafka"); + when(mockRunner.getStreamSpec(testInputId)).thenReturn(inputSpec1); + String intermediateStreamName = String.format("%s-%s-partition_by-%s", testJobName, testJobId, testReparStreamName); + StreamSpec intermediateSpec1 = new StreamSpec(intermediateStreamName, intermediateStreamName, "kafka"); + when(mockRunner.getStreamSpec(intermediateStreamName)).thenReturn(intermediateSpec1); + graphSpec = new StreamGraphSpec(mockRunner, mockConfig); + } + + @Test + public void testPartitionBy() { + MessageStream inputStream = graphSpec.getInputStream(testInputId); + MapFunction<Object, String> keyFn = m -> m.toString(); + MapFunction<Object, Object> valueFn = m -> m; + MessageStream<KV<String, Object>> + reparStream = inputStream.partitionBy(keyFn, valueFn, testReparStreamName); + InputOperatorSpec inputOpSpec = (InputOperatorSpec) Whitebox.getInternalState(reparStream, "operatorSpec"); + assertEquals(inputOpSpec.getStreamSpec().getId(), String.format("%s-%s-partition_by-%s", testJobName, testJobId, testReparStreamName)); + assertTrue(inputOpSpec.getKeySerde() instanceof NoOpSerde); + assertTrue(inputOpSpec.getValueSerde() instanceof NoOpSerde); + assertTrue(inputOpSpec.isKeyed()); + assertNull(inputOpSpec.getTimerFn()); + assertNull(inputOpSpec.getWatermarkFn()); + InputOperatorSpec originInputSpec = (InputOperatorSpec) Whitebox.getInternalState(inputStream, "operatorSpec"); + assertTrue(originInputSpec.getRegisteredOperatorSpecs().toArray()[0] instanceof PartitionByOperatorSpec); + PartitionByOperatorSpec reparOpSpec = (PartitionByOperatorSpec) originInputSpec.getRegisteredOperatorSpecs().toArray()[0]; + assertEquals(reparOpSpec.getOpId(), String.format("%s-%s-partition_by-%s", testJobName, testJobId, testReparStreamName)); + assertEquals(reparOpSpec.getKeyFunction(), keyFn); + assertEquals(reparOpSpec.getValueFunction(), valueFn); + assertEquals(reparOpSpec.getOutputStream().getStreamSpec(), new StreamSpec(reparOpSpec.getOpId(), reparOpSpec.getOpId(), "kafka")); + assertNull(reparOpSpec.getTimerFn()); + assertNull(reparOpSpec.getWatermarkFn()); + } + + @Test + public void testCopy() { + MessageStream inputStream = graphSpec.getInputStream(testInputId); + inputStream.partitionBy(m -> m.toString(), m -> m, testReparStreamName); + OperatorSpecGraph specGraph = graphSpec.getOperatorSpecGraph(); + OperatorSpecGraph clonedGraph = specGraph.clone(); + OperatorSpecTestUtils.assertClonedGraph(specGraph, clonedGraph); + } + + @Test(expected = IllegalArgumentException.class) + public void testTimerFunctionAsKeyFn() { + TimerMapFn keyFn = new TimerMapFn(); + MessageStream<Object> inputStream = graphSpec.getInputStream(testInputId); + inputStream.partitionBy(keyFn, m -> m, "parByKey"); + } + + @Test(expected = IllegalArgumentException.class) + public void testWatermarkFunctionAsKeyFn() { + WatermarkMapFn keyFn = new WatermarkMapFn(); + MessageStream<Object> inputStream = graphSpec.getInputStream(testInputId); + inputStream.partitionBy(keyFn, m -> m, "parByKey"); + } + + @Test(expected = IllegalArgumentException.class) + public void testTimerFunctionAsValueFn() { + TimerMapFn valueFn = new TimerMapFn(); + MessageStream<Object> inputStream = graphSpec.getInputStream(testInputId); + inputStream.partitionBy(m -> m.toString(), valueFn, "parByKey"); + } + + @Test(expected = IllegalArgumentException.class) + public void testWatermarkFunctionAsValueFn() { + WatermarkMapFn valueFn = new WatermarkMapFn(); + MessageStream<Object> inputStream = graphSpec.getInputStream(testInputId); + inputStream.partitionBy(m -> m.toString(), valueFn, "parByKey"); + } +} http://git-wip-us.apache.org/repos/asf/samza/blob/53d7f262/samza-core/src/test/java/org/apache/samza/operators/spec/TestWindowOperatorSpec.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/operators/spec/TestWindowOperatorSpec.java b/samza-core/src/test/java/org/apache/samza/operators/spec/TestWindowOperatorSpec.java index 65f1dc6..0a2214b 100644 --- a/samza-core/src/test/java/org/apache/samza/operators/spec/TestWindowOperatorSpec.java +++ b/samza-core/src/test/java/org/apache/samza/operators/spec/TestWindowOperatorSpec.java @@ -19,50 +19,312 @@ package org.apache.samza.operators.spec; +import org.apache.samza.operators.TimerRegistry; +import org.apache.samza.operators.functions.TimerFunction; +import org.apache.samza.operators.functions.WatermarkFunction; import org.apache.samza.serializers.Serde; +import org.apache.samza.operators.functions.FoldLeftFunction; +import org.apache.samza.operators.functions.MapFunction; +import org.apache.samza.operators.functions.SupplierFunction; import org.apache.samza.operators.triggers.Trigger; import org.apache.samza.operators.triggers.Triggers; import org.apache.samza.operators.windows.internal.WindowInternal; import org.apache.samza.operators.windows.internal.WindowType; import org.junit.Assert; +import org.junit.Before; import org.junit.Test; import java.time.Duration; +import java.util.ArrayList; +import java.util.Collection; +import static org.junit.Assert.*; import static org.mockito.Mockito.mock; public class TestWindowOperatorSpec { + + private Trigger defaultTrigger; + private Trigger earlyTrigger; + private Trigger lateTrigger; + private FoldLeftFunction<Object, Collection> foldFn; + private SupplierFunction<Collection> supplierFunction; + private MapFunction<Object, Object> keyFn; + private MapFunction<Object, Long> timeFn; + + @Before + public void setup() { + + foldFn = (m, c) -> { + c.add(m); + return c; + }; + supplierFunction = () -> new ArrayList<>(); + keyFn = m -> m.toString(); + timeFn = m -> 123456L; + + defaultTrigger = Triggers.timeSinceFirstMessage(Duration.ofMillis(150)); + earlyTrigger = Triggers.repeat(Triggers.count(5)); + lateTrigger = null; + } + @Test public void testTriggerIntervalWithNestedTimeTriggers() { - Trigger defaultTrigger = Triggers.timeSinceFirstMessage(Duration.ofMillis(150)); - Trigger lateTrigger = Triggers.any(Triggers.count(6), Triggers.timeSinceFirstMessage(Duration.ofMillis(15))); - Trigger earlyTrigger = Triggers.repeat( - Triggers.any(Triggers.count(23), - Triggers.timeSinceFirstMessage(Duration.ofMillis(15)), - Triggers.any(Triggers.any(Triggers.count(6), - Triggers.timeSinceFirstMessage(Duration.ofMillis(15)), - Triggers.timeSinceFirstMessage(Duration.ofMillis(25)), - Triggers.timeSinceLastMessage(Duration.ofMillis(15)))))); - - WindowInternal window = new WindowInternal(defaultTrigger, null, null, null, - null, WindowType.SESSION, null, null, mock(Serde.class)); - window.setEarlyTrigger(earlyTrigger); - window.setLateTrigger(lateTrigger); + defaultTrigger = Triggers.timeSinceFirstMessage(Duration.ofMillis(150)); + lateTrigger = Triggers.any(Triggers.count(6), Triggers.timeSinceFirstMessage(Duration.ofMillis(15))); + earlyTrigger = Triggers.repeat( + Triggers.any(Triggers.count(23), + Triggers.timeSinceFirstMessage(Duration.ofMillis(15)), + Triggers.any(Triggers.any(Triggers.count(6), + Triggers.timeSinceFirstMessage(Duration.ofMillis(15)), + Triggers.timeSinceFirstMessage(Duration.ofMillis(25)), + Triggers.timeSinceLastMessage(Duration.ofMillis(15)))))); - WindowOperatorSpec spec = new WindowOperatorSpec(window, "0"); - Assert.assertEquals(spec.getDefaultTriggerMs(), 5); + WindowOperatorSpec spec = getWindowOperatorSpec("w0"); + assertEquals(spec.getDefaultTriggerMs(), 5); } @Test public void testTriggerIntervalWithSingleTimeTrigger() { - Trigger defaultTrigger = Triggers.timeSinceFirstMessage(Duration.ofMillis(150)); - Trigger earlyTrigger = Triggers.repeat(Triggers.count(5)); + WindowOperatorSpec spec = getWindowOperatorSpec("w0"); + assertEquals(spec.getDefaultTriggerMs(), 150); + } + + @Test(expected = IllegalArgumentException.class) + public void testIllegalTimerFunctionAsInitializer() { + class TimedSupplierFunction implements SupplierFunction<Collection>, TimerFunction<Object, Collection> { + + @Override + public Collection get() { + return new ArrayList<>(); + } + + @Override + public void registerTimer(TimerRegistry<Object> timerRegistry) { + + } + + @Override + public Collection<Collection> onTimer(Object key, long timestamp) { + return null; + } + } + supplierFunction = new TimedSupplierFunction(); + + getWindowOperatorSpec("w0"); + } + + @Test(expected = IllegalArgumentException.class) + public void testIllegalWatermarkFunctionAsInitializer() { + class WatermarkSupplierFunction implements SupplierFunction<Collection>, WatermarkFunction<Collection> { + + @Override + public Collection get() { + return new ArrayList<>(); + } + + @Override + public Collection<Collection> processWatermark(long watermark) { + return null; + } + + @Override + public Long getOutputWatermark() { + return null; + } + } + supplierFunction = new WatermarkSupplierFunction(); + + getWindowOperatorSpec("w0"); + } + + @Test(expected = IllegalArgumentException.class) + public void testIllegalTimerFunctionAsKeyFn() { + class TimerMapFunction implements MapFunction<Object, Object>, TimerFunction<Object, Object> { + + @Override + public Object apply(Object message) { + return message.toString(); + } + + @Override + public void registerTimer(TimerRegistry<Object> timerRegistry) { + + } + + @Override + public Collection<Object> onTimer(Object key, long timestamp) { + return null; + } + } + keyFn = new TimerMapFunction(); + + getWindowOperatorSpec("w0"); + } + + @Test(expected = IllegalArgumentException.class) + public void testIllegalWatermarkFunctionAsKeyFn() { + class WatermarkMapFunction implements MapFunction<Object, Object>, WatermarkFunction<Object> { + + @Override + public Object apply(Object message) { + return message.toString(); + } + + @Override + public Collection<Object> processWatermark(long watermark) { + return null; + } + + @Override + public Long getOutputWatermark() { + return null; + } + } + keyFn = new WatermarkMapFunction(); + + getWindowOperatorSpec("w0"); + } - WindowInternal window = new WindowInternal(defaultTrigger, null, null, null, - null, WindowType.SESSION, null, null, mock(Serde.class)); + @Test(expected = IllegalArgumentException.class) + public void testIllegalTimerFunctionAsEventTimeFn() { + class TimerMapFunction implements MapFunction<Object, Long>, TimerFunction<Object, Object> { + + @Override + public Long apply(Object message) { + return 123456L; + } + + @Override + public void registerTimer(TimerRegistry<Object> timerRegistry) { + + } + + @Override + public Collection<Object> onTimer(Object key, long timestamp) { + return null; + } + } + timeFn = new TimerMapFunction(); + + getWindowOperatorSpec("w0"); + } + + @Test(expected = IllegalArgumentException.class) + public void testIllegalWatermarkFunctionAsEventTimeFn() { + class WatermarkMapFunction implements MapFunction<Object, Long>, WatermarkFunction<Object> { + + @Override + public Long apply(Object message) { + return 123456L; + } + + @Override + public Collection<Object> processWatermark(long watermark) { + return null; + } + + @Override + public Long getOutputWatermark() { + return null; + } + } + timeFn = new WatermarkMapFunction(); + + getWindowOperatorSpec("w0"); + } + + @Test + public void testTimerFunctionAsFoldLeftFn() { + class TimerFoldLeftFunction implements FoldLeftFunction<Object, Collection>, TimerFunction<Object, Collection> { + + @Override + public Collection apply(Object message, Collection oldValue) { + oldValue.add(message); + return oldValue; + } + + @Override + public void registerTimer(TimerRegistry<Object> timerRegistry) { + + } + + @Override + public Collection<Collection> onTimer(Object key, long timestamp) { + return null; + } + } + + foldFn = new TimerFoldLeftFunction(); + WindowOperatorSpec<Object, Object, Collection> windowSpec = getWindowOperatorSpec("w0"); + assertEquals(windowSpec.getTimerFn(), foldFn); + assertNull(windowSpec.getWatermarkFn()); + } + + @Test + public void testWatermarkFunctionAsFoldLeftFn() { + class WatermarkFoldLeftFunction implements FoldLeftFunction<Object, Collection>, WatermarkFunction<Object> { + + @Override + public Collection<Object> processWatermark(long watermark) { + return null; + } + + @Override + public Long getOutputWatermark() { + return null; + } + + @Override + public Collection apply(Object message, Collection oldValue) { + oldValue.add(message); + return oldValue; + } + } + + foldFn = new WatermarkFoldLeftFunction(); + WindowOperatorSpec<Object, Object, Collection> windowSpec = getWindowOperatorSpec("w0"); + assertEquals(windowSpec.getWatermarkFn(), foldFn); + assertNull(windowSpec.getTimerFn()); + } + + @Test + public void testCopy() { + WindowInternal<Object, Object, Collection> window = new WindowInternal<Object, Object, Collection>( + defaultTrigger, supplierFunction, foldFn, keyFn, timeFn, WindowType.SESSION, null, + mock(Serde.class), mock(Serde.class)); window.setEarlyTrigger(earlyTrigger); - WindowOperatorSpec spec = new WindowOperatorSpec(window, "0"); - Assert.assertEquals(spec.getDefaultTriggerMs(), 150); + WindowOperatorSpec<Object, Object, Collection> spec = new WindowOperatorSpec<>(window, "w0"); + + WindowOperatorSpec<Object, Object, Collection> copy = + (WindowOperatorSpec<Object, Object, Collection>) OperatorSpecTestUtils.copyOpSpec(spec); + + Assert.assertNotEquals(spec, copy); + Assert.assertTrue(spec.isClone(copy)); + Assert.assertNotEquals(spec.getWindow(), copy.getWindow()); + Assert.assertNotEquals(copy.getWindow().getInitializer(), supplierFunction); + assertEquals(copy.getWindow().getInitializer().get(), supplierFunction.get()); + Assert.assertNotEquals(copy.getWindow().getFoldLeftFunction(), foldFn); + Object mockMsg = new Object(); + assertEquals(copy.getWindow().getFoldLeftFunction().apply(mockMsg, new ArrayList<>()), foldFn.apply(mockMsg, new ArrayList<>())); + Assert.assertNotEquals(copy.getWindow().getKeyExtractor(), keyFn); + assertEquals(copy.getWindow().getKeyExtractor().apply(mockMsg), keyFn.apply(mockMsg)); + Assert.assertNotEquals(copy.getWindow().getEventTimeExtractor(), timeFn); + assertEquals(copy.getWindow().getEventTimeExtractor().apply(mockMsg), timeFn.apply(mockMsg)); + assertEquals(copy.getDefaultTriggerMs(), 150); + } + + private WindowOperatorSpec getWindowOperatorSpec(String opId) { + WindowInternal<Object, Object, Collection> window = new WindowInternal<Object, Object, Collection>( + defaultTrigger, supplierFunction, foldFn, keyFn, timeFn, WindowType.SESSION, null, + mock(Serde.class), mock(Serde.class)); + if (earlyTrigger != null) { + window.setEarlyTrigger(earlyTrigger); + } + if (lateTrigger != null) { + window.setLateTrigger(lateTrigger); + } + return new WindowOperatorSpec<>(window, opId); } + } http://git-wip-us.apache.org/repos/asf/samza/blob/53d7f262/samza-core/src/test/java/org/apache/samza/runtime/TestAbstractApplicationRunner.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/runtime/TestAbstractApplicationRunner.java b/samza-core/src/test/java/org/apache/samza/runtime/TestAbstractApplicationRunner.java index ed13b5b..b8d3f15 100644 --- a/samza-core/src/test/java/org/apache/samza/runtime/TestAbstractApplicationRunner.java +++ b/samza-core/src/test/java/org/apache/samza/runtime/TestAbstractApplicationRunner.java @@ -222,10 +222,10 @@ public class TestAbstractApplicationRunner { StreamConfig.SYSTEM(), TEST_SYSTEM); AbstractApplicationRunner runner = new TestAbstractApplicationRunnerImpl(config); - StreamSpec spec = runner.getStreamSpec(STREAM_ID, TEST_PHYSICAL_NAME); + StreamSpec spec = runner.getStreamSpec(STREAM_ID); assertEquals(STREAM_ID, spec.getId()); - assertEquals(TEST_PHYSICAL_NAME, spec.getPhysicalName()); + assertEquals(TEST_PHYSICAL_NAME2, spec.getPhysicalName()); assertEquals(TEST_SYSTEM, spec.getSystemName()); } @@ -233,11 +233,11 @@ public class TestAbstractApplicationRunner { @Test public void testGetStreamPhysicalNameArgSpecialCharacters() { Config config = buildStreamConfig(STREAM_ID, - StreamConfig.PHYSICAL_NAME(), TEST_PHYSICAL_NAME2, + StreamConfig.PHYSICAL_NAME(), TEST_PHYSICAL_NAME_SPECIAL_CHARS, StreamConfig.SYSTEM(), TEST_SYSTEM); AbstractApplicationRunner runner = new TestAbstractApplicationRunnerImpl(config); - StreamSpec spec = runner.getStreamSpec(STREAM_ID, TEST_PHYSICAL_NAME_SPECIAL_CHARS); + StreamSpec spec = runner.getStreamSpec(STREAM_ID); assertEquals(TEST_PHYSICAL_NAME_SPECIAL_CHARS, spec.getPhysicalName()); } @@ -245,11 +245,11 @@ public class TestAbstractApplicationRunner { @Test public void testGetStreamPhysicalNameArgNull() { Config config = buildStreamConfig(STREAM_ID, - StreamConfig.PHYSICAL_NAME(), TEST_PHYSICAL_NAME2, + StreamConfig.PHYSICAL_NAME(), null, StreamConfig.SYSTEM(), TEST_SYSTEM); AbstractApplicationRunner runner = new TestAbstractApplicationRunnerImpl(config); - StreamSpec spec = runner.getStreamSpec(STREAM_ID, null); + StreamSpec spec = runner.getStreamSpec(STREAM_ID); assertNull(spec.getPhysicalName()); } @@ -257,11 +257,11 @@ public class TestAbstractApplicationRunner { @Test public void testGetStreamSystemNameArgValid() { Config config = buildStreamConfig(STREAM_ID, - StreamConfig.PHYSICAL_NAME(), TEST_PHYSICAL_NAME2, // This should be ignored because of the explicit arg - StreamConfig.SYSTEM(), TEST_SYSTEM2); // This too + StreamConfig.PHYSICAL_NAME(), TEST_PHYSICAL_NAME, // This should be ignored because of the explicit arg + StreamConfig.SYSTEM(), TEST_SYSTEM); // This too AbstractApplicationRunner runner = new TestAbstractApplicationRunnerImpl(config); - StreamSpec spec = runner.getStreamSpec(STREAM_ID, TEST_PHYSICAL_NAME, TEST_SYSTEM); + StreamSpec spec = runner.getStreamSpec(STREAM_ID); assertEquals(STREAM_ID, spec.getId()); assertEquals(TEST_PHYSICAL_NAME, spec.getPhysicalName()); @@ -272,33 +272,33 @@ public class TestAbstractApplicationRunner { @Test(expected = IllegalArgumentException.class) public void testGetStreamSystemNameArgInvalid() { Config config = buildStreamConfig(STREAM_ID, - StreamConfig.PHYSICAL_NAME(), TEST_PHYSICAL_NAME2, - StreamConfig.SYSTEM(), TEST_SYSTEM2); + StreamConfig.PHYSICAL_NAME(), TEST_PHYSICAL_NAME, + StreamConfig.SYSTEM(), TEST_SYSTEM_INVALID); AbstractApplicationRunner runner = new TestAbstractApplicationRunnerImpl(config); - runner.getStreamSpec(STREAM_ID, TEST_PHYSICAL_NAME, TEST_SYSTEM_INVALID); + runner.getStreamSpec(STREAM_ID); } // Empty strings are NOT allowed for system name, because it's used as an identifier in the config. @Test(expected = IllegalArgumentException.class) public void testGetStreamSystemNameArgEmpty() { Config config = buildStreamConfig(STREAM_ID, - StreamConfig.PHYSICAL_NAME(), TEST_PHYSICAL_NAME2, - StreamConfig.SYSTEM(), TEST_SYSTEM2); + StreamConfig.PHYSICAL_NAME(), TEST_PHYSICAL_NAME, + StreamConfig.SYSTEM(), ""); AbstractApplicationRunner runner = new TestAbstractApplicationRunnerImpl(config); - runner.getStreamSpec(STREAM_ID, TEST_PHYSICAL_NAME, ""); + runner.getStreamSpec(STREAM_ID); } // Null is not allowed IllegalArgumentException system name. @Test(expected = IllegalArgumentException.class) public void testGetStreamSystemNameArgNull() { Config config = buildStreamConfig(STREAM_ID, - StreamConfig.PHYSICAL_NAME(), TEST_PHYSICAL_NAME2, - StreamConfig.SYSTEM(), TEST_SYSTEM2); + StreamConfig.PHYSICAL_NAME(), TEST_PHYSICAL_NAME, + StreamConfig.SYSTEM(), null); AbstractApplicationRunner runner = new TestAbstractApplicationRunnerImpl(config); - runner.getStreamSpec(STREAM_ID, TEST_PHYSICAL_NAME, null); + runner.getStreamSpec(STREAM_ID); } // Special characters are NOT allowed for streamId, because it's used as an identifier in the config. http://git-wip-us.apache.org/repos/asf/samza/blob/53d7f262/samza-core/src/test/java/org/apache/samza/runtime/TestLocalApplicationRunner.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/runtime/TestLocalApplicationRunner.java b/samza-core/src/test/java/org/apache/samza/runtime/TestLocalApplicationRunner.java index 84ecc6c..595dda2 100644 --- a/samza-core/src/test/java/org/apache/samza/runtime/TestLocalApplicationRunner.java +++ b/samza-core/src/test/java/org/apache/samza/runtime/TestLocalApplicationRunner.java @@ -29,6 +29,7 @@ import java.util.Set; import java.util.stream.Collectors; import org.apache.samza.application.StreamApplication; import org.apache.samza.config.ApplicationConfig; +import org.apache.samza.config.Config; import org.apache.samza.config.JobConfig; import org.apache.samza.config.JobCoordinatorConfig; import org.apache.samza.config.MapConfig; @@ -76,7 +77,6 @@ public class TestLocalApplicationRunner { Map<String, String> config = new HashMap<>(); LocalApplicationRunner runner = spy(new LocalApplicationRunner(new MapConfig(config))); StreamApplication app = mock(StreamApplication.class); - doNothing().when(app).init(anyObject(), anyObject()); StreamManager streamManager = mock(StreamManager.class); doReturn(streamManager).when(runner).getStreamManager(); @@ -84,7 +84,7 @@ public class TestLocalApplicationRunner { ExecutionPlan plan = mock(ExecutionPlan.class); when(plan.getIntermediateStreams()).thenReturn(Collections.singletonList(new StreamSpec("test-stream", "test-stream", "test-system"))); when(plan.getPlanAsJson()).thenReturn(""); - doReturn(plan).when(runner).getExecutionPlan(any(), any()); + doReturn(plan).when(runner).getExecutionPlan(any()); CoordinationUtilsFactory coordinationUtilsFactory = mock(CoordinationUtilsFactory.class); JobCoordinatorConfig mockJcConfig = mock(JobCoordinatorConfig.class); @@ -93,6 +93,7 @@ public class TestLocalApplicationRunner { try { runner.run(app); + runner.waitForFinish(); } catch (Throwable t) { assertNotNull(t); //no jobs exception } @@ -112,7 +113,6 @@ public class TestLocalApplicationRunner { LocalApplicationRunner runner = spy(localRunner); StreamApplication app = mock(StreamApplication.class); - doNothing().when(app).init(anyObject(), anyObject()); StreamManager streamManager = mock(StreamManager.class); doReturn(streamManager).when(runner).getStreamManager(); @@ -120,7 +120,7 @@ public class TestLocalApplicationRunner { ExecutionPlan plan = mock(ExecutionPlan.class); when(plan.getIntermediateStreams()).thenReturn(Collections.singletonList(new StreamSpec("test-stream", "test-stream", "test-system"))); when(plan.getPlanAsJson()).thenReturn(""); - doReturn(plan).when(runner).getExecutionPlan(any(), any()); + doReturn(plan).when(runner).getExecutionPlan(any()); CoordinationUtils coordinationUtils = mock(CoordinationUtils.class); CoordinationUtilsFactory coordinationUtilsFactory = mock(CoordinationUtilsFactory.class); @@ -136,6 +136,7 @@ public class TestLocalApplicationRunner { try { runner.run(app); + runner.waitForFinish(); } catch (Throwable t) { assertNotNull(t); //no jobs exception } @@ -153,7 +154,7 @@ public class TestLocalApplicationRunner { throws Exception { final Map<String, String> config = new HashMap<>(); config.put(ApplicationConfig.APP_PROCESSOR_ID_GENERATOR_CLASS, UUIDGenerator.class.getName()); - config.put(TaskConfig.TASK_CLASS(), "org.apache.samza.test.processor.IdentityStreamTask"); + config.put(TaskConfig.TASK_CLASS(), "org.apache.samza.task.IdentityStreamTask"); LocalApplicationRunner runner = new LocalApplicationRunner(new MapConfig(config)); @@ -170,7 +171,7 @@ public class TestLocalApplicationRunner { }).when(sp).start(); LocalApplicationRunner spy = spy(runner); - doReturn(sp).when(spy).createStreamProcessor(anyObject(), anyObject(), captor.capture()); + doReturn(sp).when(spy).createStreamProcessor(any(Config.class), captor.capture()); spy.runTask(); @@ -184,13 +185,12 @@ public class TestLocalApplicationRunner { config.put(ApplicationConfig.APP_PROCESSOR_ID_GENERATOR_CLASS, UUIDGenerator.class.getName()); LocalApplicationRunner runner = spy(new LocalApplicationRunner(new MapConfig(config))); StreamApplication app = mock(StreamApplication.class); - doNothing().when(app).init(anyObject(), anyObject()); ExecutionPlan plan = mock(ExecutionPlan.class); when(plan.getIntermediateStreams()).thenReturn(Collections.emptyList()); when(plan.getPlanAsJson()).thenReturn(""); when(plan.getJobConfigs()).thenReturn(Collections.singletonList(new JobConfig(new MapConfig(config)))); - doReturn(plan).when(runner).getExecutionPlan(any(), any()); + doReturn(plan).when(runner).getExecutionPlan(any()); StreamProcessor sp = mock(StreamProcessor.class); ArgumentCaptor<StreamProcessorLifecycleListener> captor = @@ -207,6 +207,7 @@ public class TestLocalApplicationRunner { doReturn(sp).when(runner).createStreamProcessor(anyObject(), anyObject(), captor.capture()); runner.run(app); + runner.waitForFinish(); assertEquals(runner.status(app), ApplicationStatus.SuccessfulFinish); } @@ -218,13 +219,12 @@ public class TestLocalApplicationRunner { config.put(ApplicationConfig.PROCESSOR_ID, "0"); LocalApplicationRunner runner = spy(new LocalApplicationRunner(new MapConfig(config))); StreamApplication app = mock(StreamApplication.class); - doNothing().when(app).init(anyObject(), anyObject()); ExecutionPlan plan = mock(ExecutionPlan.class); when(plan.getIntermediateStreams()).thenReturn(Collections.emptyList()); when(plan.getPlanAsJson()).thenReturn(""); when(plan.getJobConfigs()).thenReturn(Collections.singletonList(new JobConfig(new MapConfig(config)))); - doReturn(plan).when(runner).getExecutionPlan(any(), any()); + doReturn(plan).when(runner).getExecutionPlan(any()); StreamProcessor sp = mock(StreamProcessor.class); ArgumentCaptor<StreamProcessorLifecycleListener> captor = @@ -239,6 +239,7 @@ public class TestLocalApplicationRunner { try { runner.run(app); + runner.waitForFinish(); } catch (Throwable th) { assertNotNull(th); } http://git-wip-us.apache.org/repos/asf/samza/blob/53d7f262/samza-core/src/test/java/org/apache/samza/task/IdentityStreamTask.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/task/IdentityStreamTask.java b/samza-core/src/test/java/org/apache/samza/task/IdentityStreamTask.java new file mode 100644 index 0000000..1f71abd --- /dev/null +++ b/samza-core/src/test/java/org/apache/samza/task/IdentityStreamTask.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.samza.task; + +import org.apache.samza.config.Config; +import org.apache.samza.system.IncomingMessageEnvelope; +import org.apache.samza.system.OutgoingMessageEnvelope; +import org.apache.samza.system.SystemStream; + + +public class IdentityStreamTask implements StreamTask , InitableTask { + private int processedMessageCount = 0; + private int expectedMessageCount; + private String outputTopic; + private String outputSystem; + + @Override + public void init(Config config, TaskContext taskContext) throws Exception { + this.expectedMessageCount = config.getInt("app.messageCount"); + this.outputTopic = config.get("app.outputTopic", "output"); + this.outputSystem = config.get("app.outputSystem", "test-system"); + } + + @Override + public void process( + IncomingMessageEnvelope incomingMessageEnvelope, + MessageCollector messageCollector, + TaskCoordinator taskCoordinator) throws Exception { + messageCollector.send( + new OutgoingMessageEnvelope( + new SystemStream(outputSystem, outputTopic), + incomingMessageEnvelope.getMessage())); + processedMessageCount++; + if (processedMessageCount == expectedMessageCount) { + taskCoordinator.shutdown(TaskCoordinator.RequestScope.ALL_TASKS_IN_CONTAINER); + } + } +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/samza/blob/53d7f262/samza-core/src/test/java/org/apache/samza/task/TestTaskFactoryUtil.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/task/TestTaskFactoryUtil.java b/samza-core/src/test/java/org/apache/samza/task/TestTaskFactoryUtil.java index 942792f..e207772 100644 --- a/samza-core/src/test/java/org/apache/samza/task/TestTaskFactoryUtil.java +++ b/samza-core/src/test/java/org/apache/samza/task/TestTaskFactoryUtil.java @@ -18,20 +18,21 @@ */ package org.apache.samza.task; +import java.lang.reflect.Field; +import java.util.concurrent.ExecutorService; import org.apache.samza.SamzaException; +import org.apache.samza.application.StreamApplication; import org.apache.samza.config.ApplicationConfig; import org.apache.samza.config.Config; import org.apache.samza.config.ConfigException; import org.apache.samza.config.MapConfig; -import org.apache.samza.application.StreamApplication; +import org.apache.samza.operators.StreamGraphSpec; import org.apache.samza.runtime.ApplicationRunner; import org.apache.samza.testUtils.TestAsyncStreamTask; import org.apache.samza.testUtils.TestStreamTask; import org.junit.Test; -import java.lang.reflect.Field; import java.util.HashMap; -import java.util.concurrent.ExecutorService; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; @@ -54,7 +55,7 @@ public class TestTaskFactoryUtil { this.put("task.class", "org.apache.samza.testUtils.TestStreamTask"); } }); - Object retFactory = TaskFactoryUtil.createTaskFactory(config, null, null); + Object retFactory = TaskFactoryUtil.createTaskFactory(config); assertTrue(retFactory instanceof StreamTaskFactory); assertTrue(((StreamTaskFactory) retFactory).createInstance() instanceof TestStreamTask); @@ -64,7 +65,7 @@ public class TestTaskFactoryUtil { } }); try { - TaskFactoryUtil.createTaskFactory(config, null, null); + TaskFactoryUtil.createTaskFactory(config); fail("Should have failed w/ no.such.class"); } catch (ConfigException cfe) { // expected @@ -80,7 +81,9 @@ public class TestTaskFactoryUtil { }); StreamApplication streamApp = TaskFactoryUtil.createStreamApplication(config); assertNotNull(streamApp); - Object retFactory = TaskFactoryUtil.createTaskFactory(config, streamApp, mockRunner); + StreamGraphSpec graph = new StreamGraphSpec(mockRunner, config); + streamApp.init(graph, config); + Object retFactory = TaskFactoryUtil.createTaskFactory(graph.getOperatorSpecGraph(), null); assertTrue(retFactory instanceof StreamTaskFactory); assertTrue(((StreamTaskFactory) retFactory).createInstance() instanceof StreamOperatorTask); @@ -184,29 +187,6 @@ public class TestTaskFactoryUtil { // expected } - config = new MapConfig(new HashMap<String, String>() { - { - this.put("task.class", "org.apache.samza.testUtils.TestStreamTask"); - this.put(ApplicationConfig.APP_CLASS, ""); - } - }); - StreamApplication streamApp = TaskFactoryUtil.createStreamApplication(config); - Object retFactory = TaskFactoryUtil.createTaskFactory(config, streamApp, mockRunner); - assertTrue(retFactory instanceof StreamTaskFactory); - assertTrue(((StreamTaskFactory) retFactory).createInstance() instanceof TestStreamTask); - - config = new MapConfig(new HashMap<String, String>() { - { - this.put("task.class", ""); - this.put(ApplicationConfig.APP_CLASS, "org.apache.samza.testUtils.InvalidStreamApplication"); - } - }); - try { - TaskFactoryUtil.createStreamApplication(config); - fail("Should have failed w/ no class not found"); - } catch (ConfigException cne) { - // expected - } } @Test @@ -216,7 +196,7 @@ public class TestTaskFactoryUtil { this.put("task.class", "org.apache.samza.testUtils.TestAsyncStreamTask"); } }); - Object retFactory = TaskFactoryUtil.createTaskFactory(config, null, null); + Object retFactory = TaskFactoryUtil.createTaskFactory(config); assertTrue(retFactory instanceof AsyncStreamTaskFactory); assertTrue(((AsyncStreamTaskFactory) retFactory).createInstance() instanceof TestAsyncStreamTask); @@ -226,7 +206,7 @@ public class TestTaskFactoryUtil { } }); try { - TaskFactoryUtil.createTaskFactory(config, null, null); + TaskFactoryUtil.createTaskFactory(config); fail("Should have failed w/ no.such.class"); } catch (ConfigException cfe) { // expected @@ -247,28 +227,6 @@ public class TestTaskFactoryUtil { } catch (ConfigException cfe) { // expected } - - config = new MapConfig(new HashMap<String, String>() { - { - this.put("task.class", "org.apache.samza.testUtils.TestAsyncStreamTask"); - this.put(ApplicationConfig.APP_CLASS, ""); - } - }); - StreamApplication streamApp = TaskFactoryUtil.createStreamApplication(config); - Object retFactory = TaskFactoryUtil.createTaskFactory(config, streamApp, mockRunner); - assertTrue(retFactory instanceof AsyncStreamTaskFactory); - assertTrue(((AsyncStreamTaskFactory) retFactory).createInstance() instanceof TestAsyncStreamTask); - - config = new MapConfig(new HashMap<String, String>() { - { - this.put("task.class", "org.apache.samza.testUtils.TestAsyncStreamTask"); - this.put(ApplicationConfig.APP_CLASS, null); - } - }); - streamApp = TaskFactoryUtil.createStreamApplication(config); - retFactory = TaskFactoryUtil.createTaskFactory(config, streamApp, mockRunner); - assertTrue(retFactory instanceof AsyncStreamTaskFactory); - assertTrue(((AsyncStreamTaskFactory) retFactory).createInstance() instanceof TestAsyncStreamTask); } @Test http://git-wip-us.apache.org/repos/asf/samza/blob/53d7f262/samza-core/src/test/java/org/apache/samza/testUtils/InvalidStreamApplication.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/testUtils/InvalidStreamApplication.java b/samza-core/src/test/java/org/apache/samza/testUtils/InvalidStreamApplication.java deleted file mode 100644 index 5026066..0000000 --- a/samza-core/src/test/java/org/apache/samza/testUtils/InvalidStreamApplication.java +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.samza.testUtils; - -/** - * Test class. Invalid class to implement {@link org.apache.samza.application.StreamApplication} - */ -public class InvalidStreamApplication { -} http://git-wip-us.apache.org/repos/asf/samza/blob/53d7f262/samza-kafka/src/test/java/org/apache/samza/system/kafka/TestKafkaStreamSpec.java ---------------------------------------------------------------------- diff --git a/samza-kafka/src/test/java/org/apache/samza/system/kafka/TestKafkaStreamSpec.java b/samza-kafka/src/test/java/org/apache/samza/system/kafka/TestKafkaStreamSpec.java index 1758bf0..c00ed2d 100644 --- a/samza-kafka/src/test/java/org/apache/samza/system/kafka/TestKafkaStreamSpec.java +++ b/samza-kafka/src/test/java/org/apache/samza/system/kafka/TestKafkaStreamSpec.java @@ -21,13 +21,14 @@ package org.apache.samza.system.kafka; import com.google.common.collect.ImmutableMap; import java.util.Map; import java.util.Properties; +import org.apache.samza.runtime.TestAbstractApplicationRunner; import org.apache.samza.system.StreamSpec; import org.junit.Test; import static org.junit.Assert.*; /** - * See also the general StreamSpec tests in {@link org.apache.samza.runtime.TestAbstractApplicationRunner} + * See also the general StreamSpec tests in {@link TestAbstractApplicationRunner} */ public class TestKafkaStreamSpec { http://git-wip-us.apache.org/repos/asf/samza/blob/53d7f262/samza-sql/src/main/java/org/apache/samza/sql/data/SamzaSqlCompositeKey.java ---------------------------------------------------------------------- diff --git a/samza-sql/src/main/java/org/apache/samza/sql/data/SamzaSqlCompositeKey.java b/samza-sql/src/main/java/org/apache/samza/sql/data/SamzaSqlCompositeKey.java index 54c8391..4b4b8f2 100644 --- a/samza-sql/src/main/java/org/apache/samza/sql/data/SamzaSqlCompositeKey.java +++ b/samza-sql/src/main/java/org/apache/samza/sql/data/SamzaSqlCompositeKey.java @@ -70,6 +70,7 @@ public class SamzaSqlCompositeKey implements Serializable { * Create the SamzaSqlCompositeKey from the rel message. * @param message Represents the samza sql rel message. * @param relIdx list of keys in the form of field indices within the rel message. + * @return the composite key of the rel message */ public static SamzaSqlCompositeKey createSamzaSqlCompositeKey(SamzaSqlRelMessage message, List<Integer> relIdx) { ArrayList<Object> keyParts = new ArrayList<>(); http://git-wip-us.apache.org/repos/asf/samza/blob/53d7f262/samza-sql/src/main/java/org/apache/samza/sql/data/SamzaSqlExecutionContext.java ---------------------------------------------------------------------- diff --git a/samza-sql/src/main/java/org/apache/samza/sql/data/SamzaSqlExecutionContext.java b/samza-sql/src/main/java/org/apache/samza/sql/data/SamzaSqlExecutionContext.java index b0c30dd..cb5c7a7 100644 --- a/samza-sql/src/main/java/org/apache/samza/sql/data/SamzaSqlExecutionContext.java +++ b/samza-sql/src/main/java/org/apache/samza/sql/data/SamzaSqlExecutionContext.java @@ -32,12 +32,24 @@ import org.apache.samza.sql.testutil.ReflectionUtils; import org.apache.samza.sql.udfs.ScalarUdf; -public class SamzaSqlExecutionContext { +public class SamzaSqlExecutionContext implements Cloneable { + /** + * The variables that are shared among all cloned instance of {@link SamzaSqlExecutionContext} + */ private final SamzaSqlApplicationConfig sqlConfig; private final Map<String, UdfMetadata> udfMetadata; + + /** + * The variable that are not shared among all cloned instance of {@link SamzaSqlExecutionContext} + */ private final Map<String, ScalarUdf> udfInstances = new HashMap<>(); + private SamzaSqlExecutionContext(SamzaSqlExecutionContext other) { + this.sqlConfig = other.sqlConfig; + this.udfMetadata = other.udfMetadata; + } + public SamzaSqlExecutionContext(SamzaSqlApplicationConfig config) { this.sqlConfig = config; udfMetadata = @@ -62,4 +74,10 @@ public class SamzaSqlExecutionContext { public SamzaSqlApplicationConfig getSamzaSqlApplicationConfig() { return sqlConfig; } + + @Override + public SamzaSqlExecutionContext clone() { + return new SamzaSqlExecutionContext(this); + } + } http://git-wip-us.apache.org/repos/asf/samza/blob/53d7f262/samza-sql/src/main/java/org/apache/samza/sql/translator/FilterTranslator.java ---------------------------------------------------------------------- diff --git a/samza-sql/src/main/java/org/apache/samza/sql/translator/FilterTranslator.java b/samza-sql/src/main/java/org/apache/samza/sql/translator/FilterTranslator.java index 5832b21..f33c5ca 100644 --- a/samza-sql/src/main/java/org/apache/samza/sql/translator/FilterTranslator.java +++ b/samza-sql/src/main/java/org/apache/samza/sql/translator/FilterTranslator.java @@ -22,13 +22,13 @@ package org.apache.samza.sql.translator; import java.util.Arrays; import java.util.Collections; -import java.util.List; -import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.logical.LogicalFilter; -import org.apache.calcite.rex.RexNode; +import org.apache.samza.config.Config; import org.apache.samza.operators.MessageStream; +import org.apache.samza.operators.functions.FilterFunction; import org.apache.samza.sql.data.Expression; import org.apache.samza.sql.data.SamzaSqlRelMessage; +import org.apache.samza.task.TaskContext; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -41,19 +41,26 @@ class FilterTranslator { private static final Logger log = LoggerFactory.getLogger(FilterTranslator.class); - void translate(final LogicalFilter filter, final TranslatorContext context) { - MessageStream<SamzaSqlRelMessage> inputStream = context.getMessageStream(filter.getInput().getId()); - MessageStream<SamzaSqlRelMessage> outputStream = translateFilter(inputStream, filter.getInputs(), - filter.getCondition(), context); - context.registerMessageStream(filter.getId(), outputStream); - } + private static class FilterTranslatorFunction implements FilterFunction<SamzaSqlRelMessage> { + private transient Expression expr; + private transient TranslatorContext context; + private transient LogicalFilter filter; - static MessageStream<SamzaSqlRelMessage> translateFilter(MessageStream<SamzaSqlRelMessage> inputStream, - List<RelNode> inputs, RexNode condition, final TranslatorContext context) { - Expression expr = - context.getExpressionCompiler().compile(inputs, Collections.singletonList(condition)); + private final int filterId; - return inputStream.filter(message -> { + FilterTranslatorFunction(int filterId) { + this.filterId = filterId; + } + + @Override + public void init(Config config, TaskContext context) { + this.context = (TranslatorContext) context.getUserContext(); + this.filter = (LogicalFilter) this.context.getRelNode(filterId); + this.expr = this.context.getExpressionCompiler().compile(filter.getInputs(), Collections.singletonList(filter.getCondition())); + } + + @Override + public boolean apply(SamzaSqlRelMessage message) { Object[] result = new Object[1]; expr.execute(context.getExecutionContext(), context.getDataContext(), message.getSamzaSqlRelRecord().getFieldValues().toArray(), result); @@ -67,6 +74,16 @@ class FilterTranslator { log.error("return value is not boolean"); return false; } - }); + } + } + + void translate(final LogicalFilter filter, final TranslatorContext context) { + MessageStream<SamzaSqlRelMessage> inputStream = context.getMessageStream(filter.getInput().getId()); + final int filterId = filter.getId(); + + MessageStream<SamzaSqlRelMessage> outputStream = inputStream.filter(new FilterTranslatorFunction(filterId)); + + context.registerMessageStream(filterId, outputStream); + context.registerRelNode(filterId, filter); } } http://git-wip-us.apache.org/repos/asf/samza/blob/53d7f262/samza-sql/src/main/java/org/apache/samza/sql/translator/LogicalAggregateTranslator.java ---------------------------------------------------------------------- diff --git a/samza-sql/src/main/java/org/apache/samza/sql/translator/LogicalAggregateTranslator.java b/samza-sql/src/main/java/org/apache/samza/sql/translator/LogicalAggregateTranslator.java index 96179d6..216ecea 100644 --- a/samza-sql/src/main/java/org/apache/samza/sql/translator/LogicalAggregateTranslator.java +++ b/samza-sql/src/main/java/org/apache/samza/sql/translator/LogicalAggregateTranslator.java @@ -20,13 +20,14 @@ package org.apache.samza.sql.translator; import java.time.Duration; +import java.util.ArrayList; import java.util.List; -import java.util.function.Supplier; import org.apache.calcite.rel.logical.LogicalAggregate; import org.apache.calcite.sql.SqlKind; import org.apache.samza.SamzaException; import org.apache.samza.operators.MessageStream; import org.apache.samza.operators.functions.FoldLeftFunction; +import org.apache.samza.operators.functions.SupplierFunction; import org.apache.samza.operators.windows.AccumulationMode; import org.apache.samza.operators.windows.Windows; import org.apache.samza.serializers.LongSerde; @@ -55,9 +56,11 @@ class LogicalAggregateTranslator { MessageStream<SamzaSqlRelMessage> inputStream = context.getMessageStream(aggregate.getInput().getId()); // At this point, the assumption is that only count function is supported. - Supplier<Long> initialValue = () -> (long) 0; + SupplierFunction<Long> initialValue = () -> (long) 0; FoldLeftFunction<SamzaSqlRelMessage, Long> foldCountFn = (m, c) -> c + 1; + final ArrayList<String> aggFieldNames = getAggFieldNames(aggregate); + MessageStream<SamzaSqlRelMessage> outputStream = inputStream .window(Windows.keyedTumblingWindow(m -> m, @@ -68,15 +71,20 @@ class LogicalAggregateTranslator { new LongSerde()) .setAccumulationMode(AccumulationMode.DISCARDING), "tumblingWindow_" + windowId) .map(windowPane -> { - List<String> fieldNames = windowPane.getKey().getKey().getSamzaSqlRelRecord().getFieldNames(); - List<Object> fieldValues = windowPane.getKey().getKey().getSamzaSqlRelRecord().getFieldValues(); - fieldNames.add(aggregate.getAggCallList().get(0).getName()); - fieldValues.add(windowPane.getMessage()); - return new SamzaSqlRelMessage(fieldNames, fieldValues); - }); + List<String> fieldNames = windowPane.getKey().getKey().getSamzaSqlRelRecord().getFieldNames(); + List<Object> fieldValues = windowPane.getKey().getKey().getSamzaSqlRelRecord().getFieldValues(); + fieldNames.add(aggFieldNames.get(0)); + fieldValues.add(windowPane.getMessage()); + return new SamzaSqlRelMessage(fieldNames, fieldValues); + }); context.registerMessageStream(aggregate.getId(), outputStream); } + private ArrayList<String> getAggFieldNames(LogicalAggregate aggregate) { + return aggregate.getAggCallList().stream().collect(ArrayList::new, (names, aggCall) -> names.add(aggCall.getName()), + (n1, n2) -> n1.addAll(n2)); + } + void validateAggregateFunctions(final LogicalAggregate aggregate) { if (aggregate.getAggCallList().size() != 1) { String errMsg = "Windowing is supported ONLY with one aggregate function but the number of given functions are " +