masahi commented on a change in pull request #5306: [Torch] Support Python
list, more realistic recurrent networks
URL: https://github.com/apache/incubator-tvm/pull/5306#discussion_r407273147
##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -1077,6 +1186,50 @@ def _impl(inputs, input_types):
return _op.cast(inputs[0], "float32")
return _impl
+
+def _mm():
+ def _impl(inputs, input_types):
+ return _op.nn.dense(inputs[0], inputs[1])
+ return _impl
+
+
+def _list_getitem(prelude):
+ def _impl(inputs, input_types):
+ return prelude.nth(inputs[0], _wrap_const(inputs[1]))
+ return _impl
+
+
+def _list_len(prelude):
+ def _impl(inputs, input_types):
+ return prelude.length(inputs[0])
+ return _impl
+
+
+def _add(prelude):
+ # add_ is overloaded for tensor add and list concat
+ def _impl(inputs, input_types):
+ if input_types[0] == "ListType":
+ return prelude.concat(inputs[0], inputs[1])
+ return _elemwise("add")(inputs, input_types)
+ return _impl
+
+
+def _tensor_array_stack(prelude):
+ def _impl(inputs, input_types):
+ tensor_array = _convert_to_tensor_array(inputs[0], prelude)
+ shape = get_tensor_array_shape(tensor_array, "float32", prelude)
+ stack = prelude.get_var_static('tensor_array_stack', "float32", shape)
Review comment:
It is registered in `_convert_to_tensor_array` when creating a tensor array.
Since tensor array is created only through this function, I think it is ok and
cleaner than adding get_shape, `StaticTensorArrayOps(...)`, and register().
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services