This is an automated email from the ASF dual-hosted git repository.

sanirudh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ab7c1a91d8 [Relax] Support `input_axis_separator` to allow 2D to 1D 
conversion (#17115)
ab7c1a91d8 is described below

commit ab7c1a91d81ae91ad806c2f97c11f6b104ab2ec5
Author: Abhikrant Sharma <quic_abhik...@quicinc.com>
AuthorDate: Mon Jul 1 12:31:07 2024 +0530

    [Relax] Support `input_axis_separator` to allow 2D to 1D conversion (#17115)
    
    * [Relax] Support input axis_separator to allow 2D to 1D conversion
    
    Introduce input_axis_separator in relax.transform_layout op to allow 
conversion of 2D buffers to 1D buffers.
    The conversion from 2D->1D is handled while lowering of transform_layout 
operator.
    Also introducing support for input_axis_separator in AlterOpImpl pass.
    
    * Fix LINT errors
    
    * Fix review comments
---
 include/tvm/relax/attrs/manipulate.h               |  8 ++
 include/tvm/relax/transform.h                      |  4 +-
 python/tvm/relax/op/manipulate.py                  |  8 +-
 .../tvm/relax/transform/legalize_ops/manipulate.py | 13 +++-
 python/tvm/relax/transform/transform.py            | 12 ++-
 src/relax/op/tensor/manipulate.cc                  |  4 +-
 src/relax/op/tensor/manipulate.h                   |  4 +-
 src/relax/transform/alter_op_impl.cc               | 68 ++++++++++++-----
 tests/python/relax/test_transform_alter_op_impl.py | 85 +++++++++++++++++++++-
 9 files changed, 179 insertions(+), 27 deletions(-)

diff --git a/include/tvm/relax/attrs/manipulate.h 
b/include/tvm/relax/attrs/manipulate.h
index b9d0b9f53b..ef4265d73b 100644
--- a/include/tvm/relax/attrs/manipulate.h
+++ b/include/tvm/relax/attrs/manipulate.h
@@ -66,6 +66,12 @@ struct LayoutTransformAttrs : public 
tvm::AttrsNode<LayoutTransformAttrs> {
    * first input axis that is part of a new flattened axis.
    */
   Optional<Array<IntImm>> axis_separators;
+  /*!
+   * axis_separators for input buffers.
+   * Needed to identify if the input buffer to layout_transform
+   * contains axis separator.
+   */
+  Optional<Array<IntImm>> input_axis_separators;
 
   TVM_DECLARE_ATTRS(LayoutTransformAttrs, "relax.attrs.LayoutTransformAttrs") {
     TVM_ATTR_FIELD(index_map).describe("The layout transformation to apply.");
@@ -74,6 +80,8 @@ struct LayoutTransformAttrs : public 
tvm::AttrsNode<LayoutTransformAttrs> {
         "padding. If not specified, the compiler is free to choose any 
value.");
     TVM_ATTR_FIELD(axis_separators)
         .describe("The separators between input axes when generating flat 
output axes");
+    TVM_ATTR_FIELD(input_axis_separators)
+        .describe("The separators between axes to regenerate output");
   }
 };  // struct LayoutTransformAttrs
 
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index d8f36e4786..5a7b85ac13 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -559,11 +559,13 @@ TVM_DLL Pass DecomposeOpsForTraining(Optional<String> 
func_name);
  * \param op_buffer_transforms Map from kOperatorName attr to layout 
transformations on each of the
  * PrimFunc i/o buffers.
  * \param axis_separators Map from kOperatorName attr to axis_separators of 
each buffer_transforms
+ * \param input_axis_separators Map from kOperatorName attr to axis_separator 
for input buffer
  * \return The Pass.
  */
 TVM_DLL Pass AlterOpImpl(const Map<String, tir::PrimFunc>& op_impl_map,
                          const Map<String, Array<tir::IndexMap>>& 
op_buffer_transforms,
-                         const Map<String, Array<Array<IntImm>>>& 
axis_separators);
+                         const Map<String, Array<Array<IntImm>>>& 
axis_separators,
+                         const Map<String, Array<Array<IntImm>>>& 
input_axis_separators);
 
 /*!
  * \brief Layout conversion pass.
diff --git a/python/tvm/relax/op/manipulate.py 
b/python/tvm/relax/op/manipulate.py
index 9bd99020e9..da0a09cc7b 100644
--- a/python/tvm/relax/op/manipulate.py
+++ b/python/tvm/relax/op/manipulate.py
@@ -115,6 +115,7 @@ def layout_transform(
     index_map: Union[Callable, IndexMap],
     pad_value: Optional[Union[int, float, PrimValue]] = None,
     axis_separators: Optional[Union[int, IndexMap.AXIS_SEPARATOR]] = None,
+    input_axis_separators: Optional[Union[int, IndexMap.AXIS_SEPARATOR]] = 
None,
 ):
     """Modifies the layout of a tensor.
 
@@ -158,7 +159,12 @@ def layout_transform(
     if axis_separators is None:
         axis_separators = []
 
-    return _ffi_api.layout_transform(x, index_map, pad_value, axis_separators) 
 # type: ignore
+    if input_axis_separators is None:
+        input_axis_separators = []
+
+    return _ffi_api.layout_transform(
+        x, index_map, pad_value, axis_separators, input_axis_separators
+    )
 
 
 def permute_dims(x: Expr, axes: Optional[List[int]] = None) -> Expr:
diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py 
b/python/tvm/relax/transform/legalize_ops/manipulate.py
index e56240dc0d..4d30b97f64 100644
--- a/python/tvm/relax/transform/legalize_ops/manipulate.py
+++ b/python/tvm/relax/transform/legalize_ops/manipulate.py
@@ -181,6 +181,9 @@ def _layout_transform(bb: BlockBuilder, call: Call) -> Expr:
             name=name,
         )
 
+    def set_axis_sep(axis_sep: list, sch: tir.schedule, buffer_type: str):
+        sch.set_axis_separator(primfunc_name, (buffer_type, 0), 
axis_separators=axis_sep)
+
     index_map: tvm.tir.IndexMap = call.attrs.index_map
     pad_value = call.attrs.pad_value
     if pad_value is not None:
@@ -192,8 +195,10 @@ def _layout_transform(bb: BlockBuilder, call: Call) -> 
Expr:
             pad_value = float(0.0)
 
     axis_separators: tvm.tir.IndexMap.AXIS_SEPARATOR = 
call.attrs.axis_separators
+    input_axis_separators: tvm.tir.IndexMap.AXIS_SEPARATOR = 
call.attrs.input_axis_separators
+
     # Convert to list from array
-    axis_separators = list(map(lambda x: x.value, axis_separators))
+    axis_separators = [int(sep) for sep in axis_separators]
     primfunc_name = "te_layout_transform"
     _, padding_predicate = 
index_map.non_surjective_inverse(call.args[0].struct_info.shape)
     if not isinstance(padding_predicate, tvm.tir.expr.IntImm):
@@ -206,8 +211,10 @@ def _layout_transform(bb: BlockBuilder, call: Call) -> 
Expr:
     # Create TIR schedule to apply layout changes with axis separators
     sch = tir.Schedule(tir_func)
     sch.transform_layout(primfunc_name, ("write", 0), index_map, pad_value)
-    if len(axis_separators) != 0:
-        sch.set_axis_separator(primfunc_name, ("write", 0), 
axis_separators=axis_separators)
+    set_axis_sep(axis_separators, sch, "write")
+    if input_axis_separators is not None:
+        input_axis_separators = [int(sep) for sep in input_axis_separators]
+        set_axis_sep(input_axis_separators, sch, "read")
     gvar = bb.add_func(sch.mod["main"], primfunc_name)
     output_shape = index_map.map_shape(list(call_args[0].struct_info.shape))
     output_dtype = call_args[0].struct_info.dtype
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index 38e7994eb9..3528b4429e 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -24,6 +24,7 @@ from typing import Callable, Dict, List, Mapping, Optional, 
Sequence, Tuple, Uni
 import numpy as np  # type: ignore
 
 import tvm.ir
+from tvm.ir.container import Array
 from tvm.relax import Expr, Var, StructInfo
 from tvm.relax.dpl import DFPattern
 from tvm.runtime import NDArray, Object
@@ -1280,6 +1281,7 @@ def AlterOpImpl(
     op_impl_map: Dict[str, PrimFunc],
     op_buffer_transforms: Dict[str, List[Union[IndexMap, Callable]]],
     op_buffer_axis_separators: Dict[str, List[Union[IndexMap.AXIS_SEPARATOR, 
Callable]]],
+    op_buffer_input_axis_separators: Dict[str, 
List[Union[IndexMap.AXIS_SEPARATOR, Callable]]],
 ):
     """Replace all PrimFunc's which have matching 'operator_name' attribute, 
with replacement
     PrimFunc that could possibly have different layouts on i/o buffers. The 
layout
@@ -1295,6 +1297,8 @@ def AlterOpImpl(
         op_kind to layout transformation map for each of the buffers
     op_buffer_axis_separators: Dict[str, List[Union[IndexMap.AXIS_SEPARATOR, 
Callable]]]
         op_kind to axis_separator for each index_map
+    op_buffer_input_axis_separators: Dict[str, 
List[Union[IndexMap.AXIS_SEPARATOR, Callable]]]
+        op_kind to axis_separator for input index_map
 
     Returns
     -------
@@ -1303,13 +1307,19 @@ def AlterOpImpl(
     for operator_name, transform_list in op_buffer_transforms.items():
         l = []
         for transform in transform_list:
+            # Extract the index_map
             if isinstance(transform, Callable):
                 transform = IndexMap.from_func_with_separators(transform)[0]
+            elif isinstance(transform, (Array, tuple)) and 
isinstance(transform[0], IndexMap):
+                transform = transform[0]
             l.append(transform)
         op_buffer_transforms[operator_name] = l
 
     return _ffi_api.AlterOpImpl(
-        op_impl_map, op_buffer_transforms, op_buffer_axis_separators
+        op_impl_map,
+        op_buffer_transforms,
+        op_buffer_axis_separators,
+        op_buffer_input_axis_separators,
     )  # type: ignore
 
 
diff --git a/src/relax/op/tensor/manipulate.cc 
b/src/relax/op/tensor/manipulate.cc
index ad2a812c82..07c90756bf 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -472,11 +472,13 @@ TVM_REGISTER_OP("relax.flatten")
 TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs);
 
 Expr layout_transform(Expr x, tir::IndexMap index_map, Optional<PrimValue> 
pad_value,
-                      Optional<Array<IntImm>> axis_separators) {
+                      Optional<Array<IntImm>> axis_separators,
+                      Optional<Array<IntImm>> input_axis_separators) {
   ObjectPtr<LayoutTransformAttrs> attrs = make_object<LayoutTransformAttrs>();
   attrs->index_map = std::move(index_map);
   attrs->pad_value = std::move(pad_value);
   attrs->axis_separators = std::move(axis_separators);
+  attrs->input_axis_separators = std::move(input_axis_separators);
 
   static const Op& op = Op::Get("relax.layout_transform");
   return Call(op, {std::move(x)}, Attrs{attrs}, {});
diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h
index b19e3b8507..32aa107768 100644
--- a/src/relax/op/tensor/manipulate.h
+++ b/src/relax/op/tensor/manipulate.h
@@ -67,10 +67,12 @@ Expr flatten(Expr x);
  * not specified, any value can be used.
  * \param axis_separators Array of values to differentiate between input axes
  * when generating flattened output axes.
+ * \param input axis_separators Array of values for input buffer.
  * \return The transformed result.
  */
 Expr layout_transform(Expr x, tir::IndexMap index_map, Optional<PrimValue> 
pad_value,
-                      Optional<Array<IntImm>> axis_separators);
+                      Optional<Array<IntImm>> axis_separators,
+                      Optional<Array<IntImm>> input_axis_separators = NullOpt);
 
 /*!
  * \brief Permutes the dimensions of an array.
diff --git a/src/relax/transform/alter_op_impl.cc 
b/src/relax/transform/alter_op_impl.cc
index 2cb226d56e..aaf643f801 100644
--- a/src/relax/transform/alter_op_impl.cc
+++ b/src/relax/transform/alter_op_impl.cc
@@ -81,12 +81,14 @@ class AlterOpImplMutator : public ExprMutator {
  public:
   AlterOpImplMutator(const IRModule& mod, const Map<String, tir::PrimFunc>& 
op_impl_map,
                      const Map<String, Array<IndexMap>>& op_buffer_transforms_,
-                     const Map<String, Array<Array<IntImm>>>& axis_separators_)
+                     const Map<String, Array<Array<IntImm>>>& axis_separators_,
+                     const Map<String, Array<Array<IntImm>>>& 
input_axis_separators_)
       : ExprMutator(mod),
         mod_(mod),
         op_impl_map_(op_impl_map),
         op_buffer_transforms__(op_buffer_transforms_),
-        op_buffer_axis_separators__(axis_separators_) {}
+        op_buffer_axis_separators__(axis_separators_),
+        op_buffer_input_axis_separators__(input_axis_separators_) {}
 
   IRModule Run() {
     for (const auto& gv : mod_->GetGlobalVars()) {
@@ -127,9 +129,12 @@ class AlterOpImplMutator : public ExprMutator {
 
     Array<IndexMap> buffer_transforms;
     Optional<Array<Array<IntImm>>> axis_separators;
+    Optional<Array<Array<IntImm>>> input_axis_separators;
     if (op_buffer_transforms__.count(op_kind)) buffer_transforms = 
op_buffer_transforms__[op_kind];
     if (op_buffer_axis_separators__.count(op_kind))
       axis_separators = op_buffer_axis_separators__[op_kind];
+    if (op_buffer_input_axis_separators__.count(op_kind))
+      input_axis_separators = op_buffer_input_axis_separators__[op_kind];
 
     ICHECK(buffer_transforms.empty() || buffer_transforms.size() == 
replacement_func->params.size())
         << "Either the i/o buffers do not require any transformations or 
transformations for each "
@@ -140,7 +145,8 @@ class AlterOpImplMutator : public ExprMutator {
     GlobalVar replacement_gv = GetOrCreateGlobalVarForFunc(replacement_func, 
op_kind);
 
     auto call_tir_inputs_tuple = GetRef<Tuple>(call->args[1].as<TupleNode>());
-    Tuple updated_inputs = UpdateInputs(call_tir_inputs_tuple, 
buffer_transforms, axis_separators);
+    Tuple updated_inputs = UpdateInputs(call_tir_inputs_tuple, 
buffer_transforms, axis_separators,
+                                        input_axis_separators);
 
     ICHECK_EQ(call->sinfo_args.size(), 1) << "call_tir sinfo_args.size() is 
expected to be 1";
     StructInfo updated_ret_sinfo = UpdateStructInfo(call->sinfo_args[0], 
buffer_transforms);
@@ -148,7 +154,8 @@ class AlterOpImplMutator : public ExprMutator {
         Call(call_tir_op_, {replacement_gv, updated_inputs}, call->attrs, 
{updated_ret_sinfo}));
 
     // Now transform each of the outputs to previous layout.
-    return TransformOutputs(updated_call, buffer_transforms, 
call->sinfo_args[0], axis_separators);
+    return TransformOutputs(updated_call, buffer_transforms, 
call->sinfo_args[0], axis_separators,
+                            input_axis_separators);
   }
 
   Array<TensorStructInfo> GetTensorStructInfoPerOutput(const StructInfo& 
output_sinfo) {
@@ -175,7 +182,8 @@ class AlterOpImplMutator : public ExprMutator {
   }
 
   Expr TransformLayout(const Expr& expr, const IndexMap& index_map,
-                       const Array<IntImm>& axis_separators) {
+                       const Array<IntImm>& axis_separators,
+                       const Array<IntImm>& input_axis_separators) {
     if (IsScalarConstant(expr) || index_map.get() == nullptr) {
       return expr;
     }
@@ -185,6 +193,7 @@ class AlterOpImplMutator : public ExprMutator {
     // so would confuse the structural equality check.
     attrs->index_map = std::move(DeepCopyIndexMap(index_map));
     attrs->axis_separators = std::move(axis_separators);
+    attrs->input_axis_separators = std::move(input_axis_separators);
     return Call(layout_transform_op_, {expr}, Attrs{std::move(attrs)}, {});
   }
 
@@ -232,7 +241,8 @@ class AlterOpImplMutator : public ExprMutator {
 
   Expr TransformLayoutInverse(const Expr& expr, const IndexMap& index_map,
                               const TensorStructInfo& old_tensor_sinfo,
-                              const Array<IntImm>& axis_separator) {
+                              const Array<IntImm>& axis_separator,
+                              const Array<IntImm>& input_axis_separator) {
     if (IsScalarConstant(expr) || index_map.get() == nullptr) {
       return expr;
     }
@@ -243,10 +253,10 @@ class AlterOpImplMutator : public ExprMutator {
         index_map.NonSurjectiveInverse(initial_ranges, &analyzer);
 
     if (tir::is_zero(padding_predicate)) {
-      return TransformLayout(expr, inverse_index_map, axis_separator);
+      return TransformLayout(expr, inverse_index_map, axis_separator, 
input_axis_separator);
     } else {
-      auto padded_expr =
-          builder_->Normalize(TransformLayout(expr, inverse_index_map, 
axis_separator));
+      auto padded_expr = builder_->Normalize(
+          TransformLayout(expr, inverse_index_map, axis_separator, 
input_axis_separator));
       const auto& tensor_sinfo = 
Downcast<TensorStructInfo>(padded_expr->struct_info_);
 
       GlobalVar gv_remove_pad = GetOrCreateRemovePadOp(old_shape, 
tensor_sinfo->dtype);
@@ -277,19 +287,26 @@ class AlterOpImplMutator : public ExprMutator {
    * \brief Updates call inputs with layout transformed inputs
    */
   Tuple UpdateInputs(const Tuple& inputs, const Array<IndexMap>& transforms,
-                     const Optional<Array<Array<IntImm>>>& axis_separators) {
+                     const Optional<Array<Array<IntImm>>>& axis_separators,
+                     const Optional<Array<Array<IntImm>>>& 
input_axis_separators) {
     if (transforms.empty()) return inputs;
 
     Array<Expr> updated_inputs;
     int index = 0;
     for (const auto& input : inputs->fields) {
       Array<IntImm> axis_separator;
+      Array<IntImm> input_axis_separator;
       if (axis_separators.defined()) {
         Array<Array<IntImm>> axis_separators_value = axis_separators.value();
         axis_separator = axis_separators_value[index];
       }
+      if (input_axis_separators.defined()) {
+        Array<Array<IntImm>> input_axis_separators_value = 
input_axis_separators.value();
+        input_axis_separator = input_axis_separators_value[index];
+      }
       auto transform = transforms[index++];
-      updated_inputs.push_back(TransformLayout(input, transform, 
axis_separator));
+      updated_inputs.push_back(
+          TransformLayout(input, transform, axis_separator, 
input_axis_separator));
     }
     return Tuple(updated_inputs);
   }
@@ -338,12 +355,13 @@ class AlterOpImplMutator : public ExprMutator {
 
   Expr TransformOutputs(const Expr& expr, const Array<IndexMap>& 
buffer_transforms,
                         const StructInfo& old_struct_info,
-                        const Optional<Array<Array<IntImm>>>& axis_separators) 
{
+                        const Optional<Array<Array<IntImm>>>& axis_separators,
+                        const Optional<Array<Array<IntImm>>>& 
input_axis_separators) {
     if (buffer_transforms.empty()) return expr;
 
     Array<TensorStructInfo> old_output_sinfo = 
GetTensorStructInfoPerOutput(old_struct_info);
 
-    Array<IntImm> axis_sep;
+    Array<IntImm> axis_sep, input_axis_sep;
     size_t num_outputs = old_output_sinfo.size();
     if (num_outputs == 0) return expr;
 
@@ -355,7 +373,12 @@ class AlterOpImplMutator : public ExprMutator {
         Array<Array<IntImm>> axis_separators_value = axis_separators.value();
         axis_sep = axis_separators_value[first_output_index];
       }
-      return TransformLayoutInverse(expr, output_map, old_output_sinfo[0], 
axis_sep);
+      if (input_axis_separators.defined()) {
+        Array<Array<IntImm>> input_axis_separators_value = 
input_axis_separators.value();
+        input_axis_sep = input_axis_separators_value[first_output_index];
+      }
+      return TransformLayoutInverse(expr, output_map, old_output_sinfo[0], 
axis_sep,
+                                    input_axis_sep);
     }
 
     // In case of more than one output, we would have to get each item of the 
output tuple,
@@ -367,9 +390,13 @@ class AlterOpImplMutator : public ExprMutator {
         Array<Array<IntImm>> axis_separators_value = axis_separators.value();
         axis_sep = axis_separators_value[i + first_output_index];
       }
+      if (input_axis_separators.defined()) {
+        Array<Array<IntImm>> input_axis_separators_value = 
input_axis_separators.value();
+        input_axis_sep = input_axis_separators_value[i + first_output_index];
+      }
       auto output = builder_->Normalize(TupleGetItem(expr, 
static_cast<int>(i)));
-      transformed_outputs.push_back(
-          TransformLayoutInverse(output, output_map, old_output_sinfo[i], 
axis_sep));
+      transformed_outputs.push_back(TransformLayoutInverse(output, output_map, 
old_output_sinfo[i],
+                                                           axis_sep, 
input_axis_sep));
     }
     return Tuple(transformed_outputs);
   }
@@ -387,6 +414,8 @@ class AlterOpImplMutator : public ExprMutator {
   const Map<String, Array<IndexMap>>& op_buffer_transforms__;
   /*! \brief Map from kOperatorName attribute to the axis separatos on i/o 
buffers */
   const Map<String, Array<Array<IntImm>>>& op_buffer_axis_separators__;
+  /*! \brief Map from kOperatorName attribute to the input axis separatos */
+  const Map<String, Array<Array<IntImm>>>& op_buffer_input_axis_separators__;
 
   const Op& call_tir_op_ = Op::Get("relax.call_tir");
   const Op& layout_transform_op_ = Op::Get("relax.layout_transform");
@@ -396,10 +425,13 @@ namespace transform {
 
 Pass AlterOpImpl(const Map<String, tir::PrimFunc>& op_impl_map,
                  const Map<String, Array<IndexMap>>& op_buffer_transforms_,
-                 const Map<String, Array<Array<IntImm>>>& axis_separators_) {
+                 const Map<String, Array<Array<IntImm>>>& axis_separators_,
+                 const Map<String, Array<Array<IntImm>>>& 
input_axis_separators_) {
   runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = 
[=](IRModule mod,
                                                                             
PassContext pc) {
-    return AlterOpImplMutator(mod, op_impl_map, op_buffer_transforms_, 
axis_separators_).Run();
+    return AlterOpImplMutator(mod, op_impl_map, op_buffer_transforms_, 
axis_separators_,
+                              input_axis_separators_)
+        .Run();
   };
   return CreateModulePass(/*pass_function=*/pass_func,  //
                           /*opt_level=*/0,              //
diff --git a/tests/python/relax/test_transform_alter_op_impl.py 
b/tests/python/relax/test_transform_alter_op_impl.py
index f2bad31f21..f1824eba6b 100644
--- a/tests/python/relax/test_transform_alter_op_impl.py
+++ b/tests/python/relax/test_transform_alter_op_impl.py
@@ -26,12 +26,19 @@ kOperatorName = "operator_name"
 
 
 def _check(
-    before, expected, operator_name, replacement_primfunc, layout_changes, 
axis_separator=None
+    before,
+    expected,
+    operator_name,
+    replacement_primfunc,
+    layout_changes,
+    axis_separator=None,
+    input_axis_separator=None,
 ):
     after = relax.transform.AlterOpImpl(
         {operator_name: replacement_primfunc},
         {operator_name: layout_changes},
         {operator_name: axis_separator},
+        {operator_name: input_axis_separator},
     )(before)
     after = relax.transform.DeadCodeElimination()(after)
     tvm.ir.assert_structural_equal(after, expected)
@@ -572,5 +579,81 @@ def test_reshape():
     )
 
 
+def test_input_axis_separator():
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @T.prim_func(private=True)
+        def some_op(arg0: T.Buffer((16,), "float32"), arg1: T.Buffer((16,), 
"float32"), output0: T.Buffer((16,), "float32"), output1: T.Buffer((16,), 
"float32")):
+            T.func_attr({"operator_name": "relax.some_op"})
+            for ax0 in range(16):
+                with T.block("T_add"):
+                    v_ax0 = T.axis.spatial(16, ax0)
+                    T.reads(arg0[v_ax0], arg1[v_ax0])
+                    T.writes(output0[v_ax0], output1[v_ax0])
+                    output0[v_ax0] = arg0[v_ax0] + arg1[v_ax0]
+                    output1[v_ax0] = arg0[v_ax0] - arg1[v_ax0]
+
+        @R.function
+        def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), 
dtype="float32")) -> R.Tuple(R.Tensor((16,), dtype="float32"), R.Tensor((16,), 
dtype="float32")):
+            with R.dataflow():
+                gv = R.call_tir(Before.some_op, (x, y), 
out_sinfo=[R.Tensor((16,), dtype="float32"), R.Tensor((16,), dtype="float32")])
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func(private=True)
+        def relax_some_op_replacement(arg0: T.Buffer((4, 4), "float32"), arg1: 
T.Buffer((4, 4), "float32"), output0: T.Buffer((4, 4), "float32"), output1: 
T.Buffer((4, 4), "float32")):
+            T.func_attr({"operator_name": "relax.some_op"})
+            for ax0, ax1 in T.grid(4, 4):
+                with T.block("T_add"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    output0[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, 
v_ax1]
+                    output1[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] - arg1[v_ax0, 
v_ax1]
+
+        @R.function
+        def main(x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), 
dtype="float32")) -> R.Tuple(R.Tensor((16,), dtype="float32"), R.Tensor((16,), 
dtype="float32")):
+            with R.dataflow():
+                lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform(x, 
index_map=lambda i: (i // 4, i % 4), pad_value=None, axis_separators=[1])
+                lv1: R.Tensor((4, 4), dtype="float32") = R.layout_transform(y, 
index_map=lambda i: (i // 4, i % 4), pad_value=None, axis_separators=[1])
+                lv2 = R.call_tir(Expected.relax_some_op_replacement, (lv, 
lv1), out_sinfo=[R.Tensor((4, 4), dtype="float32"), R.Tensor((4, 4), 
dtype="float32")])
+                lv3: R.Tensor((4, 4), dtype="float32") = lv2[0]
+                lv4: R.Tensor((16,), dtype="float32") = 
R.layout_transform(lv3, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), 
pad_value=None, axis_separators=[], input_axis_separators=[1])
+                lv5: R.Tensor((4, 4), dtype="float32") = lv2[1]
+                lv6: R.Tensor((16,), dtype="float32") = 
R.layout_transform(lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), 
pad_value=None, axis_separators=[], input_axis_separators=[1])
+                gv: R.Tuple(R.Tensor((16,), dtype="float32"), R.Tensor((16,), 
dtype="float32")) = (lv4, lv6)
+                R.output(gv)
+            return gv
+
+    @T.prim_func(private=True)
+    def some_op_2d(arg0: T.Buffer((4, 4), "float32"), arg1: T.Buffer((4, 4), 
"float32"), output0: T.Buffer((4, 4), "float32"), output1: T.Buffer((4, 4), 
"float32")):
+        for ax0, ax1 in T.grid(4, 4):
+            with T.block("T_add"):
+                v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                output0[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] + arg1[v_ax0, v_ax1]
+                output1[v_ax0, v_ax1] = arg0[v_ax0, v_ax1] - arg1[v_ax0, v_ax1]
+    # fmt: on
+
+    index_map_axis_sep = IndexMap.from_func_with_separators(
+        lambda i: (i // 4, IndexMap.AXIS_SEPARATOR, i % 4)
+    )
+
+    _check(
+        Before,
+        Expected,
+        operator_name="relax.some_op",
+        replacement_primfunc=some_op_2d,
+        layout_changes=[
+            index_map_axis_sep,
+            index_map_axis_sep,
+            index_map_axis_sep,
+            index_map_axis_sep,
+        ],
+        axis_separator=[index_map_axis_sep[1], index_map_axis_sep[1], [], []],
+        input_axis_separator=[[], [], index_map_axis_sep[1], 
index_map_axis_sep[1]],
+    )
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to