xidulu commented on a change in pull request #17360: add random.multivariate_normal, fix empty_like dtype problem, fix gat… URL: https://github.com/apache/incubator-mxnet/pull/17360#discussion_r367862554
########## File path: python/mxnet/numpy_op_fallback.py ########## @@ -158,3 +165,62 @@ def infer_shape(self, in_shape): def create_operator(self, ctx, in_shapes, in_dtypes): return Unravel_index(self._shape) + +@use_np +class MultivariateNormal(operator.CustomOp): + """Fallback to the front-end implementation of random.multivariate_normal.""" + def __init__(self, size=None): + super(MultivariateNormal, self).__init__() + self._size = size + + def forward(self, is_train, req, in_data, out_data, aux): + loc = in_data[0] + cov = in_data[1] + scale = _mx_np.linalg.cholesky(cov) + # set context + noise = _mx_np.random.normal(size=out_data[0].shape, dtype=loc.dtype, ctx=loc.ctx) + out = loc + _mx_np.einsum('...jk,...j->...k', scale, noise) + self.assign(out_data[0], req[0], out) + + def backward(self, req, out_grad, in_data, out_data, in_grad, aux): + raise NotImplementedError('Operator random.multivariate_normal' + ' does not support gradient computation') + + +@register('mvn_fallback') +class MultivariateNormalProp(operator.CustomOpProp): + """Fallback np.random.multivariate_normal operator properties.""" + + def __init__(self, size=None): + super(MultivariateNormalProp, self).__init__(need_top_grad=True) + self._size = ast.literal_eval( + size) if size is not None else None + + def list_arguments(self): + return ['mean', 'variance'] + + def infer_shape(self, in_shape): + loc_shape = in_shape[0] + cov_shape = in_shape[1] + if len(loc_shape) < 1: + raise ValueError("mean must be at least 1 dimensional") + if len(cov_shape) < 2: + raise ValueError("cov must be at least 2 dimensional") + if cov_shape[-1] != cov_shape[-1]: + raise ValueError("the last two dimentions of the parameter cov have to be the same," + " whereas the shape of cov is {}".format(cov_shape)) + if cov_shape[-1] != loc_shape[-1]: + raise ValueError("mean and cov must have same length." + "The shape of mean is {} but the shape of cov is {}" + .format(loc_shape[-1:], cov_shape[-2:])) + # handle shape mismatch here + out_shape = np.broadcast(np.empty(loc_shape), np.empty(cov_shape[:-1])).shape + if self._size is not None: + self._size = [self._size] if np.isscalar( + self._size) else list(self._size) + out_shape = self._size + list(out_shape) Review comment: You could directly concat two tuples together like `(1,2) + (3,4)` ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services