gemini-code-assist[bot] commented on code in PR #18764:
URL: https://github.com/apache/tvm/pull/18764#discussion_r2796508268
##########
src/relax/transform/fold_constant.cc:
##########
@@ -137,9 +137,51 @@ class ConstantFolder : public ExprMutator {
* of the program.
*/
bool ShouldBeFolded(Expr expr) {
- // TODO(prakalp): Implement a heuristic to check if folding this expr is
actually useful or
- // not.
- return true;
+ // Skip folding for creation ops (no tensor inputs) that produce large
outputs.
+ // These ops (e.g., zeros, ones, full, arange) are cheap to compute at
runtime,
+ // and folding them would materialize large constants in the binary.
+ static constexpr int64_t kMaxFoldElements = 1024;
+
+ const auto* call = expr.as<CallNode>();
+ if (!call) return true;
+
+ const auto* tensor_sinfo = call->struct_info_.as<TensorStructInfoNode>();
+ if (!tensor_sinfo) return true;
+
+ auto opt_shape = tensor_sinfo->GetShape();
+ if (!opt_shape) return true;
+
+ int64_t num_elements = 1;
+ for (const auto& dim : opt_shape.value()) {
+ const auto* int_dim = dim.as<IntImmNode>();
+ if (!int_dim) return true;
+ num_elements *= int_dim->value;
Review Comment:

The direct multiplication `num_elements *= int_dim->value;` can lead to an
integer overflow if the tensor has a very large number of elements, which is
undefined behavior in C++. It's safer to check for potential overflow before
performing the multiplication. Additionally, you can optimize by breaking the
loop as soon as `num_elements` exceeds `kMaxFoldElements`.
```suggestion
int64_t d = int_dim->value;
if (d < 0) {
// Dimensions should be non-negative.
return true;
}
if (d == 0) {
num_elements = 0;
break;
}
// Check for overflow and if we've already exceeded the threshold.
if (num_elements > kMaxFoldElements / d) {
num_elements = kMaxFoldElements + 1;
break;
}
num_elements *= d;
```
##########
src/relax/transform/fold_constant.cc:
##########
@@ -137,9 +137,51 @@ class ConstantFolder : public ExprMutator {
* of the program.
*/
bool ShouldBeFolded(Expr expr) {
- // TODO(prakalp): Implement a heuristic to check if folding this expr is
actually useful or
- // not.
- return true;
+ // Skip folding for creation ops (no tensor inputs) that produce large
outputs.
+ // These ops (e.g., zeros, ones, full, arange) are cheap to compute at
runtime,
+ // and folding them would materialize large constants in the binary.
+ static constexpr int64_t kMaxFoldElements = 1024;
+
+ const auto* call = expr.as<CallNode>();
+ if (!call) return true;
+
+ const auto* tensor_sinfo = call->struct_info_.as<TensorStructInfoNode>();
+ if (!tensor_sinfo) return true;
+
+ auto opt_shape = tensor_sinfo->GetShape();
+ if (!opt_shape) return true;
+
+ int64_t num_elements = 1;
+ for (const auto& dim : opt_shape.value()) {
+ const auto* int_dim = dim.as<IntImmNode>();
+ if (!int_dim) return true;
+ num_elements *= int_dim->value;
+ }
+
+ if (num_elements <= kMaxFoldElements) return true;
+
+ // Large output. Only skip if there are no tensor inputs,
+ // i.e., this is a pure creation op.
+ bool has_tensor_input = false;
+ for (const auto& arg : call->args) {
+ auto arg_sinfo = GetStructInfo(arg);
+ if (arg_sinfo.as<TensorStructInfoNode>()) {
+ has_tensor_input = true;
+ break;
+ }
+ // Check for tuple of tensors (e.g., call_tir args)
+ if (const auto* tuple = arg.as<TupleNode>()) {
+ for (const auto& field : tuple->fields) {
+ if (GetStructInfo(field).as<TensorStructInfoNode>()) {
+ has_tensor_input = true;
+ break;
+ }
+ }
+ if (has_tensor_input) break;
+ }
+ }
+
+ return has_tensor_input;
Review Comment:

The current logic for detecting tensor inputs doesn't handle nested tuples
(e.g., a tuple containing another tuple of tensors). This could lead to
incorrect folding decisions for operations with such argument structures.
A recursive helper function would make the logic cleaner, more robust, and
correct for nested structures. I recommend adding a private static helper
method to `ConstantFolder` like this:
```cpp
static bool ExprContainsTensor(const Expr& expr) {
if (GetStructInfo(expr).as<TensorStructInfoNode>()) {
return true;
}
if (const auto* tuple = expr.as<TupleNode>()) {
for (const auto& field : tuple->fields) {
if (ExprContainsTensor(field)) {
return true;
}
}
}
return false;
}
```
Using this helper simplifies the main logic and correctly handles
arbitrarily nested tuples. It would also be beneficial to add a test case for
this scenario.
```suggestion
// Large output. Only skip if there are no tensor inputs,
// i.e., this is a pure creation op.
for (const auto& arg : call->args) {
if (ExprContainsTensor(arg)) {
return true;
}
}
return false;
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]