Repository: samza
Updated Branches:
  refs/heads/master 1971d596c -> 2be7061d4


http://git-wip-us.apache.org/repos/asf/samza/blob/2be7061d/samza-core/src/test/java/org/apache/samza/util/TestEmbeddedTaggedRateLimiter.java
----------------------------------------------------------------------
diff --git 
a/samza-core/src/test/java/org/apache/samza/util/TestEmbeddedTaggedRateLimiter.java
 
b/samza-core/src/test/java/org/apache/samza/util/TestEmbeddedTaggedRateLimiter.java
index 9c79766..05fa19a 100644
--- 
a/samza-core/src/test/java/org/apache/samza/util/TestEmbeddedTaggedRateLimiter.java
+++ 
b/samza-core/src/test/java/org/apache/samza/util/TestEmbeddedTaggedRateLimiter.java
@@ -18,14 +18,22 @@
  */
 package org.apache.samza.util;
 
+import java.lang.reflect.Field;
+import java.util.Collection;
 import java.util.HashMap;
 import java.util.Map;
 
+import org.apache.samza.SamzaException;
+import org.apache.samza.config.Config;
+import org.apache.samza.container.SamzaContainerContext;
+import org.apache.samza.task.TaskContext;
 import org.junit.Assert;
 import org.junit.Ignore;
 import org.junit.Test;
 
 import static java.util.concurrent.TimeUnit.MILLISECONDS;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
 
 
 public class TestEmbeddedTaggedRateLimiter {
@@ -35,42 +43,71 @@ public class TestEmbeddedTaggedRateLimiter {
   final static private int TARGET_RATE_RED = 1000;
   final static private int TARGET_RATE_PER_TASK_RED = TARGET_RATE_RED / 
NUMBER_OF_TASKS;
   final static private int TARGET_RATE_GREEN = 2000;
-  final static private int TARGET_RATE_PER_TASK_GREEN = TARGET_RATE_GREEN / 
NUMBER_OF_TASKS;
   final static private int INCREMENT = 2;
 
+  final static private int TARGET_RATE = 4000;
+  final static private int TARGET_RATE_PER_TASK = TARGET_RATE / 
NUMBER_OF_TASKS;
+
   @Test
   @Ignore("Flaky Test: Test fails in travis.")
   public void testAcquire() {
-    RateLimiter rateLimiter = createRateLimiter();
+    RateLimiter rateLimiter = new EmbeddedTaggedRateLimiter(TARGET_RATE);
+    initRateLimiter(rateLimiter);
 
-    Map<String, Integer> tagToCount = new HashMap<>();
-    tagToCount.put("red", 0);
-    tagToCount.put("green", 0);
+    int count = 0;
+    long start = System.currentTimeMillis();
+    while (System.currentTimeMillis() - start < TEST_INTERVAL) {
+      rateLimiter.acquire(INCREMENT);
+      count += INCREMENT;
+    }
 
-    Map<String, Integer> tagToCredits = new HashMap<>();
-    tagToCredits.put("red", INCREMENT);
-    tagToCredits.put("green", INCREMENT);
+    long rate = count * 1000 / TEST_INTERVAL;
+    verifyRate(rate);
+  }
+
+  @Test
+  public void testAcquireWithTimeout() {
+    RateLimiter rateLimiter = new EmbeddedTaggedRateLimiter(TARGET_RATE);
+    initRateLimiter(rateLimiter);
+
+    boolean hasSeenZeros = false;
 
+    int count = 0;
+    int callCount = 0;
     long start = System.currentTimeMillis();
     while (System.currentTimeMillis() - start < TEST_INTERVAL) {
-      rateLimiter.acquire(tagToCredits);
-      tagToCount.put("red", tagToCount.get("red") + INCREMENT);
-      tagToCount.put("green", tagToCount.get("green") + INCREMENT);
+      ++callCount;
+      int availableCredits = rateLimiter.acquire(INCREMENT, 20, MILLISECONDS);
+      if (availableCredits <= 0) {
+        hasSeenZeros = true;
+      } else {
+        count += INCREMENT;
+      }
     }
 
-    {
-      long rate = tagToCount.get("red") * 1000 / TEST_INTERVAL;
-      verifyRate(rate, TARGET_RATE_PER_TASK_RED);
-    } {
-      // Note: due to blocking, green is capped at red's QPS
-      long rate = tagToCount.get("green") * 1000 / TEST_INTERVAL;
-      verifyRate(rate, TARGET_RATE_PER_TASK_RED);
-    }
+    long rate = count * 1000 / TEST_INTERVAL;
+    verifyRate(rate);
+    junit.framework.Assert.assertTrue(Math.abs(callCount - 
TARGET_RATE_PER_TASK * TEST_INTERVAL / 1000 / INCREMENT) <= 2);
+    junit.framework.Assert.assertFalse(hasSeenZeros);
   }
 
-  @Test
-  public void testTryAcquire() {
+  @Test(expected = IllegalStateException.class)
+  public void testFailsWhenUninitialized() {
+    new EmbeddedTaggedRateLimiter(100).acquire(1);
+  }
+
+  @Test(expected = IllegalArgumentException.class)
+  public void testFailsWhenUsingTags() {
+    RateLimiter rateLimiter = new EmbeddedTaggedRateLimiter(10);
+    initRateLimiter(rateLimiter);
+    Map<String, Integer> tagToCredits = new HashMap<>();
+    tagToCredits.put("red", 1);
+    tagToCredits.put("green", 1);
+    rateLimiter.acquire(tagToCredits);
+  }
 
+  @Test
+  public void testAcquireTagged() {
     RateLimiter rateLimiter = createRateLimiter();
 
     Map<String, Integer> tagToCount = new HashMap<>();
@@ -83,22 +120,23 @@ public class TestEmbeddedTaggedRateLimiter {
 
     long start = System.currentTimeMillis();
     while (System.currentTimeMillis() - start < TEST_INTERVAL) {
-      Map<String, Integer> resultMap = rateLimiter.tryAcquire(tagToCredits);
-      tagToCount.put("red", tagToCount.get("red") + resultMap.get("red"));
-      tagToCount.put("green", tagToCount.get("green") + 
resultMap.get("green"));
+      rateLimiter.acquire(tagToCredits);
+      tagToCount.put("red", tagToCount.get("red") + INCREMENT);
+      tagToCount.put("green", tagToCount.get("green") + INCREMENT);
     }
 
     {
       long rate = tagToCount.get("red") * 1000 / TEST_INTERVAL;
       verifyRate(rate, TARGET_RATE_PER_TASK_RED);
     } {
+      // Note: due to blocking, green is capped at red's QPS
       long rate = tagToCount.get("green") * 1000 / TEST_INTERVAL;
-      verifyRate(rate, TARGET_RATE_PER_TASK_GREEN);
+      verifyRate(rate, TARGET_RATE_PER_TASK_RED);
     }
   }
 
   @Test
-  public void testAcquireWithTimeout() {
+  public void testAcquireWithTimeoutTagged() {
 
     RateLimiter rateLimiter = createRateLimiter();
 
@@ -128,7 +166,7 @@ public class TestEmbeddedTaggedRateLimiter {
   }
 
   @Test(expected = IllegalStateException.class)
-  public void testFailsWhenUninitialized() {
+  public void testFailsWhenUninitializedTagged() {
     Map<String, Integer> tagToTargetRateMap = new HashMap<>();
     tagToTargetRateMap.put("red", 1000);
     tagToTargetRateMap.put("green", 2000);
@@ -141,14 +179,14 @@ public class TestEmbeddedTaggedRateLimiter {
     tagToCredits.put("red", 1);
     tagToCredits.put("green", 1);
     RateLimiter rateLimiter = new EmbeddedTaggedRateLimiter(tagToCredits);
-    TestEmbeddedRateLimiter.initRateLimiter(rateLimiter);
+    initRateLimiter(rateLimiter);
     rateLimiter.acquire(1);
   }
 
   private void verifyRate(long rate, long targetRate) {
     // As the actual rate would likely not be exactly the same as target rate, 
the calculation below
-    // verifies the actual rate is within 5% of the target rate per task
-    Assert.assertTrue(Math.abs(rate - targetRate) <= targetRate * 5 / 100);
+    // verifies the actual rate is within 10% of the target rate per task
+    Assert.assertTrue(Math.abs(rate - targetRate) <= targetRate * 10 / 100);
   }
 
   private RateLimiter createRateLimiter() {
@@ -156,8 +194,36 @@ public class TestEmbeddedTaggedRateLimiter {
     tagToTargetRateMap.put("red", TARGET_RATE_RED);
     tagToTargetRateMap.put("green", TARGET_RATE_GREEN);
     RateLimiter rateLimiter = new 
EmbeddedTaggedRateLimiter(tagToTargetRateMap);
-    TestEmbeddedRateLimiter.initRateLimiter(rateLimiter);
+    initRateLimiter(rateLimiter);
     return rateLimiter;
   }
 
+  private void verifyRate(long rate) {
+    // As the actual rate would likely not be exactly the same as target rate, 
the calculation below
+    // verifies the actual rate is within 5% of the target rate per task
+    junit.framework.Assert.assertTrue(Math.abs(rate - TARGET_RATE_PER_TASK) <= 
TARGET_RATE_PER_TASK * 5 / 100);
+  }
+
+  static void initRateLimiter(RateLimiter rateLimiter) {
+    Config config = mock(Config.class);
+    TaskContext taskContext = mock(TaskContext.class);
+    SamzaContainerContext containerContext = mockSamzaContainerContext();
+    when(taskContext.getSamzaContainerContext()).thenReturn(containerContext);
+    rateLimiter.init(config, taskContext);
+  }
+
+  static SamzaContainerContext mockSamzaContainerContext() {
+    try {
+      Collection<String> taskNames = mock(Collection.class);
+      when(taskNames.size()).thenReturn(NUMBER_OF_TASKS);
+      SamzaContainerContext containerContext = 
mock(SamzaContainerContext.class);
+      Field taskNamesField = 
SamzaContainerContext.class.getDeclaredField("taskNames");
+      taskNamesField.setAccessible(true);
+      taskNamesField.set(containerContext, taskNames);
+      taskNamesField.setAccessible(false);
+      return containerContext;
+    } catch (Exception ex) {
+      throw new SamzaException(ex);
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/2be7061d/samza-kv/src/main/java/org/apache/samza/storage/kv/BaseLocalStoreBackedTableProvider.java
----------------------------------------------------------------------
diff --git 
a/samza-kv/src/main/java/org/apache/samza/storage/kv/BaseLocalStoreBackedTableProvider.java
 
b/samza-kv/src/main/java/org/apache/samza/storage/kv/BaseLocalStoreBackedTableProvider.java
index 4af0f1d..b494eba 100644
--- 
a/samza-kv/src/main/java/org/apache/samza/storage/kv/BaseLocalStoreBackedTableProvider.java
+++ 
b/samza-kv/src/main/java/org/apache/samza/storage/kv/BaseLocalStoreBackedTableProvider.java
@@ -25,18 +25,26 @@ import org.apache.samza.SamzaException;
 import org.apache.samza.config.JavaTableConfig;
 import org.apache.samza.config.MapConfig;
 import org.apache.samza.config.StorageConfig;
-import org.apache.samza.storage.StorageEngine;
-import org.apache.samza.table.LocalStoreBackedTableProvider;
+import org.apache.samza.container.SamzaContainerContext;
+import org.apache.samza.table.ReadableTable;
 import org.apache.samza.table.Table;
+import org.apache.samza.table.TableProvider;
 import org.apache.samza.table.TableSpec;
+import org.apache.samza.task.TaskContext;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import com.google.common.base.Preconditions;
+
 
 /**
- * Base class for tables backed by Samza stores, see {@link 
LocalStoreBackedTableProvider}.
+ * Base class for tables backed by Samza local stores. The backing stores are
+ * injected during initialization of the table. Since the lifecycle
+ * of the underlying stores are already managed by Samza container,
+ * the table provider will not manage the lifecycle of the backing
+ * stores.
  */
-abstract public class BaseLocalStoreBackedTableProvider implements 
LocalStoreBackedTableProvider {
+abstract public class BaseLocalStoreBackedTableProvider implements 
TableProvider {
 
   protected final Logger logger = LoggerFactory.getLogger(getClass());
 
@@ -44,13 +52,28 @@ abstract public class BaseLocalStoreBackedTableProvider 
implements LocalStoreBac
 
   protected KeyValueStore kvStore;
 
+  protected SamzaContainerContext containerContext;
+
+  protected TaskContext taskContext;
+
   public BaseLocalStoreBackedTableProvider(TableSpec tableSpec) {
     this.tableSpec = tableSpec;
   }
 
   @Override
-  public void init(StorageEngine store) {
-    kvStore = (KeyValueStore) store;
+  public void init(SamzaContainerContext containerContext, TaskContext 
taskContext) {
+    this.containerContext = containerContext;
+    this.taskContext = taskContext;
+
+    Preconditions.checkNotNull(this.taskContext, "Must specify task context 
for local tables.");
+
+    kvStore = (KeyValueStore) taskContext.getStore(tableSpec.getId());
+
+    if (kvStore == null) {
+      throw new SamzaException(String.format(
+          "Backing store for table %s was not injected by SamzaContainer", 
tableSpec.getId()));
+    }
+
     logger.info("Initialized backing store for table " + tableSpec.getId());
   }
 
@@ -59,17 +82,9 @@ abstract public class BaseLocalStoreBackedTableProvider 
implements LocalStoreBac
     if (kvStore == null) {
       throw new SamzaException("Store not initialized for table " + 
tableSpec.getId());
     }
-    return new LocalStoreBackedReadWriteTable(kvStore);
-  }
-
-  @Override
-  public void start() {
-    logger.info("Starting table provider for table " + tableSpec.getId());
-  }
-
-  @Override
-  public void stop() {
-    logger.info("Stopping table provider for table " + tableSpec.getId());
+    ReadableTable table = new 
LocalStoreBackedReadWriteTable(tableSpec.getId(), kvStore);
+    table.init(containerContext, taskContext);
+    return table;
   }
 
   protected Map<String, String> generateCommonStoreConfig(Map<String, String> 
config) {
@@ -89,4 +104,9 @@ abstract public class BaseLocalStoreBackedTableProvider 
implements LocalStoreBac
 
     return storeConfig;
   }
+
+  @Override
+  public void close() {
+    logger.info("Shutting down table provider for table " + tableSpec.getId());
+  }
 }

http://git-wip-us.apache.org/repos/asf/samza/blob/2be7061d/samza-kv/src/main/java/org/apache/samza/storage/kv/LocalStoreBackedReadWriteTable.java
----------------------------------------------------------------------
diff --git 
a/samza-kv/src/main/java/org/apache/samza/storage/kv/LocalStoreBackedReadWriteTable.java
 
b/samza-kv/src/main/java/org/apache/samza/storage/kv/LocalStoreBackedReadWriteTable.java
index 3149c86..4037f60 100644
--- 
a/samza-kv/src/main/java/org/apache/samza/storage/kv/LocalStoreBackedReadWriteTable.java
+++ 
b/samza-kv/src/main/java/org/apache/samza/storage/kv/LocalStoreBackedReadWriteTable.java
@@ -36,8 +36,8 @@ public class LocalStoreBackedReadWriteTable<K, V> extends 
LocalStoreBackedReadab
    * Constructs an instance of {@link LocalStoreBackedReadWriteTable}
    * @param kvStore the backing store
    */
-  public LocalStoreBackedReadWriteTable(KeyValueStore kvStore) {
-    super(kvStore);
+  public LocalStoreBackedReadWriteTable(String tableId, KeyValueStore kvStore) 
{
+    super(tableId, kvStore);
   }
 
   @Override

http://git-wip-us.apache.org/repos/asf/samza/blob/2be7061d/samza-kv/src/main/java/org/apache/samza/storage/kv/LocalStoreBackedReadableTable.java
----------------------------------------------------------------------
diff --git 
a/samza-kv/src/main/java/org/apache/samza/storage/kv/LocalStoreBackedReadableTable.java
 
b/samza-kv/src/main/java/org/apache/samza/storage/kv/LocalStoreBackedReadableTable.java
index fead086..5ff58ab 100644
--- 
a/samza-kv/src/main/java/org/apache/samza/storage/kv/LocalStoreBackedReadableTable.java
+++ 
b/samza-kv/src/main/java/org/apache/samza/storage/kv/LocalStoreBackedReadableTable.java
@@ -24,6 +24,8 @@ import java.util.stream.Collectors;
 
 import org.apache.samza.table.ReadableTable;
 
+import com.google.common.base.Preconditions;
+
 
 /**
  * A store backed readable table
@@ -34,12 +36,16 @@ import org.apache.samza.table.ReadableTable;
 public class LocalStoreBackedReadableTable<K, V> implements ReadableTable<K, 
V> {
 
   protected KeyValueStore<K, V> kvStore;
+  protected String tableId;
 
   /**
    * Constructs an instance of {@link LocalStoreBackedReadableTable}
    * @param kvStore the backing store
    */
-  public LocalStoreBackedReadableTable(KeyValueStore<K, V> kvStore) {
+  public LocalStoreBackedReadableTable(String tableId, KeyValueStore<K, V> 
kvStore) {
+    Preconditions.checkArgument(tableId != null & !tableId.isEmpty() , 
"invalid tableId");
+    Preconditions.checkNotNull(kvStore, "null KeyValueStore");
+    this.tableId = tableId;
     this.kvStore = kvStore;
   }
 

http://git-wip-us.apache.org/repos/asf/samza/blob/2be7061d/samza-kv/src/test/java/org/apache/samza/storage/kv/TestLocalBaseStoreBackedTableProvider.java
----------------------------------------------------------------------
diff --git 
a/samza-kv/src/test/java/org/apache/samza/storage/kv/TestLocalBaseStoreBackedTableProvider.java
 
b/samza-kv/src/test/java/org/apache/samza/storage/kv/TestLocalBaseStoreBackedTableProvider.java
index 9c95637..d30c18f 100644
--- 
a/samza-kv/src/test/java/org/apache/samza/storage/kv/TestLocalBaseStoreBackedTableProvider.java
+++ 
b/samza-kv/src/test/java/org/apache/samza/storage/kv/TestLocalBaseStoreBackedTableProvider.java
@@ -27,11 +27,13 @@ import org.apache.samza.config.JavaTableConfig;
 import org.apache.samza.config.StorageConfig;
 import org.apache.samza.storage.StorageEngine;
 import org.apache.samza.table.TableSpec;
+import org.apache.samza.task.TaskContext;
 import org.junit.Before;
 import org.junit.Test;
 
 import junit.framework.Assert;
 
+import static org.mockito.Matchers.any;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
@@ -55,7 +57,9 @@ public class TestLocalBaseStoreBackedTableProvider {
   @Test
   public void testInit() {
     StorageEngine store = mock(KeyValueStorageEngine.class);
-    tableProvider.init(store);
+    TaskContext taskContext = mock(TaskContext.class);
+    when(taskContext.getStore(any())).thenReturn(store);
+    tableProvider.init(null, taskContext);
     Assert.assertNotNull(tableProvider.getTable());
   }
 

http://git-wip-us.apache.org/repos/asf/samza/blob/2be7061d/samza-test/src/test/java/org/apache/samza/test/table/TestLocalTable.java
----------------------------------------------------------------------
diff --git 
a/samza-test/src/test/java/org/apache/samza/test/table/TestLocalTable.java 
b/samza-test/src/test/java/org/apache/samza/test/table/TestLocalTable.java
index 8f7eb5d..23fa9e6 100644
--- a/samza-test/src/test/java/org/apache/samza/test/table/TestLocalTable.java
+++ b/samza-test/src/test/java/org/apache/samza/test/table/TestLocalTable.java
@@ -74,7 +74,7 @@ public class TestLocalTable extends 
AbstractIntegrationTestHarness {
     Profile[] profiles = TestTableData.generateProfiles(count);
 
     int partitionCount = 4;
-    Map<String, String> configs = getBaseJobConfig();
+    Map<String, String> configs = getBaseJobConfig(bootstrapUrl(), 
zkConnect());
 
     configs.put("streams.Profile.samza.system", "test");
     configs.put("streams.Profile.source", 
Base64Serializer.serialize(profiles));
@@ -112,7 +112,7 @@ public class TestLocalTable extends 
AbstractIntegrationTestHarness {
     Profile[] profiles = TestTableData.generateProfiles(count);
 
     int partitionCount = 4;
-    Map<String, String> configs = getBaseJobConfig();
+    Map<String, String> configs = getBaseJobConfig(bootstrapUrl(), 
zkConnect());
 
     configs.put("streams.PageView.samza.system", "test");
     configs.put("streams.PageView.source", 
Base64Serializer.serialize(pageViews));
@@ -170,7 +170,7 @@ public class TestLocalTable extends 
AbstractIntegrationTestHarness {
     Profile[] profiles = TestTableData.generateProfiles(count);
 
     int partitionCount = 4;
-    Map<String, String> configs = getBaseJobConfig();
+    Map<String, String> configs = getBaseJobConfig(bootstrapUrl(), 
zkConnect());
 
     configs.put("streams.Profile1.samza.system", "test");
     configs.put("streams.Profile1.source", 
Base64Serializer.serialize(profiles));
@@ -239,7 +239,7 @@ public class TestLocalTable extends 
AbstractIntegrationTestHarness {
     assertTrue(joinedPageViews2.get(0) instanceof EnrichedPageView);
   }
 
-  private Map<String, String> getBaseJobConfig() {
+  static Map<String, String> getBaseJobConfig(String bootstrapUrl, String 
zkConnect) {
     Map<String, String> configs = new HashMap<>();
     configs.put("systems.test.samza.factory", 
ArraySystemFactory.class.getName());
 
@@ -251,8 +251,8 @@ public class TestLocalTable extends 
AbstractIntegrationTestHarness {
 
     // For intermediate streams
     configs.put("systems.kafka.samza.factory", 
"org.apache.samza.system.kafka.KafkaSystemFactory");
-    configs.put("systems.kafka.producer.bootstrap.servers", bootstrapUrl());
-    configs.put("systems.kafka.consumer.zookeeper.connect", zkConnect());
+    configs.put("systems.kafka.producer.bootstrap.servers", bootstrapUrl);
+    configs.put("systems.kafka.consumer.zookeeper.connect", zkConnect);
     configs.put("systems.kafka.samza.key.serde", "int");
     configs.put("systems.kafka.samza.msg.serde", "json");
     configs.put("systems.kafka.default.stream.replication.factor", "1");
@@ -281,7 +281,7 @@ public class TestLocalTable extends 
AbstractIntegrationTestHarness {
     }
   }
 
-  private class PageViewToProfileJoinFunction implements 
StreamTableJoinFunction
+  static class PageViewToProfileJoinFunction implements StreamTableJoinFunction
       <Integer, KV<Integer, PageView>, KV<Integer, Profile>, EnrichedPageView> 
{
     private int count;
     @Override

http://git-wip-us.apache.org/repos/asf/samza/blob/2be7061d/samza-test/src/test/java/org/apache/samza/test/table/TestRemoteTable.java
----------------------------------------------------------------------
diff --git 
a/samza-test/src/test/java/org/apache/samza/test/table/TestRemoteTable.java 
b/samza-test/src/test/java/org/apache/samza/test/table/TestRemoteTable.java
new file mode 100644
index 0000000..a260c3f
--- /dev/null
+++ b/samza-test/src/test/java/org/apache/samza/test/table/TestRemoteTable.java
@@ -0,0 +1,180 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.samza.test.table;
+
+import java.io.IOException;
+import java.io.ObjectInputStream;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+
+import org.apache.samza.SamzaException;
+import org.apache.samza.application.StreamApplication;
+import org.apache.samza.config.MapConfig;
+import org.apache.samza.container.SamzaContainerContext;
+import org.apache.samza.metrics.Counter;
+import org.apache.samza.metrics.MetricsRegistry;
+import org.apache.samza.metrics.Timer;
+import org.apache.samza.operators.KV;
+import org.apache.samza.runtime.LocalApplicationRunner;
+import org.apache.samza.serializers.NoOpSerde;
+import org.apache.samza.table.Table;
+import org.apache.samza.table.remote.TableReadFunction;
+import org.apache.samza.table.remote.TableWriteFunction;
+import org.apache.samza.table.remote.RemoteReadableTable;
+import org.apache.samza.table.remote.RemoteTableDescriptor;
+import org.apache.samza.table.remote.RemoteReadWriteTable;
+import org.apache.samza.task.TaskContext;
+import org.apache.samza.test.harness.AbstractIntegrationTestHarness;
+import org.apache.samza.test.util.Base64Serializer;
+import org.apache.samza.util.RateLimiter;
+import org.junit.Assert;
+import org.junit.Test;
+
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyString;
+import static org.mockito.Mockito.doReturn;
+import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.mock;
+
+
+public class TestRemoteTable extends AbstractIntegrationTestHarness {
+  private TableReadFunction<Integer, TestTableData.Profile> 
getInMemoryReader(TestTableData.Profile[] profiles) {
+    final Map<Integer, TestTableData.Profile> profileMap = 
Arrays.stream(profiles)
+        .collect(Collectors.toMap(p -> p.getMemberId(), Function.identity()));
+    TableReadFunction<Integer, TestTableData.Profile> reader =
+        (TableReadFunction<Integer, TestTableData.Profile>) key -> 
profileMap.getOrDefault(key, null);
+    return reader;
+  }
+
+  static List<TestTableData.EnrichedPageView> writtenRecords = new 
LinkedList<>();
+
+  static class InMemoryWriteFunction implements TableWriteFunction<Integer, 
TestTableData.EnrichedPageView> {
+    private transient List<TestTableData.EnrichedPageView> records;
+
+    // Verify serializable functionality
+    private void readObject(ObjectInputStream in) throws IOException, 
ClassNotFoundException {
+      in.defaultReadObject();
+
+      // Write to the global list for verification
+      records = writtenRecords;
+    }
+
+    @Override
+    public void put(Integer key, TestTableData.EnrichedPageView record) {
+      records.add(record);
+    }
+
+    @Override
+    public void delete(Integer key) {
+      records.remove(key);
+    }
+
+    @Override
+    public void deleteAll(Collection<Integer> keys) {
+      records.removeAll(keys);
+    }
+  }
+
+  @Test
+  public void testStreamTableJoinRemoteTable() throws Exception {
+    List<TestTableData.PageView> received = new LinkedList<>();
+    final InMemoryWriteFunction writer = new InMemoryWriteFunction();
+
+    int count = 10;
+    TestTableData.PageView[] pageViews = 
TestTableData.generatePageViews(count);
+    TestTableData.Profile[] profiles = TestTableData.generateProfiles(count);
+
+    int partitionCount = 4;
+    Map<String, String> configs = 
TestLocalTable.getBaseJobConfig(bootstrapUrl(), zkConnect());
+
+    configs.put("streams.PageView.samza.system", "test");
+    configs.put("streams.PageView.source", 
Base64Serializer.serialize(pageViews));
+    configs.put("streams.PageView.partitionCount", 
String.valueOf(partitionCount));
+
+    final RateLimiter readRateLimiter = mock(RateLimiter.class);
+    final RateLimiter writeRateLimiter = mock(RateLimiter.class);
+    final LocalApplicationRunner runner = new LocalApplicationRunner(new 
MapConfig(configs));
+    final StreamApplication app = (streamGraph, cfg) -> {
+      RemoteTableDescriptor<Integer, TestTableData.Profile> inputTableDesc = 
new RemoteTableDescriptor<>("profile-table-1");
+      inputTableDesc
+          .withReadFunction(getInMemoryReader(profiles))
+          .withRateLimiter(readRateLimiter, null, null);
+
+      RemoteTableDescriptor<Integer, TestTableData.EnrichedPageView> 
outputTableDesc = new RemoteTableDescriptor<>("enriched-page-view-table-1");
+      outputTableDesc
+          .withReadFunction(key -> null) // dummy reader
+          .withWriteFunction(writer)
+          .withRateLimiter(writeRateLimiter, null, null);
+
+      Table<KV<Integer, TestTableData.Profile>> inputTable = 
streamGraph.getTable(inputTableDesc);
+      Table<KV<Integer, TestTableData.EnrichedPageView>> outputTable = 
streamGraph.getTable(outputTableDesc);
+
+      streamGraph.getInputStream("PageView", new 
NoOpSerde<TestTableData.PageView>())
+          .map(pv -> {
+              received.add(pv);
+              return new KV<Integer, TestTableData.PageView>(pv.getMemberId(), 
pv);
+            })
+          .join(inputTable, new TestLocalTable.PageViewToProfileJoinFunction())
+          .map(m -> new KV(m.getMemberId(), m))
+          .sendTo(outputTable);
+    };
+
+    runner.run(app);
+    runner.waitForFinish();
+
+    int numExpected = count * partitionCount;
+    Assert.assertEquals(numExpected, received.size());
+    Assert.assertEquals(numExpected, writtenRecords.size());
+    Assert.assertTrue(writtenRecords.get(0) instanceof 
TestTableData.EnrichedPageView);
+  }
+
+  private TaskContext createMockTaskContext() {
+    MetricsRegistry metricsRegistry = mock(MetricsRegistry.class);
+    doReturn(new Counter("")).when(metricsRegistry).newCounter(anyString(), 
anyString());
+    doReturn(new Timer("")).when(metricsRegistry).newTimer(anyString(), 
anyString());
+    TaskContext context = mock(TaskContext.class);
+    doReturn(metricsRegistry).when(context).getMetricsRegistry();
+    return context;
+  }
+
+  @Test(expected = SamzaException.class)
+  public void testCatchReaderException() {
+    TableReadFunction<String, ?> reader = mock(TableReadFunction.class);
+    doThrow(new RuntimeException("Expected test 
exception")).when(reader).get(anyString());
+    RemoteReadableTable<String, ?> table = new RemoteReadableTable<>("table1", 
reader, null, null);
+    table.init(mock(SamzaContainerContext.class), createMockTaskContext());
+    table.get("abc");
+  }
+
+  @Test(expected = SamzaException.class)
+  public void testCatchWriterException() {
+    TableReadFunction<String, String> reader = mock(TableReadFunction.class);
+    TableWriteFunction<String, String> writer = mock(TableWriteFunction.class);
+    doThrow(new RuntimeException("Expected test 
exception")).when(writer).put(anyString(), any());
+    RemoteReadWriteTable<String, String> table = new 
RemoteReadWriteTable<>("table1", reader, writer, null, null, null);
+    table.init(mock(SamzaContainerContext.class), createMockTaskContext());
+    table.put("abc", "efg");
+  }
+}

Reply via email to