zheng-da closed pull request #12151: fix a minor bug in while_loop URL: https://github.com/apache/incubator-mxnet/pull/12151
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/python/mxnet/symbol/contrib.py b/python/mxnet/symbol/contrib.py index 38195bd62ff..f89c73164fe 100644 --- a/python/mxnet/symbol/contrib.py +++ b/python/mxnet/symbol/contrib.py @@ -539,9 +539,6 @@ def _union_inputs(*graphs): # find symbols used in either cond_g or func_g input_syms, ((cond_input_locs, _), (func_input_locs, func_var_locs)) = \ _union_inputs(cond_g, func_g) - for i_th, loc in enumerate(func_var_locs, 1): - if loc == -1: - raise ValueError("The %d-th loop_var doesn't involve into the computation" % i_th) result = symbol._internal._while_loop( # [cond, func_g, *input_syms] cond_g, diff --git a/tests/python/unittest/test_contrib_control_flow.py b/tests/python/unittest/test_contrib_control_flow.py index a4b794c9595..7205b55ec52 100644 --- a/tests/python/unittest/test_contrib_control_flow.py +++ b/tests/python/unittest/test_contrib_control_flow.py @@ -139,6 +139,30 @@ def hybrid_forward(self, F, *loop_vars): assert result_s.asscalar() == 0 +def test_while_loop2(): + class TestBlock(gluon.HybridBlock): + def __init__(self, prefix=None, params=None): + super(TestBlock, self).__init__(prefix=prefix, params=params) + + # In this test, body_func only accesses one of the states, + # so not all loop variables are used. + def hybrid_forward(self, F, data): + def cond_func(state1, state2): + return state1 > 0 + def body_func(state1, state2): + return (state2, [state2 + 1, state2 + 2]) + return F.contrib.while_loop( + cond=cond_func, + func=body_func, + loop_vars=[data, data + 1], + max_iterations=10) + + block = TestBlock() + block.initialize(ctx=default_context()) + block.hybridize() + block(mx.nd.ones((1))) + + def _verify_while_loop(cond, func, loop_var_shapes, free_var_shapes, is_train, max_iterations, is_for, n_steps): def _create_vars(num, prefix): ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services