Repository: samza
Updated Branches:
  refs/heads/master 6af104ea3 -> 963bd3085


SAMZA-1597: Expose an interface for throttling


Project: http://git-wip-us.apache.org/repos/asf/samza/repo
Commit: http://git-wip-us.apache.org/repos/asf/samza/commit/963bd308
Tree: http://git-wip-us.apache.org/repos/asf/samza/tree/963bd308
Diff: http://git-wip-us.apache.org/repos/asf/samza/diff/963bd308

Branch: refs/heads/master
Commit: 963bd3085e0284b40917c3ddfd02e70f882d1a19
Parents: 6af104e
Author: Wei Song <ws...@linkedin.com>
Authored: Tue Feb 27 11:47:51 2018 -0800
Committer: Jagadish <jvenkatra...@linkedin.com>
Committed: Tue Feb 27 11:47:51 2018 -0800

----------------------------------------------------------------------
 .../java/org/apache/samza/util/RateLimiter.java | 120 ++++++++++++++
 .../apache/samza/util/EmbeddedRateLimiter.java  |  98 +++++++++++
 .../samza/util/EmbeddedTaggedRateLimiter.java   | 138 ++++++++++++++++
 .../samza/util/TestEmbeddedRateLimiter.java     | 155 ++++++++++++++++++
 .../util/TestEmbeddedTaggedRateLimiter.java     | 161 +++++++++++++++++++
 5 files changed, 672 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/samza/blob/963bd308/samza-api/src/main/java/org/apache/samza/util/RateLimiter.java
----------------------------------------------------------------------
diff --git a/samza-api/src/main/java/org/apache/samza/util/RateLimiter.java 
b/samza-api/src/main/java/org/apache/samza/util/RateLimiter.java
new file mode 100644
index 0000000..75818dd
--- /dev/null
+++ b/samza-api/src/main/java/org/apache/samza/util/RateLimiter.java
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.util;
+
+import java.io.Serializable;
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
+
+import org.apache.samza.annotation.InterfaceStability;
+import org.apache.samza.config.Config;
+import org.apache.samza.task.TaskContext;
+
+/**
+ * A rate limiter interface used by Samza components to limit throughput of 
operations
+ * against a resource. Operations against a resource are represented by 
credits.
+ * Resources could be streams, databases, web services, etc.
+ *
+ * <p>
+ * This interface supports two categories of policies: tagged and non-tagged.
+ * Tagged rate limiter is used, when further differentiation is required 
within a resource.
+ * For example: messages in a stream may be treated differently depending on 
the
+ * overall situation of processing; or read/write operations to a database.
+ * Tagging is the mechanism to allow this differentiation.
+ *
+ * <p>
+ * The following types of invocations are provided
+ * <ul>
+ *   <li>Block indefinitely until requested credits become available</li>
+ *   <li>Block for a provided amount of time, then return available 
credits</li>
+ *   <li>Non-blocking, returns immediately available credits</li>
+ * </ul>
+ *
+ */
+@InterfaceStability.Unstable
+public interface RateLimiter extends Serializable {
+
+  /**
+   * Initialize this rate limiter, this method should be called during 
container initialization.
+   *
+   * @param config job configuration
+   * @param taskContext task context that owns this rate limiter
+   */
+  void init(Config config, TaskContext taskContext);
+
+  /**
+   * Attempt to acquire the provided number of credits, blocks indefinitely 
until
+   * all requested credits become available.
+   *
+   * @param numberOfCredit requested number of credits
+   */
+  void acquire(int numberOfCredit);
+
+  /**
+   * Attempt to acquire the provided number of credits, blocks for up to 
provided amount of
+   * time for credits to become available. When timeout elapses and not all 
required credits
+   * can be acquired, it returns the number of credits currently available. It 
may return
+   * immediately, if it determines no credits can be acquired during the 
provided amount time.
+   *
+   * @param numberOfCredit requested number of credits
+   * @param timeout number of time unit to wait
+   * @param unit time unit to for timeout
+   * @return number of credits acquired
+   */
+  int acquire(int numberOfCredit, long timeout, TimeUnit unit);
+
+  /**
+   * Attempt to acquire the provided number of credits, returns immediately 
number of
+   * credits acquired.
+   *
+   * @param numberOfCredit requested number of credits
+   * @return number of credits acquired
+   */
+  int tryAcquire(int numberOfCredit);
+
+  /**
+   * Attempt to acquire the provided number of credits for a number of tags, 
blocks indefinitely
+   * until all requested credits become available
+   *
+   * @param tagToCreditMap a map of requested number of credits keyed by tag
+   */
+  void acquire(Map<String, Integer> tagToCreditMap);
+
+  /**
+   * Attempt to acquire the provided number of credits for a number of tags, 
blocks for up to provided amount of
+   * time for credits to become available. When timeout elapses and not all 
required credits
+   * can be acquired, it returns the number of credits currently available. It 
may return
+   * immediately, if it determines no credits can be acquired during the 
provided amount time.
+   *
+   * @param tagToCreditMap a map of requested number of credits keyed by tag
+   * @param timeout number of time unit to wait
+   * @param unit time unit to for timeout
+   * @return a map of number of credits acquired keyed by tag
+   */
+  Map<String, Integer> acquire(Map<String, Integer> tagToCreditMap, long 
timeout, TimeUnit unit);
+
+  /**
+   * Attempt to acquire the provided number of credits for a number of tags, 
returns immediately number of
+   * credits acquired.
+   *
+   * @param tagToCreditMap a map of requested number of credits keyed by tag
+   * @return a map of number of credits acquired keyed by tag
+   */
+  Map<String, Integer> tryAcquire(Map<String, Integer> tagToCreditMap);
+}

http://git-wip-us.apache.org/repos/asf/samza/blob/963bd308/samza-core/src/main/java/org/apache/samza/util/EmbeddedRateLimiter.java
----------------------------------------------------------------------
diff --git 
a/samza-core/src/main/java/org/apache/samza/util/EmbeddedRateLimiter.java 
b/samza-core/src/main/java/org/apache/samza/util/EmbeddedRateLimiter.java
new file mode 100644
index 0000000..9ccf2f4
--- /dev/null
+++ b/samza-core/src/main/java/org/apache/samza/util/EmbeddedRateLimiter.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.util;
+
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
+
+import org.apache.samza.config.Config;
+import org.apache.samza.task.TaskContext;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.base.Preconditions;
+
+
+/**
+ * An embedded rate limiter
+ */
+public class EmbeddedRateLimiter implements RateLimiter {
+
+  static final private Logger LOGGER = 
LoggerFactory.getLogger(EmbeddedRateLimiter.class);
+
+  private final int targetRate;
+  private com.google.common.util.concurrent.RateLimiter rateLimiter;
+
+  public EmbeddedRateLimiter(int creditsPerSecond) {
+    this.targetRate = creditsPerSecond;
+  }
+
+  @Override
+  public void acquire(int numberOfCredits) {
+    ensureInitialized();
+    rateLimiter.acquire(numberOfCredits);
+  }
+
+  @Override
+  public int acquire(int numberOfCredits, long timeout, TimeUnit unit) {
+    ensureInitialized();
+    return rateLimiter.tryAcquire(numberOfCredits, timeout, unit)
+        ? numberOfCredits
+        : 0;
+  }
+
+  @Override
+  public int tryAcquire(int numberOfCredits) {
+    ensureInitialized();
+    return rateLimiter.tryAcquire(numberOfCredits)
+        ? numberOfCredits
+        : 0;
+  }
+
+  @Override
+  public void acquire(Map<String, Integer> tagToCreditsMap) {
+    throw new IllegalArgumentException("This method is not applicable");
+  }
+
+  @Override
+  public Map<String, Integer> acquire(Map<String, Integer> tagToCreditsMap, 
long timeout, TimeUnit unit) {
+    throw new IllegalArgumentException("This method is not applicable");
+  }
+
+  @Override
+  public Map<String, Integer> tryAcquire(Map<String, Integer> tagToCreditsMap) 
{
+    throw new IllegalArgumentException("This method is not applicable");
+  }
+
+  @Override
+  public void init(Config config, TaskContext taskContext) {
+    int effectiveRate = targetRate;
+    if (taskContext != null) {
+      effectiveRate /= taskContext.getSamzaContainerContext().taskNames.size();
+      LOGGER.info(String.format("Effective rate limit for task %s is %d",
+          taskContext.getTaskName(), effectiveRate));
+    }
+    this.rateLimiter = 
com.google.common.util.concurrent.RateLimiter.create(effectiveRate);
+  }
+
+  private void ensureInitialized() {
+    Preconditions.checkState(rateLimiter != null, "Not initialized");
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/samza/blob/963bd308/samza-core/src/main/java/org/apache/samza/util/EmbeddedTaggedRateLimiter.java
----------------------------------------------------------------------
diff --git 
a/samza-core/src/main/java/org/apache/samza/util/EmbeddedTaggedRateLimiter.java 
b/samza-core/src/main/java/org/apache/samza/util/EmbeddedTaggedRateLimiter.java
new file mode 100644
index 0000000..9c20eee
--- /dev/null
+++ 
b/samza-core/src/main/java/org/apache/samza/util/EmbeddedTaggedRateLimiter.java
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.util;
+
+import java.util.Collections;
+import java.util.Map;
+import java.util.concurrent.TimeUnit;
+import java.util.stream.Collectors;
+
+import org.apache.commons.lang3.tuple.ImmutablePair;
+import org.apache.samza.config.Config;
+import org.apache.samza.task.TaskContext;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.base.Preconditions;
+import com.google.common.base.Stopwatch;
+
+import static java.util.concurrent.TimeUnit.NANOSECONDS;
+
+
+/**
+ * An embedded rate limiter that supports tags
+ */
+public class EmbeddedTaggedRateLimiter implements RateLimiter {
+
+  static final private Logger LOGGER = 
LoggerFactory.getLogger(EmbeddedTaggedRateLimiter.class);
+
+  private final Map<String, Integer> tagToTargetRateMap;
+  private Map<String, com.google.common.util.concurrent.RateLimiter> 
tagToRateLimiterMap;
+
+  public EmbeddedTaggedRateLimiter(Map<String, Integer> 
tagToCreditsPerSecondMap) {
+    Preconditions.checkArgument(tagToCreditsPerSecondMap.size() > 0, "Map of 
tags can't be empty");
+    this.tagToTargetRateMap = tagToCreditsPerSecondMap;
+  }
+
+  @Override
+  public void acquire(Map<String, Integer> tagToCreditsMap) {
+    ensureTagsAreValid(tagToCreditsMap);
+    tagToCreditsMap.forEach((tag, numberOfCredits) -> 
tagToRateLimiterMap.get(tag).acquire(numberOfCredits));
+  }
+
+  @Override
+  public Map<String, Integer> acquire(Map<String, Integer> tagToCreditsMap, 
long timeout, TimeUnit unit) {
+    ensureTagsAreValid(tagToCreditsMap);
+
+    long timeoutInNanos = NANOSECONDS.convert(timeout, unit);
+
+    Stopwatch stopwatch = Stopwatch.createStarted();
+    return tagToCreditsMap.entrySet().stream()
+        .map(e -> {
+            String tag = e.getKey();
+            int requiredCredits = e.getValue();
+            long remainingTimeoutInNanos = Math.max(0L, timeoutInNanos - 
stopwatch.elapsed(NANOSECONDS));
+            com.google.common.util.concurrent.RateLimiter rateLimiter = 
tagToRateLimiterMap.get(tag);
+            int availableCredits = rateLimiter.tryAcquire(requiredCredits, 
remainingTimeoutInNanos, NANOSECONDS)
+                ? requiredCredits
+                : 0;
+            return new ImmutablePair<String, Integer>(tag, availableCredits);
+          })
+        .collect(Collectors.toMap(ImmutablePair::getKey, 
ImmutablePair::getValue));
+  }
+
+  @Override
+  public Map<String, Integer> tryAcquire(Map<String, Integer> tagToCreditsMap) 
{
+    ensureTagsAreValid(tagToCreditsMap);
+    return tagToCreditsMap.entrySet().stream()
+        .map(e -> {
+            String tag = e.getKey();
+            int requiredCredits = e.getValue();
+            int availableCredits = 
tagToRateLimiterMap.get(tag).tryAcquire(requiredCredits)
+                ? requiredCredits
+                : 0;
+            return new ImmutablePair<String, Integer>(tag, availableCredits);
+          })
+        .collect(Collectors.toMap(ImmutablePair::getKey, 
ImmutablePair::getValue));
+  }
+
+  @Override
+  public void acquire(int numberOfCredits) {
+    throw new IllegalArgumentException("This method is not applicable");
+  }
+
+  @Override
+  public int acquire(int numberOfCredit, long timeout, TimeUnit unit) {
+    throw new IllegalArgumentException("This method is not applicable");
+  }
+
+  @Override
+  public int tryAcquire(int numberOfCredit) {
+    throw new IllegalArgumentException("This method is not applicable");
+  }
+
+  @Override
+  public void init(Config config, TaskContext taskContext) {
+    this.tagToRateLimiterMap = 
Collections.unmodifiableMap(tagToTargetRateMap.entrySet().stream()
+        .map(e -> {
+            String tag = e.getKey();
+            int effectiveRate = e.getValue();
+            if (taskContext != null) {
+              effectiveRate /= 
taskContext.getSamzaContainerContext().taskNames.size();
+              LOGGER.info(String.format("Effective rate limit for task %s and 
tag %s is %d",
+                  taskContext.getTaskName(), tag, effectiveRate));
+            }
+            return new ImmutablePair<String, 
com.google.common.util.concurrent.RateLimiter>(
+                tag, 
com.google.common.util.concurrent.RateLimiter.create(effectiveRate));
+          })
+        .collect(Collectors.toMap(ImmutablePair::getKey, 
ImmutablePair::getValue))
+    );
+  }
+
+  private void ensureInitialized() {
+    Preconditions.checkState(tagToRateLimiterMap != null, "Not initialized");
+  }
+
+  private void ensureTagsAreValid(Map<String, ?> tagMap) {
+    ensureInitialized();
+    tagMap.keySet().forEach(tag ->
+        Preconditions.checkArgument(tagToRateLimiterMap.containsKey(tag), 
"Invalid tag: " + tag));
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/samza/blob/963bd308/samza-core/src/test/java/org/apache/samza/util/TestEmbeddedRateLimiter.java
----------------------------------------------------------------------
diff --git 
a/samza-core/src/test/java/org/apache/samza/util/TestEmbeddedRateLimiter.java 
b/samza-core/src/test/java/org/apache/samza/util/TestEmbeddedRateLimiter.java
new file mode 100644
index 0000000..1b3f687
--- /dev/null
+++ 
b/samza-core/src/test/java/org/apache/samza/util/TestEmbeddedRateLimiter.java
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.util;
+
+import java.lang.reflect.Field;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.apache.samza.SamzaException;
+import org.apache.samza.config.Config;
+import org.apache.samza.container.SamzaContainerContext;
+import org.apache.samza.task.TaskContext;
+import org.junit.Test;
+
+import junit.framework.Assert;
+
+import static java.util.concurrent.TimeUnit.MILLISECONDS;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+
+public class TestEmbeddedRateLimiter {
+
+  final static private int TEST_INTERVAL = 200; // ms
+  final static private int TARGET_RATE = 4000;
+  final static private int NUMBER_OF_TASKS = 2;
+  final static private int TARGET_RATE_PER_TASK = TARGET_RATE / 
NUMBER_OF_TASKS;
+  final static private int INCREMENT = 2;
+
+  @Test
+  public void testAcquire() {
+    RateLimiter rateLimiter = new EmbeddedRateLimiter(TARGET_RATE);
+    initRateLimiter(rateLimiter);
+
+    int count = 0;
+    long start = System.currentTimeMillis();
+    while (System.currentTimeMillis() - start < TEST_INTERVAL) {
+      rateLimiter.acquire(INCREMENT);
+      count += INCREMENT;
+    }
+
+    long rate = count * 1000 / TEST_INTERVAL;
+    verifyRate(rate);
+  }
+
+  @Test
+  public void testTryAcquire() {
+    RateLimiter rateLimiter = new EmbeddedRateLimiter(TARGET_RATE);
+    initRateLimiter(rateLimiter);
+
+    boolean hasSeenZeros = false;
+
+    int count = 0;
+    long start = System.currentTimeMillis();
+    while (System.currentTimeMillis() - start < TEST_INTERVAL) {
+      int availableCredits = rateLimiter.tryAcquire(INCREMENT);
+      if (availableCredits <= 0) {
+        hasSeenZeros = true;
+      } else {
+        count += INCREMENT;
+      }
+    }
+
+    long rate = count * 1000 / TEST_INTERVAL;
+    verifyRate(rate);
+    Assert.assertTrue(hasSeenZeros);
+  }
+
+  @Test
+  public void testAcquireWithTimeout() {
+    RateLimiter rateLimiter = new EmbeddedRateLimiter(TARGET_RATE);
+    initRateLimiter(rateLimiter);
+
+    boolean hasSeenZeros = false;
+
+    int count = 0;
+    int callCount = 0;
+    long start = System.currentTimeMillis();
+    while (System.currentTimeMillis() - start < TEST_INTERVAL) {
+      ++callCount;
+      int availableCredits = rateLimiter.acquire(INCREMENT, 20, MILLISECONDS);
+      if (availableCredits <= 0) {
+        hasSeenZeros = true;
+      } else {
+        count += INCREMENT;
+      }
+    }
+
+    long rate = count * 1000 / TEST_INTERVAL;
+    verifyRate(rate);
+    Assert.assertTrue(Math.abs(callCount - TARGET_RATE_PER_TASK * 
TEST_INTERVAL / 1000 / INCREMENT) <= 2);
+    Assert.assertFalse(hasSeenZeros);
+  }
+
+  @Test(expected = IllegalStateException.class)
+  public void testFailsWhenUninitialized() {
+    new EmbeddedRateLimiter(100).acquire(1);
+  }
+
+  @Test(expected = IllegalArgumentException.class)
+  public void testFailsWhenUsingTags() {
+    RateLimiter rateLimiter = new EmbeddedRateLimiter(10);
+    initRateLimiter(rateLimiter);
+    Map<String, Integer> tagToCredits = new HashMap<>();
+    tagToCredits.put("red", 1);
+    tagToCredits.put("green", 1);
+    rateLimiter.acquire(tagToCredits);
+  }
+
+  private void verifyRate(long rate) {
+    // As the actual rate would likely not be exactly the same as target rate, 
the calculation below
+    // verifies the actual rate is within 5% of the target rate per task
+    Assert.assertTrue(Math.abs(rate - TARGET_RATE_PER_TASK) <= 
TARGET_RATE_PER_TASK * 5 / 100);
+  }
+
+  static void initRateLimiter(RateLimiter rateLimiter) {
+    Config config = mock(Config.class);
+    TaskContext taskContext = mock(TaskContext.class);
+    SamzaContainerContext containerContext = mockSamzaContainerContext();
+    when(taskContext.getSamzaContainerContext()).thenReturn(containerContext);
+    rateLimiter.init(config, taskContext);
+  }
+
+  static SamzaContainerContext mockSamzaContainerContext() {
+    try {
+      Collection<String> taskNames = mock(Collection.class);
+      when(taskNames.size()).thenReturn(NUMBER_OF_TASKS);
+      SamzaContainerContext containerContext = 
mock(SamzaContainerContext.class);
+      Field taskNamesField = 
SamzaContainerContext.class.getDeclaredField("taskNames");
+      taskNamesField.setAccessible(true);
+      taskNamesField.set(containerContext, taskNames);
+      taskNamesField.setAccessible(false);
+      return containerContext;
+    } catch (Exception ex) {
+      throw new SamzaException(ex);
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/samza/blob/963bd308/samza-core/src/test/java/org/apache/samza/util/TestEmbeddedTaggedRateLimiter.java
----------------------------------------------------------------------
diff --git 
a/samza-core/src/test/java/org/apache/samza/util/TestEmbeddedTaggedRateLimiter.java
 
b/samza-core/src/test/java/org/apache/samza/util/TestEmbeddedTaggedRateLimiter.java
new file mode 100644
index 0000000..a295d8f
--- /dev/null
+++ 
b/samza-core/src/test/java/org/apache/samza/util/TestEmbeddedTaggedRateLimiter.java
@@ -0,0 +1,161 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.samza.util;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import static java.util.concurrent.TimeUnit.MILLISECONDS;
+
+
+public class TestEmbeddedTaggedRateLimiter {
+
+  final static private int TEST_INTERVAL = 200; // ms
+  final static private int NUMBER_OF_TASKS = 2;
+  final static private int TARGET_RATE_RED = 1000;
+  final static private int TARGET_RATE_PER_TASK_RED = TARGET_RATE_RED / 
NUMBER_OF_TASKS;
+  final static private int TARGET_RATE_GREEN = 2000;
+  final static private int TARGET_RATE_PER_TASK_GREEN = TARGET_RATE_GREEN / 
NUMBER_OF_TASKS;
+  final static private int INCREMENT = 2;
+
+  @Test
+  public void testAcquire() {
+    RateLimiter rateLimiter = createRateLimiter();
+
+    Map<String, Integer> tagToCount = new HashMap<>();
+    tagToCount.put("red", 0);
+    tagToCount.put("green", 0);
+
+    Map<String, Integer> tagToCredits = new HashMap<>();
+    tagToCredits.put("red", INCREMENT);
+    tagToCredits.put("green", INCREMENT);
+
+    long start = System.currentTimeMillis();
+    while (System.currentTimeMillis() - start < TEST_INTERVAL) {
+      rateLimiter.acquire(tagToCredits);
+      tagToCount.put("red", tagToCount.get("red") + INCREMENT);
+      tagToCount.put("green", tagToCount.get("green") + INCREMENT);
+    }
+
+    {
+      long rate = tagToCount.get("red") * 1000 / TEST_INTERVAL;
+      verifyRate(rate, TARGET_RATE_PER_TASK_RED);
+    } {
+      // Note: due to blocking, green is capped at red's QPS
+      long rate = tagToCount.get("green") * 1000 / TEST_INTERVAL;
+      verifyRate(rate, TARGET_RATE_PER_TASK_RED);
+    }
+  }
+
+  @Test
+  public void testTryAcquire() {
+
+    RateLimiter rateLimiter = createRateLimiter();
+
+    Map<String, Integer> tagToCount = new HashMap<>();
+    tagToCount.put("red", 0);
+    tagToCount.put("green", 0);
+
+    Map<String, Integer> tagToCredits = new HashMap<>();
+    tagToCredits.put("red", INCREMENT);
+    tagToCredits.put("green", INCREMENT);
+
+    long start = System.currentTimeMillis();
+    while (System.currentTimeMillis() - start < TEST_INTERVAL) {
+      Map<String, Integer> resultMap = rateLimiter.tryAcquire(tagToCredits);
+      tagToCount.put("red", tagToCount.get("red") + resultMap.get("red"));
+      tagToCount.put("green", tagToCount.get("green") + 
resultMap.get("green"));
+    }
+
+    {
+      long rate = tagToCount.get("red") * 1000 / TEST_INTERVAL;
+      verifyRate(rate, TARGET_RATE_PER_TASK_RED);
+    } {
+      long rate = tagToCount.get("green") * 1000 / TEST_INTERVAL;
+      verifyRate(rate, TARGET_RATE_PER_TASK_GREEN);
+    }
+  }
+
+  @Test
+  public void testAcquireWithTimeout() {
+
+    RateLimiter rateLimiter = createRateLimiter();
+
+    Map<String, Integer> tagToCount = new HashMap<>();
+    tagToCount.put("red", 0);
+    tagToCount.put("green", 0);
+
+    Map<String, Integer> tagToCredits = new HashMap<>();
+    tagToCredits.put("red", INCREMENT);
+    tagToCredits.put("green", INCREMENT);
+
+    long start = System.currentTimeMillis();
+    while (System.currentTimeMillis() - start < TEST_INTERVAL) {
+      Map<String, Integer> resultMap = rateLimiter.acquire(tagToCredits, 20, 
MILLISECONDS);
+      tagToCount.put("red", tagToCount.get("red") + resultMap.get("red"));
+      tagToCount.put("green", tagToCount.get("green") + 
resultMap.get("green"));
+    }
+
+    {
+      long rate = tagToCount.get("red") * 1000 / TEST_INTERVAL;
+      verifyRate(rate, TARGET_RATE_PER_TASK_RED);
+    } {
+      // Note: due to blocking, green is capped at red's QPS
+      long rate = tagToCount.get("green") * 1000 / TEST_INTERVAL;
+      verifyRate(rate, TARGET_RATE_PER_TASK_RED);
+    }
+  }
+
+  @Test(expected = IllegalStateException.class)
+  public void testFailsWhenUninitialized() {
+    Map<String, Integer> tagToTargetRateMap = new HashMap<>();
+    tagToTargetRateMap.put("red", 1000);
+    tagToTargetRateMap.put("green", 2000);
+    new 
EmbeddedTaggedRateLimiter(tagToTargetRateMap).acquire(tagToTargetRateMap);
+  }
+
+  @Test(expected = IllegalArgumentException.class)
+  public void testFailsWhenNotUsingTags() {
+    Map<String, Integer> tagToCredits = new HashMap<>();
+    tagToCredits.put("red", 1);
+    tagToCredits.put("green", 1);
+    RateLimiter rateLimiter = new EmbeddedTaggedRateLimiter(tagToCredits);
+    TestEmbeddedRateLimiter.initRateLimiter(rateLimiter);
+    rateLimiter.acquire(1);
+  }
+
+  private void verifyRate(long rate, long targetRate) {
+    // As the actual rate would likely not be exactly the same as target rate, 
the calculation below
+    // verifies the actual rate is within 5% of the target rate per task
+    Assert.assertTrue(Math.abs(rate - targetRate) <= targetRate * 5 / 100);
+  }
+
+  private RateLimiter createRateLimiter() {
+    Map<String, Integer> tagToTargetRateMap = new HashMap<>();
+    tagToTargetRateMap.put("red", TARGET_RATE_RED);
+    tagToTargetRateMap.put("green", TARGET_RATE_GREEN);
+    RateLimiter rateLimiter = new 
EmbeddedTaggedRateLimiter(tagToTargetRateMap);
+    TestEmbeddedRateLimiter.initRateLimiter(rateLimiter);
+    return rateLimiter;
+  }
+
+}

Reply via email to