Author: russellm
Date: 2010-03-07 01:13:55 -0600 (Sun, 07 Mar 2010)
New Revision: 12701

Modified:
   django/trunk/django/db/models/query.py
   django/trunk/tests/regressiontests/multiple_database/tests.py
Log:
Fixed #13003 -- Ensured that ._state.db is set correctly for select_related() 
queries. Thanks to Alex Gaynor for the report.

Modified: django/trunk/django/db/models/query.py
===================================================================
--- django/trunk/django/db/models/query.py      2010-03-07 07:11:22 UTC (rev 
12700)
+++ django/trunk/django/db/models/query.py      2010-03-07 07:13:55 UTC (rev 
12701)
@@ -267,7 +267,7 @@
         for row in compiler.results_iter():
             if fill_cache:
                 obj, _ = get_cached_row(self.model, row,
-                            index_start, max_depth,
+                            index_start, using=self.db, max_depth=max_depth,
                             requested=requested, offset=len(aggregate_select),
                             only_load=only_load)
             else:
@@ -279,6 +279,9 @@
                     # Omit aggregates in object creation.
                     obj = self.model(*row[index_start:aggregate_start])
 
+                # Store the source database of the object
+                obj._state.db = self.db
+
             for i, k in enumerate(extra_select):
                 setattr(obj, k, row[i])
 
@@ -286,9 +289,6 @@
             for i, aggregate in enumerate(aggregate_select):
                 setattr(obj, aggregate, row[i+aggregate_start])
 
-            # Store the source database of the object
-            obj._state.db = self.db
-
             yield obj
 
     def aggregate(self, *args, **kwargs):
@@ -1112,7 +1112,7 @@
     value_annotation = False
 
 
-def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0,
+def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
                    requested=None, offset=0, only_load=None):
     """
     Helper function that recursively returns an object with the specified
@@ -1126,6 +1126,7 @@
      * row - the row of data returned by the database cursor
      * index_start - the index of the row at which data for this
        object is known to start
+     * using - the database alias on which the query is being executed.
      * max_depth - the maximum depth to which a select_related()
        relationship should be explored.
      * cur_depth - the current depth in the select_related() tree.
@@ -1170,6 +1171,7 @@
             obj = klass(**dict(zip(init_list, fields)))
         else:
             obj = klass(*fields)
+
     else:
         # Load all fields on klass
         field_count = len(klass._meta.fields)
@@ -1182,6 +1184,10 @@
         else:
             obj = klass(*fields)
 
+    # If an object was retrieved, set the database state.
+    if obj:
+        obj._state.db = using
+
     index_end = index_start + field_count + offset
     # Iterate over each related object, populating any
     # select_related() fields
@@ -1193,8 +1199,8 @@
         else:
             next = None
         # Recursively retrieve the data for the related object
-        cached_row = get_cached_row(f.rel.to, row, index_end, max_depth,
-                cur_depth+1, next)
+        cached_row = get_cached_row(f.rel.to, row, index_end, using,
+                max_depth, cur_depth+1, next)
         # If the recursive descent found an object, populate the
         # descriptor caches relevant to the object
         if cached_row:
@@ -1222,8 +1228,8 @@
                 continue
             next = requested[f.related_query_name()]
             # Recursively retrieve the data for the related object
-            cached_row = get_cached_row(model, row, index_end, max_depth,
-                cur_depth+1, next)
+            cached_row = get_cached_row(model, row, index_end, using,
+                max_depth, cur_depth+1, next)
             # If the recursive descent found an object, populate the
             # descriptor caches relevant to the object
             if cached_row:

Modified: django/trunk/tests/regressiontests/multiple_database/tests.py
===================================================================
--- django/trunk/tests/regressiontests/multiple_database/tests.py       
2010-03-07 07:11:22 UTC (rev 12700)
+++ django/trunk/tests/regressiontests/multiple_database/tests.py       
2010-03-07 07:13:55 UTC (rev 12701)
@@ -641,6 +641,20 @@
         val = Book.objects.raw('SELECT id FROM 
"multiple_database_book"').using('other')
         self.assertEqual(map(lambda o: o.pk, val), [dive.pk])
 
+    def test_select_related(self):
+        "Database assignment is retained if an object is retrieved with 
select_related()"
+        # Create a book and author on the other database
+        mark = Person.objects.using('other').create(name="Mark Pilgrim")
+        dive = Book.objects.using('other').create(title="Dive into Python",
+                                                  
published=datetime.date(2009, 5, 4),
+                                                  editor=mark)
+
+        # Retrieve the Person using select_related()
+        book = 
Book.objects.using('other').select_related('editor').get(title="Dive into 
Python")
+
+        # The editor instance should have a db state
+        self.assertEqual(book.editor._state.db, 'other')
+
 class TestRouter(object):
     # A test router. The behaviour is vaguely master/slave, but the
     # databases aren't assumed to propagate changes.

-- 
You received this message because you are subscribed to the Google Groups 
"Django updates" group.
To post to this group, send email to django-upda...@googlegroups.com.
To unsubscribe from this group, send email to 
django-updates+unsubscr...@googlegroups.com.
For more options, visit this group at 
http://groups.google.com/group/django-updates?hl=en.

Reply via email to