xidulu commented on a change in pull request #17360: add 
random.multivariate_normal, fix empty_like dtype problem, fix gat…
URL: https://github.com/apache/incubator-mxnet/pull/17360#discussion_r367862554
 
 

 ##########
 File path: python/mxnet/numpy_op_fallback.py
 ##########
 @@ -158,3 +165,62 @@ def infer_shape(self, in_shape):
 
     def create_operator(self, ctx, in_shapes, in_dtypes):
         return Unravel_index(self._shape)
+
+@use_np
+class MultivariateNormal(operator.CustomOp):
+    """Fallback to the front-end implementation of 
random.multivariate_normal."""
+    def __init__(self, size=None):
+        super(MultivariateNormal, self).__init__()
+        self._size = size
+
+    def forward(self, is_train, req, in_data, out_data, aux):
+        loc = in_data[0]
+        cov = in_data[1]
+        scale = _mx_np.linalg.cholesky(cov)
+        # set context
+        noise = _mx_np.random.normal(size=out_data[0].shape, dtype=loc.dtype, 
ctx=loc.ctx)
+        out = loc + _mx_np.einsum('...jk,...j->...k', scale, noise)
+        self.assign(out_data[0], req[0], out)
+
+    def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
+        raise NotImplementedError('Operator random.multivariate_normal'
+                                  ' does not support gradient computation')
+
+
+@register('mvn_fallback')
+class MultivariateNormalProp(operator.CustomOpProp):
+    """Fallback np.random.multivariate_normal operator properties."""
+
+    def __init__(self, size=None):
+        super(MultivariateNormalProp, self).__init__(need_top_grad=True)
+        self._size = ast.literal_eval(
+            size) if size is not None else None
+
+    def list_arguments(self):
+        return ['mean', 'variance']
+
+    def infer_shape(self, in_shape):
+        loc_shape = in_shape[0]
+        cov_shape = in_shape[1]
+        if len(loc_shape) < 1:
+            raise ValueError("mean must be at least 1 dimensional")
+        if len(cov_shape) < 2:
+            raise ValueError("cov must be at least 2 dimensional")
+        if cov_shape[-1] != cov_shape[-1]:
+            raise ValueError("the last two dimentions of the parameter cov 
have to be the same,"
+                             " whereas the shape of cov is 
{}".format(cov_shape))
+        if cov_shape[-1] != loc_shape[-1]:
+            raise ValueError("mean and cov must have same length."
+                             "The shape of mean is {} but the shape of cov is 
{}"
+                             .format(loc_shape[-1:], cov_shape[-2:]))
+        # handle shape mismatch here
+        out_shape = np.broadcast(np.empty(loc_shape), 
np.empty(cov_shape[:-1])).shape
+        if self._size is not None:
+            self._size = [self._size] if np.isscalar(
+                self._size) else list(self._size)
+            out_shape = self._size + list(out_shape)
 
 Review comment:
   You could directly concat two tuples together like `(1,2) + (3,4)`

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to