reminisce commented on a change in pull request #10451: [WIP] Add Foreach
URL: https://github.com/apache/incubator-mxnet/pull/10451#discussion_r188504970
 
 

 ##########
 File path: python/mxnet/symbol/contrib.py
 ##########
 @@ -91,3 +98,99 @@ def rand_zipfian(true_classes, num_sampled, range_max):
     expected_prob_sampled = ((sampled_cls_fp64 + 2.0) / (sampled_cls_fp64 + 
1.0)).log() / log_range
     expected_count_sampled = expected_prob_sampled * num_sampled
     return sampled_classes, expected_count_true, expected_count_sampled
+
+def _get_graph_inputs(subg, name, prefix):
+    num_handles = ctypes.c_int(1000)
+    handles = c_array(SymbolHandle, [SymbolHandle(0) for i in range(1000)])
+    check_call(_LIB.MXSymbolGetInputSymbols(subg.handle, handles,
+        ctypes.byref(num_handles)))
+
+    syms = []
+    for i in range(num_handles.value):
+        s = Symbol(handles[i])
+        syms.append(s)
+    return syms
+
+def foreach(func, input, init_states, back_prop=False, name="foreach"):
+    assert isinstance(init_states, list), "init_states should be a list"
+    states = []
+    with AttrScope(subgraph_name=name):
+        if isinstance(input, list):
+            in_eles = [symbol.var(sym.name) for sym in input]
+        else:
+            in_eles = symbol.var(input.name)
+        for s in init_states:
+            states.append(symbol.var(s.name))
+
+        sym_out = func(in_eles, states)
+        # The function should return a tuple. The first element goes to
+        # the output of the function. The second element is a list.
+        assert isinstance(sym_out, tuple), "func should return a tuple (out, 
states)"
+        assert isinstance(sym_out[1], list), \
+                "the second element in the returned tuple should be a list"
+        assert len(sym_out[1]) == len(init_states), \
+                "the number of output states (%d) should be the same as input 
states (%d)" \
+                % (len(sym_out[1]), len(init_states))
+
+        if (isinstance(sym_out[0], list)):
 
 Review comment:
   No parentheses needed. It would result in coding style error in PyCharm.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to