This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0dfc5f955e [Unity] Check for transpose and dynamic shape in 
AdjustMatmulOrder (#16589)
0dfc5f955e is described below

commit 0dfc5f955e2dd883527638e3d5b1f6844971af3a
Author: Eric Lunderberg <lunderb...@users.noreply.github.com>
AuthorDate: Mon May 13 14:46:22 2024 -0500

    [Unity] Check for transpose and dynamic shape in AdjustMatmulOrder (#16589)
    
    When determining whether to evaluate matrix multiplications as
    `(A*B)*C` or as `A*(B*C)`, dynamic shapes may occur (e.g. a dynamic
    LoRA rank).  This commit tests for these cases, and improves the
    arithmetic bounds used to prove which order of evaluation is
    preferred.
    
    As part of the implementation, this commit also adds a utility
    `CollectNonNegativeExpressions`, exposed to the python API as
    `relax.analysis.collect_non_negative_expresisons`.  This utility
    collects expressions within a `StructInfo` which must be non-negative,
    based on the location where they appear.  For example, the size of a
    tensor along each dimension must be non-negative.  Unlike the existing
    `defineable_tir_vars_in_struct_info`, this will include the `N-2`
    expression in `R.Tensor([N-2])`.
---
 include/tvm/relax/analysis.h                       |  13 ++
 python/tvm/relax/analysis/__init__.py              |   1 +
 python/tvm/relax/analysis/analysis.py              |  27 ++++
 src/relax/analysis/struct_info_analysis.cc         |  45 ++++++
 src/relax/ir/expr_functor.cc                       |  11 ++
 src/relax/transform/adjust_matmul_order.cc         |  83 ++++++++---
 .../relax/test_analysis_struct_info_analysis.py    |  43 ++++++
 .../relax/test_transform_adjust_matmul_order.py    | 164 +++++++++++++++++++++
 8 files changed, 368 insertions(+), 19 deletions(-)

diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h
index fa928d082d..527327d56a 100644
--- a/include/tvm/relax/analysis.h
+++ b/include/tvm/relax/analysis.h
@@ -304,6 +304,19 @@ TVM_DLL Array<tir::Var> TIRVarsInStructInfo(const 
StructInfo& sinfo);
  */
 TVM_DLL Array<tir::Var> DefinableTIRVarsInStructInfo(const StructInfo& sinfo);
 
+/*! \brief Collect expressions whose usage requires them to be non-negative
+ *
+ * Any PrimExpr that is used as a tensor shape, or as an element in a
+ * ShapeExpr, may not be negative.  This utility function can be used
+ * to generate assertions prior to calling a kernel, or to provide
+ * assumptions within a kernel that may be useful for simplification.
+ *
+ * \param sinfo The struct info to be analyzed
+ *
+ * \return A list of non-negative expressions.
+ */
+TVM_DLL Array<PrimExpr> CollectNonNegativeExpressions(const StructInfo& sinfo);
+
 /*!
  * \brief Get the TIR variables that defined in the input function.
  * The returned list is deduplicated - each TIR variable will appear at most 
once.
diff --git a/python/tvm/relax/analysis/__init__.py 
b/python/tvm/relax/analysis/__init__.py
index 06b4f64326..592e3bb5db 100644
--- a/python/tvm/relax/analysis/__init__.py
+++ b/python/tvm/relax/analysis/__init__.py
@@ -21,6 +21,7 @@ from .analysis import (
     all_global_vars,
     all_vars,
     bound_vars,
+    collect_non_negative_expressions,
     computable_at_compile_time,
     contains_impure_call,
     definable_tir_vars_in_struct_info,
diff --git a/python/tvm/relax/analysis/analysis.py 
b/python/tvm/relax/analysis/analysis.py
index e6eaff3711..edcf02bf6a 100644
--- a/python/tvm/relax/analysis/analysis.py
+++ b/python/tvm/relax/analysis/analysis.py
@@ -202,6 +202,33 @@ def definable_tir_vars_in_struct_info(sinfo: StructInfo) 
-> List[tir.Var]:
     return _ffi_api.DefinableTIRVarsInStructInfo(sinfo)  # type: ignore
 
 
+def collect_non_negative_expressions(sinfo: StructInfo) -> List[tir.PrimExpr]:
+    """Collect TIR expressions used in non-negative contexts
+
+    Get TIR variables that are non-negative within the context where
+    the struct info is used.  For example, any expression used as a
+    tensor shape.
+
+    The returned list is deduplicated - each TIR expression will
+    appear at most once.  The order of the list is in the order of
+    occurrence within the struct info.
+
+    Parameters
+    ----------
+    sinfo : StructInfo
+        The struct info object to be analyzed.
+
+    Returns
+    -------
+    ret : List[tir.Var]
+
+        The list of TIR variables that can be defined from the StructInfo
+
+    """
+
+    return _ffi_api.CollectNonNegativeExpressions(sinfo)  # type: ignore
+
+
 def defined_symbolic_vars(func: Function) -> List[Var]:
     """Get the TIR variables that defined in the input function.
     The returned list is deduplicated - each TIR variable will appear at most 
once.
diff --git a/src/relax/analysis/struct_info_analysis.cc 
b/src/relax/analysis/struct_info_analysis.cc
index 0432c96e2e..e811b01cf5 100644
--- a/src/relax/analysis/struct_info_analysis.cc
+++ b/src/relax/analysis/struct_info_analysis.cc
@@ -1231,6 +1231,51 @@ 
TVM_REGISTER_GLOBAL("relax.analysis.TIRVarsInStructInfo").set_body_typed(TIRVars
 TVM_REGISTER_GLOBAL("relax.analysis.DefinableTIRVarsInStructInfo")
     .set_body_typed(DefinableTIRVarsInStructInfo);
 
+class NonNegativeExpressionCollector : relax::StructInfoVisitor {
+ public:
+  static Array<PrimExpr> Collect(const StructInfo& sinfo) {
+    NonNegativeExpressionCollector visitor;
+    visitor(sinfo);
+    return visitor.expressions_;
+  }
+
+ private:
+  void VisitStructInfo_(const TensorStructInfoNode* op) override {
+    if (op->shape.defined()) {
+      VisitStructInfo(GetStructInfo(op->shape.value()));
+    }
+  }
+
+  void VisitStructInfo_(const PrimStructInfoNode* op) override {
+    // Unlike the expressions in TensorStructInfo or ShapeStructInfo,
+    // PrimStructInfo may contain negative values.  This override
+    // prevents calling VisitStructInfoExprField from the default
+    // StructInfoVisitor implementation.
+  }
+
+  void VisitStructInfoExprField(const PrimExpr& size_expr) override {
+    if (auto size_int = size_expr.as<IntImmNode>(); size_int && 
size_int->value >= 0) {
+      // Avoid cluttering the result with non-negative integers
+      return;
+    }
+
+    if (!dedup_lookup_.count(size_expr)) {
+      expressions_.push_back(size_expr);
+      dedup_lookup_.insert(size_expr);
+    }
+  }
+
+  Array<PrimExpr> expressions_;
+  std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> dedup_lookup_;
+};
+
+Array<PrimExpr> CollectNonNegativeExpressions(const StructInfo& sinfo) {
+  return NonNegativeExpressionCollector::Collect(sinfo);
+}
+
+TVM_REGISTER_GLOBAL("relax.analysis.CollectNonNegativeExpressions")
+    .set_body_typed(CollectNonNegativeExpressions);
+
 class SymbolicVarCollector : public relax::ExprVisitor,
                              public relax::StructInfoVisitor,
                              public tir::ExprVisitor {
diff --git a/src/relax/ir/expr_functor.cc b/src/relax/ir/expr_functor.cc
index e01b710df1..dbfaf60fec 100644
--- a/src/relax/ir/expr_functor.cc
+++ b/src/relax/ir/expr_functor.cc
@@ -779,7 +779,18 @@ Var ExprMutator::VisitVarDef(const Var& var) {
 Expr ExprMutator::VisitWithNewScope(const Expr& expr, Optional<Array<Var>> 
params) {
   ICHECK(expr->IsInstance<SeqExprNode>())
       << "Normal form requires all new scope is stored as SeqExpr";
+
+  PrimExpr constraint = Bool(true);
+  if (params.defined()) {
+    auto non_negative_expressions =
+        
CollectNonNegativeExpressions(TupleStructInfo(params.value().Map(GetStructInfo)));
+    for (const auto& expr : non_negative_expressions) {
+      constraint = constraint && (expr >= 0);
+    }
+  }
+
   builder_->BeginScope(params);
+  With<arith::ConstraintContext> context(builder_->GetAnalyzer(), constraint);
   Expr ret = this->VisitExpr(expr);
   builder_->EndScope();
   return ret;
diff --git a/src/relax/transform/adjust_matmul_order.cc 
b/src/relax/transform/adjust_matmul_order.cc
index 399860987c..10b0267851 100644
--- a/src/relax/transform/adjust_matmul_order.cc
+++ b/src/relax/transform/adjust_matmul_order.cc
@@ -33,6 +33,7 @@
 #include <vector>
 
 #include "../op/tensor/linear_algebra.h"
+#include "../op/tensor/manipulate.h"
 
 namespace tvm {
 namespace relax {
@@ -60,11 +61,34 @@ std::tuple<DFPattern, TypedPackedFunc<Expr(Expr, 
Map<DFPattern, Expr>)>> CreateP
   DFPattern pat_c = WildcardPattern();
 
   auto pat_matmul = IsOp("relax.matmul");
+  auto pat_permute_dims = IsOp("relax.permute_dims");
 
   auto pat_matmul_on_lhs = pat_matmul(pat_matmul(pat_a, pat_b), pat_c);
   auto pat_matmul_on_rhs = pat_matmul(pat_a, pat_matmul(pat_b, pat_c));
 
-  auto pat = pat_matmul_on_lhs | pat_matmul_on_rhs;
+  auto pat_permuted_matmul_on_lhs = 
pat_matmul(pat_permute_dims(pat_matmul(pat_b, pat_a)), pat_c);
+  auto pat_permuted_matmul_on_rhs = pat_matmul(pat_a, 
pat_permute_dims(pat_matmul(pat_c, pat_b)));
+
+  auto pat = pat_matmul_on_lhs | pat_matmul_on_rhs | 
pat_permuted_matmul_on_lhs |
+             pat_permuted_matmul_on_rhs;
+
+  PrimExpr symbolic_var_constraints = Bool(true);
+  if (auto upper_bounds = func->GetAttr<Map<ObjectRef, 
ObjectRef>>("tir_var_upper_bound")) {
+    Map<String, tir::Var> name_lookup;
+    for (const auto& tir_var : TIRVarsInStructInfo(GetStructInfo(func))) {
+      name_lookup.Set(tir_var->name_hint, tir_var);
+      symbolic_var_constraints = symbolic_var_constraints && (0 <= tir_var);
+    }
+
+    for (const auto& [key, obj_bound] : upper_bounds.value()) {
+      auto tir_var_name = Downcast<String>(key);
+      if (auto opt_var = name_lookup.Get(tir_var_name)) {
+        auto var = opt_var.value();
+        auto expr_bound = Downcast<PrimExpr>(obj_bound);
+        symbolic_var_constraints = symbolic_var_constraints && (var < 
expr_bound);
+      }
+    }
+  }
 
   auto rewriter = [=](Expr expr, Map<DFPattern, Expr> matches) -> Expr {
     auto expr_a = matches[pat_a];
@@ -78,23 +102,6 @@ std::tuple<DFPattern, TypedPackedFunc<Expr(Expr, 
Map<DFPattern, Expr>)>> CreateP
       return expr;
     }
 
-    // If two of the three are compile-time, group those two values
-    // together, to allow them to be lifted out and pre-computed.
-    if (is_compile_time(expr_a) && is_compile_time(expr_b)) {
-      return matmul(matmul(expr_a, expr_b, DataType::Void()), expr_c, 
DataType::Void());
-    } else if (is_compile_time(expr_b) && is_compile_time(expr_c)) {
-      return matmul(expr_a, matmul(expr_b, expr_c, DataType::Void()), 
DataType::Void());
-    }
-
-    // Otherwise, select the order that reduces the total number of
-    // operations required, assuming a naive matmul.
-
-    // Matmul on LHS: ([N,R]*[R,M]) * [M,batch]
-    // Matmul on RHS: [N,R] * ([R,M]*[M,batch])
-    //
-    // LHS first: `N*R*M + N*M*batch = N*M*(R+batch)`
-    // RHS first: `N*R*batch + R*M*batch = (N+M)*R*batch`
-
     auto get_shape = [](Expr expr) -> Optional<Array<PrimExpr>> {
       auto sinfo = expr->struct_info_.as<TensorStructInfoNode>();
       if (sinfo) {
@@ -115,6 +122,39 @@ std::tuple<DFPattern, TypedPackedFunc<Expr(Expr, 
Map<DFPattern, Expr>)>> CreateP
     auto shape_b = opt_shape_b.value();
     auto shape_c = opt_shape_c.value();
 
+    if (matches.count(pat_permuted_matmul_on_lhs)) {
+      expr_a = permute_dims(expr_a, NullOpt);
+      expr_b = permute_dims(expr_b, NullOpt);
+      CHECK_EQ(shape_a.size(), 2);
+      CHECK_EQ(shape_b.size(), 2);
+      shape_a = {shape_a[1], shape_a[0]};
+      shape_b = {shape_b[1], shape_b[0]};
+    } else if (matches.count(pat_permuted_matmul_on_rhs)) {
+      expr_b = permute_dims(expr_b, NullOpt);
+      expr_c = permute_dims(expr_c, NullOpt);
+      CHECK_EQ(shape_b.size(), 2);
+      CHECK_EQ(shape_c.size(), 2);
+      shape_b = {shape_b[1], shape_b[0]};
+      shape_c = {shape_c[1], shape_c[0]};
+    }
+
+    // If two of the three are compile-time, group those two values
+    // together, to allow them to be lifted out and pre-computed.
+    if (is_compile_time(expr_a) && is_compile_time(expr_b)) {
+      return matmul(matmul(expr_a, expr_b, DataType::Void()), expr_c, 
DataType::Void());
+    } else if (is_compile_time(expr_b) && is_compile_time(expr_c)) {
+      return matmul(expr_a, matmul(expr_b, expr_c, DataType::Void()), 
DataType::Void());
+    }
+
+    // Otherwise, select the order that reduces the total number of
+    // operations required, assuming a naive matmul.
+
+    // Matmul on LHS: ([N,R]*[R,M]) * [M,batch]
+    // Matmul on RHS: [N,R] * ([R,M]*[M,batch])
+    //
+    // LHS first: `N*R*M + N*M*batch = N*M*(R+batch)`
+    // RHS first: `N*R*batch + R*M*batch = (N+M)*R*batch`
+
     if (shape_a.size() == 1) {
       shape_a = {IntImm(shape_a[0].dtype(), 1), shape_a[0]};
     }
@@ -142,8 +182,13 @@ std::tuple<DFPattern, TypedPackedFunc<Expr(Expr, 
Map<DFPattern, Expr>)>> CreateP
     auto ops_with_rhs_first = (size_M + size_N) * size_R * size_B;
 
     arith::Analyzer analyzer;
+    
analyzer.rewrite_simplify.SetEnabledExtensions(static_cast<arith::RewriteSimplifier::Extension>(
+        analyzer.rewrite_simplify.GetEnabledExtensions() |
+        arith::RewriteSimplifier::Extension::kComparisonOfProductAndSum));
+    With<arith::ConstraintContext> func_attr_constraint(&analyzer, 
symbolic_var_constraints);
     With<arith::ConstraintContext> analyzer_constraint(
-        &analyzer, size_N >= 0 && size_R >= 0 && size_M >= 0 && size_B >= 0);
+        &analyzer, size_N > 0 && size_R > 0 && size_M > 0 && size_B > 0);
+
     if (analyzer.CanProve(ops_with_lhs_first < ops_with_rhs_first)) {
       return matmul(matmul(expr_a, expr_b, DataType::Void()), expr_c, 
DataType::Void());
     } else if (analyzer.CanProve(ops_with_rhs_first < ops_with_lhs_first)) {
diff --git a/tests/python/relax/test_analysis_struct_info_analysis.py 
b/tests/python/relax/test_analysis_struct_info_analysis.py
index b28df7b224..83b1ddd4fc 100644
--- a/tests/python/relax/test_analysis_struct_info_analysis.py
+++ b/tests/python/relax/test_analysis_struct_info_analysis.py
@@ -24,6 +24,7 @@ import tvm.testing
 from tvm import TVMError
 from tvm import relax as rx
 from tvm import tir, ir
+from tvm.script import relax as R
 
 
 def test_get_static_type_basic():
@@ -718,5 +719,47 @@ def 
test_collect_symbolic_var_from_non_tensor_params(param_type, param_order):
     assert free_vars == set()
 
 
+def test_collect_nonnegative_expressions():
+    @R.function
+    def func(
+        A: R.Tensor([1024, "M", "N-2"]),
+        B: R.Tensor([128, "N", "M+2"]),
+        C: R.Shape(["M", "N"]),
+        D: R.Prim(value="N"),
+    ):
+        return R.tuple()
+
+    M, N = list(func.params[2].struct_info.values)
+
+    # Expressions are de-duplicated, in order of their first appearance
+    tvm.ir.assert_structural_equal(
+        rx.analysis.collect_non_negative_expressions(func.struct_info),
+        [M, N - 2, N, M + 2],
+    )
+
+    # Tensor shapes can imply that their shapes are non-negative
+    tvm.ir.assert_structural_equal(
+        
rx.analysis.collect_non_negative_expressions(func.params[0].struct_info),
+        [M, N - 2],
+    )
+    tvm.ir.assert_structural_equal(
+        
rx.analysis.collect_non_negative_expressions(func.params[1].struct_info),
+        [N, M + 2],
+    )
+
+    # ShapeExpr values can imply that their contents are non-negative
+    tvm.ir.assert_structural_equal(
+        
rx.analysis.collect_non_negative_expressions(func.params[2].struct_info),
+        [M, N],
+    )
+
+    # PrimValue instances may contain negative values, and do not
+    # imply that their contents are non-negative.
+    tvm.ir.assert_structural_equal(
+        
rx.analysis.collect_non_negative_expressions(func.params[3].struct_info),
+        [],
+    )
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_transform_adjust_matmul_order.py 
b/tests/python/relax/test_transform_adjust_matmul_order.py
index 8b5a26682a..5112bf5384 100644
--- a/tests/python/relax/test_transform_adjust_matmul_order.py
+++ b/tests/python/relax/test_transform_adjust_matmul_order.py
@@ -347,5 +347,169 @@ class TestNoOpForFullyDynamicOnRHS(Base):
     Expected = Before
 
 
+class TestRHSPermuteDims(Base):
+    """Prefer (x*A)*B instead of x*(A*B)
+
+    Like `TestRHS`, but the weights on the RHS are transposed.
+    """
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor([16]),
+            A: R.Tensor([32, 2]),
+            B: R.Tensor([2, 16]),
+        ) -> R.Tensor([32]):
+            linear_weight: R.Tensor([32, 16]) = R.matmul(A, B)
+            matmul_weight: R.Tensor([16, 32]) = R.permute_dims(linear_weight)
+            out: R.Tensor([32]) = R.matmul(x, matmul_weight)
+            return out
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor([16]),
+            A: R.Tensor([32, 2]),
+            B: R.Tensor([2, 16]),
+        ) -> R.Tensor([32]):
+            B_transpose = R.permute_dims(B)
+            x: R.Tensor([2]) = R.matmul(x, B_transpose)
+            A_transpose = R.permute_dims(A)
+            x: R.Tensor([32]) = R.matmul(x, A_transpose)
+            return x
+
+
+class TestRHSPermuteDimsDynamic(Base):
+    """Prefer (x*A)*B instead of x*(A*B)
+
+    Like `TestRHSPermuteDims`, but the weights on the RHS have a
+    dynamic shape.
+    """
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor([16]),
+            A: R.Tensor([32, "lora_r"]),
+            B: R.Tensor(["lora_r", 16]),
+        ) -> R.Tensor([32]):
+            linear_weight: R.Tensor([32, 16]) = R.matmul(A, B)
+            matmul_weight: R.Tensor([16, 32]) = R.permute_dims(linear_weight)
+            out: R.Tensor([32]) = R.matmul(x, matmul_weight)
+            return out
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor([16]),
+            A: R.Tensor([32, "lora_r"]),
+            B: R.Tensor(["lora_r", 16]),
+        ) -> R.Tensor([32]):
+            lora_r = T.int64()
+            B_transpose = R.permute_dims(B)
+            x: R.Tensor([lora_r]) = R.matmul(x, B_transpose)
+            A_transpose = R.permute_dims(A)
+            x: R.Tensor([32]) = R.matmul(x, A_transpose)
+            return x
+
+
+class TestRHSPermuteDimsWithDynamicBatch(Base):
+    """Prefer (x*A)*B instead of x*(A*B)
+
+    Like `TestRHSPermuteDims`, but both the weights on the RHS and the
+    activations on the LHS have a dynamic dimension.
+
+    Unlike most of the tests for this transform, the
+    `tir_vars_upper_bound` attribute is required.  In order to make a
+    change, `AdjustMatmulOrder` must first prove that the modified
+    execution order reduces the number of computations.
+
+        ops_left_to_right = (batch_size + lora_r)*4096*4096
+        ops_right_to_left = (4096 + 4096)*batch_size*lora_r
+
+    Without an upper bound on `lora_r`, we cannot prove which of these
+    is the preferred execution order.  With the upper bound, TVM can
+    determine the preferred order using the following arithmethic
+    reasoning.
+
+        (batch_size + lora_r)*4096*4096 < (4096 + 4096)*batch_size*lora_r
+        (batch_size + lora_r)*2048 < batch_size*lora_r
+        1/batch_size + 1/lora_r < 1/2048
+
+    """
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor(["batch_size", 4096]),
+            A: R.Tensor([4096, "lora_r"]),
+            B: R.Tensor(["lora_r", 4096]),
+        ) -> R.Tensor(["batch_size", 4096]):
+            R.func_attr({"tir_var_upper_bound": {"lora_r": 2048}})
+            batch_size = T.int64()
+            linear_weight: R.Tensor([4096, 4096]) = R.matmul(A, B)
+            matmul_weight: R.Tensor([4096, 4096]) = 
R.permute_dims(linear_weight)
+            out: R.Tensor([batch_size, 4096]) = R.matmul(x, matmul_weight)
+            return out
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor(["batch_size", 4096]),
+            A: R.Tensor([4096, "lora_r"]),
+            B: R.Tensor(["lora_r", 4096]),
+        ) -> R.Tensor(["batch_size", 4096]):
+            R.func_attr({"tir_var_upper_bound": {"lora_r": 2048}})
+            lora_r = T.int64()
+            batch_size = T.int64()
+            B_transpose = R.permute_dims(B)
+            x: R.Tensor([batch_size, lora_r]) = R.matmul(x, B_transpose)
+            A_transpose = R.permute_dims(A)
+            x: R.Tensor([batch_size, 4096]) = R.matmul(x, A_transpose)
+            return x
+
+
+class TestRHSPermuteDimsDynamicWithSquareMatrix(Base):
+    """Prefer (x*A)*B instead of x*(A*B)
+
+    Like `TestRHSPermuteDims`, but the weights on the RHS have a
+    dynamic shape.
+    """
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor([32]),
+            A: R.Tensor([32, "lora_r"]),
+            B: R.Tensor(["lora_r", 32]),
+        ) -> R.Tensor([32]):
+            linear_weight: R.Tensor([32, 32]) = R.matmul(A, B)
+            matmul_weight: R.Tensor([32, 32]) = R.permute_dims(linear_weight)
+            out: R.Tensor([32]) = R.matmul(x, matmul_weight)
+            return out
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor([32]),
+            A: R.Tensor([32, "lora_r"]),
+            B: R.Tensor(["lora_r", 32]),
+        ) -> R.Tensor([32]):
+            lora_r = T.int64()
+            B_transpose = R.permute_dims(B)
+            x: R.Tensor([lora_r]) = R.matmul(x, B_transpose)
+            A_transpose = R.permute_dims(A)
+            x: R.Tensor([32]) = R.matmul(x, A_transpose)
+            return x
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to