comaniac commented on a change in pull request #9737:
URL: https://github.com/apache/tvm/pull/9737#discussion_r769005791
##
File path: python/tvm/contrib/cutlass/gen_conv2d.py
##
@@ -121,27 +131,70 @@ def get_default(self, out_dtype):
data_type = gemm_profile_result["data_type"]
return create_conv2d_operator([tile_description], data_type,
[alignment])[0]
+def check_align(self, op_name, C, K):
+"""Filter out kernels that cannot be supported."""
+aligns = re.findall(r"align[1|2|4|8]", op_name)
+assert len(aligns) == 1
+align = int(aligns[0][-1])
+return all([dim % align == 0 for dim in [C, K]])
+
def profile(
-self, d_shape, w_shape, out_shape, out_dtype, profile_all=True,
use_multiprocessing=False
+self,
+d_shape,
+w_shape,
+padding,
+stride,
+dilation,
+out_dtype,
+profile_all=True,
+use_multiprocessing=False,
):
"""Profile and select the best kernel from candidate kernels.
If profile_all is False, return immediately after the first applicable
kernel is found.
If use_multiprocessing is True, compile all profiler executables in
parallel.
"""
-B, _, _, IC = d_shape
+N, H, W, IC = d_shape
OC, R, S, _ = w_shape
-_, P, Q, _ = out_shape
+workload = (
+N,
+H,
+W,
+IC,
+OC,
+R,
+S,
+padding[0],
+padding[1],
+stride[0],
+stride[1],
+dilation[0],
+dilation[1],
+)
-M = B * P * Q
-N = OC
-K = R * S * IC
+if workload in self.cache:
+return self.cache[workload]
-gemm_profile_result = self.gemm_profiler.profile(
-M, N, K, out_dtype, profile_all=profile_all,
use_multiprocessing=use_multiprocessing
-)
+ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype,
op_creator=create_conv2d_operator)
+ops = list(filter(lambda op: self.check_align(op["name"], IC, OC),
ops))
-tile_description = gemm_profile_result["tile_description"]
-alignment = gemm_profile_result["alignment"]
-data_type = gemm_profile_result["data_type"]
+for op in ops:
+op["runtime"] = -1
-return create_conv2d_operator([tile_description], data_type,
[alignment])[0]
+if profile_all:
+self.engine.compile_all(ops, use_multiprocessing)
+
+args = (
+"--n=%d --h=%d --w=%d --c=%d --k=%d --r=%d --s=%d --pad_h=%d
--pad_w=%d "
+"--stride_h=%d --stride_w=%d --dilation_h=%d --dilation_w=%d"
+) % workload
+
+for op in ops:
+out = self.engine.evaluate(op, args.split(" "))
+op["runtime"] = out
+if out > 0 and profile_all is False:
Review comment:
nit
```suggestion
if out > 0 and not profile_all:
```
##
File path: python/tvm/contrib/cutlass/conv2d_profiler.py
##
@@ -0,0 +1,163 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=import-outside-toplevel, invalid-name
+"""Instantiate a C++ source for profiling CUTLASS kernels."""
+
+
+class Conv2dProfilerEmitter(object):
Review comment:
I raised this topic before in the GEMM profiler PR, but I agreed with
@masahi that it seems not much to share and CUTLASS basically only supports
GEMM and Conv2D. Accordingly, it might be a bit overkill to have a common base
class at least for now.
##
File path: python/tvm/contrib/cutlass/gen_conv2d.py
##
@@ -121,27 +131,70 @@ def get_default(self, out_dtype):
data_type = gemm_profile_result["data_type"]
return create_conv2d_operator([tile_description], data_type,
[alignment])[0]
+def check_align(self, op_name, C, K):
+"""Filter out kernels that cannot be supported."""
+aligns = re.findall(r"align[1|2|4|8]", op_name)
+assert len(aligns) == 1
+align = int(aligns[0][-1])
+return all([dim % align == 0 for dim in [C, K]])
+
def profile(
-