This is an automated email from the ASF dual-hosted git repository.

ptrendx pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 9c5b95a  Partition API adding and deleting new params to Block and 
Symbol (#18405)
9c5b95a is described below

commit 9c5b95a9c5d6f83a067504fb47fac4e3aed27e81
Author: Serge Panev <spa...@nvidia.com>
AuthorDate: Mon Jul 13 11:45:29 2020 -0700

    Partition API adding and deleting new params to Block and Symbol (#18405)
    
    * Add deleting of args aux aux to Partition API
    
    Signed-off-by: Serge Panev <spa...@nvidia.com>
    
    * Delete args from Block.params
    
    Signed-off-by: Serge Panev <spa...@nvidia.com>
    
    * Fix to use arg/auxdict when optimize_for is called in HybridBlock
    
    Signed-off-by: Serge Panev <spa...@nvidia.com>
    
    * Address PR comments
    
    Signed-off-by: Serge Panev <spa...@nvidia.com>
---
 python/mxnet/gluon/block.py   | 105 ++++++++++++++++++++++++++++++------------
 python/mxnet/symbol/symbol.py |  32 ++++++++++++-
 2 files changed, 106 insertions(+), 31 deletions(-)

diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 2fda080..1f9cd43 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -1040,41 +1040,69 @@ class HybridBlock(Block):
             warnings.warn("Parameter %s is not used by any computation. "
                           "Is this intended?"%unused, stacklevel=4)
 
-        data_indices = []
-        param_indices = []
-        self._cached_op_args = []
-        for i, name in enumerate(input_names):
-            if name in data_names:
-                data_indices.append(i)
-                self._cached_op_args.append((True, data_names[name]))
-            else:
-                param_indices.append(i)
-                self._cached_op_args.append((False, params[name]))
-        flags = [('data_indices', data_indices), ('param_indices', 
param_indices)] + \
-                self._flags
-
         args, _ = _flatten(args, "input")
         try:
-            for is_arg, i in self._cached_op_args:
-                if not is_arg:
-                    i.data()
+            for name in input_names:
+                if name in params:
+                    params[name].data()
         except DeferredInitializationError:
             self._deferred_infer_shape(*args)
-            for is_arg, i in self._cached_op_args:
-                if not is_arg:
-                    i._finish_deferred_init()
+            for name in input_names:
+                if name in params:
+                    params[name]._finish_deferred_init()
 
+        arg_dict, aux_dict = dict(), dict()
         if self._backend:
             ctx = args[0].context
             # get list of params in the order of out.list_arguments
-            arg_dict = {name:args[data_names[name]] if name in 
data_names.keys() else params[name].data()
-                        for name in out.list_arguments()}
-            aux_dict = {name:args[data_names[name]] if name in 
data_names.keys() else params[name].data()
-                        for name in out.list_auxiliary_states()}
+            arg_dict.update({name:args[data_names[name]] if name in 
data_names.keys() else params[name].data()
+                             for name in out.list_arguments()})
+            aux_dict.update({name:args[data_names[name]] if name in 
data_names.keys() else params[name].data()
+                             for name in out.list_auxiliary_states()})
             # Partition the graph.
             out = out.optimize_for(self._backend, arg_dict, aux_dict, ctx, 
**self._backend_opts)
+
             #update cached graph with partitioned graph
             self._cached_graph = data, out
+
+        input_names = out.list_inputs()
+        data_indices = []
+        param_indices = []
+
+        # In the default case, _cached_ops_args contains all the parameters 
from params (the sets are identical)
+        # In the case of Partition API optimized graph _cached_ops_args might 
contain some parameters from params,
+        # might contain some new parameters created during optimization and 
added to `arg_dict/aux_dict`,
+        # and might not contain some parameters that were deleted during 
optimization.
+        self._cached_op_args = []
+        for i, name in enumerate(input_names):
+            pair = None
+            if name in data_names:
+                data_indices.append(i)
+                pair = (True, data_names[name])
+            else:
+                param_indices.append(i)
+                if name in params:
+                    param = params[name]
+                else:
+                    # The param is missing from the original params 
dictionary, which means the param must have
+                    # been added by the Partition API backend
+                    if name in arg_dict or name:
+                        param_data = arg_dict[name]
+                    elif name in aux_dict:
+                        param_data = aux_dict[name]
+                    else:
+                        raise RuntimeError('A parameter was added to the graph 
during optimization but it was not '
+                                           'added to the parameter dicts.\n'
+                                           'Please check the backend.')
+
+                    param = Parameter(name)
+                    param._load_init(param_data, args[0].context)
+                pair = (False, param)
+
+            self._cached_op_args.append(pair)
+
+        flags = [('data_indices', data_indices), ('param_indices', 
param_indices)] + \
+                self._flags
         self._cached_op = ndarray.CachedOp(out, flags)
 
 
@@ -1321,12 +1349,14 @@ class HybridBlock(Block):
         arg_names = set(sym.list_arguments())
         aux_names = set(sym.list_auxiliary_states())
         arg_dict = {}
-        for param in self.collect_params().values():
-            if param.name in arg_names:
-                arg_dict['arg:%s'%param.name] = param._reduce()
-            else:
-                assert param.name in aux_names
-                arg_dict['aux:%s'%param.name] = param._reduce()
+        for is_arg, param in self._cached_op_args:
+            if not is_arg:
+                name = param.name
+                if name in arg_names:
+                    arg_dict['arg:{}'.format(name)] = param._reduce()
+                else:
+                    assert name in aux_names
+                    arg_dict['aux:{}'.format(name)] = param._reduce()
         save_fn = _mx_npx.save if is_np_array() else ndarray.save
         params_filename = '%s-%04d.params'%(path, epoch)
         save_fn(params_filename, arg_dict)
@@ -1437,6 +1467,23 @@ class HybridBlock(Block):
         # pylint: disable= invalid-name
         raise NotImplementedError
 
+    def reset_ctx(self, ctx):
+        """Re-assign all Parameters to other contexts. If the Block is 
hybridized, it will reset the _cached_op_args.
+
+        Parameters
+        ----------
+        ctx : Context or list of Context, default 
:py:meth:`context.current_context()`.
+            Assign Parameter to given context. If ctx is a list of Context, a
+            copy will be made for each context.
+        """
+        params = self.collect_params()
+        if self._cached_op:
+            for p in self._cached_op_args:
+                # resetting parameters creating by the partitioning backend
+                if p.name not in params:
+                    p.reset_ctx(ctx)
+        for p in params.values():
+            p.reset_ctx(ctx)
 
 class SymbolBlock(HybridBlock):
     """Construct block from symbol. This is useful for using pre-trained models
diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py
index 39b8799..89ff6bf 100644
--- a/python/mxnet/symbol/symbol.py
+++ b/python/mxnet/symbol/symbol.py
@@ -1544,8 +1544,36 @@ class Symbol(SymbolBase):
             raise RuntimeError('Cannot add new aux in optimize_for since aux 
is None\n' +
                                'Provide a dictionary to the aux argument to 
optimize_for')
 
-        # return modified symbol
-        return Symbol(out)
+        new_sym = Symbol(out)
+
+        arg_names = self.list_arguments()
+        new_arg_names = new_sym.list_arguments()
+        deleted_arg_names = set([item for item in arg_names
+                                 if item not in set(new_arg_names)])
+
+        if len(deleted_arg_names) > 0:
+            if args is not None:
+                for a_n in deleted_arg_names:
+                    if a_n in args:
+                        args.pop(a_n)
+            else:
+                warnings.warn('A param was deleted during optimization, but no 
args dictionary was provided.\n' +
+                              'Please ensure that your model weights match the 
newly optimized model.')
+
+        aux_names = self.list_auxiliary_states()
+        new_aux_names = new_sym.list_auxiliary_states()
+        deleted_aux_names = set([item for item in aux_names
+                                 if item not in set(new_aux_names)])
+        if len(deleted_aux_names) > 0:
+            if aux is not None:
+                for a_n in deleted_aux_names:
+                    if a_n in aux:
+                        aux.pop(a_n)
+            else:
+                warnings.warn('A param was deleted during optimization, but no 
args dictionary was provided.\n' +
+                              'Please ensure that your model weights match the 
newly optimized model.')
+
+        return new_sym
 
     # pylint: disable=too-many-locals
     def _simple_bind(self, ctx, grad_req='write', type_dict=None, 
stype_dict=None,

Reply via email to