This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4e07a8ed66 [TOPI] remove the i32 cast for output shape of pool (#14549)
4e07a8ed66 is described below

commit 4e07a8ed6687a08b6b27db21af019a5a179b9ee1
Author: Yong Wu <[email protected]>
AuthorDate: Sun Apr 9 18:59:16 2023 -0700

    [TOPI] remove the i32 cast for output shape of pool (#14549)
    
    [TOPI] remove the cast for output shape of pool
---
 include/tvm/topi/nn/dilate.h              |  3 +-
 include/tvm/topi/nn/pooling.h             | 58 +++++++++++++------------------
 include/tvm/topi/transform.h              |  6 +---
 tests/python/relay/test_op_grad_level2.py |  2 +-
 tests/python/relay/test_op_level2.py      |  7 ++--
 5 files changed, 30 insertions(+), 46 deletions(-)

diff --git a/include/tvm/topi/nn/dilate.h b/include/tvm/topi/nn/dilate.h
index 3369316e4d..74c46e2694 100644
--- a/include/tvm/topi/nn/dilate.h
+++ b/include/tvm/topi/nn/dilate.h
@@ -76,8 +76,7 @@ inline Tensor dilate(const Tensor& x, Array<PrimExpr> 
strides, double dilation_v
   Array<PrimExpr> out_shape;
   arith::Analyzer analyzer;
   for (size_t i = 0; i < n; ++i) {
-    out_shape.push_back(
-        analyzer.Simplify((x->shape[i] - 1) * cast(DataType::Int(32), 
strides[i] + 1)));
+    out_shape.push_back(analyzer.Simplify((x->shape[i] - 1) * (strides[i] + 
1)));
   }
 
   return tvm::te::compute(
diff --git a/include/tvm/topi/nn/pooling.h b/include/tvm/topi/nn/pooling.h
index 5f365e5192..ac048f585c 100644
--- a/include/tvm/topi/nn/pooling.h
+++ b/include/tvm/topi/nn/pooling.h
@@ -57,18 +57,18 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const 
Tensor& x,
   ICHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 
elements";
   ICHECK_EQ(padding_size.size(), 4) << "Pooling padding_size must have 4 
elements";
 
-  auto kernel_height = cast(DataType::DataType::Int(32), kernel_size[0]);
-  auto kernel_width = cast(DataType::DataType::Int(32), kernel_size[1]);
-  auto stride_height = cast(DataType::DataType::Int(32), stride_size[0]);
-  auto stride_width = cast(DataType::DataType::Int(32), stride_size[1]);
+  auto kernel_height = kernel_size[0];
+  auto kernel_width = kernel_size[1];
+  auto stride_height = stride_size[0];
+  auto stride_width = stride_size[1];
 
-  auto height = cast(DataType::DataType::Int(32), x->shape[height_axis]);
-  auto width = cast(DataType::DataType::Int(32), x->shape[width_axis]);
+  auto height = x->shape[height_axis];
+  auto width = x->shape[width_axis];
 
-  auto pad_top = cast(DataType::DataType::Int(32), padding_size[0]);
-  auto pad_left = cast(DataType::DataType::Int(32), padding_size[1]);
-  auto pad_bottom = cast(DataType::DataType::Int(32), padding_size[2]);
-  auto pad_right = cast(DataType::DataType::Int(32), padding_size[3]);
+  auto pad_top = padding_size[0];
+  auto pad_left = padding_size[1];
+  auto pad_bottom = padding_size[2];
+  auto pad_right = padding_size[3];
 
   if (ceil_mode) {
     // Additional padding to ensure we do ceil instead of floor when
@@ -94,10 +94,6 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, const 
Tensor& x,
   auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width), "dw");
 
   Array<PrimExpr> data_shape = x->shape;
-  for (size_t i = 0; i < data_shape.size(); ++i) {
-    data_shape.Set(i, cast(DataType::DataType::Int(32), data_shape[i]));
-  }
-
   Array<PrimExpr> out_shape = data_shape;
   out_shape.Set(height_axis, out_height);
   out_shape.Set(width_axis, out_width);
@@ -148,10 +144,10 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, 
const Tensor& x,
           out_idx.Set(width_axis, (inds[width_axis] + pad_left) / stride_width 
- windoww);
 
           PrimExpr out_idx_lower_h = tir::Select(
-              pad_inds[height_axis] < kernel_height, 
make_const(DataType::DataType::Int(32), 0),
+              pad_inds[height_axis] < kernel_height, 
make_const(pad_inds[height_axis].dtype(), 0),
               (pad_inds[height_axis] - kernel_height) / stride_height + 1);
           PrimExpr out_idx_lower_w = tir::Select(
-              pad_inds[width_axis] < kernel_width, 
make_const(DataType::DataType::Int(32), 0),
+              pad_inds[width_axis] < kernel_width, 
make_const(pad_inds[width_axis].dtype(), 0),
               (pad_inds[width_axis] - kernel_width) / stride_width + 1);
 
           return tvm::sum(
@@ -179,10 +175,10 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, 
const Tensor& x,
           out_idx.Set(width_axis, (pad_w_idx / stride_width - windoww));
 
           PrimExpr out_idx_lower_h =
-              tir::Select(pad_h_idx < kernel_height, 
make_const(DataType::Int(32), 0),
+              tir::Select(pad_h_idx < kernel_height, 
make_const(pad_h_idx.dtype(), 0),
                           (pad_h_idx - kernel_height) / stride_height + 1);
           PrimExpr out_idx_lower_w =
-              tir::Select(pad_w_idx < kernel_width, 
make_const(DataType::Int(32), 0),
+              tir::Select(pad_w_idx < kernel_width, 
make_const(pad_w_idx.dtype(), 0),
                           (pad_w_idx - kernel_width) / stride_width + 1);
 
           PrimExpr divide_factor;  // number of pooled elements
@@ -194,10 +190,10 @@ inline Tensor pool_grad_impl(const Tensor& out_grad, 
const Tensor& x,
 
             PrimExpr h_end = min(h_start + kernel_height, height);
             PrimExpr w_end = min(w_start + kernel_width, width);
-            h_start = max(h_start, make_const(DataType::Int(32), 0));
-            w_start = max(w_start, make_const(DataType::Int(32), 0));
+            h_start = max(h_start, make_const(h_start.dtype(), 0));
+            w_start = max(w_start, make_const(w_start.dtype(), 0));
             divide_factor =
-                max((h_end - h_start) * (w_end - w_start), 
make_const(DataType::Int(32), 1));
+                max((h_end - h_start) * (w_end - w_start), 
make_const(h_end.dtype(), 1));
           }
           return tvm::sum(
               tvm::if_then_else(tir::And(tir::And(out_idx[height_axis] >= 
out_idx_lower_h,
@@ -329,14 +325,11 @@ inline Tensor adaptive_pool_impl(const Tensor& x, const 
Array<PrimExpr>& output_
   ICHECK_EQ(axes.size(), n_dim) << "The number of axes not equal to the in/out 
dimension";
 
   Array<PrimExpr> data_shape = x->shape;
-  for (size_t i = 0; i < data_shape.size(); ++i) {
-    data_shape.Set(i, cast(DataType::DataType::Int(32), data_shape[i]));
-  }
   Array<PrimExpr> out_shape = data_shape;
   Array<PrimExpr> in_size, out_size;
   for (size_t i = 0; i < n_dim; ++i) {
     in_size.push_back(data_shape[axes[i]]);
-    out_size.push_back(cast(DataType::Int(32), output_size[i]));
+    out_size.push_back(output_size[i]);
     out_shape.Set(axes[i], out_size[i]);
   }
 
@@ -532,19 +525,16 @@ inline Tensor pool_impl_nd(const Tensor& x, const 
Array<PrimExpr>& kernel_size,
   Array<PrimExpr> pad_before(std::vector<PrimExpr>(x_size, 0));
   Array<PrimExpr> pad_after(std::vector<PrimExpr>(x_size, 0));
   Array<PrimExpr> data_shape = x->shape;
-  for (size_t i = 0; i < data_shape.size(); ++i) {
-    data_shape.Set(i, cast(DataType::DataType::Int(32), data_shape[i]));
-  }
   Array<PrimExpr> out_shape = data_shape;
 
   bool do_pad = false;
   for (int i = 0; i < k_size; i++) {
     int ii = axis[i];
-    kernel[i] = cast(DataType::Int(32), kernel_size[i]);
-    stride[i] = cast(DataType::Int(32), stride_size[i]);
-    dilation[i] = cast(DataType::Int(32), dilation_size[i]);
-    pad_head[i] = cast(DataType::Int(32), padding_size[i]);
-    pad_tail[i] = cast(DataType::Int(32), padding_size[i + k_size]);
+    kernel[i] = kernel_size[i];
+    stride[i] = stride_size[i];
+    dilation[i] = dilation_size[i];
+    pad_head[i] = padding_size[i];
+    pad_tail[i] = padding_size[i + k_size];
 
     if (ceil_mode) {
       // The offset[i] is an additional padding to ensure we do ceil instead 
of floor when
@@ -650,7 +640,7 @@ inline Tensor pool_impl_nd(const Tensor& x, const 
Array<PrimExpr>& kernel_size,
               // number that represents the number of steps along the dilated 
kernel to reach a
               // non-padded value. Otherwise this should be 0.
               PrimExpr jumps_to_non_pad = (dilation[i] - 1 - start[i]) / 
dilation[i];
-              jumps_to_non_pad = max(jumps_to_non_pad, 
make_const(DataType::Int(32), 0));
+              jumps_to_non_pad = max(jumps_to_non_pad, 
make_const(jumps_to_non_pad.dtype(), 0));
 
               end[i] = min(end[i], data_shape[ii] - 1);
               num_el *= (end[i] - (start[i] + dilation[i] * jumps_to_non_pad)) 
/ dilation[i] + 1;
diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h
index aa75fb05a0..7fe56d9532 100644
--- a/include/tvm/topi/transform.h
+++ b/include/tvm/topi/transform.h
@@ -323,11 +323,7 @@ inline Tensor reshape(const Tensor& x, Array<PrimExpr> 
newshape, std::string nam
   Array<PrimExpr> target_shape;
 
   for (const auto& ele : newshape) {
-    if (ele.as<IntImmNode>()) {
-      target_shape.push_back(cast(DataType::Int(32), ele));
-    } else {
-      target_shape.push_back(ele);
-    }
+    target_shape.push_back(ele);
   }
 
   // If either the input shape or the target shape contains a zero, return an 
empty tensor.
diff --git a/tests/python/relay/test_op_grad_level2.py 
b/tests/python/relay/test_op_grad_level2.py
index bbd851dc9c..7a40a58ee8 100644
--- a/tests/python/relay/test_op_grad_level2.py
+++ b/tests/python/relay/test_op_grad_level2.py
@@ -154,7 +154,7 @@ def test_avg_pool2d_grad(executor_kind):
         ceil_mode=False,
         count_include_pad=False,
         executor_kind=executor_kind,
-        dtype="int32",
+        dtype="float16",
     )
 
 
diff --git a/tests/python/relay/test_op_level2.py 
b/tests/python/relay/test_op_level2.py
index f7cfc81fb2..b2a8c2cdb0 100644
--- a/tests/python/relay/test_op_level2.py
+++ b/tests/python/relay/test_op_level2.py
@@ -1108,7 +1108,6 @@ def test_pool2d():
         yy = run_infer_type(y)
         assert yy.checked_type == relay.TensorType((n, 10, 224, 224), dtype)
         # test execution
-        dtype = "int32"
         dshape = (1, 3, 28, 28)
         for shape_dtype in ["int32", "int64"]:
             x = relay.var("x", shape=[tvm.tir.IntImm(shape_dtype, x) for x in 
dshape], dtype=dtype)
@@ -1129,8 +1128,8 @@ def test_pool2d():
     _test_pool2d(relay.nn.avg_pool2d, "avg", pool_size=2, strides=2, padding=0)
     _test_pool2d(relay.nn.avg_pool2d, "avg", pool_size=2, strides=2, 
padding=0, dilation=2)
 
-    _test_pool2d_int(relay.nn.avg_pool2d, np.mean, "int32")
-    _test_pool2d_int(relay.nn.avg_pool2d, np.mean, "uint16")
+    _test_pool2d_int(relay.nn.avg_pool2d, np.mean, "int64")
+    _test_pool2d_int(relay.nn.avg_pool2d, np.mean, "float16")
     _test_global_pool2d(relay.nn.global_max_pool2d, np.max)
     _test_global_pool2d(relay.nn.global_avg_pool2d, np.mean)
 
@@ -1201,7 +1200,7 @@ def test_pool1d():
     _test_pool1d(relay.nn.max_pool1d, "max", pool_size=2, strides=2, padding=0)
     _test_pool1d(relay.nn.max_pool1d, "max", pool_size=2, strides=2, 
padding=0, dilation=2)
     _test_pool1d(relay.nn.avg_pool1d, "avg")
-    _test_pool1d(relay.nn.avg_pool1d, "avg", dtype="int32")
+    _test_pool1d(relay.nn.avg_pool1d, "avg", dtype="int64")
     _test_pool1d(relay.nn.avg_pool1d, "avg", pool_size=2, strides=2, padding=0)
     _test_pool1d(relay.nn.avg_pool1d, "avg", pool_size=2, strides=2, 
padding=0, dilation=2)
     _test_global_pool1d(relay.nn.global_max_pool1d, np.max)

Reply via email to