This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2c49e01a0a [Unity][Transform] Implement 
relax.transform.ReorderTakeAfterMatmul (#16315)
2c49e01a0a is described below

commit 2c49e01a0a1a69ae578fb27f7f78225d4a55dffd
Author: Eric Lunderberg <[email protected]>
AuthorDate: Tue Jan 23 16:43:08 2024 -0600

    [Unity][Transform] Implement relax.transform.ReorderTakeAfterMatmul (#16315)
    
    If `R.matmul(x, R.take(weights, indices))` occurs, with `R.take`
    selecting along the output feature dimension, it can be
    rearranged to `R.take(R.matmul(x, weights), indices)`.
---
 python/tvm/relax/transform/__init__.py             |   1 +
 python/tvm/relax/transform/transform.py            |  15 ++
 src/relax/transform/reorder_take_after_matmul.cc   | 164 ++++++++++++++++++
 .../test_transform_reorder_take_after_matmul.py    | 186 +++++++++++++++++++++
 4 files changed, 366 insertions(+)

diff --git a/python/tvm/relax/transform/__init__.py 
b/python/tvm/relax/transform/__init__.py
index eeac5f82c8..7efe144c50 100644
--- a/python/tvm/relax/transform/__init__.py
+++ b/python/tvm/relax/transform/__init__.py
@@ -63,6 +63,7 @@ from .transform import (
     RemovePurityChecking,
     RemoveUnusedParameters,
     RemoveUnusedOutputs,
+    ReorderTakeAfterMatmul,
     RewriteCUDAGraph,
     RewriteDataflowReshape,
     RunCodegen,
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index 03d0878810..1f390adb2e 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -1302,6 +1302,21 @@ def ExpandMatmulOfSum():
     return _ffi_api.ExpandMatmulOfSum()  # type: ignore
 
 
+def ReorderTakeAfterMatmul():
+    """Reorder `matmul(x, take(weights, indices))` to 
`take(matmul(x,weights),indices)`
+
+    Useful for optimizing LoRA computations, where several LoRAs may
+    be batched together.
+
+    Returns
+    -------
+    ret : tvm.transform.Pass
+        The corresponding pass.
+    """
+
+    return _ffi_api.ReorderTakeAfterMatmul()  # type: ignore
+
+
 def CombineParallelMatmul(check=None):
     """Combine multiple matmul operators sharing the same LHS matrix into one,
     followed by slicing. When all matmul branches in a tree have the same set 
of fused ops,
diff --git a/src/relax/transform/reorder_take_after_matmul.cc 
b/src/relax/transform/reorder_take_after_matmul.cc
new file mode 100644
index 0000000000..9e037f05f0
--- /dev/null
+++ b/src/relax/transform/reorder_take_after_matmul.cc
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/transform/expand_matmul_of_sum.cc
+ * \brief Expand `matmul(x, A+B)` to `matmul(x, A) + matmul(x,B)`
+ */
+
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/dataflow_matcher.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+
+#include <optional>
+#include <unordered_set>
+#include <vector>
+
+#include "../op/tensor/index.h"
+#include "../op/tensor/linear_algebra.h"
+#include "../op/tensor/manipulate.h"
+
+namespace tvm {
+namespace relax {
+
+namespace {
+std::tuple<DFPattern, TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)>> 
CreatePatterns() {
+  auto pat_lhs = WildcardPattern();
+
+  auto pat_weights = WildcardPattern();
+  auto pat_indices = WildcardPattern();
+  auto pat_rhs = IsOp("relax.take")(pat_weights, pat_indices);
+
+  auto pat_matmul = IsOp("relax.matmul")(pat_lhs, pat_rhs);
+
+  auto rewriter = [=](Expr expr, Map<DFPattern, Expr> matches) -> Expr {
+    auto lhs = matches[pat_lhs];
+    auto weights = matches[pat_weights];
+    auto indices = matches[pat_indices];
+
+    const auto* take_call = matches[pat_rhs].as<CallNode>();
+    ICHECK(take_call) << "InternalError: "
+                      << "Match of relax.take operator should produce Call, "
+                      << "but instead produces " << matches[pat_rhs] << " with 
type "
+                      << matches[pat_rhs]->GetTypeKey();
+    const auto* attrs = take_call->attrs.as<TakeAttrs>();
+    ICHECK(attrs) << "InternalError: "
+                  << "Attributes for relax.take operator should be TakeAttrs, "
+                  << "but were instead " << take_call->attrs << " with type "
+                  << take_call->GetTypeKey();
+
+    const auto* lhs_sinfo = lhs->struct_info_.as<TensorStructInfoNode>();
+    if (!lhs_sinfo) return expr;
+
+    const auto* weights_sinfo = 
weights->struct_info_.as<TensorStructInfoNode>();
+    if (!weights_sinfo) return expr;
+
+    const auto* indices_sinfo = 
indices->struct_info_.as<TensorStructInfoNode>();
+    if (!indices_sinfo) return expr;
+
+    const auto* matmul_sinfo = expr->struct_info_.as<TensorStructInfoNode>();
+    if (!matmul_sinfo) return expr;
+
+    if (!attrs->axis.defined()) return expr;
+    auto axis = attrs->axis.value()->value;
+
+    if (lhs_sinfo->IsUnknownNdim() || indices_sinfo->IsUnknownNdim() ||
+        matmul_sinfo->IsUnknownNdim() || weights_sinfo->IsUnknownNdim())
+      return expr;
+
+    if (indices_sinfo->ndim == 1 && axis + 1 == weights_sinfo->ndim) {
+      // Simpler case.  The activations may have batch dimensions, but
+      // the weights do not.
+
+      // lhs.shape = [*batch, infeatures]
+      // weights.shape = [infeatures, table_size]
+      // indices.shape = [outfeatures]
+
+      // out_table.shape = [*batch, table_size]
+      auto out_table = matmul(lhs, weights, DataType::Void());
+      // new_output.shape = [*batch, outfeatures]
+      auto new_output = take(out_table, indices, Integer(matmul_sinfo->ndim - 
1));
+
+      return new_output;
+    } else if (lhs_sinfo->ndim == 3 && weights_sinfo->ndim == 3 && 
indices_sinfo->ndim == 1 &&
+               axis == 0 && weights_sinfo->GetShape().defined() &&
+               lhs_sinfo->GetShape().defined()) {
+      // More complicated case, used for batched LoRA.  The conditions
+      // on the argument dimensions can probably be relaxed, but would
+      // probably need to remove the use of the einsum operator.
+
+      auto lhs_shape = lhs_sinfo->GetShape().value();
+      auto weight_shape = weights_sinfo->GetShape().value();
+
+      // lhs.shape = [batch1, batch2, infeatures]
+      // weights.shape = [table_size, infeatures, outfeatures]
+      // indices.shape = [batch1]
+
+      // reordered_weight.shape = [infeatures, table_size, outfeatures]
+      auto reordered_weight = permute_dims(weights, Array{Integer(1), 
Integer(0), Integer(2)});
+      // fused_weight.shape = [infeatures, table_size * outfeatures]
+      auto fused_weight = reshape(reordered_weight,
+                                  ShapeExpr({weight_shape[1], weight_shape[0] 
* weight_shape[2]}));
+      // fused_output.shape = [batch1, batch2, table_size * outfeatures]
+      auto fused_output = matmul(lhs, fused_weight, DataType::Void());
+      // indexed_output.shape = [batch1, batch2, table_size, outfeatures]
+      auto indexed_output = reshape(
+          fused_output, ShapeExpr({lhs_shape[0], lhs_shape[1], 
weight_shape[0], weight_shape[2]}));
+
+      // TODO(Lunderberg): Find a better way to express these last two
+      // steps.  For an output at [i,j,k], the value is
+      // `indexed_output[i, j, indices[i], k]`, but there doesn't seem
+      // to be a good way to express that in relax.  It could be
+      // written using `call_te`, but that would prevent later
+      // optimizations from recognizing the high-level relax
+      // operations.
+
+      // duplicated_output.shape = [batch1, batch2, batch1, outfeatures]
+      auto duplicated_output = take(indexed_output, indices, Integer(2));
+      // new_output.shape = [batch1, batch2, outfeatures]
+      auto new_output = einsum(Tuple({duplicated_output}), "ijik->ijk");
+
+      return new_output;
+    } else {
+      return expr;
+    }
+  };
+
+  return {pat_matmul, rewriter};
+}
+
+}  // namespace
+
+namespace transform {
+Pass ReorderTakeAfterMatmul() {
+  auto pass_func = [=](Function func, IRModule mod, PassContext pc) {
+    auto [pattern, rewriter] = CreatePatterns();
+    return RewriteCall(pattern, rewriter, func);
+  };
+  return CreateFunctionPass(pass_func, 1, "ReorderTakeAfterMatmul", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.ReorderTakeAfterMatmul")
+    .set_body_typed(ReorderTakeAfterMatmul);
+
+}  // namespace transform
+}  // namespace relax
+}  // namespace tvm
diff --git a/tests/python/relax/test_transform_reorder_take_after_matmul.py 
b/tests/python/relax/test_transform_reorder_take_after_matmul.py
new file mode 100644
index 0000000000..bf969fb3fe
--- /dev/null
+++ b/tests/python/relax/test_transform_reorder_take_after_matmul.py
@@ -0,0 +1,186 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import inspect
+
+import pytest
+
+import tvm.testing
+from tvm import relax
+from tvm.script import ir as I, relax as R, tir as T
+
+
+class Base:
+    def test_compare(self):
+        transform = relax.transform.ReorderTakeAfterMatmul()
+
+        if inspect.isclass(self.Expected) and issubclass(self.Expected, 
Exception):
+            with pytest.raises(self.Expected):
+                transform(self.Before)
+        else:
+            after = transform(self.Before)
+            tvm.ir.assert_structural_equal(self.Expected, after)
+
+
+class TestSimple(Base):
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor([1, 16], "float32"),
+            weight_table: R.Tensor([16, "weight_table_size"], "float32"),
+            routing_table: R.Tensor([32], "int64"),
+        ) -> R.Tensor([1, 32], "float32"):
+            weight_table_size = T.int64()
+            with R.dataflow():
+                weight: R.Tensor([16, 32], "float32") = R.take(weight_table, 
routing_table, axis=1)
+                out: R.Tensor([1, 32], "float32") = R.matmul(x, weight)
+                R.output(out)
+            return out
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor([1, 16], "float32"),
+            weight_table: R.Tensor([16, "weight_table_size"], "float32"),
+            routing_table: R.Tensor([32], "int64"),
+        ) -> R.Tensor([1, 32], "float32"):
+            weight_table_size = T.int64()
+            with R.dataflow():
+                out_table: R.Tensor([1, weight_table_size], "float32") = 
R.matmul(x, weight_table)
+                out: R.Tensor([1, 32], "float32") = R.take(out_table, 
routing_table, axis=1)
+                R.output(out)
+            return out
+
+
+class TestBatchedActivations(Base):
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor(["batch_size", 1, 16], "float32"),
+            weight_table: R.Tensor([16, "weight_table_size"], "float32"),
+            routing_table: R.Tensor([32], "int64"),
+        ) -> R.Tensor(["batch_size", 1, 32], "float32"):
+            batch_size = T.int64()
+            weight_table_size = T.int64()
+            with R.dataflow():
+                weight: R.Tensor([16, 32], "float32") = R.take(weight_table, 
routing_table, axis=1)
+                out: R.Tensor([batch_size, 1, 32], "float32") = R.matmul(x, 
weight)
+                R.output(out)
+            return out
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor(["batch_size", 1, 16], "float32"),
+            weight_table: R.Tensor([16, "weight_table_size"], "float32"),
+            routing_table: R.Tensor([32], "int64"),
+        ) -> R.Tensor(["batch_size", 1, 32], "float32"):
+            batch_size = T.int64()
+            weight_table_size = T.int64()
+            with R.dataflow():
+                out_table: R.Tensor([batch_size, 1, weight_table_size], 
"float32") = R.matmul(
+                    x, weight_table
+                )
+                out: R.Tensor([batch_size, 1, 32], "float32") = R.take(
+                    out_table, routing_table, axis=2
+                )
+                R.output(out)
+            return out
+
+
+class TestStaticBatchedActivationsAndWeights(Base):
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor([128, 1, 16], "float32"),
+            weight_table: R.Tensor(["routing_table_size", 16, 32], "float32"),
+            routing_table: R.Tensor([128], "int64"),
+        ) -> R.Tensor([128, 1, 32], "float32"):
+            batch_size = T.int64()
+            routing_table_size = T.int64()
+            with R.dataflow():
+                weight = R.take(weight_table, routing_table, axis=0)
+                out = R.matmul(x, weight)
+                R.output(out)
+            return out
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor([128, 1, 16], "float32"),
+            weight_table: R.Tensor(["routing_table_size", 16, 32], "float32"),
+            routing_table: R.Tensor([128], "int64"),
+        ) -> R.Tensor([128, 1, 32], "float32"):
+            batch_size = T.int64()
+            routing_table_size = T.int64()
+            with R.dataflow():
+                reordered_weight = R.permute_dims(weight_table, [1, 0, 2])
+                fused_weight = R.reshape(reordered_weight, [16, 
routing_table_size * 32])
+                fused_output = R.matmul(x, fused_weight)
+                reordered_output = R.reshape(fused_output, [128, 1, 
routing_table_size, 32])
+                tabular_output = R.take(reordered_output, routing_table, 
axis=2)
+                out = R.einsum([tabular_output], "ijik->ijk")
+                R.output(out)
+            return out
+
+
+class TestDynamicBatchedActivationsAndWeights(Base):
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor(["batch_size", 1, 16], "float32"),
+            weight_table: R.Tensor(["routing_table_size", 16, 32], "float32"),
+            routing_table: R.Tensor(["batch_size"], "int64"),
+        ) -> R.Tensor(["batch_size", 1, 32], "float32"):
+            batch_size = T.int64()
+            routing_table_size = T.int64()
+            with R.dataflow():
+                weight = R.take(weight_table, routing_table, axis=0)
+                out = R.matmul(x, weight)
+                R.output(out)
+            return out
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor(["batch_size", 1, 16], "float32"),
+            weight_table: R.Tensor(["routing_table_size", 16, 32], "float32"),
+            routing_table: R.Tensor(["batch_size"], "int64"),
+        ) -> R.Tensor(["batch_size", 1, 32], "float32"):
+            batch_size = T.int64()
+            routing_table_size = T.int64()
+            with R.dataflow():
+                reordered_weight = R.permute_dims(weight_table, [1, 0, 2])
+                fused_weight = R.reshape(reordered_weight, [16, 
routing_table_size * 32])
+                fused_output = R.matmul(x, fused_weight)
+                reordered_output = R.reshape(fused_output, [batch_size, 1, 
routing_table_size, 32])
+                tabular_output = R.take(reordered_output, routing_table, 
axis=2)
+                out = R.einsum([tabular_output], "ijik->ijk")
+                R.output(out)
+            return out
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to