Ghosts381937 commented on PR #17765:
URL: https://github.com/apache/tvm/pull/17765#issuecomment-2868940770
Sorry for lately reply.
After taking a deep dive into the code infrastructure, I have a few
questions regarding the following code, and
I'm not sure if this is the correct simplification , but it seems to be a
bug for x * c pattern of floormod.
```python
from tvm import tir
from tvm.arith import Analyzer
from tvm.tir.op import floormod
# Define symbolic variable
past_decoder_sequence_length = tir.Var("past_decoder_sequence_length",
"int64")
# Create expressions with common factor
expr1 = (past_decoder_sequence_length + 1) * tir.IntImm("int64", 64)
divisor1 = (past_decoder_sequence_length + 1) * tir.IntImm("int64", 32)
# Create Analyzer
analyzer = Analyzer()
# The Logic from CanProveDivisible().
print(analyzer.can_prove_equal(expr1, divisor1) or
analyzer.can_prove(floormod(expr1, divisor1) == 0))
# Expected: True, but actual: False
# Main reason is the following simplification.
print(analyzer.rewrite_simplify(floormod(expr1, divisor1)))
# Expected output: 0
# Actual output:
# T.int64(64) * (past_decoder_sequence_length + T.int64(1)) %
# (T.int64(32) * (past_decoder_sequence_length + T.int64(1)))
```
Detailed rewrite_simplify process(up to 2 iterations):
```
# iter 0: (past_decoder_sequence_length * T.int64(64) + T.int64(64)) %
(past_decoder_sequence_length * T.int64(32) + T.int64(32))
# iter 1: T.int64(64) * (past_decoder_sequence_length + T.int64(1)) %
(T.int64(32) * (past_decoder_sequence_length + T.int64(1)))
```
Caused by
https://github.com/apache/tvm/blob/f4704f2288b14c05e99de9f74bcb9b530c2dd7e6/src/arith/rewrite_simplify.cc#L449
it will cause the following case:
```
# iter 0 -> iter 1:
# x: T.int64(64)
# y: past_decoder_sequence_length
# => x * (y + 1) = T.int64(64) * (past_decoder_sequence_length +
T.int64(1))
# => c1 * (y + 1) = T.int64(64) * (past_decoder_sequence_length +
T.int64(1))
```
I'm not sure if this is the correct simplification , but it seems to be a
bug for x * c pattern of floormod.
Perhaps should we consider supporting the x * c + x => (x + 1) * c pattern
like recently commit in the rewrite rules for add?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]