Lunderberg commented on issue #17211: URL: https://github.com/apache/tvm/issues/17211#issuecomment-2256481944
Good catch, and I think this is arising from a number of different edge cases. * The `StructInfo` for `R.call_tir` is always inferred from the `out_sinfo`, not from the TIR function's signature. This is for historical reasons, as TIR functions only recently started holding the annotations that would allow them to perform shape inference. As a result, no errors are seen during the initial well-formed check. * The default of `T.Buffer` and `R.Tensor` is different. If unspecified, `T.Buffer` defaults to `"float32"` datatype, where `R.Tensor` defaults to `DataType::Void`, which is used to represent an unknown datatype that might be inferred later in compilation. There is no equivalent in TIR, which must have a known datatype for each buffer. * There is no rule that would infer the unknown Relax datatype from the mandatory TIR datatype. As a result, the `out_sinfo` remains the incomplete `R.Tensor(shape)`, rather than `R.Tensor(shape, dtype="float32")`. * The error is raised during `CallTIRRewrite`, which rewrites low-level calls from having an implied allocation for the output to having an explicit argument for the output. Here, this rewirtes the `R.call_tir(cls.relu, [x], out_sinfo=R.Tensor([1,512,64,64]))` into `cls.relu(x, output_allocation)`, where `output_allocation` has shape `R.Tensor([1,512,64,64])`. This is the first point at which the TIR function's signature is actually inspected. * Currently, when checking whether the constraints required by a subroutine, the constraints must either pass or fail. There is no mechanism for the subroutine's constraints to be hoisted into the calling scope. Since "tensor of arbitrary element type" is not a valid argument for "tensor with float32 element type", the check fails. I think there's a number of improvements that could be made, in order to close each of these loopholes. 1. Improved well-formed checker. If `out_sinfo` is explicitly stated in `R.call_tir`, then `IsBaseOf(inferred_sinfo, out_sinfo)` must return true. 2. Infer the dtype of `out_sinfo` in `R.call_tir`. If `out_sinfo` is a Tensor, or a Tuple of tensors, and one of those tensors has `DataType::Void()`, normalize the `out_sinfo` argument to include the datatype from the PrimFunc. 3. Improved struct inference for `R.call_tir`. Now that PrimFuncs have a known shape for each argument, the output of `R.call_tir` could be improved. For backwards compatibility, an explicit `out_sinfo` argument would still take precedence. However, if `out_sinfo` is omitted (which currently would cause an immediate error), it would instead infer the output struct info assuming that the last `len(params) - len(args)` are output parameters. 4. Improved normalization in block builder. If an operator has restrictions on an argument, normalization could expose those constraints to the Relax levels, rather than only marking it as pass/fail. For example, normalization of an operator whose argument must be `DataType::Float(32)`, but which received `DataType::Void()`, could produce a binding of `new_arg = R.match_cast(arg, R.Tensor(arg.struct_info.shape, "float32"))`, then use `new_arg` in its call. I think all of these would be useful changes to make, but some would have wider impacts than others. The well-formed checks could be added with the smallest risk of breakage, but also place the greatest load on new developers. Improved normalization would provide the greatest ease-of-use, but would require the most widespread changes. @tqchen, since some of these would be much more involved changes, do you have preferences/thoughts on them? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org