This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 07d8e02367 [Unity][nnModule] Dynamic shape support in nn Module 
(#16284)
07d8e02367 is described below

commit 07d8e0236791e5d9069f2b4f9227bc33f77b328b
Author: Charlie Ruan <53290280+charliefr...@users.noreply.github.com>
AuthorDate: Sat Jan 13 03:44:01 2024 +0800

    [Unity][nnModule] Dynamic shape support in nn Module (#16284)
    
    * [Unity][nnModule] Dynamic shape support in nn Module
---
 python/tvm/relax/frontend/nn/core.py     | 15 +++++++++++----
 python/tvm/relax/frontend/nn/exporter.py | 21 ++++++++++++++++++---
 python/tvm/relax/frontend/nn/modules.py  | 11 ++++++++---
 3 files changed, 37 insertions(+), 10 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/core.py 
b/python/tvm/relax/frontend/nn/core.py
index 9c99ba6177..8eeffd8758 100644
--- a/python/tvm/relax/frontend/nn/core.py
+++ b/python/tvm/relax/frontend/nn/core.py
@@ -128,13 +128,15 @@ class Tensor(_TensorOp):
 
     @staticmethod
     def placeholder(
-        shape: Sequence[Union[int, tir.PrimExpr]],
+        shape: Sequence[Union[int, str, tir.PrimExpr]],
         dtype: str,
         name: str = "tensor",
     ) -> "Tensor":
         """Create a placeholder tensor with given shape and dtype. A 
placeholder tensor should
         never be created directly by users in usual cases, and the only 
exception is to indicate
         the shape/dtype of return values of an external function.
+
+        If shape is a string `name`, we create a symbolic shape 
`tvm.tir.Var(name, "int64")`.
         """
         new_shape = []
         for expr in shape:
@@ -143,6 +145,10 @@ class Tensor(_TensorOp):
                 assert expr >= 0
                 new_shape.append(expr)
                 continue
+            if isinstance(expr, str):
+                expr = tir.Var(expr, "int64")
+                new_shape.append(expr)
+                continue
             if not isinstance(expr, tir.PrimExpr):
                 raise TypeError(f"Invalid shape: {shape}")
             assert expr.dtype == "int64"
@@ -214,7 +220,7 @@ class Parameter(Tensor):
 
     def __init__(
         self,
-        shape: Sequence[Union[int, tir.PrimExpr]],
+        shape: Sequence[Union[int, str, tir.PrimExpr]],
         dtype: Optional[str] = None,
     ) -> None:
         """Create a parameter with given shape and dtype. The parameter is not 
bound to any
@@ -222,8 +228,9 @@ class Parameter(Tensor):
 
         Parameters
         ----------
-        shape : Sequence[Union[int, tir.PrimExpr]]
-            The shape of the parameter
+        shape : Sequence[Union[int, str, tir.PrimExpr]]
+            The shape of the parameter. If it is a string `name`, we create a 
symbolic shape
+            `tvm.tir.Var(name, "int64")`.
         dtype : Optional[str]
             The data type of the parameter. If not specified, the default 
dtype will be used.
         """
diff --git a/python/tvm/relax/frontend/nn/exporter.py 
b/python/tvm/relax/frontend/nn/exporter.py
index 99591c8a3e..d452af69d3 100644
--- a/python/tvm/relax/frontend/nn/exporter.py
+++ b/python/tvm/relax/frontend/nn/exporter.py
@@ -111,8 +111,7 @@ class Exporter:
             return result
 
         # pylint: enable=protected-access
-
-        params = _params()
+        params = None
         effects = _effects()
         ext_mods = self.extern_mods
         with self:
@@ -122,6 +121,7 @@ class Exporter:
                         outputs = _emit_effect_init(self.builder, effects)
                     self.builder.emit_func_output(outputs, params=[])
             for method_name, method_spec in zip(spec.method_names, 
spec.method_specs):
+                params = _params()  # Re-initialize so symbolic shapes not 
shared across methods
                 len_args = len(method_spec.arg_specs)
                 len_effects = {
                     "packed": 1,
@@ -159,6 +159,9 @@ def _emit_method(  # pylint: 
disable=too-many-locals,too-many-branches,too-many-
     effects: typing.Optional[typing.List[typing.Tuple[str, core.Effect]]],
 ):
     # pylint: disable=protected-access
+    # symbolic shape's name mapping to its tir.Var for reuse
+    str2var_params: typing.Dict[str, tir.Var] = {}
+
     def _unwrap_ret(expr: typing.Any) -> typing.Any:
         if isinstance(expr, (core.Tensor, core.Object)):
             return expr._expr
@@ -184,8 +187,20 @@ def _emit_method(  # pylint: 
disable=too-many-locals,too-many-branches,too-many-
 
     def _params(mode: str) -> typing.List[rx.Var]:
         inputs: typing.List[rx.Var] = []
+
+        def _get_var(shape_var: tir.Var) -> tir.Var:
+            name = shape_var.name
+            if name in str2var_params:
+                return str2var_params[name]
+            var = tir.Var(name, "int64")
+            str2var_params[name] = var
+            return var
+
         for name, param in params:
-            var = core.Tensor.placeholder(param.shape, param.dtype, name)._expr
+            # Make sure the a symbolic shape is not re-registered (same as 
_method_spec_to_inputs)
+            # e.g. we do not see `vocab_size` for `lm_head` and `vocab_size_1` 
for `embed_tokens`
+            new_shape = [_get_var(x) if isinstance(x, tir.Var) else x for x in 
param.shape]
+            var = core.Tensor.placeholder(new_shape, param.dtype, name)._expr
             inputs.append(var)
             param._expr = var
         if mode == "none":
diff --git a/python/tvm/relax/frontend/nn/modules.py 
b/python/tvm/relax/frontend/nn/modules.py
index 03d6a06994..f1b785b51a 100644
--- a/python/tvm/relax/frontend/nn/modules.py
+++ b/python/tvm/relax/frontend/nn/modules.py
@@ -99,8 +99,8 @@ class Linear(Module):
 
     def __init__(
         self,
-        in_features: int,
-        out_features: int,
+        in_features: Union[int, str, tir.PrimExpr],
+        out_features: Union[int, str, tir.PrimExpr],
         bias: bool = True,
         dtype: Optional[str] = None,
         out_dtype: Optional[str] = None,
@@ -617,7 +617,12 @@ class Embedding(Module):
     Module for embedding layer.
     """
 
-    def __init__(self, num: int, dim: int, dtype: Optional[str] = None):
+    def __init__(
+        self,
+        num: Union[int, str, tir.PrimExpr],
+        dim: Union[int, str, tir.PrimExpr],
+        dtype: Optional[str] = None,
+    ):
         self.num = num
         self.dim = dim
         self.weight = Parameter((num, dim), dtype=dtype)

Reply via email to