Lunderberg opened a new pull request, #17080: URL: https://github.com/apache/tvm/pull/17080
This is a follow-up commit to https://github.com/apache/tvm/pull/16637, which updated `relax.transform.FuseOps` to provide additional parameters defining symbolic variables required by the fused functions. While this ensures that `relax.transform.FuseOps` produces well-formed Relax functions, these additional arguments can break some kernel implementations. This commit implements a new transform `RemoveSymbolicExpressionsInSubroutine` to resolve this issue. This transform identifies function arguments whose sole purpose is to compute a symbolic expression, when that symbolic expression could be inferred from tensor shapes. For example, consider the following Relax function: ```python @R.function def func( data: R.Tensor(["batch_size * seq_len", "hidden_size"]), weights: R.Tensor(["hidden_size", "intermediate_size"]), dummy_arg: R.Shape(["batch_size", "seq_len"]), ) -> R.Tensor(["batch_size * seq_len", "intermediate_size"]): batch_size = T.int64() seq_len = T.int64() intermediate_size = T.int64() hidden_size = T.int64() output: R.Tensor([batch_size * seq_len, intermediate_size]) = R.matmul(data, weights) return output ``` The `data` tensor may be used to infer `hidden_size`, but cannot be used to infer `batch_size` or `seq_len`. The `R.Shape` parameter exists solely to define `batch_size` and `seq_len`, since all symbolic variables must be defined. However, neither `batch_size` nor `seq_len` are ever used outside of the expression `batch_size * seq_len`, and the value of `batch_size * seq_len` could be inferred from the shape of the `data` tensor. This new transform identifies cases where an argument is otherwise unnecessary, and replaces the symbolic expression with a new argument. This makes the `dummy_arg: R.Shape` be entirely unused, so a later use of `relax.transform.RemoveUnusedParameters()` can remove the parameter altogether. ```python @R.function def func( data: R.Tensor(["data_dim0", "hidden_size"]), weights: R.Tensor(["hidden_size", "intermediate_size"]), dummy_arg: R.Shape(["batch_size", "seq_len"]), ): data_dim0 = T.int64() intermediate_size = T.int64() hidden_size = T.int64() output: R.Tensor([data_dim0, intermediate_size]) = R.matmul(data, weights) return output ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org