This is an automated email from the ASF dual-hosted git repository. taolv pushed a commit to branch v1.7.x in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.7.x by this push: new 8695537 add logic for no batch size while getting data arrays from executors (#17772) (#18122) 8695537 is described below commit 86955370cd868b5d4f46f2f80f7632fd864773e3 Author: Manu Seth <22492939+mset...@users.noreply.github.com> AuthorDate: Thu Apr 23 01:14:44 2020 -0700 add logic for no batch size while getting data arrays from executors (#17772) (#18122) Co-authored-by: Ubuntu <ubuntu@ip-172-31-94-123.ec2.internal> Co-authored-by: Ubuntu <ubuntu@ip-172-31-94-123.ec2.internal> --- python/mxnet/module/executor_group.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/python/mxnet/module/executor_group.py b/python/mxnet/module/executor_group.py index d47665d..f2cb62f 100755 --- a/python/mxnet/module/executor_group.py +++ b/python/mxnet/module/executor_group.py @@ -308,8 +308,16 @@ class DataParallelExecutorGroup(object): def _collect_arrays(self): """Collect internal arrays from executors.""" # convenient data structures - self.data_arrays = [[(self.slices[i], e.arg_dict[name]) for i, e in enumerate(self.execs)] - for name, _ in self.data_shapes] + + # check if self.slices is populated, if not then that means that there is no batch size + if self.slices: + # based on batch size, slice up data for the given contexts (self.execs) + self.data_arrays = [[(self.slices[i], e.arg_dict[name]) for i, e in enumerate(self.execs)] + for name, _ in self.data_shapes] + else: + # just use the context index as index into the data + self.data_arrays = [[(slice(i, i+1), e.arg_dict[name]) for i, e in enumerate(self.execs)] + for name, _ in self.data_shapes] self.state_arrays = [[e.arg_dict[name] for e in self.execs] for name in self.state_names]