commit 1452e777377d2ecccc4b7bf0dff3d815fbb018b0
Author: Arturo Filastò <a...@fuffa.org>
Date:   Wed Apr 23 16:41:50 2014 +0200

    Set limits on the number of headers that can be sent and their length.
---
 oonib/testhelpers/http_helpers.py |   39 +++++++++++++++++++++++--------------
 1 file changed, 24 insertions(+), 15 deletions(-)

diff --git a/oonib/testhelpers/http_helpers.py 
b/oonib/testhelpers/http_helpers.py
index 6c3e123..245e08a 100644
--- a/oonib/testhelpers/http_helpers.py
+++ b/oonib/testhelpers/http_helpers.py
@@ -2,18 +2,15 @@ import json
 import random
 import string
 
-from twisted.application import internet, service
-from twisted.internet import protocol, reactor, defer
-from twisted.protocols import basic
-from twisted.web import resource, server, static, http
-from twisted.web.microdom import escape
+from twisted.internet import protocol, defer
 
 from cyclone.web import RequestHandler, Application
 
 from twisted.protocols import policies, basic
 from twisted.web.http import Request
 
-from oonib import randomStr
+from oonib import log
+
 
 class SimpleHTTPChannel(basic.LineReceiver, policies.TimeoutMixin):
     """
@@ -41,6 +38,7 @@ class SimpleHTTPChannel(basic.LineReceiver, 
policies.TimeoutMixin):
 
     length = 0
     maxHeaders = 500
+    maxHeaderLineLength = 16384
     requestLine = ''
 
     timeOut = 60 * 60 * 12
@@ -53,6 +51,10 @@ class SimpleHTTPChannel(basic.LineReceiver, 
policies.TimeoutMixin):
         self.setTimeout(self.timeOut)
 
     def lineReceived(self, line):
+        if len(self.__header) >= self.maxHeaderLineLength:
+            log.err("Maximum header length reached.")
+            return self.transport.loseConnection()
+
         if self.__first_line:
             self.requestLine = line
             self.__first_line = 0
@@ -66,7 +68,7 @@ class SimpleHTTPChannel(basic.LineReceiver, 
policies.TimeoutMixin):
         elif line[0] in ' \t':
             # This is to support header field value folding over multiple lines
             # as specified by rfc2616.
-            self.__header = self.__header+'\n'+line
+            self.__header += '\n'+line
         else:
             if self.__header:
                 self.headerReceived(self.__header)
@@ -80,6 +82,10 @@ class SimpleHTTPChannel(basic.LineReceiver, 
policies.TimeoutMixin):
             log.err("Got malformed HTTP Header request field")
             log.err("%s" % line)
 
+        if len(self.headers) >= self.maxHeaders:
+            log.err("Maximum number of headers received.")
+            return self.transport.loseConnection()
+
     def allHeadersReceived(self):
         headers_dict = {}
         for k, v in self.headers:
@@ -88,9 +94,9 @@ class SimpleHTTPChannel(basic.LineReceiver, 
policies.TimeoutMixin):
             headers_dict[k].append(v)
 
         response = {'request_headers': self.headers,
-            'request_line': self.requestLine,
-            'headers_dict': headers_dict
-        }
+                    'request_line': self.requestLine,
+                    'headers_dict': headers_dict
+                    }
         json_response = json.dumps(response)
         self.transport.write('HTTP/1.1 200 OK\r\n\r\n')
         self.transport.write('%s' % json_response)
@@ -99,22 +105,25 @@ class SimpleHTTPChannel(basic.LineReceiver, 
policies.TimeoutMixin):
 
 class HTTPReturnJSONHeadersHelper(protocol.ServerFactory):
     protocol = SimpleHTTPChannel
+
     def buildProtocol(self, addr):
         return self.protocol()
 
+
 class HTTPTrapAll(RequestHandler):
     def _execute(self, transforms, *args, **kwargs):
         self._transforms = transforms
         defer.maybeDeferred(self.prepare).addCallbacks(
-                    self._execute_handler,
-                    lambda f: self._handle_request_exception(f.value),
-                    callbackArgs=(args, kwargs))
+            self._execute_handler,
+            lambda f: self._handle_request_exception(f.value),
+            callbackArgs=(args, kwargs))
 
     def _execute_handler(self, r, args, kwargs):
         if not self._finished:
             args = [self.decode_argument(arg) for arg in args]
             kwargs = dict((k, self.decode_argument(v, name=k))
-                            for (k, v) in kwargs.iteritems())
+                          for (k, v) in kwargs.iteritems())
+
             # This is where we do the patching
             # XXX this is somewhat hackish
             d = defer.maybeDeferred(self.all, *args, **kwargs)
@@ -130,6 +139,7 @@ class HTTPRandomPage(HTTPTrapAll):
     XXX this is currently disabled as it is not of use to any test.
     """
     isLeaf = True
+
     def _gen_random_string(self, length):
         return ''.join(random.choice(string.letters) for x in range(length))
 
@@ -152,4 +162,3 @@ HTTPRandomPageHelper = Application([
     # XXX add regexps here
     (r"/(.*)/(.*)", HTTPRandomPage)
 ])
-



_______________________________________________
tor-commits mailing list
tor-commits@lists.torproject.org
https://lists.torproject.org/cgi-bin/mailman/listinfo/tor-commits

Reply via email to