Author: russellm
Date: 2010-10-12 18:37:47 -0500 (Tue, 12 Oct 2010)
New Revision: 14191

Modified:
   django/trunk/django/test/__init__.py
   django/trunk/django/test/client.py
   django/trunk/docs/topics/testing.txt
   django/trunk/tests/modeltests/test_client/models.py
Log:
Fixed #9002 -- Added a RequestFactory. This allows you to create request 
instances so you can unit test views as standalone functions. Thanks to Simon 
Willison for the suggestion and snippet on which this patch was originally 
based.

Modified: django/trunk/django/test/__init__.py
===================================================================
--- django/trunk/django/test/__init__.py        2010-10-12 17:27:07 UTC (rev 
14190)
+++ django/trunk/django/test/__init__.py        2010-10-12 23:37:47 UTC (rev 
14191)
@@ -2,6 +2,6 @@
 Django Unit Test and Doctest framework.
 """
 
-from django.test.client import Client
+from django.test.client import Client, RequestFactory
 from django.test.testcases import TestCase, TransactionTestCase, 
skipIfDBFeature, skipUnlessDBFeature
 from django.test.utils import Approximate

Modified: django/trunk/django/test/client.py
===================================================================
--- django/trunk/django/test/client.py  2010-10-12 17:27:07 UTC (rev 14190)
+++ django/trunk/django/test/client.py  2010-10-12 23:37:47 UTC (rev 14191)
@@ -156,8 +156,166 @@
         file.read()
     ]
 
-class Client(object):
+
+
+class RequestFactory(object):
     """
+    Class that lets you create mock Request objects for use in testing.
+
+    Usage:
+
+    rf = RequestFactory()
+    get_request = rf.get('/hello/')
+    post_request = rf.post('/submit/', {'foo': 'bar'})
+
+    Once you have a request object you can pass it to any view function,
+    just as if that view had been hooked up using a URLconf.
+    """
+    def __init__(self, **defaults):
+        self.defaults = defaults
+        self.cookies = SimpleCookie()
+        self.errors = StringIO()
+
+    def _base_environ(self, **request):
+        """
+        The base environment for a request.
+        """
+        environ = {
+            'HTTP_COOKIE':       self.cookies.output(header='', sep='; '),
+            'PATH_INFO':         '/',
+            'QUERY_STRING':      '',
+            'REMOTE_ADDR':       '127.0.0.1',
+            'REQUEST_METHOD':    'GET',
+            'SCRIPT_NAME':       '',
+            'SERVER_NAME':       'testserver',
+            'SERVER_PORT':       '80',
+            'SERVER_PROTOCOL':   'HTTP/1.1',
+            'wsgi.version':      (1,0),
+            'wsgi.url_scheme':   'http',
+            'wsgi.errors':       self.errors,
+            'wsgi.multiprocess': True,
+            'wsgi.multithread':  False,
+            'wsgi.run_once':     False,
+        }
+        environ.update(self.defaults)
+        environ.update(request)
+        return environ
+
+    def request(self, **request):
+        "Construct a generic request object."
+        return WSGIRequest(self._base_environ(**request))
+
+    def get(self, path, data={}, **extra):
+        "Construct a GET request"
+
+        parsed = urlparse(path)
+        r = {
+            'CONTENT_TYPE':    'text/html; charset=utf-8',
+            'PATH_INFO':       urllib.unquote(parsed[2]),
+            'QUERY_STRING':    urlencode(data, doseq=True) or parsed[4],
+            'REQUEST_METHOD': 'GET',
+            'wsgi.input':      FakePayload('')
+        }
+        r.update(extra)
+        return self.request(**r)
+
+    def post(self, path, data={}, content_type=MULTIPART_CONTENT,
+             **extra):
+        "Construct a POST request."
+
+        if content_type is MULTIPART_CONTENT:
+            post_data = encode_multipart(BOUNDARY, data)
+        else:
+            # Encode the content so that the byte representation is correct.
+            match = CONTENT_TYPE_RE.match(content_type)
+            if match:
+                charset = match.group(1)
+            else:
+                charset = settings.DEFAULT_CHARSET
+            post_data = smart_str(data, encoding=charset)
+
+        parsed = urlparse(path)
+        r = {
+            'CONTENT_LENGTH': len(post_data),
+            'CONTENT_TYPE':   content_type,
+            'PATH_INFO':      urllib.unquote(parsed[2]),
+            'QUERY_STRING':   parsed[4],
+            'REQUEST_METHOD': 'POST',
+            'wsgi.input':     FakePayload(post_data),
+        }
+        r.update(extra)
+        return self.request(**r)
+
+    def head(self, path, data={}, **extra):
+        "Construct a HEAD request."
+
+        parsed = urlparse(path)
+        r = {
+            'CONTENT_TYPE':    'text/html; charset=utf-8',
+            'PATH_INFO':       urllib.unquote(parsed[2]),
+            'QUERY_STRING':    urlencode(data, doseq=True) or parsed[4],
+            'REQUEST_METHOD': 'HEAD',
+            'wsgi.input':      FakePayload('')
+        }
+        r.update(extra)
+        return self.request(**r)
+
+    def options(self, path, data={}, **extra):
+        "Constrict an OPTIONS request"
+
+        parsed = urlparse(path)
+        r = {
+            'PATH_INFO':       urllib.unquote(parsed[2]),
+            'QUERY_STRING':    urlencode(data, doseq=True) or parsed[4],
+            'REQUEST_METHOD': 'OPTIONS',
+            'wsgi.input':      FakePayload('')
+        }
+        r.update(extra)
+        return self.request(**r)
+
+    def put(self, path, data={}, content_type=MULTIPART_CONTENT,
+            **extra):
+        "Construct a PUT request."
+
+        if content_type is MULTIPART_CONTENT:
+            post_data = encode_multipart(BOUNDARY, data)
+        else:
+            post_data = data
+
+        # Make `data` into a querystring only if it's not already a string. If
+        # it is a string, we'll assume that the caller has already encoded it.
+        query_string = None
+        if not isinstance(data, basestring):
+            query_string = urlencode(data, doseq=True)
+
+        parsed = urlparse(path)
+        r = {
+            'CONTENT_LENGTH': len(post_data),
+            'CONTENT_TYPE':   content_type,
+            'PATH_INFO':      urllib.unquote(parsed[2]),
+            'QUERY_STRING':   query_string or parsed[4],
+            'REQUEST_METHOD': 'PUT',
+            'wsgi.input':     FakePayload(post_data),
+        }
+        r.update(extra)
+        return self.request(**r)
+
+    def delete(self, path, data={}, **extra):
+        "Construct a DELETE request."
+
+        parsed = urlparse(path)
+        r = {
+            'PATH_INFO':       urllib.unquote(parsed[2]),
+            'QUERY_STRING':    urlencode(data, doseq=True) or parsed[4],
+            'REQUEST_METHOD': 'DELETE',
+            'wsgi.input':      FakePayload('')
+        }
+        r.update(extra)
+        return self.request(**r)
+
+
+class Client(RequestFactory):
+    """
     A class that can act as a client for testing purposes.
 
     It allows the user to compose GET and POST requests, and
@@ -175,11 +333,9 @@
     HTML rendered to the end-user.
     """
     def __init__(self, enforce_csrf_checks=False, **defaults):
+        super(Client, self).__init__(**defaults)
         self.handler = ClientHandler(enforce_csrf_checks)
-        self.defaults = defaults
-        self.cookies = SimpleCookie()
         self.exc_info = None
-        self.errors = StringIO()
 
     def store_exc_info(self, **kwargs):
         """
@@ -199,6 +355,7 @@
         return {}
     session = property(_session)
 
+
     def request(self, **request):
         """
         The master request method. Composes the environment dictionary
@@ -206,25 +363,7 @@
         Assumes defaults for the query environment, which can be overridden
         using the arguments to the request.
         """
-        environ = {
-            'HTTP_COOKIE':       self.cookies.output(header='', sep='; '),
-            'PATH_INFO':         '/',
-            'QUERY_STRING':      '',
-            'REMOTE_ADDR':       '127.0.0.1',
-            'REQUEST_METHOD':    'GET',
-            'SCRIPT_NAME':       '',
-            'SERVER_NAME':       'testserver',
-            'SERVER_PORT':       '80',
-            'SERVER_PROTOCOL':   'HTTP/1.1',
-            'wsgi.version':      (1,0),
-            'wsgi.url_scheme':   'http',
-            'wsgi.errors':       self.errors,
-            'wsgi.multiprocess': True,
-            'wsgi.multithread':  False,
-            'wsgi.run_once':     False,
-        }
-        environ.update(self.defaults)
-        environ.update(request)
+        environ = self._base_environ(**request)
 
         # Curry a data dictionary into an instance of the template renderer
         # callback function.
@@ -290,22 +429,11 @@
             
signals.template_rendered.disconnect(dispatch_uid="template-render")
             got_request_exception.disconnect(dispatch_uid="request-exception")
 
-
     def get(self, path, data={}, follow=False, **extra):
         """
         Requests a response from the server using GET.
         """
-        parsed = urlparse(path)
-        r = {
-            'CONTENT_TYPE':    'text/html; charset=utf-8',
-            'PATH_INFO':       urllib.unquote(parsed[2]),
-            'QUERY_STRING':    urlencode(data, doseq=True) or parsed[4],
-            'REQUEST_METHOD': 'GET',
-            'wsgi.input':      FakePayload('')
-        }
-        r.update(extra)
-
-        response = self.request(**r)
+        response = super(Client, self).get(path, data=data, **extra)
         if follow:
             response = self._handle_redirects(response, **extra)
         return response
@@ -315,29 +443,7 @@
         """
         Requests a response from the server using POST.
         """
-        if content_type is MULTIPART_CONTENT:
-            post_data = encode_multipart(BOUNDARY, data)
-        else:
-            # Encode the content so that the byte representation is correct.
-            match = CONTENT_TYPE_RE.match(content_type)
-            if match:
-                charset = match.group(1)
-            else:
-                charset = settings.DEFAULT_CHARSET
-            post_data = smart_str(data, encoding=charset)
-
-        parsed = urlparse(path)
-        r = {
-            'CONTENT_LENGTH': len(post_data),
-            'CONTENT_TYPE':   content_type,
-            'PATH_INFO':      urllib.unquote(parsed[2]),
-            'QUERY_STRING':   parsed[4],
-            'REQUEST_METHOD': 'POST',
-            'wsgi.input':     FakePayload(post_data),
-        }
-        r.update(extra)
-
-        response = self.request(**r)
+        response = super(Client, self).post(path, data=data, 
content_type=content_type, **extra)
         if follow:
             response = self._handle_redirects(response, **extra)
         return response
@@ -346,17 +452,7 @@
         """
         Request a response from the server using HEAD.
         """
-        parsed = urlparse(path)
-        r = {
-            'CONTENT_TYPE':    'text/html; charset=utf-8',
-            'PATH_INFO':       urllib.unquote(parsed[2]),
-            'QUERY_STRING':    urlencode(data, doseq=True) or parsed[4],
-            'REQUEST_METHOD': 'HEAD',
-            'wsgi.input':      FakePayload('')
-        }
-        r.update(extra)
-
-        response = self.request(**r)
+        response = super(Client, self).head(path, data=data, **extra)
         if follow:
             response = self._handle_redirects(response, **extra)
         return response
@@ -365,16 +461,7 @@
         """
         Request a response from the server using OPTIONS.
         """
-        parsed = urlparse(path)
-        r = {
-            'PATH_INFO':       urllib.unquote(parsed[2]),
-            'QUERY_STRING':    urlencode(data, doseq=True) or parsed[4],
-            'REQUEST_METHOD': 'OPTIONS',
-            'wsgi.input':      FakePayload('')
-        }
-        r.update(extra)
-
-        response = self.request(**r)
+        response = super(Client, self).options(path, data=data, **extra)
         if follow:
             response = self._handle_redirects(response, **extra)
         return response
@@ -384,29 +471,7 @@
         """
         Send a resource to the server using PUT.
         """
-        if content_type is MULTIPART_CONTENT:
-            post_data = encode_multipart(BOUNDARY, data)
-        else:
-            post_data = data
-
-        # Make `data` into a querystring only if it's not already a string. If
-        # it is a string, we'll assume that the caller has already encoded it.
-        query_string = None
-        if not isinstance(data, basestring):
-            query_string = urlencode(data, doseq=True)
-
-        parsed = urlparse(path)
-        r = {
-            'CONTENT_LENGTH': len(post_data),
-            'CONTENT_TYPE':   content_type,
-            'PATH_INFO':      urllib.unquote(parsed[2]),
-            'QUERY_STRING':   query_string or parsed[4],
-            'REQUEST_METHOD': 'PUT',
-            'wsgi.input':     FakePayload(post_data),
-        }
-        r.update(extra)
-
-        response = self.request(**r)
+        response = super(Client, self).put(path, data=data, 
content_type=content_type, **extra)
         if follow:
             response = self._handle_redirects(response, **extra)
         return response
@@ -415,23 +480,14 @@
         """
         Send a DELETE request to the server.
         """
-        parsed = urlparse(path)
-        r = {
-            'PATH_INFO':       urllib.unquote(parsed[2]),
-            'QUERY_STRING':    urlencode(data, doseq=True) or parsed[4],
-            'REQUEST_METHOD': 'DELETE',
-            'wsgi.input':      FakePayload('')
-        }
-        r.update(extra)
-
-        response = self.request(**r)
+        response = super(Client, self).delete(path, data=data, **extra)
         if follow:
             response = self._handle_redirects(response, **extra)
         return response
 
     def login(self, **credentials):
         """
-        Sets the Client to appear as if it has successfully logged into a site.
+        Sets the Factory to appear as if it has successfully logged into a 
site.
 
         Returns True if login is possible; False if the provided credentials
         are incorrect, or the user is inactive, or if the sessions framework is
@@ -506,4 +562,3 @@
             if response.redirect_chain[-1] in response.redirect_chain[0:-1]:
                 break
         return response
-

Modified: django/trunk/docs/topics/testing.txt
===================================================================
--- django/trunk/docs/topics/testing.txt        2010-10-12 17:27:07 UTC (rev 
14190)
+++ django/trunk/docs/topics/testing.txt        2010-10-12 23:37:47 UTC (rev 
14191)
@@ -1014,6 +1014,51 @@
             # Check that the rendered context contains 5 customers.
             self.assertEqual(len(response.context['customers']), 5)
 
+The request factory
+-------------------
+
+.. Class:: RequestFactory
+
+The :class:`~django.test.client.RequestFactory` is a simplified
+version of the test client that provides a way to generate a request
+instance that can be used as the first argument to any view. This
+means you can test a view function the same way as you would test any
+other function -- as a black box, with exactly known inputs, testing
+for specific outputs.
+
+The API for the :class:`~django.test.client.RequestFactory` is a slightly
+restricted subset of the test client API:
+
+    * It only has access to the HTTP methods :meth:`~Client.get()`,
+      :meth:`~Client.post()`, :meth:`~Client.put()`,
+      :meth:`~Client.delete()`, :meth:`~Client.head()` and
+      :meth:`~Client.options()`.
+
+    * These methods accept all the same arguments *except* for
+      ``follows``. Since this is just a factory for producing
+      requests, it's up to you to handle the response.
+
+Example
+~~~~~~~
+
+The following is a simple unit test using the request factory::
+
+    from django.utils import unittest
+    from django.test.client import RequestFactory
+
+    class SimpleTest(unittest.TestCase):
+        def setUp(self):
+            # Every test needs a client.
+            self.factory = RequestFactory()
+
+        def test_details(self):
+            # Issue a GET request.
+            request = self.factory.get('/customer/details')
+
+            # Test my_view() as if it were deployed at /customer/details
+            response = my_view(request)
+            self.assertEquals(response.status_code, 200)
+
 TestCase
 --------
 

Modified: django/trunk/tests/modeltests/test_client/models.py
===================================================================
--- django/trunk/tests/modeltests/test_client/models.py 2010-10-12 17:27:07 UTC 
(rev 14190)
+++ django/trunk/tests/modeltests/test_client/models.py 2010-10-12 23:37:47 UTC 
(rev 14191)
@@ -20,10 +20,13 @@
 rather than the HTML rendered to the end-user.
 
 """
-from django.test import Client, TestCase
 from django.conf import settings
 from django.core import mail
+from django.test import Client, TestCase, RequestFactory
 
+from views import get_view
+
+
 class ClientTest(TestCase):
     fixtures = ['testdata.json']
 
@@ -469,3 +472,12 @@
         """A test case can specify a custom class for self.client."""
         self.assertEqual(hasattr(self.client, "i_am_customized"), True)
 
+
+class RequestFactoryTest(TestCase):
+    def test_request_factory(self):
+        factory = RequestFactory()
+        request = factory.get('/somewhere/')
+        response = get_view(request)
+
+        self.assertEqual(response.status_code, 200)
+        self.assertContains(response, 'This is a test')

-- 
You received this message because you are subscribed to the Google Groups 
"Django updates" group.
To post to this group, send email to django-upda...@googlegroups.com.
To unsubscribe from this group, send email to 
django-updates+unsubscr...@googlegroups.com.
For more options, visit this group at 
http://groups.google.com/group/django-updates?hl=en.

Reply via email to