This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 783b467845 [Unity] CUDA Graph update  (#15320)
783b467845 is described below

commit 783b46784530eb21e927cb62e1675afd680610ad
Author: masahi <[email protected]>
AuthorDate: Sat Jul 15 08:23:31 2023 +0900

    [Unity] CUDA Graph update  (#15320)
    
    * allow capturing input parameters in a cuda graph
    
    * remove unnecessary cudaGraphLaunch
    
    * support cuda graph for cutlass
    
    * add test
    
    * add test for cutlass
    
    * revert LiftTransformParams change
    
    * comment
    
    * update test
    
    * update builtin
    
    * update
    
    * delete exec properly
    
    * run cuda graph twice in the test to make sure cached launch works
---
 python/tvm/contrib/cutlass/attention_operation.py  |   6 +-
 python/tvm/contrib/cutlass/conv2d_operation.py     |   7 +-
 python/tvm/contrib/cutlass/gemm_operation.py       |  47 +--
 python/tvm/contrib/cutlass/gen_tensor_op.py        |   2 +-
 python/tvm/contrib/cutlass/layer_norm_operation.py |   6 +-
 python/tvm/contrib/cutlass/rms_norm_operation.py   |   6 +-
 src/relax/transform/rewrite_cuda_graph.cc          |  20 +-
 src/runtime/cuda/cuda_device_api.cc                |   4 +
 src/runtime/relax_vm/cuda/cuda_graph_builtin.cc    |  47 +--
 tests/python/relax/test_codegen_cutlass.py         |  70 ++++-
 .../relax/test_transform_rewrite_cuda_graph.py     | 322 +++++++++++++++++++++
 11 files changed, 465 insertions(+), 72 deletions(-)

diff --git a/python/tvm/contrib/cutlass/attention_operation.py 
b/python/tvm/contrib/cutlass/attention_operation.py
index 7240f24de4..b6a9517f80 100644
--- a/python/tvm/contrib/cutlass/attention_operation.py
+++ b/python/tvm/contrib/cutlass/attention_operation.py
@@ -139,7 +139,11 @@ def instantiate_attention_template(attrs):
   }
 
   CHECK(Attention::check_supported(p));
-  kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);
+  auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
+  ICHECK(func != nullptr);
+  cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());
+
+  kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes, stream>>>(p);
 
   if (accumulator_buf_allocated) {
     cudaFree(p.output_accum_ptr);
diff --git a/python/tvm/contrib/cutlass/conv2d_operation.py 
b/python/tvm/contrib/cutlass/conv2d_operation.py
index 8f85fc5382..77f4449db2 100644
--- a/python/tvm/contrib/cutlass/conv2d_operation.py
+++ b/python/tvm/contrib/cutlass/conv2d_operation.py
@@ -423,7 +423,12 @@ def instantiate_conv2d_template(attrs):
   status = conv2d_op.initialize(arguments, workspace.get());
   CHECK(status == cutlass::Status::kSuccess);
   ${split_k_update}
-  status = conv2d_op();
+
+  auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
+  ICHECK(func != nullptr);
+  cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());
+
+  status = conv2d_op(stream);
   CHECK(status == cutlass::Status::kSuccess);
   ${split_k_reduction}
 """
diff --git a/python/tvm/contrib/cutlass/gemm_operation.py 
b/python/tvm/contrib/cutlass/gemm_operation.py
index 86f6e97719..3fa6e9be8d 100644
--- a/python/tvm/contrib/cutlass/gemm_operation.py
+++ b/python/tvm/contrib/cutlass/gemm_operation.py
@@ -344,7 +344,12 @@ def instantiate_gemm_template(attrs):
   CHECK(status == cutlass::Status::kSuccess);
   status = gemm_op.initialize(arguments, workspace.get());
   CHECK(status == cutlass::Status::kSuccess);
-  status = gemm_op();
+
+  auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
+  ICHECK(func != nullptr);
+  cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());
+
+  status = gemm_op(stream);
   CHECK(status == cutlass::Status::kSuccess);
 """
     op_type = attrs["op_type"]
@@ -416,38 +421,34 @@ def emit_fp16A_int4B_matmul(attrs):
   int m = ${A_arg}->shape[${batch_offset}];
   int n = ${B_arg}->shape[1] * 2;
   int k = ${B_arg}->shape[0];
+
+  auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
+  ICHECK(func != nullptr);
+  cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*());
     """,
         attrs,
     )
 
     template = """
   ${template_common}
-  gemm_fp16_int4(static_cast<cutlass::half_t*>(${A_arg}->data),
-                 static_cast<cutlass::uint4b_t*>(${B_arg}->data),
-                 static_cast<cutlass::half_t*>(${scales_arg}->data),
-                 static_cast<cutlass::half_t*>(out0->data),
-                 m, n, k, nullptr, 0, nullptr);
-"""
-
-    template_bias = """
-  ${template_common}
-  gemm_fp16_int4_bias(static_cast<cutlass::half_t*>(${A_arg}->data),
-                 static_cast<cutlass::uint4b_t*>(${B_arg}->data),
-                 static_cast<cutlass::half_t*>(${scales_arg}->data),
-                 static_cast<cutlass::half_t*>(${bias_arg}->data),
-                 static_cast<cutlass::half_t*>(out0->data),
-                 m, n, k, nullptr, 0, nullptr);
+  gemm_fp16_int_bias_act(static_cast<cutlass::half_t*>(${A_arg}->data),
+                static_cast<${weight_dtype}*>(${B_arg}->data),
+                static_cast<cutlass::half_t*>(${scales_arg}->data),
+                ${bias},
+                static_cast<cutlass::half_t*>(out0->data),
+                "${activation}",
+                m, n, k, ${bias_stride}, nullptr, 0, stream);
 """
 
     template_residual = """
   ${template_common}
-  
gemm_fp16_int4_bias_act_residual(static_cast<cutlass::half_t*>(${A_arg}->data),
-                 static_cast<cutlass::uint4b_t*>(${B_arg}->data),
-                 static_cast<cutlass::half_t*>(${scales_arg}->data),
-                 ${bias},
-                 static_cast<cutlass::half_t*>(${residual_arg}->data),
-                 static_cast<cutlass::half_t*>(out0->data), "${activation}", 
"${binary_op}", "${unary_op}",
-                 m, n, k, nullptr, 0, nullptr);
+  
gemm_fp16_int_bias_act_residual(static_cast<cutlass::half_t*>(${A_arg}->data),
+                static_cast<${weight_dtype}*>(${B_arg}->data),
+                static_cast<cutlass::half_t*>(${scales_arg}->data),
+                ${bias},
+                static_cast<cutlass::half_t*>(${residual_arg}->data),
+                static_cast<cutlass::half_t*>(out0->data), "${activation}", 
"${binary_op}", "${unary_op}",
+                m, n, k, nullptr, 0, stream);
 """
 
     if "residual_arg" in attrs and "bias_arg" in attrs:
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py 
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index 0aaafe8505..8c8bcc20c3 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -483,7 +483,7 @@ def instantiate_template(func_name, annotations, func_args):
         if k in annotations:
             attrs[k] = annotations[k]
 
-    headers = []
+    headers = ["tvm/runtime/registry.h"]
 
     if "relu" in func_name:
         
headers.append("cutlass/epilogue/thread/linear_combination_bias_relu.h")
diff --git a/python/tvm/contrib/cutlass/layer_norm_operation.py 
b/python/tvm/contrib/cutlass/layer_norm_operation.py
index 589f559e93..ad2e730d27 100644
--- a/python/tvm/contrib/cutlass/layer_norm_operation.py
+++ b/python/tvm/contrib/cutlass/layer_norm_operation.py
@@ -39,6 +39,10 @@ def instantiate_layer_norm_template(attrs):
     cutlass::TensorRef<data_type, RowMajor> _beta((data_type*)${beta}->data, 
layout_channels);
     cutlass::TensorRef<data_type, RowMajor> _output((data_type*)out0->data, 
layout_2D);
 
-    cutlass::layernorm(size, _output, _input, _gamma, _beta, NULL);
+    auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
+    ICHECK(func != nullptr);
+    cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator 
void*());
+
+    cutlass::layernorm(size, _output, _input, _gamma, _beta, stream);
     """
     return substitute_template(template, attrs)
diff --git a/python/tvm/contrib/cutlass/rms_norm_operation.py 
b/python/tvm/contrib/cutlass/rms_norm_operation.py
index e24d6bc39a..ef0d8ef61f 100644
--- a/python/tvm/contrib/cutlass/rms_norm_operation.py
+++ b/python/tvm/contrib/cutlass/rms_norm_operation.py
@@ -38,6 +38,10 @@ def instantiate_rms_norm_template(attrs):
     cutlass::TensorRef<data_type, RowMajor> 
_weight((data_type*)${weight}->data, layout_channels);
     cutlass::TensorRef<data_type, RowMajor> _output((data_type*)out0->data, 
layout_2D);
 
-    cutlass::rmsnorm(size, _output, _input, _weight, nullptr);
+    auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
+    ICHECK(func != nullptr);
+    cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator 
void*());
+
+    cutlass::rmsnorm(size, _output, _input, _weight, stream);
     """
     return substitute_template(template, attrs)
diff --git a/src/relax/transform/rewrite_cuda_graph.cc 
b/src/relax/transform/rewrite_cuda_graph.cc
index af10640635..088fa38758 100644
--- a/src/relax/transform/rewrite_cuda_graph.cc
+++ b/src/relax/transform/rewrite_cuda_graph.cc
@@ -150,8 +150,16 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
   explicit CUDAGraphRewritePlanner(const IRModule& mod) : mod_(mod) {}
   std::vector<LiftedFunctionRewritePlan> Plan() {
     for (const auto& pair : mod_->functions) {
-      const auto& func = pair.second;
-      if (func->IsInstance<FunctionNode>()) {
+      if (pair.second->IsInstance<FunctionNode>()) {
+        // If a function has the num_input attribute, the last 
func->params.size() - num_inputs
+        // inputs are assumed to be fixed and thus they can be captured into a 
cuda graph.
+        static const char* attr_num_input = "num_input";
+        const auto& func = Downcast<Function>(pair.second);
+        if (auto num_input = func->attrs.GetAttr<Integer>(attr_num_input)) {
+          for (size_t i = num_input.value().IntValue(); i < 
func->params.size(); ++i) {
+            static_vars_.insert(func->params[i].get());
+          }
+        }
         VisitExpr(func);
       }
     }
@@ -349,7 +357,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
       if (vars_collector != nullptr) {
         vars_collector->push_back(var);
       }
-      return static_bindings_.count(var);
+      return static_vars_.count(var);
     }
 
     if (const auto* shape = expr.as<ShapeExprNode>()) {
@@ -402,7 +410,7 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
       current_.capture_builder->AddBinding(binding);
       binding_to_region_[binding->var.get()] = current_.capture_builder;
     }
-    static_bindings_.emplace(binding->var.get(), GetRef<VarBinding>(binding));
+    static_vars_.emplace(binding->var.get());
   }
 
   /*! \brief The states of the current scope (the BindingBlock) which is a 
pair of FuncBuilder.
@@ -419,8 +427,8 @@ class CUDAGraphRewritePlanner : public ExprVisitor {
   IRModule mod_;
   // States of the current scope
   Scope current_;
-  // All the static bindings
-  std::unordered_map<const VarNode*, VarBinding> static_bindings_;
+  // Variables whose buffer address is fixed
+  std::unordered_set<const VarNode*> static_vars_;
   // Binding to the FuncBuilder if the binding is lifted. This is used to 
update the inputs/outputs
   // of the lifted function when its binding is used outside.
   std::unordered_map<const VarNode*, FuncBuilder*> binding_to_region_;
diff --git a/src/runtime/cuda/cuda_device_api.cc 
b/src/runtime/cuda/cuda_device_api.cc
index 71788e5299..b8854f88cb 100644
--- a/src/runtime/cuda/cuda_device_api.cc
+++ b/src/runtime/cuda/cuda_device_api.cc
@@ -300,5 +300,9 @@ TVM_DLL String GetCudaFreeMemory() {
 
 
TVM_REGISTER_GLOBAL("runtime.GetCudaFreeMemory").set_body_typed(GetCudaFreeMemory);
 
+TVM_REGISTER_GLOBAL("runtime.get_cuda_stream").set_body_typed([]() {
+  return static_cast<void*>(CUDAThreadEntry::ThreadLocal()->stream);
+});
+
 }  // namespace runtime
 }  // namespace tvm
diff --git a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc 
b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
index 9d2025d647..f6eef9ca25 100644
--- a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
+++ b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
@@ -31,45 +31,22 @@ namespace tvm {
 namespace runtime {
 namespace relax_vm {
 
-/*! \brief Represents a CUDA graph. */
-class CUDAGraphNode : public Object {
- public:
-  cudaGraph_t handle_ = nullptr;
-
-  ~CUDAGraphNode() {
-    if (handle_ != nullptr) {
-      cudaGraphDestroy(handle_);
-    }
-  }
-
-  TVM_DECLARE_FINAL_OBJECT_INFO(CUDAGraphNode, Object);
-};
-
-/*!
- * \brief Managed reference to CUDAGraphNode
- * \sa CUDAGraphNode
- */
-class CUDAGraph : public ObjectRef {
- public:
-  explicit CUDAGraph(cudaGraph_t handle) {
-    auto n = make_object<CUDAGraphNode>();
-    n->handle_ = handle;
-    data_ = std::move(n);
-  }
-  TVM_DEFINE_OBJECT_REF_METHODS(CUDAGraph, ObjectRef, CUDAGraphNode);
-};
-
 /*! \brief The cache states of a CUDA graph. */
 class CUDAGraphCache : public Object {
  public:
   struct CaptureResult {
+    ~CaptureResult() {
+      if (exec) {
+        CUDA_CALL(cudaGraphExecDestroy(exec));
+      }
+    }
     /*!
      * \brief Tuple of intemediate tensors in the capture func that will be 
used outside the
      * capture func
      */
     ObjectRef states;
-    /*! \brief The cuda graph instance */
-    CUDAGraph graph;
+    /*! \brief The instantiated cuda graph */
+    cudaGraphExec_t exec = nullptr;
   };
 
   static CUDAGraphCache* Get() { return 
dmlc::ThreadLocalStore<CUDAGraphCache>::Get(); }
@@ -88,11 +65,8 @@ class CUDAGraphCache : public Object {
                          int64_t entry_index) {
     if (auto it = capture_cache_.find(entry_index); it != 
capture_cache_.end()) {
       // Launch CUDA graph
-      const auto& [states, cuda_graph] = it->second;
-      cudaGraphExec_t cuda_graph_exec;
-      CUDA_CALL(cudaGraphInstantiate(&cuda_graph_exec, cuda_graph->handle_, 
NULL, NULL, 0));
-      CUDA_CALL(cudaGraphLaunch(cuda_graph_exec, 
CUDAThreadEntry::ThreadLocal()->stream));
-      CUDA_CALL(cudaGraphExecDestroy(cuda_graph_exec));
+      const auto& [states, exec] = it->second;
+      CUDA_CALL(cudaGraphLaunch(exec, CUDAThreadEntry::ThreadLocal()->stream));
       return states;
     }
 
@@ -129,9 +103,10 @@ class CUDAGraphCache : public Object {
     CUDA_CALL(cudaStreamEndCapture(CUDAThreadEntry::ThreadLocal()->stream, 
&graph));
     std::swap(capture_stream, CUDAThreadEntry::ThreadLocal()->stream);
 
-    entry.graph = CUDAGraph(graph);
     capture_cache_[entry_index] = entry;
+    CUDA_CALL(cudaGraphInstantiate(&capture_cache_[entry_index].exec, graph, 
NULL, NULL, 0));
     CUDA_CALL(cudaStreamDestroy(capture_stream));
+    CUDA_CALL(cudaGraphDestroy(graph));
     return entry.states;
   }
 
diff --git a/tests/python/relax/test_codegen_cutlass.py 
b/tests/python/relax/test_codegen_cutlass.py
index e1ce46ecb0..1528141e4a 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -85,15 +85,23 @@ cutlass_enabled = pytest.mark.skipif(
 pytestmark = [cutlass_enabled]
 
 
-def build_and_run(mod, inputs_np, target, legalize=True):
+def build_and_run(mod, inputs_np, target, legalize=True, cuda_graph=False):
     if legalize:
         mod = relax.transform.LegalizeOps()(mod)  # For cpu reference, nop for 
cutlass.
 
+    with tvm.transform.PassContext(config={"relax.backend.use_cuda_graph": 
cuda_graph}):
+        ex = relax.build(mod, target)
+
     dev = tvm.device(target, 0)
-    ex = relax.build(mod, target)
     vm = relax.VirtualMachine(ex, dev)
     f = vm["main"]
     inputs = [tvm.nd.array(inp, dev) for inp in inputs_np]
+
+    # For cuda graph, run the compiled function twice to make sure that we can 
launch the cached
+    # graph on the second run.
+    if cuda_graph:
+        f(*inputs)
+
     return f(*inputs).numpy()
 
 
@@ -1554,5 +1562,63 @@ def test_rms_norm():
     tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
 
 
+def test_conv2d_cuda_graph():
+    @tvm.script.ir_module
+    class Conv2d:
+        @R.function
+        def main(
+            data: R.Tensor((16, 32, 32, 16), "float16"),
+            weight1: R.Tensor((16, 3, 3, 16), "float16"),
+            weight2: R.Tensor((16, 3, 3, 16), "float16"),
+            weight3: R.Tensor((16, 3, 3, 16), "float16"),
+            gamma: R.Tensor((16,), "float16"),
+            beta: R.Tensor((16,), "float16"),
+        ):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                conv1 = R.nn.relu(
+                    R.nn.conv2d(
+                        data, weight1, padding=(1, 1), data_layout="NHWC", 
kernel_layout="OHWI"
+                    )
+                )
+                conv2 = R.nn.relu(
+                    R.nn.conv2d(
+                        conv1, weight2, padding=(1, 1), data_layout="NHWC", 
kernel_layout="OHWI"
+                    )
+                )
+                ln = R.nn.layer_norm(conv2, gamma, beta, axes=[-1])
+                conv3 = R.nn.relu(
+                    R.nn.conv2d(
+                        ln, weight3, padding=(1, 1), data_layout="NHWC", 
kernel_layout="OHWI"
+                    )
+                )
+                R.output(conv3)
+
+            return conv3
+
+    low, high = -1, 1
+    data_shape = (16, 32, 32, 16)
+    weight_shape = (16, 3, 3, 16)
+    dtype = "float16"
+    data = np.random.randint(low, high, size=data_shape).astype(dtype)
+    weight1 = np.random.randint(low, high, size=weight_shape).astype(dtype)
+    weight2 = np.random.randint(low, high, size=weight_shape).astype(dtype)
+    weight3 = np.random.randint(low, high, size=weight_shape).astype(dtype)
+    gamma = np.random.randint(low, high, size=(weight_shape[0],)).astype(dtype)
+    beta = np.random.randint(low, high, size=(weight_shape[0],)).astype(dtype)
+    inputs = [data, weight1, weight2, weight3, gamma, beta]
+
+    mod = partition_for_cutlass(Conv2d)
+    mod = relax.transform.RunCodegen({"cutlass": {"sm": 80, 
"find_first_valid": True}})(mod)
+    mod = relax.pipeline.get_pipeline()(mod)  # pylint: 
disable=no-value-for-parameter
+
+    with tvm.target.Target("cuda"):
+        mod = tvm.tir.transform.DefaultGPUSchedule()(mod)
+
+    out = build_and_run(mod, inputs, "cuda", cuda_graph=True)
+    ref = build_and_run(Conv2d, inputs, "llvm", legalize=True)
+    tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_transform_rewrite_cuda_graph.py 
b/tests/python/relax/test_transform_rewrite_cuda_graph.py
index 52362eae01..106147ef9a 100644
--- a/tests/python/relax/test_transform_rewrite_cuda_graph.py
+++ b/tests/python/relax/test_transform_rewrite_cuda_graph.py
@@ -339,5 +339,327 @@ def test_vm_builtin():
     tvm.ir.assert_structural_equal(after, Expected)
 
 
+def test_capture_fixed_inputs():
+    @tvm.script.ir_module
+    class Conv2dx3:
+        @R.function
+        def main(
+            data: R.Tensor((16, 32, 32, 16), "float16"),
+            weight1: R.Tensor((16, 3, 3, 16), "float16"),
+            weight2: R.Tensor((16, 3, 3, 16), "float16"),
+            weight3: R.Tensor((16, 3, 3, 16), "float16"),
+            gamma: R.Tensor((16,), "float16"),
+            beta: R.Tensor((16,), "float16"),
+        ):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                conv1 = R.nn.relu(
+                    R.nn.conv2d(
+                        data, weight1, padding=(1, 1), data_layout="NHWC", 
kernel_layout="OHWI"
+                    )
+                )
+
+                
###############################################################################
+                # The second conv2d and layer norm can be captured into a graph
+                conv2 = R.nn.relu(
+                    R.nn.conv2d(
+                        conv1, weight2, padding=(1, 1), data_layout="NHWC", 
kernel_layout="OHWI"
+                    )
+                )
+                ln = R.nn.layer_norm(conv2, gamma, beta, axes=[-1])
+                
###############################################################################
+
+                conv3 = R.nn.relu(
+                    R.nn.conv2d(
+                        ln, weight3, padding=(1, 1), data_layout="NHWC", 
kernel_layout="OHWI"
+                    )
+                )
+                R.output(conv3)
+
+            return conv3
+
+    @I.ir_module
+    class Expected:
+        @T.prim_func
+        def fused_conv2d_relu(
+            data: T.Buffer((T.int64(16), T.int64(32), T.int64(32), 
T.int64(16)), "float16"),
+            weight1: T.Buffer((T.int64(16), T.int64(3), T.int64(3), 
T.int64(16)), "float16"),
+            var_compute_intermediate: T.Buffer(
+                (T.int64(16), T.int64(32), T.int64(32), T.int64(16)), "float16"
+            ),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            pad_temp = T.alloc_buffer(
+                (T.int64(16), T.int64(34), T.int64(34), T.int64(16)), "float16"
+            )
+            var_conv2d_nhwc_intermediate = T.alloc_buffer(
+                (T.int64(16), T.int64(32), T.int64(32), T.int64(16)), "float16"
+            )
+            for i0, i1, i2, i3 in T.grid(T.int64(16), T.int64(34), 
T.int64(34), T.int64(16)):
+                with T.block("pad_temp"):
+                    v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, 
i3])
+                    T.reads(data[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1), 
v_i3])
+                    T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3])
+                    pad_temp[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(
+                        T.int64(1) <= v_i1
+                        and v_i1 < T.int64(33)
+                        and T.int64(1) <= v_i2
+                        and v_i2 < T.int64(33),
+                        data[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1), v_i3],
+                        T.float16(0),
+                    )
+            for nn, yy, xx, ff, ry, rx, rc in T.grid(
+                T.int64(16),
+                T.int64(32),
+                T.int64(32),
+                T.int64(16),
+                T.int64(3),
+                T.int64(3),
+                T.int64(16),
+            ):
+                with T.block("conv2d_nhwc"):
+                    v_nn, v_yy, v_xx, v_ff, v_ry, v_rx, v_rc = T.axis.remap(
+                        "SSSSRRR", [nn, yy, xx, ff, ry, rx, rc]
+                    )
+                    T.reads(
+                        pad_temp[v_nn, v_yy + v_ry, v_xx + v_rx, v_rc],
+                        weight1[v_ff, v_ry, v_rx, v_rc],
+                    )
+                    T.writes(var_conv2d_nhwc_intermediate[v_nn, v_yy, v_xx, 
v_ff])
+                    with T.init():
+                        var_conv2d_nhwc_intermediate[v_nn, v_yy, v_xx, v_ff] = 
T.float16(0)
+                    var_conv2d_nhwc_intermediate[v_nn, v_yy, v_xx, v_ff] = (
+                        var_conv2d_nhwc_intermediate[v_nn, v_yy, v_xx, v_ff]
+                        + pad_temp[v_nn, v_yy + v_ry, v_xx + v_rx, v_rc]
+                        * weight1[v_ff, v_ry, v_rx, v_rc]
+                    )
+            for i0, i1, i2, i3 in T.grid(T.int64(16), T.int64(32), 
T.int64(32), T.int64(16)):
+                with T.block("compute"):
+                    v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, 
i3])
+                    T.reads(var_conv2d_nhwc_intermediate[v_i0, v_i1, v_i2, 
v_i3])
+                    T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3])
+                    var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.max(
+                        var_conv2d_nhwc_intermediate[v_i0, v_i1, v_i2, v_i3], 
T.float16(0)
+                    )
+
+        @T.prim_func
+        def layer_norm(
+            A: T.Buffer((T.int64(16), T.int64(32), T.int64(32), T.int64(16)), 
"float16"),
+            B: T.Buffer((T.int64(16),), "float16"),
+            C: T.Buffer((T.int64(16),), "float16"),
+            T_layer_norm: T.Buffer((T.int64(16), T.int64(32), T.int64(32), 
T.int64(16)), "float16"),
+        ):
+            T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            A_red_temp_v0 = T.alloc_buffer((T.int64(16), T.int64(32), 
T.int64(32)))
+            A_red_temp_v1 = T.alloc_buffer((T.int64(16), T.int64(32), 
T.int64(32)))
+            for ax0, ax1, ax2, k3 in T.grid(T.int64(16), T.int64(32), 
T.int64(32), T.int64(16)):
+                with T.block("A_red_temp"):
+                    v_ax0, v_ax1, v_ax2, v_k3 = T.axis.remap("SSSR", [ax0, 
ax1, ax2, k3])
+                    T.reads(A[v_ax0, v_ax1, v_ax2, v_k3])
+                    T.writes(A_red_temp_v0[v_ax0, v_ax1, v_ax2], 
A_red_temp_v1[v_ax0, v_ax1, v_ax2])
+                    with T.init():
+                        A_red_temp_v0[v_ax0, v_ax1, v_ax2] = T.float32(0)
+                        A_red_temp_v1[v_ax0, v_ax1, v_ax2] = T.float32(0)
+                    v_A_red_temp_v0: T.float32 = A_red_temp_v0[v_ax0, v_ax1, 
v_ax2] + T.Cast(
+                        "float32", A[v_ax0, v_ax1, v_ax2, v_k3]
+                    )
+                    v_A_red_temp_v1: T.float32 = A_red_temp_v1[v_ax0, v_ax1, 
v_ax2] + T.Cast(
+                        "float32", A[v_ax0, v_ax1, v_ax2, v_k3]
+                    ) * T.Cast("float32", A[v_ax0, v_ax1, v_ax2, v_k3])
+                    A_red_temp_v0[v_ax0, v_ax1, v_ax2] = v_A_red_temp_v0
+                    A_red_temp_v1[v_ax0, v_ax1, v_ax2] = v_A_red_temp_v1
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(16), T.int64(32), 
T.int64(32), T.int64(16)):
+                with T.block("T_layer_norm"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(
+                        A[v_ax0, v_ax1, v_ax2, v_ax3],
+                        A_red_temp_v0[v_ax0, v_ax1, v_ax2],
+                        A_red_temp_v1[v_ax0, v_ax1, v_ax2],
+                        B[v_ax3],
+                        C[v_ax3],
+                    )
+                    T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3] = (
+                        T.Cast(
+                            "float16",
+                            (
+                                T.Cast("float32", A[v_ax0, v_ax1, v_ax2, 
v_ax3])
+                                - A_red_temp_v0[v_ax0, v_ax1, v_ax2] * 
T.float32(0.0625)
+                            )
+                            * T.rsqrt(
+                                A_red_temp_v1[v_ax0, v_ax1, v_ax2] * 
T.float32(0.0625)
+                                - A_red_temp_v0[v_ax0, v_ax1, v_ax2]
+                                * T.float32(0.0625)
+                                * (A_red_temp_v0[v_ax0, v_ax1, v_ax2] * 
T.float32(0.0625))
+                                + T.float32(1.0000000000000001e-05)
+                            ),
+                        )
+                        * B[v_ax3]
+                        + C[v_ax3]
+                    )
+
+        @R.function(private=True)
+        def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object):
+            R.func_attr({"relax.force_pure": True})
+            storage: R.Object = R.memory.alloc_storage(
+                R.shape([524288]), R.prim_value(0), R.str("global"), 
R.dtype("float16")
+            )
+            storage1: R.Object = R.memory.alloc_storage(
+                R.shape([524288]), R.prim_value(0), R.str("global"), 
R.dtype("float16")
+            )
+            gv: R.Tuple(R.Object, R.Object) = storage, storage1
+            return gv
+
+        @R.function(private=True)
+        def cuda_graph_capture(
+            lv: R.Tensor((16, 32, 32, 16), dtype="float16"),
+            lv1: R.Tensor((16, 3, 3, 16), dtype="float16"),
+            alloc1: R.Tensor((16, 32, 32, 16), dtype="float16"),
+            alloc: R.Tensor((16, 32, 32, 16), dtype="float16"),
+            params: R.Tuple(
+                R.Tensor((16, 3, 3, 16), dtype="float16"),
+                R.Tensor((16, 3, 3, 16), dtype="float16"),
+                R.Tensor((16, 3, 3, 16), dtype="float16"),
+                R.Tensor((16,), dtype="float16"),
+                R.Tensor((16,), dtype="float16"),
+            ),
+            storage: R.Object,
+        ) -> R.Tuple(
+            R.Tensor((16, 32, 32, 16), dtype="float16"),
+            R.Tensor((16, 3, 3, 16), dtype="float16"),
+            R.Tensor((16, 32, 32, 16), dtype="float16"),
+        ):
+            R.func_attr({"relax.force_pure": True})
+            cls = Expected
+            _1: R.Tuple = cls.fused_conv2d_relu(lv, lv1, alloc1)
+            _: R.Tuple = R.memory.kill_tensor(alloc)
+            lv1_1: R.Tensor((16, 32, 32, 16), dtype="float16") = alloc1
+            lv2: R.Tensor((16,), dtype="float16") = params[3]
+            lv3: R.Tensor((16,), dtype="float16") = params[4]
+            alloc2: R.Tensor((16, 32, 32, 16), dtype="float16") = 
R.memory.alloc_tensor(
+                storage, R.prim_value(0), R.shape([16, 32, 32, 16]), 
R.dtype("float16")
+            )
+            _2: R.Tuple = cls.layer_norm(lv1_1, lv2, lv3, alloc2)
+            _1_1: R.Tuple = R.memory.kill_tensor(alloc1)
+            ln: R.Tensor((16, 32, 32, 16), dtype="float16") = alloc2
+            lv4: R.Tensor((16, 3, 3, 16), dtype="float16") = params[2]
+            gv: R.Tuple(
+                R.Tensor((16, 32, 32, 16), dtype="float16"),
+                R.Tensor((16, 3, 3, 16), dtype="float16"),
+                R.Tensor((16, 32, 32, 16), dtype="float16"),
+            ) = (ln, lv4, alloc2)
+            return gv
+
+        @R.function
+        def main_transform_params(
+            params: R.Tuple(
+                R.Tensor((16, 3, 3, 16), dtype="float16"),
+                R.Tensor((16, 3, 3, 16), dtype="float16"),
+                R.Tensor((16, 3, 3, 16), dtype="float16"),
+                R.Tensor((16,), dtype="float16"),
+                R.Tensor((16,), dtype="float16"),
+            )
+        ) -> R.Tuple(
+            R.Tensor((16, 3, 3, 16), dtype="float16"),
+            R.Tensor((16, 3, 3, 16), dtype="float16"),
+            R.Tensor((16, 3, 3, 16), dtype="float16"),
+            R.Tensor((16,), dtype="float16"),
+            R.Tensor((16,), dtype="float16"),
+        ):
+            R.func_attr({"relax.force_pure": True})
+            lv: R.Tensor((16, 3, 3, 16), dtype="float16") = params[0]
+            lv1: R.Tensor((16, 3, 3, 16), dtype="float16") = params[1]
+            lv2: R.Tensor((16, 3, 3, 16), dtype="float16") = params[2]
+            lv3: R.Tensor((16,), dtype="float16") = params[3]
+            lv4: R.Tensor((16,), dtype="float16") = params[4]
+            gv: R.Tuple(
+                R.Tensor((16, 3, 3, 16), dtype="float16"),
+                R.Tensor((16, 3, 3, 16), dtype="float16"),
+                R.Tensor((16, 3, 3, 16), dtype="float16"),
+                R.Tensor((16,), dtype="float16"),
+                R.Tensor((16,), dtype="float16"),
+            ) = (lv, lv1, lv2, lv3, lv4)
+            return gv
+
+        @R.function
+        def main(
+            data: R.Tensor((16, 32, 32, 16), dtype="float16"),
+            params: R.Tuple(
+                R.Tensor((16, 3, 3, 16), dtype="float16"),
+                R.Tensor((16, 3, 3, 16), dtype="float16"),
+                R.Tensor((16, 3, 3, 16), dtype="float16"),
+                R.Tensor((16,), dtype="float16"),
+                R.Tensor((16,), dtype="float16"),
+            ),
+        ) -> R.Tensor((16, 32, 32, 16), dtype="float16"):
+            R.func_attr({"num_input": 1, "relax.force_pure": True})
+            cls = Expected
+            lv: R.Tensor((16, 3, 3, 16), dtype="float16") = params[0]
+            gv: R.Tuple(R.Object, R.Object) = R.call_builtin_with_ctx(
+                "vm.builtin.cuda_graph.get_cached_alloc",
+                (cls.cuda_graph_alloc, R.prim_value(0)),
+                sinfo_args=(R.Tuple(R.Object, R.Object),),
+            )
+            storage: R.Object = gv[0]
+            alloc: R.Tensor((16, 32, 32, 16), dtype="float16") = 
R.memory.alloc_tensor(
+                storage, R.prim_value(0), R.shape([16, 32, 32, 16]), 
R.dtype("float16")
+            )
+            _: R.Tuple = cls.fused_conv2d_relu(data, lv, alloc)
+            lv_1: R.Tensor((16, 32, 32, 16), dtype="float16") = alloc
+            lv1: R.Tensor((16, 3, 3, 16), dtype="float16") = params[1]
+            storage1: R.Object = gv[1]
+            alloc1: R.Tensor((16, 32, 32, 16), dtype="float16") = 
R.memory.alloc_tensor(
+                storage1, R.prim_value(0), R.shape([16, 32, 32, 16]), 
R.dtype("float16")
+            )
+            gv1: R.Tuple(
+                R.Tensor((16, 32, 32, 16), dtype="float16"),
+                R.Tensor((16, 3, 3, 16), dtype="float16"),
+                R.Tensor((16, 32, 32, 16), dtype="float16"),
+            ) = R.call_builtin_with_ctx(
+                "vm.builtin.cuda_graph.run_or_capture",
+                (
+                    cls.cuda_graph_capture,
+                    (lv_1, lv1, alloc1, alloc, params, storage),
+                    R.prim_value(0),
+                ),
+                sinfo_args=(
+                    R.Tuple(
+                        R.Tensor((16, 32, 32, 16), dtype="float16"),
+                        R.Tensor((16, 3, 3, 16), dtype="float16"),
+                        R.Tensor((16, 32, 32, 16), dtype="float16"),
+                    ),
+                ),
+            )
+            alloc2: R.Tensor((16, 32, 32, 16), dtype="float16") = gv1[2]
+            ln: R.Tensor((16, 32, 32, 16), dtype="float16") = gv1[0]
+            lv4: R.Tensor((16, 3, 3, 16), dtype="float16") = gv1[1]
+            alloc3: R.Tensor((16, 32, 32, 16), dtype="float16") = 
R.builtin.alloc_tensor(
+                R.shape([16, 32, 32, 16]), R.dtype("float16"), R.prim_value(0)
+            )
+            _3: R.Tuple = cls.fused_conv2d_relu(ln, lv4, alloc3)
+            _2: R.Tuple = R.memory.kill_tensor(alloc2)
+            gv_1: R.Tensor((16, 32, 32, 16), dtype="float16") = alloc3
+            _3_1: R.Tuple = R.memory.kill_storage(storage)
+            _4: R.Tuple = R.memory.kill_storage(storage1)
+            return gv_1
+
+    mod = tvm.transform.Sequential(
+        [
+            relax.pipeline.get_pipeline(),
+            relax.transform.LiftTransformParams(),
+            relax.transform.ToNonDataflow(),
+            relax.transform.RemovePurityChecking(),
+            relax.transform.CallTIRRewrite(),
+            relax.transform.StaticPlanBlockMemory(),
+        ]
+    )(Conv2dx3)
+
+    mod["main"] = mod["main"].with_attr({"num_input": 1})
+    after = relax.transform.RewriteCUDAGraph()(mod)
+    tvm.ir.assert_structural_equal(after, after)
+
+
 if __name__ == "__main__":
     tvm.testing.main()


Reply via email to