Ghosts381937 commented on PR #17765:
URL: https://github.com/apache/tvm/pull/17765#issuecomment-2868940770

   Sorry for lately reply.
   After taking a deep dive into the code infrastructure, I have a few 
questions regarding the following code, and
   I'm not sure if this is the correct simplification , but it seems to be a 
bug for x * c pattern of floormod.
   
   ```python
   from tvm import tir
   from tvm.arith import Analyzer
   from tvm.tir.op import floormod
   
   # Define symbolic variable
   past_decoder_sequence_length = tir.Var("past_decoder_sequence_length", 
"int64")
   
   # Create expressions with common factor
   expr1 = (past_decoder_sequence_length + 1) * tir.IntImm("int64", 64)
   divisor1 = (past_decoder_sequence_length + 1) * tir.IntImm("int64", 32)
   # Create Analyzer
   analyzer = Analyzer()
   
   # The Logic from CanProveDivisible().
   print(analyzer.can_prove_equal(expr1, divisor1) or 
analyzer.can_prove(floormod(expr1, divisor1) == 0))
   # Expected: True, but actual: False
   
   # Main reason is the following simplification.
   print(analyzer.rewrite_simplify(floormod(expr1, divisor1)))
   # Expected output: 0
   # Actual output: 
   #     T.int64(64) * (past_decoder_sequence_length + T.int64(1)) % 
   #     (T.int64(32) * (past_decoder_sequence_length + T.int64(1)))
   ```
   
   Detailed rewrite_simplify process(up to 2 iterations):
   ```
   #   iter 0: (past_decoder_sequence_length * T.int64(64) + T.int64(64)) % 
               (past_decoder_sequence_length * T.int64(32) + T.int64(32))
   #   iter 1: T.int64(64) * (past_decoder_sequence_length + T.int64(1)) % 
               (T.int64(32) * (past_decoder_sequence_length + T.int64(1)))
   ```
   Caused by 
   
https://github.com/apache/tvm/blob/f4704f2288b14c05e99de9f74bcb9b530c2dd7e6/src/arith/rewrite_simplify.cc#L449
   it will cause the following case:
   ```
   #     iter 0 -> iter 1: 
   #         x: T.int64(64)
   #         y: past_decoder_sequence_length
   #     => x * (y + 1) = T.int64(64) * (past_decoder_sequence_length + 
T.int64(1))
   #     => c1 * (y + 1) = T.int64(64) * (past_decoder_sequence_length + 
T.int64(1))
   ```
   I'm not sure if this is the correct simplification , but it seems to be a 
bug for x * c pattern of floormod.
   Perhaps should we consider supporting the x * c  + x => (x + 1) * c pattern 
like recently commit in the rewrite rules for add?
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to