mbs-octoml commented on code in PR #11631: URL: https://github.com/apache/tvm/pull/11631#discussion_r894903266
########## python/tvm/contrib/cutlass/build.py: ########## @@ -346,187 +357,271 @@ def tune_cutlass_kernels( for var in mod.get_global_vars(): fun_name = var.name_hint func = mod[fun_name] - annotator = OpAnnotator() if "cutlass" in fun_name: num_cutlass_partition += 1 - annotator.visit(func) - out_shape = annotator.signature["ret_shape"] - out_dtype = annotator.signature["ret_dtype"] - op_type = annotator.signature["op_type"] - - new_attrs = {"op_type": op_type} - new_attrs.update(annotator.signature) - new_attrs.update(func.attrs) - arg0_shape = new_attrs["arg0_shape"] - arg1_shape = new_attrs["arg1_shape"] - arg0_dtype = new_attrs["arg0_dtype"] - arg1_dtype = new_attrs["arg1_dtype"] - - if "conv2d" in op_type: - new_attrs["padding"] = annotator.op_attrs.padding - new_attrs["strides"] = annotator.op_attrs.strides - new_attrs["dilation"] = annotator.op_attrs.dilation - - if "conv2d_transpose" in op_type: - d_shape = out_shape - w_shape = arg1_shape - elif "conv2d_backward_weight" in op_type: - d_shape = arg1_shape - w_shape = out_shape - else: - d_shape = arg0_shape - w_shape = arg1_shape - - new_attrs.update( - handle_conv2d( - conv2d_profiler, - op_type, - d_shape, - w_shape, - annotator.op_attrs.padding, - annotator.op_attrs.strides, - annotator.op_attrs.dilation, - out_dtype, - arg0_dtype, - arg1_dtype, - use_3xtf32, - split_k_slices, - profile_all_alignments, - find_first_valid, - use_multiprocessing, - ) - ) - elif "batch_matmul" in op_type: - new_attrs.update( - handle_batch_matmul( - gemm_profiler, - op_type, - arg0_shape, - arg1_shape, - out_dtype, - arg0_dtype, - arg1_dtype, - use_3xtf32, - find_first_valid, - use_multiprocessing, - ) - ) - elif "dense" in op_type: - new_attrs.update( - handle_dense( - gemm_profiler, - op_type, - arg0_shape, - arg1_shape, - out_dtype, - arg0_dtype, - arg1_dtype, - use_3xtf32, - find_first_valid, - use_multiprocessing, - ) - ) - else: - raise ValueError("%s unsupported composite" % op_type) - - new_attrs = tvm.ir.make_node("DictAttrs", **new_attrs) - new_func = relay.Function( - func.params, - func.body, - ret_type=func.ret_type, - type_params=func.type_params, - attrs=new_attrs, + new_func = tune_cutlass_function( + func, + use_3xtf32, + split_k_slices, + profile_all_alignments, + find_first_valid, + use_multiprocessing, + gemm_profiler, + conv2d_profiler, ) mod.update_func(var, new_func) return mod, num_cutlass_partition -def build_cutlass_kernels( - lib, sm, tmp_dir="./tmp", lib_path="compile.so", threads=-1, use_fast_math=False +def tune_cutlass_function( + func, + use_3xtf32, + split_k_slices, + profile_all_alignments, + find_first_valid, + use_multiprocessing, + gemm_profiler, + conv2d_profiler, ): - """Compile CUTLASS kernels in lib and return the runtime module ready to run. + """Given a function intended to be offloaded to CUTLASS, profile each workload to select which + kernels to emit. Parameters ---------- - lib : GraphExecutorFactoryModule - The output from relay.build containing compiled host code and non-cutlass kernels. + func : IRModule + The Relay Function to tune for. - sm : int - An integer specifying the compute capability. For example, 75 for Turing and - 80 or 86 for Ampere. + use_3xtf32 : bool + Wheter or not use slower but very accurate (compared to tf32) 3xtf32 mode for + fp32 inputs on tensorcore. - tmp_dir : string, optional - A temporary directory where intermediate compiled artifacts will be stored. + split_k_slices : list of int + Split factor candidates for split-K GEMM. If split-K > 1, the GEMM K-loop is computed in + parallel accross split-K blocks, and a seperate global reduction kernel is launched to + accumulate partial reductions. The profiler will pick the best split-k factor from the + given candidate list. Note that the larger split-K factor requires a larger workspace. + Currently, parallel split-k has been tested only for wgrad. For GEMM and other conv2d + kinds, split_k_slices is ignored. + + profile_all_alignments : bool + When True, profile all kernal variants with smaller alignments than the largest possible. - lib_path : string, optional - The path to a shared library which will be generated as the result of the build process. + find_first_valid : bool + Whether or not profile all candidate kernels, or stop profiling after + the first applicable kernel is found. - threads : int, optional - The number of threads to use for compiling generated kernels. Only available for - CUDA 11.2 or later. Use all physical cores by default. + use_multiprocessing : bool + Whether or not compile profiler executables for different kernels in parallel. + + gemm_profiler : CutlassGemmProfiler + Profiler for dense operators. May cache results between tuned functions. - use_fast_math : bool, optional - Whether or not to use faster but less accurate math intrinsics. + conv2d_profiler : CutlassConv2DProfiler + Profiler for conv2d operators. May cach results between tuned functions. Returns ------- - updated_lib : runtime.Module - The updated module with compiled cutlass kernels. + annot_func : Function + The input function with attributes capturing the best CUTLASS kernel found by tuning. """ - kwargs = _get_cutlass_compile_options(sm, threads, use_fast_math) - lib.export_library(lib_path, workspace_dir=tmp_dir, **kwargs) - return runtime.load_module(lib_path) + annotator = OpAnnotator() + annotator.visit(func) + out_shape = annotator.signature["ret_shape"] + out_dtype = annotator.signature["ret_dtype"] + op_type = annotator.signature["op_type"] + + new_attrs = {"op_type": op_type} + new_attrs.update(annotator.signature) + new_attrs.update(func.attrs) + arg0_shape = new_attrs["arg0_shape"] + arg1_shape = new_attrs["arg1_shape"] + arg0_dtype = new_attrs["arg0_dtype"] + arg1_dtype = new_attrs["arg1_dtype"] + + if "conv2d" in op_type: + new_attrs["padding"] = annotator.op_attrs.padding + new_attrs["strides"] = annotator.op_attrs.strides + new_attrs["dilation"] = annotator.op_attrs.dilation + + if "conv2d_transpose" in op_type: + d_shape = out_shape + w_shape = arg1_shape + elif "conv2d_backward_weight" in op_type: + d_shape = arg1_shape + w_shape = out_shape + else: + d_shape = arg0_shape + w_shape = arg1_shape + + new_attrs.update( + handle_conv2d( + conv2d_profiler, + op_type, + d_shape, + w_shape, + annotator.op_attrs.padding, + annotator.op_attrs.strides, + annotator.op_attrs.dilation, + out_dtype, + arg0_dtype, + arg1_dtype, + use_3xtf32, + split_k_slices, + profile_all_alignments, + find_first_valid, + use_multiprocessing, + ) + ) + elif "batch_matmul" in op_type: + new_attrs.update( + handle_batch_matmul( + gemm_profiler, + op_type, + arg0_shape, + arg1_shape, + out_dtype, + arg0_dtype, + arg1_dtype, + use_3xtf32, + find_first_valid, + use_multiprocessing, + ) + ) + elif "dense" in op_type: + new_attrs.update( + handle_dense( + gemm_profiler, + op_type, + arg0_shape, + arg1_shape, + out_dtype, + arg0_dtype, + arg1_dtype, + use_3xtf32, + find_first_valid, + use_multiprocessing, + ) + ) + else: + raise ValueError("%s unsupported composite" % op_type) + + new_attrs = tvm.ir.make_node("DictAttrs", **new_attrs) + return relay.Function( + func.params, + func.body, + ret_type=func.ret_type, + type_params=func.type_params, + attrs=new_attrs, + ) -def build_cutlass_kernels_vm( - vm_exec, - sm, - tmp_dir="./tmp", - lib_path="compile.so", - vmcode_path="vmcode.ro", - threads=-1, - use_fast_math=False, -): - """Compile CUTLASS kernels in vm_exec and return a VM executable ready to run. +@register_func("relay.ext.cutlass.compile_for_cutlass") +def compile_for_cutlass(mod, cutlass_target): + """Given an IRModule with at least one Compiler='cutlass' Relay function, return a + LibraryModule with all such functions compiled into their PackedFunc-compatible form. + - First runs CUTLASS tuning to decide on the best kernels, which itself requires the + repeated compilation and execution of CUDA code using nvcc. The results of this + is captured as annotation on each relevant function. Kernel performance is cached + overall all functions. + - Then generates a single CSourceModule containing C code implementing all the + Compiler='cutlass' Relay functions, accounting for the tuning done above. + - Then compiles that CSourceModule with the appropriate nvcc arguments to yield + a static .o library. An export_library step will be required on the final runtime + module to link that library into the overall .so library. + See CompileForCutlass in src/relay/backend/contrib/cutlass/codegen.cc for where this + helper function is used to implement the RelayToTIR pass hook for CUTLASS.""" + + # Recover options from the current 'cutlass' Target + assert cutlass_target.kind.name == "cutlass" + tuning_config = { + key: cutlass_target.attrs.get(key) + for key in [ + "sm", + "use_3xtf32", + "split_k_slices", + "profile_all_alignments", + "find_first_valid", + "use_multiprocessing", + ] + } + compile_config = { + key: cutlass_target.attrs.get(key) for key in ["sm", "threads", "use_fast_math"] + } + tmp_dir = cutlass_target.attrs.get("tmp_dir") + + # Tune + logger.info("Tuning for CUTLASS") + mod, _ = tune_cutlass_kernels(mod, tmp_dir=tmp_dir, **tuning_config) + + # Compile + logger.info("Creating CSource module for CUTLASS") + create_c_source_module = tvm._ffi.get_global_func("relay.ext.cutlass.create_c_source_module") + c_module = create_c_source_module(mod) + function_names = c_module.get_function("get_func_names")() + compile_options = _get_cutlass_compile_options(**compile_config) + lib_path = os.path.join(tmp_dir, "cutlass.o") + logger.info("Compiling generated CUTLASS code") + c_module.export_library(lib_path, workspace_dir=tmp_dir, **compile_options) + + # Recover static library + logger.info("Loading compiled CUTLASS code") + final_mod = tvm.runtime.load_static_library(lib_path, function_names) + + logger.info("Done with CUTLASS compilation") + return final_mod + + +def finalize_modules(lib, lib_path, tmp_dir): Review Comment: done -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org