leezu commented on a change in pull request #18403:
URL: https://github.com/apache/incubator-mxnet/pull/18403#discussion_r432683605



##########
File path: python/mxnet/gluon/probability/block/stochastic_block.py
##########
@@ -0,0 +1,130 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=wildcard-import
+"""Stochastic block class."""
+__all__ = ['StochasticBlock', 'StochasticSequential']
+
+from functools import wraps
+from ...block import HybridBlock
+from ...nn.basic_layers import HybridSequential
+from ...utils import _indent
+
+
+class StochasticBlock(HybridBlock):
+    """`StochasticBlock` extends `HybridBlock` to support accumulating loss
+    in the forward phase, which is extremely useful in building Bayesian 
Neural Network,
+    where the loss function is composed of a classification loss and a KL loss.
+
+    """
+
+    def __init__(self, prefix=None, params=None):
+        super(StochasticBlock, self).__init__(prefix=prefix, params=params)
+        self._losses = []
+        self._losscache = []
+        self._count = 0
+
+    def add_loss(self, loss):
+        self._count += 1
+        self._losscache.append(loss)
+
+    @staticmethod
+    def collectLoss(func):
+        """To accumulate loss during the forward phase, one could first 
decorate
+        hybrid_forward with `StochasticBlock.collectLos`s`,
+        and then collect the loss tensor `x` by calling self.add_loss(x).
+        For example, in the following forward function,
+        we generate samples from a Gaussian parameterized by `loc` and `scale` 
and
+        accumulate the KL-divergence between it and its prior into the block's 
loss storage.:
+        @StochasticBlock.collectLoss
+        def hybrid_forward(self, F, loc, scale):
+            qz = mgp.Normal(loc, scale)
+            # prior
+            pz = mgp.Normal(F.np.zeros_like(loc), F.np.ones_like(scale))
+            self.add_loss(mgp.kl_divergence(qz, pz))
+            return qz.sample()
+        """
+        @wraps(func)
+        def inner(self, *args, **kwargs):
+            # Loss from hybrid_forward
+            func_out = func(self, *args, **kwargs)
+            collected_loss = self._losscache
+            self._losscache = []
+            return (func_out, collected_loss)
+
+        return inner
+
+    def __call__(self, *args):
+        """Calls forward. Only accepts positional arguments."""
+        for hook in self._forward_pre_hooks.values():
+            hook(self, args)
+        self._losses = []
+        out = self.forward(*args)  # out[0]: net output, out[1]: collected loss
+        self._losses.extend(out[1])
+        for hook in self._forward_hooks.values():
+            hook(self, args, out)
+        return out[0]
+
+    @property
+    def losses(self):
+        return self._losses
+
+
+class StochasticSequential(StochasticBlock):
+    """Stack StochasticBlock sequentially.
+    """
+
+    def __init__(self, prefix=None, params=None):
+        super(StochasticSequential, self).__init__(
+            prefix=prefix, params=params)
+        self._layers = []
+
+    def add(self, *blocks):
+        """Adds block on top of the stack."""
+        for block in blocks:
+            self._layers.append(block)
+            self.register_child(block)
+
+    @StochasticBlock.collectLoss
+    def hybrid_forward(self, F, x):
+        for block in self._layers:
+            x = block(x)

Review comment:
       For the `hybrid_forward` API, it seems you forgot to use `F` here?

##########
File path: python/mxnet/gluon/probability/block/stochastic_block.py
##########
@@ -0,0 +1,130 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=wildcard-import
+"""Stochastic block class."""
+__all__ = ['StochasticBlock', 'StochasticSequential']
+
+from functools import wraps
+from ...block import HybridBlock
+from ...nn.basic_layers import HybridSequential
+from ...utils import _indent
+
+
+class StochasticBlock(HybridBlock):
+    """`StochasticBlock` extends `HybridBlock` to support accumulating loss
+    in the forward phase, which is extremely useful in building Bayesian 
Neural Network,
+    where the loss function is composed of a classification loss and a KL loss.
+
+    """
+
+    def __init__(self, prefix=None, params=None):
+        super(StochasticBlock, self).__init__(prefix=prefix, params=params)
+        self._losses = []
+        self._losscache = []
+        self._count = 0

Review comment:
       Is it possible to just use len(_losscache) instead?

##########
File path: python/mxnet/gluon/probability/block/stochastic_block.py
##########
@@ -0,0 +1,130 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=wildcard-import

Review comment:
       Let's not disable pylint flags unnecessarily

##########
File path: src/operator/random/multisample_op.h
##########
@@ -67,7 +67,7 @@ inline bool MultiSampleOpShape(const nnvm::NodeAttrs& attrs,
   const MultiSampleParam& param = nnvm::get<MultiSampleParam>(attrs.parsed);
   mxnet::TShape sshape = param.shape;
   for (int i = 0; i < sshape.ndim(); ++i) {
-    CHECK_GT(sshape[i], 0) << "shape parameter must be non-zero within each 
dimension";
+    CHECK_GE(sshape[i], 0) << "shape parameter must be non-zero within each 
dimension";

Review comment:
       now the error message and the check don't align

##########
File path: python/mxnet/gluon/probability/block/stochastic_block.py
##########
@@ -0,0 +1,130 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=wildcard-import
+"""Stochastic block class."""
+__all__ = ['StochasticBlock', 'StochasticSequential']
+
+from functools import wraps
+from ...block import HybridBlock
+from ...nn.basic_layers import HybridSequential
+from ...utils import _indent
+
+
+class StochasticBlock(HybridBlock):
+    """`StochasticBlock` extends `HybridBlock` to support accumulating loss
+    in the forward phase, which is extremely useful in building Bayesian 
Neural Network,
+    where the loss function is composed of a classification loss and a KL loss.
+
+    """
+
+    def __init__(self, prefix=None, params=None):
+        super(StochasticBlock, self).__init__(prefix=prefix, params=params)
+        self._losses = []
+        self._losscache = []
+        self._count = 0
+
+    def add_loss(self, loss):
+        self._count += 1
+        self._losscache.append(loss)
+
+    @staticmethod
+    def collectLoss(func):
+        """To accumulate loss during the forward phase, one could first 
decorate
+        hybrid_forward with `StochasticBlock.collectLos`s`,
+        and then collect the loss tensor `x` by calling self.add_loss(x).
+        For example, in the following forward function,
+        we generate samples from a Gaussian parameterized by `loc` and `scale` 
and
+        accumulate the KL-divergence between it and its prior into the block's 
loss storage.:
+        @StochasticBlock.collectLoss
+        def hybrid_forward(self, F, loc, scale):
+            qz = mgp.Normal(loc, scale)
+            # prior
+            pz = mgp.Normal(F.np.zeros_like(loc), F.np.ones_like(scale))
+            self.add_loss(mgp.kl_divergence(qz, pz))
+            return qz.sample()
+        """
+        @wraps(func)
+        def inner(self, *args, **kwargs):
+            # Loss from hybrid_forward
+            func_out = func(self, *args, **kwargs)
+            collected_loss = self._losscache
+            self._losscache = []
+            return (func_out, collected_loss)
+
+        return inner
+
+    def __call__(self, *args):
+        """Calls forward. Only accepts positional arguments."""
+        for hook in self._forward_pre_hooks.values():
+            hook(self, args)
+        self._losses = []
+        out = self.forward(*args)  # out[0]: net output, out[1]: collected loss
+        self._losses.extend(out[1])
+        for hook in self._forward_hooks.values():
+            hook(self, args, out)
+        return out[0]
+
+    @property
+    def losses(self):
+        return self._losses
+
+
+class StochasticSequential(StochasticBlock):
+    """Stack StochasticBlock sequentially.
+    """
+
+    def __init__(self, prefix=None, params=None):
+        super(StochasticSequential, self).__init__(
+            prefix=prefix, params=params)
+        self._layers = []
+
+    def add(self, *blocks):
+        """Adds block on top of the stack."""
+        for block in blocks:
+            self._layers.append(block)
+            self.register_child(block)
+
+    @StochasticBlock.collectLoss
+    def hybrid_forward(self, F, x):
+        for block in self._layers:
+            x = block(x)
+            if hasattr(block, '_losses'):
+                self.add_loss(block._losses)
+        return x
+
+    def __repr__(self):
+        s = '{name}(\n{modstr}\n)'
+        modstr = '\n'.join(['  ({key}): {block}'.format(key=key,
+                                                        
block=_indent(block.__repr__(), 2))
+                            for key, block in self._children.items()])
+        return s.format(name=self.__class__.__name__,
+                        modstr=modstr)
+
+    def __getitem__(self, key):
+        layers = list(self._children.values())[key]

Review comment:
       Note the recent change to how HybridSequential stores the _children: 
https://github.com/apache/incubator-mxnet/pull/18376
   
   You need to take that into account and dereference the weakref here

##########
File path: python/mxnet/gluon/probability/transformation/transformation.py
##########
@@ -0,0 +1,289 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+"""Transformation Classes"""
+__all__ = ["Transformation", "TransformBlock","ComposeTransform", 
"ExpTransform",
+           "AffineTransform", "PowerTransform", "AbsTransform", 
'SigmoidTransform',
+           'SoftmaxTransform']
+
+from ..distributions.utils import _clip_prob, cached_property, sum_right_most
+from ...block import HybridBlock
+import weakref
+
+
+class Transformation(object):
+    r"""Abstract class for implementing invertible transformation
+    with computable log  det jacobians
+    
+    Attributes
+    ----------
+    bijective : bool
+        
+    """
+    bijective = False
+    event_dim = 0
+
+    def __init__(self, F=None):
+        self._inv = None
+        self._F = F
+        super(Transformation, self).__init__()
+
+    @property
+    def F(self):
+        return self._F
+
+    @F.setter
+    def F(self, value):
+        self._F = value
+
+    @property
+    def sign(self):
+        """
+        Returns the sign of the determinant of the Jacobian.
+        """
+        raise NotImplementedError
+
+    @property
+    def inv(self):
+        inv = None
+        if self._inv is not None:
+            inv = self._inv()
+        if inv is None:
+            inv = _InverseTransformation(self)
+            self._inv = weakref.ref(inv)

Review comment:
       The idea is to cache the inverse transformation as long as it's in use 
somewhere else?

##########
File path: python/mxnet/gluon/probability/transformation/transformation.py
##########
@@ -0,0 +1,289 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+"""Transformation Classes"""
+__all__ = ["Transformation", "TransformBlock","ComposeTransform", 
"ExpTransform",
+           "AffineTransform", "PowerTransform", "AbsTransform", 
'SigmoidTransform',
+           'SoftmaxTransform']
+
+from ..distributions.utils import _clip_prob, cached_property, sum_right_most
+from ...block import HybridBlock
+import weakref
+
+
+class Transformation(object):
+    r"""Abstract class for implementing invertible transformation
+    with computable log  det jacobians
+    
+    Attributes
+    ----------
+    bijective : bool
+        
+    """
+    bijective = False
+    event_dim = 0
+
+    def __init__(self, F=None):
+        self._inv = None
+        self._F = F
+        super(Transformation, self).__init__()
+
+    @property
+    def F(self):
+        return self._F
+
+    @F.setter
+    def F(self, value):
+        self._F = value
+
+    @property
+    def sign(self):
+        """
+        Returns the sign of the determinant of the Jacobian.
+        """
+        raise NotImplementedError
+
+    @property
+    def inv(self):
+        inv = None
+        if self._inv is not None:
+            inv = self._inv()
+        if inv is None:
+            inv = _InverseTransformation(self)
+            self._inv = weakref.ref(inv)
+        return inv
+
+    def __call__(self, x):
+        return self._forward_compute(x)
+
+    def _inv_call(self, y):
+        return self._inverse_compute(y)
+
+    def _forward_compute(self, x):
+        raise NotImplementedError
+
+    def _inverse_compute(self, x):
+        raise NotImplementedError
+    
+    def log_det_jacobian(self, x, y):
+        """
+        Compute the value of log(|dy/dx|)
+        """
+        raise NotImplementedError
+
+
+class _InverseTransformation(Transformation):
+    """
+    A private class representing the invert of `Transformation`,
+    which should be accessed through `Transformation.inv` property.
+    """
+    def __init__(self, forward_transformation):
+        super(_InverseTransformation, self).__init__()
+        self._inv = forward_transformation
+
+    @property
+    def inv(self):
+        return self._inv
+
+    @property
+    def sign(self):
+        return self._inv.sign
+
+    @property
+    def event_dim(self):
+        return self._inv.event_dim
+
+    def __call__(self, x):
+        return self._inv._inverse_compute(x)
+
+    def log_det_jacobian(self, x, y):
+        return -self._inv.log_det_jacobian(y, x)
+
+
+class TransformBlock(Transformation, HybridBlock):

Review comment:
       There may be some issues with multiple inheritance here if constructor 
arguments for the Block (params, prefix) are used. Did you check it works? But 
as we plan to remove the Block constructor arguments 
(https://github.com/apache/incubator-mxnet/pull/18413) it should be fine

##########
File path: tests/python/unittest/test_gluon_probability.py
##########
@@ -0,0 +1,2326 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import tempfile
+
+import mxnet as mx
+from mxnet import np, npx, autograd
+from mxnet import gluon
+import mxnet.gluon.probability as mgp
+from mxnet.gluon.probability import StochasticBlock, StochasticSequential
+from mxnet.gluon import HybridBlock
+from mxnet.test_utils import use_np, assert_almost_equal, set_default_context
+import numpy as _np
+from common import (setup_module, with_seed, assertRaises,
+                    assert_raises_cudnn_not_satisfied)
+from numpy.testing import assert_array_equal
+import pytest
+import scipy.stats as ss
+import scipy.special as scipy_special
+import warnings
+import json
+import unittest
+import random
+import itertools
+from numbers import Number
+
+# set_default_context(mx.gpu(0))

Review comment:
       You should delete this line here and create a separate file in the `gpu` 
test folder. In that file, import all functions from the current 
`test_gluon_probability.py` file

##########
File path: tests/python/unittest/test_gluon_probability.py
##########
@@ -0,0 +1,2326 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import tempfile
+
+import mxnet as mx
+from mxnet import np, npx, autograd
+from mxnet import gluon
+import mxnet.gluon.probability as mgp
+from mxnet.gluon.probability import StochasticBlock, StochasticSequential
+from mxnet.gluon import HybridBlock
+from mxnet.test_utils import use_np, assert_almost_equal, set_default_context
+import numpy as _np
+from common import (setup_module, with_seed, assertRaises,
+                    assert_raises_cudnn_not_satisfied)
+from numpy.testing import assert_array_equal
+import pytest
+import scipy.stats as ss
+import scipy.special as scipy_special
+import warnings
+import json
+import unittest
+import random
+import itertools
+from numbers import Number
+
+# set_default_context(mx.gpu(0))
+
+def prob_to_logit(prob):
+    return np.log(prob) - np.log1p(-prob)
+
+def _distribution_method_invoker(dist, func, *args):
+    """Wrapper for invoking different types of class methods with one unified
+    interface.
+
+    Parameters
+    ----------
+    dist : Distribution
+    func : method
+    """
+    if (len(args) == 0):
+        out = getattr(dist, func)
+        if callable(out):
+            return out()
+        else:
+            return out
+    return getattr(dist, func)(*args)
+
+
+def test_mgp_getF():
+    # Test getF
+    getF = mgp.utils.getF
+    nd = mx.nd
+    sym = mx.sym
+    assert getF(nd.ones((2,2)), nd.ones((2,2))) == nd
+    assert getF(sym.ones((2,2)), sym.ones((2,2))) == sym
+    assert getF(1.0, 2.0) == nd
+
+    # Test exception
+    try:
+        getF(nd.ones((2,2)), sym.ones((2,2)))
+    except TypeError as e:
+        pass
+
+    try:
+        getF(sym.ones((2,2)), nd.ones((2,2)))
+    except TypeError as e:
+        pass
+
+
+@with_seed()
+@use_np
+def test_gluon_uniform():
+    class TestUniform(HybridBlock):
+        def __init__(self, func):
+            super(TestUniform, self).__init__()
+            self._func = func
+
+        def hybrid_forward(self, F, low, high, *args):
+            uniform = mgp.Uniform(low, high, validate_args=True)
+            return _distribution_method_invoker(uniform, self._func, *args)
+
+    shapes = [(), (1,), (2, 3), 6]
+
+    # Test log_prob
+    for shape, hybridize in itertools.product(shapes, [True, False]):

Review comment:
       It's better to use pytest.parameterize to enable parallelization

##########
File path: python/mxnet/gluon/probability/block/stochastic_block.py
##########
@@ -0,0 +1,130 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=wildcard-import
+"""Stochastic block class."""
+__all__ = ['StochasticBlock', 'StochasticSequential']
+
+from functools import wraps
+from ...block import HybridBlock
+from ...nn.basic_layers import HybridSequential
+from ...utils import _indent
+
+
+class StochasticBlock(HybridBlock):
+    """`StochasticBlock` extends `HybridBlock` to support accumulating loss
+    in the forward phase, which is extremely useful in building Bayesian 
Neural Network,
+    where the loss function is composed of a classification loss and a KL loss.
+
+    """
+
+    def __init__(self, prefix=None, params=None):
+        super(StochasticBlock, self).__init__(prefix=prefix, params=params)
+        self._losses = []
+        self._losscache = []
+        self._count = 0
+
+    def add_loss(self, loss):
+        self._count += 1
+        self._losscache.append(loss)
+
+    @staticmethod
+    def collectLoss(func):
+        """To accumulate loss during the forward phase, one could first 
decorate
+        hybrid_forward with `StochasticBlock.collectLos`s`,

Review comment:
       Typo collectLos`s

##########
File path: python/mxnet/gluon/probability/distributions/distribution.py
##########
@@ -0,0 +1,196 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=wildcard-import
+"""Base distribution class."""
+__all__ = ['Distribution']
+
+from .utils import cached_property
+from numbers import Number
+
+
+class Distribution(object):
+    r"""Base class for distribution.
+    
+    Parameters
+    ----------
+    F : mx.ndarray or mx.symbol.numpy._Symbol
+        Variable that stores the running mode.
+    event_dim : int, default None
+        Variable indicating the dimension of the distribution's support.
+    validate_args : bool, default None
+        Whether to validate the distribution parameters
+    """          
+
+    # Variable indicating whether the sampling method has
+    # pathwise gradient.
+    has_grad = False
+    support = None
+    has_enumerate_support = False
+    arg_constraints = {}
+    _validate_args = False
+
+    @staticmethod
+    def set_default_validate_args(value):
+        if value not in [True, False]:
+            raise ValueError
+        Distribution._validate_args = value
+
+    def __init__(self, F=None, event_dim=None, validate_args=None):
+        self.F = F

Review comment:
       Should the F be removed as it's no longer needed?

##########
File path: python/mxnet/gluon/probability/transformation/domain_map.py
##########
@@ -0,0 +1,109 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+"""Classes for registering and storaging bijection/transformations from

Review comment:
       storing

##########
File path: python/mxnet/gluon/probability/block/stochastic_block.py
##########
@@ -0,0 +1,130 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=wildcard-import
+"""Stochastic block class."""
+__all__ = ['StochasticBlock', 'StochasticSequential']
+
+from functools import wraps
+from ...block import HybridBlock
+from ...nn.basic_layers import HybridSequential
+from ...utils import _indent
+
+
+class StochasticBlock(HybridBlock):
+    """`StochasticBlock` extends `HybridBlock` to support accumulating loss
+    in the forward phase, which is extremely useful in building Bayesian 
Neural Network,
+    where the loss function is composed of a classification loss and a KL loss.
+
+    """
+
+    def __init__(self, prefix=None, params=None):
+        super(StochasticBlock, self).__init__(prefix=prefix, params=params)
+        self._losses = []
+        self._losscache = []
+        self._count = 0
+
+    def add_loss(self, loss):
+        self._count += 1
+        self._losscache.append(loss)
+
+    @staticmethod
+    def collectLoss(func):
+        """To accumulate loss during the forward phase, one could first 
decorate
+        hybrid_forward with `StochasticBlock.collectLos`s`,
+        and then collect the loss tensor `x` by calling self.add_loss(x).
+        For example, in the following forward function,
+        we generate samples from a Gaussian parameterized by `loc` and `scale` 
and
+        accumulate the KL-divergence between it and its prior into the block's 
loss storage.:
+        @StochasticBlock.collectLoss
+        def hybrid_forward(self, F, loc, scale):
+            qz = mgp.Normal(loc, scale)
+            # prior
+            pz = mgp.Normal(F.np.zeros_like(loc), F.np.ones_like(scale))
+            self.add_loss(mgp.kl_divergence(qz, pz))
+            return qz.sample()
+        """
+        @wraps(func)
+        def inner(self, *args, **kwargs):
+            # Loss from hybrid_forward
+            func_out = func(self, *args, **kwargs)
+            collected_loss = self._losscache
+            self._losscache = []
+            return (func_out, collected_loss)
+
+        return inner
+
+    def __call__(self, *args):
+        """Calls forward. Only accepts positional arguments."""
+        for hook in self._forward_pre_hooks.values():
+            hook(self, args)
+        self._losses = []
+        out = self.forward(*args)  # out[0]: net output, out[1]: collected loss
+        self._losses.extend(out[1])
+        for hook in self._forward_hooks.values():
+            hook(self, args, out)
+        return out[0]
+
+    @property
+    def losses(self):
+        return self._losses
+
+
+class StochasticSequential(StochasticBlock):
+    """Stack StochasticBlock sequentially.
+    """
+
+    def __init__(self, prefix=None, params=None):
+        super(StochasticSequential, self).__init__(
+            prefix=prefix, params=params)
+        self._layers = []
+
+    def add(self, *blocks):
+        """Adds block on top of the stack."""
+        for block in blocks:
+            self._layers.append(block)
+            self.register_child(block)
+
+    @StochasticBlock.collectLoss
+    def hybrid_forward(self, F, x):

Review comment:
       It would be better to override `def forward(self, x)` here, as the 
`hybrid_forward` is deprecated: 
https://github.com/apache/incubator-mxnet/blob/5b9aedd933d0fd506de93a1680e107d1f8aa8983/python/mxnet/gluon/block.py#L862-L872

##########
File path: tests/python/unittest/test_gluon_probability.py
##########
@@ -0,0 +1,2326 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import tempfile
+
+import mxnet as mx
+from mxnet import np, npx, autograd
+from mxnet import gluon
+import mxnet.gluon.probability as mgp
+from mxnet.gluon.probability import StochasticBlock, StochasticSequential
+from mxnet.gluon import HybridBlock
+from mxnet.test_utils import use_np, assert_almost_equal, set_default_context
+import numpy as _np
+from common import (setup_module, with_seed, assertRaises,
+                    assert_raises_cudnn_not_satisfied)
+from numpy.testing import assert_array_equal
+import pytest
+import scipy.stats as ss
+import scipy.special as scipy_special
+import warnings
+import json
+import unittest
+import random
+import itertools
+from numbers import Number
+
+# set_default_context(mx.gpu(0))
+
+def prob_to_logit(prob):
+    return np.log(prob) - np.log1p(-prob)
+
+def _distribution_method_invoker(dist, func, *args):
+    """Wrapper for invoking different types of class methods with one unified
+    interface.
+
+    Parameters
+    ----------
+    dist : Distribution
+    func : method
+    """
+    if (len(args) == 0):
+        out = getattr(dist, func)
+        if callable(out):
+            return out()
+        else:
+            return out
+    return getattr(dist, func)(*args)
+
+
+def test_mgp_getF():
+    # Test getF
+    getF = mgp.utils.getF
+    nd = mx.nd
+    sym = mx.sym
+    assert getF(nd.ones((2,2)), nd.ones((2,2))) == nd
+    assert getF(sym.ones((2,2)), sym.ones((2,2))) == sym
+    assert getF(1.0, 2.0) == nd
+
+    # Test exception
+    try:
+        getF(nd.ones((2,2)), sym.ones((2,2)))
+    except TypeError as e:
+        pass

Review comment:
       This doesn't test exception correctly. The test won't fail if no 
exception is raised. Use pytest.raises




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to