This is an automated email from the ASF dual-hosted git repository. junrushao pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 25cf489b04 [TOPI] Layout Rewriting in TE (#11844) 25cf489b04 is described below commit 25cf489b0410dc8cf4c938e9337a31b9c5ddd3b6 Author: Hongyi Jin <3231950...@qq.com> AuthorDate: Thu Jun 23 11:43:45 2022 +0800 [TOPI] Layout Rewriting in TE (#11844) --- python/tvm/auto_scheduler/relay_integration.py | 5 ++++ python/tvm/topi/cuda/conv2d_winograd.py | 1 + python/tvm/topi/nn/batch_matmul.py | 14 +++++++++- python/tvm/topi/nn/conv2d.py | 30 ++++++++++++++++++--- python/tvm/topi/nn/conv3d.py | 7 ++++- python/tvm/topi/nn/dense.py | 36 +++++++++++++++++++++++--- src/auto_scheduler/compute_dag.cc | 10 +++++++ 7 files changed, 95 insertions(+), 8 deletions(-) diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index e9bf1ccfd7..ee166e8679 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -467,6 +467,11 @@ def rewrite_compute_body(compute_tensor, new_layout): return outputs[0] if num == 1 else outputs +def rewrite_tensor_shape(tensor, shape): + """Rewrite the tensor shape""" + _ffi_api.RewriteTensorShape(tensor, shape) + + def is_auto_scheduler_enabled(): """Return whether the auto-scheduler is enabled. diff --git a/python/tvm/topi/cuda/conv2d_winograd.py b/python/tvm/topi/cuda/conv2d_winograd.py index d2b373ba87..89a21f5c02 100644 --- a/python/tvm/topi/cuda/conv2d_winograd.py +++ b/python/tvm/topi/cuda/conv2d_winograd.py @@ -379,6 +379,7 @@ def conv2d_winograd_nhwc_cuda( out_dtype, pre_computed=False, auto_scheduler_rewritten_layout="", + meta_schedule_original_shape=None, ): """Conv2D Winograd in NHWC layout. This is a clean version to be used by the auto-scheduler for both CPU and GPU. diff --git a/python/tvm/topi/nn/batch_matmul.py b/python/tvm/topi/nn/batch_matmul.py index 26d45feb03..2156fe11ed 100644 --- a/python/tvm/topi/nn/batch_matmul.py +++ b/python/tvm/topi/nn/batch_matmul.py @@ -17,8 +17,10 @@ """Batch matrix multiplication""" # pylint: disable=invalid-name import logging + import tvm -from tvm import te, auto_scheduler +from tvm import auto_scheduler, te + from ..utils import get_const_tuple logger = logging.getLogger("topi") @@ -32,6 +34,7 @@ def batch_matmul( transpose_a=False, transpose_b=True, auto_scheduler_rewritten_layout="", + meta_schedule_original_shape=None, ): """Compute batch matrix multiplication of `tensor_a` and `tensor_b`. @@ -62,6 +65,9 @@ def batch_matmul( auto_scheduler_rewritten_layout: Optional[str] = "" The layout after auto-scheduler's layout rewrite pass. + meta_schedule_original_shape: Optional[List[PrimExpr]] = None + The original shape of the tensor + Returns ------- output : tvm.te.Tensor @@ -78,6 +84,12 @@ def batch_matmul( auto_scheduler_rewritten_layout, ["b", "k", "j"] ) auto_scheduler.remove_index_check(tensor_b) + elif meta_schedule_original_shape: + auto_scheduler.rewrite_tensor_shape(tensor_b, meta_schedule_original_shape) + if transpose_b: + YB, YJ, YK = get_const_tuple(tensor_b.shape) + else: + YB, YK, YJ = get_const_tuple(tensor_b.shape) else: assert len(tensor_b.shape) == 3, "tensor_b only support 3-dim" if transpose_b: diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index b7ae9b3e1c..5db752f6d5 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -280,6 +280,7 @@ def conv2d_nhwc( dilation, out_dtype="float32", auto_scheduler_rewritten_layout="", + meta_schedule_original_shape=None, ): """Convolution operator in NHWC layout. @@ -308,6 +309,9 @@ def conv2d_nhwc( auto_scheduler_rewritten_layout: str = "" The layout after auto-scheduler's layout rewrite pass. + meta_schedule_original_shape: Optional[List[PrimExpr]] = None + The original shape of the input tensor. + Returns ------- output : tvm.te.Tensor @@ -323,6 +327,7 @@ def conv2d_nhwc( "NHWC", out_dtype, auto_scheduler_rewritten_layout, + meta_schedule_original_shape, auto_scheduler_should_rewrite_layout=True, ) @@ -716,6 +721,7 @@ def conv( order: str, out_dtype: Union[str, None] = None, auto_scheduler_rewritten_layout: Optional[str] = None, + meta_schedule_original_shape=None, auto_scheduler_should_rewrite_layout: bool = False, ): """Convolution operator in NCHW or NHWC layout. @@ -755,14 +761,17 @@ def conv( Elements are converted to this type before elementwise multiplication and summation. + auto_scheduler_rewritten_layout: str + Layout from autoscheduler's layout rewritting. + + meta_schedule_original_shape : Optional[List[PrimExpr]] + The original shape of the input tensor. + auto_scheduler_should_rewrite_layout : bool Should auto scheduler be allowed to rewrite the layout of the filter tensor. Defaults to false. This can cause errors if used with grouped convs. - auto_scheduler_rewritten_layout: str - Layout from autoscheduler's layout rewritting. - Returns ------- Output : tvm.te.Tensor @@ -802,6 +811,8 @@ def conv( permutation_to_kernel = [dim + 1, dim] + list(range(dim)) permutation_from_kernel = np.argsort(permutation_to_kernel) + if meta_schedule_original_shape: + auto_scheduler.rewrite_tensor_shape(filt, meta_schedule_original_shape) batch, in_channel, *dimensions = np.array(get_const_tuple(inp.shape))[permutation_to].tolist() num_filter, _, *kernel_dimensions = np.array(get_const_tuple(filt.shape))[ permutation_to_kernel @@ -959,6 +970,7 @@ def _conv2d_winograd_nhwc_impl( tile_size, pre_computed=False, auto_scheduler_rewritten_layout="", + meta_schedule_original_shape=None, ): """Conv2D Winograd implementation in NHWC layout. This is a clean version to be used by the auto-scheduler for both CPU and GPU. @@ -983,6 +995,8 @@ def _conv2d_winograd_nhwc_impl( Whether the kernel is precomputed auto_scheduler_rewritten_layout: str = "" The layout after auto-scheduler's layout rewrite pass. + meta_schedule_original_shape: Optional[List[PrimExpr]] = None + The original shape of the input tensor. Returns ------- @@ -994,6 +1008,8 @@ def _conv2d_winograd_nhwc_impl( dilation_h = dilation_w = dilation else: dilation_h, dilation_w = dilation + if meta_schedule_original_shape: + auto_scheduler.rewrite_tensor_shape(weight, meta_schedule_original_shape) assert (dilation_h, dilation_w) == (1, 1), "Does not support dilation" if not pre_computed: @@ -1136,6 +1152,7 @@ def conv2d_winograd_nhwc( out_dtype, pre_computed=False, auto_scheduler_rewritten_layout="", + meta_schedule_original_shape=None, ): """Conv2D Winograd in NHWC layout. This is a clean version to be used by the auto-scheduler for both CPU and GPU. @@ -1158,6 +1175,8 @@ def conv2d_winograd_nhwc( Whether the kernel is precomputed auto_scheduler_rewritten_layout: str = "" The layout after auto-scheduler's layout rewrite pass. + meta_schedule_original_shape: Optional[List[PrimExpr]] = None + The original shape of the input tensor. Returns ------- @@ -1176,6 +1195,7 @@ def conv2d_winograd_nhwc( tile_size, pre_computed, auto_scheduler_rewritten_layout, + meta_schedule_original_shape, ) @@ -1187,6 +1207,7 @@ def conv2d_winograd_nhwc_without_weight_transform( dilation, out_dtype, auto_scheduler_rewritten_layout="", + meta_schedule_original_shape=None, ): """Conv2D Winograd without layout transform in NHWC layout. This is a clean version to be used by the auto-scheduler for both CPU and GPU. @@ -1207,6 +1228,8 @@ def conv2d_winograd_nhwc_without_weight_transform( Specifies the output data type. auto_scheduler_rewritten_layout: str = "" The layout after auto-scheduler's layout rewrite pass. + meta_schedule_original_shape: Optional[List[PrimExpr]] = None + The original shape of the input tensor. Returns ------- @@ -1223,4 +1246,5 @@ def conv2d_winograd_nhwc_without_weight_transform( out_dtype, pre_computed=True, auto_scheduler_rewritten_layout=auto_scheduler_rewritten_layout, + meta_schedule_original_shape=meta_schedule_original_shape, ) diff --git a/python/tvm/topi/nn/conv3d.py b/python/tvm/topi/nn/conv3d.py index 2915b886a5..591c643a95 100644 --- a/python/tvm/topi/nn/conv3d.py +++ b/python/tvm/topi/nn/conv3d.py @@ -21,8 +21,8 @@ import tvm from tvm import te from ..utils import get_const_tuple -from .winograd_util import winograd_transform_matrices from .conv2d import conv +from .winograd_util import winograd_transform_matrices def conv3d_ncdhw(Input, Filter, stride, padding, dilation, groups, out_dtype=None): @@ -65,6 +65,7 @@ def conv3d_ndhwc( groups, out_dtype="float32", auto_scheduler_rewritten_layout="", + meta_schedule_origin_shape=None, ): """Convolution operator in NDHWC layout. @@ -94,6 +95,9 @@ def conv3d_ndhwc( auto_scheduler_rewritten_layout: str = "" The layout after auto-scheduler's layout rewrite pass. + meta_schedule_origin_shape: Optional[List[PrimExpr]] = None + The original shape of the input tensor. + Returns ------- Output : tvm.te.Tensor @@ -109,6 +113,7 @@ def conv3d_ndhwc( "NDHWC", out_dtype, auto_scheduler_rewritten_layout, + meta_schedule_origin_shape, ) diff --git a/python/tvm/topi/nn/dense.py b/python/tvm/topi/nn/dense.py index 69fac92c7c..61f9c4e17c 100644 --- a/python/tvm/topi/nn/dense.py +++ b/python/tvm/topi/nn/dense.py @@ -17,7 +17,8 @@ # pylint: disable=invalid-name,unused-argument """TVM operator fully connected compute.""" import tvm -from tvm import te, auto_scheduler +from tvm import auto_scheduler, te + from .. import tag @@ -29,6 +30,7 @@ def matmul( transpose_a=False, transpose_b=False, auto_scheduler_rewritten_layout="", + meta_schedule_original_shape=None, ): """The default implementation of matmul in topi. @@ -55,6 +57,9 @@ def matmul( auto_scheduler_rewritten_layout: Optional[str] = "" The layout after auto-scheduler's layout rewrite pass. + meta_schedule_original_shape: Optional[List[PrimExpr]] = None + The original shape of the input tensor. + Returns ------- output : tvm.te.Tensor @@ -77,6 +82,12 @@ def matmul( auto_scheduler_rewritten_layout, ["j", "k"] ) auto_scheduler.remove_index_check(tensor_b) + elif meta_schedule_original_shape: + auto_scheduler.rewrite_tensor_shape(tensor_b, meta_schedule_original_shape) + if transpose_b: + out_dim, red_dim = tensor_b.shape + else: + red_dim, out_dim = tensor_b.shape elif transpose_b: out_dim, red_dim = tensor_b.shape else: @@ -156,7 +167,14 @@ def matmul_legalize(attrs, inputs, types): return None -def dense(data, weight, bias=None, out_dtype=None, auto_scheduler_rewritten_layout=""): +def dense( + data, + weight, + bias=None, + out_dtype=None, + auto_scheduler_rewritten_layout="", + meta_schedule_original_shape=None, +): """The default implementation of dense in topi. This is an alias of matmul_nt operator for data tensor in non-transposed format and weight tensor in transposed format. @@ -178,12 +196,24 @@ def dense(data, weight, bias=None, out_dtype=None, auto_scheduler_rewritten_layo auto_scheduler_rewritten_layout: str = "" The layout after auto-scheduler's layout rewrite pass. + meta_schedule_original_shape: Optional[List[PrimExpr]] = None + The original shape of the input tensor. + Returns ------- output : tvm.te.Tensor 2-D with shape [batch, out_dim] """ - return matmul(data, weight, bias, out_dtype, False, True, auto_scheduler_rewritten_layout) + return matmul( + data, + weight, + bias, + out_dtype, + False, + True, + auto_scheduler_rewritten_layout, + meta_schedule_original_shape, + ) @tvm.target.generic_func diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index e82830fa4d..dad55db030 100644 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -1517,6 +1517,16 @@ TVM_REGISTER_GLOBAL("auto_scheduler.RewriteIndexForNewLayout") return index_rewriter.Rewrite(body); }); +TVM_REGISTER_GLOBAL("auto_scheduler.RewriteTensorShape") + .set_body_typed([](te::Tensor tensor, Array<PrimExpr> new_shape) -> void { + ICHECK(tensor->op->IsInstance<te::PlaceholderOpNode>()); + te::PlaceholderOpNode* op = + const_cast<te::PlaceholderOpNode*>(tensor->op.as<te::PlaceholderOpNode>()); + te::TensorNode* t = const_cast<te::TensorNode*>(tensor.get()); + op->shape = new_shape; + t->shape = new_shape; + }); + TVM_REGISTER_GLOBAL("auto_scheduler.GetShapeFromRewrittenLayout") .set_body_typed(GetShapeFromRewrittenLayout);