comaniac commented on a change in pull request #9737:
URL: https://github.com/apache/tvm/pull/9737#discussion_r769071320



##########
File path: python/tvm/contrib/cutlass/gen_conv2d.py
##########
@@ -121,27 +131,67 @@ def get_default(self, out_dtype):
         data_type = gemm_profile_result["data_type"]
         return create_conv2d_operator([tile_description], data_type, 
[alignment])[0]
 
+    def check_align(self, op_name, C, K):
+        """Filter out kernels that cannot be supported."""
+        aligns = re.findall(r"align[1|2|4|8]", op_name)
+        assert len(aligns) == 1
+        align = int(aligns[0][-1])
+        return all([dim % align == 0 for dim in [C, K]])
+
     def profile(
-        self, d_shape, w_shape, out_shape, out_dtype, profile_all=True, 
use_multiprocessing=False
+        self,
+        d_shape,
+        w_shape,
+        padding,
+        stride,
+        dilation,
+        out_dtype,
+        profile_all=True,
+        use_multiprocessing=False,
     ):
         """Profile and select the best kernel from candidate kernels.
         If profile_all is False, return immediately after the first applicable 
kernel is found.
         If use_multiprocessing is True, compile all profiler executables in 
parallel.
         """
-        B, _, _, IC = d_shape
+        N, H, W, IC = d_shape
         OC, R, S, _ = w_shape
-        _, P, Q, _ = out_shape
+        workload = (
+            N,
+            H,
+            W,
+            IC,
+            OC,
+            R,
+            S,
+            padding[0],
+            padding[1],
+            stride[0],
+            stride[1],
+            dilation[0],
+            dilation[1],
+        )
 
-        M = B * P * Q
-        N = OC
-        K = R * S * IC
+        if workload in self.cache:
+            return self.cache[workload]
 
-        gemm_profile_result = self.gemm_profiler.profile(
-            M, N, K, out_dtype, profile_all=profile_all, 
use_multiprocessing=use_multiprocessing
-        )
+        ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, 
op_creator=create_conv2d_operator)
+        ops = list(filter(lambda op: self.check_align(op["name"], IC, OC), 
ops))
 
-        tile_description = gemm_profile_result["tile_description"]
-        alignment = gemm_profile_result["alignment"]
-        data_type = gemm_profile_result["data_type"]
+        if profile_all:
+            self.engine.compile_all(ops, use_multiprocessing)
 
-        return create_conv2d_operator([tile_description], data_type, 
[alignment])[0]
+        args = (
+            "--n=%d --h=%d --w=%d --c=%d --k=%d --r=%d --s=%d --pad_h=%d 
--pad_w=%d "
+            "--stride_h=%d --stride_w=%d --dilation_h=%d --dilation_w=%d"
+        ) % workload
+
+        for op in ops:
+            out = self.engine.evaluate(op, args.split(" "))
+            op["runtime"] = out
+            if out > 0 and not profile_all:

Review comment:
       IIUC, now you changed `evaluate` to return `float("inf")` when invalid. 
Then the fist invalid kernel will be selected since `float("inf") > 0`, right?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to