slyubomirsky commented on code in PR #16595:
URL: https://github.com/apache/tvm/pull/16595#discussion_r1499985241


##########
src/relax/transform/lift_transform_params.cc:
##########
@@ -37,405 +38,467 @@
 namespace tvm {
 namespace relax {
 
-/*! \brief Plan of lifting transform params */
-struct LiftTransformParamsInfoPlan {
-  Function f_transform_params;  // the lifted function that transforms the 
parameters
-  std::unordered_map<Var, int, ObjectPtrHash, ObjectPtrEqual>
-      output_to_index;  // the index of the original bindings in the output 
tuple
-  std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>
-      lifted_bindings;  // the bindings of the original function that are 
lifted
-};
+namespace {
 
-/*! \brief Builder of the function that transforms the parameters. */
-class TransformParamsFuncBuilder : public ExprMutator {
- public:
-  TransformParamsFuncBuilder() { builder_->BeginDataflowBlock(); }
+struct CollectInfo {
+  /* \brief The analyzed function */
+  Function orig_func;
+
+  /* \brief The number of parameters unknown until runtime */
+  size_t num_runtime_params;
 
-  /*! \brief Add a input parameter. */
-  void AddInput(const Var& var) {
-    inputs_.push_back(var);
-    lifted_binding_lookup_.insert(var);
+  /*! \brief Bindings that can be lifted out into a pre-processing
+   *
+   * - All bindings in `liftable_bindings` are suitable for use in a
+   *   DataflowBlock.
+   *
+   * - Do not depend on any parameter prior to attr::kNumInput.
+   *
+   * - Does not include "relax.builtin.stop_lift_params"
+   */
+  std::vector<Binding> computable_at_compile_time;
+
+  /*! \brief Variables that require a compile-time parameter
+   *
+   * Used to distinguish between parameters
+   */
+  std::unordered_set<Variant<relax::Var, tir::Var>, ObjectPtrHash, 
ObjectPtrEqual>
+      requires_compile_time_param;
+
+  /*! \brief Variables that are required at runtime */
+  std::unordered_set<Variant<relax::Var, tir::Var>, ObjectPtrHash, 
ObjectPtrEqual>
+      required_at_runtime;
+
+  Array<Var> GetCompileTimeInputs() const {
+    return Array<Var>(orig_func->params.begin() + num_runtime_params, 
orig_func->params.end());
   }
 
-  void UpdateBasedOnRuntimeInput(const Var& var) {
-    for (const auto& var : DefinableTIRVarsInStructInfo(GetStructInfo(var))) {
-      known_symbolic_var_during_inference_.insert(var);
-    }
-    for (const auto& var : TIRVarsInStructInfo(GetStructInfo(var))) {
-      required_symbolic_var_during_inference_.insert(var);
-    }
+  Array<Var> GetRuntimeInputs() const {
+    return Array<Var>(orig_func->params.begin(), orig_func->params.begin() + 
num_runtime_params);
   }
 
-  /*! \brief Add a binding to lift. */
-  void AddInternalBinding(const VarBinding& binding) {
-    bindings_.push_back(binding);
-    lifted_binding_lookup_.insert(binding->var);
+  Array<tir::Var> GetPropagatedSymbolicVariables() const {
+    auto vars_from_any_param =
+        
DefinableTIRVarsInStructInfo(TupleStructInfo(orig_func->params.Map(GetStructInfo)));
+
+    auto vars_from_runtime_params =
+        [&]() -> std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> {
+      auto tir_var_vec =
+          
DefinableTIRVarsInStructInfo(TupleStructInfo(GetRuntimeInputs().Map(GetStructInfo)));
+      return {tir_var_vec.begin(), tir_var_vec.end()};
+    }();
+
+    auto vars_from_transformed_params =
+        [&]() -> std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> {
+      auto tir_var_vec =
+          
DefinableTIRVarsInStructInfo(TupleStructInfo(GetCompileTimeOutputs().Map(GetStructInfo)));
+      return {tir_var_vec.begin(), tir_var_vec.end()};
+    }();
+
+    Array<tir::Var> output;
+    for (const auto& tir_var : vars_from_any_param) {
+      if (required_at_runtime.count(tir_var) && 
!vars_from_runtime_params.count(tir_var) &&
+          !vars_from_transformed_params.count(tir_var)) {
+        output.push_back(tir_var);
+      }
+    }
+    return output;
   }
 
-  /*! \brief Update based on bindings not being lifted. */
-  void UpdateBasedOnRuntimeBinding(const VarBinding& binding) {
-    for (const auto& producer : FreeVars(binding->value)) {
-      // An external value that uses a lifted binding requires the
-      // lifted binding to be returned as output.
-      if (lifted_binding_lookup_.count(producer)) {
-        outputs_.insert(producer);
+  Array<Var> GetCompileTimeOutputs() const {
+    Array<Var> params;
 
-        for (const auto& var : 
DefinableTIRVarsInStructInfo(GetStructInfo(producer))) {
-          known_symbolic_var_during_inference_.insert(var);
-        }
+    // Any value that is available at compile-time, but is also
+    // required at runtime, must be passed through the compile-time
+    // function.
+    for (size_t i = num_runtime_params; i < orig_func->params.size(); i++) {
+      Var var = orig_func->params[i];
+      if (required_at_runtime.count(var)) {
+        params.push_back(var);
       }
     }
 
-    // All TIR variables used in the binding must be available at runtime.
-    for (const auto& var : FreeSymbolicVars(binding->value)) {
-      required_symbolic_var_during_inference_.insert(var);
+    // Any variable that is computed at compile-time, but is required
+    // at runtime, must be provided as a parameter.
+    for (const auto& binding : computable_at_compile_time) {
+      if (requires_compile_time_param.count(binding->var) &&
+          required_at_runtime.count(binding->var)) {
+        params.push_back(binding->var);
+      }
     }
-  }
 
-  bool UsesOnlyLiftableProducers(const Expr& expr) {
-    auto producers = FreeVars(expr);
-    bool uses_only_liftable_producers = [&]() {
-      return std::all_of(producers.begin(), producers.end(),
-                         [&](const auto& var) { return 
lifted_binding_lookup_.count(var); });
-    }();
-    return uses_only_liftable_producers;
+    return params;
   }
 
-  /*!
-   * \brief Build the function that transforms the parameters
-   * \return The created function, and a map from the variable in the original 
function to the index
-   * of the element of the output tuple
-   */
-  std::pair<Function, std::unordered_map<Var, int, ObjectPtrHash, 
ObjectPtrEqual>> Build() {
-    Array<PrimExpr> extra_symbolic_vars;
-    for (const auto& var : required_symbolic_var_during_inference_) {
-      if (!known_symbolic_var_during_inference_.count(var)) {
-        extra_symbolic_vars.push_back(var);
-      }
-    }
+  Function MakeCompileTimeFunction() const {
+    auto compile_time_params = GetCompileTimeInputs();
 
-    Array<StructInfo> input_sinfo;
-    Array<Expr> output_vars;
-    std::unordered_map<Var, int, ObjectPtrHash, ObjectPtrEqual> 
output_to_index;
+    Array<Binding> output_var_binding;
+    Array<Expr> output_exprs;
 
-    for (const auto& input : inputs_) {
-      input_sinfo.push_back(Downcast<StructInfo>(input->struct_info_.value()));
+    // Any symbolic variables that are inferrable from compile-time
+    // parameters, but are not inferrable from run-time parameters,
+    // must be propagated to the output.
+    if (auto propagated_tir_vars = GetPropagatedSymbolicVariables(); 
propagated_tir_vars.size()) {
+      output_exprs.push_back(
+          ShapeExpr(propagated_tir_vars.Map([](tir::Var var) -> PrimExpr { 
return var; })));
     }
-    Var params("params", TupleStructInfo(input_sinfo));
 
-    if (extra_symbolic_vars.size()) {
-      output_vars.push_back(builder_->Emit(ShapeExpr(extra_symbolic_vars), 
"extra_symbolic_vars"));
+    for (const auto& var : GetCompileTimeOutputs()) {
+      Var out_var(var->name_hint() + "_output", GetStructInfo(var));
+      output_var_binding.push_back(VarBinding(out_var, var));
+      output_exprs.push_back(out_var);
     }
 
-    // Helper to add a variable to the output tuple
-    // original_var: the binding variable in the original function
-    // output_var: the variable, which is a binding in the transform_params 
function, that is added
-    // to the output tuple
-    auto f_add_output = [&](const Var& original_var, const Var& output_var) -> 
void {
-      output_to_index[original_var] = output_vars.size();
-      output_vars.push_back(output_var);
-    };
+    Var tuple_var("output_tuple", 
TupleStructInfo(output_exprs.Map(GetStructInfo)));
+    output_var_binding.push_back(VarBinding(tuple_var, Tuple(output_exprs)));
+
+    SeqExpr body(
+        {
+            DataflowBlock(computable_at_compile_time),
+            DataflowBlock(output_var_binding),
+        },
+        tuple_var);
+
+    Function func(compile_time_params, body, GetStructInfo(tuple_var));
+    func = WithAttr(func, attr::kNumInput, Integer(0));
+    func = CopyWithNewVars(func);
+    func = Downcast<Function>(CanonicalizeBindings(func));
+    return func;
+  }
 
-    // Create mapping from the original input variables to the TupleGetItem 
from the packed
-    // parameter tuple Add the parameters that are marked as the output of the 
function to the
-    // output tuple
-    for (const auto& input : inputs_) {
-      input_remap_.emplace(input.get(), TupleGetItem(params, 
input_remap_.size()));
-      if (outputs_.count(input)) {
-        auto output_var = builder_->Emit(input_remap_.at(input.get()));
-        f_add_output(input, output_var);
-      }
+  Function MakeRuntimeFunction() const {
+    Array<Binding> bindings;
+
+    // Any parameter that isn't available until runtime must be an
+    // input, along with any output from the compile-time function.
+    // Compile-time outputs must have a fresh non-dataflow var to
+    // serve as the parameter.  This trivial binding will later be
+    // removed with CanonicalizeBindings.
+    Array<Var> params = GetRuntimeInputs();
+    if (auto propagated_tir_vars = GetPropagatedSymbolicVariables(); 
propagated_tir_vars.size()) {
+      ShapeStructInfo shape_sinfo(
+          propagated_tir_vars.Map([](tir::Var var) -> PrimExpr { return var; 
}));
+      Var shape_expr("vars_from_compile_time_params", shape_sinfo);
+      params.push_back(shape_expr);
+    }
+    for (const auto& var : GetCompileTimeOutputs()) {
+      Var param_var(var->name_hint(), GetStructInfo(var));
+      bindings.push_back(VarBinding(var, param_var));
+      params.push_back(param_var);
     }
 
-    // Re-emit the bindings that are lifted. Update the output tuple if the 
binding is marked as the
-    // output.
-    for (const auto& binding : bindings_) {
-      if (outputs_.count(binding->var)) {
-        auto output_var = builder_->Emit(VisitExpr(binding->value));
-        var_remap_[binding->var->vid] = output_var;
-        f_add_output(binding->var, output_var);
-      } else {
-        VisitBinding(binding);
+    // Any binding that is computable at compile-time should be
+    // suppressed at run-time.
+    std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> to_suppress;
+    for (const auto& binding : computable_at_compile_time) {
+      if (requires_compile_time_param.count(binding->var)) {
+        to_suppress.insert(binding->var);
       }
     }
 
-    // Create the function.
-    Expr transformed_params = builder_->EmitOutput(Tuple(output_vars));
-    BindingBlock block = builder_->EndBlock();
-    Expr body = VisitWithNewScope(SeqExpr({block}, transformed_params), 
Array<Var>{params});
-    Function f_transform_params =
-        Function(/*params=*/{params}, /*body=*/body, 
/*ret_struct_info=*/NullOpt);
-    return {f_transform_params, output_to_index};
-  }
+    struct SuppressCompileTime : ExprMutator {
+      std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> to_suppress;
+      explicit SuppressCompileTime(
+          std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> to_suppress)
+          : to_suppress(to_suppress) {}

Review Comment:
   I was just curious, I don't think it's worthy of holding up a review or 
anything like that. Having a different rule for public/private members is 
reasonable.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to