gemini-code-assist[bot] commented on code in PR #18736:
URL: https://github.com/apache/tvm/pull/18736#discussion_r2780403242


##########
src/relax/transform/fold_constant.cc:
##########
@@ -175,25 +174,73 @@ class ConstantFolder : public ExprMutator {
     return Constant(ret_tensor);
   }
 
+  // Try constant evaluate a call_tir with tuple outputs (multiple output 
tensors).
+  // Returns std::nullopt on failure.
+  ffi::Optional<Expr> ConstEvaluateCallTIRTuple(tir::PrimFunc tir_func,
+                                                ffi::Array<runtime::Tensor> 
arr_args,
+                                                const TupleStructInfoNode* 
tuple_sinfo) {
+    ffi::Optional<ffi::Function> func = GetCachedBuild(tir_func);
+    if (!func) return std::nullopt;
+
+    DLDevice cpu_dev = {DLDeviceType::kDLCPU, 0};
+    size_t num_outputs = tuple_sinfo->fields.size();
+
+    // Match shapes and dtypes for all output fields.
+    std::vector<runtime::Tensor> ret_tensors;
+    for (size_t i = 0; i < num_outputs; ++i) {
+      ffi::Optional<ffi::Shape> shape = 
MatchConstShape(tuple_sinfo->fields[i]);
+      if (!shape) return std::nullopt;
+      auto tensor_sinfo = tuple_sinfo->fields[i].as<TensorStructInfoNode>();
+      if (!tensor_sinfo || tensor_sinfo->IsUnknownDtype()) return std::nullopt;

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   Since `MatchConstShape` is called on line 191 and it succeeds, we know that 
`tuple_sinfo->fields[i]` is a `TensorStructInfo`. Therefore, you can use 
`Downcast<TensorStructInfo>` here instead of `as<TensorStructInfoNode>()` to 
make the assumption explicit and avoid a redundant null check. `Downcast` will 
perform a checked cast.
   
   ```suggestion
         auto tensor_sinfo = Downcast<TensorStructInfo>(tuple_sinfo->fields[i]);
         if (tensor_sinfo->IsUnknownDtype()) return std::nullopt;
   ```



##########
src/relax/transform/fold_constant.cc:
##########
@@ -175,25 +174,73 @@ class ConstantFolder : public ExprMutator {
     return Constant(ret_tensor);
   }
 
+  // Try constant evaluate a call_tir with tuple outputs (multiple output 
tensors).
+  // Returns std::nullopt on failure.
+  ffi::Optional<Expr> ConstEvaluateCallTIRTuple(tir::PrimFunc tir_func,
+                                                ffi::Array<runtime::Tensor> 
arr_args,
+                                                const TupleStructInfoNode* 
tuple_sinfo) {
+    ffi::Optional<ffi::Function> func = GetCachedBuild(tir_func);
+    if (!func) return std::nullopt;
+
+    DLDevice cpu_dev = {DLDeviceType::kDLCPU, 0};
+    size_t num_outputs = tuple_sinfo->fields.size();
+
+    // Match shapes and dtypes for all output fields.
+    std::vector<runtime::Tensor> ret_tensors;
+    for (size_t i = 0; i < num_outputs; ++i) {
+      ffi::Optional<ffi::Shape> shape = 
MatchConstShape(tuple_sinfo->fields[i]);
+      if (!shape) return std::nullopt;
+      auto tensor_sinfo = tuple_sinfo->fields[i].as<TensorStructInfoNode>();
+      if (!tensor_sinfo || tensor_sinfo->IsUnknownDtype()) return std::nullopt;
+      ret_tensors.push_back(runtime::Tensor::Empty(shape.value(), 
tensor_sinfo->dtype, cpu_dev));
+    }
+
+    // Pack input args + all output tensors.
+    std::vector<AnyView> packed_args(arr_args.size() + num_outputs);
+    std::vector<runtime::Tensor> temp_args(arr_args.begin(), arr_args.end());
+
+    size_t arg_offset = 0;
+    for (; arg_offset < arr_args.size(); ++arg_offset) {
+      packed_args[arg_offset] = temp_args[arg_offset];
+    }
+    for (size_t i = 0; i < num_outputs; ++i) {
+      packed_args[arg_offset++] = ret_tensors[i];
+    }

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The argument packing logic can be simplified for better readability and to 
be more idiomatic C++. Instead of using a C-style loop with an index, you can 
use range-based for loops to populate `packed_args`.
   
   ```c
       std::vector<runtime::Tensor> temp_args(arr_args.begin(), arr_args.end());
       std::vector<AnyView> packed_args;
       packed_args.reserve(temp_args.size() + num_outputs);
       for (const auto& arg : temp_args) {
         packed_args.push_back(arg);
       }
       for (const auto& out_tensor : ret_tensors) {
         packed_args.push_back(out_tensor);
       }
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to