masahi commented on code in PR #11631:
URL: https://github.com/apache/tvm/pull/11631#discussion_r894911347


##########
python/tvm/contrib/cutlass/build.py:
##########
@@ -346,187 +357,271 @@ def tune_cutlass_kernels(
     for var in mod.get_global_vars():
         fun_name = var.name_hint
         func = mod[fun_name]
-        annotator = OpAnnotator()
         if "cutlass" in fun_name:
             num_cutlass_partition += 1
-            annotator.visit(func)
-            out_shape = annotator.signature["ret_shape"]
-            out_dtype = annotator.signature["ret_dtype"]
-            op_type = annotator.signature["op_type"]
-
-            new_attrs = {"op_type": op_type}
-            new_attrs.update(annotator.signature)
-            new_attrs.update(func.attrs)
-            arg0_shape = new_attrs["arg0_shape"]
-            arg1_shape = new_attrs["arg1_shape"]
-            arg0_dtype = new_attrs["arg0_dtype"]
-            arg1_dtype = new_attrs["arg1_dtype"]
-
-            if "conv2d" in op_type:
-                new_attrs["padding"] = annotator.op_attrs.padding
-                new_attrs["strides"] = annotator.op_attrs.strides
-                new_attrs["dilation"] = annotator.op_attrs.dilation
-
-                if "conv2d_transpose" in op_type:
-                    d_shape = out_shape
-                    w_shape = arg1_shape
-                elif "conv2d_backward_weight" in op_type:
-                    d_shape = arg1_shape
-                    w_shape = out_shape
-                else:
-                    d_shape = arg0_shape
-                    w_shape = arg1_shape
-
-                new_attrs.update(
-                    handle_conv2d(
-                        conv2d_profiler,
-                        op_type,
-                        d_shape,
-                        w_shape,
-                        annotator.op_attrs.padding,
-                        annotator.op_attrs.strides,
-                        annotator.op_attrs.dilation,
-                        out_dtype,
-                        arg0_dtype,
-                        arg1_dtype,
-                        use_3xtf32,
-                        split_k_slices,
-                        profile_all_alignments,
-                        find_first_valid,
-                        use_multiprocessing,
-                    )
-                )
-            elif "batch_matmul" in op_type:
-                new_attrs.update(
-                    handle_batch_matmul(
-                        gemm_profiler,
-                        op_type,
-                        arg0_shape,
-                        arg1_shape,
-                        out_dtype,
-                        arg0_dtype,
-                        arg1_dtype,
-                        use_3xtf32,
-                        find_first_valid,
-                        use_multiprocessing,
-                    )
-                )
-            elif "dense" in op_type:
-                new_attrs.update(
-                    handle_dense(
-                        gemm_profiler,
-                        op_type,
-                        arg0_shape,
-                        arg1_shape,
-                        out_dtype,
-                        arg0_dtype,
-                        arg1_dtype,
-                        use_3xtf32,
-                        find_first_valid,
-                        use_multiprocessing,
-                    )
-                )
-            else:
-                raise ValueError("%s unsupported composite" % op_type)
-
-            new_attrs = tvm.ir.make_node("DictAttrs", **new_attrs)
-            new_func = relay.Function(
-                func.params,
-                func.body,
-                ret_type=func.ret_type,
-                type_params=func.type_params,
-                attrs=new_attrs,
+            new_func = tune_cutlass_function(
+                func,
+                use_3xtf32,
+                split_k_slices,
+                profile_all_alignments,
+                find_first_valid,
+                use_multiprocessing,
+                gemm_profiler,
+                conv2d_profiler,
             )
             mod.update_func(var, new_func)
 
     return mod, num_cutlass_partition
 
 
-def build_cutlass_kernels(
-    lib, sm, tmp_dir="./tmp", lib_path="compile.so", threads=-1, 
use_fast_math=False
+def tune_cutlass_function(
+    func,
+    use_3xtf32,
+    split_k_slices,
+    profile_all_alignments,
+    find_first_valid,
+    use_multiprocessing,
+    gemm_profiler,
+    conv2d_profiler,
 ):
-    """Compile CUTLASS kernels in lib and return the runtime module ready to 
run.
+    """Given a function intended to be offloaded to CUTLASS,  profile each 
workload to select which
+    kernels to emit.
 
     Parameters
     ----------
-    lib : GraphExecutorFactoryModule
-        The output from relay.build containing compiled host code and 
non-cutlass kernels.
+    func : IRModule
+        The Relay Function to tune for.
 
-    sm : int
-        An integer specifying the compute capability. For example, 75 for 
Turing and
-        80 or 86 for Ampere.
+    use_3xtf32 : bool
+        Wheter or not use slower but very accurate (compared to tf32) 3xtf32 
mode for
+        fp32 inputs on tensorcore.
 
-    tmp_dir : string, optional
-        A temporary directory where intermediate compiled artifacts will be 
stored.
+    split_k_slices : list of int
+        Split factor candidates for split-K GEMM. If split-K > 1, the GEMM 
K-loop is computed in
+        parallel accross split-K blocks, and a seperate global reduction 
kernel is launched to
+        accumulate partial reductions. The profiler will pick the best split-k 
factor from the
+        given candidate list. Note that the larger split-K factor requires a 
larger workspace.
+        Currently, parallel split-k has been tested only for wgrad. For GEMM 
and other conv2d
+        kinds, split_k_slices is ignored.
+
+    profile_all_alignments : bool
+        When True, profile all kernal variants with smaller alignments than 
the largest possible.
 
-    lib_path : string, optional
-        The path to a shared library which will be generated as the result of 
the build process.
+    find_first_valid : bool
+        Whether or not profile all candidate kernels, or stop profiling after
+        the first applicable kernel is found.
 
-    threads : int, optional
-        The number of threads to use for compiling generated kernels. Only 
available for
-        CUDA 11.2 or later. Use all physical cores by default.
+    use_multiprocessing : bool
+        Whether or not compile profiler executables for different kernels in 
parallel.
+
+    gemm_profiler : CutlassGemmProfiler
+        Profiler for dense operators. May cache results between tuned 
functions.
 
-    use_fast_math : bool, optional
-        Whether or not to use faster but less accurate math intrinsics.
+    conv2d_profiler : CutlassConv2DProfiler
+        Profiler for conv2d operators. May cach results between tuned 
functions.
 
     Returns
     -------
-    updated_lib : runtime.Module
-        The updated module with compiled cutlass kernels.
+    annot_func : Function
+        The input function with attributes capturing the best CUTLASS kernel 
found by tuning.
     """
-    kwargs = _get_cutlass_compile_options(sm, threads, use_fast_math)
-    lib.export_library(lib_path, workspace_dir=tmp_dir, **kwargs)
-    return runtime.load_module(lib_path)
+    annotator = OpAnnotator()
+    annotator.visit(func)
+    out_shape = annotator.signature["ret_shape"]
+    out_dtype = annotator.signature["ret_dtype"]
+    op_type = annotator.signature["op_type"]
+
+    new_attrs = {"op_type": op_type}
+    new_attrs.update(annotator.signature)
+    new_attrs.update(func.attrs)
+    arg0_shape = new_attrs["arg0_shape"]
+    arg1_shape = new_attrs["arg1_shape"]
+    arg0_dtype = new_attrs["arg0_dtype"]
+    arg1_dtype = new_attrs["arg1_dtype"]
+
+    if "conv2d" in op_type:
+        new_attrs["padding"] = annotator.op_attrs.padding
+        new_attrs["strides"] = annotator.op_attrs.strides
+        new_attrs["dilation"] = annotator.op_attrs.dilation
+
+        if "conv2d_transpose" in op_type:
+            d_shape = out_shape
+            w_shape = arg1_shape
+        elif "conv2d_backward_weight" in op_type:
+            d_shape = arg1_shape
+            w_shape = out_shape
+        else:
+            d_shape = arg0_shape
+            w_shape = arg1_shape
+
+        new_attrs.update(
+            handle_conv2d(
+                conv2d_profiler,
+                op_type,
+                d_shape,
+                w_shape,
+                annotator.op_attrs.padding,
+                annotator.op_attrs.strides,
+                annotator.op_attrs.dilation,
+                out_dtype,
+                arg0_dtype,
+                arg1_dtype,
+                use_3xtf32,
+                split_k_slices,
+                profile_all_alignments,
+                find_first_valid,
+                use_multiprocessing,
+            )
+        )
+    elif "batch_matmul" in op_type:
+        new_attrs.update(
+            handle_batch_matmul(
+                gemm_profiler,
+                op_type,
+                arg0_shape,
+                arg1_shape,
+                out_dtype,
+                arg0_dtype,
+                arg1_dtype,
+                use_3xtf32,
+                find_first_valid,
+                use_multiprocessing,
+            )
+        )
+    elif "dense" in op_type:
+        new_attrs.update(
+            handle_dense(
+                gemm_profiler,
+                op_type,
+                arg0_shape,
+                arg1_shape,
+                out_dtype,
+                arg0_dtype,
+                arg1_dtype,
+                use_3xtf32,
+                find_first_valid,
+                use_multiprocessing,
+            )
+        )
+    else:
+        raise ValueError("%s unsupported composite" % op_type)
+
+    new_attrs = tvm.ir.make_node("DictAttrs", **new_attrs)
+    return relay.Function(
+        func.params,
+        func.body,
+        ret_type=func.ret_type,
+        type_params=func.type_params,
+        attrs=new_attrs,
+    )
 
 
-def build_cutlass_kernels_vm(
-    vm_exec,
-    sm,
-    tmp_dir="./tmp",
-    lib_path="compile.so",
-    vmcode_path="vmcode.ro",
-    threads=-1,
-    use_fast_math=False,
-):
-    """Compile CUTLASS kernels in vm_exec and return a VM executable ready to 
run.
+@register_func("relay.ext.cutlass.compile_for_cutlass")
+def compile_for_cutlass(mod, cutlass_target):
+    """Given an IRModule with at least one Compiler='cutlass' Relay function, 
return a
+    LibraryModule with all such functions compiled into their 
PackedFunc-compatible form.
+     - First runs CUTLASS tuning to decide on the best kernels, which itself 
requires the
+       repeated compilation and execution of CUDA code using nvcc. The results 
of this
+       is captured as annotation on each relevant function. Kernel performance 
is cached
+       overall all functions.
+     - Then generates a single CSourceModule containing C code implementing 
all the
+       Compiler='cutlass' Relay functions, accounting for the tuning done 
above.
+     - Then compiles that CSourceModule with the appropriate nvcc arguments to 
yield
+       a static .o library. An export_library step will be required on the 
final runtime
+       module to link that library into the overall .so library.
+     See CompileForCutlass in src/relay/backend/contrib/cutlass/codegen.cc for 
where this
+     helper function is used to implement the RelayToTIR pass hook for 
CUTLASS."""
+
+    # Recover options from the current 'cutlass' Target
+    assert cutlass_target.kind.name == "cutlass"
+    tuning_config = {
+        key: cutlass_target.attrs.get(key)
+        for key in [
+            "sm",
+            "use_3xtf32",
+            "split_k_slices",
+            "profile_all_alignments",
+            "find_first_valid",
+            "use_multiprocessing",
+        ]
+    }
+    compile_config = {
+        key: cutlass_target.attrs.get(key) for key in ["sm", "threads", 
"use_fast_math"]
+    }
+    tmp_dir = cutlass_target.attrs.get("tmp_dir")
+
+    # Tune
+    logger.info("Tuning for CUTLASS")
+    mod, _ = tune_cutlass_kernels(mod, tmp_dir=tmp_dir, **tuning_config)
+
+    # Compile
+    logger.info("Creating CSource module for CUTLASS")
+    create_c_source_module = 
tvm._ffi.get_global_func("relay.ext.cutlass.create_c_source_module")
+    c_module = create_c_source_module(mod)
+    function_names = c_module.get_function("get_func_names")()
+    compile_options = _get_cutlass_compile_options(**compile_config)
+    lib_path = os.path.join(tmp_dir, "cutlass.o")
+    logger.info("Compiling generated CUTLASS code")
+    c_module.export_library(lib_path, workspace_dir=tmp_dir, **compile_options)
+
+    # Recover static library
+    logger.info("Loading compiled CUTLASS code")
+    final_mod = tvm.runtime.load_static_library(lib_path, function_names)
+
+    logger.info("Done with CUTLASS compilation")
+    return final_mod
+
+
+def finalize_modules(lib, lib_path, tmp_dir):
+    """Returns lib with any C source, LLVM and static library modules complied 
and linked in ready
+    for use by the graph or AOT executors. This method is not specific to 
CUTLASS, however it does
+    assume nvcc will be used for final compilation and linking. It is provided 
here for
+    convenience.
 
     Parameters
     ----------
-    vm_exec : vm.Executable
-        The output from relay.vm.compile containing compiled host code and 
non-cutlass kernels.
+    lib : runtime.Module
+        The output from relay.build.
 
-    sm : int
-        An integer specifying the compute capability. For example, 75 for 
Turing and
-        80 or 86 for Ampere.
+    lib_path : string
+        Name for temporary library .so file.
 
-    tmp_dir : string, optional
-        A temporary directory where intermediate compiled artifacts will be 
stored.
+    tmp_dir : Working temporary directory.
+
+    Returns
+    -------
+    updated_lib : runtime::Module
+        The given lib with any final compilation and linking steps completed.
+
+    """
+    lib_path = os.path.join(tmp_dir, lib_path)
+    lib.export_library(lib_path, workspace_dir=tmp_dir, cc="nvcc")
+    return runtime.load_module(lib_path)
 
-    lib_path : string, optional
-        The path to a shared library which will be generated as the result of 
the build process.
 
-    vmcode_path : string, optional
-        The path where the VM bytecode will be serialized to.
+def finalize_modules_vm(vm_exec, lib_path, tmp_dir):

Review Comment:
   I'm using `vmcode` here in my cutlass byoc benchmark
   
   https://github.com/masahi/tvm-cutlass-eval/blob/master/yolo5/run.py#L133-L136



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to