Repository: samza Updated Branches: refs/heads/master 6af104ea3 -> 963bd3085
SAMZA-1597: Expose an interface for throttling Project: http://git-wip-us.apache.org/repos/asf/samza/repo Commit: http://git-wip-us.apache.org/repos/asf/samza/commit/963bd308 Tree: http://git-wip-us.apache.org/repos/asf/samza/tree/963bd308 Diff: http://git-wip-us.apache.org/repos/asf/samza/diff/963bd308 Branch: refs/heads/master Commit: 963bd3085e0284b40917c3ddfd02e70f882d1a19 Parents: 6af104e Author: Wei Song <ws...@linkedin.com> Authored: Tue Feb 27 11:47:51 2018 -0800 Committer: Jagadish <jvenkatra...@linkedin.com> Committed: Tue Feb 27 11:47:51 2018 -0800 ---------------------------------------------------------------------- .../java/org/apache/samza/util/RateLimiter.java | 120 ++++++++++++++ .../apache/samza/util/EmbeddedRateLimiter.java | 98 +++++++++++ .../samza/util/EmbeddedTaggedRateLimiter.java | 138 ++++++++++++++++ .../samza/util/TestEmbeddedRateLimiter.java | 155 ++++++++++++++++++ .../util/TestEmbeddedTaggedRateLimiter.java | 161 +++++++++++++++++++ 5 files changed, 672 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/samza/blob/963bd308/samza-api/src/main/java/org/apache/samza/util/RateLimiter.java ---------------------------------------------------------------------- diff --git a/samza-api/src/main/java/org/apache/samza/util/RateLimiter.java b/samza-api/src/main/java/org/apache/samza/util/RateLimiter.java new file mode 100644 index 0000000..75818dd --- /dev/null +++ b/samza-api/src/main/java/org/apache/samza/util/RateLimiter.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.samza.util; + +import java.io.Serializable; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import org.apache.samza.annotation.InterfaceStability; +import org.apache.samza.config.Config; +import org.apache.samza.task.TaskContext; + +/** + * A rate limiter interface used by Samza components to limit throughput of operations + * against a resource. Operations against a resource are represented by credits. + * Resources could be streams, databases, web services, etc. + * + * <p> + * This interface supports two categories of policies: tagged and non-tagged. + * Tagged rate limiter is used, when further differentiation is required within a resource. + * For example: messages in a stream may be treated differently depending on the + * overall situation of processing; or read/write operations to a database. + * Tagging is the mechanism to allow this differentiation. + * + * <p> + * The following types of invocations are provided + * <ul> + * <li>Block indefinitely until requested credits become available</li> + * <li>Block for a provided amount of time, then return available credits</li> + * <li>Non-blocking, returns immediately available credits</li> + * </ul> + * + */ +@InterfaceStability.Unstable +public interface RateLimiter extends Serializable { + + /** + * Initialize this rate limiter, this method should be called during container initialization. + * + * @param config job configuration + * @param taskContext task context that owns this rate limiter + */ + void init(Config config, TaskContext taskContext); + + /** + * Attempt to acquire the provided number of credits, blocks indefinitely until + * all requested credits become available. + * + * @param numberOfCredit requested number of credits + */ + void acquire(int numberOfCredit); + + /** + * Attempt to acquire the provided number of credits, blocks for up to provided amount of + * time for credits to become available. When timeout elapses and not all required credits + * can be acquired, it returns the number of credits currently available. It may return + * immediately, if it determines no credits can be acquired during the provided amount time. + * + * @param numberOfCredit requested number of credits + * @param timeout number of time unit to wait + * @param unit time unit to for timeout + * @return number of credits acquired + */ + int acquire(int numberOfCredit, long timeout, TimeUnit unit); + + /** + * Attempt to acquire the provided number of credits, returns immediately number of + * credits acquired. + * + * @param numberOfCredit requested number of credits + * @return number of credits acquired + */ + int tryAcquire(int numberOfCredit); + + /** + * Attempt to acquire the provided number of credits for a number of tags, blocks indefinitely + * until all requested credits become available + * + * @param tagToCreditMap a map of requested number of credits keyed by tag + */ + void acquire(Map<String, Integer> tagToCreditMap); + + /** + * Attempt to acquire the provided number of credits for a number of tags, blocks for up to provided amount of + * time for credits to become available. When timeout elapses and not all required credits + * can be acquired, it returns the number of credits currently available. It may return + * immediately, if it determines no credits can be acquired during the provided amount time. + * + * @param tagToCreditMap a map of requested number of credits keyed by tag + * @param timeout number of time unit to wait + * @param unit time unit to for timeout + * @return a map of number of credits acquired keyed by tag + */ + Map<String, Integer> acquire(Map<String, Integer> tagToCreditMap, long timeout, TimeUnit unit); + + /** + * Attempt to acquire the provided number of credits for a number of tags, returns immediately number of + * credits acquired. + * + * @param tagToCreditMap a map of requested number of credits keyed by tag + * @return a map of number of credits acquired keyed by tag + */ + Map<String, Integer> tryAcquire(Map<String, Integer> tagToCreditMap); +} http://git-wip-us.apache.org/repos/asf/samza/blob/963bd308/samza-core/src/main/java/org/apache/samza/util/EmbeddedRateLimiter.java ---------------------------------------------------------------------- diff --git a/samza-core/src/main/java/org/apache/samza/util/EmbeddedRateLimiter.java b/samza-core/src/main/java/org/apache/samza/util/EmbeddedRateLimiter.java new file mode 100644 index 0000000..9ccf2f4 --- /dev/null +++ b/samza-core/src/main/java/org/apache/samza/util/EmbeddedRateLimiter.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.samza.util; + +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import org.apache.samza.config.Config; +import org.apache.samza.task.TaskContext; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.base.Preconditions; + + +/** + * An embedded rate limiter + */ +public class EmbeddedRateLimiter implements RateLimiter { + + static final private Logger LOGGER = LoggerFactory.getLogger(EmbeddedRateLimiter.class); + + private final int targetRate; + private com.google.common.util.concurrent.RateLimiter rateLimiter; + + public EmbeddedRateLimiter(int creditsPerSecond) { + this.targetRate = creditsPerSecond; + } + + @Override + public void acquire(int numberOfCredits) { + ensureInitialized(); + rateLimiter.acquire(numberOfCredits); + } + + @Override + public int acquire(int numberOfCredits, long timeout, TimeUnit unit) { + ensureInitialized(); + return rateLimiter.tryAcquire(numberOfCredits, timeout, unit) + ? numberOfCredits + : 0; + } + + @Override + public int tryAcquire(int numberOfCredits) { + ensureInitialized(); + return rateLimiter.tryAcquire(numberOfCredits) + ? numberOfCredits + : 0; + } + + @Override + public void acquire(Map<String, Integer> tagToCreditsMap) { + throw new IllegalArgumentException("This method is not applicable"); + } + + @Override + public Map<String, Integer> acquire(Map<String, Integer> tagToCreditsMap, long timeout, TimeUnit unit) { + throw new IllegalArgumentException("This method is not applicable"); + } + + @Override + public Map<String, Integer> tryAcquire(Map<String, Integer> tagToCreditsMap) { + throw new IllegalArgumentException("This method is not applicable"); + } + + @Override + public void init(Config config, TaskContext taskContext) { + int effectiveRate = targetRate; + if (taskContext != null) { + effectiveRate /= taskContext.getSamzaContainerContext().taskNames.size(); + LOGGER.info(String.format("Effective rate limit for task %s is %d", + taskContext.getTaskName(), effectiveRate)); + } + this.rateLimiter = com.google.common.util.concurrent.RateLimiter.create(effectiveRate); + } + + private void ensureInitialized() { + Preconditions.checkState(rateLimiter != null, "Not initialized"); + } + +} http://git-wip-us.apache.org/repos/asf/samza/blob/963bd308/samza-core/src/main/java/org/apache/samza/util/EmbeddedTaggedRateLimiter.java ---------------------------------------------------------------------- diff --git a/samza-core/src/main/java/org/apache/samza/util/EmbeddedTaggedRateLimiter.java b/samza-core/src/main/java/org/apache/samza/util/EmbeddedTaggedRateLimiter.java new file mode 100644 index 0000000..9c20eee --- /dev/null +++ b/samza-core/src/main/java/org/apache/samza/util/EmbeddedTaggedRateLimiter.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.samza.util; + +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.samza.config.Config; +import org.apache.samza.task.TaskContext; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.base.Preconditions; +import com.google.common.base.Stopwatch; + +import static java.util.concurrent.TimeUnit.NANOSECONDS; + + +/** + * An embedded rate limiter that supports tags + */ +public class EmbeddedTaggedRateLimiter implements RateLimiter { + + static final private Logger LOGGER = LoggerFactory.getLogger(EmbeddedTaggedRateLimiter.class); + + private final Map<String, Integer> tagToTargetRateMap; + private Map<String, com.google.common.util.concurrent.RateLimiter> tagToRateLimiterMap; + + public EmbeddedTaggedRateLimiter(Map<String, Integer> tagToCreditsPerSecondMap) { + Preconditions.checkArgument(tagToCreditsPerSecondMap.size() > 0, "Map of tags can't be empty"); + this.tagToTargetRateMap = tagToCreditsPerSecondMap; + } + + @Override + public void acquire(Map<String, Integer> tagToCreditsMap) { + ensureTagsAreValid(tagToCreditsMap); + tagToCreditsMap.forEach((tag, numberOfCredits) -> tagToRateLimiterMap.get(tag).acquire(numberOfCredits)); + } + + @Override + public Map<String, Integer> acquire(Map<String, Integer> tagToCreditsMap, long timeout, TimeUnit unit) { + ensureTagsAreValid(tagToCreditsMap); + + long timeoutInNanos = NANOSECONDS.convert(timeout, unit); + + Stopwatch stopwatch = Stopwatch.createStarted(); + return tagToCreditsMap.entrySet().stream() + .map(e -> { + String tag = e.getKey(); + int requiredCredits = e.getValue(); + long remainingTimeoutInNanos = Math.max(0L, timeoutInNanos - stopwatch.elapsed(NANOSECONDS)); + com.google.common.util.concurrent.RateLimiter rateLimiter = tagToRateLimiterMap.get(tag); + int availableCredits = rateLimiter.tryAcquire(requiredCredits, remainingTimeoutInNanos, NANOSECONDS) + ? requiredCredits + : 0; + return new ImmutablePair<String, Integer>(tag, availableCredits); + }) + .collect(Collectors.toMap(ImmutablePair::getKey, ImmutablePair::getValue)); + } + + @Override + public Map<String, Integer> tryAcquire(Map<String, Integer> tagToCreditsMap) { + ensureTagsAreValid(tagToCreditsMap); + return tagToCreditsMap.entrySet().stream() + .map(e -> { + String tag = e.getKey(); + int requiredCredits = e.getValue(); + int availableCredits = tagToRateLimiterMap.get(tag).tryAcquire(requiredCredits) + ? requiredCredits + : 0; + return new ImmutablePair<String, Integer>(tag, availableCredits); + }) + .collect(Collectors.toMap(ImmutablePair::getKey, ImmutablePair::getValue)); + } + + @Override + public void acquire(int numberOfCredits) { + throw new IllegalArgumentException("This method is not applicable"); + } + + @Override + public int acquire(int numberOfCredit, long timeout, TimeUnit unit) { + throw new IllegalArgumentException("This method is not applicable"); + } + + @Override + public int tryAcquire(int numberOfCredit) { + throw new IllegalArgumentException("This method is not applicable"); + } + + @Override + public void init(Config config, TaskContext taskContext) { + this.tagToRateLimiterMap = Collections.unmodifiableMap(tagToTargetRateMap.entrySet().stream() + .map(e -> { + String tag = e.getKey(); + int effectiveRate = e.getValue(); + if (taskContext != null) { + effectiveRate /= taskContext.getSamzaContainerContext().taskNames.size(); + LOGGER.info(String.format("Effective rate limit for task %s and tag %s is %d", + taskContext.getTaskName(), tag, effectiveRate)); + } + return new ImmutablePair<String, com.google.common.util.concurrent.RateLimiter>( + tag, com.google.common.util.concurrent.RateLimiter.create(effectiveRate)); + }) + .collect(Collectors.toMap(ImmutablePair::getKey, ImmutablePair::getValue)) + ); + } + + private void ensureInitialized() { + Preconditions.checkState(tagToRateLimiterMap != null, "Not initialized"); + } + + private void ensureTagsAreValid(Map<String, ?> tagMap) { + ensureInitialized(); + tagMap.keySet().forEach(tag -> + Preconditions.checkArgument(tagToRateLimiterMap.containsKey(tag), "Invalid tag: " + tag)); + } + +} http://git-wip-us.apache.org/repos/asf/samza/blob/963bd308/samza-core/src/test/java/org/apache/samza/util/TestEmbeddedRateLimiter.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/util/TestEmbeddedRateLimiter.java b/samza-core/src/test/java/org/apache/samza/util/TestEmbeddedRateLimiter.java new file mode 100644 index 0000000..1b3f687 --- /dev/null +++ b/samza-core/src/test/java/org/apache/samza/util/TestEmbeddedRateLimiter.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.samza.util; + +import java.lang.reflect.Field; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; + +import org.apache.samza.SamzaException; +import org.apache.samza.config.Config; +import org.apache.samza.container.SamzaContainerContext; +import org.apache.samza.task.TaskContext; +import org.junit.Test; + +import junit.framework.Assert; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + + +public class TestEmbeddedRateLimiter { + + final static private int TEST_INTERVAL = 200; // ms + final static private int TARGET_RATE = 4000; + final static private int NUMBER_OF_TASKS = 2; + final static private int TARGET_RATE_PER_TASK = TARGET_RATE / NUMBER_OF_TASKS; + final static private int INCREMENT = 2; + + @Test + public void testAcquire() { + RateLimiter rateLimiter = new EmbeddedRateLimiter(TARGET_RATE); + initRateLimiter(rateLimiter); + + int count = 0; + long start = System.currentTimeMillis(); + while (System.currentTimeMillis() - start < TEST_INTERVAL) { + rateLimiter.acquire(INCREMENT); + count += INCREMENT; + } + + long rate = count * 1000 / TEST_INTERVAL; + verifyRate(rate); + } + + @Test + public void testTryAcquire() { + RateLimiter rateLimiter = new EmbeddedRateLimiter(TARGET_RATE); + initRateLimiter(rateLimiter); + + boolean hasSeenZeros = false; + + int count = 0; + long start = System.currentTimeMillis(); + while (System.currentTimeMillis() - start < TEST_INTERVAL) { + int availableCredits = rateLimiter.tryAcquire(INCREMENT); + if (availableCredits <= 0) { + hasSeenZeros = true; + } else { + count += INCREMENT; + } + } + + long rate = count * 1000 / TEST_INTERVAL; + verifyRate(rate); + Assert.assertTrue(hasSeenZeros); + } + + @Test + public void testAcquireWithTimeout() { + RateLimiter rateLimiter = new EmbeddedRateLimiter(TARGET_RATE); + initRateLimiter(rateLimiter); + + boolean hasSeenZeros = false; + + int count = 0; + int callCount = 0; + long start = System.currentTimeMillis(); + while (System.currentTimeMillis() - start < TEST_INTERVAL) { + ++callCount; + int availableCredits = rateLimiter.acquire(INCREMENT, 20, MILLISECONDS); + if (availableCredits <= 0) { + hasSeenZeros = true; + } else { + count += INCREMENT; + } + } + + long rate = count * 1000 / TEST_INTERVAL; + verifyRate(rate); + Assert.assertTrue(Math.abs(callCount - TARGET_RATE_PER_TASK * TEST_INTERVAL / 1000 / INCREMENT) <= 2); + Assert.assertFalse(hasSeenZeros); + } + + @Test(expected = IllegalStateException.class) + public void testFailsWhenUninitialized() { + new EmbeddedRateLimiter(100).acquire(1); + } + + @Test(expected = IllegalArgumentException.class) + public void testFailsWhenUsingTags() { + RateLimiter rateLimiter = new EmbeddedRateLimiter(10); + initRateLimiter(rateLimiter); + Map<String, Integer> tagToCredits = new HashMap<>(); + tagToCredits.put("red", 1); + tagToCredits.put("green", 1); + rateLimiter.acquire(tagToCredits); + } + + private void verifyRate(long rate) { + // As the actual rate would likely not be exactly the same as target rate, the calculation below + // verifies the actual rate is within 5% of the target rate per task + Assert.assertTrue(Math.abs(rate - TARGET_RATE_PER_TASK) <= TARGET_RATE_PER_TASK * 5 / 100); + } + + static void initRateLimiter(RateLimiter rateLimiter) { + Config config = mock(Config.class); + TaskContext taskContext = mock(TaskContext.class); + SamzaContainerContext containerContext = mockSamzaContainerContext(); + when(taskContext.getSamzaContainerContext()).thenReturn(containerContext); + rateLimiter.init(config, taskContext); + } + + static SamzaContainerContext mockSamzaContainerContext() { + try { + Collection<String> taskNames = mock(Collection.class); + when(taskNames.size()).thenReturn(NUMBER_OF_TASKS); + SamzaContainerContext containerContext = mock(SamzaContainerContext.class); + Field taskNamesField = SamzaContainerContext.class.getDeclaredField("taskNames"); + taskNamesField.setAccessible(true); + taskNamesField.set(containerContext, taskNames); + taskNamesField.setAccessible(false); + return containerContext; + } catch (Exception ex) { + throw new SamzaException(ex); + } + } +} http://git-wip-us.apache.org/repos/asf/samza/blob/963bd308/samza-core/src/test/java/org/apache/samza/util/TestEmbeddedTaggedRateLimiter.java ---------------------------------------------------------------------- diff --git a/samza-core/src/test/java/org/apache/samza/util/TestEmbeddedTaggedRateLimiter.java b/samza-core/src/test/java/org/apache/samza/util/TestEmbeddedTaggedRateLimiter.java new file mode 100644 index 0000000..a295d8f --- /dev/null +++ b/samza-core/src/test/java/org/apache/samza/util/TestEmbeddedTaggedRateLimiter.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.samza.util; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.Assert; +import org.junit.Test; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; + + +public class TestEmbeddedTaggedRateLimiter { + + final static private int TEST_INTERVAL = 200; // ms + final static private int NUMBER_OF_TASKS = 2; + final static private int TARGET_RATE_RED = 1000; + final static private int TARGET_RATE_PER_TASK_RED = TARGET_RATE_RED / NUMBER_OF_TASKS; + final static private int TARGET_RATE_GREEN = 2000; + final static private int TARGET_RATE_PER_TASK_GREEN = TARGET_RATE_GREEN / NUMBER_OF_TASKS; + final static private int INCREMENT = 2; + + @Test + public void testAcquire() { + RateLimiter rateLimiter = createRateLimiter(); + + Map<String, Integer> tagToCount = new HashMap<>(); + tagToCount.put("red", 0); + tagToCount.put("green", 0); + + Map<String, Integer> tagToCredits = new HashMap<>(); + tagToCredits.put("red", INCREMENT); + tagToCredits.put("green", INCREMENT); + + long start = System.currentTimeMillis(); + while (System.currentTimeMillis() - start < TEST_INTERVAL) { + rateLimiter.acquire(tagToCredits); + tagToCount.put("red", tagToCount.get("red") + INCREMENT); + tagToCount.put("green", tagToCount.get("green") + INCREMENT); + } + + { + long rate = tagToCount.get("red") * 1000 / TEST_INTERVAL; + verifyRate(rate, TARGET_RATE_PER_TASK_RED); + } { + // Note: due to blocking, green is capped at red's QPS + long rate = tagToCount.get("green") * 1000 / TEST_INTERVAL; + verifyRate(rate, TARGET_RATE_PER_TASK_RED); + } + } + + @Test + public void testTryAcquire() { + + RateLimiter rateLimiter = createRateLimiter(); + + Map<String, Integer> tagToCount = new HashMap<>(); + tagToCount.put("red", 0); + tagToCount.put("green", 0); + + Map<String, Integer> tagToCredits = new HashMap<>(); + tagToCredits.put("red", INCREMENT); + tagToCredits.put("green", INCREMENT); + + long start = System.currentTimeMillis(); + while (System.currentTimeMillis() - start < TEST_INTERVAL) { + Map<String, Integer> resultMap = rateLimiter.tryAcquire(tagToCredits); + tagToCount.put("red", tagToCount.get("red") + resultMap.get("red")); + tagToCount.put("green", tagToCount.get("green") + resultMap.get("green")); + } + + { + long rate = tagToCount.get("red") * 1000 / TEST_INTERVAL; + verifyRate(rate, TARGET_RATE_PER_TASK_RED); + } { + long rate = tagToCount.get("green") * 1000 / TEST_INTERVAL; + verifyRate(rate, TARGET_RATE_PER_TASK_GREEN); + } + } + + @Test + public void testAcquireWithTimeout() { + + RateLimiter rateLimiter = createRateLimiter(); + + Map<String, Integer> tagToCount = new HashMap<>(); + tagToCount.put("red", 0); + tagToCount.put("green", 0); + + Map<String, Integer> tagToCredits = new HashMap<>(); + tagToCredits.put("red", INCREMENT); + tagToCredits.put("green", INCREMENT); + + long start = System.currentTimeMillis(); + while (System.currentTimeMillis() - start < TEST_INTERVAL) { + Map<String, Integer> resultMap = rateLimiter.acquire(tagToCredits, 20, MILLISECONDS); + tagToCount.put("red", tagToCount.get("red") + resultMap.get("red")); + tagToCount.put("green", tagToCount.get("green") + resultMap.get("green")); + } + + { + long rate = tagToCount.get("red") * 1000 / TEST_INTERVAL; + verifyRate(rate, TARGET_RATE_PER_TASK_RED); + } { + // Note: due to blocking, green is capped at red's QPS + long rate = tagToCount.get("green") * 1000 / TEST_INTERVAL; + verifyRate(rate, TARGET_RATE_PER_TASK_RED); + } + } + + @Test(expected = IllegalStateException.class) + public void testFailsWhenUninitialized() { + Map<String, Integer> tagToTargetRateMap = new HashMap<>(); + tagToTargetRateMap.put("red", 1000); + tagToTargetRateMap.put("green", 2000); + new EmbeddedTaggedRateLimiter(tagToTargetRateMap).acquire(tagToTargetRateMap); + } + + @Test(expected = IllegalArgumentException.class) + public void testFailsWhenNotUsingTags() { + Map<String, Integer> tagToCredits = new HashMap<>(); + tagToCredits.put("red", 1); + tagToCredits.put("green", 1); + RateLimiter rateLimiter = new EmbeddedTaggedRateLimiter(tagToCredits); + TestEmbeddedRateLimiter.initRateLimiter(rateLimiter); + rateLimiter.acquire(1); + } + + private void verifyRate(long rate, long targetRate) { + // As the actual rate would likely not be exactly the same as target rate, the calculation below + // verifies the actual rate is within 5% of the target rate per task + Assert.assertTrue(Math.abs(rate - targetRate) <= targetRate * 5 / 100); + } + + private RateLimiter createRateLimiter() { + Map<String, Integer> tagToTargetRateMap = new HashMap<>(); + tagToTargetRateMap.put("red", TARGET_RATE_RED); + tagToTargetRateMap.put("green", TARGET_RATE_GREEN); + RateLimiter rateLimiter = new EmbeddedTaggedRateLimiter(tagToTargetRateMap); + TestEmbeddedRateLimiter.initRateLimiter(rateLimiter); + return rateLimiter; + } + +}