This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 783b467845 [Unity] CUDA Graph update (#15320)
783b467845 is described below
commit 783b46784530eb21e927cb62e1675afd680610ad
Author: masahi <[email protected]>
AuthorDate: Sat Jul 15 08:23:31 2023 +0900
[Unity] CUDA Graph update (#15320)
* allow capturing input parameters in a cuda graph
* remove unnecessary cudaGraphLaunch
* support cuda graph for cutlass
* add test
* add test for cutlass
* revert LiftTransformParams change
* comment
* update test
* update builtin
* update
* delete exec properly
* run cuda graph twice in the test to make sure cached launch works
---
python/tvm/contrib/cutlass/attention_operation.py | 6 +-
python/tvm/contrib/cutlass/conv2d_operation.py | 7 +-
python/tvm/contrib/cutlass/gemm_operation.py | 47 +--
python/tvm/contrib/cutlass/gen_tensor_op.py | 2 +-
python/tvm/contrib/cutlass/layer_norm_operation.py | 6 +-
python/tvm/contrib/cutlass/rms_norm_operation.py | 6 +-
src/relax/transform/rewrite_cuda_graph.cc | 20 +-
src/runtime/cuda/cuda_device_api.cc | 4 +
src/runtime/relax_vm/cuda/cuda_graph_builtin.cc | 47 +--
tests/python/relax/test_codegen_cutlass.py | 70 ++++-
.../relax/test_transform_rewrite_cuda_graph.py | 322 +++++++++++++++++++++
11 files changed, 465 insertions(+), 72 deletions(-)
diff --git a/python/tvm/contrib/cutlass/attention_operation.py
b/python/tvm/contrib/cutlass/attention_operation.py
index 7240f24de4..b6a9517f80 100644
--- a/python/tvm/contrib/cutlass/attention_operation.py
+++ b/python/tvm/contrib/cutlass/attention_operation.py
@@ -139,7 +139,11 @@ def instantiate_attention_template(attrs):
}
CHECK(Attention::check_supported(p));
- kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);
+ auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
+ ICHECK(func != nullptr);
+ cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());
+
+ kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes, stream>>>(p);
if (accumulator_buf_allocated) {
cudaFree(p.output_accum_ptr);
diff --git a/python/tvm/contrib/cutlass/conv2d_operation.py
b/python/tvm/contrib/cutlass/conv2d_operation.py
index 8f85fc5382..77f4449db2 100644
--- a/python/tvm/contrib/cutlass/conv2d_operation.py
+++ b/python/tvm/contrib/cutlass/conv2d_operation.py
@@ -423,7 +423,12 @@ def instantiate_conv2d_template(attrs):
status = conv2d_op.initialize(arguments, workspace.get());
CHECK(status == cutlass::Status::kSuccess);
${split_k_update}
- status = conv2d_op();
+
+ auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
+ ICHECK(func != nullptr);
+ cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());
+
+ status = conv2d_op(stream);
CHECK(status == cutlass::Status::kSuccess);
${split_k_reduction}
"""
diff --git a/python/tvm/contrib/cutlass/gemm_operation.py
b/python/tvm/contrib/cutlass/gemm_operation.py
index 86f6e97719..3fa6e9be8d 100644
--- a/python/tvm/contrib/cutlass/gemm_operation.py
+++ b/python/tvm/contrib/cutlass/gemm_operation.py
@@ -344,7 +344,12 @@ def instantiate_gemm_template(attrs):
CHECK(status == cutlass::Status::kSuccess);
status = gemm_op.initialize(arguments, workspace.get());
CHECK(status == cutlass::Status::kSuccess);
- status = gemm_op();
+
+ auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
+ ICHECK(func != nullptr);
+ cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());
+
+ status = gemm_op(stream);
CHECK(status == cutlass::Status::kSuccess);
"""
op_type = attrs["op_type"]
@@ -416,38 +421,34 @@ def emit_fp16A_int4B_matmul(attrs):
int m = ${A_arg}->shape[${batch_offset}];
int n = ${B_arg}->shape[1] * 2;
int k = ${B_arg}->shape[0];
+
+ auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
+ ICHECK(func != nullptr);
+ cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());
""",
attrs,
)
template = """
${template_common}
- gemm_fp16_int4(static_cast<cutlass::half_t*>(${A_arg}->data),
- static_cast<cutlass::uint4b_t*>(${B_arg}->data),
- static_cast<cutlass::half_t*>(${scales_arg}->data),
- static_cast<cutlass::half_t*>(out0->data),
- m, n, k, nullptr, 0, nullptr);
-"""
-
- template_bias = """
- ${template_common}
- gemm_fp16_int4_bias(static_cast<cutlass::half_t*>(${A_arg}->data),
- static_cast<cutlass::uint4b_t*>(${B_arg}->data),
- static_cast<cutlass::half_t*>(${scales_arg}->data),
- static_cast<cutlass::half_t*>(${bias_arg}->data),
- static_cast<cutlass::half_t*>(out0->data),
- m, n, k, nullptr, 0, nullptr);
+ gemm_fp16_int_bias_act(static_cast<cutlass::half_t*>(${A_arg}->data),
+ static_cast<${weight_dtype}*>(${B_arg}->data),
+ static_cast<cutlass::half_t*>(${scales_arg}->data),
+ ${bias},
+ static_cast<cutlass::half_t*>(out0->data),
+ "${activation}",
+ m, n, k, ${bias_stride}, nullptr, 0, stream);
"""
template_residual = """
${template_common}
-
gemm_fp16_int4_bias_act_residual(static_cast<cutlass::half_t*>(${A_arg}->data),
- static_cast<cutlass::uint4b_t*>(${B_arg}->data),
- static_cast<cutlass::half_t*>(${scales_arg}->data),
- ${bias},
- static_cast<cutlass::half_t*>(${residual_arg}->data),
- static_cast<cutlass::half_t*>(out0->data), "${activation}",
"${binary_op}", "${unary_op}",
- m, n, k, nullptr, 0, nullptr);
+
gemm_fp16_int_bias_act_residual(static_cast<cutlass::half_t*>(${A_arg}->data),
+ static_cast<${weight_dtype}*>(${B_arg}->data),
+ static_cast<cutlass::half_t*>(${scales_arg}->data),
+ ${bias},
+ static_cast<cutlass::half_t*>(${residual_arg}->data),
+ static_cast<cutlass::half_t*>(out0->data), "${activation}",
"${binary_op}", "${unary_op}",
+ m, n, k, nullptr, 0, stream);
"""
if "residual_arg" in attrs and "bias_arg" in attrs:
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index 0aaafe8505..8c8bcc20c3 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -483,7 +483,7 @@ def instantiate_template(func_name, annotations, func_args):
if k in annotations:
attrs[k] = annotations[k]
- headers = []
+ headers = ["tvm/runtime/registry.h"]
if "relu" in func_name:
headers.append("cutlass/epilogue/thread/linear_combination_bias_relu.h")
diff --git a/python/tvm/contrib/cutlass/layer_norm_operation.py
b/python/tvm/contrib/cutlass/layer_norm_operation.py
index 589f559e93..ad2e730d27 100644
--- a/python/tvm/contrib/cutlass/layer_norm_operation.py
+++ b/python/tvm/contrib/cutlass/layer_norm_operation.py
@@ -39,6 +39,10 @@ def instantiate_layer_norm_template(attrs):
cutlass::TensorRef<data_type, RowMajor> _beta((data_type*)${beta}->data,
layout_channels);
cutlass::TensorRef<data_type, RowMajor> _output((data_type*)out0->data,
layout_2D);
- cutlass::layernorm(size, _output, _input, _gamma, _beta, NULL);
+ auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
+ ICHECK(func != nullptr);
+ cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator
void*());
+
+ cutlass::layernorm(size, _output, _input, _gamma, _beta, stream);
"""
return substitute_template(template, attrs)
diff --git a/python/tvm/contrib/cutlass/rms_norm_operation.py
b/python/tvm/contrib/cutlass/rms_norm_operation.py
index e24d6bc39a..ef0d8ef61f 100644
--- a/python/tvm/contrib/cutlass/rms_norm_operation.py
+++ b/python/tvm/contrib/cutlass/rms_norm_operation.py
@@ -38,6 +38,10 @@ def instantiate_rms_norm_template(attrs):
cutlass::TensorRef<data_type, RowMajor>
_weight((data_type*)${weight}->data, layout_channels);
cutlass::TensorRef<data_type, RowMajor> _output((data_type*)out0->data,
layout_2D);
- cutlass::rmsnorm(size, _output, _input, _weight, nullptr);
+ auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
+ ICHECK(func != nullptr);
+ cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator
void*());
+
+ cutlass::rmsnorm(size, _output, _input, _weight, stream);
"""
return substitute_template(template, attrs)
diff --git a/src/relax/transform/rewrite_cuda_graph.cc
b/src/relax/transform/rewrite_cuda_graph.cc
index af10640635..088fa38758 100644
--- a/src/relax/transform/rewrite_cuda_graph.cc
+++ b/src/relax/transform/rewrite_cuda_graph.cc
@@ -150,8 +150,16 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
explicit CUDAGraphRewritePlanner(const IRModule& mod) : mod_(mod) {}
std::vector<LiftedFunctionRewritePlan> Plan() {
for (const auto& pair : mod_->functions) {
- const auto& func = pair.second;
- if (func->IsInstance<FunctionNode>()) {
+ if (pair.second->IsInstance<FunctionNode>()) {
+ // If a function has the num_input attribute, the last
func->params.size() - num_inputs
+ // inputs are assumed to be fixed and thus they can be captured into a
cuda graph.
+ static const char* attr_num_input = "num_input";
+ const auto& func = Downcast<Function>(pair.second);
+ if (auto num_input = func->attrs.GetAttr<Integer>(attr_num_input)) {
+ for (size_t i = num_input.value().IntValue(); i <
func->params.size(); ++i) {
+ static_vars_.insert(func->params[i].get());
+ }
+ }
VisitExpr(func);
}
}
@@ -349,7 +357,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
if (vars_collector != nullptr) {
vars_collector->push_back(var);
}
- return static_bindings_.count(var);
+ return static_vars_.count(var);
}
if (const auto* shape = expr.as<ShapeExprNode>()) {
@@ -402,7 +410,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
current_.capture_builder->AddBinding(binding);
binding_to_region_[binding->var.get()] = current_.capture_builder;
}
- static_bindings_.emplace(binding->var.get(), GetRef<VarBinding>(binding));
+ static_vars_.emplace(binding->var.get());
}
/*! \brief The states of the current scope (the BindingBlock) which is a
pair of FuncBuilder.
@@ -419,8 +427,8 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
IRModule mod_;
// States of the current scope
Scope current_;
- // All the static bindings
- std::unordered_map<const VarNode*, VarBinding> static_bindings_;
+ // Variables whose buffer address is fixed
+ std::unordered_set<const VarNode*> static_vars_;
// Binding to the FuncBuilder if the binding is lifted. This is used to
update the inputs/outputs
// of the lifted function when its binding is used outside.
std::unordered_map<const VarNode*, FuncBuilder*> binding_to_region_;
diff --git a/src/runtime/cuda/cuda_device_api.cc
b/src/runtime/cuda/cuda_device_api.cc
index 71788e5299..b8854f88cb 100644
--- a/src/runtime/cuda/cuda_device_api.cc
+++ b/src/runtime/cuda/cuda_device_api.cc
@@ -300,5 +300,9 @@ TVM_DLL String GetCudaFreeMemory() {
TVM_REGISTER_GLOBAL("runtime.GetCudaFreeMemory").set_body_typed(GetCudaFreeMemory);
+TVM_REGISTER_GLOBAL("runtime.get_cuda_stream").set_body_typed([]() {
+ return static_cast<void*>(CUDAThreadEntry::ThreadLocal()->stream);
+});
+
} // namespace runtime
} // namespace tvm
diff --git a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
index 9d2025d647..f6eef9ca25 100644
--- a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
+++ b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
@@ -31,45 +31,22 @@ namespace tvm {
namespace runtime {
namespace relax_vm {
-/*! \brief Represents a CUDA graph. */
-class CUDAGraphNode : public Object {
- public:
- cudaGraph_t handle_ = nullptr;
-
- ~CUDAGraphNode() {
- if (handle_ != nullptr) {
- cudaGraphDestroy(handle_);
- }
- }
-
- TVM_DECLARE_FINAL_OBJECT_INFO(CUDAGraphNode, Object);
-};
-
-/*!
- * \brief Managed reference to CUDAGraphNode
- * \sa CUDAGraphNode
- */
-class CUDAGraph : public ObjectRef {
- public:
- explicit CUDAGraph(cudaGraph_t handle) {
- auto n = make_object<CUDAGraphNode>();
- n->handle_ = handle;
- data_ = std::move(n);
- }
- TVM_DEFINE_OBJECT_REF_METHODS(CUDAGraph, ObjectRef, CUDAGraphNode);
-};
-
/*! \brief The cache states of a CUDA graph. */
class CUDAGraphCache : public Object {
public:
struct CaptureResult {
+ ~CaptureResult() {
+ if (exec) {
+ CUDA_CALL(cudaGraphExecDestroy(exec));
+ }
+ }
/*!
* \brief Tuple of intemediate tensors in the capture func that will be
used outside the
* capture func
*/
ObjectRef states;
- /*! \brief The cuda graph instance */
- CUDAGraph graph;
+ /*! \brief The instantiated cuda graph */
+ cudaGraphExec_t exec = nullptr;
};
static CUDAGraphCache* Get() { return
dmlc::ThreadLocalStore<CUDAGraphCache>::Get(); }
@@ -88,11 +65,8 @@ class CUDAGraphCache : public Object {
int64_t entry_index) {
if (auto it = capture_cache_.find(entry_index); it !=
capture_cache_.end()) {
// Launch CUDA graph
- const auto& [states, cuda_graph] = it->second;
- cudaGraphExec_t cuda_graph_exec;
- CUDA_CALL(cudaGraphInstantiate(&cuda_graph_exec, cuda_graph->handle_,
NULL, NULL, 0));
- CUDA_CALL(cudaGraphLaunch(cuda_graph_exec,
CUDAThreadEntry::ThreadLocal()->stream));
- CUDA_CALL(cudaGraphExecDestroy(cuda_graph_exec));
+ const auto& [states, exec] = it->second;
+ CUDA_CALL(cudaGraphLaunch(exec, CUDAThreadEntry::ThreadLocal()->stream));
return states;
}
@@ -129,9 +103,10 @@ class CUDAGraphCache : public Object {
CUDA_CALL(cudaStreamEndCapture(CUDAThreadEntry::ThreadLocal()->stream,
&graph));
std::swap(capture_stream, CUDAThreadEntry::ThreadLocal()->stream);
- entry.graph = CUDAGraph(graph);
capture_cache_[entry_index] = entry;
+ CUDA_CALL(cudaGraphInstantiate(&capture_cache_[entry_index].exec, graph,
NULL, NULL, 0));
CUDA_CALL(cudaStreamDestroy(capture_stream));
+ CUDA_CALL(cudaGraphDestroy(graph));
return entry.states;
}
diff --git a/tests/python/relax/test_codegen_cutlass.py
b/tests/python/relax/test_codegen_cutlass.py
index e1ce46ecb0..1528141e4a 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -85,15 +85,23 @@ cutlass_enabled = pytest.mark.skipif(
pytestmark = [cutlass_enabled]
-def build_and_run(mod, inputs_np, target, legalize=True):
+def build_and_run(mod, inputs_np, target, legalize=True, cuda_graph=False):
if legalize:
mod = relax.transform.LegalizeOps()(mod) # For cpu reference, nop for
cutlass.
+ with tvm.transform.PassContext(config={"relax.backend.use_cuda_graph":
cuda_graph}):
+ ex = relax.build(mod, target)
+
dev = tvm.device(target, 0)
- ex = relax.build(mod, target)
vm = relax.VirtualMachine(ex, dev)
f = vm["main"]
inputs = [tvm.nd.array(inp, dev) for inp in inputs_np]
+
+ # For cuda graph, run the compiled function twice to make sure that we can
launch the cached
+ # graph on the second run.
+ if cuda_graph:
+ f(*inputs)
+
return f(*inputs).numpy()
@@ -1554,5 +1562,63 @@ def test_rms_norm():
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+def test_conv2d_cuda_graph():
+ @tvm.script.ir_module
+ class Conv2d:
+ @R.function
+ def main(
+ data: R.Tensor((16, 32, 32, 16), "float16"),
+ weight1: R.Tensor((16, 3, 3, 16), "float16"),
+ weight2: R.Tensor((16, 3, 3, 16), "float16"),
+ weight3: R.Tensor((16, 3, 3, 16), "float16"),
+ gamma: R.Tensor((16,), "float16"),
+ beta: R.Tensor((16,), "float16"),
+ ):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ conv1 = R.nn.relu(
+ R.nn.conv2d(
+ data, weight1, padding=(1, 1), data_layout="NHWC",
kernel_layout="OHWI"
+ )
+ )
+ conv2 = R.nn.relu(
+ R.nn.conv2d(
+ conv1, weight2, padding=(1, 1), data_layout="NHWC",
kernel_layout="OHWI"
+ )
+ )
+ ln = R.nn.layer_norm(conv2, gamma, beta, axes=[-1])
+ conv3 = R.nn.relu(
+ R.nn.conv2d(
+ ln, weight3, padding=(1, 1), data_layout="NHWC",
kernel_layout="OHWI"
+ )
+ )
+ R.output(conv3)
+
+ return conv3
+
+ low, high = -1, 1
+ data_shape = (16, 32, 32, 16)
+ weight_shape = (16, 3, 3, 16)
+ dtype = "float16"
+ data = np.random.randint(low, high, size=data_shape).astype(dtype)
+ weight1 = np.random.randint(low, high, size=weight_shape).astype(dtype)
+ weight2 = np.random.randint(low, high, size=weight_shape).astype(dtype)
+ weight3 = np.random.randint(low, high, size=weight_shape).astype(dtype)
+ gamma = np.random.randint(low, high, size=(weight_shape[0],)).astype(dtype)
+ beta = np.random.randint(low, high, size=(weight_shape[0],)).astype(dtype)
+ inputs = [data, weight1, weight2, weight3, gamma, beta]
+
+ mod = partition_for_cutlass(Conv2d)
+ mod = relax.transform.RunCodegen({"cutlass": {"sm": 80,
"find_first_valid": True}})(mod)
+ mod = relax.pipeline.get_pipeline()(mod) # pylint:
disable=no-value-for-parameter
+
+ with tvm.target.Target("cuda"):
+ mod = tvm.tir.transform.DefaultGPUSchedule()(mod)
+
+ out = build_and_run(mod, inputs, "cuda", cuda_graph=True)
+ ref = build_and_run(Conv2d, inputs, "llvm", legalize=True)
+ tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_transform_rewrite_cuda_graph.py
b/tests/python/relax/test_transform_rewrite_cuda_graph.py
index 52362eae01..106147ef9a 100644
--- a/tests/python/relax/test_transform_rewrite_cuda_graph.py
+++ b/tests/python/relax/test_transform_rewrite_cuda_graph.py
@@ -339,5 +339,327 @@ def test_vm_builtin():
tvm.ir.assert_structural_equal(after, Expected)
+def test_capture_fixed_inputs():
+ @tvm.script.ir_module
+ class Conv2dx3:
+ @R.function
+ def main(
+ data: R.Tensor((16, 32, 32, 16), "float16"),
+ weight1: R.Tensor((16, 3, 3, 16), "float16"),
+ weight2: R.Tensor((16, 3, 3, 16), "float16"),
+ weight3: R.Tensor((16, 3, 3, 16), "float16"),
+ gamma: R.Tensor((16,), "float16"),
+ beta: R.Tensor((16,), "float16"),
+ ):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ conv1 = R.nn.relu(
+ R.nn.conv2d(
+ data, weight1, padding=(1, 1), data_layout="NHWC",
kernel_layout="OHWI"
+ )
+ )
+
+
###############################################################################
+ # The second conv2d and layer norm can be captured into a graph
+ conv2 = R.nn.relu(
+ R.nn.conv2d(
+ conv1, weight2, padding=(1, 1), data_layout="NHWC",
kernel_layout="OHWI"
+ )
+ )
+ ln = R.nn.layer_norm(conv2, gamma, beta, axes=[-1])
+
###############################################################################
+
+ conv3 = R.nn.relu(
+ R.nn.conv2d(
+ ln, weight3, padding=(1, 1), data_layout="NHWC",
kernel_layout="OHWI"
+ )
+ )
+ R.output(conv3)
+
+ return conv3
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def fused_conv2d_relu(
+ data: T.Buffer((T.int64(16), T.int64(32), T.int64(32),
T.int64(16)), "float16"),
+ weight1: T.Buffer((T.int64(16), T.int64(3), T.int64(3),
T.int64(16)), "float16"),
+ var_compute_intermediate: T.Buffer(
+ (T.int64(16), T.int64(32), T.int64(32), T.int64(16)), "float16"
+ ),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ pad_temp = T.alloc_buffer(
+ (T.int64(16), T.int64(34), T.int64(34), T.int64(16)), "float16"
+ )
+ var_conv2d_nhwc_intermediate = T.alloc_buffer(
+ (T.int64(16), T.int64(32), T.int64(32), T.int64(16)), "float16"
+ )
+ for i0, i1, i2, i3 in T.grid(T.int64(16), T.int64(34),
T.int64(34), T.int64(16)):
+ with T.block("pad_temp"):
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2,
i3])
+ T.reads(data[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1),
v_i3])
+ T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
+ pad_temp[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(
+ T.int64(1) <= v_i1
+ and v_i1 < T.int64(33)
+ and T.int64(1) <= v_i2
+ and v_i2 < T.int64(33),
+ data[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1), v_i3],
+ T.float16(0),
+ )
+ for nn, yy, xx, ff, ry, rx, rc in T.grid(
+ T.int64(16),
+ T.int64(32),
+ T.int64(32),
+ T.int64(16),
+ T.int64(3),
+ T.int64(3),
+ T.int64(16),
+ ):
+ with T.block("conv2d_nhwc"):
+ v_nn, v_yy, v_xx, v_ff, v_ry, v_rx, v_rc = T.axis.remap(
+ "SSSSRRR", [nn, yy, xx, ff, ry, rx, rc]
+ )
+ T.reads(
+ pad_temp[v_nn, v_yy + v_ry, v_xx + v_rx, v_rc],
+ weight1[v_ff, v_ry, v_rx, v_rc],
+ )
+ T.writes(var_conv2d_nhwc_intermediate[v_nn, v_yy, v_xx,
v_ff])
+ with T.init():
+ var_conv2d_nhwc_intermediate[v_nn, v_yy, v_xx, v_ff] =
T.float16(0)
+ var_conv2d_nhwc_intermediate[v_nn, v_yy, v_xx, v_ff] = (
+ var_conv2d_nhwc_intermediate[v_nn, v_yy, v_xx, v_ff]
+ + pad_temp[v_nn, v_yy + v_ry, v_xx + v_rx, v_rc]
+ * weight1[v_ff, v_ry, v_rx, v_rc]
+ )
+ for i0, i1, i2, i3 in T.grid(T.int64(16), T.int64(32),
T.int64(32), T.int64(16)):
+ with T.block("compute"):
+ v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2,
i3])
+ T.reads(var_conv2d_nhwc_intermediate[v_i0, v_i1, v_i2,
v_i3])
+ T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3])
+ var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.max(
+ var_conv2d_nhwc_intermediate[v_i0, v_i1, v_i2, v_i3],
T.float16(0)
+ )
+
+ @T.prim_func
+ def layer_norm(
+ A: T.Buffer((T.int64(16), T.int64(32), T.int64(32), T.int64(16)),
"float16"),
+ B: T.Buffer((T.int64(16),), "float16"),
+ C: T.Buffer((T.int64(16),), "float16"),
+ T_layer_norm: T.Buffer((T.int64(16), T.int64(32), T.int64(32),
T.int64(16)), "float16"),
+ ):
+ T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ A_red_temp_v0 = T.alloc_buffer((T.int64(16), T.int64(32),
T.int64(32)))
+ A_red_temp_v1 = T.alloc_buffer((T.int64(16), T.int64(32),
T.int64(32)))
+ for ax0, ax1, ax2, k3 in T.grid(T.int64(16), T.int64(32),
T.int64(32), T.int64(16)):
+ with T.block("A_red_temp"):
+ v_ax0, v_ax1, v_ax2, v_k3 = T.axis.remap("SSSR", [ax0,
ax1, ax2, k3])
+ T.reads(A[v_ax0, v_ax1, v_ax2, v_k3])
+ T.writes(A_red_temp_v0[v_ax0, v_ax1, v_ax2],
A_red_temp_v1[v_ax0, v_ax1, v_ax2])
+ with T.init():
+ A_red_temp_v0[v_ax0, v_ax1, v_ax2] = T.float32(0)
+ A_red_temp_v1[v_ax0, v_ax1, v_ax2] = T.float32(0)
+ v_A_red_temp_v0: T.float32 = A_red_temp_v0[v_ax0, v_ax1,
v_ax2] + T.Cast(
+ "float32", A[v_ax0, v_ax1, v_ax2, v_k3]
+ )
+ v_A_red_temp_v1: T.float32 = A_red_temp_v1[v_ax0, v_ax1,
v_ax2] + T.Cast(
+ "float32", A[v_ax0, v_ax1, v_ax2, v_k3]
+ ) * T.Cast("float32", A[v_ax0, v_ax1, v_ax2, v_k3])
+ A_red_temp_v0[v_ax0, v_ax1, v_ax2] = v_A_red_temp_v0
+ A_red_temp_v1[v_ax0, v_ax1, v_ax2] = v_A_red_temp_v1
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(16), T.int64(32),
T.int64(32), T.int64(16)):
+ with T.block("T_layer_norm"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(
+ A[v_ax0, v_ax1, v_ax2, v_ax3],
+ A_red_temp_v0[v_ax0, v_ax1, v_ax2],
+ A_red_temp_v1[v_ax0, v_ax1, v_ax2],
+ B[v_ax3],
+ C[v_ax3],
+ )
+ T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3] = (
+ T.Cast(
+ "float16",
+ (
+ T.Cast("float32", A[v_ax0, v_ax1, v_ax2,
v_ax3])
+ - A_red_temp_v0[v_ax0, v_ax1, v_ax2] *
T.float32(0.0625)
+ )
+ * T.rsqrt(
+ A_red_temp_v1[v_ax0, v_ax1, v_ax2] *
T.float32(0.0625)
+ - A_red_temp_v0[v_ax0, v_ax1, v_ax2]
+ * T.float32(0.0625)
+ * (A_red_temp_v0[v_ax0, v_ax1, v_ax2] *
T.float32(0.0625))
+ + T.float32(1.0000000000000001e-05)
+ ),
+ )
+ * B[v_ax3]
+ + C[v_ax3]
+ )
+
+ @R.function(private=True)
+ def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object):
+ R.func_attr({"relax.force_pure": True})
+ storage: R.Object = R.memory.alloc_storage(
+ R.shape([524288]), R.prim_value(0), R.str("global"),
R.dtype("float16")
+ )
+ storage1: R.Object = R.memory.alloc_storage(
+ R.shape([524288]), R.prim_value(0), R.str("global"),
R.dtype("float16")
+ )
+ gv: R.Tuple(R.Object, R.Object) = storage, storage1
+ return gv
+
+ @R.function(private=True)
+ def cuda_graph_capture(
+ lv: R.Tensor((16, 32, 32, 16), dtype="float16"),
+ lv1: R.Tensor((16, 3, 3, 16), dtype="float16"),
+ alloc1: R.Tensor((16, 32, 32, 16), dtype="float16"),
+ alloc: R.Tensor((16, 32, 32, 16), dtype="float16"),
+ params: R.Tuple(
+ R.Tensor((16, 3, 3, 16), dtype="float16"),
+ R.Tensor((16, 3, 3, 16), dtype="float16"),
+ R.Tensor((16, 3, 3, 16), dtype="float16"),
+ R.Tensor((16,), dtype="float16"),
+ R.Tensor((16,), dtype="float16"),
+ ),
+ storage: R.Object,
+ ) -> R.Tuple(
+ R.Tensor((16, 32, 32, 16), dtype="float16"),
+ R.Tensor((16, 3, 3, 16), dtype="float16"),
+ R.Tensor((16, 32, 32, 16), dtype="float16"),
+ ):
+ R.func_attr({"relax.force_pure": True})
+ cls = Expected
+ _1: R.Tuple = cls.fused_conv2d_relu(lv, lv1, alloc1)
+ _: R.Tuple = R.memory.kill_tensor(alloc)
+ lv1_1: R.Tensor((16, 32, 32, 16), dtype="float16") = alloc1
+ lv2: R.Tensor((16,), dtype="float16") = params[3]
+ lv3: R.Tensor((16,), dtype="float16") = params[4]
+ alloc2: R.Tensor((16, 32, 32, 16), dtype="float16") =
R.memory.alloc_tensor(
+ storage, R.prim_value(0), R.shape([16, 32, 32, 16]),
R.dtype("float16")
+ )
+ _2: R.Tuple = cls.layer_norm(lv1_1, lv2, lv3, alloc2)
+ _1_1: R.Tuple = R.memory.kill_tensor(alloc1)
+ ln: R.Tensor((16, 32, 32, 16), dtype="float16") = alloc2
+ lv4: R.Tensor((16, 3, 3, 16), dtype="float16") = params[2]
+ gv: R.Tuple(
+ R.Tensor((16, 32, 32, 16), dtype="float16"),
+ R.Tensor((16, 3, 3, 16), dtype="float16"),
+ R.Tensor((16, 32, 32, 16), dtype="float16"),
+ ) = (ln, lv4, alloc2)
+ return gv
+
+ @R.function
+ def main_transform_params(
+ params: R.Tuple(
+ R.Tensor((16, 3, 3, 16), dtype="float16"),
+ R.Tensor((16, 3, 3, 16), dtype="float16"),
+ R.Tensor((16, 3, 3, 16), dtype="float16"),
+ R.Tensor((16,), dtype="float16"),
+ R.Tensor((16,), dtype="float16"),
+ )
+ ) -> R.Tuple(
+ R.Tensor((16, 3, 3, 16), dtype="float16"),
+ R.Tensor((16, 3, 3, 16), dtype="float16"),
+ R.Tensor((16, 3, 3, 16), dtype="float16"),
+ R.Tensor((16,), dtype="float16"),
+ R.Tensor((16,), dtype="float16"),
+ ):
+ R.func_attr({"relax.force_pure": True})
+ lv: R.Tensor((16, 3, 3, 16), dtype="float16") = params[0]
+ lv1: R.Tensor((16, 3, 3, 16), dtype="float16") = params[1]
+ lv2: R.Tensor((16, 3, 3, 16), dtype="float16") = params[2]
+ lv3: R.Tensor((16,), dtype="float16") = params[3]
+ lv4: R.Tensor((16,), dtype="float16") = params[4]
+ gv: R.Tuple(
+ R.Tensor((16, 3, 3, 16), dtype="float16"),
+ R.Tensor((16, 3, 3, 16), dtype="float16"),
+ R.Tensor((16, 3, 3, 16), dtype="float16"),
+ R.Tensor((16,), dtype="float16"),
+ R.Tensor((16,), dtype="float16"),
+ ) = (lv, lv1, lv2, lv3, lv4)
+ return gv
+
+ @R.function
+ def main(
+ data: R.Tensor((16, 32, 32, 16), dtype="float16"),
+ params: R.Tuple(
+ R.Tensor((16, 3, 3, 16), dtype="float16"),
+ R.Tensor((16, 3, 3, 16), dtype="float16"),
+ R.Tensor((16, 3, 3, 16), dtype="float16"),
+ R.Tensor((16,), dtype="float16"),
+ R.Tensor((16,), dtype="float16"),
+ ),
+ ) -> R.Tensor((16, 32, 32, 16), dtype="float16"):
+ R.func_attr({"num_input": 1, "relax.force_pure": True})
+ cls = Expected
+ lv: R.Tensor((16, 3, 3, 16), dtype="float16") = params[0]
+ gv: R.Tuple(R.Object, R.Object) = R.call_builtin_with_ctx(
+ "vm.builtin.cuda_graph.get_cached_alloc",
+ (cls.cuda_graph_alloc, R.prim_value(0)),
+ sinfo_args=(R.Tuple(R.Object, R.Object),),
+ )
+ storage: R.Object = gv[0]
+ alloc: R.Tensor((16, 32, 32, 16), dtype="float16") =
R.memory.alloc_tensor(
+ storage, R.prim_value(0), R.shape([16, 32, 32, 16]),
R.dtype("float16")
+ )
+ _: R.Tuple = cls.fused_conv2d_relu(data, lv, alloc)
+ lv_1: R.Tensor((16, 32, 32, 16), dtype="float16") = alloc
+ lv1: R.Tensor((16, 3, 3, 16), dtype="float16") = params[1]
+ storage1: R.Object = gv[1]
+ alloc1: R.Tensor((16, 32, 32, 16), dtype="float16") =
R.memory.alloc_tensor(
+ storage1, R.prim_value(0), R.shape([16, 32, 32, 16]),
R.dtype("float16")
+ )
+ gv1: R.Tuple(
+ R.Tensor((16, 32, 32, 16), dtype="float16"),
+ R.Tensor((16, 3, 3, 16), dtype="float16"),
+ R.Tensor((16, 32, 32, 16), dtype="float16"),
+ ) = R.call_builtin_with_ctx(
+ "vm.builtin.cuda_graph.run_or_capture",
+ (
+ cls.cuda_graph_capture,
+ (lv_1, lv1, alloc1, alloc, params, storage),
+ R.prim_value(0),
+ ),
+ sinfo_args=(
+ R.Tuple(
+ R.Tensor((16, 32, 32, 16), dtype="float16"),
+ R.Tensor((16, 3, 3, 16), dtype="float16"),
+ R.Tensor((16, 32, 32, 16), dtype="float16"),
+ ),
+ ),
+ )
+ alloc2: R.Tensor((16, 32, 32, 16), dtype="float16") = gv1[2]
+ ln: R.Tensor((16, 32, 32, 16), dtype="float16") = gv1[0]
+ lv4: R.Tensor((16, 3, 3, 16), dtype="float16") = gv1[1]
+ alloc3: R.Tensor((16, 32, 32, 16), dtype="float16") =
R.builtin.alloc_tensor(
+ R.shape([16, 32, 32, 16]), R.dtype("float16"), R.prim_value(0)
+ )
+ _3: R.Tuple = cls.fused_conv2d_relu(ln, lv4, alloc3)
+ _2: R.Tuple = R.memory.kill_tensor(alloc2)
+ gv_1: R.Tensor((16, 32, 32, 16), dtype="float16") = alloc3
+ _3_1: R.Tuple = R.memory.kill_storage(storage)
+ _4: R.Tuple = R.memory.kill_storage(storage1)
+ return gv_1
+
+ mod = tvm.transform.Sequential(
+ [
+ relax.pipeline.get_pipeline(),
+ relax.transform.LiftTransformParams(),
+ relax.transform.ToNonDataflow(),
+ relax.transform.RemovePurityChecking(),
+ relax.transform.CallTIRRewrite(),
+ relax.transform.StaticPlanBlockMemory(),
+ ]
+ )(Conv2dx3)
+
+ mod["main"] = mod["main"].with_attr({"num_input": 1})
+ after = relax.transform.RewriteCUDAGraph()(mod)
+ tvm.ir.assert_structural_equal(after, after)
+
+
if __name__ == "__main__":
tvm.testing.main()