Gehel has uploaded a new change for review. ( 
https://gerrit.wikimedia.org/r/367664 )

Change subject: [WIP] playing around throttling filter
......................................................................

[WIP] playing around throttling filter

Change-Id: If3c0c28c47f953fdb7f3b6186da8a9535cc18bdf
---
M blazegraph/pom.xml
A 
blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/Bucketting.java
A 
blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/Throttler.java
A 
blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/ThrottlingFilter.java
A 
blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/ThrottlingState.java
A 
blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/UserAgentIpAddressBucketting.java
A 
blazegraph/src/test/java/org/wikidata/query/rdf/blazegraph/throttling/ThrottlingStateTest.java
A 
blazegraph/src/test/java/org/wikidata/query/rdf/blazegraph/throttling/UserAgentIpAddressBuckettingTest.java
M pom.xml
M war/src/main/webapp/WEB-INF/web.xml
10 files changed, 461 insertions(+), 1 deletion(-)


  git pull ssh://gerrit.wikimedia.org:29418/wikidata/query/rdf 
refs/changes/64/367664/2

diff --git a/blazegraph/pom.xml b/blazegraph/pom.xml
index 86187ef..d219c25 100644
--- a/blazegraph/pom.xml
+++ b/blazegraph/pom.xml
@@ -74,6 +74,10 @@
       <artifactId>jetty-http</artifactId>
     </dependency>
     <dependency>
+      <groupId>org.isomorphism</groupId>
+      <artifactId>token-bucket</artifactId>
+    </dependency>
+    <dependency>
       <groupId>org.linkeddatafragments</groupId>
       <artifactId>ldfserver</artifactId>
       <classifier>classes</classifier>
diff --git 
a/blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/Bucketting.java
 
b/blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/Bucketting.java
new file mode 100644
index 0000000..0668725
--- /dev/null
+++ 
b/blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/Bucketting.java
@@ -0,0 +1,10 @@
+package org.wikidata.query.rdf.blazegraph.throttling;
+
+import javax.servlet.http.HttpServletRequest;
+
+/**
+ * Created by gehel on 25.07.17.
+ */
+public interface Bucketting<T extends Object> {
+    T bucket(HttpServletRequest request);
+}
diff --git 
a/blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/Throttler.java
 
b/blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/Throttler.java
new file mode 100644
index 0000000..d2b8c6e
--- /dev/null
+++ 
b/blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/Throttler.java
@@ -0,0 +1,82 @@
+package org.wikidata.query.rdf.blazegraph.throttling;
+
+import com.google.common.cache.Cache;
+import com.google.common.cache.CacheBuilder;
+import org.isomorphism.util.TokenBuckets;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+import java.time.Duration;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutionException;
+
+import static java.lang.Math.max;
+import static java.lang.Math.min;
+import static java.time.temporal.ChronoUnit.MILLIS;
+import static java.util.concurrent.TimeUnit.MILLISECONDS;
+import static java.util.concurrent.TimeUnit.MINUTES;
+
+public class Throttler<K extends Object> {
+
+    private static final Logger log = LoggerFactory.getLogger(Throttler.class);
+
+    private final Bucketting<K> bucketting;
+    private final Cache<K, ThrottlingState> state;
+    private final Duration requestTimeThreshold;
+    private Callable<ThrottlingState> createThrottlingState;
+
+    public Throttler(Duration requestTimeThreshold, Bucketting<K> bucketting, 
int maxStateEntries, Duration stateTTL, Callable<ThrottlingState> 
createThrottlingState) {
+        this.requestTimeThreshold = requestTimeThreshold;
+        this.bucketting = bucketting;
+        this.state = CacheBuilder.newBuilder()
+                .maximumSize(maxStateEntries)
+                .expireAfterAccess(stateTTL.get(MILLIS), MILLISECONDS)
+                .build();
+        this.createThrottlingState = createThrottlingState;
+    }
+
+    public boolean isThrottled(HttpServletRequest request) {
+        ThrottlingState throttlingState = 
state.getIfPresent(bucketting.bucket(request));
+        if (throttlingState == null) return false;
+
+        return throttlingState.isThrottled();
+    }
+
+
+    public void success(HttpServletRequest request, HttpServletResponse 
response, Duration elapsed) {
+        try {
+            // only start to keep track of time usage if requests are expensive
+            if (elapsed.compareTo(requestTimeThreshold) > 0) {
+                ThrottlingState throttlingState = 
state.get(bucketting.bucket(request), createThrottlingState);
+
+                throttlingState.consumeTime(elapsed);
+            }
+        } catch (ExecutionException ee) {
+            log.warn("Could not create throttling state", ee);
+        }
+    }
+
+    public void failure(HttpServletRequest request, HttpServletResponse 
response, Duration elapsed, Exception e) {
+        try {
+            ThrottlingState throttlingState = 
state.get(bucketting.bucket(request), createThrottlingState);
+
+            throttlingState.consumeError();
+            throttlingState.consumeTime(elapsed);
+        } catch (ExecutionException ee) {
+            log.warn("Could not create throttling state", ee);
+        }
+    }
+
+    public void throttled() {
+    }
+
+    public Duration getBackoffDelay(HttpServletRequest request) {
+        ThrottlingState throttlingState = 
state.getIfPresent(bucketting.bucket(request));
+        if (throttlingState == null) return Duration.of(0, MILLIS);
+
+        return throttlingState.getBackoffDelay();
+    }
+
+}
diff --git 
a/blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/ThrottlingFilter.java
 
b/blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/ThrottlingFilter.java
new file mode 100644
index 0000000..de8e4da
--- /dev/null
+++ 
b/blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/ThrottlingFilter.java
@@ -0,0 +1,82 @@
+package org.wikidata.query.rdf.blazegraph.throttling;
+
+
+import com.google.common.base.Stopwatch;
+import org.isomorphism.util.TokenBuckets;
+import 
org.wikidata.query.rdf.blazegraph.throttling.UserAgentIpAddressBucketting.Bucket;
+
+import javax.servlet.Filter;
+import javax.servlet.FilterChain;
+import javax.servlet.FilterConfig;
+import javax.servlet.ServletException;
+import javax.servlet.ServletRequest;
+import javax.servlet.ServletResponse;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+import java.io.IOException;
+import java.time.Duration;
+import java.time.temporal.ChronoUnit;
+
+import static java.time.temporal.ChronoUnit.SECONDS;
+import static java.util.concurrent.TimeUnit.MINUTES;
+
+public class ThrottlingFilter implements Filter {
+
+    private Throttler<Bucket> throttler;
+
+    @Override
+    public void init(FilterConfig filterConfig) throws ServletException {
+        // TODO: we need to extract all those parameters, not sure if we want
+        // to expose them as servlet filter params or in a specific config
+        // file.
+
+        // TODO: the value of the parameters are mostly invented, we need to
+        // find the correct ones.
+
+        throttler = new Throttler<>(
+                Duration.of(30, SECONDS),
+                new UserAgentIpAddressBucketting(),
+                10000, Duration.of(15, ChronoUnit.MINUTES),
+                () -> new ThrottlingState(
+                        TokenBuckets.builder()
+                                .withCapacity(10000)
+                                .withFixedIntervalRefillStrategy(1000000, 1, 
MINUTES)
+                                .build(),
+                        TokenBuckets.builder()
+                                .withCapacity(10)
+                                .withFixedIntervalRefillStrategy(100, 1, 
MINUTES)
+                                .build()));
+    }
+
+    @Override
+    public void doFilter(ServletRequest request, ServletResponse response, 
FilterChain chain) throws IOException, ServletException {
+        HttpServletRequest httpRequest = (HttpServletRequest) request;
+        HttpServletResponse httpResponse = (HttpServletResponse) response;
+
+        if (throttler.isThrottled(httpRequest)) {
+            throttler.throttled();
+            notifyUser(httpResponse, throttler.getBackoffDelay(httpRequest));
+            // TODO: we probably want to publish metrics on throttling rate
+            return;
+        }
+
+        Stopwatch stopwatch = Stopwatch.createStarted();
+        try {
+            chain.doFilter(request, response);
+            throttler.success(httpRequest, httpResponse, stopwatch.elapsed());
+        } catch (Exception e) {
+            throttler.failure(httpRequest, httpResponse, stopwatch.elapsed(), 
e);
+            throw e;
+        }
+    }
+
+    public void notifyUser(HttpServletResponse response, Duration 
backoffDelay) throws IOException {
+        response.setHeader("Retry-After", 
Long.toString(backoffDelay.getSeconds()));
+        response.sendError(429, "Too Many Requests - Please lower your request 
rate.");
+    }
+
+    @Override
+    public void destroy() {
+        // Nothing to destroy
+    }
+}
diff --git 
a/blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/ThrottlingState.java
 
b/blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/ThrottlingState.java
new file mode 100644
index 0000000..7311a23
--- /dev/null
+++ 
b/blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/ThrottlingState.java
@@ -0,0 +1,46 @@
+package org.wikidata.query.rdf.blazegraph.throttling;
+
+import org.isomorphism.util.TokenBucket;
+
+import java.time.Duration;
+
+import static java.lang.Math.max;
+import static java.lang.Math.min;
+import static java.time.temporal.ChronoUnit.MILLIS;
+import static java.util.concurrent.TimeUnit.MILLISECONDS;
+
+public final class ThrottlingState {
+
+    private final TokenBucket timeBucket;
+    private final TokenBucket errorsBucket;
+
+    public ThrottlingState(TokenBucket timeBucket, TokenBucket errorsBucket) {
+        this.timeBucket = timeBucket;
+        this.errorsBucket = errorsBucket;
+    }
+
+    public synchronized boolean isThrottled() {
+        return timeBucket.getNumTokens() == 0 || errorsBucket.getNumTokens() 
== 0;
+    }
+
+    public synchronized Duration getBackoffDelay() {
+        return Duration.of(
+                max(backoffDelayMillis(timeBucket), 
backoffDelayMillis(errorsBucket)),
+                MILLIS
+        );
+    }
+
+    public synchronized void consumeTime(Duration elapsed) {
+        timeBucket.consume(min(elapsed.toMillis(), timeBucket.getNumTokens()));
+    }
+
+    public synchronized void consumeError() {
+        errorsBucket.tryConsume();
+    }
+
+    private static long backoffDelayMillis(TokenBucket bucket) {
+        if (bucket.getNumTokens() > 0) return 0;
+        return bucket.getDurationUntilNextRefill(MILLISECONDS);
+    }
+
+}
diff --git 
a/blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/UserAgentIpAddressBucketting.java
 
b/blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/UserAgentIpAddressBucketting.java
new file mode 100644
index 0000000..22fd77b
--- /dev/null
+++ 
b/blazegraph/src/main/java/org/wikidata/query/rdf/blazegraph/throttling/UserAgentIpAddressBucketting.java
@@ -0,0 +1,42 @@
+package org.wikidata.query.rdf.blazegraph.throttling;
+
+import javax.servlet.http.HttpServletRequest;
+import java.util.Objects;
+
+public class UserAgentIpAddressBucketting implements 
Bucketting<UserAgentIpAddressBucketting.Bucket> {
+
+
+    @Override
+    public Bucket bucket(HttpServletRequest request) {
+        // TODO: we need to check if Jetty is configured to honor
+        // x-forwarded-for HTTP header, else request.getRemoteAddr() will just
+        // return the IP of the load balancer
+        return new Bucket(request.getRemoteAddr(), 
request.getHeader("User-Agent"));
+    }
+
+    public static final class Bucket {
+        private final String remoteAddr;
+        private final String userAgent;
+
+        private Bucket(String remoteAddr, String userAgent) {
+            this.remoteAddr = remoteAddr;
+            this.userAgent = userAgent;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (o == null || getClass() != o.getClass()) return false;
+
+            Bucket bucket = (Bucket) o;
+
+            return Objects.equals(remoteAddr, bucket.remoteAddr)
+                    && Objects.equals(userAgent, bucket.userAgent);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(remoteAddr, userAgent);
+        }
+    }
+}
diff --git 
a/blazegraph/src/test/java/org/wikidata/query/rdf/blazegraph/throttling/ThrottlingStateTest.java
 
b/blazegraph/src/test/java/org/wikidata/query/rdf/blazegraph/throttling/ThrottlingStateTest.java
new file mode 100644
index 0000000..b665b3b
--- /dev/null
+++ 
b/blazegraph/src/test/java/org/wikidata/query/rdf/blazegraph/throttling/ThrottlingStateTest.java
@@ -0,0 +1,123 @@
+package org.wikidata.query.rdf.blazegraph.throttling;
+
+import org.isomorphism.util.TokenBucket;
+import org.isomorphism.util.TokenBuckets;
+import org.junit.Assert;
+import org.junit.Test;
+import org.mockito.Mockito;
+
+import java.time.Duration;
+import java.time.temporal.ChronoUnit;
+import java.util.concurrent.TimeUnit;
+
+import static java.time.temporal.ChronoUnit.MILLIS;
+import static java.time.temporal.ChronoUnit.SECONDS;
+import static java.util.concurrent.TimeUnit.MILLISECONDS;
+import static java.util.concurrent.TimeUnit.MINUTES;
+import static org.hamcrest.CoreMatchers.equalTo;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.mockito.Mockito.when;
+
+public class ThrottlingStateTest {
+
+    @Test
+    public void fullBucketsAreNotThrottled() {
+        assert(!new ThrottlingState(fullBucket(), fullBucket()).isThrottled());
+    }
+
+    @Test
+    public void anyEmptyBucketIsThrottled() {
+        assert(new ThrottlingState(fullBucket(), emptyBucket()).isThrottled());
+        assert(new ThrottlingState(emptyBucket(), fullBucket()).isThrottled());
+        assert(new ThrottlingState(emptyBucket(), 
emptyBucket()).isThrottled());
+
+    }
+
+    @Test
+    public void backoffDelayIsZeroWhenBucketsAreNonEmpty() {
+        ThrottlingState state = new ThrottlingState(fullBucket(), 
fullBucket());
+        assertThat(state.getBackoffDelay(), equalTo(Duration.of(0, SECONDS)));
+    }
+
+    @Test
+    public void backoffDelayIsTheLargestDelay() {
+        TokenBucket emptyBucket1 = Mockito.mock(TokenBucket.class);
+        when(emptyBucket1.getNumTokens()).thenReturn(0L);
+        
when(emptyBucket1.getDurationUntilNextRefill(MILLISECONDS)).thenReturn(1000L);
+
+        TokenBucket emptyBucket2 = Mockito.mock(TokenBucket.class);
+        when(emptyBucket2.getNumTokens()).thenReturn(0L);
+        
when(emptyBucket2.getDurationUntilNextRefill(MILLISECONDS)).thenReturn(10000L);
+
+        assertThat(
+                new ThrottlingState(fullBucket(), 
emptyBucket1).getBackoffDelay(),
+                equalTo(Duration.of(1, SECONDS)));
+
+        assertThat(
+                new ThrottlingState(emptyBucket1, 
emptyBucket2).getBackoffDelay(),
+                equalTo(Duration.of(10, SECONDS)));
+    }
+
+    @Test
+    public void canConsumeTime() {
+        TokenBucket timeBucket = fullBucket();
+        ThrottlingState state = new ThrottlingState(timeBucket, fullBucket());
+        long tokensBefore = timeBucket.getNumTokens();
+
+        state.consumeTime(Duration.of(500, MILLIS));
+
+        assertThat(timeBucket.getNumTokens() + 500, equalTo(tokensBefore));
+    }
+
+    @Test
+    public void canConsumeTimeEvenIfNotEnoughTokensAvailable() {
+        TokenBucket timeBucket = fullBucket();
+        ThrottlingState state = new ThrottlingState(timeBucket, fullBucket());
+
+        state.consumeTime(Duration.of(5, SECONDS));
+
+        assertThat(timeBucket.getNumTokens(), equalTo(0L));
+    }
+
+    @Test
+    public void canConsumeErrors() {
+        TokenBucket errorsBucket = fullBucket();
+        ThrottlingState state = new ThrottlingState(fullBucket(), 
errorsBucket);
+        long tokensBefore = errorsBucket.getNumTokens();
+
+        state.consumeError();
+
+        assertThat(errorsBucket.getNumTokens() + 1, equalTo(tokensBefore));
+    }
+
+    @Test
+    public void canConsumeErrorsEvenIfNotEnoughTokensAvailable() {
+        TokenBucket errorsBucket = emptyBucket();
+        ThrottlingState state = new ThrottlingState(errorsBucket, 
fullBucket());
+
+        state.consumeError();
+
+        assertThat(errorsBucket.getNumTokens(), equalTo(0L));
+    }
+
+    private TokenBucket fullBucket() {
+        return TokenBuckets.builder()
+                .withCapacity(1000)
+                .withFixedIntervalRefillStrategy(1000, 1, MINUTES)
+                .build();
+    }
+
+    private TokenBucket emptyBucket() {
+        TokenBucket bucket = TokenBuckets.builder()
+                .withCapacity(1000)
+                .withFixedIntervalRefillStrategy(1000, 1, MINUTES)
+                .build();
+        while (bucket.tryConsume()) {
+            // consume all tokens
+        }
+        return bucket;
+    }
+
+
+
+}
diff --git 
a/blazegraph/src/test/java/org/wikidata/query/rdf/blazegraph/throttling/UserAgentIpAddressBuckettingTest.java
 
b/blazegraph/src/test/java/org/wikidata/query/rdf/blazegraph/throttling/UserAgentIpAddressBuckettingTest.java
new file mode 100644
index 0000000..2914db9
--- /dev/null
+++ 
b/blazegraph/src/test/java/org/wikidata/query/rdf/blazegraph/throttling/UserAgentIpAddressBuckettingTest.java
@@ -0,0 +1,58 @@
+package org.wikidata.query.rdf.blazegraph.throttling;
+
+import org.junit.Before;
+import org.junit.Test;
+import 
org.wikidata.query.rdf.blazegraph.throttling.UserAgentIpAddressBucketting.Bucket;
+
+import javax.servlet.http.HttpServletRequest;
+
+import static org.hamcrest.CoreMatchers.equalTo;
+import static org.hamcrest.CoreMatchers.not;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+public class UserAgentIpAddressBuckettingTest {
+
+    private UserAgentIpAddressBucketting bucketting;
+
+    @Before
+    public void createBuckettingUnderTest() {
+        bucketting = new UserAgentIpAddressBucketting();
+    }
+
+    @Test
+    public void sameUserAgentAndIpAddressAreInTheSameBucket() {
+        Bucket bucket1 = bucketting.bucket(mockRequest("1.2.3.4", "UA1"));
+        Bucket bucket2 = bucketting.bucket(mockRequest("1.2.3.4", "UA1"));
+
+        assertThat(bucket1, equalTo(bucket2));
+        assertThat(bucket1.hashCode(), equalTo(bucket2.hashCode()));
+    }
+
+    @Test
+    public void differentUserAgentsAreInDifferentBuckets() {
+        Bucket bucket1 = bucketting.bucket(mockRequest("1.2.3.4", "UA1"));
+        Bucket bucket2 = bucketting.bucket(mockRequest("1.2.3.4", "UA2"));
+
+        assertThat(bucket1, not(equalTo(bucket2)));
+        assertThat(bucket1.hashCode(), not(equalTo(bucket2.hashCode())));
+    }
+
+    @Test
+    public void differentIpAddressesAreInDifferentBuckets() {
+        Bucket bucket1 = bucketting.bucket(mockRequest("1.2.3.4", "UA1"));
+        Bucket bucket2 = bucketting.bucket(mockRequest("4.3.2.1", "UA1"));
+
+        assertThat(bucket1, not(equalTo(bucket2)));
+        assertThat(bucket1.hashCode(), not(equalTo(bucket2.hashCode())));
+    }
+
+    private HttpServletRequest mockRequest(String ipAddress, String userAgent) 
{
+        HttpServletRequest request1 = mock(HttpServletRequest.class);
+        when(request1.getRemoteAddr()).thenReturn(ipAddress);
+        when(request1.getHeader("User-Agent")).thenReturn(userAgent);
+        return request1;
+    }
+
+}
diff --git a/pom.xml b/pom.xml
index 7d65391..97457f0 100644
--- a/pom.xml
+++ b/pom.xml
@@ -247,7 +247,7 @@
       <dependency>
         <groupId>com.google.guava</groupId>
         <artifactId>guava</artifactId>
-        <version>21.0</version>
+        <version>22.0</version>
       </dependency>
       <dependency>
         <groupId>com.googlecode.json-simple</groupId>
@@ -422,6 +422,11 @@
         <version>${jetty.version}</version>
       </dependency>
       <dependency>
+        <groupId>org.isomorphism</groupId>
+        <artifactId>token-bucket</artifactId>
+        <version>1.6</version>
+      </dependency>
+      <dependency>
         <groupId>org.jolokia</groupId>
         <artifactId>jolokia-jvm</artifactId>
         <version>1.3.1</version>
diff --git a/war/src/main/webapp/WEB-INF/web.xml 
b/war/src/main/webapp/WEB-INF/web.xml
index 91efc26..502085f 100644
--- a/war/src/main/webapp/WEB-INF/web.xml
+++ b/war/src/main/webapp/WEB-INF/web.xml
@@ -84,6 +84,14 @@
   <listener>
    
<listener-class>org.wikidata.query.rdf.blazegraph.WikibaseContextListener</listener-class>
   </listener>
+  <filter>
+      <filter-name>Throttling Filter</filter-name>
+      
<filter-class>org.wikidata.query.rdf.blazegraph.throttling.ThrottlingFilter</filter-class>
+  </filter>
+  <filter-mapping>
+      <filter-name>Throttling Filter</filter-name>
+      <url-pattern>/*</url-pattern>
+  </filter-mapping>
   <servlet>
    <servlet-name>REST API</servlet-name>
    <display-name>REST API</display-name>

-- 
To view, visit https://gerrit.wikimedia.org/r/367664
To unsubscribe, visit https://gerrit.wikimedia.org/r/settings

Gerrit-MessageType: newchange
Gerrit-Change-Id: If3c0c28c47f953fdb7f3b6186da8a9535cc18bdf
Gerrit-PatchSet: 2
Gerrit-Project: wikidata/query/rdf
Gerrit-Branch: master
Gerrit-Owner: Gehel <[email protected]>
Gerrit-Reviewer: Smalyshev <[email protected]>
Gerrit-Reviewer: Volans <[email protected]>
Gerrit-Reviewer: jenkins-bot <>

_______________________________________________
MediaWiki-commits mailing list
[email protected]
https://lists.wikimedia.org/mailman/listinfo/mediawiki-commits

Reply via email to