zheng-da closed pull request #12151: fix a minor bug in while_loop
URL: https://github.com/apache/incubator-mxnet/pull/12151
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py
index 38195bd62ff..f89c73164fe 100644
--- a/python/mxnet/symbol/contrib.py
+++ b/python/mxnet/symbol/contrib.py
@@ -539,9 +539,6 @@ def _union_inputs(*graphs):
     # find symbols used in either cond_g or func_g
     input_syms, ((cond_input_locs, _), (func_input_locs, func_var_locs)) = \
         _union_inputs(cond_g, func_g)
-    for i_th, loc in enumerate(func_var_locs, 1):
-        if loc == -1:
-            raise ValueError("The %d-th loop_var doesn't involve into the 
computation" % i_th)
     result = symbol._internal._while_loop(
         # [cond, func_g, *input_syms]
         cond_g,
diff --git a/tests/python/unittest/test_contrib_control_flow.py 
b/tests/python/unittest/test_contrib_control_flow.py
index a4b794c9595..7205b55ec52 100644
--- a/tests/python/unittest/test_contrib_control_flow.py
+++ b/tests/python/unittest/test_contrib_control_flow.py
@@ -139,6 +139,30 @@ def hybrid_forward(self, F, *loop_vars):
         assert result_s.asscalar() == 0
 
 
+def test_while_loop2():
+    class TestBlock(gluon.HybridBlock):
+        def __init__(self, prefix=None, params=None):
+            super(TestBlock, self).__init__(prefix=prefix, params=params)
+
+        # In this test, body_func only accesses one of the states,
+        # so not all loop variables are used.
+        def hybrid_forward(self, F, data):
+            def cond_func(state1, state2):
+                return state1 > 0
+            def body_func(state1, state2):
+                return (state2, [state2 + 1, state2 + 2])
+            return F.contrib.while_loop(
+                    cond=cond_func,
+                    func=body_func,
+                    loop_vars=[data, data + 1],
+                    max_iterations=10)
+
+    block = TestBlock()
+    block.initialize(ctx=default_context())
+    block.hybridize()
+    block(mx.nd.ones((1)))
+
+
 def _verify_while_loop(cond, func, loop_var_shapes, free_var_shapes, is_train, 
max_iterations, is_for, n_steps):
 
     def _create_vars(num, prefix):


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to