This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a393b47563 [Relax] Add FInferMixedPrecision and FRelaxInferLayout for 
conv transpose ops (#18629)
a393b47563 is described below

commit a393b4756368d26db927085cb1de028b567e78c0
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Sat Jan 3 13:26:20 2026 +0800

    [Relax] Add FInferMixedPrecision and FRelaxInferLayout for conv transpose 
ops (#18629)
    
    ## Why
    
    The `conv1d_transpose` and `conv2d_transpose` operators were missing
    FInferMixedPrecision and FRelaxInferLayout attribute implementations,
    which are needed for:
    
    - Mixed precision training/inference support (e.g., float16 inputs with
    float32 outputs)
    - Layout transformation optimizations during compilation
    - Consistency with conv1d and conv2d operators which already have these
    attributes
    
    ## How
    
    - Implemented InferLayoutConv1dTranspose and
    InferMixedPrecisionConv1dTranspose
    - Implemented InferLayoutConv2dTranspose and
    InferMixedPrecisionConv2dTranspose
---
 src/relax/op/nn/convolution.cc               | 137 ++++++++++++++++++++++++++-
 tests/python/relax/test_op_nn_convolution.py |  38 ++++++++
 2 files changed, 171 insertions(+), 4 deletions(-)

diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc
index 49e92719ba..ca09c0f1cb 100644
--- a/src/relax/op/nn/convolution.cc
+++ b/src/relax/op/nn/convolution.cc
@@ -707,14 +707,62 @@ StructInfo InferStructInfoConv1dTranspose(const Call& 
call, const BlockBuilder&
   return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice);
 }
 
-// TODO(relax-team): implement FInferMixedPrecision and FRelaxInferLayout for 
conv1d_transpose
-// and unit test for mixed_precision
+InferLayoutOutput InferLayoutConv1dTranspose(
+    const Call& call, const ffi::Map<ffi::String, ffi::Array<ffi::String>>& 
desired_layouts,
+    const VarLayoutMap& var_layout_map) {
+  const auto* attrs = call->attrs.as<Conv1DTransposeAttrs>();
+  LayoutDecision data_layout, weight_layout, output_layout;
+  ObjectPtr<Conv1DTransposeAttrs> new_attrs = 
ffi::make_object<Conv1DTransposeAttrs>(*attrs);
+
+  auto it = desired_layouts.find("relax.nn.conv1d_transpose");
+  if (it != desired_layouts.end()) {
+    Layout desired_data_layout = (*it).second[0];
+    Layout desired_weight_layout = (*it).second[1];
+    Layout desired_output_layout = (*it).second.size() == 3 ? (*it).second[2] 
: (*it).second[0];
+    ICHECK_EQ(desired_data_layout.ndim(), desired_data_layout.ndim_primal()) 
<< "Axis swap only";
+    ICHECK_EQ(desired_weight_layout.ndim(), 
desired_weight_layout.ndim_primal())
+        << "Axis swap only";
+    ICHECK_EQ(desired_output_layout.ndim(), 
desired_output_layout.ndim_primal())
+        << "Axis swap only";
+    data_layout = TransposeLike(InitialLayout(3), attrs->data_layout, 
desired_data_layout);
+    weight_layout = TransposeLike(InitialLayout(3), attrs->kernel_layout, 
desired_weight_layout);
+    output_layout = TransposeLike(InitialLayout(3), attrs->out_layout, 
desired_output_layout);
+    new_attrs->data_layout = (*it).second[0];
+    new_attrs->kernel_layout = (*it).second[1];
+    new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] : 
(*it).second[0];
+  } else {
+    data_layout = GetLayoutDecision(var_layout_map, call->args[0]);
+    weight_layout = GetLayoutDecision(var_layout_map, call->args[1]);
+    output_layout = data_layout;
+    new_attrs->data_layout =
+        TransposeLike(attrs->data_layout, InitialLayout(3), 
data_layout->layout).name();
+    new_attrs->kernel_layout =
+        TransposeLike(attrs->kernel_layout, InitialLayout(3), 
weight_layout->layout).name();
+    new_attrs->out_layout =
+        TransposeLike(attrs->out_layout, InitialLayout(3), 
output_layout->layout).name();
+  }
+  return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, 
Attrs(new_attrs));
+}
+
+Call InferMixedPrecisionConv1dTranspose(const Call& call, const DataType& 
out_dtype) {
+  const auto* conv1d_transpose_attrs = call->attrs.as<Conv1DTransposeAttrs>();
+  return Downcast<Call>(
+      conv1d_transpose(call->args[0], call->args[1], 
conv1d_transpose_attrs->strides,
+                       conv1d_transpose_attrs->padding, 
conv1d_transpose_attrs->output_padding,
+                       conv1d_transpose_attrs->dilation, 
conv1d_transpose_attrs->groups,
+                       conv1d_transpose_attrs->data_layout, 
conv1d_transpose_attrs->kernel_layout,
+                       conv1d_transpose_attrs->out_layout, out_dtype));
+}
+
 TVM_REGISTER_OP("relax.nn.conv1d_transpose")
     .set_num_inputs(2)
     .add_argument("data", "Tensor", "The input tensor.")
     .add_argument("weight", "Tensor", "The weight tensor.")
     .set_attrs_type<Conv1DTransposeAttrs>()
     .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoConv1dTranspose)
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", 
InferLayoutConv1dTranspose)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kAlways)
+    .set_attr<FInferMixedPrecision>("FInferMixedPrecision", 
InferMixedPrecisionConv1dTranspose)
     .set_attr<Bool>("FPurity", Bool(true));
 
 /* relax.nn.conv2d_transpose */
@@ -857,14 +905,95 @@ StructInfo InferStructInfoConv2dTranspose(const Call& 
call, const BlockBuilder&
   return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice);
 }
 
-// TODO(relax-team): implement FInferMixedPrecision and FRelaxInferLayout for 
conv2d_transpose
-// and unit test for mixed_precision
+InferLayoutOutput InferLayoutConv2dTranspose(
+    const Call& call, const ffi::Map<ffi::String, ffi::Array<ffi::String>>& 
desired_layouts,
+    const VarLayoutMap& var_layout_map) {
+  const auto* attrs = call->attrs.as<Conv2DTransposeAttrs>();
+  LayoutDecision data_layout = GetLayoutDecision(var_layout_map, 
call->args[0]);
+  LayoutDecision weight_layout = GetLayoutDecision(var_layout_map, 
call->args[1]);
+  LayoutDecision output_layout;
+  ObjectPtr<Conv2DTransposeAttrs> new_attrs = 
ffi::make_object<Conv2DTransposeAttrs>(*attrs);
+
+  auto it = desired_layouts.find("relax.nn.conv2d_transpose");
+  if (it != desired_layouts.end()) {
+    Layout desired_data_layout = (*it).second[0];
+    Layout desired_weight_layout = (*it).second[1];
+    Layout desired_output_layout = (*it).second.size() == 3 ? (*it).second[2] 
: (*it).second[0];
+
+    Layout input_layout = Layout(attrs->data_layout);
+    Layout kernel_layout = Layout(attrs->kernel_layout);
+    Layout out_layout = Layout(attrs->out_layout);
+
+    if (desired_data_layout.ndim_primal() == input_layout.ndim() &&
+        desired_weight_layout.ndim_primal() == kernel_layout.ndim() &&
+        desired_output_layout.ndim_primal() == out_layout.ndim()) {
+      data_layout = TransposeLike(InitialLayout(4), attrs->data_layout, 
desired_data_layout);
+      weight_layout = TransposeLike(InitialLayout(4), attrs->kernel_layout, 
desired_weight_layout);
+      output_layout = TransposeLike(InitialLayout(4), attrs->out_layout, 
desired_output_layout);
+      new_attrs->data_layout = (*it).second[0];
+      new_attrs->kernel_layout = (*it).second[1];
+      new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] : 
(*it).second[0];
+      return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, 
Attrs(new_attrs));
+    } else {
+      auto data_si = GetStructInfo(call->args[0]);
+      auto kernel_si = GetStructInfo(call->args[1]);
+      TensorStructInfo data_sinfo = data_si.as<TensorStructInfo>().value();
+      TensorStructInfo kernel_sinfo = kernel_si.as<TensorStructInfo>().value();
+      ffi::Optional<ShapeExpr> data_shape =
+          ffi::GetRef<ShapeExpr>(data_sinfo->shape.as<ShapeExprNode>());
+      ffi::Optional<ShapeExpr> kernel_shape =
+          ffi::GetRef<ShapeExpr>(kernel_sinfo->shape.as<ShapeExprNode>());
+
+      bool can_data_proved =
+          CanProveLayoutTransform(input_layout, desired_data_layout, 
data_shape.value()->values);
+      bool can_kernel_proved = CanProveLayoutTransform(kernel_layout, 
desired_weight_layout,
+                                                       
kernel_shape.value()->values);
+
+      if (can_data_proved && can_kernel_proved) {
+        data_layout = TransposeSubLayoutLike(InitialLayout(4), input_layout, 
desired_data_layout);
+        weight_layout =
+            TransposeSubLayoutLike(InitialLayout(4), kernel_layout, 
desired_weight_layout);
+        output_layout = TransposeSubLayoutLike(InitialLayout(4), out_layout, 
desired_output_layout);
+        new_attrs->data_layout = (*it).second[0];
+        new_attrs->kernel_layout = (*it).second[1];
+        new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] : 
(*it).second[0];
+        return InferLayoutOutput({data_layout, weight_layout}, 
{output_layout}, Attrs(new_attrs));
+      } else {
+        data_layout = LayoutDecision(InitialLayout(4));
+        weight_layout = LayoutDecision(InitialLayout(4));
+      }
+    }
+  }
+
+  output_layout = data_layout;
+  new_attrs->data_layout =
+      TransposeLike(attrs->data_layout, InitialLayout(4), 
data_layout->layout).name();
+  new_attrs->kernel_layout =
+      TransposeLike(attrs->kernel_layout, InitialLayout(4), 
weight_layout->layout).name();
+  new_attrs->out_layout =
+      TransposeLike(attrs->out_layout, InitialLayout(4), 
output_layout->layout).name();
+  return InferLayoutOutput({data_layout, weight_layout}, {output_layout}, 
Attrs(new_attrs));
+}
+
+Call InferMixedPrecisionConv2dTranspose(const Call& call, const DataType& 
out_dtype) {
+  const auto* conv2d_transpose_attrs = call->attrs.as<Conv2DTransposeAttrs>();
+  return Downcast<Call>(
+      conv2d_transpose(call->args[0], call->args[1], 
conv2d_transpose_attrs->strides,
+                       conv2d_transpose_attrs->padding, 
conv2d_transpose_attrs->output_padding,
+                       conv2d_transpose_attrs->dilation, 
conv2d_transpose_attrs->groups,
+                       conv2d_transpose_attrs->data_layout, 
conv2d_transpose_attrs->kernel_layout,
+                       conv2d_transpose_attrs->out_layout, out_dtype));
+}
+
 TVM_REGISTER_OP("relax.nn.conv2d_transpose")
     .set_num_inputs(2)
     .add_argument("data", "Tensor", "The input tensor.")
     .add_argument("weight", "Tensor", "The weight tensor.")
     .set_attrs_type<Conv2DTransposeAttrs>()
     .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoConv2dTranspose)
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", 
InferLayoutConv2dTranspose)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kAlways)
+    .set_attr<FInferMixedPrecision>("FInferMixedPrecision", 
InferMixedPrecisionConv2dTranspose)
     .set_attr<Bool>("FPurity", Bool(true));
 
 }  // namespace relax
diff --git a/tests/python/relax/test_op_nn_convolution.py 
b/tests/python/relax/test_op_nn_convolution.py
index 588dc9b1b1..9b913138df 100644
--- a/tests/python/relax/test_op_nn_convolution.py
+++ b/tests/python/relax/test_op_nn_convolution.py
@@ -782,6 +782,25 @@ def 
test_conv1d_transpose_infer_struct_info_wrong_input_type():
         bb.normalize(relax.op.nn.conv1d_transpose(x1, w0))
 
 
+def test_conv1d_transpose_infer_struct_info_mixed_precision():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 28), "float16"))
+    w0 = relax.Var("w", R.Tensor((3, 4, 3), "float16"))
+    x1 = relax.Var("x", R.Tensor((2, 3, 28), "int8"))
+    w1 = relax.Var("w", R.Tensor((3, 4, 3), "int8"))
+
+    _check_inference(
+        bb,
+        relax.op.nn.conv1d_transpose(x0, w0, out_dtype="float32"),
+        relax.TensorStructInfo((2, 4, 30), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv1d_transpose(x1, w1, out_dtype="int32"),
+        relax.TensorStructInfo((2, 4, 30), "int32"),
+    )
+
+
 def test_conv2d_infer_struct_info():
     bb = relax.BlockBuilder()
     vdev0 = VDevice("llvm")
@@ -1571,6 +1590,25 @@ def 
test_conv2d_transpose_infer_struct_info_wrong_input_type():
         bb.normalize(relax.op.nn.conv2d_transpose(x1, w0))
 
 
+def test_conv2d_transpose_infer_struct_info_mixed_precision():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float16"))
+    w0 = relax.Var("w", R.Tensor((3, 4, 3, 3), "float16"))
+    x1 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int8"))
+    w1 = relax.Var("w", R.Tensor((3, 4, 3, 3), "int8"))
+
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d_transpose(x0, w0, out_dtype="float32"),
+        relax.TensorStructInfo((2, 4, 30, 30), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d_transpose(x1, w1, out_dtype="int32"),
+        relax.TensorStructInfo((2, 4, 30, 30), "int32"),
+    )
+
+
 def test_conv3d_infer_struct_info():
     bb = relax.BlockBuilder()
     vdev0 = VDevice("llvm")

Reply via email to