reminisce commented on a change in pull request #10451: [WIP] Add Foreach URL: https://github.com/apache/incubator-mxnet/pull/10451#discussion_r188504970
########## File path: python/mxnet/symbol/contrib.py ########## @@ -91,3 +98,99 @@ def rand_zipfian(true_classes, num_sampled, range_max): expected_prob_sampled = ((sampled_cls_fp64 + 2.0) / (sampled_cls_fp64 + 1.0)).log() / log_range expected_count_sampled = expected_prob_sampled * num_sampled return sampled_classes, expected_count_true, expected_count_sampled + +def _get_graph_inputs(subg, name, prefix): + num_handles = ctypes.c_int(1000) + handles = c_array(SymbolHandle, [SymbolHandle(0) for i in range(1000)]) + check_call(_LIB.MXSymbolGetInputSymbols(subg.handle, handles, + ctypes.byref(num_handles))) + + syms = [] + for i in range(num_handles.value): + s = Symbol(handles[i]) + syms.append(s) + return syms + +def foreach(func, input, init_states, back_prop=False, name="foreach"): + assert isinstance(init_states, list), "init_states should be a list" + states = [] + with AttrScope(subgraph_name=name): + if isinstance(input, list): + in_eles = [symbol.var(sym.name) for sym in input] + else: + in_eles = symbol.var(input.name) + for s in init_states: + states.append(symbol.var(s.name)) + + sym_out = func(in_eles, states) + # The function should return a tuple. The first element goes to + # the output of the function. The second element is a list. + assert isinstance(sym_out, tuple), "func should return a tuple (out, states)" + assert isinstance(sym_out[1], list), \ + "the second element in the returned tuple should be a list" + assert len(sym_out[1]) == len(init_states), \ + "the number of output states (%d) should be the same as input states (%d)" \ + % (len(sym_out[1]), len(init_states)) + + if (isinstance(sym_out[0], list)): Review comment: No parentheses needed. It would result in coding style error in PyCharm. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services