changeset f6615c7c6c21 in trytond:default
details: https://hg.tryton.org/trytond?cmd=changeset;node=f6615c7c6c21
description:
        Update MPTT only for affected fields

        issue10130
        review353331002
diffstat:

 trytond/model/modelsql.py     |  89 +++++++++++++++++++++---------------------
 trytond/model/modelstorage.py |  11 +++++
 2 files changed, 55 insertions(+), 45 deletions(-)

diffs (151 lines):

diff -r 318efe295241 -r f6615c7c6c21 trytond/model/modelsql.py
--- a/trytond/model/modelsql.py Tue Mar 02 08:38:10 2021 +0100
+++ b/trytond/model/modelsql.py Fri Mar 05 18:22:55 2021 +0100
@@ -1,7 +1,7 @@
 # This file is part of Tryton.  The COPYRIGHT file at the top level of
 # this repository contains the full copyright notices and license terms.
 import datetime
-from itertools import islice, chain, product, groupby
+from itertools import islice, chain, product, groupby, repeat
 from collections import OrderedDict, defaultdict
 from functools import wraps
 
@@ -665,8 +665,9 @@
 
         cls._insert_history(new_ids)
 
-        field_names = list(cls._fields.keys())
-        cls._update_mptt(field_names, [new_ids] * len(field_names))
+        if cls._mptt_fields:
+            field_names = list(sorted(cls._mptt_fields))
+            cls._update_mptt(field_names, repeat(new_ids, len(field_names)))
 
         cls.__check_domain_rule(new_ids, 'create')
         records = cls.browse(new_ids)
@@ -1027,9 +1028,12 @@
                 if hasattr(field, 'set'):
                     fields_to_set.setdefault(fname, []).extend((ids, value))
 
-            field_names = list(values.keys())
-            cls._update_mptt(field_names, [ids] * len(field_names), values)
-            all_field_names |= set(field_names)
+            mptt_fields = cls._mptt_fields & set(values)
+            if mptt_fields:
+                cls._update_mptt(
+                    list(sorted(mptt_fields)), repeat(ids, len(mptt_fields)),
+                    values)
+            all_field_names |= set(values)
 
         for fname in sorted(fields_to_set, key=cls.index_set_field):
             fargs = fields_to_set[fname]
@@ -1067,20 +1071,18 @@
         cls.__check_timestamp(ids)
         cls.__check_domain_rule(ids, 'delete')
 
-        has_translation = False
         tree_ids = {}
-        for fname, field in cls._fields.items():
-            if (isinstance(field, fields.Many2One)
-                    and field.model_name == cls.__name__
-                    and field.left and field.right):
-                tree_ids[fname] = []
-                for sub_ids in grouped_slice(ids):
-                    where = reduce_ids(field.sql_column(table), sub_ids)
-                    cursor.execute(*table.select(table.id, where=where))
-                    tree_ids[fname] += [x[0] for x in cursor.fetchall()]
-            if (getattr(field, 'translate', False)
-                    and not hasattr(field, 'set')):
-                has_translation = True
+        for fname in cls._mptt_fields:
+            field = cls._fields[fname]
+            tree_ids[fname] = []
+            for sub_ids in grouped_slice(ids):
+                where = reduce_ids(field.sql_column(table), sub_ids)
+                cursor.execute(*table.select(table.id, where=where))
+                tree_ids[fname] += [x[0] for x in cursor.fetchall()]
+
+        has_translation = any(
+            getattr(f, 'translate', False) and not hasattr(f, 'set')
+            for f in cls._fields.values())
 
         foreign_keys_tocheck = []
         foreign_keys_toupdate = []
@@ -1463,34 +1465,31 @@
         cursor = Transaction().connection.cursor()
         for field_name, ids in zip(field_names, list_ids):
             field = cls._fields[field_name]
-            if (isinstance(field, fields.Many2One)
-                    and field.model_name == cls.__name__
-                    and field.left and field.right):
-                if (values is not None
-                        and (field.left in values or field.right in values)):
-                    raise Exception('ValidateError',
-                        'You can not update fields: "%s", "%s"' %
-                        (field.left, field.right))
+            if (values is not None
+                    and (field.left in values or field.right in values)):
+                raise Exception('ValidateError',
+                    'You can not update fields: "%s", "%s"' %
+                    (field.left, field.right))
 
-                # Nested creation require a rebuild
-                # because initial values are 0
-                # and thus _update_tree can not find the children
-                table = cls.__table__()
-                parent = cls.__table__()
-                cursor.execute(*table.join(parent,
-                        condition=Column(table, field_name) == parent.id
-                        ).select(table.id,
-                        where=(Column(parent, field.left) == 0)
-                        & (Column(parent, field.right) == 0),
-                        limit=1))
-                nested_create = cursor.fetchone()
+            # Nested creation require a rebuild
+            # because initial values are 0
+            # and thus _update_tree can not find the children
+            table = cls.__table__()
+            parent = cls.__table__()
+            cursor.execute(*table.join(parent,
+                    condition=Column(table, field_name) == parent.id
+                    ).select(table.id,
+                    where=(Column(parent, field.left) == 0)
+                    & (Column(parent, field.right) == 0),
+                    limit=1))
+            nested_create = cursor.fetchone()
 
-                if not nested_create and len(ids) < 2:
-                    for id_ in ids:
-                        cls._update_tree(id_, field_name,
-                            field.left, field.right)
-                else:
-                    cls._rebuild_tree(field_name, None, 0)
+            if not nested_create and len(ids) < 2:
+                for id_ in ids:
+                    cls._update_tree(id_, field_name,
+                        field.left, field.right)
+            else:
+                cls._rebuild_tree(field_name, None, 0)
 
     @classmethod
     def _rebuild_tree(cls, parent, parent_id, left):
diff -r 318efe295241 -r f6615c7c6c21 trytond/model/modelstorage.py
--- a/trytond/model/modelstorage.py     Tue Mar 02 08:38:10 2021 +0100
+++ b/trytond/model/modelstorage.py     Fri Mar 05 18:22:55 2021 +0100
@@ -141,6 +141,17 @@
                     'import_data': RPC(readonly=False),
                     })
 
+    @classmethod
+    def __post_setup__(cls):
+        super().__post_setup__()
+
+        cls._mptt_fields = set()
+        for name, field in cls._fields.items():
+            if (isinstance(field, fields.Many2One)
+                    and field.model_name == cls.__name__
+                    and field.left and field.right):
+                cls._mptt_fields.add(name)
+
     @staticmethod
     def default_create_uid():
         "Default value for uid field."

Reply via email to