This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4533d31488 [Relax] Add FDataDependent operator attribute for
LegalizeOps (#18664)
4533d31488 is described below
commit 4533d314884f0e3d1c19c6c410703b552b7e4083
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Wed Jan 21 11:04:30 2026 +0800
[Relax] Add FDataDependent operator attribute for LegalizeOps (#18664)
## Why
The LegalizeOps transform was using string matching to detect
data-dependent operators by checking if "dynamic" appears in the
operator name. This approach is fragile and doesn't scale well as new
data-dependent operators are added.
## How
- Add FDataDependent operator attribute to properly mark data-dependent
operators
- Set FDataDependent=true for relax.dynamic_strided_slice operator
- Update LegalizeOps transform to check the FDataDependent attribute
instead of string matching
---
src/relax/op/tensor/index.cc | 3 ++-
src/relax/transform/legalize_ops.cc | 15 +++++++++------
.../test_transform_legalize_ops_index_linear_algebra.py | 16 +++++++++++++---
3 files changed, 24 insertions(+), 10 deletions(-)
diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc
index a5b6225ec4..e59ba7b597 100644
--- a/src/relax/op/tensor/index.cc
+++ b/src/relax/op/tensor/index.cc
@@ -574,7 +574,8 @@ TVM_REGISTER_OP("relax.dynamic_strided_slice")
.set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoDynStridedSlice)
.set_attr<FRelaxInferLayout>("FRelaxInferLayout",
InferLayoutDynStridedSlice)
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow)
- .set_attr<Bool>("FPurity", Bool(true));
+ .set_attr<Bool>("FPurity", Bool(true))
+ .set_attr<Bool>("FDataDependent", Bool(true));
} // namespace relax
} // namespace tvm
diff --git a/src/relax/transform/legalize_ops.cc
b/src/relax/transform/legalize_ops.cc
index 75e0776418..723c281403 100644
--- a/src/relax/transform/legalize_ops.cc
+++ b/src/relax/transform/legalize_ops.cc
@@ -287,8 +287,15 @@ class LegalizeMutator : public ExprMutator {
return false;
}
- std::string op_name(op->name);
- bool is_data_dependent_op = (op_name.find("dynamic") !=
std::string::npos);
+ bool is_data_dependent_op = [&]() -> bool {
+ if (Op::HasAttrMap("FDataDependent")) {
+ auto op_map = Op::GetAttrMap<Bool>("FDataDependent");
+ if (op_map.count(op)) {
+ return op_map[op]->value;
+ }
+ }
+ return false;
+ }();
bool ret_shape_defined = KnowAllShapeValues(GetStructInfo(visited_call));
if (!is_data_dependent_op && !ret_shape_defined) {
// This operator cannot be legalized, because legalization by
@@ -303,10 +310,6 @@ class LegalizeMutator : public ExprMutator {
// data-dependent op, and match cast to define symbolic output
// shapes. These symbolic output shapes at compile time can
// be by later operations to refer to the runtime shape.
- //
- // TODO(Lunderberg): Make a new operator attribute
- // `.set_attr<Bool>("DataDependent")`, rather than relying on
- // the name of the operator.
return false;
}
diff --git
a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
index a6e53dab4d..44419e51e7 100644
--- a/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
+++ b/tests/python/relax/test_transform_legalize_ops_index_linear_algebra.py
@@ -16,10 +16,12 @@
# under the License.
import tvm
-from tvm.relax.transform import LegalizeOps
-from tvm.script import relax as R, tir as T, ir as I
import tvm.testing
-
+from tvm.ir import Op
+from tvm.relax.transform import LegalizeOps
+from tvm.script import ir as I
+from tvm.script import relax as R
+from tvm.script import tir as T
##################### Indexing #####################
@@ -1197,5 +1199,13 @@ def test_einsum_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_data_dependent_attribute():
+ dynamic_strided_slice_op = Op.get("relax.dynamic_strided_slice")
+ assert dynamic_strided_slice_op.get_attr("FDataDependent")
+
+ strided_slice_op = Op.get("relax.strided_slice")
+ assert strided_slice_op.get_attr("FDataDependent") is None
+
+
if __name__ == "__main__":
tvm.testing.main()