This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4021eec821 [fix] MXNet dot for all tensor dimensions (#11760)
4021eec821 is described below

commit 4021eec8214afb01daf3793552ca05702e642534
Author: Uroš Petković <[email protected]>
AuthorDate: Sun Jan 1 11:51:39 2023 +0100

    [fix] MXNet dot for all tensor dimensions (#11760)
    
    * [fix] MXNet dot for all tensor dimensions
    
    * Fixing the MxNet structure
---
 python/tvm/relay/frontend/mxnet.py          | 45 +++++++++++++++++++++++++----
 tests/python/frontend/mxnet/test_forward.py | 11 +++++++
 2 files changed, 50 insertions(+), 6 deletions(-)

diff --git a/python/tvm/relay/frontend/mxnet.py 
b/python/tvm/relay/frontend/mxnet.py
index 1b1d601199..4e6540fb08 100644
--- a/python/tvm/relay/frontend/mxnet.py
+++ b/python/tvm/relay/frontend/mxnet.py
@@ -795,19 +795,52 @@ def _mx_multibox_detection(inputs, attrs):
 
 def _mx_dot(inputs, attrs):
     assert len(inputs) == 2
-    a, b = inputs
+
+    a = inputs[0]
+    b = inputs[1]
+
     rank_a = len(_infer_type(a).checked_type.shape)
     rank_b = len(_infer_type(b).checked_type.shape)
-    if rank_a != 2 or rank_b != 2:
-        raise tvm.error.OpAttributeUnimplemented("Only 2-D arrays are 
supported.")
+
+    if rank_a < 1 or rank_b < 1:
+        raise tvm.error.OpAttributeInvalid("Unsupported shape of input 
tensors.")
+
     transpose_a = attrs.get_bool("transpose_a", False)
     transpose_b = attrs.get_bool("transpose_b", False)
+
     if transpose_a is True:
         msg = 'Value {} in attribute "transpose_a" of operator dot ' "is not 
valid."
         raise tvm.error.OpAttributeInvalid(msg.format(transpose_a))
-    if transpose_b is False:
-        b = _op.transpose(b, axes=[1, 0])
-    return _op.nn.dense(a, b)
+
+    # When performing dot product we need to properly handle shape of result 
-> out_shape
+    if rank_a == 1:
+        out_shape = list()
+        a = _op.expand_dims(a, axis=0)
+    else:
+        shape_a = list(_infer_type(a).checked_type.shape)
+        out_shape = shape_a[:-1]
+        a = _op.reshape(a, newshape=(-1, shape_a[-1]))
+
+    if rank_b == 1:
+        if not out_shape:
+            out_shape = [
+                1,
+            ]
+        b = _op.expand_dims(b, axis=1)
+    else:
+        # Transpose matrix b if needed
+        if transpose_b:
+            trans_axes = list(range(rank_b))
+            trans_axes = trans_axes[-1:] + trans_axes[:-1]
+            b = _op.transpose(b, axes=trans_axes)
+
+        shape_b = list(_infer_type(b).checked_type.shape)
+        out_shape += shape_b[1:]
+        b = _op.reshape(b, newshape=(shape_b[0], -1))
+
+    out = _op.reshape(_op.nn.matmul(a, b), newshape=out_shape)
+
+    return out
 
 
 def _mx_batch_dot(inputs, attrs):
diff --git a/tests/python/frontend/mxnet/test_forward.py 
b/tests/python/frontend/mxnet/test_forward.py
index 44aa93061a..0e34719ea2 100644
--- a/tests/python/frontend/mxnet/test_forward.py
+++ b/tests/python/frontend/mxnet/test_forward.py
@@ -690,6 +690,17 @@ def test_forward_dot():
 
     verify((1, 256), (256, 1))
     verify((1, 256), (1, 256), transpose_b=True)
+    verify((5,), (5,))
+    verify((3,), (3, 5))
+    verify((3,), (5, 3), transpose_b=True)
+    verify((3,), (3, 5, 3, 5))
+    verify((3,), (5, 5, 3, 3), transpose_b=True)
+    verify((10, 1), (1,))
+    verify((1, 1), (4, 3, 2, 1), transpose_b=True)
+    verify((4, 3, 2, 1), (1,))
+    verify((1, 2, 3, 4), (1, 4), transpose_b=True)
+    verify((4, 1, 1), (1, 2, 3))
+    verify((1, 1, 4), (2, 3, 4), transpose_b=True)
 
 
 @tvm.testing.uses_gpu

Reply via email to