This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 25cf489b04 [TOPI] Layout Rewriting in TE (#11844)
25cf489b04 is described below

commit 25cf489b0410dc8cf4c938e9337a31b9c5ddd3b6
Author: Hongyi Jin <3231950...@qq.com>
AuthorDate: Thu Jun 23 11:43:45 2022 +0800

    [TOPI] Layout Rewriting in TE (#11844)
---
 python/tvm/auto_scheduler/relay_integration.py |  5 ++++
 python/tvm/topi/cuda/conv2d_winograd.py        |  1 +
 python/tvm/topi/nn/batch_matmul.py             | 14 +++++++++-
 python/tvm/topi/nn/conv2d.py                   | 30 ++++++++++++++++++---
 python/tvm/topi/nn/conv3d.py                   |  7 ++++-
 python/tvm/topi/nn/dense.py                    | 36 +++++++++++++++++++++++---
 src/auto_scheduler/compute_dag.cc              | 10 +++++++
 7 files changed, 95 insertions(+), 8 deletions(-)

diff --git a/python/tvm/auto_scheduler/relay_integration.py 
b/python/tvm/auto_scheduler/relay_integration.py
index e9bf1ccfd7..ee166e8679 100644
--- a/python/tvm/auto_scheduler/relay_integration.py
+++ b/python/tvm/auto_scheduler/relay_integration.py
@@ -467,6 +467,11 @@ def rewrite_compute_body(compute_tensor, new_layout):
     return outputs[0] if num == 1 else outputs
 
 
+def rewrite_tensor_shape(tensor, shape):
+    """Rewrite the tensor shape"""
+    _ffi_api.RewriteTensorShape(tensor, shape)
+
+
 def is_auto_scheduler_enabled():
     """Return whether the auto-scheduler is enabled.
 
diff --git a/python/tvm/topi/cuda/conv2d_winograd.py 
b/python/tvm/topi/cuda/conv2d_winograd.py
index d2b373ba87..89a21f5c02 100644
--- a/python/tvm/topi/cuda/conv2d_winograd.py
+++ b/python/tvm/topi/cuda/conv2d_winograd.py
@@ -379,6 +379,7 @@ def conv2d_winograd_nhwc_cuda(
     out_dtype,
     pre_computed=False,
     auto_scheduler_rewritten_layout="",
+    meta_schedule_original_shape=None,
 ):
     """Conv2D Winograd in NHWC layout.
     This is a clean version to be used by the auto-scheduler for both CPU and 
GPU.
diff --git a/python/tvm/topi/nn/batch_matmul.py 
b/python/tvm/topi/nn/batch_matmul.py
index 26d45feb03..2156fe11ed 100644
--- a/python/tvm/topi/nn/batch_matmul.py
+++ b/python/tvm/topi/nn/batch_matmul.py
@@ -17,8 +17,10 @@
 """Batch matrix multiplication"""
 # pylint: disable=invalid-name
 import logging
+
 import tvm
-from tvm import te, auto_scheduler
+from tvm import auto_scheduler, te
+
 from ..utils import get_const_tuple
 
 logger = logging.getLogger("topi")
@@ -32,6 +34,7 @@ def batch_matmul(
     transpose_a=False,
     transpose_b=True,
     auto_scheduler_rewritten_layout="",
+    meta_schedule_original_shape=None,
 ):
     """Compute batch matrix multiplication of `tensor_a` and `tensor_b`.
 
@@ -62,6 +65,9 @@ def batch_matmul(
     auto_scheduler_rewritten_layout: Optional[str] = ""
         The layout after auto-scheduler's layout rewrite pass.
 
+    meta_schedule_original_shape: Optional[List[PrimExpr]] = None
+        The original shape of the tensor
+
     Returns
     -------
     output : tvm.te.Tensor
@@ -78,6 +84,12 @@ def batch_matmul(
             auto_scheduler_rewritten_layout, ["b", "k", "j"]
         )
         auto_scheduler.remove_index_check(tensor_b)
+    elif meta_schedule_original_shape:
+        auto_scheduler.rewrite_tensor_shape(tensor_b, 
meta_schedule_original_shape)
+        if transpose_b:
+            YB, YJ, YK = get_const_tuple(tensor_b.shape)
+        else:
+            YB, YK, YJ = get_const_tuple(tensor_b.shape)
     else:
         assert len(tensor_b.shape) == 3, "tensor_b only support 3-dim"
         if transpose_b:
diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py
index b7ae9b3e1c..5db752f6d5 100644
--- a/python/tvm/topi/nn/conv2d.py
+++ b/python/tvm/topi/nn/conv2d.py
@@ -280,6 +280,7 @@ def conv2d_nhwc(
     dilation,
     out_dtype="float32",
     auto_scheduler_rewritten_layout="",
+    meta_schedule_original_shape=None,
 ):
     """Convolution operator in NHWC layout.
 
@@ -308,6 +309,9 @@ def conv2d_nhwc(
     auto_scheduler_rewritten_layout: str = ""
         The layout after auto-scheduler's layout rewrite pass.
 
+    meta_schedule_original_shape: Optional[List[PrimExpr]] = None
+        The original shape of the input tensor.
+
     Returns
     -------
     output : tvm.te.Tensor
@@ -323,6 +327,7 @@ def conv2d_nhwc(
         "NHWC",
         out_dtype,
         auto_scheduler_rewritten_layout,
+        meta_schedule_original_shape,
         auto_scheduler_should_rewrite_layout=True,
     )
 
@@ -716,6 +721,7 @@ def conv(
     order: str,
     out_dtype: Union[str, None] = None,
     auto_scheduler_rewritten_layout: Optional[str] = None,
+    meta_schedule_original_shape=None,
     auto_scheduler_should_rewrite_layout: bool = False,
 ):
     """Convolution operator in NCHW or NHWC layout.
@@ -755,14 +761,17 @@ def conv(
         Elements are converted to this type before elementwise multiplication
         and summation.
 
+    auto_scheduler_rewritten_layout: str
+        Layout from autoscheduler's layout rewritting.
+
+    meta_schedule_original_shape : Optional[List[PrimExpr]]
+        The original shape of the input tensor.
+
     auto_scheduler_should_rewrite_layout : bool
         Should auto scheduler be allowed to rewrite the layout of the filter
         tensor. Defaults to false. This can cause errors if used with grouped
         convs.
 
-    auto_scheduler_rewritten_layout: str
-        Layout from autoscheduler's layout rewritting.
-
     Returns
     -------
     Output : tvm.te.Tensor
@@ -802,6 +811,8 @@ def conv(
         permutation_to_kernel = [dim + 1, dim] + list(range(dim))
     permutation_from_kernel = np.argsort(permutation_to_kernel)
 
+    if meta_schedule_original_shape:
+        auto_scheduler.rewrite_tensor_shape(filt, meta_schedule_original_shape)
     batch, in_channel, *dimensions = 
np.array(get_const_tuple(inp.shape))[permutation_to].tolist()
     num_filter, _, *kernel_dimensions = np.array(get_const_tuple(filt.shape))[
         permutation_to_kernel
@@ -959,6 +970,7 @@ def _conv2d_winograd_nhwc_impl(
     tile_size,
     pre_computed=False,
     auto_scheduler_rewritten_layout="",
+    meta_schedule_original_shape=None,
 ):
     """Conv2D Winograd implementation in NHWC layout.
     This is a clean version to be used by the auto-scheduler for both CPU and 
GPU.
@@ -983,6 +995,8 @@ def _conv2d_winograd_nhwc_impl(
         Whether the kernel is precomputed
     auto_scheduler_rewritten_layout: str = ""
         The layout after auto-scheduler's layout rewrite pass.
+    meta_schedule_original_shape: Optional[List[PrimExpr]] = None
+        The original shape of the input tensor.
 
     Returns
     -------
@@ -994,6 +1008,8 @@ def _conv2d_winograd_nhwc_impl(
         dilation_h = dilation_w = dilation
     else:
         dilation_h, dilation_w = dilation
+    if meta_schedule_original_shape:
+        auto_scheduler.rewrite_tensor_shape(weight, 
meta_schedule_original_shape)
 
     assert (dilation_h, dilation_w) == (1, 1), "Does not support dilation"
     if not pre_computed:
@@ -1136,6 +1152,7 @@ def conv2d_winograd_nhwc(
     out_dtype,
     pre_computed=False,
     auto_scheduler_rewritten_layout="",
+    meta_schedule_original_shape=None,
 ):
     """Conv2D Winograd in NHWC layout.
     This is a clean version to be used by the auto-scheduler for both CPU and 
GPU.
@@ -1158,6 +1175,8 @@ def conv2d_winograd_nhwc(
         Whether the kernel is precomputed
     auto_scheduler_rewritten_layout: str = ""
         The layout after auto-scheduler's layout rewrite pass.
+    meta_schedule_original_shape: Optional[List[PrimExpr]] = None
+        The original shape of the input tensor.
 
     Returns
     -------
@@ -1176,6 +1195,7 @@ def conv2d_winograd_nhwc(
         tile_size,
         pre_computed,
         auto_scheduler_rewritten_layout,
+        meta_schedule_original_shape,
     )
 
 
@@ -1187,6 +1207,7 @@ def conv2d_winograd_nhwc_without_weight_transform(
     dilation,
     out_dtype,
     auto_scheduler_rewritten_layout="",
+    meta_schedule_original_shape=None,
 ):
     """Conv2D Winograd without layout transform in NHWC layout.
     This is a clean version to be used by the auto-scheduler for both CPU and 
GPU.
@@ -1207,6 +1228,8 @@ def conv2d_winograd_nhwc_without_weight_transform(
         Specifies the output data type.
     auto_scheduler_rewritten_layout: str = ""
         The layout after auto-scheduler's layout rewrite pass.
+    meta_schedule_original_shape: Optional[List[PrimExpr]] = None
+        The original shape of the input tensor.
 
     Returns
     -------
@@ -1223,4 +1246,5 @@ def conv2d_winograd_nhwc_without_weight_transform(
         out_dtype,
         pre_computed=True,
         auto_scheduler_rewritten_layout=auto_scheduler_rewritten_layout,
+        meta_schedule_original_shape=meta_schedule_original_shape,
     )
diff --git a/python/tvm/topi/nn/conv3d.py b/python/tvm/topi/nn/conv3d.py
index 2915b886a5..591c643a95 100644
--- a/python/tvm/topi/nn/conv3d.py
+++ b/python/tvm/topi/nn/conv3d.py
@@ -21,8 +21,8 @@ import tvm
 from tvm import te
 
 from ..utils import get_const_tuple
-from .winograd_util import winograd_transform_matrices
 from .conv2d import conv
+from .winograd_util import winograd_transform_matrices
 
 
 def conv3d_ncdhw(Input, Filter, stride, padding, dilation, groups, 
out_dtype=None):
@@ -65,6 +65,7 @@ def conv3d_ndhwc(
     groups,
     out_dtype="float32",
     auto_scheduler_rewritten_layout="",
+    meta_schedule_origin_shape=None,
 ):
     """Convolution operator in NDHWC layout.
 
@@ -94,6 +95,9 @@ def conv3d_ndhwc(
     auto_scheduler_rewritten_layout: str = ""
         The layout after auto-scheduler's layout rewrite pass.
 
+    meta_schedule_origin_shape: Optional[List[PrimExpr]] = None
+        The original shape of the input tensor.
+
     Returns
     -------
     Output : tvm.te.Tensor
@@ -109,6 +113,7 @@ def conv3d_ndhwc(
         "NDHWC",
         out_dtype,
         auto_scheduler_rewritten_layout,
+        meta_schedule_origin_shape,
     )
 
 
diff --git a/python/tvm/topi/nn/dense.py b/python/tvm/topi/nn/dense.py
index 69fac92c7c..61f9c4e17c 100644
--- a/python/tvm/topi/nn/dense.py
+++ b/python/tvm/topi/nn/dense.py
@@ -17,7 +17,8 @@
 # pylint: disable=invalid-name,unused-argument
 """TVM operator fully connected compute."""
 import tvm
-from tvm import te, auto_scheduler
+from tvm import auto_scheduler, te
+
 from .. import tag
 
 
@@ -29,6 +30,7 @@ def matmul(
     transpose_a=False,
     transpose_b=False,
     auto_scheduler_rewritten_layout="",
+    meta_schedule_original_shape=None,
 ):
     """The default implementation of matmul in topi.
 
@@ -55,6 +57,9 @@ def matmul(
     auto_scheduler_rewritten_layout: Optional[str] = ""
         The layout after auto-scheduler's layout rewrite pass.
 
+    meta_schedule_original_shape: Optional[List[PrimExpr]] = None
+        The original shape of the input tensor.
+
     Returns
     -------
     output : tvm.te.Tensor
@@ -77,6 +82,12 @@ def matmul(
             auto_scheduler_rewritten_layout, ["j", "k"]
         )
         auto_scheduler.remove_index_check(tensor_b)
+    elif meta_schedule_original_shape:
+        auto_scheduler.rewrite_tensor_shape(tensor_b, 
meta_schedule_original_shape)
+        if transpose_b:
+            out_dim, red_dim = tensor_b.shape
+        else:
+            red_dim, out_dim = tensor_b.shape
     elif transpose_b:
         out_dim, red_dim = tensor_b.shape
     else:
@@ -156,7 +167,14 @@ def matmul_legalize(attrs, inputs, types):
     return None
 
 
-def dense(data, weight, bias=None, out_dtype=None, 
auto_scheduler_rewritten_layout=""):
+def dense(
+    data,
+    weight,
+    bias=None,
+    out_dtype=None,
+    auto_scheduler_rewritten_layout="",
+    meta_schedule_original_shape=None,
+):
     """The default implementation of dense in topi.
     This is an alias of matmul_nt operator for data tensor in non-transposed 
format and weight
     tensor in transposed format.
@@ -178,12 +196,24 @@ def dense(data, weight, bias=None, out_dtype=None, 
auto_scheduler_rewritten_layo
     auto_scheduler_rewritten_layout: str = ""
         The layout after auto-scheduler's layout rewrite pass.
 
+    meta_schedule_original_shape: Optional[List[PrimExpr]] = None
+        The original shape of the input tensor.
+
     Returns
     -------
     output : tvm.te.Tensor
         2-D with shape [batch, out_dim]
     """
-    return matmul(data, weight, bias, out_dtype, False, True, 
auto_scheduler_rewritten_layout)
+    return matmul(
+        data,
+        weight,
+        bias,
+        out_dtype,
+        False,
+        True,
+        auto_scheduler_rewritten_layout,
+        meta_schedule_original_shape,
+    )
 
 
 @tvm.target.generic_func
diff --git a/src/auto_scheduler/compute_dag.cc 
b/src/auto_scheduler/compute_dag.cc
index e82830fa4d..dad55db030 100644
--- a/src/auto_scheduler/compute_dag.cc
+++ b/src/auto_scheduler/compute_dag.cc
@@ -1517,6 +1517,16 @@ 
TVM_REGISTER_GLOBAL("auto_scheduler.RewriteIndexForNewLayout")
       return index_rewriter.Rewrite(body);
     });
 
+TVM_REGISTER_GLOBAL("auto_scheduler.RewriteTensorShape")
+    .set_body_typed([](te::Tensor tensor, Array<PrimExpr> new_shape) -> void {
+      ICHECK(tensor->op->IsInstance<te::PlaceholderOpNode>());
+      te::PlaceholderOpNode* op =
+          
const_cast<te::PlaceholderOpNode*>(tensor->op.as<te::PlaceholderOpNode>());
+      te::TensorNode* t = const_cast<te::TensorNode*>(tensor.get());
+      op->shape = new_shape;
+      t->shape = new_shape;
+    });
+
 TVM_REGISTER_GLOBAL("auto_scheduler.GetShapeFromRewrittenLayout")
     .set_body_typed(GetShapeFromRewrittenLayout);
 

Reply via email to