This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new f988cb48d1 [Bugfix][SLM] Produce well-formed Relax for
nn.modules.KVCache (#16684)
f988cb48d1 is described below
commit f988cb48d116defec5bbf6a1dd70cc1e538af203
Author: Eric Lunderberg <[email protected]>
AuthorDate: Sun Mar 10 16:12:02 2024 -0500
[Bugfix][SLM] Produce well-formed Relax for nn.modules.KVCache (#16684)
* [Bugfix][SLM] Produce well-formed Relax for nn.modules.KVCache
Prior to this commit, the `nn.modules.KVCache` implementations used
`R.call_packed(...)` to call the `"vm.builtin.attention_*"` functions.
Since `nn.Module` emits all relax functions within a
`relax.DataflowBlock`, where impure expressions are forbidden, this is
ill-formed.
This commit updates the implementations in `nn.modules.KVCache` to use
`R.call_pure_packed` instead of `R.call_packed`. This assertation
that the callee is pure allows the call to occur within a
`relax.DataflowBlock`.
* Correct import for relax
* Fix unit test
---
python/tvm/relax/frontend/nn/exporter.py | 4 +++-
python/tvm/relax/frontend/nn/modules.py | 27 +++++++++++++++++++-------
tests/python/relax/test_frontend_nn_modules.py | 10 +++++-----
3 files changed, 28 insertions(+), 13 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/exporter.py
b/python/tvm/relax/frontend/nn/exporter.py
index d452af69d3..1a7dcd6a64 100644
--- a/python/tvm/relax/frontend/nn/exporter.py
+++ b/python/tvm/relax/frontend/nn/exporter.py
@@ -21,7 +21,7 @@ import typing
from tvm import tir
from tvm.ir import IRModule
-from ... import expr as rx
+from .... import relax as rx
from ...block_builder import BlockBuilder
from ...struct_info import ObjectStructInfo, ShapeStructInfo, TupleStructInfo
from . import core, extern
@@ -136,6 +136,8 @@ class Exporter:
outputs, inputs = _emit_method(self.builder,
method_spec, params, effects)
self.builder.emit_func_output(outputs, inputs)
mod = self.builder.finalize()
+ assert rx.analysis.well_formed(mod)
+
return mod, params, ext_mods
diff --git a/python/tvm/relax/frontend/nn/modules.py
b/python/tvm/relax/frontend/nn/modules.py
index e69660f708..1579c5b512 100644
--- a/python/tvm/relax/frontend/nn/modules.py
+++ b/python/tvm/relax/frontend/nn/modules.py
@@ -19,7 +19,7 @@
from typing import List, Optional, Sequence, Union
from tvm import relax as rx
-from tvm import tir
+from tvm import tir, ir
from . import op
from .core import Effect, Module, ModuleList, Parameter, Tensor,
get_default_dtype
@@ -600,8 +600,13 @@ class KVCache(Effect):
return [
bb.emit(
rx.Call(
- rx.extern("vm.builtin.attention_kv_cache_create"),
- args=[rx.op.zeros(init_shape, self.dtype), init_shape,
rx.PrimValue(0)],
+ ir.Op.get("relax.call_pure_packed"),
+ args=[
+ rx.extern("vm.builtin.attention_kv_cache_create"),
+ rx.op.zeros(init_shape, self.dtype),
+ init_shape,
+ rx.PrimValue(0),
+ ],
sinfo_args=[rx.ObjectStructInfo()],
),
name_hint=name_hint,
@@ -671,8 +676,12 @@ class KVCache(Effect):
return Tensor(
_expr=rx.BlockBuilder.current().emit(
rx.Call(
- rx.extern("vm.builtin.attention_kv_cache_view"),
- args=[self.cache, shape],
+ ir.Op.get("relax.call_pure_packed"),
+ args=[
+ rx.extern("vm.builtin.attention_kv_cache_view"),
+ self.cache,
+ shape,
+ ],
sinfo_args=[rx.TensorStructInfo(shape, self.dtype)],
)
)
@@ -694,8 +703,12 @@ class KVCache(Effect):
)
self.cache = rx.BlockBuilder.current().emit(
rx.Call(
- rx.extern("vm.builtin.attention_kv_cache_append"),
- args=[self.cache, new_element._expr],
+ ir.Op.get("relax.call_pure_packed"),
+ args=[
+ rx.extern("vm.builtin.attention_kv_cache_append"),
+ self.cache,
+ new_element._expr,
+ ],
sinfo_args=[rx.ObjectStructInfo()],
)
)
diff --git a/tests/python/relax/test_frontend_nn_modules.py
b/tests/python/relax/test_frontend_nn_modules.py
index 6966a5f2a9..9b357114d3 100644
--- a/tests/python/relax/test_frontend_nn_modules.py
+++ b/tests/python/relax/test_frontend_nn_modules.py
@@ -484,15 +484,15 @@ def test_kv_cache():
lv: R.Tensor((8, 2, 4), dtype="float32") = R.zeros(
R.shape([8, 2, 4]), dtype="float32"
)
- cache: R.Object = R.call_packed(
+ cache: R.Object = R.call_pure_packed(
"vm.builtin.attention_kv_cache_create",
lv,
R.shape([8, 2, 4]),
R.prim_value(0),
sinfo_args=(R.Object,),
)
- lv1: R.Tuple(R.Object, R.Object) = _io, cache
- gv: R.Tuple(R.Object, R.Object) = lv1
+ lv1 = _io, cache
+ gv = lv1
R.output(gv)
return gv
@@ -502,10 +502,10 @@ def test_kv_cache():
) -> R.Tuple(R.Tensor((4, 2, 4), dtype="float32"), R.Tuple(R.Object,
R.Object)):
R.func_attr({"num_input": 3})
with R.dataflow():
- lv2: R.Object = R.call_packed(
+ lv2: R.Object = R.call_pure_packed(
"vm.builtin.attention_kv_cache_append", cache, x,
sinfo_args=(R.Object,)
)
- lv3: R.Tensor((4, 2, 4), dtype="float32") = R.call_packed(
+ lv3: R.Tensor((4, 2, 4), dtype="float32") = R.call_pure_packed(
"vm.builtin.attention_kv_cache_view",
lv2,
R.shape([4, 2, 4]),