reminisce commented on a change in pull request #10451: [WIP] Add Foreach URL: https://github.com/apache/incubator-mxnet/pull/10451#discussion_r188506820
########## File path: python/mxnet/symbol/contrib.py ########## @@ -91,3 +98,99 @@ def rand_zipfian(true_classes, num_sampled, range_max): expected_prob_sampled = ((sampled_cls_fp64 + 2.0) / (sampled_cls_fp64 + 1.0)).log() / log_range expected_count_sampled = expected_prob_sampled * num_sampled return sampled_classes, expected_count_true, expected_count_sampled + +def _get_graph_inputs(subg, name, prefix): + num_handles = ctypes.c_int(1000) + handles = c_array(SymbolHandle, [SymbolHandle(0) for i in range(1000)]) + check_call(_LIB.MXSymbolGetInputSymbols(subg.handle, handles, + ctypes.byref(num_handles))) + + syms = [] + for i in range(num_handles.value): + s = Symbol(handles[i]) + syms.append(s) + return syms + +def foreach(func, input, init_states, back_prop=False, name="foreach"): + assert isinstance(init_states, list), "init_states should be a list" + states = [] + with AttrScope(subgraph_name=name): + if isinstance(input, list): + in_eles = [symbol.var(sym.name) for sym in input] + else: + in_eles = symbol.var(input.name) + for s in init_states: + states.append(symbol.var(s.name)) + + sym_out = func(in_eles, states) + # The function should return a tuple. The first element goes to + # the output of the function. The second element is a list. + assert isinstance(sym_out, tuple), "func should return a tuple (out, states)" + assert isinstance(sym_out[1], list), \ + "the second element in the returned tuple should be a list" + assert len(sym_out[1]) == len(init_states), \ + "the number of output states (%d) should be the same as input states (%d)" \ + % (len(sym_out[1]), len(init_states)) + + if (isinstance(sym_out[0], list)): + flat_out = sym_out[0] + else: + flat_out = [sym_out[0]] + num_out_data = len(flat_out) + for s in sym_out[1]: + # There is a problem if the outputs are the same as the inputs + # or the first output. + # TODO this is a temp fix. + flat_out.append(symbol.op.identity(s)) + g = symbol.Group(flat_out) + input_syms = _get_graph_inputs(g, name, "ro_var") + + if (isinstance(input, list)): + num_inputs = len(input) + else: + num_inputs = 1 + + # Here we need to find out how the input symbols are ordered as well as + # where the loop states are located in the list of inputs. + + # This dict contains the symbols of the subgraph. + input_syms = {sym.name:sym for sym in input_syms} + gin_names = input_syms.keys() + # This array contains the symbols for the inputs of foreach. + # They are ordered according to the inputs of the subgraph. + ordered_ins = [] + states_map = {sym.name:sym for sym in init_states} + state_names = states_map.keys() + data_syms = _as_list(input) + data_map = {sym.name:sym for sym in data_syms} + data_names = data_map.keys() + in_state_locs = [] + in_data_locs = [] + for in_name in g.list_inputs(): + assert in_name in gin_names, "The input variable %s can't be found in graph inputs: %s" \ + % (in_name, str(gin_names)) + if (in_name in state_names): + ordered_ins.append(states_map[in_name]) + in_state_locs.append(len(ordered_ins) - 1) + elif (in_name in data_names): + ordered_ins.append(data_map[in_name]) + in_data_locs.append(len(ordered_ins) - 1) + else: + ordered_ins.append(input_syms[in_name]) + + num_outputs = len(flat_out) + num_states = len(state_names) + ret = symbol._internal._foreach(g, *ordered_ins, num_outputs=num_outputs, + num_out_data=num_out_data, in_state_locs=in_state_locs, + in_data_locs=in_data_locs) + if (num_outputs - num_states > 1): Review comment: No parentheses. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services