Lunderberg opened a new pull request, #17080:
URL: https://github.com/apache/tvm/pull/17080

   This is a follow-up commit to
   https://github.com/apache/tvm/pull/16637, which updated 
`relax.transform.FuseOps` to provide additional parameters defining symbolic 
variables required by the fused functions.  While this ensures that 
`relax.transform.FuseOps` produces well-formed Relax functions, these 
additional arguments can break some kernel implementations.
   
   This commit implements a new transform
   `RemoveSymbolicExpressionsInSubroutine` to resolve this issue.  This 
transform identifies function arguments whose sole purpose is to compute a 
symbolic expression, when that symbolic expression could be inferred from 
tensor shapes.
   
   For example, consider the following Relax function:
   
   ```python
   @R.function
   def func(
       data: R.Tensor(["batch_size * seq_len", "hidden_size"]),
       weights: R.Tensor(["hidden_size", "intermediate_size"]),
       dummy_arg: R.Shape(["batch_size", "seq_len"]),
     ) -> R.Tensor(["batch_size * seq_len", "intermediate_size"]):
   
       batch_size = T.int64()
       seq_len = T.int64()
       intermediate_size = T.int64()
       hidden_size = T.int64()
   
       output: R.Tensor([batch_size * seq_len, intermediate_size]) = 
R.matmul(data, weights)
       return output
   ```
   
   The `data` tensor may be used to infer `hidden_size`, but cannot be used to 
infer `batch_size` or `seq_len`.  The `R.Shape` parameter exists solely to 
define `batch_size` and `seq_len`, since all symbolic variables must be 
defined.  However, neither `batch_size` nor `seq_len` are ever used outside of 
the expression `batch_size * seq_len`, and the value of `batch_size * seq_len` 
could be inferred from the shape of the `data` tensor.
   
   This new transform identifies cases where an argument is otherwise 
unnecessary, and replaces the symbolic expression with a new argument.  This 
makes the `dummy_arg: R.Shape` be entirely unused, so a later use of 
`relax.transform.RemoveUnusedParameters()` can remove the parameter altogether.
   
   ```python
   @R.function
   def func(
       data: R.Tensor(["data_dim0", "hidden_size"]),
       weights: R.Tensor(["hidden_size", "intermediate_size"]),
       dummy_arg: R.Shape(["batch_size", "seq_len"]),
     ):
   
       data_dim0 = T.int64()
       intermediate_size = T.int64()
       hidden_size = T.int64()
   
       output: R.Tensor([data_dim0, intermediate_size]) = R.matmul(data, 
weights)
       return output
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to