This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 8218b18da3 [Relax] Add mod operator support (#18559)
8218b18da3 is described below
commit 8218b18da331f887934f72ab4f4b4a5f2c0dc082
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Tue Dec 9 20:11:48 2025 +0800
[Relax] Add mod operator support (#18559)
## How
- Resolve todo by changing from raising error to calling _op_ffi_api.mod
- Add both operators to the parametrized test
---
python/tvm/relax/expr.py | 3 +--
tests/python/relax/test_op_binary.py | 4 ++++
2 files changed, 5 insertions(+), 2 deletions(-)
diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py
index 8dd4eff5c7..e9bc9a7a3e 100644
--- a/python/tvm/relax/expr.py
+++ b/python/tvm/relax/expr.py
@@ -185,8 +185,7 @@ class ExprWithOp(Expr, Scriptable):
return _binary_rhs_helper(other)
def __mod__(self, other: Expr) -> "ExprWithOp":
- # TODO(siyuan): Support it after mod operator is supported in relax
- raise ValueError("relax.mod is not supported yet.")
+ return _binary_op_helper(self, other, _op_ffi_api.mod) # type: ignore
def __rmod__(self, other: Expr) -> "ExprWithOp":
return _binary_rhs_helper(other)
diff --git a/tests/python/relax/test_op_binary.py
b/tests/python/relax/test_op_binary.py
index 20c111495d..3376569bf3 100644
--- a/tests/python/relax/test_op_binary.py
+++ b/tests/python/relax/test_op_binary.py
@@ -33,6 +33,8 @@ def test_op_correctness():
assert relax.op.multiply(x, y).op == Op.get("relax.multiply")
assert relax.op.power(x, y).op == Op.get("relax.power")
assert relax.op.subtract(x, y).op == Op.get("relax.subtract")
+ assert relax.op.mod(x, y).op == Op.get("relax.mod")
+ assert relax.op.floor_mod(x, y).op == Op.get("relax.floor_mod")
assert relax.op.equal(x, y).op == Op.get("relax.equal")
assert relax.op.greater(x, y).op == Op.get("relax.greater")
@@ -70,6 +72,8 @@ def _check_inference(bb: relax.BlockBuilder, call:
relax.Call, expected_sinfo: r
(relax.op.subtract, tir.Sub),
(relax.op.maximum, tir.Max),
(relax.op.minimum, tir.Min),
+ (relax.op.mod, tir.Mod),
+ (relax.op.floor_mod, tir.FloorMod),
)