reminisce commented on a change in pull request #10451: [WIP] Add Foreach
URL: https://github.com/apache/incubator-mxnet/pull/10451#discussion_r188506820
 
 

 ##########
 File path: python/mxnet/symbol/contrib.py
 ##########
 @@ -91,3 +98,99 @@ def rand_zipfian(true_classes, num_sampled, range_max):
     expected_prob_sampled = ((sampled_cls_fp64 + 2.0) / (sampled_cls_fp64 + 
1.0)).log() / log_range
     expected_count_sampled = expected_prob_sampled * num_sampled
     return sampled_classes, expected_count_true, expected_count_sampled
+
+def _get_graph_inputs(subg, name, prefix):
+    num_handles = ctypes.c_int(1000)
+    handles = c_array(SymbolHandle, [SymbolHandle(0) for i in range(1000)])
+    check_call(_LIB.MXSymbolGetInputSymbols(subg.handle, handles,
+        ctypes.byref(num_handles)))
+
+    syms = []
+    for i in range(num_handles.value):
+        s = Symbol(handles[i])
+        syms.append(s)
+    return syms
+
+def foreach(func, input, init_states, back_prop=False, name="foreach"):
+    assert isinstance(init_states, list), "init_states should be a list"
+    states = []
+    with AttrScope(subgraph_name=name):
+        if isinstance(input, list):
+            in_eles = [symbol.var(sym.name) for sym in input]
+        else:
+            in_eles = symbol.var(input.name)
+        for s in init_states:
+            states.append(symbol.var(s.name))
+
+        sym_out = func(in_eles, states)
+        # The function should return a tuple. The first element goes to
+        # the output of the function. The second element is a list.
+        assert isinstance(sym_out, tuple), "func should return a tuple (out, 
states)"
+        assert isinstance(sym_out[1], list), \
+                "the second element in the returned tuple should be a list"
+        assert len(sym_out[1]) == len(init_states), \
+                "the number of output states (%d) should be the same as input 
states (%d)" \
+                % (len(sym_out[1]), len(init_states))
+
+        if (isinstance(sym_out[0], list)):
+            flat_out = sym_out[0]
+        else:
+            flat_out = [sym_out[0]]
+        num_out_data = len(flat_out)
+        for s in sym_out[1]:
+            # There is a problem if the outputs are the same as the inputs
+            # or the first output.
+            # TODO this is a temp fix.
+            flat_out.append(symbol.op.identity(s))
+    g = symbol.Group(flat_out)
+    input_syms = _get_graph_inputs(g, name, "ro_var")
+
+    if (isinstance(input, list)):
+        num_inputs = len(input)
+    else:
+        num_inputs = 1
+
+    # Here we need to find out how the input symbols are ordered as well as
+    # where the loop states are located in the list of inputs.
+
+    # This dict contains the symbols of the subgraph.
+    input_syms = {sym.name:sym for sym in input_syms}
+    gin_names = input_syms.keys()
+    # This array contains the symbols for the inputs of foreach.
+    # They are ordered according to the inputs of the subgraph.
+    ordered_ins = []
+    states_map = {sym.name:sym for sym in init_states}
+    state_names = states_map.keys()
+    data_syms = _as_list(input)
+    data_map = {sym.name:sym for sym in data_syms}
+    data_names = data_map.keys()
+    in_state_locs = []
+    in_data_locs = []
+    for in_name in g.list_inputs():
+        assert in_name in gin_names, "The input variable %s can't be found in 
graph inputs: %s" \
+                % (in_name, str(gin_names))
+        if (in_name in state_names):
+            ordered_ins.append(states_map[in_name])
+            in_state_locs.append(len(ordered_ins) - 1)
+        elif (in_name in data_names):
+            ordered_ins.append(data_map[in_name])
+            in_data_locs.append(len(ordered_ins) - 1)
+        else:
+            ordered_ins.append(input_syms[in_name])
+
+    num_outputs = len(flat_out)
+    num_states = len(state_names)
+    ret = symbol._internal._foreach(g, *ordered_ins, num_outputs=num_outputs,
+                                    num_out_data=num_out_data, 
in_state_locs=in_state_locs,
+                                    in_data_locs=in_data_locs)
+    if (num_outputs - num_states > 1):
 
 Review comment:
   No parentheses.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to