comaniac commented on a change in pull request #9737: URL: https://github.com/apache/tvm/pull/9737#discussion_r769071320
########## File path: python/tvm/contrib/cutlass/gen_conv2d.py ########## @@ -121,27 +131,67 @@ def get_default(self, out_dtype): data_type = gemm_profile_result["data_type"] return create_conv2d_operator([tile_description], data_type, [alignment])[0] + def check_align(self, op_name, C, K): + """Filter out kernels that cannot be supported.""" + aligns = re.findall(r"align[1|2|4|8]", op_name) + assert len(aligns) == 1 + align = int(aligns[0][-1]) + return all([dim % align == 0 for dim in [C, K]]) + def profile( - self, d_shape, w_shape, out_shape, out_dtype, profile_all=True, use_multiprocessing=False + self, + d_shape, + w_shape, + padding, + stride, + dilation, + out_dtype, + profile_all=True, + use_multiprocessing=False, ): """Profile and select the best kernel from candidate kernels. If profile_all is False, return immediately after the first applicable kernel is found. If use_multiprocessing is True, compile all profiler executables in parallel. """ - B, _, _, IC = d_shape + N, H, W, IC = d_shape OC, R, S, _ = w_shape - _, P, Q, _ = out_shape + workload = ( + N, + H, + W, + IC, + OC, + R, + S, + padding[0], + padding[1], + stride[0], + stride[1], + dilation[0], + dilation[1], + ) - M = B * P * Q - N = OC - K = R * S * IC + if workload in self.cache: + return self.cache[workload] - gemm_profile_result = self.gemm_profiler.profile( - M, N, K, out_dtype, profile_all=profile_all, use_multiprocessing=use_multiprocessing - ) + ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, op_creator=create_conv2d_operator) + ops = list(filter(lambda op: self.check_align(op["name"], IC, OC), ops)) - tile_description = gemm_profile_result["tile_description"] - alignment = gemm_profile_result["alignment"] - data_type = gemm_profile_result["data_type"] + if profile_all: + self.engine.compile_all(ops, use_multiprocessing) - return create_conv2d_operator([tile_description], data_type, [alignment])[0] + args = ( + "--n=%d --h=%d --w=%d --c=%d --k=%d --r=%d --s=%d --pad_h=%d --pad_w=%d " + "--stride_h=%d --stride_w=%d --dilation_h=%d --dilation_w=%d" + ) % workload + + for op in ops: + out = self.engine.evaluate(op, args.split(" ")) + op["runtime"] = out + if out > 0 and not profile_all: Review comment: IIUC, now you changed `evaluate` to return `float("inf")` when invalid. Then the fist invalid kernel will be selected since `float("inf") > 0`, right? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org