This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2dcf9ec5a6 [Keras] Enable Dense operator for any input dims (#16526)
2dcf9ec5a6 is described below

commit 2dcf9ec5a6d6873cee0461d4c3f3e6990916e020
Author: Egor Churaev <egor.chur...@gmail.com>
AuthorDate: Wed Feb 7 13:40:21 2024 +0300

    [Keras] Enable Dense operator for any input dims (#16526)
    
    Our dense op expects 2D, but there are no limitation in Keras on the
    shape of the input tensor. Reshaping of all "batch" axes into one was
    added in this commit. After that, it is possible to import Dense layer
    with ND input tensor from Keras to TVM.
---
 python/tvm/relay/frontend/keras.py          | 14 ++++++++------
 tests/python/frontend/keras/test_forward.py | 10 ++++++++++
 2 files changed, 18 insertions(+), 6 deletions(-)

diff --git a/python/tvm/relay/frontend/keras.py 
b/python/tvm/relay/frontend/keras.py
index 2186208994..d53647cc68 100644
--- a/python/tvm/relay/frontend/keras.py
+++ b/python/tvm/relay/frontend/keras.py
@@ -266,11 +266,12 @@ def _convert_dense(
     # In case of RNN dense, input shape will be (1, 1, n)
     if input_dim > 2:
         input_shape = tuple(dim if dim else 1 for dim in 
_as_list(input_shape)[0])
-        if input_dim != 3 or input_shape[0] != 1 or input_shape[1] != 1:
-            raise tvm.error.OpAttributeInvalid(
-                f"Input shape {input_shape} is not valid for operator Dense."
-            )
-        inexpr = _op.squeeze(inexpr, axis=[0])
+        # Keras has no limitations on the shape of the input tensor. But our
+        # dense op expects 2D input. All inputs with number of dimensions > 2
+        # are reshaped all "batch" axes into one.
+        # For example: (N, d1, d2, d3) -> (N * d1 * d2, d3)
+        new_batch_size = np.prod(input_shape[:-1])
+        inexpr = _op.reshape(inexpr, newshape=(new_batch_size, 
input_shape[-1]))
     out = _op.nn.dense(data=inexpr, **params)
     if keras_layer.use_bias:
         bias = etab.new_const(weightList[1])
@@ -283,7 +284,8 @@ def _convert_dense(
     if act_type != "linear":
         out = _convert_activation(out, act_type, etab, data_layout)
     if input_dim > 2:
-        out = _op.expand_dims(out, axis=0)
+        out_shape = (*input_shape[:-1], units)
+        out = _op.reshape(out, newshape=out_shape)
     return out
 
 
diff --git a/tests/python/frontend/keras/test_forward.py 
b/tests/python/frontend/keras/test_forward.py
index aef137e634..0d05e34a15 100644
--- a/tests/python/frontend/keras/test_forward.py
+++ b/tests/python/frontend/keras/test_forward.py
@@ -285,6 +285,16 @@ class TestKeras:
         keras_model = keras_mod.models.Model(data, x)
         verify_keras_frontend(keras_model, need_transpose=False)
 
+        data = keras_mod.layers.Input(shape=(120, 2560), name="image_set")
+        x = keras_mod.layers.Dense(1, activation="linear", name="e")(data)
+        keras_model = keras_mod.models.Model(data, x)
+        verify_keras_frontend(keras_model, need_transpose=False)
+
+        data = keras_mod.layers.Input(shape=(10, 12, 2560), name="image_set")
+        x = keras_mod.layers.Dense(32, activation="linear", name="e")(data)
+        keras_model = keras_mod.models.Model(data, x)
+        verify_keras_frontend(keras_model, need_transpose=False)
+
     def test_forward_permute(self, keras_mod):
         data = keras_mod.layers.Input(shape=(2, 3, 4))
         x = keras_mod.layers.Permute([2, 3, 1])(data)

Reply via email to