slyubomirsky commented on code in PR #16129:
URL: https://github.com/apache/tvm/pull/16129#discussion_r1450847241


##########
src/relax/transform/dataflow_inplace.cc:
##########
@@ -0,0 +1,1017 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ *
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/relax/transform/dataflow_inplace.cc
+ * \brief Pass that converts eligible operator calls in dataflow blocks
+ *   into in-place versions.
+ */
+
+#include <tvm/ir/transform.h>
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/attrs/op.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+#include <tvm/relax/utils.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "utils.h"
+
+namespace tvm {
+namespace relax {
+
+// Perform liveness analysis on a dataflow block, returning a map of vars to
+// pairs of indices (the liveness interval, from the starting index to the end 
index).
+// A starting index of -1 means the var is defined before the block starts and 
an end index
+// of block->bindings.size() (one past the last index) means it is live after 
the block ends.
+std::unordered_map<Var, std::pair<int, int>, ObjectPtrHash, ObjectPtrEqual> 
AnalyzeLiveness(
+    const DataflowBlock& block) {
+  std::unordered_map<Var, std::pair<int, int>, ObjectPtrHash, ObjectPtrEqual> 
ret;
+  for (int i = block->bindings.size() - 1; i >= 0; i--) {
+    Binding b = block->bindings[i];
+    Var defined_var = b->var;
+    Expr value = GetBoundValue(b);
+    Array<Var> used_vars;
+    // for a function literal, we consider only the free vars
+    // (those captured from the outer scope)
+    if (value.as<FunctionNode>()) {
+      used_vars = FreeVars(value);
+    } else if (value.as<TupleGetItemNode>()) {
+      // Special case: we do not consider a tuple index to be a "use."
+      // This is a bit of a hack but allows us to do operations that
+      // create tuples to be done in-place (otherwise, any index of the tuple
+      // would be considered a use and so the tuple would be live later).
+      // Hence we keep the array empty.
+    } else {
+      used_vars = AllVars(value);
+    }
+
+    for (auto var : used_vars) {
+      if (!ret.count(var)) {
+        ret[var] = {-1, i};
+      }
+    }
+
+    if (!ret.count(defined_var)) {
+      // if it's an output, then it lives past the end of the block
+      if (!defined_var.as<DataflowVarNode>()) {
+        ret[defined_var] = {i, block->bindings.size()};
+      } else {
+        // otherwise, it's live only here
+        ret[defined_var] = {i, i};
+      }
+    } else {
+      // this means the var is used later but we encountered its definition now
+      auto last_range = ret[defined_var];
+      CHECK_EQ(last_range.first, -1);
+      std::pair<int, int> new_range = {i, last_range.second};
+      ret[defined_var] = new_range;
+    }
+  }
+  return ret;
+}
+
+class AliasAnalyzer {
+ public:
+  AliasAnalyzer() : alias_map_(), tuple_map_(), mem_idx_(0) {}
+
+  // The analysis returns a map of vars to memory locations that it *could* 
map to
+  // (any unique allocation = one memory location), plus a map of memory 
locations
+  // that correspond to tuples (this maps to sets of memory locations for each 
tuple element).
+  // Note: inputs are values that should be assumed not to be aliased and are 
therefore
+  // (in the case of in-place ops) safe to overwrite. This may not be true of 
function args.
+  std::pair<std::unordered_map<Var, std::unordered_set<int>, ObjectPtrHash, 
ObjectPtrEqual>,
+            std::unordered_map<int, std::vector<std::unordered_set<int>>>>
+  Analyze(const DataflowBlock& block, const Array<Var>& inputs) {
+    for (auto input : inputs) {
+      int curr_idx = get_fresh_idx();
+      alias_map_[input] = {curr_idx};
+      if (auto* tup_info = GetStructInfoAs<TupleStructInfoNode>(input)) {
+        InsertFreshTuple(curr_idx, tup_info);
+      }
+    }
+
+    for (const Binding& binding : block->bindings) {
+      Var current_var = binding->var;
+      Expr value = GetBoundValue(binding);
+      alias_map_[current_var] = GetAliasSet(value, current_var);
+    }
+
+    return {alias_map_, tuple_map_};
+  }
+
+ private:
+  int get_fresh_idx() {
+    int ret = mem_idx_;
+    mem_idx_++;
+    return ret;
+  }
+
+  // Fresh tuple = each element is assumed to be a unique allocation
+  void InsertFreshTuple(int tup_idx, const TupleStructInfoNode* tup_info) {
+    std::vector<std::unordered_set<int>> tuple_set;
+    for (int i = 0; i < static_cast<int>(tup_info->fields.size()); i++) {
+      int curr_field = get_fresh_idx();
+      tuple_set.push_back({curr_field});
+      if (auto* nested_tup_info = 
tup_info->fields[i].as<TupleStructInfoNode>()) {
+        InsertFreshTuple(curr_field, nested_tup_info);
+      }
+    }
+    tuple_map_[tup_idx] = tuple_set;
+  }
+
+  // given a tuple index, add the given memory location indices to each 
component's
+  // alias set
+  void UpdateTupleComponents(int tup_idx, const std::unordered_set<int>& 
insert_idxs) {
+    if (tuple_map_.count(tup_idx)) {
+      auto tuple_comps = tuple_map_[tup_idx];
+      for (size_t i = 0; i < tuple_comps.size(); i++) {
+        auto comp_set = tuple_comps[i];
+
+        // if a member is a tuple, update its components as well
+        for (int member : comp_set) {
+          if (tuple_map_.count(member)) {
+            UpdateTupleComponents(member, insert_idxs);
+          }
+        }
+
+        // update after iterating to avoid iterating over the inserted elements
+        tuple_map_[tup_idx][i].insert(insert_idxs.begin(), insert_idxs.end());
+      }
+    }
+  }
+
+  // capture the given index and also its tuple components (including 
recursively)
+  // if they exist
+  void AddCapturedIndices(std::unordered_set<int>* captured_set, int idx) {
+    captured_set->insert(idx);
+    if (tuple_map_.count(idx)) {
+      for (auto comp_set : tuple_map_[idx]) {
+        for (auto tup_comp_idx : comp_set) {
+          AddCapturedIndices(captured_set, tup_comp_idx);
+        }
+      }
+    }
+  }
+
+  // Conservative extremely pessimistic assumption:
+  // assume that the result of a non-op call can be aliased to any argument
+  // or that it could be a newly allocated value.
+  // For tuples, assume all members are aliased. Yeah, it's bad.
+  // (Skip first arg is for handling call_pure_packed, where the first arg is 
an ExternFunc that we
+  // should ignore)
+  std::unordered_set<int> HandleMysteryCall(const CallNode* call_node, const 
Var& bound_var,
+                                            bool skip_first_arg = false) {
+    // the result may or may not be newly allocated
+    std::unordered_set<int> ret;
+    int res_idx = get_fresh_idx();
+    // the result may be a tuple
+    if (auto* tup_info_node = GetStructInfoAs<TupleStructInfoNode>(bound_var)) 
{
+      InsertFreshTuple(res_idx, tup_info_node);
+    }
+    AddCapturedIndices(&ret, res_idx);
+
+    for (size_t i = (skip_first_arg) ? 1 : 0; i < call_node->args.size(); i++) 
{
+      auto arg = call_node->args[i];
+      auto arg_alias_set = GetAliasSet(arg, bound_var);
+      for (int alias_idx : arg_alias_set) {
+        AddCapturedIndices(&ret, alias_idx);
+      }
+    }
+    // if the result is a tuple, the components can also potentially be 
aliased to any arg
+    // or, in fact, to each other
+    UpdateTupleComponents(res_idx, ret);
+    return ret;
+  }
+
+  // given the expression value, return the set of memory locations 
corresponding to it
+  // (the var the expression is being bound to is needed for struct info)
+  std::unordered_set<int> GetAliasSet(const Expr& value, const Var& bound_var) 
{
+    std::unordered_set<int> ret;
+
+    // cases for value:
+    // constant: it's a fresh index
+    // var: look up in alias map (-1 if not present)
+    // op call: assume it's fresh (may need to make list of exceptions)
+    // tuple: fresh entry in tuple index, recurse to determine indices for 
values
+    // function/packed call: chaos reigns, alias with any other argument
+    //   (if tuple is passed, assume also aliased with all members of the 
tuple)
+    // tuple index: -1 if tuple is not in tuple map, otherwise look up 
corresponding entry
+    // function constant: give them a fresh index (TODO: we can handle in more 
detail if this is a
+    // case we need to support) prim value: fresh index if node: should not 
happen inside dataflow
+    // block
+    if (value.as<ConstantNode>() || value.as<PrimValueNode>() || 
value.as<FunctionNode>()) {
+      // TODO(@slyubomirsky): We will probably want special handling for 
closures
+      ret.insert(get_fresh_idx());
+    } else if (auto* target_var_node = value.as<VarNode>()) {
+      auto target_var = GetRef<Var>(target_var_node);
+      if (alias_map_.count(target_var)) {
+        ret.insert(alias_map_[target_var].begin(), 
alias_map_[target_var].end());
+      } else {
+        ret.insert(-1);
+      }
+    } else if (auto* target_tuple = value.as<TupleNode>()) {
+      // fresh idx but we update the tuple map
+      int tup_idx = get_fresh_idx();
+      ret.insert(tup_idx);
+      std::vector<std::unordered_set<int>> new_tuple_map;
+      for (auto field : target_tuple->fields) {
+        new_tuple_map.push_back(GetAliasSet(field, bound_var));
+      }
+      tuple_map_[tup_idx] = new_tuple_map;
+    } else if (auto* target_tgi = value.as<TupleGetItemNode>()) {
+      std::unordered_set<int> tuple_set = GetAliasSet(target_tgi->tuple, 
bound_var);
+      // if -1 is a member of the tuple set, then we have to assume the result 
is -1
+      if (tuple_set.count(-1)) {
+        ret.insert(-1);
+      } else {
+        // otherwise, consider all members that are tuples of appropriate size 
and index into them
+        // (this is safe because the type system will ensure we're not 
indexing into a tuple
+        // of the wrong size)
+        for (int member : tuple_set) {
+          if (tuple_map_.count(member) &&
+              static_cast<int>(tuple_map_[member].size()) > target_tgi->index) 
{
+            auto member_set = tuple_map_[member][target_tgi->index];
+            ret.insert(member_set.begin(), member_set.end());
+          }
+        }
+      }
+    } else if (auto* call_node = value.as<CallNode>()) {
+      if (auto* op_node = call_node->op.as<OpNode>()) {
+        // call_pure_packed: treat as non-op call
+        if (op_node->name == "relax.call_pure_packed") {
+          return HandleMysteryCall(call_node, bound_var, true);
+        } else if (op_node->name == "relax.split") {
+          // split: Returns a tuple, treat as allocation
+
+          // tuple is freshly allocated, but also add components to the tuple 
map
+          int tup_idx = get_fresh_idx();
+          ret.insert(tup_idx);
+          // the LHS (the bound var) will definitely have a tuple struct info
+          InsertFreshTuple(tup_idx, 
GetStructInfoAs<TupleStructInfoNode>(bound_var));
+        } else if (op_node->name == "relax.call_tir") {
+          // call_tir: can potentially return a tuple
+          if (auto* tuple_struct_info = 
call_node->sinfo_args[0].as<TupleStructInfoNode>()) {
+            int tup_idx = get_fresh_idx();
+            ret.insert(tup_idx);
+            InsertFreshTuple(tup_idx, tuple_struct_info);
+          } else {
+            ret.insert(get_fresh_idx());
+          }
+        } else {
+          // We are assuming most op calls return a single fresh allocation.
+          // We may have to track more exceptions
+          ret.insert(get_fresh_idx());
+        }
+      } else {
+        // assume any non-op call can be extremely dangerous and do anything
+        return HandleMysteryCall(call_node, bound_var);
+      }
+    }
+
+    return ret;
+  }
+
+  std::unordered_map<Var, std::unordered_set<int>, ObjectPtrHash, 
ObjectPtrEqual> alias_map_;
+  std::unordered_map<int, std::vector<std::unordered_set<int>>> tuple_map_;
+  int mem_idx_;
+};
+
+// given a shape, return the allocation size corresponding to it (product of 
elements)
+int ShapeSize(const ShapeExpr& shape) {
+  int ret = 1;
+  for (auto dim : shape->values) {
+    if (auto int_dim = dim.as<IntImmNode>()) {
+      ret *= static_cast<int>(int_dim->value);
+    } else {
+      return -1;
+    }
+  }
+  return ret;
+}
+
+// Given the struct info of the result, return any struct info nested in it
+// that is eleigible to be used for in-place computations (tensors are eligible
+// only if all their dimensions are integer constants, tuples are eligible if
+// all members are eligible though we can consider only individual members 
separately)
+std::unordered_set<StructInfo, ObjectPtrHash, ObjectPtrEqual> 
GatherCandidateSinfo(
+    const StructInfo& result_sinfo) {
+  if (auto* tensor_info = result_sinfo.as<TensorStructInfoNode>()) {
+    // don't consider void dtype (don't know the size at compile time)
+    if (tensor_info->dtype.is_void()) {
+      return {};
+    }
+    // don't consider cases where we don't know the shape at compile time
+    // TODO(@slyubomirsky): variables might be okay if we use the arithmetic 
analyzer
+    if (auto* shape_node = tensor_info->shape.as<ShapeExprNode>()) {
+      for (auto dim : shape_node->values) {
+        if (!dim.as<IntImmNode>()) {
+          return {};
+        }
+      }
+      return {GetRef<TensorStructInfo>(tensor_info)};
+    } else {
+      return {};
+    }
+  } else if (auto* tuple_info = result_sinfo.as<TupleStructInfoNode>()) {
+    // we can see if the whole tuple matches or go for any of the components
+    std::unordered_set<StructInfo, ObjectPtrHash, ObjectPtrEqual> ret;
+    for (auto field : tuple_info->fields) {
+      auto field_candidates = GatherCandidateSinfo(field);
+      ret.insert(field_candidates.begin(), field_candidates.end());
+    }
+    // at least one field should be eligible to be done in-place
+    if (!ret.empty()) {
+      ret.insert(GetRef<StructInfo>(tuple_info));
+    }
+    return ret;
+  } else {
+    // don't consider any other types
+    return {};
+  }
+}
+
+// Given the two struct info, return a pair of bools where the first element 
is true if
+// the two struct info are the same _size_ in memory and the second element is 
true
+// if the shapes match _exactly_. Performs this check recursively and ensures 
the
+// stated condition is true for all tensor members of the struct info (return 
false
+// if a single pair of corresponding tensors does not meet the condition).
+std::pair<bool, bool> SizeMatches(const StructInfo& target_info, const 
StructInfo& arg_info) {
+  if (target_info.as<TensorStructInfoNode>() && 
arg_info.as<TensorStructInfoNode>()) {
+    auto target_tensor = Downcast<TensorStructInfo>(target_info);
+    auto arg_tensor = Downcast<TensorStructInfo>(arg_info);
+    if (target_tensor->shape.defined() && 
target_tensor->shape.as<ShapeExprNode>() &&
+        arg_tensor->shape.defined() && arg_tensor->shape.as<ShapeExprNode>()) {
+      if (target_tensor->dtype != arg_tensor->dtype) {
+        return {false, false};
+      }
+      auto target_shape = Downcast<ShapeExpr>(target_tensor->shape);
+      auto arg_shape = Downcast<ShapeExpr>(arg_tensor->shape);
+      int target_size = ShapeSize(target_shape);
+      int arg_size = ShapeSize(arg_shape);
+      if (target_size == -1 || arg_size == -1 || target_size < arg_size) {
+        return {false, false};
+      }
+      // exact match: number of dims and each dim matches
+      if (target_shape->values.size() == arg_shape->values.size()) {
+        for (size_t i = 0; i < target_shape->values.size(); i++) {
+          if (!arg_shape->values[i].as<IntImmNode>()) {
+            return {false, false};
+          }
+          if (Downcast<IntImm>(target_shape->values[i])->value !=
+              Downcast<IntImm>(arg_shape->values[i])->value) {
+            return {true, false};
+          }
+        }
+        return {true, true};
+      }
+      return {true, false};
+    } else {
+      return {false, false};
+    }
+  } else if (target_info.as<TupleStructInfoNode>() && 
arg_info.as<TupleStructInfoNode>()) {
+    auto target_tup = Downcast<TupleStructInfo>(target_info);
+    auto arg_tup = Downcast<TupleStructInfo>(arg_info);
+    if (target_tup->fields.size() != arg_tup->fields.size()) {
+      return {false, false};
+    }
+    bool all_exact = true;
+    for (size_t i = 0; i < target_tup->fields.size(); i++) {
+      auto element_match = SizeMatches(target_tup->fields[i], 
arg_tup->fields[i]);
+      if (!element_match.first) {
+        return {false, false};
+      }
+      if (!element_match.second) {
+        all_exact = false;
+      }
+    }
+    return {true, all_exact};
+  } else if (target_info.as<PrimStructInfoNode>() && 
arg_info.as<PrimStructInfoNode>()) {

Review Comment:
   Good point.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to