Author: russellm
Date: 2009-04-30 10:40:09 -0500 (Thu, 30 Apr 2009)
New Revision: 10648

Modified:
   django/trunk/django/db/models/query.py
   django/trunk/django/db/models/sql/query.py
   django/trunk/django/db/models/sql/subqueries.py
   django/trunk/tests/regressiontests/extra_regress/models.py
Log:
Fixed #10847 -- Modified handling of extra() to use a masking strategy, rather 
than last-minute trimming. Thanks to Tai Lee for the report, and Alex Gaynor 
for his work on the patch.

This enables querysets with an extra clause to be used in an __in filter; as a 
side effect, it also means that as_sql() now returns the correct result for any 
query with an extra clause.

Modified: django/trunk/django/db/models/query.py
===================================================================
--- django/trunk/django/db/models/query.py      2009-04-30 13:49:14 UTC (rev 
10647)
+++ django/trunk/django/db/models/query.py      2009-04-30 15:40:09 UTC (rev 
10648)
@@ -715,9 +715,6 @@
 
     def iterator(self):
         # Purge any extra columns that haven't been explicitly asked for
-        if self.extra_names is not None:
-            self.query.trim_extra_select(self.extra_names)
-
         extra_names = self.query.extra_select.keys()
         field_names = self.field_names
         aggregate_names = self.query.aggregate_select.keys()
@@ -741,13 +738,18 @@
         if self._fields:
             self.extra_names = []
             self.aggregate_names = []
-            if not self.query.extra_select and not self.query.aggregate_select:
+            if not self.query.extra and not self.query.aggregates:
+                # Short cut - if there are no extra or aggregates, then
+                # the values() clause must be just field names.
                 self.field_names = list(self._fields)
             else:
                 self.query.default_cols = False
                 self.field_names = []
                 for f in self._fields:
-                    if self.query.extra_select.has_key(f):
+                    # we inspect the full extra_select list since we might
+                    # be adding back an extra select item that we hadn't
+                    # had selected previously.
+                    if self.query.extra.has_key(f):
                         self.extra_names.append(f)
                     elif self.query.aggregate_select.has_key(f):
                         self.aggregate_names.append(f)
@@ -760,6 +762,8 @@
             self.aggregate_names = None
 
         self.query.select = []
+        if self.extra_names is not None:
+            self.query.set_extra_mask(self.extra_names)
         self.query.add_fields(self.field_names, False)
         if self.aggregate_names is not None:
             self.query.set_aggregate_mask(self.aggregate_names)
@@ -816,9 +820,6 @@
 
 class ValuesListQuerySet(ValuesQuerySet):
     def iterator(self):
-        if self.extra_names is not None:
-            self.query.trim_extra_select(self.extra_names)
-
         if self.flat and len(self._fields) == 1:
             for row in self.query.results_iter():
                 yield row[0]

Modified: django/trunk/django/db/models/sql/query.py
===================================================================
--- django/trunk/django/db/models/sql/query.py  2009-04-30 13:49:14 UTC (rev 
10647)
+++ django/trunk/django/db/models/sql/query.py  2009-04-30 15:40:09 UTC (rev 
10648)
@@ -88,7 +88,10 @@
 
         # These are for extensions. The contents are more or less appended
         # verbatim to the appropriate clause.
-        self.extra_select = SortedDict()  # Maps col_alias -> (col_sql, 
params).
+        self.extra = SortedDict()  # Maps col_alias -> (col_sql, params).
+        self.extra_select_mask = None
+        self._extra_select_cache = None
+
         self.extra_tables = ()
         self.extra_where = ()
         self.extra_params = ()
@@ -214,13 +217,21 @@
         if self.aggregate_select_mask is None:
             obj.aggregate_select_mask = None
         else:
-            obj.aggregate_select_mask = self.aggregate_select_mask[:]
+            obj.aggregate_select_mask = self.aggregate_select_mask.copy()
         if self._aggregate_select_cache is None:
             obj._aggregate_select_cache = None
         else:
             obj._aggregate_select_cache = self._aggregate_select_cache.copy()
         obj.max_depth = self.max_depth
-        obj.extra_select = self.extra_select.copy()
+        obj.extra = self.extra.copy()
+        if self.extra_select_mask is None:
+            obj.extra_select_mask = None
+        else:
+            obj.extra_select_mask = self.extra_select_mask.copy()
+        if self._extra_select_cache is None:
+            obj._extra_select_cache = None
+        else:
+            obj._extra_select_cache = self._extra_select_cache.copy()
         obj.extra_tables = self.extra_tables
         obj.extra_where = self.extra_where
         obj.extra_params = self.extra_params
@@ -325,7 +336,7 @@
             query = self
             self.select = []
             self.default_cols = False
-            self.extra_select = {}
+            self.extra = {}
             self.remove_inherited_models()
 
         query.clear_ordering(True)
@@ -540,13 +551,20 @@
             # It would be nice to be able to handle this, but the queries don't
             # really make sense (or return consistent value sets). Not worth
             # the extra complexity when you can write a real query instead.
-            if self.extra_select and rhs.extra_select:
+            if self.extra and rhs.extra:
                 raise ValueError("When merging querysets using 'or', you "
                         "cannot have extra(select=...) on both sides.")
             if self.extra_where and rhs.extra_where:
                 raise ValueError("When merging querysets using 'or', you "
                         "cannot have extra(where=...) on both sides.")
-        self.extra_select.update(rhs.extra_select)
+        self.extra.update(rhs.extra)
+        extra_select_mask = set()
+        if self.extra_select_mask is not None:
+            extra_select_mask.update(self.extra_select_mask)
+        if rhs.extra_select_mask is not None:
+            extra_select_mask.update(rhs.extra_select_mask)
+        if extra_select_mask:
+            self.set_extra_mask(extra_select_mask)
         self.extra_tables += rhs.extra_tables
         self.extra_where += rhs.extra_where
         self.extra_params += rhs.extra_params
@@ -2011,7 +2029,7 @@
         except MultiJoin:
             raise FieldError("Invalid field name: '%s'" % name)
         except FieldError:
-            names = opts.get_all_field_names() + self.extra_select.keys() + 
self.aggregate_select.keys()
+            names = opts.get_all_field_names() + self.extra.keys() + 
self.aggregate_select.keys()
             names.sort()
             raise FieldError("Cannot resolve keyword %r into field. "
                     "Choices are: %s" % (name, ", ".join(names)))
@@ -2139,7 +2157,7 @@
                     pos = entry.find("%s", pos + 2)
                 select_pairs[name] = (entry, entry_params)
             # This is order preserving, since self.extra_select is a 
SortedDict.
-            self.extra_select.update(select_pairs)
+            self.extra.update(select_pairs)
         if where:
             self.extra_where += tuple(where)
         if params:
@@ -2213,22 +2231,26 @@
         """
         target[model] = set([f.name for f in fields])
 
-    def trim_extra_select(self, names):
-        """
-        Removes any aliases in the extra_select dictionary that aren't in
-        'names'.
-
-        This is needed if we are selecting certain values that don't incldue
-        all of the extra_select names.
-        """
-        for key in set(self.extra_select).difference(set(names)):
-            del self.extra_select[key]
-
     def set_aggregate_mask(self, names):
         "Set the mask of aggregates that will actually be returned by the 
SELECT"
-        self.aggregate_select_mask = names
+        if names is None:
+            self.aggregate_select_mask = None
+        else:
+            self.aggregate_select_mask = set(names)
         self._aggregate_select_cache = None
 
+    def set_extra_mask(self, names):
+        """
+        Set the mask of extra select items that will be returned by SELECT,
+        we don't actually remove them from the Query since they might be used
+        later
+        """
+        if names is None:
+            self.extra_select_mask = None
+        else:
+            self.extra_select_mask = set(names)
+        self._extra_select_cache = None
+
     def _aggregate_select(self):
         """The SortedDict of aggregate columns that are not masked, and should
         be used in the SELECT clause.
@@ -2247,6 +2269,19 @@
             return self.aggregates
     aggregate_select = property(_aggregate_select)
 
+    def _extra_select(self):
+        if self._extra_select_cache is not None:
+            return self._extra_select_cache
+        elif self.extra_select_mask is not None:
+            self._extra_select_cache = SortedDict([
+                (k,v) for k,v in self.extra.items()
+                if k in self.extra_select_mask
+            ])
+            return self._extra_select_cache
+        else:
+            return self.extra
+    extra_select = property(_extra_select)
+
     def set_start(self, start):
         """
         Sets the table from which to start joining. The start position is

Modified: django/trunk/django/db/models/sql/subqueries.py
===================================================================
--- django/trunk/django/db/models/sql/subqueries.py     2009-04-30 13:49:14 UTC 
(rev 10647)
+++ django/trunk/django/db/models/sql/subqueries.py     2009-04-30 15:40:09 UTC 
(rev 10648)
@@ -178,7 +178,7 @@
         # from other tables.
         query = self.clone(klass=Query)
         query.bump_prefix()
-        query.extra_select = {}
+        query.extra = {}
         query.select = []
         query.add_fields([query.model._meta.pk.name])
         must_pre_select = count > 1 and not 
self.connection.features.update_can_self_select
@@ -409,7 +409,7 @@
         self.select = [select]
         self.select_fields = [None]
         self.select_related = False # See #7097.
-        self.extra_select = {}
+        self.extra = {}
         self.distinct = True
         self.order_by = order == 'ASC' and [1] or [-1]
 

Modified: django/trunk/tests/regressiontests/extra_regress/models.py
===================================================================
--- django/trunk/tests/regressiontests/extra_regress/models.py  2009-04-30 
13:49:14 UTC (rev 10647)
+++ django/trunk/tests/regressiontests/extra_regress/models.py  2009-04-30 
15:40:09 UTC (rev 10648)
@@ -35,6 +35,9 @@
     second = models.CharField(max_length=20)
     third = models.CharField(max_length=20)
 
+    def __unicode__(self):
+        return u'TestObject: %s,%s,%s' % (self.first,self.second,self.third)
+
 __test__ = {"API_TESTS": """
 # Regression tests for #7314 and #7372
 
@@ -189,6 +192,19 @@
 >>> TestObject.objects.extra(select=SortedDict((('foo','first'),('bar','second'),('whiz','third')))).values_list('whiz',
 >>>  'first', 'bar', 'id')
 [(u'third', u'first', u'second', 1)]
 
-"""}
+# Regression for #10847: the list of extra columns can always be accurately 
evaluated.
+# Using an inner query ensures that as_sql() is producing correct output
+# without requiring full evaluation and execution of the inner query.
+>>> TestObject.objects.extra(select={'extra': 1}).values('pk')
+[{'pk': 1}]
 
+>>> TestObject.objects.filter(pk__in=TestObject.objects.extra(select={'extra': 
1}).values('pk'))
+[<TestObject: TestObject: first,second,third>]
 
+>>> TestObject.objects.values('pk').extra(select={'extra': 1})
+[{'pk': 1}]
+
+>>> 
TestObject.objects.filter(pk__in=TestObject.objects.values('pk').extra(select={'extra':
 1}))
+[<TestObject: TestObject: first,second,third>]
+
+"""}


--~--~---------~--~----~------------~-------~--~----~
You received this message because you are subscribed to the Google Groups 
"Django updates" group.
To post to this group, send email to django-updates@googlegroups.com
To unsubscribe from this group, send email to 
django-updates+unsubscr...@googlegroups.com
For more options, visit this group at 
http://groups.google.com/group/django-updates?hl=en
-~----------~----~----~----~------~----~------~--~---

Reply via email to