This is an automated email from the ASF dual-hosted git repository.

taolv pushed a commit to branch v1.7.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.7.x by this push:
     new 8695537  add logic for no batch size while getting data arrays from 
executors (#17772) (#18122)
8695537 is described below

commit 86955370cd868b5d4f46f2f80f7632fd864773e3
Author: Manu Seth <22492939+mset...@users.noreply.github.com>
AuthorDate: Thu Apr 23 01:14:44 2020 -0700

    add logic for no batch size while getting data arrays from executors 
(#17772) (#18122)
    
    Co-authored-by: Ubuntu <ubuntu@ip-172-31-94-123.ec2.internal>
    
    Co-authored-by: Ubuntu <ubuntu@ip-172-31-94-123.ec2.internal>
---
 python/mxnet/module/executor_group.py | 12 ++++++++++--
 1 file changed, 10 insertions(+), 2 deletions(-)

diff --git a/python/mxnet/module/executor_group.py 
b/python/mxnet/module/executor_group.py
index d47665d..f2cb62f 100755
--- a/python/mxnet/module/executor_group.py
+++ b/python/mxnet/module/executor_group.py
@@ -308,8 +308,16 @@ class DataParallelExecutorGroup(object):
     def _collect_arrays(self):
         """Collect internal arrays from executors."""
         # convenient data structures
-        self.data_arrays = [[(self.slices[i], e.arg_dict[name]) for i, e in 
enumerate(self.execs)]
-                            for name, _ in self.data_shapes]
+
+        # check if self.slices is populated, if not then that means that there 
is no batch size
+        if self.slices:
+            # based on batch size, slice up data for the given contexts 
(self.execs)
+            self.data_arrays = [[(self.slices[i], e.arg_dict[name]) for i, e 
in enumerate(self.execs)]
+                                for name, _ in self.data_shapes]
+        else:
+            # just use the context index as index into the data
+            self.data_arrays = [[(slice(i, i+1), e.arg_dict[name]) for i, e in 
enumerate(self.execs)]
+                                for name, _ in self.data_shapes]
 
         self.state_arrays = [[e.arg_dict[name] for e in self.execs]
                              for name in self.state_names]

Reply via email to