szha commented on a change in pull request #12512: [MXNET-938] Allow fusing weight ahead of calling forward() URL: https://github.com/apache/incubator-mxnet/pull/12512#discussion_r217132819
########## File path: python/mxnet/gluon/rnn/rnn_layer.py ########## @@ -209,12 +220,16 @@ def _forward_kernel(self, F, inputs, states, **kwargs): """ forward using CUDNN or CPU kenrel""" if self._layout == 'NTC': inputs = F.swapaxes(inputs, dim1=0, dim2=1) - params = (kwargs['{}{}_{}_{}'.format(d, l, g, t)].reshape(-1) - for t in ['weight', 'bias'] - for l in range(self._num_layers) - for d in ['l', 'r'][:self._dir] - for g in ['i2h', 'h2h']) - params = F._internal._rnn_param_concat(*params, dim=0) + + if F is ndarray and self.fused_params is not None: + params = self.fused_params Review comment: The hybrid_forward isn't always invoked with the same context, so even if the parameter doesn't change, the context may be wrong. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services