This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2c49e01a0a [Unity][Transform] Implement
relax.transform.ReorderTakeAfterMatmul (#16315)
2c49e01a0a is described below
commit 2c49e01a0a1a69ae578fb27f7f78225d4a55dffd
Author: Eric Lunderberg <[email protected]>
AuthorDate: Tue Jan 23 16:43:08 2024 -0600
[Unity][Transform] Implement relax.transform.ReorderTakeAfterMatmul (#16315)
If `R.matmul(x, R.take(weights, indices))` occurs, with `R.take`
selecting along the output feature dimension, it can be
rearranged to `R.take(R.matmul(x, weights), indices)`.
---
python/tvm/relax/transform/__init__.py | 1 +
python/tvm/relax/transform/transform.py | 15 ++
src/relax/transform/reorder_take_after_matmul.cc | 164 ++++++++++++++++++
.../test_transform_reorder_take_after_matmul.py | 186 +++++++++++++++++++++
4 files changed, 366 insertions(+)
diff --git a/python/tvm/relax/transform/__init__.py
b/python/tvm/relax/transform/__init__.py
index eeac5f82c8..7efe144c50 100644
--- a/python/tvm/relax/transform/__init__.py
+++ b/python/tvm/relax/transform/__init__.py
@@ -63,6 +63,7 @@ from .transform import (
RemovePurityChecking,
RemoveUnusedParameters,
RemoveUnusedOutputs,
+ ReorderTakeAfterMatmul,
RewriteCUDAGraph,
RewriteDataflowReshape,
RunCodegen,
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index 03d0878810..1f390adb2e 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -1302,6 +1302,21 @@ def ExpandMatmulOfSum():
return _ffi_api.ExpandMatmulOfSum() # type: ignore
+def ReorderTakeAfterMatmul():
+ """Reorder `matmul(x, take(weights, indices))` to
`take(matmul(x,weights),indices)`
+
+ Useful for optimizing LoRA computations, where several LoRAs may
+ be batched together.
+
+ Returns
+ -------
+ ret : tvm.transform.Pass
+ The corresponding pass.
+ """
+
+ return _ffi_api.ReorderTakeAfterMatmul() # type: ignore
+
+
def CombineParallelMatmul(check=None):
"""Combine multiple matmul operators sharing the same LHS matrix into one,
followed by slicing. When all matmul branches in a tree have the same set
of fused ops,
diff --git a/src/relax/transform/reorder_take_after_matmul.cc
b/src/relax/transform/reorder_take_after_matmul.cc
new file mode 100644
index 0000000000..9e037f05f0
--- /dev/null
+++ b/src/relax/transform/reorder_take_after_matmul.cc
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/transform/expand_matmul_of_sum.cc
+ * \brief Expand `matmul(x, A+B)` to `matmul(x, A) + matmul(x,B)`
+ */
+
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/dataflow_matcher.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+
+#include <optional>
+#include <unordered_set>
+#include <vector>
+
+#include "../op/tensor/index.h"
+#include "../op/tensor/linear_algebra.h"
+#include "../op/tensor/manipulate.h"
+
+namespace tvm {
+namespace relax {
+
+namespace {
+std::tuple<DFPattern, TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)>>
CreatePatterns() {
+ auto pat_lhs = WildcardPattern();
+
+ auto pat_weights = WildcardPattern();
+ auto pat_indices = WildcardPattern();
+ auto pat_rhs = IsOp("relax.take")(pat_weights, pat_indices);
+
+ auto pat_matmul = IsOp("relax.matmul")(pat_lhs, pat_rhs);
+
+ auto rewriter = [=](Expr expr, Map<DFPattern, Expr> matches) -> Expr {
+ auto lhs = matches[pat_lhs];
+ auto weights = matches[pat_weights];
+ auto indices = matches[pat_indices];
+
+ const auto* take_call = matches[pat_rhs].as<CallNode>();
+ ICHECK(take_call) << "InternalError: "
+ << "Match of relax.take operator should produce Call, "
+ << "but instead produces " << matches[pat_rhs] << " with
type "
+ << matches[pat_rhs]->GetTypeKey();
+ const auto* attrs = take_call->attrs.as<TakeAttrs>();
+ ICHECK(attrs) << "InternalError: "
+ << "Attributes for relax.take operator should be TakeAttrs, "
+ << "but were instead " << take_call->attrs << " with type "
+ << take_call->GetTypeKey();
+
+ const auto* lhs_sinfo = lhs->struct_info_.as<TensorStructInfoNode>();
+ if (!lhs_sinfo) return expr;
+
+ const auto* weights_sinfo =
weights->struct_info_.as<TensorStructInfoNode>();
+ if (!weights_sinfo) return expr;
+
+ const auto* indices_sinfo =
indices->struct_info_.as<TensorStructInfoNode>();
+ if (!indices_sinfo) return expr;
+
+ const auto* matmul_sinfo = expr->struct_info_.as<TensorStructInfoNode>();
+ if (!matmul_sinfo) return expr;
+
+ if (!attrs->axis.defined()) return expr;
+ auto axis = attrs->axis.value()->value;
+
+ if (lhs_sinfo->IsUnknownNdim() || indices_sinfo->IsUnknownNdim() ||
+ matmul_sinfo->IsUnknownNdim() || weights_sinfo->IsUnknownNdim())
+ return expr;
+
+ if (indices_sinfo->ndim == 1 && axis + 1 == weights_sinfo->ndim) {
+ // Simpler case. The activations may have batch dimensions, but
+ // the weights do not.
+
+ // lhs.shape = [*batch, infeatures]
+ // weights.shape = [infeatures, table_size]
+ // indices.shape = [outfeatures]
+
+ // out_table.shape = [*batch, table_size]
+ auto out_table = matmul(lhs, weights, DataType::Void());
+ // new_output.shape = [*batch, outfeatures]
+ auto new_output = take(out_table, indices, Integer(matmul_sinfo->ndim -
1));
+
+ return new_output;
+ } else if (lhs_sinfo->ndim == 3 && weights_sinfo->ndim == 3 &&
indices_sinfo->ndim == 1 &&
+ axis == 0 && weights_sinfo->GetShape().defined() &&
+ lhs_sinfo->GetShape().defined()) {
+ // More complicated case, used for batched LoRA. The conditions
+ // on the argument dimensions can probably be relaxed, but would
+ // probably need to remove the use of the einsum operator.
+
+ auto lhs_shape = lhs_sinfo->GetShape().value();
+ auto weight_shape = weights_sinfo->GetShape().value();
+
+ // lhs.shape = [batch1, batch2, infeatures]
+ // weights.shape = [table_size, infeatures, outfeatures]
+ // indices.shape = [batch1]
+
+ // reordered_weight.shape = [infeatures, table_size, outfeatures]
+ auto reordered_weight = permute_dims(weights, Array{Integer(1),
Integer(0), Integer(2)});
+ // fused_weight.shape = [infeatures, table_size * outfeatures]
+ auto fused_weight = reshape(reordered_weight,
+ ShapeExpr({weight_shape[1], weight_shape[0]
* weight_shape[2]}));
+ // fused_output.shape = [batch1, batch2, table_size * outfeatures]
+ auto fused_output = matmul(lhs, fused_weight, DataType::Void());
+ // indexed_output.shape = [batch1, batch2, table_size, outfeatures]
+ auto indexed_output = reshape(
+ fused_output, ShapeExpr({lhs_shape[0], lhs_shape[1],
weight_shape[0], weight_shape[2]}));
+
+ // TODO(Lunderberg): Find a better way to express these last two
+ // steps. For an output at [i,j,k], the value is
+ // `indexed_output[i, j, indices[i], k]`, but there doesn't seem
+ // to be a good way to express that in relax. It could be
+ // written using `call_te`, but that would prevent later
+ // optimizations from recognizing the high-level relax
+ // operations.
+
+ // duplicated_output.shape = [batch1, batch2, batch1, outfeatures]
+ auto duplicated_output = take(indexed_output, indices, Integer(2));
+ // new_output.shape = [batch1, batch2, outfeatures]
+ auto new_output = einsum(Tuple({duplicated_output}), "ijik->ijk");
+
+ return new_output;
+ } else {
+ return expr;
+ }
+ };
+
+ return {pat_matmul, rewriter};
+}
+
+} // namespace
+
+namespace transform {
+Pass ReorderTakeAfterMatmul() {
+ auto pass_func = [=](Function func, IRModule mod, PassContext pc) {
+ auto [pattern, rewriter] = CreatePatterns();
+ return RewriteCall(pattern, rewriter, func);
+ };
+ return CreateFunctionPass(pass_func, 1, "ReorderTakeAfterMatmul", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.ReorderTakeAfterMatmul")
+ .set_body_typed(ReorderTakeAfterMatmul);
+
+} // namespace transform
+} // namespace relax
+} // namespace tvm
diff --git a/tests/python/relax/test_transform_reorder_take_after_matmul.py
b/tests/python/relax/test_transform_reorder_take_after_matmul.py
new file mode 100644
index 0000000000..bf969fb3fe
--- /dev/null
+++ b/tests/python/relax/test_transform_reorder_take_after_matmul.py
@@ -0,0 +1,186 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import inspect
+
+import pytest
+
+import tvm.testing
+from tvm import relax
+from tvm.script import ir as I, relax as R, tir as T
+
+
+class Base:
+ def test_compare(self):
+ transform = relax.transform.ReorderTakeAfterMatmul()
+
+ if inspect.isclass(self.Expected) and issubclass(self.Expected,
Exception):
+ with pytest.raises(self.Expected):
+ transform(self.Before)
+ else:
+ after = transform(self.Before)
+ tvm.ir.assert_structural_equal(self.Expected, after)
+
+
+class TestSimple(Base):
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor([1, 16], "float32"),
+ weight_table: R.Tensor([16, "weight_table_size"], "float32"),
+ routing_table: R.Tensor([32], "int64"),
+ ) -> R.Tensor([1, 32], "float32"):
+ weight_table_size = T.int64()
+ with R.dataflow():
+ weight: R.Tensor([16, 32], "float32") = R.take(weight_table,
routing_table, axis=1)
+ out: R.Tensor([1, 32], "float32") = R.matmul(x, weight)
+ R.output(out)
+ return out
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor([1, 16], "float32"),
+ weight_table: R.Tensor([16, "weight_table_size"], "float32"),
+ routing_table: R.Tensor([32], "int64"),
+ ) -> R.Tensor([1, 32], "float32"):
+ weight_table_size = T.int64()
+ with R.dataflow():
+ out_table: R.Tensor([1, weight_table_size], "float32") =
R.matmul(x, weight_table)
+ out: R.Tensor([1, 32], "float32") = R.take(out_table,
routing_table, axis=1)
+ R.output(out)
+ return out
+
+
+class TestBatchedActivations(Base):
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor(["batch_size", 1, 16], "float32"),
+ weight_table: R.Tensor([16, "weight_table_size"], "float32"),
+ routing_table: R.Tensor([32], "int64"),
+ ) -> R.Tensor(["batch_size", 1, 32], "float32"):
+ batch_size = T.int64()
+ weight_table_size = T.int64()
+ with R.dataflow():
+ weight: R.Tensor([16, 32], "float32") = R.take(weight_table,
routing_table, axis=1)
+ out: R.Tensor([batch_size, 1, 32], "float32") = R.matmul(x,
weight)
+ R.output(out)
+ return out
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor(["batch_size", 1, 16], "float32"),
+ weight_table: R.Tensor([16, "weight_table_size"], "float32"),
+ routing_table: R.Tensor([32], "int64"),
+ ) -> R.Tensor(["batch_size", 1, 32], "float32"):
+ batch_size = T.int64()
+ weight_table_size = T.int64()
+ with R.dataflow():
+ out_table: R.Tensor([batch_size, 1, weight_table_size],
"float32") = R.matmul(
+ x, weight_table
+ )
+ out: R.Tensor([batch_size, 1, 32], "float32") = R.take(
+ out_table, routing_table, axis=2
+ )
+ R.output(out)
+ return out
+
+
+class TestStaticBatchedActivationsAndWeights(Base):
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor([128, 1, 16], "float32"),
+ weight_table: R.Tensor(["routing_table_size", 16, 32], "float32"),
+ routing_table: R.Tensor([128], "int64"),
+ ) -> R.Tensor([128, 1, 32], "float32"):
+ batch_size = T.int64()
+ routing_table_size = T.int64()
+ with R.dataflow():
+ weight = R.take(weight_table, routing_table, axis=0)
+ out = R.matmul(x, weight)
+ R.output(out)
+ return out
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor([128, 1, 16], "float32"),
+ weight_table: R.Tensor(["routing_table_size", 16, 32], "float32"),
+ routing_table: R.Tensor([128], "int64"),
+ ) -> R.Tensor([128, 1, 32], "float32"):
+ batch_size = T.int64()
+ routing_table_size = T.int64()
+ with R.dataflow():
+ reordered_weight = R.permute_dims(weight_table, [1, 0, 2])
+ fused_weight = R.reshape(reordered_weight, [16,
routing_table_size * 32])
+ fused_output = R.matmul(x, fused_weight)
+ reordered_output = R.reshape(fused_output, [128, 1,
routing_table_size, 32])
+ tabular_output = R.take(reordered_output, routing_table,
axis=2)
+ out = R.einsum([tabular_output], "ijik->ijk")
+ R.output(out)
+ return out
+
+
+class TestDynamicBatchedActivationsAndWeights(Base):
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor(["batch_size", 1, 16], "float32"),
+ weight_table: R.Tensor(["routing_table_size", 16, 32], "float32"),
+ routing_table: R.Tensor(["batch_size"], "int64"),
+ ) -> R.Tensor(["batch_size", 1, 32], "float32"):
+ batch_size = T.int64()
+ routing_table_size = T.int64()
+ with R.dataflow():
+ weight = R.take(weight_table, routing_table, axis=0)
+ out = R.matmul(x, weight)
+ R.output(out)
+ return out
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor(["batch_size", 1, 16], "float32"),
+ weight_table: R.Tensor(["routing_table_size", 16, 32], "float32"),
+ routing_table: R.Tensor(["batch_size"], "int64"),
+ ) -> R.Tensor(["batch_size", 1, 32], "float32"):
+ batch_size = T.int64()
+ routing_table_size = T.int64()
+ with R.dataflow():
+ reordered_weight = R.permute_dims(weight_table, [1, 0, 2])
+ fused_weight = R.reshape(reordered_weight, [16,
routing_table_size * 32])
+ fused_output = R.matmul(x, fused_weight)
+ reordered_output = R.reshape(fused_output, [batch_size, 1,
routing_table_size, 32])
+ tabular_output = R.take(reordered_output, routing_table,
axis=2)
+ out = R.einsum([tabular_output], "ijik->ijk")
+ R.output(out)
+ return out
+
+
+if __name__ == "__main__":
+ tvm.testing.main()