New code for performing explicit joins with custom join conditions. * added ExtendedManager.join_custom_field(), which uses the introspection magic from populate_relationships (now factored out) to infer the type of relationship between two models and construct the correct join. join_custom_field() presents a much simpler, more Django-y interface for doing this sort of thing -- compare with add_join() above it. * changed TKO custom fields code to use join_custom_field() * added some cases to AFE rpc_interface_unittest to ensure populate_relationships() usage didn't break * simplified _CustomQuery and got rid of _CustomSqlQ. _CustomQuery can do the work itself and its cleaner this way. * added add_where(), an alternative to extra(where=...) that fits more into Django's normal representation of WHERE clauses, and therefore supports & and | operators later
Signed-off-by: Steve Howard <[email protected]> --- autotest/frontend/afe/model_logic.py 2010-01-19 11:34:44.000000000 -0800 +++ autotest/frontend/afe/model_logic.py 2010-01-19 11:34:44.000000000 -0800 @@ -6,6 +6,7 @@ import django.core.exceptions from django.db import models as dbmodels, backend, connection from django.db.models.sql import query +import django.db.models.sql.where from django.utils import datastructures from autotest_lib.frontend.afe import readonly_connection @@ -94,77 +95,97 @@ """ class _CustomQuery(query.Query): - def clone(self, klass=None, **kwargs): - obj = super(ExtendedManager._CustomQuery, self).clone( - klass, _customSqlQ=self._customSqlQ) + def __init__(self, *args, **kwargs): + super(ExtendedManager._CustomQuery, self).__init__(*args, **kwargs) + self._custom_joins = [] - customQ = kwargs.get('_customSqlQ', None) - if customQ is not None: - obj._customSqlQ._joins.update(customQ._joins) - obj._customSqlQ._where.extend(customQ._where) - obj._customSqlQ._params.extend(customQ._params) + def clone(self, klass=None, **kwargs): + obj = super(ExtendedManager._CustomQuery, self).clone(klass) + obj._custom_joins = list(self._custom_joins) return obj - def get_from_clause(self): - from_, params = super( - ExtendedManager._CustomQuery, self).get_from_clause() - - join_clause = '' - for join_alias, join in self._customSqlQ._joins.iteritems(): - join_table, join_type, condition = join - join_clause += ' %s %s AS %s ON (%s)' % ( - join_type, _quote_name(join_table), - _quote_name(join_alias), condition) - if join_clause: - from_.append(join_clause) + def combine(self, rhs, connector): + super(ExtendedManager._CustomQuery, self).combine(rhs, connector) + if hasattr(rhs, '_custom_joins'): + self._custom_joins.extend(rhs._custom_joins) - return from_, params + def add_custom_join(self, table, condition, join_type, + condition_values=(), alias=None): + if alias is None: + alias = table + join_dict = dict(table=table, + condition=condition, + condition_values=condition_values, + join_type=join_type, + alias=alias) + self._custom_joins.append(join_dict) - class _CustomSqlQ(dbmodels.Q): - def __init__(self): - self._joins = datastructures.SortedDict() - self._where, self._params = [], [] + def get_from_clause(self): + from_, params = (super(ExtendedManager._CustomQuery, self) + .get_from_clause()) - def add_join(self, table, condition, join_type, alias=None): - if alias is None: - alias = table - self._joins[alias] = (table, join_type, condition) + for join_dict in self._custom_joins: + from_.append('%s %s AS %s ON (%s)' + % (join_dict['join_type'], + _quote_name(join_dict['table']), + _quote_name(join_dict['alias']), + join_dict['condition'])) + params.extend(join_dict['condition_values']) + return from_, params - def add_where(self, where, params=[]): - self._where.append(where) - self._params.extend(params) + @classmethod + def convert_query(self, query_set): + """ + Convert the query set's "query" attribute to a _CustomQuery. + """ + # Make a copy of the query set + query_set = query_set.all() + query_set.query = query_set.query.clone( + klass=ExtendedManager._CustomQuery, + _custom_joins=[]) + return query_set - def add_to_query(self, query, aliases): - if self._where: - where = ' AND '.join(self._where) - query.add_extra(None, None, (where,), self._params, None, None) + class _WhereClause(object): + """Object allowing us to inject arbitrary SQL into Django queries. - def _add_customSqlQ(self, query_set, filter_object): - """\ - Add a _CustomSqlQ to the query set. + By using this instead of extra(where=...), we can still freely combine + queries with & and |. """ - # Make a copy of the query set - query_set = query_set.all() + def __init__(self, clause, values=()): + self._clause = clause + self._values = values - query_set.query = query_set.query.clone( - ExtendedManager._CustomQuery, _customSqlQ=filter_object) - return query_set.filter(filter_object) + + def as_sql(self, qn=None): + return self._clause, self._values + + + def relabel_aliases(self, change_map): + return def add_join(self, query_set, join_table, join_key, join_condition='', - alias=None, suffix='', exclude=False, force_left_join=False): - """ - Add a join to query_set. + join_condition_values=(), join_from_key=None, alias=None, + suffix='', exclude=False, force_left_join=False): + """Add a join to query_set. + + Join looks like this: + (INNER|LEFT) JOIN <join_table> AS <alias> + ON (<this table>.<join_from_key> = <join_table>.<join_key> + and <join_condition>) + @param join_table table to join to @param join_key field referencing back to this model to use for the join @param join_condition extra condition for the ON clause of the join + @param join_condition_values values to substitute into join_condition + @param join_from_key column on this model to join from. @param alias alias to use for for join @param suffix suffix to add to join_table for the join alias, if no alias is provided @@ -173,15 +194,15 @@ @param force_left_join - if true, a LEFT OUTER JOIN will be used instead of an INNER JOIN regardless of other options """ - join_from_table = _quote_name(self.model._meta.db_table) - join_from_key = _quote_name(self.model._meta.pk.name) - if alias: - join_alias = alias - else: - join_alias = join_table + suffix - full_join_key = _quote_name(join_alias) + '.' + _quote_name(join_key) - full_join_condition = '%s = %s.%s' % (full_join_key, join_from_table, - join_from_key) + join_from_table = query_set.model._meta.db_table + if join_from_key is None: + join_from_key = self.model._meta.pk.name + if alias is None: + alias = join_table + suffix + full_join_key = _quote_name(alias) + '.' + _quote_name(join_key) + full_join_condition = '%s = %s.%s' % (full_join_key, + _quote_name(join_from_table), + _quote_name(join_from_key)) if join_condition: full_join_condition += ' AND (' + join_condition + ')' if exclude or force_left_join: @@ -189,15 +210,128 @@ else: join_type = query_set.query.INNER - filter_object = self._CustomSqlQ() - filter_object.add_join(join_table, - full_join_condition, - join_type, - alias=join_alias) + query_set = self._CustomQuery.convert_query(query_set) + query_set.query.add_custom_join(join_table, + full_join_condition, + join_type, + condition_values=join_condition_values, + alias=alias) + if exclude: - filter_object.add_where(full_join_key + ' IS NULL') + query_set = query_set.extra(where=[full_join_key + ' IS NULL']) - query_set = self._add_customSqlQ(query_set, filter_object) + return query_set + + + def _info_for_many_to_one_join(self, field, join_to_query, alias): + """ + @param field: the ForeignKey field on the related model + @param join_to_query: the query over the related model that we're + joining to + @param alias: alias of joined table + """ + info = {} + rhs_table = join_to_query.model._meta.db_table + info['rhs_table'] = rhs_table + info['rhs_column'] = field.column + info['lhs_column'] = field.rel.get_related_field().column + rhs_where = join_to_query.query.where + rhs_where.relabel_aliases({rhs_table: alias}) + initial_clause, values = rhs_where.as_sql() + all_clauses = (initial_clause,) + join_to_query.query.extra_where + info['where_clause'] = ' AND '.join('(%s)' % clause + for clause in all_clauses) + values += join_to_query.query.extra_params + info['values'] = values + return info + + + def _info_for_many_to_many_join(self, m2m_field, join_to_query, alias, + m2m_is_on_this_model): + """ + @param m2m_field: a Django field representing the M2M relationship. + It uses a pivot table with the following structure: + this model table <---> M2M pivot table <---> joined model table + @param join_to_query: the query over the related model that we're + joining to. + @param alias: alias of joined table + """ + if m2m_is_on_this_model: + # referenced field on this model + lhs_id_field = self.model._meta.pk + # foreign key on the pivot table referencing lhs_id_field + m2m_lhs_column = m2m_field.m2m_column_name() + # foreign key on the pivot table referencing rhd_id_field + m2m_rhs_column = m2m_field.m2m_reverse_name() + # referenced field on related model + rhs_id_field = m2m_field.rel.get_related_field() + else: + lhs_id_field = m2m_field.rel.get_related_field() + m2m_lhs_column = m2m_field.m2m_reverse_name() + m2m_rhs_column = m2m_field.m2m_column_name() + rhs_id_field = join_to_query.model._meta.pk + + info = {} + info['rhs_table'] = m2m_field.m2m_db_table() + info['rhs_column'] = m2m_lhs_column + info['lhs_column'] = lhs_id_field.column + + # select the ID of related models relevant to this join. we can only do + # a single join, so we need to gather this information up front and + # include it in the join condition. + rhs_ids = join_to_query.values_list(rhs_id_field.attname, flat=True) + assert len(rhs_ids) == 1, ('Many-to-many custom field joins can only ' + 'match a single related object.') + rhs_id = rhs_ids[0] + + info['where_clause'] = '%s.%s = %s' % (_quote_name(alias), + _quote_name(m2m_rhs_column), + rhs_id) + info['values'] = () + return info + + + def join_custom_field(self, query_set, join_to_query, alias, + left_join=True): + """Join to a related model to create a custom field in the given query. + + This method is used to construct a custom field on the given query based + on a many-valued relationsip. join_to_query should be a simple query + (no joins) on the related model which returns at most one related row + per instance of this model. + + For many-to-one relationships, the joined table contains the matching + row from the related model it one is related, NULL otherwise. + + For many-to-many relationships, the joined table contains the matching + row if it's related, NULL otherwise. + """ + relationship_type, field = self.determine_relationship( + join_to_query.model) + + if relationship_type == self.MANY_TO_ONE: + info = self._info_for_many_to_one_join(field, join_to_query, alias) + elif relationship_type == self.M2M_ON_RELATED_MODEL: + info = self._info_for_many_to_many_join( + m2m_field=field, join_to_query=join_to_query, alias=alias, + m2m_is_on_this_model=False) + elif relationship_type ==self.M2M_ON_THIS_MODEL: + info = self._info_for_many_to_many_join( + m2m_field=field, join_to_query=join_to_query, alias=alias, + m2m_is_on_this_model=True) + + return self.add_join(query_set, info['rhs_table'], info['rhs_column'], + join_from_key=info['lhs_column'], + join_condition=info['where_clause'], + join_condition_values=info['values'], + alias=alias, + force_left_join=left_join) + + + def add_where(self, query_set, where, values=()): + query_set = query_set.all() + query_set.query.where.add(self._WhereClause(where, values), + django.db.models.sql.where.AND) return query_set @@ -235,43 +369,65 @@ return field.rel and field.rel.to is model_class - def _get_pivot_iterator(self, base_objects_by_id, related_model): + MANY_TO_ONE = object() + M2M_ON_RELATED_MODEL = object() + M2M_ON_THIS_MODEL = object() + + def determine_relationship(self, related_model): """ - Determine the relationship between this model and related_model, and - return a pivot iterator. - @param base_objects_by_id: dict of instances of this model indexed by - their IDs - @returns a pivot iterator, which yields a tuple (base_object, - related_object) for each relationship between a base object and a - related object. all base_object instances come from base_objects_by_id. - Note -- this depends on Django model internals and will likely need to - be updated when we move to Django 1.x. + Determine the relationship between this model and related_model. + + related_model must have some sort of many-valued relationship to this + manager's model. + @returns (relationship_type, field), where relationship_type is one of + MANY_TO_ONE, M2M_ON_RELATED_MODEL, M2M_ON_THIS_MODEL, and field + is the Django field object for the relationship. """ - # look for a field on related_model relating to this model + # look for a foreign key field on related_model relating to this model for field in related_model._meta.fields: if self._is_relation_to(field, self.model): - # many-to-one - return self._many_to_one_pivot(base_objects_by_id, - related_model, field) + return self.MANY_TO_ONE, field + # look for an M2M field on related_model relating to this model for field in related_model._meta.many_to_many: if self._is_relation_to(field, self.model): - # many-to-many - return self._many_to_many_pivot( - base_objects_by_id, related_model, field.m2m_db_table(), - field.m2m_reverse_name(), field.m2m_column_name()) + return self.M2M_ON_RELATED_MODEL, field # maybe this model has the many-to-many field for field in self.model._meta.many_to_many: if self._is_relation_to(field, related_model): - return self._many_to_many_pivot( - base_objects_by_id, related_model, field.m2m_db_table(), - field.m2m_column_name(), field.m2m_reverse_name()) + return self.M2M_ON_THIS_MODEL, field raise ValueError('%s has no relation to %s' % (related_model, self.model)) + def _get_pivot_iterator(self, base_objects_by_id, related_model): + """ + Determine the relationship between this model and related_model, and + return a pivot iterator. + @param base_objects_by_id: dict of instances of this model indexed by + their IDs + @returns a pivot iterator, which yields a tuple (base_object, + related_object) for each relationship between a base object and a + related object. all base_object instances come from base_objects_by_id. + Note -- this depends on Django model internals. + """ + relationship_type, field = self.determine_relationship(related_model) + if relationship_type == self.MANY_TO_ONE: + return self._many_to_one_pivot(base_objects_by_id, + related_model, field) + elif relationship_type == self.M2M_ON_RELATED_MODEL: + return self._many_to_many_pivot( + base_objects_by_id, related_model, field.m2m_db_table(), + field.m2m_reverse_name(), field.m2m_column_name()) + else: + assert relationship_type == self.M2M_ON_THIS_MODEL + return self._many_to_many_pivot( + base_objects_by_id, related_model, field.m2m_db_table(), + field.m2m_column_name(), field.m2m_reverse_name()) + + def _many_to_one_pivot(self, base_objects_by_id, related_model, foreign_key_field): """ --- autotest/frontend/afe/rpc_interface_unittest.py 2010-01-19 11:34:44.000000000 -0800 +++ autotest/frontend/afe/rpc_interface_unittest.py 2010-01-19 11:34:44.000000000 -0800 @@ -60,6 +60,12 @@ hosts = rpc_interface.get_hosts(hostname='host1') self._check_hostnames(hosts, ['host1']) + host = hosts[0] + self.assertEquals(sorted(host['labels']), ['label1', 'myplatform']) + self.assertEquals(host['platform'], 'myplatform') + self.assertEquals(host['atomic_group'], None) + self.assertEquals(host['acls'], ['my_acl']) + self.assertEquals(host['attributes'], {}) def test_get_hosts_multiple_labels(self): --- autotest/frontend/tko/models.py 2010-01-19 11:34:44.000000000 -0800 +++ autotest/frontend/tko/models.py 2010-01-19 11:34:44.000000000 -0800 @@ -327,12 +327,11 @@ second_join_condition = ('%s.id = %s.testlabel_id' % (second_join_alias, 'tko_test_labels_tests' + suffix)) - filter_object = self._CustomSqlQ() - filter_object.add_join('tko_test_labels', - second_join_condition, - query_set.query.LOUTER, - alias=second_join_alias) - return self._add_customSqlQ(query_set, filter_object) + query_set.query.add_custom_join('tko_test_labels', + second_join_condition, + query_set.query.LOUTER, + alias=second_join_alias) + return query_set def _get_label_ids_from_names(self, label_names): @@ -373,12 +372,10 @@ def _join_label_column(self, query_set, label_name, label_id): - table_name = TestLabel.tests.field.m2m_db_table() alias = 'label_' + label_name - condition = "%s.testlabel_id = %s" % (_quote_name(alias), label_id) - query_set = self.add_join(query_set, table_name, - join_key='test_id', join_condition=condition, - alias=alias, force_left_join=True) + label_query = TestLabel.objects.filter(name=label_name) + query_set = Test.objects.join_custom_field(query_set, label_query, + alias) query_set = self._add_select_ifnull(query_set, alias, label_name) return query_set @@ -392,23 +389,21 @@ return query_set - def _join_attribute(self, test_view_query_set, attribute, - alias=None, extra_join_condition=None): + def _join_attribute(self, query_set, attribute, alias=None, + extra_join_condition=None): """ Join the given TestView QuerySet to TestAttribute. The resulting query has an additional column for the given attribute named "attribute_<attribute name>". """ - table_name = TestAttribute._meta.db_table if not alias: alias = 'attribute_' + attribute - condition = "%s.attribute = '%s'" % (_quote_name(alias), - self.escape_user_sql(attribute)) + attribute_query = TestAttribute.objects.filter(attribute=attribute) if extra_join_condition: - condition += ' AND (%s)' % extra_join_condition - query_set = self.add_join(test_view_query_set, table_name, - join_key='test_idx', join_condition=condition, - alias=alias, force_left_join=True) + attribute_query = attribute_query.extra( + where=[extra_join_condition]) + query_set = Test.objects.join_custom_field(query_set, attribute_query, + alias) query_set = self._add_select_value(query_set, alias) return query_set @@ -427,23 +422,18 @@ def _join_one_iteration_key(self, query_set, result_key, first_alias=None): - table_name = IterationResult._meta.db_table alias = 'iteration_' + result_key - condition_parts = ["%s.attribute = '%s'" % - (_quote_name(alias), - self.escape_user_sql(result_key))] + iteration_query = IterationResult.objects.filter(attribute=result_key) if first_alias: # after the first join, we need to match up iteration indices, # otherwise each join will expand the query by the number of # iterations and we'll have extraneous rows - condition_parts.append('%s.iteration = %s.iteration' % - (_quote_name(alias), - _quote_name(first_alias))) - - condition = ' and '.join(condition_parts) - # add a join to IterationResult - query_set = self.add_join(query_set, table_name, join_key='test_idx', - join_condition=condition, alias=alias) + iteration_query = iteration_query.extra( + where=['%s.iteration = %s.iteration' + % (_quote_name(alias), _quote_name(first_alias))]) + + query_set = Test.objects.join_custom_field(query_set, iteration_query, + alias, left_join=False) # select the iteration value and index for this join query_set = self._add_select_value(query_set, alias) if not first_alias: --- autotest/frontend/tko/rpc_interface_unittest.py 2010-01-19 11:34:44.000000000 -0800 +++ autotest/frontend/tko/rpc_interface_unittest.py 2010-01-19 11:34:44.000000000 -0800 @@ -465,7 +465,7 @@ self.assertEquals(len(tests), 3) self.assertEquals(tests[0]['label_testlabel1'], 'testlabel1') - self.assert_(tests[0]['label_testlabel2'], 'testlabel2') + self.assertEquals(tests[0]['label_testlabel2'], 'testlabel2') for index in (1, 2): self.assertEquals(tests[index]['label_testlabel1'], None) --- autotest/scheduler/monitor_db.py 2010-01-19 11:34:44.000000000 -0800 +++ autotest/scheduler/monitor_db.py 2010-01-19 11:34:44.000000000 -0800 @@ -935,10 +935,10 @@ host__locked=False) # exclude hosts with active queue entries unless the SpecialTask is for # that queue entry - queued_tasks = models.Host.objects.add_join( + queued_tasks = models.SpecialTask.objects.add_join( queued_tasks, 'afe_host_queue_entries', 'host_id', join_condition='afe_host_queue_entries.active', - force_left_join=True) + join_from_key='host_id', force_left_join=True) queued_tasks = queued_tasks.extra( where=['(afe_host_queue_entries.id IS NULL OR ' 'afe_host_queue_entries.id = ' _______________________________________________ Autotest mailing list [email protected] http://test.kernel.org/cgi-bin/mailman/listinfo/autotest
