Gehel has uploaded a new change for review. (
https://gerrit.wikimedia.org/r/367664 )
Change subject: [WIP] playing around throttling filter
......................................................................
[WIP] playing around throttling filter
Change-Id: If3c0c28c47f953fdb7f3b6186da8a9535cc18bdf
---
M blazegraph/pom.xml
A
blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/Bucketting.java
A
blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/Throttler.java
A
blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/ThrottlingFilter.java
A
blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/ThrottlingState.java
A
blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/UserAgentIpAddressBucketting.java
A
blazegraph/src/test/java/org/wikidata/query/rdf/blazegraph/throttling/ThrottlingStateTest.java
A
blazegraph/src/test/java/org/wikidata/query/rdf/blazegraph/throttling/UserAgentIpAddressBuckettingTest.java
M pom.xml
M war/src/main/webapp/WEB-INF/web.xml
10 files changed, 461 insertions(+), 1 deletion(-)
git pull ssh://gerrit.wikimedia.org:29418/wikidata/query/rdf
refs/changes/64/367664/2
diff --git a/blazegraph/pom.xml b/blazegraph/pom.xml
index 86187ef..d219c25 100644
--- a/blazegraph/pom.xml
+++ b/blazegraph/pom.xml
@@ -74,6 +74,10 @@
<artifactId>jetty-http</artifactId>
</dependency>
<dependency>
+ <groupId>org.isomorphism</groupId>
+ <artifactId>token-bucket</artifactId>
+ </dependency>
+ <dependency>
<groupId>org.linkeddatafragments</groupId>
<artifactId>ldfserver</artifactId>
<classifier>classes</classifier>
diff --git
a/blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/Bucketting.java
b/blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/Bucketting.java
new file mode 100644
index 0000000..0668725
--- /dev/null
+++
b/blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/Bucketting.java
@@ -0,0 +1,10 @@
+package org.wikidata.query.rdf.blazegraph.throttling;
+
+import javax.servlet.http.HttpServletRequest;
+
+/**
+ * Created by gehel on 25.07.17.
+ */
+public interface Bucketting<T extends Object> {
+ T bucket(HttpServletRequest request);
+}
diff --git
a/blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/Throttler.java
b/blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/Throttler.java
new file mode 100644
index 0000000..d2b8c6e
--- /dev/null
+++
b/blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/Throttler.java
@@ -0,0 +1,82 @@
+package org.wikidata.query.rdf.blazegraph.throttling;
+
+import com.google.common.cache.Cache;
+import com.google.common.cache.CacheBuilder;
+import org.isomorphism.util.TokenBuckets;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+import java.time.Duration;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+
+import static java.lang.Math.max;
+import static java.lang.Math.min;
+import static java.time.temporal.ChronoUnit.MILLIS;
+import static java.util.concurrent.TimeUnit.MILLISECONDS;
+import static java.util.concurrent.TimeUnit.MINUTES;
+
+public class Throttler<K extends Object> {
+
+ private static final Logger log = LoggerFactory.getLogger(Throttler.class);
+
+ private final Bucketting<K> bucketting;
+ private final Cache<K, ThrottlingState> state;
+ private final Duration requestTimeThreshold;
+ private Callable<ThrottlingState> createThrottlingState;
+
+ public Throttler(Duration requestTimeThreshold, Bucketting<K> bucketting,
int maxStateEntries, Duration stateTTL, Callable<ThrottlingState>
createThrottlingState) {
+ this.requestTimeThreshold = requestTimeThreshold;
+ this.bucketting = bucketting;
+ this.state = CacheBuilder.newBuilder()
+ .maximumSize(maxStateEntries)
+ .expireAfterAccess(stateTTL.get(MILLIS), MILLISECONDS)
+ .build();
+ this.createThrottlingState = createThrottlingState;
+ }
+
+ public boolean isThrottled(HttpServletRequest request) {
+ ThrottlingState throttlingState =
state.getIfPresent(bucketting.bucket(request));
+ if (throttlingState == null) return false;
+
+ return throttlingState.isThrottled();
+ }
+
+
+ public void success(HttpServletRequest request, HttpServletResponse
response, Duration elapsed) {
+ try {
+ // only start to keep track of time usage if requests are expensive
+ if (elapsed.compareTo(requestTimeThreshold) > 0) {
+ ThrottlingState throttlingState =
state.get(bucketting.bucket(request), createThrottlingState);
+
+ throttlingState.consumeTime(elapsed);
+ }
+ } catch (ExecutionException ee) {
+ log.warn("Could not create throttling state", ee);
+ }
+ }
+
+ public void failure(HttpServletRequest request, HttpServletResponse
response, Duration elapsed, Exception e) {
+ try {
+ ThrottlingState throttlingState =
state.get(bucketting.bucket(request), createThrottlingState);
+
+ throttlingState.consumeError();
+ throttlingState.consumeTime(elapsed);
+ } catch (ExecutionException ee) {
+ log.warn("Could not create throttling state", ee);
+ }
+ }
+
+ public void throttled() {
+ }
+
+ public Duration getBackoffDelay(HttpServletRequest request) {
+ ThrottlingState throttlingState =
state.getIfPresent(bucketting.bucket(request));
+ if (throttlingState == null) return Duration.of(0, MILLIS);
+
+ return throttlingState.getBackoffDelay();
+ }
+
+}
diff --git
a/blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/ThrottlingFilter.java
b/blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/ThrottlingFilter.java
new file mode 100644
index 0000000..de8e4da
--- /dev/null
+++
b/blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/ThrottlingFilter.java
@@ -0,0 +1,82 @@
+package org.wikidata.query.rdf.blazegraph.throttling;
+
+
+import com.google.common.base.Stopwatch;
+import org.isomorphism.util.TokenBuckets;
+import
org.wikidata.query.rdf.blazegraph.throttling.UserAgentIpAddressBucketting.Bucket;
+
+import javax.servlet.Filter;
+import javax.servlet.FilterChain;
+import javax.servlet.FilterConfig;
+import javax.servlet.ServletException;
+import javax.servlet.ServletRequest;
+import javax.servlet.ServletResponse;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+import java.io.IOException;
+import java.time.Duration;
+import java.time.temporal.ChronoUnit;
+
+import static java.time.temporal.ChronoUnit.SECONDS;
+import static java.util.concurrent.TimeUnit.MINUTES;
+
+public class ThrottlingFilter implements Filter {
+
+ private Throttler<Bucket> throttler;
+
+ @Override
+ public void init(FilterConfig filterConfig) throws ServletException {
+ // TODO: we need to extract all those parameters, not sure if we want
+ // to expose them as servlet filter params or in a specific config
+ // file.
+
+ // TODO: the value of the parameters are mostly invented, we need to
+ // find the correct ones.
+
+ throttler = new Throttler<>(
+ Duration.of(30, SECONDS),
+ new UserAgentIpAddressBucketting(),
+ 10000, Duration.of(15, ChronoUnit.MINUTES),
+ () -> new ThrottlingState(
+ TokenBuckets.builder()
+ .withCapacity(10000)
+ .withFixedIntervalRefillStrategy(1000000, 1,
MINUTES)
+ .build(),
+ TokenBuckets.builder()
+ .withCapacity(10)
+ .withFixedIntervalRefillStrategy(100, 1,
MINUTES)
+ .build()));
+ }
+
+ @Override
+ public void doFilter(ServletRequest request, ServletResponse response,
FilterChain chain) throws IOException, ServletException {
+ HttpServletRequest httpRequest = (HttpServletRequest) request;
+ HttpServletResponse httpResponse = (HttpServletResponse) response;
+
+ if (throttler.isThrottled(httpRequest)) {
+ throttler.throttled();
+ notifyUser(httpResponse, throttler.getBackoffDelay(httpRequest));
+ // TODO: we probably want to publish metrics on throttling rate
+ return;
+ }
+
+ Stopwatch stopwatch = Stopwatch.createStarted();
+ try {
+ chain.doFilter(request, response);
+ throttler.success(httpRequest, httpResponse, stopwatch.elapsed());
+ } catch (Exception e) {
+ throttler.failure(httpRequest, httpResponse, stopwatch.elapsed(),
e);
+ throw e;
+ }
+ }
+
+ public void notifyUser(HttpServletResponse response, Duration
backoffDelay) throws IOException {
+ response.setHeader("Retry-After",
Long.toString(backoffDelay.getSeconds()));
+ response.sendError(429, "Too Many Requests - Please lower your request
rate.");
+ }
+
+ @Override
+ public void destroy() {
+ // Nothing to destroy
+ }
+}
diff --git
a/blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/ThrottlingState.java
b/blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/ThrottlingState.java
new file mode 100644
index 0000000..7311a23
--- /dev/null
+++
b/blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/ThrottlingState.java
@@ -0,0 +1,46 @@
+package org.wikidata.query.rdf.blazegraph.throttling;
+
+import org.isomorphism.util.TokenBucket;
+
+import java.time.Duration;
+
+import static java.lang.Math.max;
+import static java.lang.Math.min;
+import static java.time.temporal.ChronoUnit.MILLIS;
+import static java.util.concurrent.TimeUnit.MILLISECONDS;
+
+public final class ThrottlingState {
+
+ private final TokenBucket timeBucket;
+ private final TokenBucket errorsBucket;
+
+ public ThrottlingState(TokenBucket timeBucket, TokenBucket errorsBucket) {
+ this.timeBucket = timeBucket;
+ this.errorsBucket = errorsBucket;
+ }
+
+ public synchronized boolean isThrottled() {
+ return timeBucket.getNumTokens() == 0 || errorsBucket.getNumTokens()
== 0;
+ }
+
+ public synchronized Duration getBackoffDelay() {
+ return Duration.of(
+ max(backoffDelayMillis(timeBucket),
backoffDelayMillis(errorsBucket)),
+ MILLIS
+ );
+ }
+
+ public synchronized void consumeTime(Duration elapsed) {
+ timeBucket.consume(min(elapsed.toMillis(), timeBucket.getNumTokens()));
+ }
+
+ public synchronized void consumeError() {
+ errorsBucket.tryConsume();
+ }
+
+ private static long backoffDelayMillis(TokenBucket bucket) {
+ if (bucket.getNumTokens() > 0) return 0;
+ return bucket.getDurationUntilNextRefill(MILLISECONDS);
+ }
+
+}
diff --git
a/blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/UserAgentIpAddressBucketting.java
b/blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/UserAgentIpAddressBucketting.java
new file mode 100644
index 0000000..22fd77b
--- /dev/null
+++
b/blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/UserAgentIpAddressBucketting.java
@@ -0,0 +1,42 @@
+package org.wikidata.query.rdf.blazegraph.throttling;
+
+import javax.servlet.http.HttpServletRequest;
+import java.util.Objects;
+
+public class UserAgentIpAddressBucketting implements
Bucketting<UserAgentIpAddressBucketting.Bucket> {
+
+
+ @Override
+ public Bucket bucket(HttpServletRequest request) {
+ // TODO: we need to check if Jetty is configured to honor
+ // x-forwarded-for HTTP header, else request.getRemoteAddr() will just
+ // return the IP of the load balancer
+ return new Bucket(request.getRemoteAddr(),
request.getHeader("User-Agent"));
+ }
+
+ public static final class Bucket {
+ private final String remoteAddr;
+ private final String userAgent;
+
+ private Bucket(String remoteAddr, String userAgent) {
+ this.remoteAddr = remoteAddr;
+ this.userAgent = userAgent;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+
+ Bucket bucket = (Bucket) o;
+
+ return Objects.equals(remoteAddr, bucket.remoteAddr)
+ && Objects.equals(userAgent, bucket.userAgent);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(remoteAddr, userAgent);
+ }
+ }
+}
diff --git
a/blazegraph/src/test/java/org/wikidata/query/rdf/blazegraph/throttling/ThrottlingStateTest.java
b/blazegraph/src/test/java/org/wikidata/query/rdf/blazegraph/throttling/ThrottlingStateTest.java
new file mode 100644
index 0000000..b665b3b
--- /dev/null
+++
b/blazegraph/src/test/java/org/wikidata/query/rdf/blazegraph/throttling/ThrottlingStateTest.java
@@ -0,0 +1,123 @@
+package org.wikidata.query.rdf.blazegraph.throttling;
+
+import org.isomorphism.util.TokenBucket;
+import org.isomorphism.util.TokenBuckets;
+import org.junit.Assert;
+import org.junit.Test;
+import org.mockito.Mockito;
+
+import java.time.Duration;
+import java.time.temporal.ChronoUnit;
+import java.util.concurrent.TimeUnit;
+
+import static java.time.temporal.ChronoUnit.MILLIS;
+import static java.time.temporal.ChronoUnit.SECONDS;
+import static java.util.concurrent.TimeUnit.MILLISECONDS;
+import static java.util.concurrent.TimeUnit.MINUTES;
+import static org.hamcrest.CoreMatchers.equalTo;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.mockito.Mockito.when;
+
+public class ThrottlingStateTest {
+
+ @Test
+ public void fullBucketsAreNotThrottled() {
+ assert(!new ThrottlingState(fullBucket(), fullBucket()).isThrottled());
+ }
+
+ @Test
+ public void anyEmptyBucketIsThrottled() {
+ assert(new ThrottlingState(fullBucket(), emptyBucket()).isThrottled());
+ assert(new ThrottlingState(emptyBucket(), fullBucket()).isThrottled());
+ assert(new ThrottlingState(emptyBucket(),
emptyBucket()).isThrottled());
+
+ }
+
+ @Test
+ public void backoffDelayIsZeroWhenBucketsAreNonEmpty() {
+ ThrottlingState state = new ThrottlingState(fullBucket(),
fullBucket());
+ assertThat(state.getBackoffDelay(), equalTo(Duration.of(0, SECONDS)));
+ }
+
+ @Test
+ public void backoffDelayIsTheLargestDelay() {
+ TokenBucket emptyBucket1 = Mockito.mock(TokenBucket.class);
+ when(emptyBucket1.getNumTokens()).thenReturn(0L);
+
when(emptyBucket1.getDurationUntilNextRefill(MILLISECONDS)).thenReturn(1000L);
+
+ TokenBucket emptyBucket2 = Mockito.mock(TokenBucket.class);
+ when(emptyBucket2.getNumTokens()).thenReturn(0L);
+
when(emptyBucket2.getDurationUntilNextRefill(MILLISECONDS)).thenReturn(10000L);
+
+ assertThat(
+ new ThrottlingState(fullBucket(),
emptyBucket1).getBackoffDelay(),
+ equalTo(Duration.of(1, SECONDS)));
+
+ assertThat(
+ new ThrottlingState(emptyBucket1,
emptyBucket2).getBackoffDelay(),
+ equalTo(Duration.of(10, SECONDS)));
+ }
+
+ @Test
+ public void canConsumeTime() {
+ TokenBucket timeBucket = fullBucket();
+ ThrottlingState state = new ThrottlingState(timeBucket, fullBucket());
+ long tokensBefore = timeBucket.getNumTokens();
+
+ state.consumeTime(Duration.of(500, MILLIS));
+
+ assertThat(timeBucket.getNumTokens() + 500, equalTo(tokensBefore));
+ }
+
+ @Test
+ public void canConsumeTimeEvenIfNotEnoughTokensAvailable() {
+ TokenBucket timeBucket = fullBucket();
+ ThrottlingState state = new ThrottlingState(timeBucket, fullBucket());
+
+ state.consumeTime(Duration.of(5, SECONDS));
+
+ assertThat(timeBucket.getNumTokens(), equalTo(0L));
+ }
+
+ @Test
+ public void canConsumeErrors() {
+ TokenBucket errorsBucket = fullBucket();
+ ThrottlingState state = new ThrottlingState(fullBucket(),
errorsBucket);
+ long tokensBefore = errorsBucket.getNumTokens();
+
+ state.consumeError();
+
+ assertThat(errorsBucket.getNumTokens() + 1, equalTo(tokensBefore));
+ }
+
+ @Test
+ public void canConsumeErrorsEvenIfNotEnoughTokensAvailable() {
+ TokenBucket errorsBucket = emptyBucket();
+ ThrottlingState state = new ThrottlingState(errorsBucket,
fullBucket());
+
+ state.consumeError();
+
+ assertThat(errorsBucket.getNumTokens(), equalTo(0L));
+ }
+
+ private TokenBucket fullBucket() {
+ return TokenBuckets.builder()
+ .withCapacity(1000)
+ .withFixedIntervalRefillStrategy(1000, 1, MINUTES)
+ .build();
+ }
+
+ private TokenBucket emptyBucket() {
+ TokenBucket bucket = TokenBuckets.builder()
+ .withCapacity(1000)
+ .withFixedIntervalRefillStrategy(1000, 1, MINUTES)
+ .build();
+ while (bucket.tryConsume()) {
+ // consume all tokens
+ }
+ return bucket;
+ }
+
+
+
+}
diff --git
a/blazegraph/src/test/java/org/wikidata/query/rdf/blazegraph/throttling/UserAgentIpAddressBuckettingTest.java
b/blazegraph/src/test/java/org/wikidata/query/rdf/blazegraph/throttling/UserAgentIpAddressBuckettingTest.java
new file mode 100644
index 0000000..2914db9
--- /dev/null
+++
b/blazegraph/src/test/java/org/wikidata/query/rdf/blazegraph/throttling/UserAgentIpAddressBuckettingTest.java
@@ -0,0 +1,58 @@
+package org.wikidata.query.rdf.blazegraph.throttling;
+
+import org.junit.Before;
+import org.junit.Test;
+import
org.wikidata.query.rdf.blazegraph.throttling.UserAgentIpAddressBucketting.Bucket;
+
+import javax.servlet.http.HttpServletRequest;
+
+import static org.hamcrest.CoreMatchers.equalTo;
+import static org.hamcrest.CoreMatchers.not;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class UserAgentIpAddressBuckettingTest {
+
+ private UserAgentIpAddressBucketting bucketting;
+
+ @Before
+ public void createBuckettingUnderTest() {
+ bucketting = new UserAgentIpAddressBucketting();
+ }
+
+ @Test
+ public void sameUserAgentAndIpAddressAreInTheSameBucket() {
+ Bucket bucket1 = bucketting.bucket(mockRequest("1.2.3.4", "UA1"));
+ Bucket bucket2 = bucketting.bucket(mockRequest("1.2.3.4", "UA1"));
+
+ assertThat(bucket1, equalTo(bucket2));
+ assertThat(bucket1.hashCode(), equalTo(bucket2.hashCode()));
+ }
+
+ @Test
+ public void differentUserAgentsAreInDifferentBuckets() {
+ Bucket bucket1 = bucketting.bucket(mockRequest("1.2.3.4", "UA1"));
+ Bucket bucket2 = bucketting.bucket(mockRequest("1.2.3.4", "UA2"));
+
+ assertThat(bucket1, not(equalTo(bucket2)));
+ assertThat(bucket1.hashCode(), not(equalTo(bucket2.hashCode())));
+ }
+
+ @Test
+ public void differentIpAddressesAreInDifferentBuckets() {
+ Bucket bucket1 = bucketting.bucket(mockRequest("1.2.3.4", "UA1"));
+ Bucket bucket2 = bucketting.bucket(mockRequest("4.3.2.1", "UA1"));
+
+ assertThat(bucket1, not(equalTo(bucket2)));
+ assertThat(bucket1.hashCode(), not(equalTo(bucket2.hashCode())));
+ }
+
+ private HttpServletRequest mockRequest(String ipAddress, String userAgent)
{
+ HttpServletRequest request1 = mock(HttpServletRequest.class);
+ when(request1.getRemoteAddr()).thenReturn(ipAddress);
+ when(request1.getHeader("User-Agent")).thenReturn(userAgent);
+ return request1;
+ }
+
+}
diff --git a/pom.xml b/pom.xml
index 7d65391..97457f0 100644
--- a/pom.xml
+++ b/pom.xml
@@ -247,7 +247,7 @@
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
- <version>21.0</version>
+ <version>22.0</version>
</dependency>
<dependency>
<groupId>com.googlecode.json-simple</groupId>
@@ -422,6 +422,11 @@
<version>${jetty.version}</version>
</dependency>
<dependency>
+ <groupId>org.isomorphism</groupId>
+ <artifactId>token-bucket</artifactId>
+ <version>1.6</version>
+ </dependency>
+ <dependency>
<groupId>org.jolokia</groupId>
<artifactId>jolokia-jvm</artifactId>
<version>1.3.1</version>
diff --git a/war/src/main/webapp/WEB-INF/web.xml
b/war/src/main/webapp/WEB-INF/web.xml
index 91efc26..502085f 100644
--- a/war/src/main/webapp/WEB-INF/web.xml
+++ b/war/src/main/webapp/WEB-INF/web.xml
@@ -84,6 +84,14 @@
<listener>
<listener-class>org.wikidata.query.rdf.blazegraph.WikibaseContextListener</listener-class>
</listener>
+ <filter>
+ <filter-name>Throttling Filter</filter-name>
+
<filter-class>org.wikidata.query.rdf.blazegraph.throttling.ThrottlingFilter</filter-class>
+ </filter>
+ <filter-mapping>
+ <filter-name>Throttling Filter</filter-name>
+ <url-pattern>/*</url-pattern>
+ </filter-mapping>
<servlet>
<servlet-name>REST API</servlet-name>
<display-name>REST API</display-name>
--
To view, visit https://gerrit.wikimedia.org/r/367664
To unsubscribe, visit https://gerrit.wikimedia.org/r/settings
Gerrit-MessageType: newchange
Gerrit-Change-Id: If3c0c28c47f953fdb7f3b6186da8a9535cc18bdf
Gerrit-PatchSet: 2
Gerrit-Project: wikidata/query/rdf
Gerrit-Branch: master
Gerrit-Owner: Gehel <[email protected]>
Gerrit-Reviewer: Smalyshev <[email protected]>
Gerrit-Reviewer: Volans <[email protected]>
Gerrit-Reviewer: jenkins-bot <>
_______________________________________________
MediaWiki-commits mailing list
[email protected]
https://lists.wikimedia.org/mailman/listinfo/mediawiki-commits