This is an automated email from the ASF dual-hosted git repository. mbrookhart pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new ec6a817 [Frontend, Tensorflow] Support for broadcasting in batch_matmul when shapes differ (#8251) ec6a817 is described below commit ec6a817eaed246ffcf925f295b587cfc0af15035 Author: Rohan Mukherjee <mukro...@amazon.com> AuthorDate: Wed Jun 16 14:05:33 2021 -0700 [Frontend, Tensorflow] Support for broadcasting in batch_matmul when shapes differ (#8251) * Support for broadcasting in batch_matmul when shapes differ * refactor * refactor logic for reshape in conditional * refactor --- python/tvm/relay/frontend/tensorflow_ops.py | 16 +++++++++------- tests/python/frontend/tensorflow/test_forward.py | 17 +++++++++++++++++ 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/frontend/tensorflow_ops.py b/python/tvm/relay/frontend/tensorflow_ops.py index c738556..3c4a9b6 100644 --- a/python/tvm/relay/frontend/tensorflow_ops.py +++ b/python/tvm/relay/frontend/tensorflow_ops.py @@ -1132,22 +1132,23 @@ def _batch_matmul(): orig_shape_x = _infer_shape(input_x, mod) orig_shape_y = _infer_shape(input_y, mod) ndim = len(orig_shape_x) + ndim_y = len(orig_shape_y) is_static = not check_symbolic_shape(orig_shape_x) - if ndim > 3 and not is_static: - shape_of_x = list_shape_of(inputs[0], ndim) - shape_of_y = list_shape_of(inputs[1], ndim) - # reshape n-dimensional batch matmul into 3d if ndim > 3: outer_dims = [orig_shape_x[i] for i in range(0, len(orig_shape_x) - 2)] if is_static: num_outer_elts = np.prod(outer_dims) new_shape_x = (num_outer_elts, orig_shape_x[-2], orig_shape_x[-1]) - new_shape_y = (num_outer_elts, orig_shape_y[-2], orig_shape_y[-1]) + if ndim_y > 2: + new_shape_y = (num_outer_elts, orig_shape_y[-2], orig_shape_y[-1]) + elif ndim_y == 2: + new_shape_y = (1, orig_shape_y[-2], orig_shape_y[-1]) else: # handle dynamic shape (dyn.reshape op) - # new shape = [prod(shape[:-2]), -2, -1] + shape_of_x = list_shape_of(inputs[0], ndim) + shape_of_y = list_shape_of(inputs[1], ndim) new_shape_x = [_op.const(1), shape_of_x[-2], shape_of_x[-1]] new_shape_y = [_op.const(1), shape_of_y[-2], shape_of_y[-1]] for i in range(ndim - 2): @@ -1158,7 +1159,8 @@ def _batch_matmul(): input_x = _op.reshape(input_x, newshape=new_shape_x) input_y = _op.reshape(input_y, newshape=new_shape_y) - + elif ndim_y == 2: + input_y = _op.reshape(input_y, (1, orig_shape_y[-2], orig_shape_y[-1])) adj_x = attr["adj_x"] adj_y = attr["adj_y"] input_x = _op.transpose(input_x, axes=[0, 2, 1]) if adj_x else input_x diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 3315533..57497d0 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1843,6 +1843,9 @@ def test_forward_batch_matmul(): _test_batch_matmul((1, 2, 3, 4, 5, 6), (1, 2, 3, 4, 6, 5), "float32", True, True) _test_batch_matmul((3, 4, 5, 6), (3, 4, 5, 6), "int32", True, False) _test_batch_matmul((2, 3, 4, 2, 3, 4, 5, 6), (2, 3, 4, 2, 3, 4, 5, 6), "float32", False, True) + _test_batch_matmul((1, 8, 64, 2), (2, 1), "float32", False, False) + _test_batch_matmul((1, 8, 8, 64), (64, 1), "float32", False, False) + _test_batch_matmul((1, 8, 64), (64, 1), "float32", False, False) @tvm.testing.requires_cuda @@ -1870,6 +1873,20 @@ def test_forward_batch_matmul_dynamic(): (2, 3, 4, 6, 5), "float32", ) + _test_batch_matmul_dynamic( + (None, None, None, 5, 6), + (6, None), + (2, 3, 4, 5, 6), + (6, 1), + "float32", + ) + _test_batch_matmul_dynamic( + (None, 5, 6), + (6, None), + (24, 5, 6), + (6, 1), + "float32", + ) #######################################################################