Repository: samza Updated Branches: refs/heads/master 8aa879aa6 -> a08040dcb
http://git-wip-us.apache.org/repos/asf/samza/blob/a08040dc/samza-core/src/test/java/org/apache/samza/table/caching/TestCachingTable.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/table/caching/TestCachingTable.java b/samza-core/src/test/java/org/apache/samza/table/caching/TestCachingTable.java index 2e40358..49c72dc 100644 --- a/samza-core/src/test/java/org/apache/samza/table/caching/TestCachingTable.java +++ b/samza-core/src/test/java/org/apache/samza/table/caching/TestCachingTable.java @@ -20,29 +20,41 @@ package org.apache.samza.table.caching; import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; -import java.util.Random; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; -import java.util.concurrent.TimeUnit; import org.apache.commons.lang3.tuple.Pair; import org.apache.samza.container.SamzaContainerContext; +import org.apache.samza.metrics.Counter; +import org.apache.samza.metrics.Gauge; +import org.apache.samza.metrics.MetricsRegistry; +import org.apache.samza.metrics.Timer; import org.apache.samza.operators.TableImpl; +import org.apache.samza.storage.kv.Entry; import org.apache.samza.table.ReadWriteTable; import org.apache.samza.table.ReadableTable; import org.apache.samza.table.Table; import org.apache.samza.table.TableSpec; +import org.apache.samza.table.caching.guava.GuavaCacheTable; import org.apache.samza.table.caching.guava.GuavaCacheTableDescriptor; import org.apache.samza.table.caching.guava.GuavaCacheTableProvider; +import org.apache.samza.table.remote.TableRateLimiter; +import org.apache.samza.table.remote.RemoteReadWriteTable; +import org.apache.samza.table.remote.TableReadFunction; +import org.apache.samza.table.remote.TableWriteFunction; import org.apache.samza.task.TaskContext; import org.apache.samza.util.NoOpMetricsRegistry; import org.junit.Assert; import org.junit.Test; +import com.google.common.cache.Cache; import com.google.common.cache.CacheBuilder; import static org.mockito.Matchers.any; @@ -81,7 +93,6 @@ public class TestCachingTable { desc.withCache(cache); } - desc.withStripes(32); desc.withWriteAround(); TableSpec spec = desc.getTableSpec(); @@ -94,7 +105,6 @@ public class TestCachingTable { Assert.assertTrue(spec.getConfig().containsKey(CachingTableProvider.CACHE_TABLE_ID)); } - Assert.assertEquals("32", spec.getConfig().get(CachingTableProvider.LOCK_STRIPES)); Assert.assertEquals("true", spec.getConfig().get(CachingTableProvider.WRITE_AROUND)); desc.validate(); @@ -106,33 +116,39 @@ public class TestCachingTable { // CHM for each does not serialize such two-step operation so the atomicity is still tested. // Regular HashMap is not thread-safe even for disjoint keys. final Map<String, String> cacheStore = new ConcurrentHashMap<>(); - final ReadWriteTable tableCache = mock(ReadWriteTable.class); + final ReadWriteTable cacheTable = mock(ReadWriteTable.class); doAnswer(invocation -> { String key = invocation.getArgumentAt(0, String.class); String value = invocation.getArgumentAt(1, String.class); cacheStore.put(key, value); return null; - }).when(tableCache).put(any(), any()); + }).when(cacheTable).put(any(), any()); doAnswer(invocation -> { String key = invocation.getArgumentAt(0, String.class); return cacheStore.get(key); - }).when(tableCache).get(any()); + }).when(cacheTable).get(any()); doAnswer(invocation -> { String key = invocation.getArgumentAt(0, String.class); return cacheStore.remove(key); - }).when(tableCache).delete(any()); + }).when(cacheTable).delete(any()); - return Pair.of(tableCache, cacheStore); + return Pair.of(cacheTable, cacheStore); } - private void initTable(CachingTable cachingTable) { + private void initTables(ReadableTable ... tables) { SamzaContainerContext containerContext = mock(SamzaContainerContext.class); TaskContext taskContext = mock(TaskContext.class); - when(taskContext.getMetricsRegistry()).thenReturn(new NoOpMetricsRegistry()); - cachingTable.init(containerContext, taskContext); + MetricsRegistry metricsRegistry = mock(MetricsRegistry.class); + doReturn(mock(Timer.class)).when(metricsRegistry).newTimer(anyString(), anyString()); + doReturn(mock(Counter.class)).when(metricsRegistry).newCounter(anyString(), anyString()); + doReturn(mock(Gauge.class)).when(metricsRegistry).newGauge(anyString(), any()); + when(taskContext.getMetricsRegistry()).thenReturn(metricsRegistry); + for (ReadableTable table : tables) { + table.init(containerContext, taskContext); + } } private void doTestCacheOps(boolean isWriteAround) { @@ -147,14 +163,16 @@ public class TestCachingTable { SamzaContainerContext containerContext = mock(SamzaContainerContext.class); TaskContext taskContext = mock(TaskContext.class); - final ReadWriteTable tableCache = getMockCache().getLeft(); + final ReadWriteTable cacheTable = getMockCache().getLeft(); final ReadWriteTable realTable = mock(ReadWriteTable.class); doAnswer(invocation -> { String key = invocation.getArgumentAt(0, String.class); - return "test-data-" + key; - }).when(realTable).get(any()); + return CompletableFuture.completedFuture("test-data-" + key); + }).when(realTable).getAsync(any()); + + doReturn(CompletableFuture.completedFuture(null)).when(realTable).putAsync(any(), any()); doAnswer(invocation -> { String tableId = invocation.getArgumentAt(0, String.class); @@ -162,7 +180,7 @@ public class TestCachingTable { // cache return realTable; } else if (tableId.equals("cacheTable")) { - return tableCache; + return cacheTable; } Assert.fail(); @@ -173,39 +191,39 @@ public class TestCachingTable { tableProvider.init(containerContext, taskContext); - CachingTable cacheTable = (CachingTable) tableProvider.getTable(); + CachingTable cachingTable = (CachingTable) tableProvider.getTable(); - Assert.assertEquals("test-data-1", cacheTable.get("1")); - verify(realTable, times(1)).get(any()); - verify(tableCache, times(2)).get(any()); // cache miss leads to 2 more get() calls - verify(tableCache, times(1)).put(any(), any()); - Assert.assertEquals(cacheTable.hitRate(), 0.0, 0.0); // 0 hit, 1 request - Assert.assertEquals(cacheTable.missRate(), 1.0, 0.0); + Assert.assertEquals("test-data-1", cachingTable.get("1")); + verify(realTable, times(1)).getAsync(any()); + verify(cacheTable, times(1)).get(any()); // cache miss + verify(cacheTable, times(1)).put(any(), any()); + Assert.assertEquals(cachingTable.hitRate(), 0.0, 0.0); // 0 hit, 1 request + Assert.assertEquals(cachingTable.missRate(), 1.0, 0.0); - Assert.assertEquals("test-data-1", cacheTable.get("1")); - verify(realTable, times(1)).get(any()); // no change - verify(tableCache, times(3)).get(any()); - verify(tableCache, times(1)).put(any(), any()); // no change - Assert.assertEquals(cacheTable.hitRate(), 0.5, 0.0); // 1 hit, 2 requests - Assert.assertEquals(cacheTable.missRate(), 0.5, 0.0); + Assert.assertEquals("test-data-1", cachingTable.get("1")); + verify(realTable, times(1)).getAsync(any()); // no change + verify(cacheTable, times(2)).get(any()); + verify(cacheTable, times(1)).put(any(), any()); // no change + Assert.assertEquals(0.5, cachingTable.hitRate(), 0.0); // 1 hit, 2 requests + Assert.assertEquals(0.5, cachingTable.missRate(), 0.0); - cacheTable.put("2", "test-data-XXXX"); - verify(tableCache, times(isWriteAround ? 1 : 2)).put(any(), any()); - verify(realTable, times(1)).put(any(), any()); + cachingTable.put("2", "test-data-XXXX"); + verify(cacheTable, times(isWriteAround ? 1 : 2)).put(any(), any()); + verify(realTable, times(1)).putAsync(any(), any()); if (isWriteAround) { - Assert.assertEquals("test-data-2", cacheTable.get("2")); // expects value from table - verify(tableCache, times(5)).get(any()); // cache miss leads to 2 more get() calls - Assert.assertEquals(cacheTable.hitRate(), 0.33, 0.1); // 1 hit, 3 requests + Assert.assertEquals("test-data-2", cachingTable.get("2")); // expects value from table + verify(realTable, times(2)).getAsync(any()); // should have one more fetch + Assert.assertEquals(cachingTable.hitRate(), 0.33, 0.1); // 1 hit, 3 requests } else { - Assert.assertEquals("test-data-XXXX", cacheTable.get("2")); // expect value from cache - verify(tableCache, times(4)).get(any()); // cache hit - Assert.assertEquals(cacheTable.hitRate(), 0.66, 0.1); // 2 hits, 3 requests + Assert.assertEquals("test-data-XXXX", cachingTable.get("2")); // expect value from cache + verify(realTable, times(1)).getAsync(any()); // no change + Assert.assertEquals(cachingTable.hitRate(), 0.66, 0.1); // 2 hits, 3 requests } } @Test - public void testCacheOps() { + public void testCacheOpsWriteThrough() { doTestCacheOps(false); } @@ -217,12 +235,12 @@ public class TestCachingTable { @Test public void testNonexistentKeyInTable() { ReadableTable<String, String> table = mock(ReadableTable.class); - doReturn(null).when(table).get(any()); + doReturn(CompletableFuture.completedFuture(null)).when(table).getAsync(any()); ReadWriteTable<String, String> cache = getMockCache().getLeft(); - CachingTable<String, String> cachingTable = new CachingTable<>("myTable", table, cache, 16, false); - initTable(cachingTable); + CachingTable<String, String> cachingTable = new CachingTable<>("myTable", table, cache, false); + initTables(cachingTable); Assert.assertNull(cachingTable.get("abc")); - verify(cache, times(2)).get(any()); + verify(cache, times(1)).get(any()); Assert.assertNull(cache.get("abc")); verify(cache, times(0)).put(any(), any()); } @@ -230,86 +248,119 @@ public class TestCachingTable { @Test public void testKeyEviction() { ReadableTable<String, String> table = mock(ReadableTable.class); - doReturn("3").when(table).get(any()); + doReturn(CompletableFuture.completedFuture("3")).when(table).getAsync(any()); ReadWriteTable<String, String> cache = mock(ReadWriteTable.class); // no handler added to mock cache so get/put are noop, this can simulate eviction - CachingTable<String, String> cachingTable = new CachingTable<>("myTable", table, cache, 16, false); - initTable(cachingTable); + CachingTable<String, String> cachingTable = new CachingTable<>("myTable", table, cache, false); + initTables(cachingTable); cachingTable.get("abc"); - verify(table, times(1)).get(any()); + verify(table, times(1)).getAsync(any()); // get() should go to table again cachingTable.get("abc"); - verify(table, times(2)).get(any()); + verify(table, times(2)).getAsync(any()); } /** - * Test the atomic operations in CachingTable by simulating 10 threads each executing - * 5000 random operations (GET/PUT/DELETE) with random keys, which are picked from a - * narrow range (0-9) for higher concurrency. Consistency is verified by comparing - * the cache content and table content both of which should match exactly. Eviction - * is not simulated because it would be impossible to compare the cache/table. - * @throws InterruptedException + * Testing caching in a more realistic scenario with Guava cache + remote table */ @Test - public void testConcurrentAccess() throws InterruptedException { - final int numThreads = 10; - final int iterations = 5000; - ExecutorService executor = Executors.newFixedThreadPool(numThreads); - - // Ensure all threads to reach rendezvous before starting the simulation - final CountDownLatch startLatch = new CountDownLatch(numThreads); - - Pair<ReadWriteTable<String, String>, Map<String, String>> tableMapPair = getMockCache(); - final ReadableTable<String, String> table = tableMapPair.getLeft(); - - Pair<ReadWriteTable<String, String>, Map<String, String>> cacheMapPair = getMockCache(); - final CachingTable<String, String> cachingTable = new CachingTable<>("myTable", table, cacheMapPair.getLeft(), 16, false); - - Map<String, String> cacheMap = cacheMapPair.getRight(); - Map<String, String> tableMap = tableMapPair.getRight(); - - final Random rand = new Random(System.currentTimeMillis()); - - for (int i = 0; i < numThreads; i++) { - executor.submit(() -> { - try { - startLatch.countDown(); - startLatch.await(); - } catch (InterruptedException e) { - Assert.fail(); - } - - String lastPutKey = null; - for (int j = 0; j < iterations; j++) { - int cmd = rand.nextInt(3); - String key = String.valueOf(rand.nextInt(10)); - switch (cmd) { - case 0: - cachingTable.get(key); - break; - case 1: - cachingTable.put(key, "test-data-" + rand.nextInt()); - lastPutKey = key; - break; - case 2: - if (lastPutKey != null) { - cachingTable.delete(lastPutKey); - } - break; - } - } - }); + public void testGuavaCacheAndRemoteTable() throws Exception { + String tableId = "testGuavaCacheAndRemoteTable"; + Cache<String, String> guavaCache = CacheBuilder.newBuilder().initialCapacity(100).build(); + final ReadWriteTable<String, String> guavaTable = new GuavaCacheTable<>(tableId, guavaCache); + + // It is okay to share rateLimitHelper and async helper for read/write in test + TableRateLimiter<String, String> rateLimitHelper = mock(TableRateLimiter.class); + TableReadFunction<String, String> readFn = mock(TableReadFunction.class); + TableWriteFunction<String, String> writeFn = mock(TableWriteFunction.class); + final RemoteReadWriteTable<String, String> remoteTable = new RemoteReadWriteTable<>( + tableId, readFn, writeFn, rateLimitHelper, rateLimitHelper, + Executors.newSingleThreadExecutor(), Executors.newSingleThreadExecutor()); + + final CachingTable<String, String> cachingTable = new CachingTable<>( + tableId, remoteTable, guavaTable, false); + + initTables(cachingTable, guavaTable, remoteTable); + + // GET + doReturn(CompletableFuture.completedFuture("bar")).when(readFn).getAsync(any()); + Assert.assertEquals(cachingTable.getAsync("foo").get(), "bar"); + // Ensure cache is updated + Assert.assertEquals(guavaCache.getIfPresent("foo"), "bar"); + + // PUT + doReturn(CompletableFuture.completedFuture(null)).when(writeFn).putAsync(any(), any()); + cachingTable.putAsync("foo", "baz").get(); + // Ensure cache is updated + Assert.assertEquals(guavaCache.getIfPresent("foo"), "baz"); + + // DELETE + doReturn(CompletableFuture.completedFuture(null)).when(writeFn).deleteAsync(any()); + cachingTable.deleteAsync("foo").get(); + // Ensure cache is updated + Assert.assertNull(guavaCache.getIfPresent("foo")); + + // GET-ALL + Map<String, String> records = new HashMap<>(); + records.put("foo1", "bar1"); + records.put("foo2", "bar2"); + doReturn(CompletableFuture.completedFuture(records)).when(readFn).getAllAsync(any()); + Assert.assertEquals(cachingTable.getAllAsync(Arrays.asList("foo1", "foo2")).get(), records); + // Ensure cache is updated + Assert.assertEquals(guavaCache.getIfPresent("foo1"), "bar1"); + Assert.assertEquals(guavaCache.getIfPresent("foo2"), "bar2"); + + // GET-ALL with partial miss + doReturn(CompletableFuture.completedFuture(Collections.singletonMap("foo3", "bar3"))).when(readFn).getAllAsync(any()); + records = cachingTable.getAllAsync(Arrays.asList("foo1", "foo2", "foo3")).get(); + Assert.assertEquals(records.get("foo3"), "bar3"); + // Ensure cache is updated + Assert.assertEquals(guavaCache.getIfPresent("foo3"), "bar3"); + + // Calling again for the same keys should not trigger IO, ie. no exception is thrown + CompletableFuture<String> exFuture = new CompletableFuture<>(); + exFuture.completeExceptionally(new RuntimeException("Test exception")); + doReturn(exFuture).when(readFn).getAllAsync(any()); + cachingTable.getAllAsync(Arrays.asList("foo1", "foo2", "foo3")).get(); + + // Partial results should throw + try { + cachingTable.getAllAsync(Arrays.asList("foo1", "foo2", "foo5")).get(); + Assert.fail(); + } catch (Exception e) { } - executor.shutdown(); - - // Wait up to 1 minute for all threads to finish - Assert.assertTrue(executor.awaitTermination(60, TimeUnit.MINUTES)); - - // Verify cache and table contents fully match - Assert.assertEquals(cacheMap.size(), tableMap.size()); - cacheMap.keySet().forEach(k -> Assert.assertEquals(cacheMap.get(k), tableMap.get(k))); + // PUT-ALL + doReturn(CompletableFuture.completedFuture(null)).when(writeFn).putAllAsync(any()); + List<Entry<String, String>> entries = new ArrayList<>(); + entries.add(new Entry<>("foo1", "bar111")); + entries.add(new Entry<>("foo2", "bar222")); + cachingTable.putAllAsync(entries).get(); + // Ensure cache is updated + Assert.assertEquals(guavaCache.getIfPresent("foo1"), "bar111"); + Assert.assertEquals(guavaCache.getIfPresent("foo2"), "bar222"); + + // PUT-ALL with delete + doReturn(CompletableFuture.completedFuture(null)).when(writeFn).putAllAsync(any()); + doReturn(CompletableFuture.completedFuture(null)).when(writeFn).deleteAllAsync(any()); + entries = new ArrayList<>(); + entries.add(new Entry<>("foo1", "bar111")); + entries.add(new Entry<>("foo2", null)); + cachingTable.putAllAsync(entries).get(); + // Ensure cache is updated + Assert.assertNull(guavaCache.getIfPresent("foo2")); + + // At this point, foo1 and foo3 should still exist + Assert.assertNotNull(guavaCache.getIfPresent("foo1")); + Assert.assertNotNull(guavaCache.getIfPresent("foo3")); + + // DELETE-ALL + doReturn(CompletableFuture.completedFuture(null)).when(writeFn).deleteAllAsync(any()); + cachingTable.deleteAllAsync(Arrays.asList("foo1", "foo3")).get(); + // Ensure foo1 and foo3 are gone + Assert.assertNull(guavaCache.getIfPresent("foo1")); + Assert.assertNull(guavaCache.getIfPresent("foo3")); } } \ No newline at end of file http://git-wip-us.apache.org/repos/asf/samza/blob/a08040dc/samza-core/src/test/java/org/apache/samza/table/remote/TestRemoteTable.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/table/remote/TestRemoteTable.java b/samza-core/src/test/java/org/apache/samza/table/remote/TestRemoteTable.java new file mode 100644 index 0000000..21fc6a5 --- /dev/null +++ b/samza-core/src/test/java/org/apache/samza/table/remote/TestRemoteTable.java @@ -0,0 +1,413 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.samza.table.remote; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +import org.apache.samza.container.SamzaContainerContext; +import org.apache.samza.metrics.Counter; +import org.apache.samza.metrics.Gauge; +import org.apache.samza.metrics.MetricsRegistry; +import org.apache.samza.metrics.Timer; +import org.apache.samza.storage.kv.Entry; +import org.apache.samza.task.TaskContext; +import org.junit.Test; +import org.mockito.ArgumentCaptor; + +import junit.framework.Assert; + +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyCollection; +import static org.mockito.Matchers.anyString; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + + +public class TestRemoteTable { + private <K, V, T extends RemoteReadableTable<K, V>> T getTable(String tableId, + TableReadFunction<K, V> readFn, TableWriteFunction<K, V> writeFn) { + return getTable(tableId, readFn, writeFn, null); + } + + private <K, V, T extends RemoteReadableTable<K, V>> T getTable(String tableId, + TableReadFunction<K, V> readFn, TableWriteFunction<K, V> writeFn, ExecutorService cbExecutor) { + RemoteReadableTable<K, V> table; + + TableRateLimiter<K, V> readRateLimiter = mock(TableRateLimiter.class); + TableRateLimiter<K, V> writeRateLimiter = mock(TableRateLimiter.class); + doReturn(true).when(readRateLimiter).isRateLimited(); + doReturn(true).when(writeRateLimiter).isRateLimited(); + + ExecutorService tableExecutor = Executors.newSingleThreadExecutor(); + + if (writeFn == null) { + table = new RemoteReadableTable<K, V>(tableId, readFn, readRateLimiter, tableExecutor, cbExecutor); + } else { + table = new RemoteReadWriteTable<K, V>(tableId, readFn, writeFn, readRateLimiter, writeRateLimiter, tableExecutor, cbExecutor); + } + + TaskContext taskContext = mock(TaskContext.class); + MetricsRegistry metricsRegistry = mock(MetricsRegistry.class); + doReturn(mock(Timer.class)).when(metricsRegistry).newTimer(anyString(), anyString()); + doReturn(mock(Counter.class)).when(metricsRegistry).newCounter(anyString(), anyString()); + doReturn(mock(Gauge.class)).when(metricsRegistry).newGauge(anyString(), any()); + doReturn(metricsRegistry).when(taskContext).getMetricsRegistry(); + + SamzaContainerContext containerContext = mock(SamzaContainerContext.class); + + table.init(containerContext, taskContext); + + return (T) table; + } + + private void doTestGet(boolean sync, boolean error) throws Exception { + TableReadFunction<String, String> readFn = mock(TableReadFunction.class); + // Sync is backed by async so needs to mock the async method + CompletableFuture<String> future; + if (error) { + future = new CompletableFuture(); + future.completeExceptionally(new RuntimeException("Test exception")); + } else { + future = CompletableFuture.completedFuture("bar"); + } + doReturn(future).when(readFn).getAsync(anyString()); + RemoteReadableTable<String, String> table = getTable("testGet-" + sync + error, readFn, null); + Assert.assertEquals("bar", sync ? table.get("foo") : table.getAsync("foo").get()); + verify(table.readRateLimiter, times(1)).throttle(anyString()); + } + + @Test + public void testGet() throws Exception { + doTestGet(true, false); + } + + @Test + public void testGetAsync() throws Exception { + doTestGet(false, false); + } + + @Test(expected = ExecutionException.class) + public void testGetAsyncError() throws Exception { + doTestGet(false, true); + } + + @Test + public void testGetMultipleTables() { + TableReadFunction<String, String> readFn1 = mock(TableReadFunction.class); + TableReadFunction<String, String> readFn2 = mock(TableReadFunction.class); + + // Sync is backed by async so needs to mock the async method + doReturn(CompletableFuture.completedFuture("bar1")).when(readFn1).getAsync(anyString()); + doReturn(CompletableFuture.completedFuture("bar2")).when(readFn1).getAsync(anyString()); + + RemoteReadableTable<String, String> table1 = getTable("testGetMultipleTables-1", readFn1, null); + RemoteReadableTable<String, String> table2 = getTable("testGetMultipleTables-2", readFn2, null); + + CompletableFuture<String> future1 = table1.getAsync("foo1"); + CompletableFuture<String> future2 = table2.getAsync("foo2"); + + CompletableFuture.allOf(future1, future2) + .thenAccept(u -> { + Assert.assertEquals(future1.join(), "bar1"); + Assert.assertEquals(future2.join(), "bar1"); + }); + } + + private void doTestPut(boolean sync, boolean error, boolean isDelete) throws Exception { + TableWriteFunction<String, String> writeFn = mock(TableWriteFunction.class); + RemoteReadWriteTable<String, String> table = getTable("testPut-" + sync + error + isDelete, + mock(TableReadFunction.class), writeFn); + CompletableFuture<Void> future; + if (error) { + future = new CompletableFuture(); + future.completeExceptionally(new RuntimeException("Test exception")); + } else { + future = CompletableFuture.completedFuture(null); + } + // Sync is backed by async so needs to mock the async method + if (isDelete) { + doReturn(future).when(writeFn).deleteAsync(any()); + } else { + doReturn(future).when(writeFn).putAsync(any(), any()); + } + if (sync) { + table.put("foo", isDelete ? null : "bar"); + } else { + table.putAsync("foo", isDelete ? null : "bar").get(); + } + ArgumentCaptor<String> keyCaptor = ArgumentCaptor.forClass(String.class); + ArgumentCaptor<String> valCaptor = ArgumentCaptor.forClass(String.class); + if (isDelete) { + verify(writeFn, times(1)).deleteAsync(keyCaptor.capture()); + } else { + verify(writeFn, times(1)).putAsync(keyCaptor.capture(), valCaptor.capture()); + Assert.assertEquals("bar", valCaptor.getValue()); + } + Assert.assertEquals("foo", keyCaptor.getValue()); + if (isDelete) { + verify(table.writeRateLimiter, times(1)).throttle(anyString()); + } else { + verify(table.writeRateLimiter, times(1)).throttle(anyString(), anyString()); + } + } + + @Test + public void testPut() throws Exception { + doTestPut(true, false, false); + } + + @Test + public void testPutDelete() throws Exception { + doTestPut(true, false, true); + } + + @Test + public void testPutAsync() throws Exception { + doTestPut(false, false, false); + } + + @Test + public void testPutAsyncDelete() throws Exception { + doTestPut(false, false, true); + } + + @Test(expected = ExecutionException.class) + public void testPutAsyncError() throws Exception { + doTestPut(false, true, false); + } + + private void doTestDelete(boolean sync, boolean error) throws Exception { + TableWriteFunction<String, String> writeFn = mock(TableWriteFunction.class); + RemoteReadWriteTable<String, String> table = getTable("testDelete-" + sync + error, + mock(TableReadFunction.class), writeFn); + CompletableFuture<Void> future; + if (error) { + future = new CompletableFuture(); + future.completeExceptionally(new RuntimeException("Test exception")); + } else { + future = CompletableFuture.completedFuture(null); + } + // Sync is backed by async so needs to mock the async method + doReturn(future).when(writeFn).deleteAsync(any()); + ArgumentCaptor<String> argCaptor = ArgumentCaptor.forClass(String.class); + if (sync) { + table.delete("foo"); + } else { + table.deleteAsync("foo").get(); + } + verify(writeFn, times(1)).deleteAsync(argCaptor.capture()); + Assert.assertEquals("foo", argCaptor.getValue()); + verify(table.writeRateLimiter, times(1)).throttle(anyString()); + } + + @Test + public void testDelete() throws Exception { + doTestDelete(true, false); + } + + @Test + public void testDeleteAsync() throws Exception { + doTestDelete(false, false); + } + + @Test(expected = ExecutionException.class) + public void testDeleteAsyncError() throws Exception { + doTestDelete(false, true); + } + + private void doTestGetAll(boolean sync, boolean error, boolean partial) throws Exception { + TableReadFunction<String, String> readFn = mock(TableReadFunction.class); + Map<String, String> res = new HashMap<>(); + res.put("foo1", "bar1"); + if (!partial) { + res.put("foo2", "bar2"); + } + CompletableFuture<Map<String, String>> future; + if (error) { + future = new CompletableFuture(); + future.completeExceptionally(new RuntimeException("Test exception")); + } else { + future = CompletableFuture.completedFuture(res); + } + // Sync is backed by async so needs to mock the async method + doReturn(future).when(readFn).getAllAsync(any()); + RemoteReadableTable<String, String> table = getTable("testGetAll-" + sync + error + partial, readFn, null); + Assert.assertEquals(res, sync ? table.getAll(Arrays.asList("foo1", "foo2")) + : table.getAllAsync(Arrays.asList("foo1", "foo2")).get()); + verify(table.readRateLimiter, times(1)).throttle(anyCollection()); + } + + @Test + public void testGetAll() throws Exception { + doTestGetAll(true, false, false); + } + + @Test + public void testGetAllAsync() throws Exception { + doTestGetAll(false, false, false); + } + + @Test(expected = ExecutionException.class) + public void testGetAllAsyncError() throws Exception { + doTestGetAll(false, true, false); + } + + // Partial result is an acceptable scenario + @Test + public void testGetAllPartialResult() throws Exception { + doTestGetAll(false, false, true); + } + + public void doTestPutAll(boolean sync, boolean error, boolean hasDelete) throws Exception { + TableWriteFunction<String, String> writeFn = mock(TableWriteFunction.class); + RemoteReadWriteTable<String, String> table = getTable("testPutAll-" + sync + error + hasDelete, + mock(TableReadFunction.class), writeFn); + CompletableFuture<Void> future; + if (error) { + future = new CompletableFuture(); + future.completeExceptionally(new RuntimeException("Test exception")); + } else { + future = CompletableFuture.completedFuture(null); + } + // Sync is backed by async so needs to mock the async method + doReturn(future).when(writeFn).putAllAsync(any()); + if (hasDelete) { + doReturn(future).when(writeFn).deleteAllAsync(any()); + } + List<Entry<String, String>> entries = Arrays.asList( + new Entry<>("foo1", "bar1"), new Entry<>("foo2", hasDelete ? null : "bar2")); + ArgumentCaptor<List> argCaptor = ArgumentCaptor.forClass(List.class); + if (sync) { + table.putAll(entries); + } else { + table.putAllAsync(entries).get(); + } + verify(writeFn, times(1)).putAllAsync(argCaptor.capture()); + if (hasDelete) { + ArgumentCaptor<List> delArgCaptor = ArgumentCaptor.forClass(List.class); + verify(writeFn, times(1)).deleteAllAsync(delArgCaptor.capture()); + Assert.assertEquals(Arrays.asList("foo2"), delArgCaptor.getValue()); + Assert.assertEquals(1, argCaptor.getValue().size()); + Assert.assertEquals("foo1", ((Entry) argCaptor.getValue().get(0)).getKey()); + verify(table.writeRateLimiter, times(1)).throttle(anyCollection()); + } else { + Assert.assertEquals(entries, argCaptor.getValue()); + } + verify(table.writeRateLimiter, times(1)).throttleRecords(anyCollection()); + } + + @Test + public void testPutAll() throws Exception { + doTestPutAll(true, false, false); + } + + @Test + public void testPutAllHasDelete() throws Exception { + doTestPutAll(true, false, true); + } + + @Test + public void testPutAllAsync() throws Exception { + doTestPutAll(false, false, false); + } + + @Test + public void testPutAllAsyncHasDelete() throws Exception { + doTestPutAll(false, false, true); + } + + @Test(expected = ExecutionException.class) + public void testPutAllAsyncError() throws Exception { + doTestPutAll(false, true, false); + } + + public void doTestDeleteAll(boolean sync, boolean error) throws Exception { + TableWriteFunction<String, String> writeFn = mock(TableWriteFunction.class); + RemoteReadWriteTable<String, String> table = getTable("testDeleteAll-" + sync + error, + mock(TableReadFunction.class), writeFn); + CompletableFuture<Void> future; + if (error) { + future = new CompletableFuture(); + future.completeExceptionally(new RuntimeException("Test exception")); + } else { + future = CompletableFuture.completedFuture(null); + } + // Sync is backed by async so needs to mock the async method + doReturn(future).when(writeFn).deleteAllAsync(any()); + List<String> keys = Arrays.asList("foo1", "foo2"); + ArgumentCaptor<List> argCaptor = ArgumentCaptor.forClass(List.class); + if (sync) { + table.deleteAll(keys); + } else { + table.deleteAllAsync(keys).get(); + } + verify(writeFn, times(1)).deleteAllAsync(argCaptor.capture()); + Assert.assertEquals(keys, argCaptor.getValue()); + verify(table.writeRateLimiter, times(1)).throttle(anyCollection()); + } + + @Test + public void testDeleteAll() throws Exception { + doTestDeleteAll(true, false); + } + + @Test + public void testDeleteAllAsync() throws Exception { + doTestDeleteAll(false, false); + } + + @Test(expected = ExecutionException.class) + public void testDeleteAllAsyncError() throws Exception { + doTestDeleteAll(false, true); + } + + @Test + public void testFlush() { + TableWriteFunction<String, String> writeFn = mock(TableWriteFunction.class); + RemoteReadWriteTable<String, String> table = getTable("testFlush", mock(TableReadFunction.class), writeFn); + table.flush(); + verify(writeFn, times(1)).flush(); + } + + @Test + public void testGetWithCallbackExecutor() throws Exception { + TableReadFunction<String, String> readFn = mock(TableReadFunction.class); + // Sync is backed by async so needs to mock the async method + doReturn(CompletableFuture.completedFuture("bar")).when(readFn).getAsync(anyString()); + RemoteReadableTable<String, String> table = getTable("testGetWithCallbackExecutor", readFn, null, + Executors.newSingleThreadExecutor()); + Thread testThread = Thread.currentThread(); + + table.getAsync("foo").thenAccept(result -> { + Assert.assertEquals("bar", result); + // Must be executed on the executor thread + Assert.assertNotSame(testThread, Thread.currentThread()); + }); + } +} http://git-wip-us.apache.org/repos/asf/samza/blob/a08040dc/samza-core/src/test/java/org/apache/samza/table/remote/TestRemoteTableDescriptor.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/table/remote/TestRemoteTableDescriptor.java b/samza-core/src/test/java/org/apache/samza/table/remote/TestRemoteTableDescriptor.java index acf3d61..e30da12 100644 --- a/samza-core/src/test/java/org/apache/samza/table/remote/TestRemoteTableDescriptor.java +++ b/samza-core/src/test/java/org/apache/samza/table/remote/TestRemoteTableDescriptor.java @@ -22,13 +22,13 @@ package org.apache.samza.table.remote; import java.util.Collections; import java.util.HashMap; import java.util.Map; +import java.util.concurrent.ThreadPoolExecutor; import org.apache.samza.container.SamzaContainerContext; import org.apache.samza.container.TaskName; import org.apache.samza.metrics.Counter; import org.apache.samza.metrics.MetricsRegistry; import org.apache.samza.metrics.Timer; -import org.apache.samza.operators.KV; import org.apache.samza.table.Table; import org.apache.samza.table.TableSpec; import org.apache.samza.task.TaskContext; @@ -39,19 +39,16 @@ import org.junit.Test; import static org.apache.samza.table.remote.RemoteTableDescriptor.RL_READ_TAG; import static org.apache.samza.table.remote.RemoteTableDescriptor.RL_WRITE_TAG; -import static org.mockito.Matchers.anyMap; import static org.mockito.Matchers.anyString; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; public class TestRemoteTableDescriptor { private void doTestSerialize(RateLimiter rateLimiter, - CreditFunction readCredFn, - CreditFunction writeCredFn) { + TableRateLimiter.CreditFunction readCredFn, + TableRateLimiter.CreditFunction writeCredFn) { RemoteTableDescriptor desc = new RemoteTableDescriptor("1"); desc.withReadFunction(mock(TableReadFunction.class)); desc.withWriteFunction(mock(TableWriteFunction.class)); @@ -79,17 +76,17 @@ public class TestRemoteTableDescriptor { @Test public void testSerializeWithLimiterAndReadCredFn() { - doTestSerialize(mock(RateLimiter.class), kv -> 1, null); + doTestSerialize(mock(RateLimiter.class), (k, v) -> 1, null); } @Test public void testSerializeWithLimiterAndWriteCredFn() { - doTestSerialize(mock(RateLimiter.class), null, kv -> 1); + doTestSerialize(mock(RateLimiter.class), null, (k, v) -> 1); } @Test public void testSerializeWithLimiterAndReadWriteCredFns() { - doTestSerialize(mock(RateLimiter.class), kv -> 1, kv -> 1); + doTestSerialize(mock(RateLimiter.class), (key, value) -> 1, (key, value) -> 1); } @Test @@ -129,10 +126,10 @@ public class TestRemoteTableDescriptor { return taskContext; } - static class CountingCreditFunction<K, V> implements CreditFunction<K, V> { + static class CountingCreditFunction<K, V> implements TableRateLimiter.CreditFunction<K, V> { int numCalls = 0; @Override - public Integer apply(KV<K, V> kv) { + public int getCredits(K key, V value) { numCalls++; return 1; } @@ -143,6 +140,8 @@ public class TestRemoteTableDescriptor { RemoteTableDescriptor<String, String> desc = new RemoteTableDescriptor("1"); desc.withReadFunction(mock(TableReadFunction.class)); desc.withWriteFunction(mock(TableWriteFunction.class)); + desc.withAsyncCallbackExecutorPoolSize(10); + if (rateOnly) { if (rlGets) { desc.withReadRateLimit(1000); @@ -172,39 +171,13 @@ public class TestRemoteTableDescriptor { Table table = provider.getTable(); Assert.assertTrue(table instanceof RemoteReadWriteTable); RemoteReadWriteTable rwTable = (RemoteReadWriteTable) table; - Assert.assertNotNull(rwTable.readFn); - Assert.assertNotNull(rwTable.writeFn); if (numRateLimitOps > 0) { - Assert.assertNotNull(rwTable.rateLimiter); + Assert.assertTrue(!rlGets || rwTable.readRateLimiter != null); + Assert.assertTrue(!rlPuts || rwTable.writeRateLimiter != null); } - // Verify rate limiter usage - if (numRateLimitOps > 0) { - rwTable.get("xxx"); - rwTable.put("yyy", "zzz"); - - if (!rateOnly) { - verify(rwTable.rateLimiter, times(numRateLimitOps)).acquire(anyMap()); - - CountingCreditFunction<?, ?> readCreditFn = (CountingCreditFunction<?, ?>) rwTable.readCreditFn; - CountingCreditFunction<?, ?> writeCreditFn = (CountingCreditFunction<?, ?>) rwTable.writeCreditFn; - - Assert.assertNotNull(readCreditFn); - Assert.assertNotNull(writeCreditFn); - - Assert.assertEquals(readCreditFn.numCalls, rlGets ? 1 : 0); - Assert.assertEquals(writeCreditFn.numCalls, rlPuts ? 1 : 0); - } else { - Assert.assertTrue(rwTable.rateLimiter instanceof EmbeddedTaggedRateLimiter); - Assert.assertEquals(rwTable.rateLimiter.getSupportedTags().size(), numRateLimitOps); - if (rlGets) { - Assert.assertTrue(rwTable.rateLimiter.getSupportedTags().contains(RL_READ_TAG)); - } - if (rlPuts) { - Assert.assertTrue(rwTable.rateLimiter.getSupportedTags().contains(RL_WRITE_TAG)); - } - } - } + ThreadPoolExecutor callbackExecutor = (ThreadPoolExecutor) rwTable.callbackExecutor; + Assert.assertEquals(10, callbackExecutor.getCorePoolSize()); } @Test http://git-wip-us.apache.org/repos/asf/samza/blob/a08040dc/samza-core/src/test/java/org/apache/samza/table/remote/TestTableRateLimiter.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/table/remote/TestTableRateLimiter.java b/samza-core/src/test/java/org/apache/samza/table/remote/TestTableRateLimiter.java new file mode 100644 index 0000000..ea9acbd --- /dev/null +++ b/samza-core/src/test/java/org/apache/samza/table/remote/TestTableRateLimiter.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.samza.table.remote; + +import java.util.Arrays; +import java.util.Collections; + +import org.apache.samza.metrics.Timer; +import org.apache.samza.storage.kv.Entry; +import org.apache.samza.util.RateLimiter; +import org.junit.Test; + +import junit.framework.Assert; + +import static org.mockito.Matchers.anyLong; +import static org.mockito.Matchers.anyMap; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + + +public class TestTableRateLimiter { + private static final String DEFAULT_TAG = "mytag"; + + public TableRateLimiter<String, String> getThrottler() { + return getThrottler(DEFAULT_TAG); + } + + public TableRateLimiter<String, String> getThrottler(String tag) { + TableRateLimiter.CreditFunction<String, String> credFn = + (TableRateLimiter.CreditFunction<String, String>) (key, value) -> { + int credits = key == null ? 0 : 3; + credits += value == null ? 0 : 3; + return credits; + }; + RateLimiter rateLimiter = mock(RateLimiter.class); + doReturn(Collections.singleton(DEFAULT_TAG)).when(rateLimiter).getSupportedTags(); + TableRateLimiter<String, String> rateLimitHelper = new TableRateLimiter<>("foo", rateLimiter, credFn, tag); + Timer timer = mock(Timer.class); + rateLimitHelper.setTimerMetric(timer); + return rateLimitHelper; + } + + @Test + public void testCreditKeyOnly() { + TableRateLimiter<String, String> rateLimitHelper = getThrottler(); + Assert.assertEquals(3, rateLimitHelper.getCredits("abc", null)); + } + + @Test + public void testCreditKeyValue() { + TableRateLimiter<String, String> rateLimitHelper = getThrottler(); + Assert.assertEquals(6, rateLimitHelper.getCredits("abc", "efg")); + } + + @Test + public void testCreditKeys() { + TableRateLimiter<String, String> rateLimitHelper = getThrottler(); + Assert.assertEquals(9, rateLimitHelper.getCredits(Arrays.asList("abc", "efg", "hij"))); + } + + @Test + public void testCreditEntries() { + TableRateLimiter<String, String> rateLimitHelper = getThrottler(); + Assert.assertEquals(12, rateLimitHelper.getEntryCredits( + Arrays.asList(new Entry<>("abc", "efg"), new Entry<>("hij", "lmn")))); + } + + @Test + public void testThrottle() { + TableRateLimiter<String, String> rateLimitHelper = getThrottler(); + Timer timer = mock(Timer.class); + rateLimitHelper.setTimerMetric(timer); + rateLimitHelper.throttle("foo"); + verify(rateLimitHelper.rateLimiter, times(1)).acquire(anyMap()); + verify(timer, times(1)).update(anyLong()); + } + + @Test + public void testThrottleUnknownTag() { + TableRateLimiter<String, String> rateLimitHelper = getThrottler("unknown_tag"); + rateLimitHelper.throttle("foo"); + verify(rateLimitHelper.rateLimiter, times(0)).acquire(anyMap()); + } +} http://git-wip-us.apache.org/repos/asf/samza/blob/a08040dc/samza-kv/src/main/java/org/apache/samza/storage/kv/LocalStoreBackedReadWriteTable.java ---------------------------------------------------------------------- diff --git a/samza-kv/src/main/java/org/apache/samza/storage/kv/LocalStoreBackedReadWriteTable.java b/samza-kv/src/main/java/org/apache/samza/storage/kv/LocalStoreBackedReadWriteTable.java index 882ae0d..98c3e3c 100644 --- a/samza-kv/src/main/java/org/apache/samza/storage/kv/LocalStoreBackedReadWriteTable.java +++ b/samza-kv/src/main/java/org/apache/samza/storage/kv/LocalStoreBackedReadWriteTable.java @@ -19,6 +19,7 @@ package org.apache.samza.storage.kv; import java.util.List; +import java.util.concurrent.CompletableFuture; import org.apache.samza.container.SamzaContainerContext; import org.apache.samza.table.ReadWriteTable; @@ -67,6 +68,18 @@ public class LocalStoreBackedReadWriteTable<K, V> extends LocalStoreBackedReadab } @Override + public CompletableFuture<Void> putAsync(K key, V value) { + CompletableFuture<Void> future = new CompletableFuture(); + try { + put(key, value); + future.complete(null); + } catch (Exception e) { + future.completeExceptionally(e); + } + return future; + } + + @Override public void putAll(List<Entry<K, V>> entries) { writeMetrics.numPutAlls.inc(); long startNs = System.nanoTime(); @@ -75,6 +88,18 @@ public class LocalStoreBackedReadWriteTable<K, V> extends LocalStoreBackedReadab } @Override + public CompletableFuture<Void> putAllAsync(List<Entry<K, V>> entries) { + CompletableFuture<Void> future = new CompletableFuture(); + try { + putAll(entries); + future.complete(null); + } catch (Exception e) { + future.completeExceptionally(e); + } + return future; + } + + @Override public void delete(K key) { writeMetrics.numDeletes.inc(); long startNs = System.nanoTime(); @@ -83,6 +108,18 @@ public class LocalStoreBackedReadWriteTable<K, V> extends LocalStoreBackedReadab } @Override + public CompletableFuture<Void> deleteAsync(K key) { + CompletableFuture<Void> future = new CompletableFuture(); + try { + delete(key); + future.complete(null); + } catch (Exception e) { + future.completeExceptionally(e); + } + return future; + } + + @Override public void deleteAll(List<K> keys) { writeMetrics.numDeleteAlls.inc(); long startNs = System.nanoTime(); @@ -91,6 +128,18 @@ public class LocalStoreBackedReadWriteTable<K, V> extends LocalStoreBackedReadab } @Override + public CompletableFuture<Void> deleteAllAsync(List<K> keys) { + CompletableFuture<Void> future = new CompletableFuture(); + try { + deleteAll(keys); + future.complete(null); + } catch (Exception e) { + future.completeExceptionally(e); + } + return future; + } + + @Override public void flush() { writeMetrics.numFlushes.inc(); long startNs = System.nanoTime(); http://git-wip-us.apache.org/repos/asf/samza/blob/a08040dc/samza-kv/src/main/java/org/apache/samza/storage/kv/LocalStoreBackedReadableTable.java ---------------------------------------------------------------------- diff --git a/samza-kv/src/main/java/org/apache/samza/storage/kv/LocalStoreBackedReadableTable.java b/samza-kv/src/main/java/org/apache/samza/storage/kv/LocalStoreBackedReadableTable.java index 8d79e0d..1c59eb6 100644 --- a/samza-kv/src/main/java/org/apache/samza/storage/kv/LocalStoreBackedReadableTable.java +++ b/samza-kv/src/main/java/org/apache/samza/storage/kv/LocalStoreBackedReadableTable.java @@ -20,6 +20,7 @@ package org.apache.samza.storage.kv; import java.util.List; import java.util.Map; +import java.util.concurrent.CompletableFuture; import com.google.common.base.Preconditions; import org.apache.samza.container.SamzaContainerContext; @@ -70,6 +71,17 @@ public class LocalStoreBackedReadableTable<K, V> implements ReadableTable<K, V> } @Override + public CompletableFuture<V> getAsync(K key) { + CompletableFuture<V> future = new CompletableFuture(); + try { + future.complete(get(key)); + } catch (Exception e) { + future.completeExceptionally(e); + } + return future; + } + + @Override public Map<K, V> getAll(List<K> keys) { readMetrics.numGetAlls.inc(); long startNs = System.nanoTime(); @@ -79,6 +91,17 @@ public class LocalStoreBackedReadableTable<K, V> implements ReadableTable<K, V> } @Override + public CompletableFuture<Map<K, V>> getAllAsync(List<K> keys) { + CompletableFuture<Map<K, V>> future = new CompletableFuture(); + try { + future.complete(getAll(keys)); + } catch (Exception e) { + future.completeExceptionally(e); + } + return future; + } + + @Override public void close() { // The KV store is not closed here as it may still be needed by downstream operators, // it will be closed by the SamzaContainer http://git-wip-us.apache.org/repos/asf/samza/blob/a08040dc/samza-sql/src/test/java/org/apache/samza/sql/testutil/TestIOResolverFactory.java ---------------------------------------------------------------------- diff --git a/samza-sql/src/test/java/org/apache/samza/sql/testutil/TestIOResolverFactory.java b/samza-sql/src/test/java/org/apache/samza/sql/testutil/TestIOResolverFactory.java index 8a20239..7068e9b 100644 --- a/samza-sql/src/test/java/org/apache/samza/sql/testutil/TestIOResolverFactory.java +++ b/samza-sql/src/test/java/org/apache/samza/sql/testutil/TestIOResolverFactory.java @@ -23,6 +23,8 @@ import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.CompletableFuture; + import org.apache.commons.lang.NotImplementedException; import org.apache.samza.config.Config; import org.apache.samza.container.SamzaContainerContext; @@ -78,11 +80,21 @@ public class TestIOResolverFactory implements SqlIOResolverFactory { } @Override + public CompletableFuture getAsync(Object key) { + throw new NotImplementedException(); + } + + @Override public Map getAll(List keys) { throw new NotImplementedException(); } @Override + public CompletableFuture<Map> getAllAsync(List keys) { + throw new NotImplementedException(); + } + + @Override public void close() { } @@ -98,16 +110,36 @@ public class TestIOResolverFactory implements SqlIOResolverFactory { } @Override + public CompletableFuture<Void> putAsync(Object key, Object value) { + throw new NotImplementedException(); + } + + @Override + public CompletableFuture<Void> putAllAsync(List list) { + throw new NotImplementedException(); + } + + @Override public void delete(Object key) { records.remove(key); } @Override + public CompletableFuture<Void> deleteAsync(Object key) { + throw new NotImplementedException(); + } + + @Override public void deleteAll(List keys) { records.clear(); } @Override + public CompletableFuture<Void> deleteAllAsync(List keys) { + throw new NotImplementedException(); + } + + @Override public void flush() { } http://git-wip-us.apache.org/repos/asf/samza/blob/a08040dc/samza-test/src/test/java/org/apache/samza/test/table/TestLocalTable.java ---------------------------------------------------------------------- diff --git a/samza-test/src/test/java/org/apache/samza/test/table/TestLocalTable.java b/samza-test/src/test/java/org/apache/samza/test/table/TestLocalTable.java index d7f0570..14ef751 100644 --- a/samza-test/src/test/java/org/apache/samza/test/table/TestLocalTable.java +++ b/samza-test/src/test/java/org/apache/samza/test/table/TestLocalTable.java @@ -20,6 +20,7 @@ package org.apache.samza.test.table; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.LinkedList; import java.util.List; @@ -32,7 +33,12 @@ import org.apache.samza.config.JobConfig; import org.apache.samza.config.JobCoordinatorConfig; import org.apache.samza.config.MapConfig; import org.apache.samza.config.TaskConfig; +import org.apache.samza.container.SamzaContainerContext; import org.apache.samza.container.grouper.task.SingleContainerGrouperFactory; +import org.apache.samza.metrics.Counter; +import org.apache.samza.metrics.Gauge; +import org.apache.samza.metrics.MetricsRegistry; +import org.apache.samza.metrics.Timer; import org.apache.samza.operators.KV; import org.apache.samza.operators.MessageStream; import org.apache.samza.operators.functions.MapFunction; @@ -43,6 +49,9 @@ import org.apache.samza.serializers.KVSerde; import org.apache.samza.serializers.NoOpSerde; import org.apache.samza.standalone.PassthroughCoordinationUtilsFactory; import org.apache.samza.standalone.PassthroughJobCoordinatorFactory; +import org.apache.samza.storage.kv.Entry; +import org.apache.samza.storage.kv.KeyValueStore; +import org.apache.samza.storage.kv.LocalStoreBackedReadWriteTable; import org.apache.samza.storage.kv.inmemory.InMemoryTableDescriptor; import org.apache.samza.table.ReadableTable; import org.apache.samza.table.Table; @@ -61,6 +70,14 @@ import org.junit.Test; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyList; +import static org.mockito.Matchers.anyString; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + /** * This test class tests sendTo() and join() for local tables @@ -360,4 +377,48 @@ public class TestLocalTable extends AbstractIntegrationTestHarness { return record.getKey(); } } + + @Test + public void testAsyncOperation() throws Exception { + KeyValueStore kvStore = mock(KeyValueStore.class); + LocalStoreBackedReadWriteTable<String, String> table = new LocalStoreBackedReadWriteTable<>("table1", kvStore); + TaskContext taskContext = mock(TaskContext.class); + MetricsRegistry metricsRegistry = mock(MetricsRegistry.class); + doReturn(mock(Timer.class)).when(metricsRegistry).newTimer(anyString(), anyString()); + doReturn(mock(Counter.class)).when(metricsRegistry).newCounter(anyString(), anyString()); + doReturn(mock(Gauge.class)).when(metricsRegistry).newGauge(anyString(), any()); + doReturn(metricsRegistry).when(taskContext).getMetricsRegistry(); + + SamzaContainerContext containerContext = mock(SamzaContainerContext.class); + + table.init(containerContext, taskContext); + + // GET + doReturn("bar").when(kvStore).get(anyString()); + Assert.assertEquals("bar", table.getAsync("foo").get()); + + // GET-ALL + Map<String, String> recordMap = new HashMap<>(); + recordMap.put("foo1", "bar1"); + recordMap.put("foo2", "bar2"); + doReturn(recordMap).when(kvStore).getAll(anyList()); + Assert.assertEquals(recordMap, table.getAllAsync(Arrays.asList("foo1", "foo2")).get()); + + // PUT + table.putAsync("foo1", "bar1").get(); + verify(kvStore, times(1)).put(anyString(), anyString()); + + // PUT-ALL + List<Entry<String, String>> records = Arrays.asList(new Entry<>("foo1", "bar1"), new Entry<>("foo2", "bar2")); + table.putAllAsync(records).get(); + verify(kvStore, times(1)).putAll(anyList()); + + // DELETE + table.deleteAsync("foo").get(); + verify(kvStore, times(1)).delete(anyString()); + + // DELETE-ALL + table.deleteAllAsync(Arrays.asList("foo1", "foo2")).get(); + verify(kvStore, times(1)).deleteAll(anyList()); + } } http://git-wip-us.apache.org/repos/asf/samza/blob/a08040dc/samza-test/src/test/java/org/apache/samza/test/table/TestRemoteTable.java ---------------------------------------------------------------------- diff --git a/samza-test/src/test/java/org/apache/samza/test/table/TestRemoteTable.java b/samza-test/src/test/java/org/apache/samza/test/table/TestRemoteTable.java index 8d07570..2d07b01 100644 --- a/samza-test/src/test/java/org/apache/samza/test/table/TestRemoteTable.java +++ b/samza-test/src/test/java/org/apache/samza/test/table/TestRemoteTable.java @@ -24,10 +24,11 @@ import java.io.ObjectInputStream; import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.function.Function; import java.util.stream.Collectors; @@ -46,6 +47,7 @@ import org.apache.samza.serializers.NoOpSerde; import org.apache.samza.table.Table; import org.apache.samza.table.caching.CachingTableDescriptor; import org.apache.samza.table.caching.guava.GuavaCacheTableDescriptor; +import org.apache.samza.table.remote.TableRateLimiter; import org.apache.samza.table.remote.TableReadFunction; import org.apache.samza.table.remote.TableWriteFunction; import org.apache.samza.table.remote.RemoteReadableTable; @@ -63,7 +65,6 @@ import com.google.common.cache.CacheBuilder; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyString; import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; @@ -86,8 +87,8 @@ public class TestRemoteTable extends AbstractIntegrationTestHarness { } @Override - public TestTableData.Profile get(Integer key) { - return profileMap.getOrDefault(key, null); + public CompletableFuture<TestTableData.Profile> getAsync(Integer key) { + return CompletableFuture.completedFuture(profileMap.get(key)); } static InMemoryReadFunction getInMemoryReadFunction(String serializedProfiles) { @@ -112,18 +113,15 @@ public class TestRemoteTable extends AbstractIntegrationTestHarness { } @Override - public void put(Integer key, TestTableData.EnrichedPageView record) { + public CompletableFuture<Void> putAsync(Integer key, TestTableData.EnrichedPageView record) { records.add(record); + return CompletableFuture.completedFuture(null); } @Override - public void delete(Integer key) { + public CompletableFuture<Void> deleteAsync(Integer key) { records.remove(key); - } - - @Override - public void deleteAll(Collection<Integer> keys) { - records.removeAll(keys); + return CompletableFuture.completedFuture(null); } } @@ -187,9 +185,7 @@ public class TestRemoteTable extends AbstractIntegrationTestHarness { } streamGraph.getInputStream("PageView", new NoOpSerde<TestTableData.PageView>()) - .map(pv -> { - return new KV<Integer, TestTableData.PageView>(pv.getMemberId(), pv); - }) + .map(pv -> new KV<>(pv.getMemberId(), pv)) .join(inputTable, new TestLocalTable.PageViewToProfileJoinFunction()) .map(m -> new KV(m.getMemberId(), m)) .sendTo(outputTable); @@ -230,8 +226,12 @@ public class TestRemoteTable extends AbstractIntegrationTestHarness { @Test(expected = SamzaException.class) public void testCatchReaderException() { TableReadFunction<String, ?> reader = mock(TableReadFunction.class); - doThrow(new RuntimeException("Expected test exception")).when(reader).get(anyString()); - RemoteReadableTable<String, ?> table = new RemoteReadableTable<>("table1", reader, null, null); + CompletableFuture<String> future = new CompletableFuture<>(); + future.completeExceptionally(new RuntimeException("Expected test exception")); + doReturn(future).when(reader).getAsync(anyString()); + TableRateLimiter rateLimitHelper = mock(TableRateLimiter.class); + RemoteReadableTable<String, ?> table = new RemoteReadableTable<>( + "table1", reader, rateLimitHelper, Executors.newSingleThreadExecutor(), null); table.init(mock(SamzaContainerContext.class), createMockTaskContext()); table.get("abc"); } @@ -240,8 +240,12 @@ public class TestRemoteTable extends AbstractIntegrationTestHarness { public void testCatchWriterException() { TableReadFunction<String, String> reader = mock(TableReadFunction.class); TableWriteFunction<String, String> writer = mock(TableWriteFunction.class); - doThrow(new RuntimeException("Expected test exception")).when(writer).put(anyString(), any()); - RemoteReadWriteTable<String, String> table = new RemoteReadWriteTable<>("table1", reader, writer, null, null, null); + CompletableFuture<String> future = new CompletableFuture<>(); + future.completeExceptionally(new RuntimeException("Expected test exception")); + doReturn(future).when(writer).putAsync(anyString(), any()); + TableRateLimiter rateLimitHelper = mock(TableRateLimiter.class); + RemoteReadWriteTable<String, String> table = new RemoteReadWriteTable<String, String>( + "table1", reader, writer, rateLimitHelper, rateLimitHelper, Executors.newSingleThreadExecutor(), null); table.init(mock(SamzaContainerContext.class), createMockTaskContext()); table.put("abc", "efg"); } http://git-wip-us.apache.org/repos/asf/samza/blob/a08040dc/samza-test/src/test/java/org/apache/samza/test/table/TestTableDescriptorsProvider.java ---------------------------------------------------------------------- diff --git a/samza-test/src/test/java/org/apache/samza/test/table/TestTableDescriptorsProvider.java b/samza-test/src/test/java/org/apache/samza/test/table/TestTableDescriptorsProvider.java index 817fb9f..38cc47c 100644 --- a/samza-test/src/test/java/org/apache/samza/test/table/TestTableDescriptorsProvider.java +++ b/samza-test/src/test/java/org/apache/samza/test/table/TestTableDescriptorsProvider.java @@ -81,7 +81,8 @@ public class TestTableDescriptorsProvider { String tableRewriterName = "tableRewriter"; configs.put("tables.descriptors.provider.class", MySampleTableDescriptorsProvider.class.getName()); Config resultConfig = new MySampleTableConfigRewriter().rewrite(tableRewriterName, new MapConfig(configs)); - Assert.assertTrue(resultConfig.size() == 18); + Assert.assertNotNull(resultConfig); + Assert.assertTrue(!resultConfig.isEmpty()); String localTableId = "local-table-1"; String remoteTableId = "remote-table-1";