This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4021eec821 [fix] MXNet dot for all tensor dimensions (#11760)
4021eec821 is described below
commit 4021eec8214afb01daf3793552ca05702e642534
Author: Uroš Petković <[email protected]>
AuthorDate: Sun Jan 1 11:51:39 2023 +0100
[fix] MXNet dot for all tensor dimensions (#11760)
* [fix] MXNet dot for all tensor dimensions
* Fixing the MxNet structure
---
python/tvm/relay/frontend/mxnet.py | 45 +++++++++++++++++++++++++----
tests/python/frontend/mxnet/test_forward.py | 11 +++++++
2 files changed, 50 insertions(+), 6 deletions(-)
diff --git a/python/tvm/relay/frontend/mxnet.py
b/python/tvm/relay/frontend/mxnet.py
index 1b1d601199..4e6540fb08 100644
--- a/python/tvm/relay/frontend/mxnet.py
+++ b/python/tvm/relay/frontend/mxnet.py
@@ -795,19 +795,52 @@ def _mx_multibox_detection(inputs, attrs):
def _mx_dot(inputs, attrs):
assert len(inputs) == 2
- a, b = inputs
+
+ a = inputs[0]
+ b = inputs[1]
+
rank_a = len(_infer_type(a).checked_type.shape)
rank_b = len(_infer_type(b).checked_type.shape)
- if rank_a != 2 or rank_b != 2:
- raise tvm.error.OpAttributeUnimplemented("Only 2-D arrays are
supported.")
+
+ if rank_a < 1 or rank_b < 1:
+ raise tvm.error.OpAttributeInvalid("Unsupported shape of input
tensors.")
+
transpose_a = attrs.get_bool("transpose_a", False)
transpose_b = attrs.get_bool("transpose_b", False)
+
if transpose_a is True:
msg = 'Value {} in attribute "transpose_a" of operator dot ' "is not
valid."
raise tvm.error.OpAttributeInvalid(msg.format(transpose_a))
- if transpose_b is False:
- b = _op.transpose(b, axes=[1, 0])
- return _op.nn.dense(a, b)
+
+ # When performing dot product we need to properly handle shape of result
-> out_shape
+ if rank_a == 1:
+ out_shape = list()
+ a = _op.expand_dims(a, axis=0)
+ else:
+ shape_a = list(_infer_type(a).checked_type.shape)
+ out_shape = shape_a[:-1]
+ a = _op.reshape(a, newshape=(-1, shape_a[-1]))
+
+ if rank_b == 1:
+ if not out_shape:
+ out_shape = [
+ 1,
+ ]
+ b = _op.expand_dims(b, axis=1)
+ else:
+ # Transpose matrix b if needed
+ if transpose_b:
+ trans_axes = list(range(rank_b))
+ trans_axes = trans_axes[-1:] + trans_axes[:-1]
+ b = _op.transpose(b, axes=trans_axes)
+
+ shape_b = list(_infer_type(b).checked_type.shape)
+ out_shape += shape_b[1:]
+ b = _op.reshape(b, newshape=(shape_b[0], -1))
+
+ out = _op.reshape(_op.nn.matmul(a, b), newshape=out_shape)
+
+ return out
def _mx_batch_dot(inputs, attrs):
diff --git a/tests/python/frontend/mxnet/test_forward.py
b/tests/python/frontend/mxnet/test_forward.py
index 44aa93061a..0e34719ea2 100644
--- a/tests/python/frontend/mxnet/test_forward.py
+++ b/tests/python/frontend/mxnet/test_forward.py
@@ -690,6 +690,17 @@ def test_forward_dot():
verify((1, 256), (256, 1))
verify((1, 256), (1, 256), transpose_b=True)
+ verify((5,), (5,))
+ verify((3,), (3, 5))
+ verify((3,), (5, 3), transpose_b=True)
+ verify((3,), (3, 5, 3, 5))
+ verify((3,), (5, 5, 3, 3), transpose_b=True)
+ verify((10, 1), (1,))
+ verify((1, 1), (4, 3, 2, 1), transpose_b=True)
+ verify((4, 3, 2, 1), (1,))
+ verify((1, 2, 3, 4), (1, 4), transpose_b=True)
+ verify((4, 1, 1), (1, 2, 3))
+ verify((1, 1, 4), (2, 3, 4), transpose_b=True)
@tvm.testing.uses_gpu