kevinthesun commented on a change in pull request #5306: [Torch] Support Python
list, more realistic recurrent networks
URL: https://github.com/apache/incubator-tvm/pull/5306#discussion_r407269285
##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -1077,6 +1186,50 @@ def _impl(inputs, input_types):
return _op.cast(inputs[0], "float32")
return _impl
+
+def _mm():
+ def _impl(inputs, input_types):
+ return _op.nn.dense(inputs[0], inputs[1])
+ return _impl
+
+
+def _list_getitem(prelude):
+ def _impl(inputs, input_types):
+ return prelude.nth(inputs[0], _wrap_const(inputs[1]))
+ return _impl
+
+
+def _list_len(prelude):
+ def _impl(inputs, input_types):
+ return prelude.length(inputs[0])
+ return _impl
+
+
+def _add(prelude):
+ # add_ is overloaded for tensor add and list concat
+ def _impl(inputs, input_types):
+ if input_types[0] == "ListType":
+ return prelude.concat(inputs[0], inputs[1])
+ return _elemwise("add")(inputs, input_types)
+ return _impl
+
+
+def _tensor_array_stack(prelude):
+ def _impl(inputs, input_types):
+ tensor_array = _convert_to_tensor_array(inputs[0], prelude)
+ shape = get_tensor_array_shape(tensor_array, "float32", prelude)
+ stack = prelude.get_var_static('tensor_array_stack', "float32", shape)
Review comment:
I think we need to register static tensor array op first?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services