This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch unity in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push: new 07d8e02367 [Unity][nnModule] Dynamic shape support in nn Module (#16284) 07d8e02367 is described below commit 07d8e0236791e5d9069f2b4f9227bc33f77b328b Author: Charlie Ruan <53290280+charliefr...@users.noreply.github.com> AuthorDate: Sat Jan 13 03:44:01 2024 +0800 [Unity][nnModule] Dynamic shape support in nn Module (#16284) * [Unity][nnModule] Dynamic shape support in nn Module --- python/tvm/relax/frontend/nn/core.py | 15 +++++++++++---- python/tvm/relax/frontend/nn/exporter.py | 21 ++++++++++++++++++--- python/tvm/relax/frontend/nn/modules.py | 11 ++++++++--- 3 files changed, 37 insertions(+), 10 deletions(-) diff --git a/python/tvm/relax/frontend/nn/core.py b/python/tvm/relax/frontend/nn/core.py index 9c99ba6177..8eeffd8758 100644 --- a/python/tvm/relax/frontend/nn/core.py +++ b/python/tvm/relax/frontend/nn/core.py @@ -128,13 +128,15 @@ class Tensor(_TensorOp): @staticmethod def placeholder( - shape: Sequence[Union[int, tir.PrimExpr]], + shape: Sequence[Union[int, str, tir.PrimExpr]], dtype: str, name: str = "tensor", ) -> "Tensor": """Create a placeholder tensor with given shape and dtype. A placeholder tensor should never be created directly by users in usual cases, and the only exception is to indicate the shape/dtype of return values of an external function. + + If shape is a string `name`, we create a symbolic shape `tvm.tir.Var(name, "int64")`. """ new_shape = [] for expr in shape: @@ -143,6 +145,10 @@ class Tensor(_TensorOp): assert expr >= 0 new_shape.append(expr) continue + if isinstance(expr, str): + expr = tir.Var(expr, "int64") + new_shape.append(expr) + continue if not isinstance(expr, tir.PrimExpr): raise TypeError(f"Invalid shape: {shape}") assert expr.dtype == "int64" @@ -214,7 +220,7 @@ class Parameter(Tensor): def __init__( self, - shape: Sequence[Union[int, tir.PrimExpr]], + shape: Sequence[Union[int, str, tir.PrimExpr]], dtype: Optional[str] = None, ) -> None: """Create a parameter with given shape and dtype. The parameter is not bound to any @@ -222,8 +228,9 @@ class Parameter(Tensor): Parameters ---------- - shape : Sequence[Union[int, tir.PrimExpr]] - The shape of the parameter + shape : Sequence[Union[int, str, tir.PrimExpr]] + The shape of the parameter. If it is a string `name`, we create a symbolic shape + `tvm.tir.Var(name, "int64")`. dtype : Optional[str] The data type of the parameter. If not specified, the default dtype will be used. """ diff --git a/python/tvm/relax/frontend/nn/exporter.py b/python/tvm/relax/frontend/nn/exporter.py index 99591c8a3e..d452af69d3 100644 --- a/python/tvm/relax/frontend/nn/exporter.py +++ b/python/tvm/relax/frontend/nn/exporter.py @@ -111,8 +111,7 @@ class Exporter: return result # pylint: enable=protected-access - - params = _params() + params = None effects = _effects() ext_mods = self.extern_mods with self: @@ -122,6 +121,7 @@ class Exporter: outputs = _emit_effect_init(self.builder, effects) self.builder.emit_func_output(outputs, params=[]) for method_name, method_spec in zip(spec.method_names, spec.method_specs): + params = _params() # Re-initialize so symbolic shapes not shared across methods len_args = len(method_spec.arg_specs) len_effects = { "packed": 1, @@ -159,6 +159,9 @@ def _emit_method( # pylint: disable=too-many-locals,too-many-branches,too-many- effects: typing.Optional[typing.List[typing.Tuple[str, core.Effect]]], ): # pylint: disable=protected-access + # symbolic shape's name mapping to its tir.Var for reuse + str2var_params: typing.Dict[str, tir.Var] = {} + def _unwrap_ret(expr: typing.Any) -> typing.Any: if isinstance(expr, (core.Tensor, core.Object)): return expr._expr @@ -184,8 +187,20 @@ def _emit_method( # pylint: disable=too-many-locals,too-many-branches,too-many- def _params(mode: str) -> typing.List[rx.Var]: inputs: typing.List[rx.Var] = [] + + def _get_var(shape_var: tir.Var) -> tir.Var: + name = shape_var.name + if name in str2var_params: + return str2var_params[name] + var = tir.Var(name, "int64") + str2var_params[name] = var + return var + for name, param in params: - var = core.Tensor.placeholder(param.shape, param.dtype, name)._expr + # Make sure the a symbolic shape is not re-registered (same as _method_spec_to_inputs) + # e.g. we do not see `vocab_size` for `lm_head` and `vocab_size_1` for `embed_tokens` + new_shape = [_get_var(x) if isinstance(x, tir.Var) else x for x in param.shape] + var = core.Tensor.placeholder(new_shape, param.dtype, name)._expr inputs.append(var) param._expr = var if mode == "none": diff --git a/python/tvm/relax/frontend/nn/modules.py b/python/tvm/relax/frontend/nn/modules.py index 03d6a06994..f1b785b51a 100644 --- a/python/tvm/relax/frontend/nn/modules.py +++ b/python/tvm/relax/frontend/nn/modules.py @@ -99,8 +99,8 @@ class Linear(Module): def __init__( self, - in_features: int, - out_features: int, + in_features: Union[int, str, tir.PrimExpr], + out_features: Union[int, str, tir.PrimExpr], bias: bool = True, dtype: Optional[str] = None, out_dtype: Optional[str] = None, @@ -617,7 +617,12 @@ class Embedding(Module): Module for embedding layer. """ - def __init__(self, num: int, dim: int, dtype: Optional[str] = None): + def __init__( + self, + num: Union[int, str, tir.PrimExpr], + dim: Union[int, str, tir.PrimExpr], + dtype: Optional[str] = None, + ): self.num = num self.dim = dim self.weight = Parameter((num, dim), dtype=dtype)