Repository: samza Updated Branches: refs/heads/master 1971d596c -> 2be7061d4
http://git-wip-us.apache.org/repos/asf/samza/blob/2be7061d/samza-core/src/test/java/org/apache/samza/util/TestEmbeddedTaggedRateLimiter.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/util/TestEmbeddedTaggedRateLimiter.java b/samza-core/src/test/java/org/apache/samza/util/TestEmbeddedTaggedRateLimiter.java index 9c79766..05fa19a 100644 --- a/samza-core/src/test/java/org/apache/samza/util/TestEmbeddedTaggedRateLimiter.java +++ b/samza-core/src/test/java/org/apache/samza/util/TestEmbeddedTaggedRateLimiter.java @@ -18,14 +18,22 @@ */ package org.apache.samza.util; +import java.lang.reflect.Field; +import java.util.Collection; import java.util.HashMap; import java.util.Map; +import org.apache.samza.SamzaException; +import org.apache.samza.config.Config; +import org.apache.samza.container.SamzaContainerContext; +import org.apache.samza.task.TaskContext; import org.junit.Assert; import org.junit.Ignore; import org.junit.Test; import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class TestEmbeddedTaggedRateLimiter { @@ -35,42 +43,71 @@ public class TestEmbeddedTaggedRateLimiter { final static private int TARGET_RATE_RED = 1000; final static private int TARGET_RATE_PER_TASK_RED = TARGET_RATE_RED / NUMBER_OF_TASKS; final static private int TARGET_RATE_GREEN = 2000; - final static private int TARGET_RATE_PER_TASK_GREEN = TARGET_RATE_GREEN / NUMBER_OF_TASKS; final static private int INCREMENT = 2; + final static private int TARGET_RATE = 4000; + final static private int TARGET_RATE_PER_TASK = TARGET_RATE / NUMBER_OF_TASKS; + @Test @Ignore("Flaky Test: Test fails in travis.") public void testAcquire() { - RateLimiter rateLimiter = createRateLimiter(); + RateLimiter rateLimiter = new EmbeddedTaggedRateLimiter(TARGET_RATE); + initRateLimiter(rateLimiter); - Map<String, Integer> tagToCount = new HashMap<>(); - tagToCount.put("red", 0); - tagToCount.put("green", 0); + int count = 0; + long start = System.currentTimeMillis(); + while (System.currentTimeMillis() - start < TEST_INTERVAL) { + rateLimiter.acquire(INCREMENT); + count += INCREMENT; + } - Map<String, Integer> tagToCredits = new HashMap<>(); - tagToCredits.put("red", INCREMENT); - tagToCredits.put("green", INCREMENT); + long rate = count * 1000 / TEST_INTERVAL; + verifyRate(rate); + } + + @Test + public void testAcquireWithTimeout() { + RateLimiter rateLimiter = new EmbeddedTaggedRateLimiter(TARGET_RATE); + initRateLimiter(rateLimiter); + + boolean hasSeenZeros = false; + int count = 0; + int callCount = 0; long start = System.currentTimeMillis(); while (System.currentTimeMillis() - start < TEST_INTERVAL) { - rateLimiter.acquire(tagToCredits); - tagToCount.put("red", tagToCount.get("red") + INCREMENT); - tagToCount.put("green", tagToCount.get("green") + INCREMENT); + ++callCount; + int availableCredits = rateLimiter.acquire(INCREMENT, 20, MILLISECONDS); + if (availableCredits <= 0) { + hasSeenZeros = true; + } else { + count += INCREMENT; + } } - { - long rate = tagToCount.get("red") * 1000 / TEST_INTERVAL; - verifyRate(rate, TARGET_RATE_PER_TASK_RED); - } { - // Note: due to blocking, green is capped at red's QPS - long rate = tagToCount.get("green") * 1000 / TEST_INTERVAL; - verifyRate(rate, TARGET_RATE_PER_TASK_RED); - } + long rate = count * 1000 / TEST_INTERVAL; + verifyRate(rate); + junit.framework.Assert.assertTrue(Math.abs(callCount - TARGET_RATE_PER_TASK * TEST_INTERVAL / 1000 / INCREMENT) <= 2); + junit.framework.Assert.assertFalse(hasSeenZeros); } - @Test - public void testTryAcquire() { + @Test(expected = IllegalStateException.class) + public void testFailsWhenUninitialized() { + new EmbeddedTaggedRateLimiter(100).acquire(1); + } + + @Test(expected = IllegalArgumentException.class) + public void testFailsWhenUsingTags() { + RateLimiter rateLimiter = new EmbeddedTaggedRateLimiter(10); + initRateLimiter(rateLimiter); + Map<String, Integer> tagToCredits = new HashMap<>(); + tagToCredits.put("red", 1); + tagToCredits.put("green", 1); + rateLimiter.acquire(tagToCredits); + } + @Test + public void testAcquireTagged() { RateLimiter rateLimiter = createRateLimiter(); Map<String, Integer> tagToCount = new HashMap<>(); @@ -83,22 +120,23 @@ public class TestEmbeddedTaggedRateLimiter { long start = System.currentTimeMillis(); while (System.currentTimeMillis() - start < TEST_INTERVAL) { - Map<String, Integer> resultMap = rateLimiter.tryAcquire(tagToCredits); - tagToCount.put("red", tagToCount.get("red") + resultMap.get("red")); - tagToCount.put("green", tagToCount.get("green") + resultMap.get("green")); + rateLimiter.acquire(tagToCredits); + tagToCount.put("red", tagToCount.get("red") + INCREMENT); + tagToCount.put("green", tagToCount.get("green") + INCREMENT); } { long rate = tagToCount.get("red") * 1000 / TEST_INTERVAL; verifyRate(rate, TARGET_RATE_PER_TASK_RED); } { + // Note: due to blocking, green is capped at red's QPS long rate = tagToCount.get("green") * 1000 / TEST_INTERVAL; - verifyRate(rate, TARGET_RATE_PER_TASK_GREEN); + verifyRate(rate, TARGET_RATE_PER_TASK_RED); } } @Test - public void testAcquireWithTimeout() { + public void testAcquireWithTimeoutTagged() { RateLimiter rateLimiter = createRateLimiter(); @@ -128,7 +166,7 @@ public class TestEmbeddedTaggedRateLimiter { } @Test(expected = IllegalStateException.class) - public void testFailsWhenUninitialized() { + public void testFailsWhenUninitializedTagged() { Map<String, Integer> tagToTargetRateMap = new HashMap<>(); tagToTargetRateMap.put("red", 1000); tagToTargetRateMap.put("green", 2000); @@ -141,14 +179,14 @@ public class TestEmbeddedTaggedRateLimiter { tagToCredits.put("red", 1); tagToCredits.put("green", 1); RateLimiter rateLimiter = new EmbeddedTaggedRateLimiter(tagToCredits); - TestEmbeddedRateLimiter.initRateLimiter(rateLimiter); + initRateLimiter(rateLimiter); rateLimiter.acquire(1); } private void verifyRate(long rate, long targetRate) { // As the actual rate would likely not be exactly the same as target rate, the calculation below - // verifies the actual rate is within 5% of the target rate per task - Assert.assertTrue(Math.abs(rate - targetRate) <= targetRate * 5 / 100); + // verifies the actual rate is within 10% of the target rate per task + Assert.assertTrue(Math.abs(rate - targetRate) <= targetRate * 10 / 100); } private RateLimiter createRateLimiter() { @@ -156,8 +194,36 @@ public class TestEmbeddedTaggedRateLimiter { tagToTargetRateMap.put("red", TARGET_RATE_RED); tagToTargetRateMap.put("green", TARGET_RATE_GREEN); RateLimiter rateLimiter = new EmbeddedTaggedRateLimiter(tagToTargetRateMap); - TestEmbeddedRateLimiter.initRateLimiter(rateLimiter); + initRateLimiter(rateLimiter); return rateLimiter; } + private void verifyRate(long rate) { + // As the actual rate would likely not be exactly the same as target rate, the calculation below + // verifies the actual rate is within 5% of the target rate per task + junit.framework.Assert.assertTrue(Math.abs(rate - TARGET_RATE_PER_TASK) <= TARGET_RATE_PER_TASK * 5 / 100); + } + + static void initRateLimiter(RateLimiter rateLimiter) { + Config config = mock(Config.class); + TaskContext taskContext = mock(TaskContext.class); + SamzaContainerContext containerContext = mockSamzaContainerContext(); + when(taskContext.getSamzaContainerContext()).thenReturn(containerContext); + rateLimiter.init(config, taskContext); + } + + static SamzaContainerContext mockSamzaContainerContext() { + try { + Collection<String> taskNames = mock(Collection.class); + when(taskNames.size()).thenReturn(NUMBER_OF_TASKS); + SamzaContainerContext containerContext = mock(SamzaContainerContext.class); + Field taskNamesField = SamzaContainerContext.class.getDeclaredField("taskNames"); + taskNamesField.setAccessible(true); + taskNamesField.set(containerContext, taskNames); + taskNamesField.setAccessible(false); + return containerContext; + } catch (Exception ex) { + throw new SamzaException(ex); + } + } } http://git-wip-us.apache.org/repos/asf/samza/blob/2be7061d/samza-kv/src/main/java/org/apache/samza/storage/kv/BaseLocalStoreBackedTableProvider.java ---------------------------------------------------------------------- diff --git a/samza-kv/src/main/java/org/apache/samza/storage/kv/BaseLocalStoreBackedTableProvider.java b/samza-kv/src/main/java/org/apache/samza/storage/kv/BaseLocalStoreBackedTableProvider.java index 4af0f1d..b494eba 100644 --- a/samza-kv/src/main/java/org/apache/samza/storage/kv/BaseLocalStoreBackedTableProvider.java +++ b/samza-kv/src/main/java/org/apache/samza/storage/kv/BaseLocalStoreBackedTableProvider.java @@ -25,18 +25,26 @@ import org.apache.samza.SamzaException; import org.apache.samza.config.JavaTableConfig; import org.apache.samza.config.MapConfig; import org.apache.samza.config.StorageConfig; -import org.apache.samza.storage.StorageEngine; -import org.apache.samza.table.LocalStoreBackedTableProvider; +import org.apache.samza.container.SamzaContainerContext; +import org.apache.samza.table.ReadableTable; import org.apache.samza.table.Table; +import org.apache.samza.table.TableProvider; import org.apache.samza.table.TableSpec; +import org.apache.samza.task.TaskContext; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.common.base.Preconditions; + /** - * Base class for tables backed by Samza stores, see {@link LocalStoreBackedTableProvider}. + * Base class for tables backed by Samza local stores. The backing stores are + * injected during initialization of the table. Since the lifecycle + * of the underlying stores are already managed by Samza container, + * the table provider will not manage the lifecycle of the backing + * stores. */ -abstract public class BaseLocalStoreBackedTableProvider implements LocalStoreBackedTableProvider { +abstract public class BaseLocalStoreBackedTableProvider implements TableProvider { protected final Logger logger = LoggerFactory.getLogger(getClass()); @@ -44,13 +52,28 @@ abstract public class BaseLocalStoreBackedTableProvider implements LocalStoreBac protected KeyValueStore kvStore; + protected SamzaContainerContext containerContext; + + protected TaskContext taskContext; + public BaseLocalStoreBackedTableProvider(TableSpec tableSpec) { this.tableSpec = tableSpec; } @Override - public void init(StorageEngine store) { - kvStore = (KeyValueStore) store; + public void init(SamzaContainerContext containerContext, TaskContext taskContext) { + this.containerContext = containerContext; + this.taskContext = taskContext; + + Preconditions.checkNotNull(this.taskContext, "Must specify task context for local tables."); + + kvStore = (KeyValueStore) taskContext.getStore(tableSpec.getId()); + + if (kvStore == null) { + throw new SamzaException(String.format( + "Backing store for table %s was not injected by SamzaContainer", tableSpec.getId())); + } + logger.info("Initialized backing store for table " + tableSpec.getId()); } @@ -59,17 +82,9 @@ abstract public class BaseLocalStoreBackedTableProvider implements LocalStoreBac if (kvStore == null) { throw new SamzaException("Store not initialized for table " + tableSpec.getId()); } - return new LocalStoreBackedReadWriteTable(kvStore); - } - - @Override - public void start() { - logger.info("Starting table provider for table " + tableSpec.getId()); - } - - @Override - public void stop() { - logger.info("Stopping table provider for table " + tableSpec.getId()); + ReadableTable table = new LocalStoreBackedReadWriteTable(tableSpec.getId(), kvStore); + table.init(containerContext, taskContext); + return table; } protected Map<String, String> generateCommonStoreConfig(Map<String, String> config) { @@ -89,4 +104,9 @@ abstract public class BaseLocalStoreBackedTableProvider implements LocalStoreBac return storeConfig; } + + @Override + public void close() { + logger.info("Shutting down table provider for table " + tableSpec.getId()); + } } http://git-wip-us.apache.org/repos/asf/samza/blob/2be7061d/samza-kv/src/main/java/org/apache/samza/storage/kv/LocalStoreBackedReadWriteTable.java ---------------------------------------------------------------------- diff --git a/samza-kv/src/main/java/org/apache/samza/storage/kv/LocalStoreBackedReadWriteTable.java b/samza-kv/src/main/java/org/apache/samza/storage/kv/LocalStoreBackedReadWriteTable.java index 3149c86..4037f60 100644 --- a/samza-kv/src/main/java/org/apache/samza/storage/kv/LocalStoreBackedReadWriteTable.java +++ b/samza-kv/src/main/java/org/apache/samza/storage/kv/LocalStoreBackedReadWriteTable.java @@ -36,8 +36,8 @@ public class LocalStoreBackedReadWriteTable<K, V> extends LocalStoreBackedReadab * Constructs an instance of {@link LocalStoreBackedReadWriteTable} * @param kvStore the backing store */ - public LocalStoreBackedReadWriteTable(KeyValueStore kvStore) { - super(kvStore); + public LocalStoreBackedReadWriteTable(String tableId, KeyValueStore kvStore) { + super(tableId, kvStore); } @Override http://git-wip-us.apache.org/repos/asf/samza/blob/2be7061d/samza-kv/src/main/java/org/apache/samza/storage/kv/LocalStoreBackedReadableTable.java ---------------------------------------------------------------------- diff --git a/samza-kv/src/main/java/org/apache/samza/storage/kv/LocalStoreBackedReadableTable.java b/samza-kv/src/main/java/org/apache/samza/storage/kv/LocalStoreBackedReadableTable.java index fead086..5ff58ab 100644 --- a/samza-kv/src/main/java/org/apache/samza/storage/kv/LocalStoreBackedReadableTable.java +++ b/samza-kv/src/main/java/org/apache/samza/storage/kv/LocalStoreBackedReadableTable.java @@ -24,6 +24,8 @@ import java.util.stream.Collectors; import org.apache.samza.table.ReadableTable; +import com.google.common.base.Preconditions; + /** * A store backed readable table @@ -34,12 +36,16 @@ import org.apache.samza.table.ReadableTable; public class LocalStoreBackedReadableTable<K, V> implements ReadableTable<K, V> { protected KeyValueStore<K, V> kvStore; + protected String tableId; /** * Constructs an instance of {@link LocalStoreBackedReadableTable} * @param kvStore the backing store */ - public LocalStoreBackedReadableTable(KeyValueStore<K, V> kvStore) { + public LocalStoreBackedReadableTable(String tableId, KeyValueStore<K, V> kvStore) { + Preconditions.checkArgument(tableId != null & !tableId.isEmpty() , "invalid tableId"); + Preconditions.checkNotNull(kvStore, "null KeyValueStore"); + this.tableId = tableId; this.kvStore = kvStore; } http://git-wip-us.apache.org/repos/asf/samza/blob/2be7061d/samza-kv/src/test/java/org/apache/samza/storage/kv/TestLocalBaseStoreBackedTableProvider.java ---------------------------------------------------------------------- diff --git a/samza-kv/src/test/java/org/apache/samza/storage/kv/TestLocalBaseStoreBackedTableProvider.java b/samza-kv/src/test/java/org/apache/samza/storage/kv/TestLocalBaseStoreBackedTableProvider.java index 9c95637..d30c18f 100644 --- a/samza-kv/src/test/java/org/apache/samza/storage/kv/TestLocalBaseStoreBackedTableProvider.java +++ b/samza-kv/src/test/java/org/apache/samza/storage/kv/TestLocalBaseStoreBackedTableProvider.java @@ -27,11 +27,13 @@ import org.apache.samza.config.JavaTableConfig; import org.apache.samza.config.StorageConfig; import org.apache.samza.storage.StorageEngine; import org.apache.samza.table.TableSpec; +import org.apache.samza.task.TaskContext; import org.junit.Before; import org.junit.Test; import junit.framework.Assert; +import static org.mockito.Matchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -55,7 +57,9 @@ public class TestLocalBaseStoreBackedTableProvider { @Test public void testInit() { StorageEngine store = mock(KeyValueStorageEngine.class); - tableProvider.init(store); + TaskContext taskContext = mock(TaskContext.class); + when(taskContext.getStore(any())).thenReturn(store); + tableProvider.init(null, taskContext); Assert.assertNotNull(tableProvider.getTable()); } http://git-wip-us.apache.org/repos/asf/samza/blob/2be7061d/samza-test/src/test/java/org/apache/samza/test/table/TestLocalTable.java ---------------------------------------------------------------------- diff --git a/samza-test/src/test/java/org/apache/samza/test/table/TestLocalTable.java b/samza-test/src/test/java/org/apache/samza/test/table/TestLocalTable.java index 8f7eb5d..23fa9e6 100644 --- a/samza-test/src/test/java/org/apache/samza/test/table/TestLocalTable.java +++ b/samza-test/src/test/java/org/apache/samza/test/table/TestLocalTable.java @@ -74,7 +74,7 @@ public class TestLocalTable extends AbstractIntegrationTestHarness { Profile[] profiles = TestTableData.generateProfiles(count); int partitionCount = 4; - Map<String, String> configs = getBaseJobConfig(); + Map<String, String> configs = getBaseJobConfig(bootstrapUrl(), zkConnect()); configs.put("streams.Profile.samza.system", "test"); configs.put("streams.Profile.source", Base64Serializer.serialize(profiles)); @@ -112,7 +112,7 @@ public class TestLocalTable extends AbstractIntegrationTestHarness { Profile[] profiles = TestTableData.generateProfiles(count); int partitionCount = 4; - Map<String, String> configs = getBaseJobConfig(); + Map<String, String> configs = getBaseJobConfig(bootstrapUrl(), zkConnect()); configs.put("streams.PageView.samza.system", "test"); configs.put("streams.PageView.source", Base64Serializer.serialize(pageViews)); @@ -170,7 +170,7 @@ public class TestLocalTable extends AbstractIntegrationTestHarness { Profile[] profiles = TestTableData.generateProfiles(count); int partitionCount = 4; - Map<String, String> configs = getBaseJobConfig(); + Map<String, String> configs = getBaseJobConfig(bootstrapUrl(), zkConnect()); configs.put("streams.Profile1.samza.system", "test"); configs.put("streams.Profile1.source", Base64Serializer.serialize(profiles)); @@ -239,7 +239,7 @@ public class TestLocalTable extends AbstractIntegrationTestHarness { assertTrue(joinedPageViews2.get(0) instanceof EnrichedPageView); } - private Map<String, String> getBaseJobConfig() { + static Map<String, String> getBaseJobConfig(String bootstrapUrl, String zkConnect) { Map<String, String> configs = new HashMap<>(); configs.put("systems.test.samza.factory", ArraySystemFactory.class.getName()); @@ -251,8 +251,8 @@ public class TestLocalTable extends AbstractIntegrationTestHarness { // For intermediate streams configs.put("systems.kafka.samza.factory", "org.apache.samza.system.kafka.KafkaSystemFactory"); - configs.put("systems.kafka.producer.bootstrap.servers", bootstrapUrl()); - configs.put("systems.kafka.consumer.zookeeper.connect", zkConnect()); + configs.put("systems.kafka.producer.bootstrap.servers", bootstrapUrl); + configs.put("systems.kafka.consumer.zookeeper.connect", zkConnect); configs.put("systems.kafka.samza.key.serde", "int"); configs.put("systems.kafka.samza.msg.serde", "json"); configs.put("systems.kafka.default.stream.replication.factor", "1"); @@ -281,7 +281,7 @@ public class TestLocalTable extends AbstractIntegrationTestHarness { } } - private class PageViewToProfileJoinFunction implements StreamTableJoinFunction + static class PageViewToProfileJoinFunction implements StreamTableJoinFunction <Integer, KV<Integer, PageView>, KV<Integer, Profile>, EnrichedPageView> { private int count; @Override http://git-wip-us.apache.org/repos/asf/samza/blob/2be7061d/samza-test/src/test/java/org/apache/samza/test/table/TestRemoteTable.java ---------------------------------------------------------------------- diff --git a/samza-test/src/test/java/org/apache/samza/test/table/TestRemoteTable.java b/samza-test/src/test/java/org/apache/samza/test/table/TestRemoteTable.java new file mode 100644 index 0000000..a260c3f --- /dev/null +++ b/samza-test/src/test/java/org/apache/samza/test/table/TestRemoteTable.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.samza.test.table; + +import java.io.IOException; +import java.io.ObjectInputStream; +import java.util.Arrays; +import java.util.Collection; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; + +import org.apache.samza.SamzaException; +import org.apache.samza.application.StreamApplication; +import org.apache.samza.config.MapConfig; +import org.apache.samza.container.SamzaContainerContext; +import org.apache.samza.metrics.Counter; +import org.apache.samza.metrics.MetricsRegistry; +import org.apache.samza.metrics.Timer; +import org.apache.samza.operators.KV; +import org.apache.samza.runtime.LocalApplicationRunner; +import org.apache.samza.serializers.NoOpSerde; +import org.apache.samza.table.Table; +import org.apache.samza.table.remote.TableReadFunction; +import org.apache.samza.table.remote.TableWriteFunction; +import org.apache.samza.table.remote.RemoteReadableTable; +import org.apache.samza.table.remote.RemoteTableDescriptor; +import org.apache.samza.table.remote.RemoteReadWriteTable; +import org.apache.samza.task.TaskContext; +import org.apache.samza.test.harness.AbstractIntegrationTestHarness; +import org.apache.samza.test.util.Base64Serializer; +import org.apache.samza.util.RateLimiter; +import org.junit.Assert; +import org.junit.Test; + +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyString; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; + + +public class TestRemoteTable extends AbstractIntegrationTestHarness { + private TableReadFunction<Integer, TestTableData.Profile> getInMemoryReader(TestTableData.Profile[] profiles) { + final Map<Integer, TestTableData.Profile> profileMap = Arrays.stream(profiles) + .collect(Collectors.toMap(p -> p.getMemberId(), Function.identity())); + TableReadFunction<Integer, TestTableData.Profile> reader = + (TableReadFunction<Integer, TestTableData.Profile>) key -> profileMap.getOrDefault(key, null); + return reader; + } + + static List<TestTableData.EnrichedPageView> writtenRecords = new LinkedList<>(); + + static class InMemoryWriteFunction implements TableWriteFunction<Integer, TestTableData.EnrichedPageView> { + private transient List<TestTableData.EnrichedPageView> records; + + // Verify serializable functionality + private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { + in.defaultReadObject(); + + // Write to the global list for verification + records = writtenRecords; + } + + @Override + public void put(Integer key, TestTableData.EnrichedPageView record) { + records.add(record); + } + + @Override + public void delete(Integer key) { + records.remove(key); + } + + @Override + public void deleteAll(Collection<Integer> keys) { + records.removeAll(keys); + } + } + + @Test + public void testStreamTableJoinRemoteTable() throws Exception { + List<TestTableData.PageView> received = new LinkedList<>(); + final InMemoryWriteFunction writer = new InMemoryWriteFunction(); + + int count = 10; + TestTableData.PageView[] pageViews = TestTableData.generatePageViews(count); + TestTableData.Profile[] profiles = TestTableData.generateProfiles(count); + + int partitionCount = 4; + Map<String, String> configs = TestLocalTable.getBaseJobConfig(bootstrapUrl(), zkConnect()); + + configs.put("streams.PageView.samza.system", "test"); + configs.put("streams.PageView.source", Base64Serializer.serialize(pageViews)); + configs.put("streams.PageView.partitionCount", String.valueOf(partitionCount)); + + final RateLimiter readRateLimiter = mock(RateLimiter.class); + final RateLimiter writeRateLimiter = mock(RateLimiter.class); + final LocalApplicationRunner runner = new LocalApplicationRunner(new MapConfig(configs)); + final StreamApplication app = (streamGraph, cfg) -> { + RemoteTableDescriptor<Integer, TestTableData.Profile> inputTableDesc = new RemoteTableDescriptor<>("profile-table-1"); + inputTableDesc + .withReadFunction(getInMemoryReader(profiles)) + .withRateLimiter(readRateLimiter, null, null); + + RemoteTableDescriptor<Integer, TestTableData.EnrichedPageView> outputTableDesc = new RemoteTableDescriptor<>("enriched-page-view-table-1"); + outputTableDesc + .withReadFunction(key -> null) // dummy reader + .withWriteFunction(writer) + .withRateLimiter(writeRateLimiter, null, null); + + Table<KV<Integer, TestTableData.Profile>> inputTable = streamGraph.getTable(inputTableDesc); + Table<KV<Integer, TestTableData.EnrichedPageView>> outputTable = streamGraph.getTable(outputTableDesc); + + streamGraph.getInputStream("PageView", new NoOpSerde<TestTableData.PageView>()) + .map(pv -> { + received.add(pv); + return new KV<Integer, TestTableData.PageView>(pv.getMemberId(), pv); + }) + .join(inputTable, new TestLocalTable.PageViewToProfileJoinFunction()) + .map(m -> new KV(m.getMemberId(), m)) + .sendTo(outputTable); + }; + + runner.run(app); + runner.waitForFinish(); + + int numExpected = count * partitionCount; + Assert.assertEquals(numExpected, received.size()); + Assert.assertEquals(numExpected, writtenRecords.size()); + Assert.assertTrue(writtenRecords.get(0) instanceof TestTableData.EnrichedPageView); + } + + private TaskContext createMockTaskContext() { + MetricsRegistry metricsRegistry = mock(MetricsRegistry.class); + doReturn(new Counter("")).when(metricsRegistry).newCounter(anyString(), anyString()); + doReturn(new Timer("")).when(metricsRegistry).newTimer(anyString(), anyString()); + TaskContext context = mock(TaskContext.class); + doReturn(metricsRegistry).when(context).getMetricsRegistry(); + return context; + } + + @Test(expected = SamzaException.class) + public void testCatchReaderException() { + TableReadFunction<String, ?> reader = mock(TableReadFunction.class); + doThrow(new RuntimeException("Expected test exception")).when(reader).get(anyString()); + RemoteReadableTable<String, ?> table = new RemoteReadableTable<>("table1", reader, null, null); + table.init(mock(SamzaContainerContext.class), createMockTaskContext()); + table.get("abc"); + } + + @Test(expected = SamzaException.class) + public void testCatchWriterException() { + TableReadFunction<String, String> reader = mock(TableReadFunction.class); + TableWriteFunction<String, String> writer = mock(TableWriteFunction.class); + doThrow(new RuntimeException("Expected test exception")).when(writer).put(anyString(), any()); + RemoteReadWriteTable<String, String> table = new RemoteReadWriteTable<>("table1", reader, writer, null, null, null); + table.init(mock(SamzaContainerContext.class), createMockTaskContext()); + table.put("abc", "efg"); + } +}