This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch unity-staging in repository https://gitbox.apache.org/repos/asf/tvm.git
commit c21a7ddfccb57909faaeac127a99784867571108 Author: Yuchen Jin <yuch...@cs.washington.edu> AuthorDate: Sat Feb 11 09:26:36 2023 -0800 [Unity] e2e Relax minimum build flow (#13961) This PR introduces the e2e Relax lowering flow (`relax.vm.build`). Tests for each pass in the flow are added. Co-Authored-by: Altan Haan <alt...@cs.washington.edu> Co-Authored-by: Andrew Liu <andrewl...@gmail.com> Co-Authored-by: Hongyi Jin <3231950...@qq.com> Co-Authored-by: Jiawei Liu <jaway....@gmail.com> Co-Authored-by: Junru Shao <junrushao1...@gmail.com> Co-Authored-by: Prakalp Srivastava <prak...@octoml.ai> Co-Authored-by: Ruihang Lai <ruiha...@cs.cmu.edu> Co-Authored-by: Siyuan Feng <hzfen...@sjtu.edu.cn> Co-Authored-by: Steven S. <Lyubomirsky slyubomir...@octoml.ai> Co-Authored-by: Sunghyun Park <49998730+sun...@users.noreply.github.com> Co-Authored-by: Tianqi Chen <tianqi.tc...@gmail.com> Co-Authored-by: Yong Wu <yongc...@gmail.com> Co-Authored-by: Ziheng Jiang <zih...@apache.org> --- CMakeLists.txt | 1 + include/tvm/relax/analysis.h | 16 + include/tvm/relax/backend.h | 7 + include/tvm/relax/transform.h | 35 + python/tvm/relax/analysis/analysis.py | 45 + python/tvm/relax/op/__init__.py | 3 + python/tvm/relax/op/{ => builtin}/__init__.py | 6 +- .../relax/op/{__init__.py => builtin/_ffi_api.py} | 9 +- python/tvm/relax/op/builtin/builtin.py | 44 + python/tvm/relax/op/manipulate.py | 62 ++ python/tvm/relax/op/{ => memory}/__init__.py | 6 +- .../relax/op/{__init__.py => memory/_ffi_api.py} | 9 +- python/tvm/relax/op/memory/memory.py | 108 +++ python/tvm/relax/{op => testing}/__init__.py | 6 +- python/tvm/relax/testing/nn.py | 194 +++++ python/tvm/relax/transform/transform.py | 53 ++ python/tvm/relax/vm.py | 4 +- python/tvm/script/ir_builder/relax/ir.py | 6 + src/relax/analysis/tir_op_pattern_kind.cc | 447 ++++++++++ src/relax/backend/vm/vm_builtin_lower.cc | 208 +++++ src/relax/op/op.cc | 81 ++ src/relax/op/tensor/manipulate.cc | 163 ++++ .../backend.h => src/relax/op/tensor/manipulate.h | 25 +- src/relax/transform/attach_global_symbol.cc | 68 ++ src/relax/transform/call_tir_rewrite.cc | 137 ++++ src/relax/transform/rewrite_dataflow_reshape.cc | 110 +++ src/relax/transform/to_non_dataflow.cc | 67 ++ tests/python/relax/test_analysis.py | 172 ++++ tests/python/relax/test_transform.py | 141 ++++ .../relax/test_transform_attach_global_symbol.py | 88 ++ .../test_transform_rewrite_dataflow_reshape.py | 166 ++++ tests/python/relax/test_vm_build.py | 908 +++++++++++++++++++++ 32 files changed, 3358 insertions(+), 37 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index eecd67be94..d0470677e1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -292,6 +292,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS src/relax/ir/*.cc src/relax/op/*.cc src/relax/analysis/*.cc + src/relax/transform/*.cc src/relax/backend/vm/*.cc src/relax/utils.cc ) diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h index ad2bd19aa4..24cfe5b9bf 100644 --- a/include/tvm/relax/analysis.h +++ b/include/tvm/relax/analysis.h @@ -259,6 +259,22 @@ TVM_DLL bool IsBaseOf(const StructInfo& base, const StructInfo& derived, */ TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs, arith::Analyzer* ana = nullptr); + +/*! + * \brief Check if the given PrimFunc is essentially doing a reshape operation. + * The reshape operation also includes expand_dims, squeeze, flatten, etc. + * \details Here the allowed reshape pattern is: for example, assume the operation is + * `B[l_0, l_1, ..., l_b] = A[r_0, r_1, ..., r_a]`, we check if we can prove that the flattened + * index of l_0, ..., l_b under buffer B equals to the flattened index of r_0, ..., r_a under + * buffer A. + * \param func The function to be examined. + * \return A boolean indicating if the given PrimFunc is doing a reshape. + * \note According to the description above, the returned result can only be false-negative and + * cannot be false-positive, since whenever we cannot prove the equality, we return false. This + * property guarantees the safety of this function. + */ +TVM_DLL bool HasReshapePattern(const tir::PrimFunc& func); + } // namespace relax } // namespace tvm diff --git a/include/tvm/relax/backend.h b/include/tvm/relax/backend.h index 4ebeacac0f..2fb11f5a6f 100644 --- a/include/tvm/relax/backend.h +++ b/include/tvm/relax/backend.h @@ -30,6 +30,13 @@ namespace tvm { namespace relax { namespace transform { +/*! + * \brief Perform builtin lowering to map most of the op to VM builtin functions. + * + * \return The Pass. + */ +TVM_DLL Pass VMBuiltinLower(); + /*! * \brief Lower the shape expression in relax to VM shape heap and TIR functions. * diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h index fa288a7f06..ff98b16d25 100644 --- a/include/tvm/relax/transform.h +++ b/include/tvm/relax/transform.h @@ -65,6 +65,41 @@ TVM_DLL Pass CreateDataflowBlockPass( const runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule, PassContext)>& pass_func, int opt_level, String name, tvm::Array<String> required); +/*! + * \brief Transform all dataflow structure to non-dataflow version. + * + * \return The Pass. + */ +TVM_DLL Pass ToNonDataflow(); + +/*! + * \brief Perform explicit tensor allocation for call_tir. + * + * \return The Pass. + */ +TVM_DLL Pass CallTIRRewrite(); + +/*! + * \brief Convert all reshape-like call_tir whose corresponding binding + * vars are DataflowVars to relax.reshape operator calls. The relax.reshape + * calls will be lowered an external builtin function call in a subsequent + * pass, where the external builtin function does a CreateView operation + * at runtime, instead of doing real data copy. + * Here "reshape-like" includes reshape, expand_dims, flatten, etc. + * + * \return The Pass. + * \note The pass is applied at the first stage of Relax VM build, before + * rewriting call_tir, as this pass requires dataflow information. + */ +TVM_DLL Pass RewriteDataflowReshape(); + +/*! + * \brief Attach global_symbol to Relax functions and TIR Primfuncs for codegen. + * + * \return The Pass. + */ +TVM_DLL Pass AttachGlobalSymbol(); + } // namespace transform } // namespace relax } // namespace tvm diff --git a/python/tvm/relax/analysis/analysis.py b/python/tvm/relax/analysis/analysis.py index d81c477145..27416c3a79 100644 --- a/python/tvm/relax/analysis/analysis.py +++ b/python/tvm/relax/analysis/analysis.py @@ -162,3 +162,48 @@ def struct_info_lca(lhs: StructInfo, rhs: StructInfo) -> StructInfo: The corresponding lca result. """ return _ffi_api.StructInfoLCA(lhs, rhs) # type: ignore + + +def post_order_visit(expr, fvisit): + """Recursively visit the ir in post DFS order node, + apply fvisit. Each node is guaranteed to be visited + only once. + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression. + + fvisit : function + The visitor function to be applied. + """ + return _ffi_api.post_order_visit(expr, fvisit) # type: ignore + + +def has_reshape_pattern(func: tir.PrimFunc) -> bool: + """Check if the given PrimFunc is essentially doing a reshape operation. + The reshape operation also includes expand_dims, squeeze, flatten, etc. + + Here the allowed reshape pattern is: for example, assume the operation is + `B[l_0, l_1, ..., l_b] = A[r_0, r_1, ..., r_a]`, we check if we can prove + that the flattened index of l_0, ..., l_b under buffer B equals to the + flattened index of r_0, ..., r_a under buffer A. + + Parameters + ---------- + func : tir.PrimFunc + The function to be examined. + + Returns + ------- + ret : bool + A boolean indicating if the given PrimFunc is doing a reshape. + + Notes + ----- + According to the description above, the returned result can only be + false-negative and cannot be false-positive, since whenever we cannot + prove the equality, we return false. This property guarantees the safety + of this function. + """ + return _ffi_api.has_reshape_pattern(func) # type: ignore diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 101b0827d6..9a131cdf95 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -20,3 +20,6 @@ # Operators from .base import * from .binary import * +from .manipulate import * +from . import builtin +from . import memory diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/builtin/__init__.py similarity index 91% copy from python/tvm/relax/op/__init__.py copy to python/tvm/relax/op/builtin/__init__.py index 101b0827d6..04837724b1 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/builtin/__init__.py @@ -15,8 +15,6 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=wildcard-import, redefined-builtin -"""Relax core operators.""" +"""Relax builtin operators.""" -# Operators -from .base import * -from .binary import * +from .builtin import * diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/builtin/_ffi_api.py similarity index 83% copy from python/tvm/relax/op/__init__.py copy to python/tvm/relax/op/builtin/_ffi_api.py index 101b0827d6..42fe8cb652 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/builtin/_ffi_api.py @@ -13,10 +13,7 @@ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations -# under the License. -# pylint: disable=wildcard-import, redefined-builtin -"""Relax core operators.""" +"""FFI APIs for tvm.relax.op.builtin""" +import tvm._ffi -# Operators -from .base import * -from .binary import * +tvm._ffi._init_api("relax.op.builtin", __name__) diff --git a/python/tvm/relax/op/builtin/builtin.py b/python/tvm/relax/op/builtin/builtin.py new file mode 100644 index 0000000000..0afe6a42d0 --- /dev/null +++ b/python/tvm/relax/op/builtin/builtin.py @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +"""The builtin Relax operators.""" + +from ...expr import Call, Expr +from ...utils import args_converter +from . import _ffi_api + + +@args_converter.auto +def alloc_tensor(shape: Expr, dtype: str, runtime_device_index: int) -> Call: + """Construct a Call to allocate a tensor with specific shape, dtype, runtime_device_index. + + Parameters + ---------- + shape : Expr + The shape of the tensor to be allocated. + + dtype : str + The datatype of the tensor to be allocated. + + runtime_device_index : int + The device index indicating on which device the tensor is to be allocated at runtime. + Index -1 is reserved for the host device. + + Returns + ------- + result : Call + A relax Call, which gets the allocated tensor. + """ + return _ffi_api.alloc_tensor(shape, dtype, runtime_device_index) # type: ignore diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py new file mode 100644 index 0000000000..fa9c815225 --- /dev/null +++ b/python/tvm/relax/op/manipulate.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Manipulation operators.""" +from typing import Tuple, Union + +from tvm.ir.expr import PrimExpr + + +from . import _ffi_api +from ..expr import Expr + + +PrimExprLike = Union[int, PrimExpr] + + +def reshape(x: Expr, shape: Union[Tuple[PrimExprLike], Expr]) -> Expr: + """Reshape the input array. + + ``-1`` infers the dimension of the output shape by using the remainder of + the input dimensions keeping the size of the new array same as that of the input array. + At most one dimension of shape can be -1. + + .. code-block:: python + + x.shape = (2, 3, 4), shape = (6, 1, -1), result.shape = (6, 1, 4) + x.shape = (2, 3, 4), shape = (3, -1, 8), result.shape = (3, 1, 8) + x.shape = (2, 3, 4), shape = (-1,), result.shape = (24,) + + Parameters + ---------- + x : relax.Expr + The input data to the operator. + + shape : Union[Tuple[PrimExprLike], Expr] + The new shape. Should be compatible with the original shape. + + Returns + ------- + result : relax.Expr + The reshaped result. + + Note + ---- + The ``-1`` inference is only performed at compile-time. + That is to say, in any case the dimension length of ``-1`` cannot be inferred in + compile-time, an error will be thrown. + """ + return _ffi_api.reshape(x, shape) # type: ignore diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/memory/__init__.py similarity index 91% copy from python/tvm/relax/op/__init__.py copy to python/tvm/relax/op/memory/__init__.py index 101b0827d6..e039590251 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/memory/__init__.py @@ -15,8 +15,6 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=wildcard-import, redefined-builtin -"""Relax core operators.""" +"""Relax memory primitives.""" -# Operators -from .base import * -from .binary import * +from .memory import * diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/memory/_ffi_api.py similarity index 83% copy from python/tvm/relax/op/__init__.py copy to python/tvm/relax/op/memory/_ffi_api.py index 101b0827d6..475de481b2 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/memory/_ffi_api.py @@ -13,10 +13,7 @@ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations -# under the License. -# pylint: disable=wildcard-import, redefined-builtin -"""Relax core operators.""" +"""FFI APIs for tvm.relax.op.memory""" +import tvm._ffi -# Operators -from .base import * -from .binary import * +tvm._ffi._init_api("relax.op.memory", __name__) diff --git a/python/tvm/relax/op/memory/memory.py b/python/tvm/relax/op/memory/memory.py new file mode 100644 index 0000000000..b58b987d2a --- /dev/null +++ b/python/tvm/relax/op/memory/memory.py @@ -0,0 +1,108 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +"""Relax memory primitives.""" + +from . import _ffi_api +from ...expr import Expr, Call +from ...utils import args_converter + + +@args_converter.auto +def alloc_storage(size: Expr, virtual_device_index: int, storage_scope: str, dtype: str) -> Call: + """Construct a Call to allocate a storage with specific size, virtual_device_index, + storage_scope and dtype. + + Parameters + ---------- + size : Expr + The size of the storage to be allocated. + + virtual_device_index : int + The virtual device index indicating on which device the storage is to be allocated. + Index -1 is reserved for the host device. + + storage_scope : str + The storage scope to allocate the storage to. + + dtype : str + The datatype of the storage to be allocated. + + Returns + ------- + result : Call + A relax Call, which gets the allocated storage. + """ + return _ffi_api.alloc_storage(size, virtual_device_index, storage_scope, dtype) # type: ignore + + +@args_converter.auto +def alloc_tensor(storage: Expr, offset: int, shape: Expr, dtype: str) -> Call: + """Construct a Call to allocate a tensor on a certain storage starting from the given offset. + + Parameters + ---------- + storage : Expr + The storage to allocate the tensor to. + + offset : int + The storage offset to allocate the tensor. + + shape : Expr + The shape of the tensor to be allocated. + + dtype : str + The datatype of the tensor to be allocated. + + Returns + ------- + result : Call + A relax Call, which gets the allocated tensor. + """ + return _ffi_api.alloc_tensor(storage, offset, shape, dtype) # type: ignore + + +@args_converter.auto +def kill_storage(storage: Expr) -> Call: + """Construct a Call to kill a storage. + + Parameters + ---------- + storage : Expr + The storage to be killed. + + Returns + ------- + result : Call + A relax Call to kill a storage. + """ + return _ffi_api.kill_storage(storage) # type: ignore + + +@args_converter.auto +def kill_tensor(tensor: Expr) -> Call: + """Construct a Call to kill a tensor. + + Parameters + ---------- + tensor : Expr + The tensor to be killed. + + Returns + ------- + result : Call + A relax Call to kill a tensor. + """ + return _ffi_api.kill_tensor(tensor) # type: ignore diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/testing/__init__.py similarity index 91% copy from python/tvm/relax/op/__init__.py copy to python/tvm/relax/testing/__init__.py index 101b0827d6..ab1dd6f515 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/testing/__init__.py @@ -15,8 +15,6 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=wildcard-import, redefined-builtin -"""Relax core operators.""" +"""The Relax testing namespace containing nn and translator.""" -# Operators -from .base import * -from .binary import * +from .nn import * diff --git a/python/tvm/relax/testing/nn.py b/python/tvm/relax/testing/nn.py new file mode 100644 index 0000000000..830ddd779f --- /dev/null +++ b/python/tvm/relax/testing/nn.py @@ -0,0 +1,194 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=redefined-builtin +"""PyTorch-like nn.Module API for constructing workloads.""" + + +from typing import List, Any, Callable, Union +import typing +import numpy as np # type: ignore + +import tvm +from tvm import relax, topi, tir + + +def emit_te(func: Callable, *args: Any, **kwargs: Any) -> relax.Var: + return relax.BlockBuilder.current().emit_te(func, *args, **kwargs) + + +class Placeholder(relax.Var): + """A placeholder variable that can represent model input.""" + + def __init__( + self, shape: Union[List[Any], typing.Tuple[Any, ...]], dtype="float32", name="data" + ): + if not isinstance(shape, (list, tuple)): + raise TypeError("the shape of Placeholder is expected to be a list or a tuple") + super().__init__( + relax.BlockBuilder.current().get_unique_name(name), relax.TensorStructInfo(shape, dtype) + ) + + +class Parameter(relax.Var): + """A special kind of relax Var that represents model parameter(weight).""" + + def __init__( + self, shape: Union[List[Any], typing.Tuple[Any, ...]], dtype="float32", name="param" + ): + if not isinstance(shape, (list, tuple)): + raise TypeError("the shape of Parameter is expected to be a list or a tuple") + super().__init__( + relax.BlockBuilder.current().get_unique_name(name), relax.TensorStructInfo(shape, dtype) + ) + + +class Module: + """Base class for all model modules. + + A neural network or a layer can subclass this class. + + Example + ------- + .. code-block:: python + + # Define a linear layer + class Linear(Module) + def __init__(self, in_features, out_features, bias=True): + self.in_features = in_features + self.out_features = out_features + self.weight = Parameter((in_features, out_features), name="linear_weight") + if bias: + self.bias = Parameter((out_features,), name="linear_bias") + else: + self.bias = None + + # All submodules should implement forward. + # Defines the forward computation performed at every call. + def forward(self, input: relax.Expr) -> relax.Var: + y = emit_te(topi.matmul, input, self.weight) + if self.bias is not None: + y = emit_te(topi.add, y, self.bias) + return y + """ + + def parameters(self) -> List[Parameter]: + """Return the list of parameters in the module.""" + return _unpack_params(self.__dict__) + + def forward(self, input: relax.Expr): + """Define the computation performed at every call.""" + raise NotImplementedError() + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + +def _unpack_params(value: object) -> List[relax.Var]: + if isinstance(value, Parameter): + return [value] + if isinstance(value, Module): + return value.parameters() + if isinstance(value, dict): + params = [] + for v in value.values(): + params += _unpack_params(v) + return params + if isinstance(value, (list, tuple)): + params = [] + for v in value: + params += _unpack_params(v) + return params + if value is None or isinstance(value, (int, float, str)): + return [] + raise TypeError("not supported type when unpacking parameters: {}".format(type(value))) + + +def init_params(mod: tvm.IRModule) -> List[tvm.nd.array]: + """Utility function to initialize model's parameters.""" + shape_dict = {v.name_hint: v.struct_info.shape for v in mod["main"].params} + params = [] + for k, v in shape_dict.items(): + if k.startswith("data"): + continue + if isinstance(v, relax.ShapeExpr): + shape = [] + for i in v: + if isinstance(i, tir.IntImm): + shape.append(int(i)) + else: + raise TypeError("cannot initialize for unknown-shape parameters.") + params.append(tvm.nd.array(np.zeros(shape).astype(np.float32))) + else: + raise TypeError("cannot initialize for unknown-shape parameters.") + return params + + +class Sequential(Module): + """A sequential container that concatenates modules in it. + + Example + ------- + .. code-block:: python + + model = nn.Sequential( + nn.Conv2d(1, 20, 5), + nn.ReLU(), + nn.Conv2d(20, 64, 5), + nn.ReLU() + ) + """ + + def __init__(self, *modules: Module): + self.modules = modules + + def forward(self, input: relax.Expr) -> relax.Var: + for module in self.modules: + input = module(input) + return input + + +class ReLU(Module): + """Applies the rectified linear unit activation function on the input.""" + + def forward(self, input: relax.Expr) -> relax.Var: + return emit_te(topi.nn.relu, input) + + +class LogSoftmax(Module): + """Applies log softmax activation function on the input.""" + + def forward(self, input: relax.Expr) -> relax.Var: + return emit_te(topi.nn.log_softmax, input) + + +class Linear(Module): + """Applies a linear transformation to the input data: :math:`y = xA + b`.""" + + def __init__(self, in_features, out_features, bias=True): + self.in_features = in_features + self.out_features = out_features + self.weight = Parameter((in_features, out_features), name="linear_weight") + if bias: + self.bias = Parameter((out_features,), name="linear_bias") + else: + self.bias = None + + def forward(self, input: relax.Expr) -> relax.Var: + y = emit_te(topi.matmul, input, self.weight) + if self.bias is not None: + y = emit_te(topi.add, y, self.bias) + return y diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index f20f06c522..cab18797c6 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -37,6 +37,49 @@ class DataflowBlockPass(tvm.ir.transform.Pass): """A pass that works on each tvm.relax.DataflowBlock in a module.""" +def ToNonDataflow() -> tvm.ir.transform.Pass: + """Transform all dataflow structure to non-dataflow version. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.ToNonDataflow() # type: ignore + + +def CallTIRRewrite() -> tvm.ir.transform.Pass: + """Perform explicit tensor allocation for call_tir. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.CallTIRRewrite() # type: ignore + + +def RewriteDataflowReshape() -> tvm.ir.transform.Pass: + """Convert all reshape-like call_tir to VM reshape operator call. + The VM reshape operator calls will be further lowered to a CreateView + operation at runtime, instead of doing real data copy. + Here "reshape-like" includes reshape, expand_dims, flatten, etc. + + Returns + ------- + ret : tvm.ir.transform.Pass + """ + return _ffi_api.RewriteDataflowReshape() # type: ignore + + +def VMBuiltinLower() -> tvm.ir.transform.Pass: + """Lowering generic intrinsic to VM intrinsics. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.VMBuiltinLower() # type: ignore + + def VMShapeLower(*, emit_err_ctx: bool = True) -> tvm.ir.transform.Pass: """Lower the symbolic shape and argument and match-cast structinfo matching. @@ -52,6 +95,16 @@ def VMShapeLower(*, emit_err_ctx: bool = True) -> tvm.ir.transform.Pass: return _ffi_api.VMShapeLower(emit_err_ctx) # type: ignore +def AttachGlobalSymbol() -> tvm.ir.transform.Pass: + """Attach global_symbol to Relax functions and TIR Primfuncs for codegen. + + Returns + ------- + ret: tvm.ir.transform.Pass + """ + return _ffi_api.AttachGlobalSymbol() # type: ignore + + def _wrap_class_function_pass(pass_cls, pass_info): """Wrap a python class as function pass.""" diff --git a/python/tvm/relax/vm.py b/python/tvm/relax/vm.py index ba16dfb079..ff6bf816b6 100644 --- a/python/tvm/relax/vm.py +++ b/python/tvm/relax/vm.py @@ -581,7 +581,9 @@ def build( if isinstance(target, str): target = tvm.target.Target(target) - passes = [relax.transform.ToNonDataflow()] + passes = [] + passes.append(relax.transform.RewriteDataflowReshape()) + passes.append(relax.transform.ToNonDataflow()) passes.append(relax.transform.CallTIRRewrite()) passes.append(relax.transform.VMBuiltinLower()) passes.append(relax.transform.VMShapeLower()) diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 647ef8f25a..0692ec5683 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -31,13 +31,16 @@ from tvm.relax import Call, Expr, ExternFunc, TupleGetItem, Var, const from tvm.relax.op import ( add, assert_op, + builtin, call_builtin_with_ctx, call_tir, invoke_closure, make_closure, + memory, multiply, null_value, print, + reshape, shape_of, ) from tvm.relax.struct_info import StructInfo @@ -381,6 +384,7 @@ __all__ = [ "add", "arg", "assert_op", + "builtin", "call_packed", "call_tir", "call_builtin_with_ctx", @@ -396,11 +400,13 @@ __all__ = [ "function", "invoke_closure", "make_closure", + "memory", "multiply", "null_value", "output", "prim_value", "print", + "reshape", "shape_of", "str", "tuple", diff --git a/src/relax/analysis/tir_op_pattern_kind.cc b/src/relax/analysis/tir_op_pattern_kind.cc new file mode 100644 index 0000000000..b7ac8faddd --- /dev/null +++ b/src/relax/analysis/tir_op_pattern_kind.cc @@ -0,0 +1,447 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include <tvm/relax/analysis.h> +#include <tvm/tir/analysis.h> +#include <tvm/tir/expr_functor.h> +#include <tvm/tir/function.h> +#include <tvm/tir/stmt_functor.h> + +namespace tvm { +namespace relax { + +using namespace tir; + +class PatternKindAnalyzer : public StmtExprVisitor { + public: + explicit PatternKindAnalyzer(const tir::PrimFunc& func) { + for (const tir::Var& param : func->params) { + Optional<Buffer> param_buf = func->buffer_map.Get(param); + if (param_buf.defined()) { + param_buffers_.insert(param_buf.value()); + } + } + } + + private: + bool IsOutputBlock(const BlockNode* block) { + for (const BufferRegion& write_region : block->writes) { + if (param_buffers_.count(write_region->buffer)) { + return true; + } + } + return false; + } + + void VisitStmt_(const BufferStoreNode* op) final { + // We only support one buffer store in a block (ususally generated by TE compute) + // If we have already seen buffer store in the current block, classify as Opaque. + if (store_.defined()) { + kind_ = relay::kOpaque; + return; + } + store_ = GetRef<BufferStore>(op); + StmtVisitor::VisitStmt_(op); + } + + void VisitExpr_(const BufferLoadNode* op) final { + loads_.push_back(GetRef<BufferLoad>(op)); + ExprVisitor::VisitExpr_(op); + } + + void VisitStmt_(const BlockNode* op) final { + if (op->name_hint == "root") { + // Skip the root block + StmtVisitor::VisitStmt(op->body); + return; + } + + // Step 1. Clear loads and store + loads_.clear(); + store_ = NullOpt; + // Step 2. Visit block body. + StmtVisitor::VisitStmt(op->body); + BufferStore store = store_.value(); + + // Step 3. Checking load store indices pattern + relay::OpPatternKind index_pair_pattern = relay::kElemWise; + bool has_elem_wise = false; + for (const BufferLoad& load : loads_) { + // Since elemwise is stricter than broadcast and broadcast is stricter than injective, + // while the order amount enums: kElemWise < kBroadcast < kInjective. + // We can simpily use `std::max` to detect these three patterns. + // E.g Here is only one store node but two load nodes, like C[i, j] = A[i, j] + B[i] + // Buffer C and A are elemwise but C and B are broadcast. So the whole block follows + // broadcast pattern. + if (IsElemwisePattern(store, load)) { + index_pair_pattern = std::max(index_pair_pattern, relay::kElemWise); + has_elem_wise = true; + } else if (IsBroadcastPattern(store, load)) { + index_pair_pattern = std::max(index_pair_pattern, relay::kBroadcast); + } else if (IsInjectivePattern(store, load)) { + index_pair_pattern = std::max(index_pair_pattern, relay::kInjective); + } else { + index_pair_pattern = relay::kOpaque; + break; + } + } + // If there is a index pair is kElemWise and others are kBroadcast, we regard it as kElemWise + // e.g. A[i, j] = B[i, j] + C[i] + if (index_pair_pattern == relay::kBroadcast && has_elem_wise) { + index_pair_pattern = relay::kElemWise; + } + // If the block index pattern is not opaque, update kind. + if (index_pair_pattern != relay::kOpaque) { + // This rule for softmax: reduce + injective. + if (IsOutputBlock(op) && kind_ == relay::kCommReduce) { + kind_ = relay::kOutEWiseFusable; + } else { + kind_ = std::max(kind_, index_pair_pattern); + } + return; + } + + // Step 4. Checking if the block contains reduce axis by looking into block iterators. + bool has_reduction = false; + Array<tir::Var> reduce_vars; + for (const IterVar& it : op->iter_vars) { + if (it->iter_type == kCommReduce) { + has_reduction = true; + reduce_vars.push_back(it->var); + } + } + + if (has_reduction) { + if (IsFMA(op->body)) { + // FMA is regards as kOutEWiseFusable, e.g. Matmul or Conv. + kind_ = std::max(kind_, relay::kOutEWiseFusable); + return; + } else { + for (size_t i = 0; i < loads_.size(); ++i) { + // If it's not a pure reduce, regards as kOutEWiseFusable. + // This rule works for pooling for now. + if (!IsPureReducePattern(reduce_vars, loads_[i]->indices)) { + kind_ = std::max(kind_, relay::kOutEWiseFusable); + return; + } + } + } + kind_ = std::max(kind_, relay::kCommReduce); + } else { + kind_ = relay::kOpaque; + } + } + + /********** Helper Functions **********/ + + /*! \brief Checking if two arrays contains same elements. */ + static bool IsSameArray(const Array<PrimExpr>& lhs, const Array<PrimExpr>& rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + for (size_t i = 0; i < lhs.size(); ++i) { + if (!lhs[i].same_as(rhs[i])) { + return false; + } + } + return true; + } + + /*! + * \brief Checking the load indices and store indices follows elemwise pattern. + * It's elemwise pattern iff load indices and store indices are the same. + * E.g A[i, j] = B[i, j] + */ + static bool IsElemwisePattern(const BufferStore& store, const BufferLoad& load) { + return IsSameArray(store->indices, load->indices); + } + + /*! + * \brief Checking the load indices and store indices follows broadcast pattern. + * It's broadcast pattern iff all load indices are in the store indices in order + * E.g. A[i, j] = B[i] is broadcast since all load indices(`i`) are in the store indices + * A[i, j] = B[i, k] is not broadcast since `k` are not in the store indices. + * A[i, j] = B[j, i] is not broadcast the load indices are not in the same order as store's + */ + static bool IsBroadcastPattern(const BufferStore& store, const BufferLoad& load) { + size_t ndim_load_buf = load->buffer->shape.size(); + size_t ndim_store_buf = store->buffer->shape.size(); + + for (size_t i = 0, j = 0; i < ndim_load_buf; ++i) { + if (is_const_int(load->buffer->shape[i], 1) && is_const_int(load->indices[i], 0)) { + // Skip unit load dimensions + // E.g. A[i, j] = B[1, j] is still broadcast + continue; + } + + // Try to find the i-th load indice in the store indices. + while (j < ndim_store_buf && !store->indices[j].same_as(load->indices[i])) { + ++j; + } + + // It's not broadcast if we cannot find load indices in the store indices in order. + if (j == ndim_store_buf) { + return false; + } + } + return true; + } + + /*! + * \brief Checking the load indices and store indices follows injective pattern. + * It's injective pattern iff all load indice vars are in the store indices, no matter orders. + * Note that we only support store indices are direct vars so far, which can be enhance later. + * E.g. A[i, j] = B[j, i] is injective. + * A[i, j] = B[i - j] is injective since the load indice vars are only i, j + */ + static bool IsInjectivePattern(const BufferStore& store, const BufferLoad& load) { + std::unordered_set<const tir::VarNode*> vars; + for (const PrimExpr& store_index : store->indices) { + if (const auto* v = store_index.as<tir::VarNode>()) { + vars.insert(v); + } else { + return false; + } + } + for (const PrimExpr& load_index : load->indices) { + // return false if there are vars used in load indices but not in store indices. + if (tir::UsesVar(load_index, [&vars](const tir::VarNode* var) { return !vars.count(var); })) { + return false; + } + } + return true; + } + + /*! + * \brief Checking the load indices and store indices allow data reuse. + * It allow data reuse iff there is any vars in load indices but they are not in store indices + * E.g. Store = A[i, j] and Load = B[i, j, k] allow data reuse. + * Store = A[i, j] and Load = B[i, j + k] allow data reuse. + */ + static bool IsAllowReusePattern(const BufferStore& store, const BufferLoad& load) { + std::unordered_set<const tir::VarNode*> vars; + for (const PrimExpr& index : store->indices) { + if (const auto* v = index.as<tir::VarNode>()) { + vars.insert(v); + } else { + return false; + } + } + for (const PrimExpr& index : load->indices) { + PreOrderVisit(index, [&](const ObjectRef& node) { + if (const auto* v = node.as<tir::VarNode>()) { + if (vars.count(v)) { + vars.erase(v); + } + } + return true; + }); + } + return !vars.empty(); + } + + /*! \brief Checking if the stmt is multiply add. E.g. C[i, j] += A[i, k] * B[j, k] */ + static bool IsFMA(const Stmt& body) { + if (const auto* store = body.as<BufferStoreNode>()) { + if (const auto* add = store->value.as<AddNode>()) { + if (const auto* l = add->a.as<BufferLoadNode>()) { + if (const auto* r = add->b.as<MulNode>()) { + bool incremental = + store->buffer.same_as(l->buffer) && IsSameArray(store->indices, l->indices); + const auto* l_load = r->a.as<BufferLoadNode>(); + const auto* r_load = r->b.as<BufferLoadNode>(); + if (incremental && l_load && r_load) { + return IsAllowReusePattern(GetRef<BufferStore>(store), GetRef<BufferLoad>(l_load)) && + IsAllowReusePattern(GetRef<BufferStore>(store), GetRef<BufferLoad>(r_load)); + } + } + } + } + } + return false; + } + + /*! + * \brief Checking if it is pure reduce pattern. + * It's pure reduce pattern iff all reduces axis are directly reduce var + * E.g. A[i] = sum(B[i, j]) is pure reduce + * A[i] = sum(B[i, j + k]) is not pure reduce + * pooling is not pure reduce + */ + static bool IsPureReducePattern(Array<tir::Var> reduce_loops, Array<PrimExpr> indices) { + for (const PrimExpr& e : indices) { + int id = -1; + if (UsesVar(e, [&](const tir::VarNode* var) { + for (size_t i = 0; i < reduce_loops.size(); ++i) { + if (reduce_loops[i].get() == var) { + id = i; + return true; + } + } + return false; + })) { + if (!reduce_loops[id].same_as(e)) { + return false; + } + } + } + return true; + } + + private: + /*! + * \brief The BufferStore node in the current block. + * \note We only support one BufferStore node in a block (ususally generated by TE compute) + */ + Optional<BufferStore> store_; + /*! \brief The BufferLoad nodes in the current block. */ + Array<BufferLoad> loads_; + /*! \brief The result of op pattern. */ + relay::OpPatternKind kind_ = relay::kElemWise; + /*! \brief The buffers from function params. I.e. the input and output buffers. */ + std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> param_buffers_; + + public: + relay::OpPatternKind GetResult() { return kind_; } +}; + +relay::OpPatternKind AnalyzeOpPatternKind(const PrimFunc& func) { + PatternKindAnalyzer analyzer(func); + analyzer(func->body); + return analyzer.GetResult(); +} + +bool HasReshapePattern(const PrimFunc& func) { + class ReshapeDetector : public StmtVisitor { + public: + static bool Detect(const Buffer& src_buffer, const Buffer& dst_buffer, Stmt stmt) { + ReshapeDetector detector(src_buffer, dst_buffer); + detector(stmt); + return detector.is_reshape_; + } + + private: + explicit ReshapeDetector(const Buffer& src_buffer, const Buffer& dst_buffer) + : is_reshape_(false), src_buffer_(src_buffer), dst_buffer_(dst_buffer) {} + + void VisitStmt_(const ForNode* loop) final { + ana_.Bind(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); + // To detect the reshape pattern, we require each For to have + // either another For or a BlockRealize as body. + if (!(loop->body->IsInstance<ForNode>() || loop->body->IsInstance<BlockRealizeNode>())) { + return; + } + this->VisitStmt(loop->body); + } + + void VisitStmt_(const BlockRealizeNode* block_realize) final { + // Constructing the mapping from block iterators to iterator + // binding values. The mapping will be used in the substitution of + // the flattened buffer access index. + const Block& block = block_realize->block; + const Array<IterVar>& block_iter = block->iter_vars; + const Array<PrimExpr>& iter_values = block_realize->iter_values; + ICHECK_EQ(block_iter.size(), iter_values.size()); + int n_iter = block_iter.size(); + for (int i = 0; i < n_iter; ++i) { + // To detect the reshape pattern, we require each block iter to be data-parallel. + if (block_iter[i]->iter_type != tir::IterVarType::kDataPar) { + return; + } + var_map_.Set(block_iter[i]->var, iter_values[i]); + } + + // Recurse into the block. + this->VisitStmt(block); + } + + void VisitStmt_(const BlockNode* block) final { + // Step 0. If the block body is a ForNode, recurse into it. + if (block->body->IsInstance<ForNode>()) { + this->VisitStmt(block->body); + return; + } + + // Step 1. Get the load/store pattern of the block body. + // To detect the reshape pattern, we require the block body to be a + // BufferStore, which has a BufferLoad as value. + const auto* buffer_store = block->body.as<BufferStoreNode>(); + if (buffer_store == nullptr) { + return; + } + const auto* buffer_load = buffer_store->value.as<BufferLoadNode>(); + if (buffer_load == nullptr) { + return; + } + // Further, we require the buffer being stored and being loaded to + // match the parameter of the PrimFunc, namely `dst_buffer_` and `src_buffer_`. + if (!(buffer_store->buffer.same_as(dst_buffer_) && + buffer_load->buffer.same_as(src_buffer_))) { + return; + } + + // Step 3. Calculate the flattened access index according to the load/store pattern. + auto f_calc_flattened_idx = [](const Buffer& buffer, const Array<PrimExpr>& indices) { + ICHECK_EQ(indices.size(), buffer->shape.size()); + int ndim = indices.size(); + PrimExpr idx = 0; + for (int i = 0; i < ndim; ++i) { + idx = idx * buffer->shape[i] + indices[i]; + } + return idx; + }; + PrimExpr src_idx = f_calc_flattened_idx(src_buffer_, buffer_load->indices); + PrimExpr dst_idx = f_calc_flattened_idx(dst_buffer_, buffer_store->indices); + + // Step 4. Substitute the block iterators in the flattened index + // with loop variables, and check if we can prove their equality. + src_idx = tir::Substitute(std::move(src_idx), var_map_); + dst_idx = tir::Substitute(std::move(dst_idx), var_map_); + if (ana_.CanProveEqual(src_idx, dst_idx)) { + this->is_reshape_ = true; + } + } + + bool is_reshape_; + /*! \brief The mapping from block vars to block binding values. */ + Map<tir::Var, PrimExpr> var_map_; + const Buffer& src_buffer_; + const Buffer& dst_buffer_; + arith::Analyzer ana_; + }; + + if (func->params.size() < 2) { + return false; + } + Optional<Buffer> src_buffer = func->buffer_map.Get(func->params.front()); + Optional<Buffer> dst_buffer = func->buffer_map.Get(func->params.back()); + if (!(src_buffer.defined() && dst_buffer.defined())) { + return false; + } + + // To detect the reshape pattern, we require each For to have + // either another For or a BlockRealize as body. + ICHECK(func->body->IsInstance<BlockRealizeNode>()); + return ReshapeDetector::Detect(src_buffer.value(), dst_buffer.value(), func->body); +} + +TVM_REGISTER_GLOBAL("relax.analysis.has_reshape_pattern").set_body_typed(HasReshapePattern); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/backend/vm/vm_builtin_lower.cc b/src/relax/backend/vm/vm_builtin_lower.cc new file mode 100644 index 0000000000..6613b39626 --- /dev/null +++ b/src/relax/backend/vm/vm_builtin_lower.cc @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/backend/vm/vm_builtin_lower.cc + * \brief Lowers most builtin functions and packed calls. + */ +#include <tvm/relax/analysis.h> +#include <tvm/relax/backend.h> +#include <tvm/relax/expr_functor.h> +#include <tvm/relax/type.h> +#include <tvm/runtime/data_type.h> +#include <tvm/tir/op.h> + +namespace tvm { +namespace relax { + +// This pass lowers most ops to VM specific builtins. +// TODO(relax-team): revisit after PrimValue. +class VMBuiltinLowerMutator : public ExprMutator { + public: + using ExprMutator::VisitExpr_; + + // A workaround to remove the CallNodes of killing tensors and storages. + void VisitBinding_(const VarBindingNode* binding) final { + const auto* call = binding->value.as<CallNode>(); + if (call != nullptr && (call->op == mem_kill_storage_op_ || call->op == mem_kill_tensor_op_)) { + return; + } + ExprMutator::VisitBinding_(binding); + } + + Expr VisitExpr_(const CallNode* call_node) final { + // post-order mutation + Call call = Downcast<Call>(VisitExprPostOrder_(call_node)); + + if (call->op == call_tir_dyn_op_) { + return CallTIRDyn(call); + } else if (call->op == reshape_op_) { + return Reshape(call); + } else if (call->op == make_closure_op_) { + return MakeClosure(call); + } else if (call->op == invoke_closure_op_) { + return InvokeClosure(call); + } else if (call->op == alloc_tensor_op_) { + return MakeAllocTensor(call); + } else if (call->op == mem_alloc_storage_op_) { + return MakeMemAllocStorage(call); + } else if (call->op == mem_alloc_tensor_op_) { + return MakeMemAllocTensor(call); + } else { + return call; + } + } + + Expr ComputeStorageSize(const Expr& shape, const DataType& dtype) const { + // Question: what if the dtype of tensor_type is unknown? + // Symbolic/static shape case + if (auto* shape_expr = shape.as<ShapeExprNode>()) { + int64_t elem_bytes = runtime::GetVectorBytes(dtype); + PrimExpr ret = IntImm(DataType::Int(64), elem_bytes); + for (PrimExpr dim : shape_expr->values) { + ret = ret * dim; + } + return ShapeExpr({ret}); + } else { + return Call(builtin_compute_alloc_shape_, {shape, DataTypeImm(dtype)}, Attrs(), + {GetStructInfo(shape)}); + } + } + + Expr MakeAllocTensor(const Call& call) { + ShapeExpr output_shape = Downcast<ShapeExpr>(call->args[0]); + DataTypeImm output_dtype = Downcast<DataTypeImm>(call->args[1]); + DataType dtype = output_dtype->value; + Expr storage_size = ComputeStorageSize(output_shape, dtype); + PrimValue runtime_device_index = Downcast<PrimValue>(call->args[2]); + Var storage = builder_->Emit( + Call(vm_alloc_storage_op_, {storage_size, runtime_device_index, output_dtype}, Attrs()), + "storage"); + Expr shape = call->args[0]; + PrimValue offset = PrimValue::Int64(0); + return Call(vm_alloc_tensor_op_, {storage, offset, shape, DataTypeImm(dtype)}, Attrs()); + } + + Expr MakeMemAllocStorage(const Call& call) { + PrimValue runtime_device_index = Downcast<PrimValue>(call->args[1]); + DataTypeImm output_dtype = Downcast<DataTypeImm>(call->args[3]); + return Call(vm_alloc_storage_op_, {call->args[0], runtime_device_index, output_dtype}, Attrs()); + } + + Expr MakeMemAllocTensor(const Call& call) { + PrimValue offset = Downcast<PrimValue>(call->args[1]); + DataTypeImm dtype = Downcast<DataTypeImm>(call->args[3]); + return Call(vm_alloc_tensor_op_, {call->args[0], offset, call->args[2], dtype}, Attrs()); + } + + Expr CallTIRDyn(const Call& call_node) { + ICHECK(call_node->args.size() == 2); + ICHECK(call_node->args[0]->IsInstance<GlobalVarNode>()); + ICHECK(call_node->args[1]->IsInstance<TupleNode>()); + Array<Expr> args; + + auto tir_args = Downcast<Tuple>(call_node->args[1]); + args.push_back(call_node->args[0]); + for (Expr arg : tir_args->fields) { + args.push_back(arg); + } + return Call(builtin_call_tir_dyn_, args, Attrs(), {void_sinfo_}); + } + + Expr Reshape(const Call& call_node) { + ICHECK(call_node->args.size() == 2); + ICHECK(call_node->struct_info_.defined()); + CHECK(call_node->args[1]->IsInstance<ShapeExprNode>()) + << "VMBuiltinLower expects the shape arg of reshape op to be a ShapeExpr"; + return Call(builtin_reshape_, call_node->args, Attrs(), {GetStructInfo(call_node)}); + } + + Expr MakeClosure(const Call& call_node) { + ICHECK(call_node->args.size() == 2); + ICHECK(call_node->args[0]->IsInstance<GlobalVarNode>()); + ICHECK(call_node->args[1]->IsInstance<TupleNode>()); + + Array<Expr> args; + auto func = call_node->args[0]; + auto closure_args = Downcast<Tuple>(call_node->args[1]); + + args.push_back(func); + for (Expr arg : closure_args->fields) { + args.push_back(arg); + } + + return Call(builtin_make_closure_, args, Attrs(), {object_sinfo_}); + } + + Expr InvokeClosure(const Call& call_node) { + ICHECK(call_node->args.size() == 2); + ICHECK(call_node->args[0]->IsInstance<VarNode>()); + ICHECK(call_node->args[1]->IsInstance<TupleNode>()); + + Array<Expr> args; + + args.push_back(call_node->args[0]); + + // args for the invoke_closure + auto invoke_closure_args = Downcast<Tuple>(call_node->args[1]); + for (Expr arg : invoke_closure_args->fields) { + args.push_back(arg); + } + return Call(call_builtin_with_ctx_op_, {builtin_invoke_closure_, Tuple(args)}, Attrs(), + {object_sinfo_}); + } + + const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx"); + const StructInfo object_sinfo_ = ObjectStructInfo(); + const StructInfo void_sinfo_ = TupleStructInfo(Array<StructInfo>({})); + // object to pattern match. + const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn"); + const Op& reshape_op_ = Op::Get("relax.reshape"); + const Op& make_closure_op_ = Op::Get("relax.make_closure"); + const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure"); + const Op& alloc_tensor_op_ = Op::Get("relax.builtin.alloc_tensor"); + const Op& mem_alloc_storage_op_ = Op::Get("relax.memory.alloc_storage"); + const Op& mem_alloc_tensor_op_ = Op::Get("relax.memory.alloc_tensor"); + const Op& mem_kill_storage_op_ = Op::Get("relax.memory.kill_storage"); + const Op& mem_kill_tensor_op_ = Op::Get("relax.memory.kill_tensor"); + // functions to lower to + const Op& vm_alloc_storage_op_ = Op::Get("relax.vm.alloc_storage"); + const Op& vm_alloc_tensor_op_ = Op::Get("relax.vm.alloc_tensor"); + // Function to compute allocated shape. + const ExternFunc builtin_compute_alloc_shape_{"vm.builtin.compute_alloc_shape"}; + const ExternFunc builtin_call_tir_dyn_{"vm.builtin.call_tir_dyn"}; + const ExternFunc builtin_reshape_{"vm.builtin.reshape"}; + const ExternFunc builtin_make_closure_{"vm.builtin.make_closure"}; + const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"}; +}; + +Expr VMBuiltinLower(const Expr& e) { return VMBuiltinLowerMutator().VisitExpr(e); } + +namespace transform { + +Pass VMBuiltinLower() { + runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast<Function>(VMBuiltinLower(f)); }; + return CreateFunctionPass(pass_func, 0, "VMBuiltinLower", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.VMBuiltinLower").set_body_typed(VMBuiltinLower); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index ca66b0a9ef..ba167a45bc 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -226,6 +226,87 @@ Expr MakeAllocTensor(Expr shape, DataType dtype, int64_t runtime_device_index) { TVM_REGISTER_GLOBAL("relax.op.builtin.alloc_tensor").set_body_typed(MakeAllocTensor); +// memory planning alloc_storage + +RELAY_REGISTER_OP("relax.memory.alloc_storage") + .set_num_inputs(4) + .add_argument("total_space", "Expr", "The total space of the storage to allocate.") + .add_argument( + "virtual_device_index", "int64_t", + "The virtual device index indicating on which device the storage is to be allocated, " + "Index -1 is reserved for the host device.") + .add_argument("storage_scope", "string", + "The storage scope of the storage to allocate. Default is global.") + .add_argument("dtype", "DataType", "The dtype of the tensor to allocate.") + .set_attr<FInferStructInfo>("FInferStructInfo", ReturnObjectStructInfo); + +Expr MakeAllocStorage(Expr size, int64_t virtual_device_index, std::string storage_scope, + DataType dtype) { + static const Op& op = Op::Get("relax.memory.alloc_storage"); + return Call( + op, + {size, PrimValue::Int64(virtual_device_index), StringImm(storage_scope), DataTypeImm(dtype)}, + Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.memory.alloc_storage").set_body_typed(MakeAllocStorage); + +// memory planning alloc_tensor + +StructInfo InferStructInfoMemAllocTensor(const Call& call, const BlockBuilder& ctx) { + ICHECK(GetStructInfoAs<ShapeStructInfoNode>(call->args[2])) + << "must be a Expr of ShapeStructInfo, but got " << call->args[1]->GetTypeKey(); + DataType out_dtype; + if (const auto* dtype_node = call->args[3].as<DataTypeImmNode>()) { + const DataTypeImm dtype_imm = GetRef<DataTypeImm>(dtype_node); + out_dtype = dtype_imm->value; + } + return TensorStructInfo(call->args[2], out_dtype); +} + +RELAY_REGISTER_OP("relax.memory.alloc_tensor") + .set_num_inputs(4) + .add_argument("storage", "Expr", "The storage to allocate the tensor to.") + .add_argument("offset", "int", "Storage offset to allocate the tensor.") + .add_argument("shape", "Expr", "The shape of the tensor to allocate.") + .add_argument("dtype", "DataType", "The dtype of the tensor to allocate.") + .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoMemAllocTensor); + +Expr MakeMemAllocTensor(Expr storage, int offset, Expr shape, DataType dtype) { + static const Op& op = Op::Get("relax.memory.alloc_tensor"); + return Call(op, {storage, PrimValue::Int64(offset), shape, DataTypeImm(dtype)}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.memory.alloc_tensor").set_body_typed(MakeMemAllocTensor); + +// memory planning kill_storage + +RELAY_REGISTER_OP("relax.memory.kill_storage") + .set_num_inputs(1) + .add_argument("storage", "Expr", "The storage to be killed.") + .set_attr<FInferStructInfo>("FInferStructInfo", ReturnVoidStructInfo); + +Expr MakeMemKillStorage(Expr storage) { + static const Op& op = Op::Get("relax.memory.kill_storage"); + return Call(op, {storage}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.memory.kill_storage").set_body_typed(MakeMemKillStorage); + +// memory planning kill_tensor + +RELAY_REGISTER_OP("relax.memory.kill_tensor") + .set_num_inputs(1) + .add_argument("tensor", "Expr", "The tensor to be killed.") + .set_attr<FInferStructInfo>("FInferStructInfo", ReturnVoidStructInfo); + +Expr MakeMemKillTensor(Expr tensor) { + static const Op& op = Op::Get("relax.memory.kill_tensor"); + return Call(op, {tensor}, {}, {}); +} + +TVM_REGISTER_GLOBAL("relax.op.memory.kill_tensor").set_body_typed(MakeMemKillTensor); + // vm alloc_storage RELAY_REGISTER_OP("relax.vm.alloc_storage") diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc new file mode 100644 index 0000000000..2088a8306e --- /dev/null +++ b/src/relax/op/tensor/manipulate.cc @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file manipulate.cc + * \brief Manipulation operators. + */ + +#include "manipulate.h" + +#include <algorithm> +#include <numeric> +#include <utility> +#include <vector> + +namespace tvm { +namespace relax { + +// Helper function for flatten and reshape. +PrimExpr ComputeShapeProduct(const Array<PrimExpr>& shape_values) { + PrimExpr shape_prod = IntImm(DataType::Int(64), 1); + for (PrimExpr value : shape_values) { + shape_prod *= value; + } + return shape_prod; +} + +/* relax.reshape */ +Expr ConvertNewShapeToExpr(const Expr& data, const ObjectRef& shape) { + if (const auto* e = shape.as<ExprNode>()) { + return GetRef<Expr>(e); + } + + const auto* array = shape.as<ArrayNode>(); + CHECK(array != nullptr) << "Reshape only expects the input new shape to be either an Expr or an " + "Array of PrimExprs. However, the given new shape is " + << shape; + int dim_to_infer = -1; + PrimExpr new_shape_prod = IntImm(DataType::Int(64), 1); + for (int i = 0; i < static_cast<int>(array->size()); ++i) { + const auto* _len = array->at(i).as<PrimExprNode>(); + CHECK(_len != nullptr) << "Reshape only expects the input new shape to be either an Expr or an " + "Array of PrimExprs. However, the given new shape is " + << shape; + PrimExpr len = GetRef<PrimExpr>(_len); + CHECK(len->dtype.is_int()) << "Reshape requires the new shape values to be all " + "integers. However, the give new shape is " + << shape; + const auto* int_len = len.as<IntImmNode>(); + if (int_len != nullptr && int_len->value == -1) { + CHECK_EQ(dim_to_infer, -1) << "Reshape accepts at most one \"-1\" in the new shape. However, " + "there are multiple \"-1\" in the given new shape " + << shape; + dim_to_infer = i; + } else { + CHECK(int_len == nullptr || int_len->value > 0) + << "Reshape requires all values in the new shape to be positive except a single \"-1\". " + "However, the given new shape is " + << shape; + // We expect any symbolic not to signal the intent of -1, and therefore do no check for + // symbolic value here. + new_shape_prod = new_shape_prod * len; + } + } + + Array<PrimExpr> array_ref = GetRef<Array<PrimExpr>>(array); + // When there is no dimension to infer, just return the input array as ShapeExpr. + if (dim_to_infer == -1) { + return ShapeExpr(array_ref); + } + + // Otherwise, we require the input tensor to have known shape value for inference. + const auto* data_sinfo = GetStructInfoAs<TensorStructInfoNode>(data); + CHECK(data_sinfo != nullptr) + << "Reshape expects the input data to be a Tensor. However, the given input is " + << data->struct_info_->GetTypeKey(); + CHECK(data_sinfo->shape.defined()) + << "Reshape expects the input tensor to have known shape when there is some dimension length " + "to infer. However, the given input has no shape."; + const auto* shape_sinfo = GetStructInfoAs<ShapeStructInfoNode>(data_sinfo->shape.value()); + CHECK(shape_sinfo != nullptr && shape_sinfo->values.defined()) + << "Reshape expects the input tensor to have known shape when there is some dimension length " + "to infer. However, the given input shape is " + << data_sinfo->shape << " whose shape value is unknown."; + + arith::Analyzer analyzer; + PrimExpr old_shape_prod = ComputeShapeProduct(shape_sinfo->values.value()); + array_ref.Set(dim_to_infer, analyzer.Simplify(floordiv(old_shape_prod, new_shape_prod))); + return ShapeExpr(array_ref); +} + +Expr reshape(Expr x, ObjectRef shape) { + Expr shape_in_expr = ConvertNewShapeToExpr(x, shape); + static const Op& op = Op::Get("relax.reshape"); + return Call(op, {std::move(x), std::move(shape_in_expr)}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.reshape").set_body_typed(reshape); + +StructInfo InferStructInfoReshape(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 2) { + ctx->ReportFatal(Diagnostic::Error(call->span) << "Reshape op should take 2 arguments"); + } + const auto* data_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]); + const auto* new_shape_sinfo = GetStructInfoAs<ShapeStructInfoNode>(call->args[1]); + if (data_sinfo == nullptr) { + ctx->ReportFatal(Diagnostic::Error(call->span) + << "Reshape requires the input data to be Tensor. However, the given one is " + << call->args[0]->struct_info_->GetTypeKey()); + } + if (new_shape_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call->span) + << "Reshape requires the input new shape to be Shape. However, the given one is " + << call->args[1]->struct_info_->GetTypeKey()); + } + + Optional<Array<PrimExpr>> old_shape_values; + if (data_sinfo->shape.defined()) { + const auto* old_shape_sinfo = GetStructInfoAs<ShapeStructInfoNode>(data_sinfo->shape.value()); + ICHECK_NOTNULL(old_shape_sinfo); + old_shape_values = old_shape_sinfo->values; + } + + if (new_shape_sinfo->values.defined() && old_shape_values.defined()) { + PrimExpr new_shape_prod = ComputeShapeProduct(new_shape_sinfo->values.value()); + PrimExpr old_shape_prod = ComputeShapeProduct(old_shape_values.value()); + if (ctx->GetAnalyzer()->CanProve(old_shape_prod != new_shape_prod)) { + ctx->ReportFatal(Diagnostic::Error(call->span) + << "Reshape expects the new shape to be convertible from the old shape. " + "However, the old shape is " + << data_sinfo->shape << ", with product " << old_shape_prod + << ", while the new shape is " << call->args[1] << ", with product " + << new_shape_prod); + } + } + return TensorStructInfo(call->args[1], data_sinfo->dtype); +} + +TVM_REGISTER_OP("relax.reshape") + .set_num_inputs(2) + .add_argument("x", "Tensor", "The input tensor.") + .add_argument("shape", "Shape", "The input new shape.") + .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoReshape); + +} // namespace relax +} // namespace tvm diff --git a/include/tvm/relax/backend.h b/src/relax/op/tensor/manipulate.h similarity index 58% copy from include/tvm/relax/backend.h copy to src/relax/op/tensor/manipulate.h index 4ebeacac0f..1a3eb0547d 100644 --- a/include/tvm/relax/backend.h +++ b/src/relax/op/tensor/manipulate.h @@ -18,27 +18,28 @@ */ /*! - * \file tvm/relax/backend.h - * \brief Relax backend specific transformation passes. + * \file manipulate.h + * \brief The functions to make Relax tensor manipulation operator calls. */ -#ifndef TVM_RELAX_BACKEND_H_ -#define TVM_RELAX_BACKEND_H_ +#ifndef TVM_RELAX_OP_TENSOR_MANIPULATE_H_ +#define TVM_RELAX_OP_TENSOR_MANIPULATE_H_ -#include <tvm/relax/transform.h> +#include "../op_common.h" namespace tvm { namespace relax { -namespace transform { /*! - * \brief Lower the shape expression in relax to VM shape heap and TIR functions. - * - * \return The Pass. + * \brief Reshape the input array, supporting `-1` inference in the new + * shape when the new shape is given as an Array of PrimExpr. + * \param x The input data to the operator. + * \param shape The new shape. Should be compatible with the original shape. + * It is required to be either an Array of PrimExpr, or a Shape in Relax + * \return The reshaped result. */ -TVM_DLL Pass VMShapeLower(); +Expr reshape(Expr x, ObjectRef shape); -} // namespace transform } // namespace relax } // namespace tvm -#endif // TVM_RELAX_BACKEND_H_ +#endif // TVM_RELAX_OP_TENSOR_MANIPULATE_H_ diff --git a/src/relax/transform/attach_global_symbol.cc b/src/relax/transform/attach_global_symbol.cc new file mode 100644 index 0000000000..be779e97bc --- /dev/null +++ b/src/relax/transform/attach_global_symbol.cc @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/attach_global_symbol.cc + * \brief Attach global_symbol to Relax functions and TIR Primfuncs for codegen. + */ + +#include <tvm/relax/transform.h> +#include <tvm/tir/function.h> + +namespace tvm { +namespace relax { + +class GlobalSymbolAttacher { + public: + explicit GlobalSymbolAttacher(IRModule mod) : mod_(mod) {} + + IRModule Attach() { + IRModule ret; + for (auto& p : mod_->functions) { + BaseFunc func = p.second; + if (auto* prim_func = func.as<tir::PrimFuncNode>()) { + func = WithAttr(GetRef<tir::PrimFunc>(prim_func), "global_symbol", p.first->name_hint); + } else if (auto* relax_func = func.as<FunctionNode>()) { + func = WithAttr(GetRef<Function>(relax_func), "global_symbol", p.first->name_hint); + } else { + LOG(FATAL) << "Unsupported function type: " << func->GetTypeKey(); + throw; + } + ret->Add(p.first, func); + } + return ret; + } + + private: + IRModule mod_; +}; + +namespace transform { + +Pass AttachGlobalSymbol() { + runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = + [=](IRModule mod, PassContext pc) { return GlobalSymbolAttacher(mod).Attach(); }; + return CreateModulePass(pass_func, 0, "AttachGlobalSymbol", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.AttachGlobalSymbol").set_body_typed(AttachGlobalSymbol); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/call_tir_rewrite.cc b/src/relax/transform/call_tir_rewrite.cc new file mode 100644 index 0000000000..2ea039e022 --- /dev/null +++ b/src/relax/transform/call_tir_rewrite.cc @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/call_tir_rewrite.cc + * \brief Perform explicit tensor allocation for call_tir. + */ +#include <tvm/relax/expr_functor.h> +#include <tvm/relax/struct_info.h> +#include <tvm/relax/transform.h> +#include <tvm/relax/type.h> +#include <tvm/tir/op.h> + +#include "../../relay/transforms/pattern_utils.h" + +namespace tvm { +namespace relax { + +// ================== +// CallTIRMutator +// Perform explicit tensor allocation for call_tir. +// Example: +// lv0: Tensor(n, m) = rx.call_tir(func, (x), (n, m), dtype="float32") +// --> +// gv0 = rx.call("relax.builtin.alloc_tensor", [n, m], dtype="float32") +// rx.call_packed(func, x, gv0) + +class CallTIRMutator : public ExprMutator { + public: + using ExprMutator::VisitExpr_; + Expr VisitExpr_(const CallNode* call) override { + // post-order mutation + Expr expr = VisitExprPostOrder_(call); + call = expr.as<CallNode>(); + + static const Op& call_tir_op = Op::Get("relax.call_tir"); + static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor"); + static const Op& call_tir_dyn_op = Op::Get("relax.vm.call_tir_dyn"); + + if (call->op == call_tir_op) { + Array<Expr> outs; + if (const auto& _tensor_sinfo = MatchStructInfo<TensorStructInfo>(expr)) { + // single output case + const TensorStructInfo& tensor_sinfo = _tensor_sinfo.value(); + ICHECK(tensor_sinfo->shape.defined()) + << "the TensorStructInfo shape of call_tir has not populated"; + outs.push_back( + builder_->Emit(Call(alloc_tensor_op, // + {Downcast<ShapeExpr>(tensor_sinfo->shape.value()), + DataTypeImm(tensor_sinfo->dtype), PrimValue::Int64(0)}, // + Attrs()), + "alloc")); + } else if (const auto& _tuple_sinfo = MatchStructInfo<TupleStructInfo>(expr)) { + // multiple output case + const TupleStructInfo& tuple_sinfo = _tuple_sinfo.value(); + for (size_t i = 0; i < tuple_sinfo->fields.size(); ++i) { + const auto& field = tuple_sinfo->fields[i]; + + ICHECK(field->IsInstance<TensorStructInfoNode>()) + << "call_tir expects Tuple of TensorStructInfo, but got " << field + << " as an element of TupleStructInfo"; + const auto& field_tensor = Downcast<TensorStructInfo>(field); + ICHECK(field_tensor->shape.defined()) + << "call_tir expects all TensorStructInfo has shape, but got " << field_tensor + << " as an element of TupleStructInfo"; + outs.push_back( + builder_->Emit(Call(alloc_tensor_op, + {Downcast<ShapeExpr>(field_tensor->shape.value()), + DataTypeImm(field_tensor->dtype), PrimValue::Int64(0)}, + Attrs()), + "alloc")); + } + } else { + LOG(FATAL) << "TypeError: The struct info of call_tir expects to be TensorStructInfo or " + "TupleStructInfo, but got" + << expr->struct_info_; + } + + Array<Expr> args; + if (call->args[1].as<TupleNode>()) { + args = Downcast<Tuple>(call->args[1])->fields; + args.insert(args.end(), outs.begin(), outs.end()); + + if (call->args.size() == 2) { + builder_->Emit(Call(call->args[0], args), "_"); + } else { + // unpack semantics + args.push_back(call->args[2]); + builder_->Emit(Call(call_tir_dyn_op, {call->args[0], Tuple(args)}), "_"); + } + } else { + args = outs; + args.insert(args.begin(), call->args[1]); + builder_->Emit(Call(call->args[0], args), "_"); + } + + if (outs.size() == 1) { + return outs[0]; + } + return std::move(Tuple(outs)); + } + + return GetRef<Expr>(call); + } +}; + +Expr CallTIRRewrite(const Expr& e) { return CallTIRMutator().VisitExpr(e); } + +namespace transform { + +Pass CallTIRRewrite() { + runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast<Function>(CallTIRRewrite(f)); }; + return CreateFunctionPass(pass_func, 0, "CallTIRRewrite", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.CallTIRRewrite").set_body_typed(CallTIRRewrite); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/rewrite_dataflow_reshape.cc b/src/relax/transform/rewrite_dataflow_reshape.cc new file mode 100644 index 0000000000..aec0911ecc --- /dev/null +++ b/src/relax/transform/rewrite_dataflow_reshape.cc @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/rewrite_dataflow_reshape.cc + * \brief Transform all reshape within dataflow block to a relax.reshape operator + */ +#include <tvm/relax/analysis.h> +#include <tvm/relax/expr_functor.h> +#include <tvm/relax/transform.h> + +#include "../op/tensor/manipulate.h" + +namespace tvm { +namespace relax { + +class DataflowReshapeRewriter : public ExprMutator { + public: + explicit DataflowReshapeRewriter(const IRModule& mod) : mod_(mod) {} + + private: + using ExprMutator::VisitExpr_; + + BindingBlock VisitBindingBlock(const BindingBlock& block) final { + // We only rewrite the bindings inside dataflow blocks. + if (const auto* dataflow_block = block.as<DataflowBlockNode>()) { + return VisitBindingBlock_(dataflow_block); + } else { + return block; + } + } + + void VisitBinding_(const VarBindingNode* binding) final { + // We only rewrite the bindings that are not dataflow output (which means they are not + // externally referenced) + if (!binding->var->IsInstance<DataflowVarNode>()) { + this->builder_->EmitNormalized(GetRef<VarBinding>(binding)); + } else { + ExprMutator::VisitBinding_(binding); + } + } + + Expr VisitExpr_(const CallNode* call) final { + if (!IsCallingTIRReshape(call)) { + return GetRef<Call>(call); + } + + // We bring the calls of reshape PrimFunc back to calls of high-level + // relax.reshape op, which will be lowered to calls of the ExternFunc + // vm.builtin.reshape in the VMBuiltinLower pass. + Array<Expr> args = Downcast<Tuple>(call->args[1])->fields; + ICHECK_EQ(args.size(), 1); + TensorStructInfo res_sinfo = Downcast<TensorStructInfo>(call->struct_info_); + ICHECK(res_sinfo->shape.defined()); + return reshape(args[0], res_sinfo->shape.value()); + } + + bool IsCallingTIRReshape(const CallNode* call) { + static const Op& call_tir_op = Op::Get("relax.call_tir"); + if (call->op != call_tir_op) { + return false; + } + const auto* gv = call->args[0].as<GlobalVarNode>(); + if (gv == nullptr) { + return false; + } + const auto* func = mod_->functions.Get(GetRef<GlobalVar>(gv)).as<tir::PrimFuncNode>(); + ICHECK_NOTNULL(func); + return HasReshapePattern(GetRef<tir::PrimFunc>(func)); + } + + const IRModule& mod_; +}; + +Expr RewriteDataflowReshape(const Function& f, const IRModule& mod) { + return DataflowReshapeRewriter(mod)(f); +} + +namespace transform { + +Pass RewriteDataflowReshape() { + runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast<Function>(RewriteDataflowReshape(f, m)); + }; + return CreateFunctionPass(pass_func, 0, "RewriteDataflowReshape", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.RewriteDataflowReshape") + .set_body_typed(RewriteDataflowReshape); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/src/relax/transform/to_non_dataflow.cc b/src/relax/transform/to_non_dataflow.cc new file mode 100644 index 0000000000..db2e9d7ee5 --- /dev/null +++ b/src/relax/transform/to_non_dataflow.cc @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/to_non_dataflow.cc + * \brief Transform all dataflow structure to non-dataflow version. + */ +#include <tvm/relax/expr_functor.h> +#include <tvm/relax/struct_info.h> +#include <tvm/relax/transform.h> +#include <tvm/relax/type.h> +#include <tvm/tir/op.h> + +namespace tvm { +namespace relax { + +class ToNonDFMutator : public ExprMutator { + public: + Var VisitVarDef(const Var& var) final { + if (var.as<DataflowVarNode>()) { + Var new_var = Var(var->vid, GetStructInfo(var), var->span); + this->var_remap_[var->vid] = new_var; + return new_var; + } + return var; + } + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) final { + builder_->BeginBindingBlock(); + for (Binding binding : block->bindings) { + this->VisitBinding(binding); + } + return builder_->EndBlock(); + } +}; + +Expr ToNonDataflow(const Expr& e) { return ToNonDFMutator().VisitExpr(e); } + +namespace transform { + +Pass ToNonDataflow() { + runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast<Function>(ToNonDataflow(f)); }; + return CreateFunctionPass(pass_func, 0, "ToNonDataflow", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.ToNonDataflow").set_body_typed(ToNonDataflow); + +} // namespace transform + +} // namespace relax +} // namespace tvm diff --git a/tests/python/relax/test_analysis.py b/tests/python/relax/test_analysis.py new file mode 100644 index 0000000000..5dd83f2da2 --- /dev/null +++ b/tests/python/relax/test_analysis.py @@ -0,0 +1,172 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import List, Set, Union + +import tvm +import tvm.testing +from tvm import tir +from tvm import relax as rx +from tvm.relax.analysis import has_reshape_pattern +from tvm.script import relax as R, tir as T + + +def test_reshape_pattern_reshape(): + @T.prim_func + def reshape( + rxplaceholder: T.Buffer((1, 2, 3, 4), "float32"), + T_reshape: T.Buffer((8, 3), "float32"), + ): + for i0, i1 in T.grid(8, 3): + with T.block("T_reshape"): + ax0, ax1 = T.axis.remap("SS", [i0, i1]) + T.reads( + rxplaceholder[ + (ax0 * 3 + ax1) // 24, + (ax0 * 3 + ax1) % 24 // 12, + (ax0 * 3 + ax1) % 12 // 4, + (ax0 * 3 + ax1) % 4, + ] + ) + T.writes(T_reshape[ax0, ax1]) + T_reshape[ax0, ax1] = rxplaceholder[ + (ax0 * 3 + ax1) // 24, + (ax0 * 3 + ax1) % 24 // 12, + (ax0 * 3 + ax1) % 12 // 4, + (ax0 * 3 + ax1) % 4, + ] + + assert has_reshape_pattern(reshape) + + +def test_reshape_pattern_reshape_scheduled(): + @T.prim_func + def reshape_scheduled( + rxplaceholder: T.Buffer((1, 2, 3, 4), "float32"), + T_reshape: T.Buffer((8, 3), "float32"), + ): + for i0_i1_fused_0 in T.thread_binding(1, thread="blockIdx.x"): + for i0_i1_fused_1 in T.thread_binding(24, thread="threadIdx.x"): + with T.block("T_reshape"): + ax0 = T.axis.spatial(8, (i0_i1_fused_0 * 24 + i0_i1_fused_1) // 3) + ax1 = T.axis.spatial(3, (i0_i1_fused_0 * 24 + i0_i1_fused_1) % 3) + T.reads( + rxplaceholder[ + (ax0 * 3 + ax1) // 24, + (ax0 * 3 + ax1) % 24 // 12, + (ax0 * 3 + ax1) % 12 // 4, + (ax0 * 3 + ax1) % 4, + ] + ) + T.writes(T_reshape[ax0, ax1]) + T_reshape[ax0, ax1] = rxplaceholder[ + (ax0 * 3 + ax1) // 24, + (ax0 * 3 + ax1) % 24 // 12, + (ax0 * 3 + ax1) % 12 // 4, + (ax0 * 3 + ax1) % 4, + ] + + assert has_reshape_pattern(reshape_scheduled) + + +def test_reshape_pattern_expand_dims(): + @T.prim_func + def expand_dims( + rxplaceholder: T.Buffer((2, 3, 4), "float32"), + expand_dims: T.Buffer((2, 1, 1, 1, 3, 1, 4, 1), "float32"), + ): + T.func_attr({"tir.noalias": True}) + for i0, i1, i2, i3, i4, i5, i6, i7 in T.grid(2, 1, 1, 1, 3, 1, 4, 1): + with T.block("expand_dims"): + i0_1, i1_1, i2_1, i3_1, i4_1, i5_1, i6_1, i7_1 = T.axis.remap( + "SSSSSSSS", [i0, i1, i2, i3, i4, i5, i6, i7] + ) + T.reads(rxplaceholder[i0_1, i4_1, i6_1]) + T.writes(expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1, i6_1, i7_1]) + expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1, i5_1, i6_1, i7_1] = rxplaceholder[ + i0_1, i4_1, i6_1 + ] + + assert has_reshape_pattern(expand_dims) + + +def test_reshape_pattern_with_raggedness(): + @T.prim_func + def reshape_raggedness( + A: T.Buffer((100, 768), "float32"), + src_indptr: T.Buffer((9,), "int32"), + B: T.Buffer((100, 12, 64), "float32"), + ): + for b in T.serial(8): + with T.block("block0"): + vb = T.axis.spatial(8, b) + for i in T.serial(src_indptr[vb + 1] - src_indptr[vb]): + for h in T.serial(12): + for f in T.serial(64): + with T.block("block1"): + vi, vh, vf = T.axis.remap("SSS", [i, h, f]) + B[src_indptr[vb] + vi, vh, vf] = A[ + src_indptr[vb] + vi, vh * 64 + vf + ] + + assert has_reshape_pattern(reshape_raggedness) + + +def test_reshape_pattern_reject_seqstmt(): + @T.prim_func + def identity_bias(A: T.Buffer((4, 4), "float32"), B: T.Buffer((4, 4), "float32")): + C = T.alloc_buffer((128, 128), "float32") + for i0, i1 in T.grid(4, 4): + with T.block("identity"): + vi0, vi1 = T.axis.remap("SS", [i0, i1]) + C[vi0, vi1] = A[vi0, vi1] + for i0, i1 in T.grid(4, 4): + with T.block("identity"): + vi0, vi1 = T.axis.remap("SS", [i0, i1]) + B[vi0, vi1] = C[vi0, vi1] + T.float32(1) + + @T.prim_func + def identity_identity(A: T.Buffer((4, 4), "float32"), B: T.Buffer((4, 4), "float32")): + C = T.alloc_buffer((128, 128), "float32") + for i0, i1 in T.grid(4, 4): + with T.block("identity"): + vi0, vi1 = T.axis.remap("SS", [i0, i1]) + C[vi0, vi1] = A[vi0, vi1] + for i0, i1 in T.grid(4, 4): + with T.block("identity"): + vi0, vi1 = T.axis.remap("SS", [i0, i1]) + B[vi0, vi1] = C[vi0, vi1] + + assert not has_reshape_pattern(identity_bias) + assert not has_reshape_pattern(identity_identity) + + +def test_reshape_pattern_reject_reduction(): + @T.prim_func + def reduction(A: T.Buffer((4, 4), "float32"), B: T.Buffer((4,), "float32")): + for i0, i1 in T.grid(4, 4): + with T.block("identity"): + vi0, vi1 = T.axis.remap("SR", [i0, i1]) + with T.init(): + B[vi0] = T.float32(0) + B[vi0] = B[vi0] + A[vi0, vi1] + + assert not has_reshape_pattern(reduction) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py new file mode 100644 index 0000000000..624b7877cd --- /dev/null +++ b/tests/python/relax/test_transform.py @@ -0,0 +1,141 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import tvm +from tvm import relax +from tvm.ir import structural_equal +from tvm.ir.base import assert_structural_equal + +import tvm.script +from tvm.script import tir as T, relax as R + + +def test_to_non_dataflow(): + @tvm.script.ir_module + class TestToNonDataflow: + @R.function + def foo(x: R.Tensor(("m", "n"), "float32")): + m, n = T.var("int64"), T.var("int64") + with R.dataflow(): + lv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n), dtype="float32")) + gv0 = R.call_tir("test.op.identity", (lv0,), R.Tensor((m, n), dtype="float32")) + R.output(gv0) + return gv0 + + mod = TestToNonDataflow + + old_vars = [] + + def fvisit(e): + if isinstance(e, relax.Var): + nonlocal old_vars + old_vars.append(e) + + relax.analysis.post_order_visit(mod["foo"], fvisit) + x, lv0, gv0 = old_vars + + new_mod = relax.transform.ToNonDataflow()(mod) + + new_vars = [] + + def fvisit(e): + if isinstance(e, relax.Var): + nonlocal new_vars + new_vars.append(e) + + relax.analysis.post_order_visit(new_mod["foo"], fvisit) + + assert x == new_vars[0] + assert lv0 != new_vars[1] + assert isinstance(lv0, relax.DataflowVar) + assert not isinstance(new_vars[1], relax.DataflowVar) + + assert isinstance(gv0, relax.Var) + assert isinstance(new_vars[2], relax.Var) + assert gv0 == new_vars[2] + + +def test_call_tir_rewrite(): + @tvm.script.ir_module + class TestCallTIRRewrite: + @R.function + def foo(x: R.Tensor(("m", "n"), "float32")): + m, n = T.var("int64"), T.var("int64") + gv0 = R.call_tir("test.op.identity", (x,), R.Tensor((m, n), dtype="float32")) + return gv0 + + mod = TestCallTIRRewrite + + # before rewrite + v0 = mod["foo"].body.blocks[0].bindings[0].var + s0 = mod["foo"].body.blocks[0].bindings[0].value + assert isinstance(s0, relax.Call) + assert s0.op.name == "relax.call_tir" + + # after rewrite + new_mod = relax.transform.CallTIRRewrite()(mod) + func = new_mod["foo"] + + block = func.body.blocks[0] + assert not isinstance(block, relax.DataflowBlock) + + s1 = block.bindings[0].value + assert isinstance(s1, relax.Call) + assert s1.op.name == "relax.builtin.alloc_tensor" + assert isinstance(s1.args[0], relax.ShapeExpr) + assert structural_equal(s1.args[0], s0.sinfo_args[0].shape) + s2 = block.bindings[1].value + assert s2.op.global_symbol == "test.op.identity" + + +def test_vm_builtin_lower(): + @tvm.script.ir_module + class TestVMBuiltinLower: + @R.function + def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor: + m, n = T.var("int64"), T.var("int64") + alloc = R.builtin.alloc_tensor((m, n), runtime_device_index=0, dtype="float32") + _ = R.call_packed( + "test.op.identity", x, alloc, sinfo_args=(R.Tensor(ndim=2, dtype="float32")) + ) + gv0 = alloc + return gv0 + + mod = TestVMBuiltinLower + + # after vm builtin lowering + new_mod = relax.transform.VMBuiltinLower()(mod) + func = new_mod["foo"] + + assert isinstance(new_mod, tvm.IRModule) + assert isinstance(func, tvm.relax.expr.Function) + + block = func.body.blocks[0] + s1 = block.bindings[0].value + assert isinstance(s1, relax.Call) + assert s1.op.name == "relax.vm.alloc_storage" + s2 = block.bindings[1].value + assert isinstance(s2, relax.Call) + s3 = block.bindings[2].value + assert isinstance(s3, relax.Call) + assert isinstance(s3.op, relax.ExternFunc) + assert s3.op.global_symbol == "test.op.identity" + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_transform_attach_global_symbol.py b/tests/python/relax/test_transform_attach_global_symbol.py new file mode 100644 index 0000000000..edfc646e21 --- /dev/null +++ b/tests/python/relax/test_transform_attach_global_symbol.py @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import tvm +from tvm import tir, relax +from tvm.ir import assert_structural_equal + +import tvm.script +from tvm.script import tir as T, relax as R + + +@tvm.script.ir_module +class Before: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + m = T.var("int64") + n = T.var("int64") + k = T.var("int64") + A = T.match_buffer(x, (m, n)) + B = T.match_buffer(y, (n, k)) + C = T.match_buffer(z, (m, k)) + + for i, j, k in T.grid(m, k, n): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + @R.function + def main(x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32")) -> R.Tensor: + m, n, k = T.var("int64"), T.var("int64"), T.var("int64") + gv0 = R.call_tir("tir_matmul", (x, w), R.Tensor((m, k), dtype="float32")) + return gv0 + + +def test_basic(): + @tvm.script.ir_module + class Expected: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_matmul"}) + m = T.var("int64") + n = T.var("int64") + k = T.var("int64") + A = T.match_buffer(x, (m, n)) + B = T.match_buffer(y, (n, k)) + C = T.match_buffer(z, (m, k)) + + for i, j, k in T.grid(m, k, n): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + @R.function + def main( + x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32") + ) -> R.Tensor: + R.func_attr({"global_symbol": "main"}) + m, n, k = T.var("int64"), T.var("int64"), T.var("int64") + gv0 = R.call_tir("tir_matmul", (x, w), R.Tensor((m, k), dtype="float32")) + return gv0 + + before = Before + expected = Expected + after = relax.transform.AttachGlobalSymbol()(before) + assert_structural_equal(after, expected) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relax/test_transform_rewrite_dataflow_reshape.py b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py new file mode 100644 index 0000000000..2c53d85c56 --- /dev/null +++ b/tests/python/relax/test_transform_rewrite_dataflow_reshape.py @@ -0,0 +1,166 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import tvm.testing +from tvm import relax +from tvm.script import relax as R, tir as T + + +def test_reshape_expand_dims(): + @tvm.script.ir_module + class Module: + @T.prim_func + def reshape( + rxplaceholder: T.Buffer((T.int64(8), T.int64(3)), "float32"), + T_reshape: T.Buffer((T.int64(2), T.int64(4), T.int64(3)), "float32"), + ): + for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(3)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads( + rxplaceholder[ + (v_ax0 * 12 + v_ax1 * 3 + v_ax2) // T.int64(3), + (v_ax1 * 12 + v_ax2 * 3 + v_ax2) % T.int64(3), + ] + ) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) + T_reshape[v_ax0, v_ax1, v_ax2] = rxplaceholder[ + (v_ax0 * 12 + v_ax1 * 3 + v_ax2) // T.int64(3), + (v_ax1 * 12 + v_ax2 * 3 + v_ax2) % T.int64(3), + ] + + @T.prim_func + def expand_dims( + rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(3)), "float32"), + expand_dims: T.Buffer( + (T.int64(2), T.int64(1), T.int64(4), T.int64(1), T.int64(3)), + "float32", + ), + ): + for i0, i1, i2, i3, i4 in T.grid( + T.int64(2), T.int64(1), T.int64(4), T.int64(1), T.int64(3) + ): + with T.block("expand_dims"): + i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(rxplaceholder[i0_1, i2_1, i4_1]) + T.writes(expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1]) + expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1] = rxplaceholder[i0_1, i2_1, i4_1] + + @R.function + def main( + x: R.Tensor((8, 3), dtype="float32") + ) -> R.Tensor((2, 1, 4, 1, 3), dtype="float32"): + with R.dataflow(): + y = R.call_tir(reshape, (x,), out_sinfo=R.Tensor((2, 4, 3), dtype="float32")) + z = R.call_tir(expand_dims, (y,), out_sinfo=R.Tensor((2, 1, 4, 1, 3), "float32")) + R.output(z) + return z + + @tvm.script.ir_module + class Expected: + @T.prim_func + def reshape( + rxplaceholder: T.Buffer((T.int64(8), T.int64(3)), "float32"), + T_reshape: T.Buffer((T.int64(2), T.int64(4), T.int64(3)), "float32"), + ): + for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(3)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads( + rxplaceholder[ + (v_ax0 * T.int64(12) + v_ax1 * T.int64(3) + v_ax2) // T.int64(3), + (v_ax1 * T.int64(12) + v_ax2 * T.int64(3) + v_ax2) % T.int64(3), + ] + ) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) + T_reshape[v_ax0, v_ax1, v_ax2] = rxplaceholder[ + (v_ax0 * T.int64(12) + v_ax1 * T.int64(3) + v_ax2) // T.int64(3), + (v_ax1 * T.int64(12) + v_ax2 * T.int64(3) + v_ax2) % T.int64(3), + ] + + @T.prim_func + def expand_dims( + rxplaceholder: T.Buffer((T.int64(2), T.int64(4), T.int64(3)), "float32"), + expand_dims: T.Buffer( + (T.int64(2), T.int64(1), T.int64(4), T.int64(1), T.int64(3)), "float32" + ), + ): + for i0, i1, i2, i3, i4 in T.grid( + T.int64(2), T.int64(1), T.int64(4), T.int64(1), T.int64(3) + ): + with T.block("expand_dims"): + i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(rxplaceholder[i0_1, i2_1, i4_1]) + T.writes(expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1]) + expand_dims[i0_1, i1_1, i2_1, i3_1, i4_1] = rxplaceholder[i0_1, i2_1, i4_1] + + @R.function + def main( + x: R.Tensor((8, 3), dtype="float32") + ) -> R.Tensor((2, 1, 4, 1, 3), dtype="float32"): + with R.dataflow(): + y: R.Tensor((2, 4, 3), "float32") = R.reshape(x, (2, 4, 3)) + # Note: `z` is the output var of the dataflow block, and is thus + # not expected to be rewritten. + z = R.call_tir( + expand_dims, (y,), out_sinfo=R.Tensor((2, 1, 4, 1, 3), dtype="float32") + ) + R.output(z) + return z + + assert relax.analysis.has_reshape_pattern(Module["expand_dims"]) + mod = relax.transform.RewriteDataflowReshape()(Module) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_reshape_non_dataflow(): + @tvm.script.ir_module + class Module: + @T.prim_func + def reshape( + rxplaceholder: T.Buffer((T.int64(8), T.int64(3)), "float32"), + T_reshape: T.Buffer((T.int64(2), T.int64(4), T.int64(3)), "float32"), + ): + for ax0, ax1, ax2 in T.grid(T.int64(2), T.int64(4), T.int64(3)): + with T.block("T_reshape"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads( + rxplaceholder[ + (v_ax0 * 12 + v_ax1 * 3 + v_ax2) // T.int64(3), + (v_ax1 * 12 + v_ax2 * 3 + v_ax2) % T.int64(3), + ] + ) + T.writes(T_reshape[v_ax0, v_ax1, v_ax2]) + T_reshape[v_ax0, v_ax1, v_ax2] = rxplaceholder[ + (v_ax0 * 12 + v_ax1 * 3 + v_ax2) // T.int64(3), + (v_ax1 * 12 + v_ax2 * 3 + v_ax2) % T.int64(3), + ] + + @R.function + def main(x: R.Tensor((8, 3), dtype="float32")) -> R.Tensor((2, 4, 3), dtype="float32"): + y = R.call_tir(reshape, (x,), out_sinfo=R.Tensor((2, 4, 3), dtype="float32")) + return y + + assert relax.analysis.has_reshape_pattern(Module["reshape"]) + # The binding var of the call_tir is not a DataflowVar. So the pass does no change. + mod = relax.transform.RewriteDataflowReshape()(Module) + tvm.ir.assert_structural_equal(mod, Module) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_vm_build.py b/tests/python/relax/test_vm_build.py new file mode 100644 index 0000000000..534d2308da --- /dev/null +++ b/tests/python/relax/test_vm_build.py @@ -0,0 +1,908 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os +from typing import Tuple, Callable + +import sys +import tempfile +import numpy as np +import pytest +import tvm +import tvm.script +import tvm.testing +from tvm import relax, rpc, te, tir, topi +from tvm.contrib import utils +from tvm.relax.testing import nn +from tvm.script import relax as R, tir as T +from tvm.relax.testing.vm import check_saved_func + +EXEC_MODE = ["bytecode"] + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_compile_simple(exec_mode): + @tvm.script.ir_module + class TestVMCompileStage0: + @R.function + def foo(x: R.Tensor((3, 4), "float32"), y: R.Tensor((3, 4), "float32")): + z = R.call_packed( + "test.vm.identity", x, y, sinfo_args=(R.Tensor(ndim=2, dtype="float32")) + ) + return y + + mod = TestVMCompileStage0 + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target, exec_mode=exec_mode) + inp1 = tvm.nd.array(np.random.rand(3, 4).astype(np.float32)) + inp2 = tvm.nd.array(np.random.rand(3, 4).astype(np.float32)) + vm = relax.VirtualMachine(ex, tvm.cpu()) + vm["foo"](inp1, inp2) + tvm.testing.assert_allclose(inp2.numpy(), inp1.numpy(), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_match_check(exec_mode): + @tvm.script.ir_module + class TestMatchCheck: + @R.function + def foo(x: R.Tensor(["n", "m"], "int32"), y: R.Object) -> R.Tensor(["m", "n"], dtype=None): + return y + + mod = TestMatchCheck + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + x0 = tvm.nd.array(np.zeros((1, 2)).astype("int32")) + y0 = tvm.nd.array(np.zeros((2, 1)).astype("float32")) + y1 = tvm.nd.array(np.zeros((1, 2)).astype("float32")) + y2 = tvm.nd.array(np.zeros((2, 1, 1)).astype("float32")) + + vm["foo"](x0, y0) + + with pytest.raises(RuntimeError, match=".*return.*"): + vm["foo"](x0, y1) + + with pytest.raises(ValueError, match=".*return.*"): + vm["foo"](x0, y2) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_compile_stage2(exec_mode): + @tvm.script.ir_module + class TestVMCompileStage2: + @R.function + def foo(x: R.Tensor(dtype="float32")) -> R.Shape: + n, m = T.var("int64"), T.var("int64") + _ = R.match_cast(x, R.Tensor((n, m), "float32")) + return (n * 2, m * 3) + + mod = TestVMCompileStage2 + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + shape = (32, 16) + arr = tvm.nd.array(np.random.rand(*shape).astype("float32")) + res = vm["foo"](arr) + assert res[0] == shape[0] * 2 + assert res[1] == shape[1] * 3 + + # dtype mismatch + with pytest.raises(ValueError, match=".*dtype.*"): + vm["foo"](tvm.nd.array(np.zeros((1, 2)).astype("int32"))) + + # ndim mismatch + with pytest.raises(ValueError, match=".*match_cast.*ndim.*"): + vm["foo"](tvm.nd.array(np.zeros((1,)).astype("float32"))) + + # type mismach + with pytest.raises(TypeError): + vm["foo"]([]) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_compile_stage3(exec_mode): + @tvm.script.ir_module + class TestVMCompileStage3: + @R.function + def foo(x: R.Tensor((32, 16), "float32")) -> R.Tensor: + with R.dataflow(): + y = R.call_tir("test.vm.identity", (x), R.Tensor((32, 16), dtype="float32")) + R.output(y) + return y + + mod = TestVMCompileStage3 + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + shape = (32, 16) + inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32)) + res = vm["foo"](inp) + tvm.testing.assert_allclose(res.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_compile_e2e(exec_mode): + @tvm.script.ir_module + class TestVMCompileE2E: + @R.function + def foo(x: R.Tensor(dtype="float32")) -> R.Tensor: + with R.dataflow(): + n, m = T.var("int64"), T.var("int64") + _ = R.match_cast(x, R.Tensor((n, m), "float32")) + y = R.call_tir("test.vm.tile", (x), R.Tensor((n, m * 2), dtype="float32")) + R.output(y) + return y + + mod = TestVMCompileE2E + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + shape = (32, 16) + inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32)) + res = check_saved_func(vm, "foo", inp) + tvm.testing.assert_allclose(res.numpy(), np.tile(inp.numpy(), (1, 2)), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_compile_e2e_func_param_with_shape(exec_mode): + @tvm.script.ir_module + class TestVMCompileE2E2: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_matmul"}) + m = T.var("int32") + n = T.var("int32") + k = T.var("int32") + A = T.match_buffer(x, (m, n)) + B = T.match_buffer(y, (n, k)) + C = T.match_buffer(z, (m, k)) + + for i, j, k in T.grid(m, k, n): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + @R.function + def func( + x: R.Tensor(("m", "n"), "float32"), w: R.Tensor(("n", "k"), "float32") + ) -> R.Tensor: + m, k = T.var("int64"), T.var("int64") + gv0 = R.call_tir(tir_matmul, (x, w), R.Tensor((m, k), dtype="float32")) + return gv0 + + mod = TestVMCompileE2E2 + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + data = tvm.nd.array(np.random.rand(32, 16).astype(np.float32)) + weight = tvm.nd.array(np.random.rand(16, 32).astype(np.float32)) + res = check_saved_func(vm, "func", data, weight) + expected = np.dot(data.numpy(), weight.numpy()) + tvm.testing.assert_allclose(res.numpy(), expected, rtol=1e-6, atol=1e-6) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_emit_te_extern(exec_mode): + if not tvm.get_global_func("tvm.contrib.cblas.matmul", True): + print("skip because extern function is not available") + return + bb = relax.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = relax.Var("x", R.Tensor([n, m], "float32")) + y = relax.Var("y", R.Tensor([m, n], "float32")) + + with bb.function("rx_cblas_matmul", [x, y]): + out = bb.emit_te(tvm.contrib.cblas.matmul, x, y, transa=False, transb=False) + bb.emit_func_output(out) + + mod = bb.get() + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + data = tvm.nd.array(np.random.rand(16, 32).astype(np.float32)) + weight = tvm.nd.array(np.random.rand(32, 16).astype(np.float32)) + res = check_saved_func(vm, "rx_cblas_matmul", data, weight) + expected = np.dot(data.numpy(), weight.numpy()) + tvm.testing.assert_allclose(res.numpy(), expected, rtol=1e-6, atol=1e-6) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_emit_te_concat(exec_mode): + # concatenate of two vectors of size (n,) and (m,) + bb = relax.BlockBuilder() + n, m = tir.Var("n", "int64"), tir.Var("m", "int64") + x = relax.Var("x", R.Tensor([n], "float32")) + y = relax.Var("y", R.Tensor([m], "float32")) + + def te_func(A, B): + C = te.compute((n + m), lambda i: tvm.tir.if_then_else(i < n, A[i], B[i - n])) + return C + + with bb.function("rx_func", [x, y]): + x1 = bb.emit_te(te_func, x, y) + bb.emit_func_output(x1) + + mod = bb.get() + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target, exec_mode=exec_mode) + + vm = relax.VirtualMachine(ex, tvm.cpu()) + inp = tvm.nd.array( + np.random.rand( + 1, + ).astype(np.float32) + ) + inp2 = tvm.nd.array( + np.random.rand( + 2, + ).astype(np.float32) + ) + res = check_saved_func(vm, "rx_func", inp, inp2) + tvm.testing.assert_allclose( + res.numpy(), np.append(inp.numpy(), inp2.numpy()), rtol=1e-7, atol=1e-7 + ) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_emit_te_dtype_change(exec_mode): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + x = relax.Var("x", R.Tensor([n], "float32")) + + # convert a tensor with dtype of float32 to int16 + def te_func(A): + B = te.compute((n,), lambda i: A[i].astype("int16")) + return B + + with bb.function("rx_func", [x]): + y = bb.emit_te(te_func, x) + bb.emit_func_output(y) + + mod = bb.get() + + new_mod = relax.transform.CallTIRRewrite()(mod) + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target, exec_mode=exec_mode) + + vm = relax.VirtualMachine(ex, tvm.cpu()) + inp = tvm.nd.array( + np.random.rand( + 1, + ).astype(np.float32) + ) + res = check_saved_func(vm, "rx_func", inp) + np.testing.assert_allclose(res.numpy(), inp.numpy().astype("int16")) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_emit_te_floor_symbolic_shape(exec_mode): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + x = relax.Var("x", R.Tensor([n], "float32")) + + def te_func(A): + C = te.compute((tir.floordiv(n, 2),), lambda i: A[i] + 1) + return C + + with bb.function("rx_func", [x]): + x1 = bb.emit_te(te_func, x) + bb.emit_func_output(x1) + + mod = bb.get() + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target, exec_mode=exec_mode) + + vm = relax.VirtualMachine(ex, tvm.cpu()) + shape = (9,) + inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32)) + res = check_saved_func(vm, "rx_func", inp) + + def expected_output(): + output_shape = (shape[0] // 2,) + return inp.numpy()[: output_shape[0]] + 1 + + tvm.testing.assert_allclose(res.numpy(), expected_output(), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_emit_te_constant_param_cpu(exec_mode): + x_np = np.random.rand(2, 2).astype("float32") + c_np = np.random.rand(2, 2).astype("float32") + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 2), "float32")) + c = relax.const(c_np, "float32") + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, c) + gv = bb.emit_output(lv0) + bb.emit_func_output(gv) + + mod = bb.get() + exec = relax.vm.build(mod, "llvm", exec_mode=exec_mode) + dev = tvm.cpu() + vm = relax.VirtualMachine(exec, dev) + + add_res = check_saved_func(vm, "main", tvm.nd.array(x_np, dev)) + tvm.testing.assert_allclose(add_res.numpy(), x_np + c_np, rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +@tvm.testing.requires_gpu +def test_vm_emit_te_constant_param_gpu(exec_mode): + x_np = np.random.rand(2, 2).astype("float32") + c_np = np.random.rand(2, 2).astype("float32") + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 2), "float32")) + c = relax.const(c_np, "float32") + with bb.function("main", [x]): + with bb.dataflow(): + lv0 = bb.emit_te(topi.add, x, c) + gv = bb.emit_output(lv0) + bb.emit_func_output(gv) + + mod = bb.get() + sch = tvm.tir.Schedule(mod, debug_mask="all") + loops = sch.get_loops(sch.get_block(name="T_add", func_name="add")) + sch.bind(loops[0], "threadIdx.x") + + exec = relax.vm.build(sch.mod, "cuda", exec_mode=exec_mode) + dev = tvm.cuda() + vm = relax.VirtualMachine(exec, dev) + + add_res = check_saved_func(vm, "main", tvm.nd.array(x_np, dev)) + tvm.testing.assert_allclose(add_res.numpy(), x_np + c_np, rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_relax_symbolic_shape(exec_mode): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + x = relax.Var("x", R.Tensor([n], "float32")) + y = relax.Var("y", R.Tensor([(n // 2) + 1], "float32")) + + def te_func(A, B): + C = te.compute((n,), lambda i: A[i] + B[i // 2]) + return C + + with bb.function("rx_func", [x, y]): + x1 = bb.emit_te(te_func, x, y) + bb.emit_func_output(x1) + + mod = bb.get() + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target, exec_mode=exec_mode) + + vm = relax.VirtualMachine(ex, tvm.cpu()) + shape1 = (5,) + shape2 = (3,) + inp = tvm.nd.array(np.random.rand(*shape1).astype(np.float32)) + inp2 = tvm.nd.array(np.random.rand(*shape2).astype(np.float32)) + res = check_saved_func(vm, "rx_func", inp, inp2) + + def expected_output(): + return inp.numpy() + np.repeat(inp2.numpy(), 2)[:5] + + tvm.testing.assert_allclose(res.numpy(), expected_output(), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_relax_dyn_tir_shape(exec_mode): + # case where TIR variables are unbound in generated PrimFunc + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + + def te_func(A): + C = te.compute((n + 1), lambda i: A[i]) + return C + + with bb.function("rx_func"): + x = nn.Placeholder((n,), dtype="float32", name="x") + y = nn.Placeholder((n + 1,), dtype="float32", name="y") + + x1 = bb.emit_te(te_func, y) + bb.emit_func_output(x1, params=[x, y]) + + mod = bb.get() + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target, exec_mode=exec_mode) + + ex.mod.export_library("exec.so") + exec1 = relax.vm.Executable(tvm.runtime.load_module("exec.so")) + os.remove("exec.so") + assert ex.as_text() == exec1.as_text() + + vm = relax.VirtualMachine(ex, tvm.cpu()) + inp = tvm.nd.array(np.random.rand(2).astype(np.float32)) + inp2 = tvm.nd.array(np.random.rand(3).astype(np.float32)) + + res = check_saved_func(vm, "rx_func", inp, inp2) + + tvm.testing.assert_allclose(res.numpy(), inp2.numpy(), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_tuple(exec_mode): + bb = relax.BlockBuilder() + n = tir.Var("n", "int64") + + with bb.function("rx_func"): + x = nn.Placeholder((n,), dtype="float32", name="x") + y = nn.Placeholder((n,), dtype="float32", name="y") + tup = relax.Tuple([x, y]) + item = tup[0] + bb.emit_func_output([tup, item], params=[x, y]) + + mod = bb.get() + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target, exec_mode=exec_mode) + + vm = relax.VirtualMachine(ex, tvm.cpu()) + shape = (5,) + inp = tvm.nd.array(np.random.rand(*shape).astype(np.float32)) + inp2 = tvm.nd.array(np.random.rand(*shape).astype(np.float32)) + (res1, res2), res3 = vm["rx_func"](inp, inp2) + + tvm.testing.assert_allclose(res1.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7) + tvm.testing.assert_allclose(res2.numpy(), inp2.numpy(), rtol=1e-7, atol=1e-7) + tvm.testing.assert_allclose(res3.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_tuplegetitem(exec_mode): + @tvm.script.ir_module + class TestVMTupleGetItem: + @R.function + def tuple_get_item( + x: R.Tensor(ndim=2, dtype="float32"), + y: R.Tensor(ndim=2, dtype="float32"), + ): + t = (x, y) + a = t[0] + b = t[1] + c = R.call_packed("test.vm.add", a, b, sinfo_args=(R.Tensor(ndim=2, dtype="float32"))) + return c + + mod = TestVMTupleGetItem + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + x_inp = tvm.nd.array(np.random.rand(2, 3).astype("float32")) + y_inp = tvm.nd.array(np.random.rand(2, 3).astype("float32")) + res = check_saved_func(vm, "tuple_get_item", x_inp, y_inp) + tvm.testing.assert_allclose(res.numpy(), x_inp.numpy() + y_inp.numpy(), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_lower_memory_alloc_storage_tensor(exec_mode): + @tvm.script.ir_module + class TestMemoryAllocStorageTensor: + @R.function + def main(x: R.Tensor((2, 3), dtype="float32")): + storage = R.memory.alloc_storage( + (24,), virtual_device_index=0, storage_scope="global", dtype="float32" + ) + y = R.memory.alloc_tensor(storage, 0, (2, 3), dtype="float32") + _ = copy(x, y) + return y + + @T.prim_func + def copy(A: T.Buffer((2, 3), "float32"), B: T.Buffer((2, 3), "float32")): + for i0, i1 in T.grid(2, 3): + with T.block("block"): + vi0, vi1 = T.axis.remap("SS", [i0, i1]) + B[vi0, vi1] = A[vi0, vi1] + + mod = TestMemoryAllocStorageTensor + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + x = tvm.nd.array(np.random.rand(2, 3).astype("float32")) + y = vm["main"](x) + tvm.testing.assert_allclose(y.numpy(), x.numpy(), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_sub_func_call(exec_mode): + @tvm.script.ir_module + class TestVMSubFunction: + @T.prim_func + def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None: + T.func_attr({"global_symbol": "tir_matmul"}) + m = T.var("int32") + n = T.var("int32") + k = T.var("int32") + A = T.match_buffer(x, (m, n)) + B = T.match_buffer(y, (n, k)) + C = T.match_buffer(z, (m, k)) + + for i, j, k in T.grid(m, k, n): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + @R.function + def relax_matmul_tir( + x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32") + ) -> R.Tensor((32, 32), dtype="float32"): + with R.dataflow(): + gv0 = R.call_tir(tir_matmul, (x, w), R.Tensor((32, 32), dtype="float32")) + R.output(gv0) + return gv0 + + @R.function + def relax_matmul_packed( + x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32") + ) -> R.Object: + gv0 = R.call_packed("test.vm.mul", x, w, sinfo_args=(R.Tensor(ndim=2, dtype="float32"))) + return gv0 + + @R.function + def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Object: + gv0 = relax_matmul_tir(x, w) + gv1 = relax_matmul_packed(gv0, gv0) + return gv1 + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(TestVMSubFunction, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + x_inp = tvm.nd.array(np.random.rand(32, 32).astype(np.float32)) + y_inp = tvm.nd.array(np.random.rand(32, 32).astype(np.float32)) + res = check_saved_func(vm, "main", x_inp, y_inp) + product = np.dot(x_inp.numpy(), y_inp.numpy()) + expected = product * product + tvm.testing.assert_allclose(res.numpy(), expected, rtol=1e-6, atol=1e-6) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_recursion(exec_mode): + @tvm.script.ir_module + class TestVMRecursion: + @R.function + def recursion(n: R.Tensor((1,), "float32")) -> R.Tensor: + cond = R.call_packed( + "test.vm.equal_zero", n, sinfo_args=(R.Tensor(ndim=1, dtype="float32")) + ) + if cond: + res = R.const(1.0) + else: + gv0 = R.call_packed( + "test.vm.subtract_one", n, sinfo_args=(R.Tensor(ndim=1, dtype="float32")) + ) + tmp = recursion(gv0) + res = R.call_packed( + "test.vm.add", tmp, tmp, sinfo_args=(R.Tensor(ndim=1, dtype="float32")) + ) + return res + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(TestVMRecursion, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + + inp = np.empty(1).astype("float32") + recursion_runs = np.random.randint(1, 10) + inp.fill(recursion_runs) + inp = tvm.nd.array(inp) + res = check_saved_func(vm, "recursion", inp) + tvm.testing.assert_allclose(res.numpy(), np.power(2.0, recursion_runs), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_vm_closure(exec_mode): + @tvm.script.ir_module + class TestClosure: + @R.function + def lifted_func_1(x: R.Tensor((2, 3), "float32"), env: R.Tensor((2, 3), "float32")): + return R.call_packed("test.vm.add", x, env, sinfo_args=(R.Tensor)) + + @R.function + def main( + x: R.Tensor((2, 3), "float32"), + y: R.Tensor((2, 3), "float32"), + ): + clo = R.make_closure(lifted_func_1, (x,)) + res = R.invoke_closure(clo, (y,), sinfo_args=(R.Tensor)) + return res + + mod = TestClosure + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(mod, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + x_inp = tvm.nd.array(np.random.rand(2, 3).astype("float32")) + y_inp = tvm.nd.array(np.array([[3.1, 4.0, 5.0], [6.0, 7.1, 9.0]], dtype="float32")) + res = check_saved_func(vm, "main", x_inp, y_inp) + tvm.testing.assert_allclose(res.numpy(), x_inp.numpy() + y_inp.numpy()) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_time_evaluator(exec_mode): + @tvm.script.ir_module + class TestTimeEvaluator: + @R.function + def main(x: R.Tensor((1,), "float32"), y: R.Tensor((1,), "float32")): + return R.call_packed( + "test.vm.add", x, y, sinfo_args=(R.Tensor(ndim=1, dtype="float32")) + ) + + target = tvm.target.Target("llvm", host="llvm") + ex = relax.vm.build(TestTimeEvaluator, target, exec_mode=exec_mode) + vm = relax.VirtualMachine(ex, tvm.cpu()) + x = tvm.nd.array(np.random.rand(1).astype("float32")) + y = tvm.nd.array(np.random.rand(1).astype("float32")) + + # ensure we can use time_evaluator with the stateful API + vm.set_input("main", x, y) + timing_res = vm.time_evaluator("invoke_stateful", tvm.cpu())("main") + # just checking that it has some results at all + assert timing_res.results + + # ensure we can use it with a closure + vm.save_function("main", "saved_main", x, y) + timing_res = vm.time_evaluator("saved_main", tvm.cpu())() + assert timing_res.results + + +@tvm.script.ir_module +class TestVMSetInput: + @T.prim_func + def test_vm_mul(x: T.handle, y: T.handle, z: T.handle): + T.func_attr({"global_symbol": "test_vm_mul"}) + m = T.var("int32") + n = T.var("int32") + A = T.match_buffer(x, (m, n)) + B = T.match_buffer(y, (m, n)) + C = T.match_buffer(z, (m, n)) + + for i, j in T.grid(m, n): + with T.block("mul"): + vi = T.axis.spatial(m, i) + vj = T.axis.spatial(n, j) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = A[vi, vj] * B[vi, vj] + + # test returning a tuple + @R.function + def test_vm_tuple( + x: R.Tensor((), "int32") + ) -> R.Tuple(R.Tensor((), "int32"), R.Tensor((), "int32")): + return (x, x) + + # nested tuple too + @R.function + def test_vm_nested_tuple( + x: R.Tensor((), "int32") + ) -> R.Tuple( + R.Tuple( + R.Tensor((), "int32"), + R.Tuple( + R.Tensor((), "int32"), + ), + ), + R.Tensor((), "int32"), + ): + return ((x, (x,)), x) + + @R.function + def main(x: R.Tensor((32, 32), "float32"), w: R.Tensor((32, 32), "float32")) -> R.Tensor: + gv0 = R.call_tir("test_vm_mul", (x, w), R.Tensor((32, 32), dtype="float32")) + return gv0 + + +def set_input_trial(vm: relax.VirtualMachine, device: tvm.runtime.Device) -> None: + a = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + b = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + vm.set_input("main", a, b) + vm.invoke_stateful("main") + res0 = vm.get_outputs("main") + + data_dict = {"x": a, "w": b} + vm.set_input("main", **data_dict) + vm.invoke_stateful("main") + res1 = vm.get_outputs("main") + tvm.testing.assert_allclose(res0.numpy(), a.numpy() * b.numpy(), rtol=1e-7, atol=1e-7) + tvm.testing.assert_allclose(res0.numpy(), res1.numpy(), rtol=1e-7, atol=1e-7) + + # bug! If you don't bind the NDArray to a var, the memory will get corrupted. + # Possibly due to object lifecycles and other FFI issues + a = tvm.nd.array(np.array(2).astype("int32"), device) + vm.set_input("test_vm_tuple", a) + vm.invoke_stateful("test_vm_tuple") + res2 = vm.get_outputs("test_vm_tuple") + # the results are NDArrays wrapped around scalars, + # so we have to get the scalar out of the NDArray + assert tuple(map(lambda a: int(a.numpy()), res2)) == (2, 2) + + b = tvm.nd.array(np.array(1).astype("int32"), device) + vm.set_input("test_vm_nested_tuple", b) + vm.invoke_stateful("test_vm_nested_tuple") + res3 = vm.get_outputs("test_vm_nested_tuple") + assert len(res3) == 2 and len(res3[0]) == 2 and len(res3[0][1]) == 1 + result_cast = ((int(res3[0][0].numpy()), (int(res3[0][1][0].numpy()),)), int(res3[1].numpy())) + assert result_cast == ((1, (1,)), 1) + + +def set_input_attempt_stateless(vm: relax.VirtualMachine, device: tvm.runtime.Device) -> None: + # this should fail: once you set inputs, you cannot run statelessly + a = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + b = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + vm.set_input("main", a, b) + # must use invoke stateful! + vm["main"]() + + +def set_input_attempt_invoke(vm: relax.VirtualMachine, device: tvm.runtime.Device) -> None: + # this should fail: if the function needs inputs, you can't invoke directly + vm.invoke_stateful("main") + + +def set_input_attempt_get(vm: relax.VirtualMachine, device: tvm.runtime.Device) -> None: + # this should fail: you can't get outputs without invoking the function first + a = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + b = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + vm.set_input("main", a, b) + _ = vm.get_outputs("main") + + +def make_vm(mod, exec_mode) -> Tuple[relax.VirtualMachine, tvm.runtime.Device]: + """Returns a local VM for the given mod and the device""" + target = tvm.target.Target("llvm", host="llvm") + exec = relax.vm.build(TestVMSetInput, target, exec_mode=exec_mode) + exec.mod.export_library("exec.so") + exec_loaded = relax.vm.Executable(tvm.runtime.load_module("exec.so")) + os.remove("exec.so") + device = tvm.cpu() + return relax.VirtualMachine(exec_loaded, device), device + + +def run_on_rpc( + mod: tvm.IRModule, + trial_func: Callable[[relax.VirtualMachine, tvm.runtime.Device], None], + exec_mode: str, +): + """ + Sets up a VM over localhost using the given mod and runs the given trial function. + The trial function should take a VM and a device + """ + target = tvm.target.Target("llvm", host="llvm") + exec = relax.vm.build(mod, target, exec_mode=exec_mode) + temp = utils.tempdir() + path = temp.relpath("vm_library.so") + exec.mod.export_library(path) + + # Use local rpc server for testing. + # Server must use popen so it doesn't inherit the current process state. It + # will crash otherwise. + # Adapted from relay/test_vm.py + def check_remote(server): + remote = rpc.connect(server.host, server.port, session_timeout=10) + + # Upload the serialized Executable. + remote.upload(path) + # Get a handle to remote Executable. + rexec = remote.load_module("vm_library.so") + + device = remote.cpu() + # Build a VM out of the executable and context. + vm = relax.vm.VirtualMachine(exec=rexec, device=device) + trial_func(vm, device) + + check_remote(rpc.Server("127.0.0.1")) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_set_input(exec_mode): + set_input_trial(*make_vm(TestVMSetInput, exec_mode)) + + +def save_function_kwargs_trial(vm: relax.VirtualMachine, device: tvm.runtime.Device) -> None: + # just checking that we can use kwargs for the args when saving a function + a = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + b = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + vm.save_function("main", "saved_main", x=a, w=b) + res0 = vm["saved_main"]() + tvm.testing.assert_allclose(res0.numpy(), a.numpy() * b.numpy(), rtol=1e-7, atol=1e-7) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_save_function_kwargs(exec_mode): + save_function_kwargs_trial(*make_vm(TestVMSetInput, exec_mode)) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_save_function_kwargs_rpc(exec_mode): + run_on_rpc(TestVMSetInput, save_function_kwargs_trial, exec_mode) + + +def save_function_time_evaluator_trial( + vm: relax.VirtualMachine, device: tvm.runtime.Device +) -> None: + # just checking that the saved function can be called in the time evaluator + a = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + b = tvm.nd.array(np.random.rand(32, 32).astype("float32"), device) + vm.save_function("main", "saved_main", a, b) + vm.time_evaluator("saved_main", device)() + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_save_function_time_evaluator(exec_mode): + save_function_time_evaluator_trial(*make_vm(TestVMSetInput, exec_mode)) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +def test_save_function_time_evaluator(exec_mode): + run_on_rpc(TestVMSetInput, save_function_time_evaluator_trial, exec_mode) + + +# if you set an input, you should not be able to call statelessly +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +@pytest.mark.xfail() +def test_set_input_stateless_failure(exec_mode): + set_input_attempt_stateless(*make_vm(TestVMSetInput, exec_mode)) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +@pytest.mark.xfail() +def test_set_input_stateless_failure_rpc(exec_mode): + run_on_rpc(TestVMSetInput, set_input_attempt_stateless, exec_mode) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +@pytest.mark.xfail() +def test_set_input_invoke_failure(exec_mode): + set_input_attempt_invoke(*make_vm(TestVMSetInput, exec_mode)) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +@pytest.mark.xfail() +def test_set_input_invoke_failure_rpc(exec_mode): + run_on_rpc(TestVMSetInput, set_input_attempt_invoke, exec_mode) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +@pytest.mark.xfail() +def test_set_input_get_failure(exec_mode): + set_input_attempt_get(*make_vm(TestVMSetInput, exec_mode)) + + +@pytest.mark.parametrize("exec_mode", EXEC_MODE) +@pytest.mark.xfail() +def test_set_input_get_failure_rpc(exec_mode): + run_on_rpc(TestVMSetInput, set_input_attempt_get, exec_mode) + + +if __name__ == "__main__": + tvm.testing.main()