comaniac commented on a change in pull request #6107:
URL: https://github.com/apache/incubator-tvm/pull/6107#discussion_r458955101



##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -923,5 +958,275 @@ String 
ComputeRootStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
   return ss.str();
 }
 
+/********** Primitives adding new stages **********/
+
+// Common part for steps that add new stages
+// (e.g. CacheReadStep, CacheWriteStep, RfactorStep)
+void AddStageModificationSteps(int step_id, const Array<Step>& transform_steps,
+                               Array<Step>* replay_steps) {

Review comment:
       The description and function name could be improved. This function 
checks if `step_id` is a step that adds a new stage, and puts the step to 
`replay_steps` if so. Will this function be extended when adding Rfactor? If 
not, then it seems to me that this function is unnecessary.

##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -923,5 +958,275 @@ String 
ComputeRootStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
   return ss.str();
 }
 
+/********** Primitives adding new stages **********/
+
+// Common part for steps that add new stages
+// (e.g. CacheReadStep, CacheWriteStep, RfactorStep)
+void AddStageModificationSteps(int step_id, const Array<Step>& transform_steps,
+                               Array<Step>* replay_steps) {
+  const Step& step = transform_steps[step_id];
+  if (step->IsInstance<CacheWriteStepNode>() || 
step->IsInstance<CacheReadStepNode>()) {
+    replay_steps->push_back(step);
+  }
+  // TODO(jcf94): add rfactor support
+}
+
+/********** Cache Read **********/
+CacheReadStep::CacheReadStep(int stage_id, String scope_name,
+                             const Array<Integer>& reader_stage_ids) {
+  auto node = make_object<CacheReadStepNode>();
+  node->stage_id = stage_id;
+  node->scope_name = std::move(scope_name);
+  node->reader_stage_ids = reader_stage_ids;
+  data_ = std::move(node);
+}
+
+CacheReadStep::CacheReadStep(dmlc::JSONReader* reader) {
+  auto node = make_object<CacheReadStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  std::string string_value;
+  reader->Read(&string_value);
+  node->scope_name = std::move(string_value);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  std::vector<int> int_list;
+  reader->Read(&int_list);
+  Array<Integer> reader_stage_ids;
+  for (int i : int_list) {
+    reader_stage_ids.push_back(i);
+  }
+  node->reader_stage_ids = std::move(reader_stage_ids);
+  data_ = std::move(node);
+}
+
+void CacheReadStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArraySeperator();
+  writer->WriteString(scope_name);
+  writer->WriteArrayItem(IntArrayToVector(reader_stage_ids));
+}
+
+int CacheReadStepNode::ApplyToState(State* state, const ComputeDAG& dag) const 
{
+  StateNode* pstate = state->CopyOnWrite();
+  Array<Step> replay_steps;
+  for (size_t i = 0; i < pstate->transform_steps.size(); ++i) {
+    AddStageModificationSteps(i, pstate->transform_steps, &replay_steps);
+    if (pstate->transform_steps[i].same_as(GetRef<Step>(this))) {
+      break;
+    }
+  }
+  const ComputeDAG& current_compute_dag = dag.ReplayAndGetDAG(replay_steps);
+
+  // target -> target + target_store
+  // Should update target's op, insert new stage, update the later stage's op
+  int added_stage_id = stage_id + 1;
+  Stage tmp_stage = pstate->stages[stage_id];
+  tmp_stage.CopyOnWrite()->op = current_compute_dag->ops[stage_id];
+  pstate->stages.Set(stage_id, std::move(tmp_stage));
+  pstate->stages.insert(pstate->stages.begin() + added_stage_id,
+                        Stage(current_compute_dag->ops[added_stage_id]));
+  for (size_t i = added_stage_id + 1; i < pstate->stages.size(); ++i) {
+    tmp_stage = pstate->stages[i];
+    tmp_stage.CopyOnWrite()->op = current_compute_dag->ops[i];
+    pstate->stages.Set(i, std::move(tmp_stage));
+  }
+  pstate->attach_map = pstate->attach_map.ApplyStageIdOffset(added_stage_id);
+  pstate->current_compute_dag = std::move(current_compute_dag);
+
+  return added_stage_id;
+}
+
+te::Tensor CacheReadStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+                                              StageToAxesMap* stage_to_axes,
+                                              te::Schedule* schedule) const {
+  const te::Stage& stage = (*stages)[stage_id];
+
+  Array<te::Operation> readers;
+  for (const auto& i : reader_stage_ids) {
+    readers.push_back((*stages)[i]->origin_op);
+  }
+  auto out = schedule->cache_read(stage->origin_op.output(0), scope_name, 
readers);
+
+  const auto& new_stage = (*schedule)[out->op];
+  UpdateStageToAxesMap(new_stage, stage_to_axes);
+  stages->insert(stages->begin() + stage_id + 1, new_stage);
+
+  return out;
+}
+
+String CacheReadStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, 
StageToAxesMap* stage_to_axes,
+                                           te::Schedule* schedule) const {
+  std::stringstream ss;
+  // Copy stage here, for the original stage will change after apply

Review comment:
       Improve this comment.

##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -923,5 +958,275 @@ String 
ComputeRootStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
   return ss.str();
 }
 
+/********** Primitives adding new stages **********/
+
+// Common part for steps that add new stages
+// (e.g. CacheReadStep, CacheWriteStep, RfactorStep)
+void AddStageModificationSteps(int step_id, const Array<Step>& transform_steps,
+                               Array<Step>* replay_steps) {
+  const Step& step = transform_steps[step_id];
+  if (step->IsInstance<CacheWriteStepNode>() || 
step->IsInstance<CacheReadStepNode>()) {
+    replay_steps->push_back(step);
+  }
+  // TODO(jcf94): add rfactor support
+}
+
+/********** Cache Read **********/
+CacheReadStep::CacheReadStep(int stage_id, String scope_name,
+                             const Array<Integer>& reader_stage_ids) {
+  auto node = make_object<CacheReadStepNode>();
+  node->stage_id = stage_id;
+  node->scope_name = std::move(scope_name);
+  node->reader_stage_ids = reader_stage_ids;
+  data_ = std::move(node);
+}
+
+CacheReadStep::CacheReadStep(dmlc::JSONReader* reader) {
+  auto node = make_object<CacheReadStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  std::string string_value;
+  reader->Read(&string_value);
+  node->scope_name = std::move(string_value);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  std::vector<int> int_list;
+  reader->Read(&int_list);
+  Array<Integer> reader_stage_ids;
+  for (int i : int_list) {
+    reader_stage_ids.push_back(i);
+  }
+  node->reader_stage_ids = std::move(reader_stage_ids);
+  data_ = std::move(node);
+}
+
+void CacheReadStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArraySeperator();
+  writer->WriteString(scope_name);
+  writer->WriteArrayItem(IntArrayToVector(reader_stage_ids));
+}
+
+int CacheReadStepNode::ApplyToState(State* state, const ComputeDAG& dag) const 
{
+  StateNode* pstate = state->CopyOnWrite();
+  Array<Step> replay_steps;
+  for (size_t i = 0; i < pstate->transform_steps.size(); ++i) {
+    AddStageModificationSteps(i, pstate->transform_steps, &replay_steps);
+    if (pstate->transform_steps[i].same_as(GetRef<Step>(this))) {
+      break;
+    }
+  }
+  const ComputeDAG& current_compute_dag = dag.ReplayAndGetDAG(replay_steps);
+
+  // target -> target + target_store
+  // Should update target's op, insert new stage, update the later stage's op
+  int added_stage_id = stage_id + 1;
+  Stage tmp_stage = pstate->stages[stage_id];
+  tmp_stage.CopyOnWrite()->op = current_compute_dag->ops[stage_id];
+  pstate->stages.Set(stage_id, std::move(tmp_stage));
+  pstate->stages.insert(pstate->stages.begin() + added_stage_id,
+                        Stage(current_compute_dag->ops[added_stage_id]));
+  for (size_t i = added_stage_id + 1; i < pstate->stages.size(); ++i) {
+    tmp_stage = pstate->stages[i];
+    tmp_stage.CopyOnWrite()->op = current_compute_dag->ops[i];
+    pstate->stages.Set(i, std::move(tmp_stage));
+  }
+  pstate->attach_map = pstate->attach_map.ApplyStageIdOffset(added_stage_id);
+  pstate->current_compute_dag = std::move(current_compute_dag);
+
+  return added_stage_id;
+}
+
+te::Tensor CacheReadStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+                                              StageToAxesMap* stage_to_axes,
+                                              te::Schedule* schedule) const {
+  const te::Stage& stage = (*stages)[stage_id];
+
+  Array<te::Operation> readers;
+  for (const auto& i : reader_stage_ids) {
+    readers.push_back((*stages)[i]->origin_op);
+  }
+  auto out = schedule->cache_read(stage->origin_op.output(0), scope_name, 
readers);
+
+  const auto& new_stage = (*schedule)[out->op];
+  UpdateStageToAxesMap(new_stage, stage_to_axes);
+  stages->insert(stages->begin() + stage_id + 1, new_stage);
+
+  return out;
+}
+
+String CacheReadStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, 
StageToAxesMap* stage_to_axes,
+                                           te::Schedule* schedule) const {
+  std::stringstream ss;
+  // Copy stage here, for the original stage will change after apply
+  auto stage = (*stages)[stage_id];
+  std::vector<te::Stage> reader_stages;
+  for (size_t i = 0; i < reader_stage_ids.size(); ++i) {
+    reader_stages.push_back((*stages)[reader_stage_ids[i]]);
+  }
+
+  auto out = ApplyToSchedule(stages, stage_to_axes, schedule);
+
+  ss << CleanName(out->op->name) << " = "
+     << "s.cache_read(" << CleanName(stage->op->name) << ", \"" << scope_name 
<< "\", ["
+     << CleanName(reader_stages[0]->op->name);
+  for (size_t i = 1; i < reader_stage_ids.size(); ++i) {
+    ss << ", " << CleanName(reader_stages[i]->op->name);
+  }
+  ss << "])\n";
+
+  const auto& iters = out->op->root_iter_vars();
+  for (size_t i = 0; i < iters.size(); ++i) {
+    ss << CleanName(iters[i]->var->name_hint);
+    if (i != iters.size() - 1) {
+      ss << ", ";
+    }
+  }
+  ss << " = "
+     << "tuple(" << CleanName(out->op->name) << ".op.axis)\n";
+
+  return ss.str();
+}
+
+/********** Cache Write **********/
+CacheWriteStep::CacheWriteStep(int stage_id, String scope_name) {
+  auto node = make_object<CacheWriteStepNode>();
+  node->stage_id = stage_id;
+  node->scope_name = std::move(scope_name);
+  data_ = std::move(node);
+}
+
+CacheWriteStep::CacheWriteStep(dmlc::JSONReader* reader) {
+  auto node = make_object<CacheWriteStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  std::string string_value;
+  reader->Read(&string_value);
+  node->scope_name = std::move(string_value);
+  data_ = std::move(node);
+}
+
+void CacheWriteStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArraySeperator();
+  writer->WriteString(scope_name);
+}
+
+int CacheWriteStepNode::ApplyToState(State* state, const ComputeDAG& dag) 
const {
+  StateNode* pstate = state->CopyOnWrite();
+  Array<Step> replay_steps;
+  for (size_t i = 0; i < pstate->transform_steps.size(); ++i) {
+    AddStageModificationSteps(i, pstate->transform_steps, &replay_steps);
+    if (pstate->transform_steps[i].same_as(GetRef<Step>(this))) {
+      break;
+    }
+  }
+  int last_dag_op_size = pstate->current_compute_dag.defined()
+                             ? 
pstate->current_compute_dag.as<ComputeDAGNode>()->ops.size()
+                             : dag->ops.size();
+  const ComputeDAG& current_compute_dag = dag.ReplayAndGetDAG(replay_steps);
+  int added_ops = current_compute_dag->ops.size() - last_dag_op_size;
+  CHECK_GE(added_ops, 1);
+
+  // target -> target_compute + target
+  // Assume target stage has never been applied any steps before cache_write
+  // Should insert new stage, update target stage, update the later stage's op
+  pstate->stages.insert(pstate->stages.begin() + stage_id,
+                        Stage(current_compute_dag->ops[stage_id]));
+  pstate->stages.Set(stage_id + 1, Stage(current_compute_dag->ops[stage_id + 
1]));
+  int next_stage_id = stage_id + 2;
+  // Notice: added_ops should actually assert to be 1
+  // branch of 2 here is somehow a hack to TVM's cache_write bug with multi 
outputs
+  // see 
`tests/python/unittest/test_auto_scheduler_loop_state.py::test_cache_read_write`
 test for
+  // more information
+  // TODO(jcf94): Fix the cache write bug in TVM and remove these branches here

Review comment:
       ```suggestion
     // TODO(jc94): Fix the cache write bug in TVM and remove added_op == 2 
support.
     // TVM's cache_write has a bug with multi outputs.
     // See 
`tests/python/unittest/test_auto_scheduler_loop_state.py::test_cache_read_write`
 test.
   ```

##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -923,5 +958,275 @@ String 
ComputeRootStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
   return ss.str();
 }
 
+/********** Primitives adding new stages **********/
+
+// Common part for steps that add new stages
+// (e.g. CacheReadStep, CacheWriteStep, RfactorStep)
+void AddStageModificationSteps(int step_id, const Array<Step>& transform_steps,
+                               Array<Step>* replay_steps) {
+  const Step& step = transform_steps[step_id];
+  if (step->IsInstance<CacheWriteStepNode>() || 
step->IsInstance<CacheReadStepNode>()) {
+    replay_steps->push_back(step);
+  }
+  // TODO(jcf94): add rfactor support
+}
+
+/********** Cache Read **********/
+CacheReadStep::CacheReadStep(int stage_id, String scope_name,
+                             const Array<Integer>& reader_stage_ids) {
+  auto node = make_object<CacheReadStepNode>();
+  node->stage_id = stage_id;
+  node->scope_name = std::move(scope_name);
+  node->reader_stage_ids = reader_stage_ids;
+  data_ = std::move(node);
+}
+
+CacheReadStep::CacheReadStep(dmlc::JSONReader* reader) {
+  auto node = make_object<CacheReadStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  std::string string_value;
+  reader->Read(&string_value);
+  node->scope_name = std::move(string_value);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  std::vector<int> int_list;
+  reader->Read(&int_list);
+  Array<Integer> reader_stage_ids;
+  for (int i : int_list) {
+    reader_stage_ids.push_back(i);
+  }
+  node->reader_stage_ids = std::move(reader_stage_ids);
+  data_ = std::move(node);
+}
+
+void CacheReadStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArraySeperator();
+  writer->WriteString(scope_name);
+  writer->WriteArrayItem(IntArrayToVector(reader_stage_ids));
+}
+
+int CacheReadStepNode::ApplyToState(State* state, const ComputeDAG& dag) const 
{
+  StateNode* pstate = state->CopyOnWrite();
+  Array<Step> replay_steps;
+  for (size_t i = 0; i < pstate->transform_steps.size(); ++i) {
+    AddStageModificationSteps(i, pstate->transform_steps, &replay_steps);
+    if (pstate->transform_steps[i].same_as(GetRef<Step>(this))) {
+      break;
+    }
+  }
+  const ComputeDAG& current_compute_dag = dag.ReplayAndGetDAG(replay_steps);
+
+  // target -> target + target_store
+  // Should update target's op, insert new stage, update the later stage's op
+  int added_stage_id = stage_id + 1;
+  Stage tmp_stage = pstate->stages[stage_id];
+  tmp_stage.CopyOnWrite()->op = current_compute_dag->ops[stage_id];
+  pstate->stages.Set(stage_id, std::move(tmp_stage));
+  pstate->stages.insert(pstate->stages.begin() + added_stage_id,
+                        Stage(current_compute_dag->ops[added_stage_id]));
+  for (size_t i = added_stage_id + 1; i < pstate->stages.size(); ++i) {
+    tmp_stage = pstate->stages[i];
+    tmp_stage.CopyOnWrite()->op = current_compute_dag->ops[i];
+    pstate->stages.Set(i, std::move(tmp_stage));
+  }
+  pstate->attach_map = pstate->attach_map.ApplyStageIdOffset(added_stage_id);
+  pstate->current_compute_dag = std::move(current_compute_dag);
+
+  return added_stage_id;
+}
+
+te::Tensor CacheReadStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+                                              StageToAxesMap* stage_to_axes,
+                                              te::Schedule* schedule) const {
+  const te::Stage& stage = (*stages)[stage_id];
+
+  Array<te::Operation> readers;
+  for (const auto& i : reader_stage_ids) {
+    readers.push_back((*stages)[i]->origin_op);
+  }
+  auto out = schedule->cache_read(stage->origin_op.output(0), scope_name, 
readers);
+
+  const auto& new_stage = (*schedule)[out->op];
+  UpdateStageToAxesMap(new_stage, stage_to_axes);
+  stages->insert(stages->begin() + stage_id + 1, new_stage);
+
+  return out;
+}
+
+String CacheReadStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, 
StageToAxesMap* stage_to_axes,
+                                           te::Schedule* schedule) const {
+  std::stringstream ss;
+  // Copy stage here, for the original stage will change after apply
+  auto stage = (*stages)[stage_id];
+  std::vector<te::Stage> reader_stages;
+  for (size_t i = 0; i < reader_stage_ids.size(); ++i) {
+    reader_stages.push_back((*stages)[reader_stage_ids[i]]);
+  }
+
+  auto out = ApplyToSchedule(stages, stage_to_axes, schedule);
+
+  ss << CleanName(out->op->name) << " = "
+     << "s.cache_read(" << CleanName(stage->op->name) << ", \"" << scope_name 
<< "\", ["
+     << CleanName(reader_stages[0]->op->name);
+  for (size_t i = 1; i < reader_stage_ids.size(); ++i) {
+    ss << ", " << CleanName(reader_stages[i]->op->name);
+  }
+  ss << "])\n";
+
+  const auto& iters = out->op->root_iter_vars();
+  for (size_t i = 0; i < iters.size(); ++i) {
+    ss << CleanName(iters[i]->var->name_hint);
+    if (i != iters.size() - 1) {
+      ss << ", ";
+    }
+  }
+  ss << " = "
+     << "tuple(" << CleanName(out->op->name) << ".op.axis)\n";
+
+  return ss.str();
+}
+
+/********** Cache Write **********/
+CacheWriteStep::CacheWriteStep(int stage_id, String scope_name) {
+  auto node = make_object<CacheWriteStepNode>();
+  node->stage_id = stage_id;
+  node->scope_name = std::move(scope_name);
+  data_ = std::move(node);
+}
+
+CacheWriteStep::CacheWriteStep(dmlc::JSONReader* reader) {
+  auto node = make_object<CacheWriteStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  std::string string_value;
+  reader->Read(&string_value);
+  node->scope_name = std::move(string_value);
+  data_ = std::move(node);
+}
+
+void CacheWriteStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArraySeperator();
+  writer->WriteString(scope_name);
+}
+
+int CacheWriteStepNode::ApplyToState(State* state, const ComputeDAG& dag) 
const {
+  StateNode* pstate = state->CopyOnWrite();
+  Array<Step> replay_steps;
+  for (size_t i = 0; i < pstate->transform_steps.size(); ++i) {
+    AddStageModificationSteps(i, pstate->transform_steps, &replay_steps);
+    if (pstate->transform_steps[i].same_as(GetRef<Step>(this))) {
+      break;
+    }
+  }
+  int last_dag_op_size = pstate->current_compute_dag.defined()
+                             ? 
pstate->current_compute_dag.as<ComputeDAGNode>()->ops.size()
+                             : dag->ops.size();
+  const ComputeDAG& current_compute_dag = dag.ReplayAndGetDAG(replay_steps);
+  int added_ops = current_compute_dag->ops.size() - last_dag_op_size;
+  CHECK_GE(added_ops, 1);
+
+  // target -> target_compute + target
+  // Assume target stage has never been applied any steps before cache_write
+  // Should insert new stage, update target stage, update the later stage's op

Review comment:
       ```suggestion
     // target_stage -> cache_write_stage + target_stage
     // Assume no step has been applied to the target stage cache_write.
     // Insert a new cache write stage, update the target stage and the 
op-to-stage-id map for ops in later stages.
   ```

##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -923,5 +958,275 @@ String 
ComputeRootStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
   return ss.str();
 }
 
+/********** Primitives adding new stages **********/
+
+// Common part for steps that add new stages
+// (e.g. CacheReadStep, CacheWriteStep, RfactorStep)
+void AddStageModificationSteps(int step_id, const Array<Step>& transform_steps,
+                               Array<Step>* replay_steps) {
+  const Step& step = transform_steps[step_id];
+  if (step->IsInstance<CacheWriteStepNode>() || 
step->IsInstance<CacheReadStepNode>()) {
+    replay_steps->push_back(step);
+  }
+  // TODO(jcf94): add rfactor support
+}
+
+/********** Cache Read **********/
+CacheReadStep::CacheReadStep(int stage_id, String scope_name,
+                             const Array<Integer>& reader_stage_ids) {
+  auto node = make_object<CacheReadStepNode>();
+  node->stage_id = stage_id;
+  node->scope_name = std::move(scope_name);
+  node->reader_stage_ids = reader_stage_ids;
+  data_ = std::move(node);
+}
+
+CacheReadStep::CacheReadStep(dmlc::JSONReader* reader) {
+  auto node = make_object<CacheReadStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  std::string string_value;
+  reader->Read(&string_value);
+  node->scope_name = std::move(string_value);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  std::vector<int> int_list;
+  reader->Read(&int_list);
+  Array<Integer> reader_stage_ids;
+  for (int i : int_list) {
+    reader_stage_ids.push_back(i);
+  }
+  node->reader_stage_ids = std::move(reader_stage_ids);
+  data_ = std::move(node);
+}
+
+void CacheReadStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArraySeperator();
+  writer->WriteString(scope_name);
+  writer->WriteArrayItem(IntArrayToVector(reader_stage_ids));
+}
+
+int CacheReadStepNode::ApplyToState(State* state, const ComputeDAG& dag) const 
{
+  StateNode* pstate = state->CopyOnWrite();
+  Array<Step> replay_steps;
+  for (size_t i = 0; i < pstate->transform_steps.size(); ++i) {
+    AddStageModificationSteps(i, pstate->transform_steps, &replay_steps);
+    if (pstate->transform_steps[i].same_as(GetRef<Step>(this))) {
+      break;
+    }
+  }
+  const ComputeDAG& current_compute_dag = dag.ReplayAndGetDAG(replay_steps);
+
+  // target -> target + target_store
+  // Should update target's op, insert new stage, update the later stage's op
+  int added_stage_id = stage_id + 1;
+  Stage tmp_stage = pstate->stages[stage_id];
+  tmp_stage.CopyOnWrite()->op = current_compute_dag->ops[stage_id];
+  pstate->stages.Set(stage_id, std::move(tmp_stage));
+  pstate->stages.insert(pstate->stages.begin() + added_stage_id,
+                        Stage(current_compute_dag->ops[added_stage_id]));
+  for (size_t i = added_stage_id + 1; i < pstate->stages.size(); ++i) {
+    tmp_stage = pstate->stages[i];
+    tmp_stage.CopyOnWrite()->op = current_compute_dag->ops[i];
+    pstate->stages.Set(i, std::move(tmp_stage));
+  }
+  pstate->attach_map = pstate->attach_map.ApplyStageIdOffset(added_stage_id);
+  pstate->current_compute_dag = std::move(current_compute_dag);
+
+  return added_stage_id;
+}
+
+te::Tensor CacheReadStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+                                              StageToAxesMap* stage_to_axes,
+                                              te::Schedule* schedule) const {
+  const te::Stage& stage = (*stages)[stage_id];
+
+  Array<te::Operation> readers;
+  for (const auto& i : reader_stage_ids) {
+    readers.push_back((*stages)[i]->origin_op);
+  }
+  auto out = schedule->cache_read(stage->origin_op.output(0), scope_name, 
readers);
+
+  const auto& new_stage = (*schedule)[out->op];
+  UpdateStageToAxesMap(new_stage, stage_to_axes);
+  stages->insert(stages->begin() + stage_id + 1, new_stage);
+
+  return out;
+}
+
+String CacheReadStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, 
StageToAxesMap* stage_to_axes,
+                                           te::Schedule* schedule) const {
+  std::stringstream ss;
+  // Copy stage here, for the original stage will change after apply
+  auto stage = (*stages)[stage_id];
+  std::vector<te::Stage> reader_stages;
+  for (size_t i = 0; i < reader_stage_ids.size(); ++i) {
+    reader_stages.push_back((*stages)[reader_stage_ids[i]]);
+  }
+
+  auto out = ApplyToSchedule(stages, stage_to_axes, schedule);
+
+  ss << CleanName(out->op->name) << " = "
+     << "s.cache_read(" << CleanName(stage->op->name) << ", \"" << scope_name 
<< "\", ["
+     << CleanName(reader_stages[0]->op->name);
+  for (size_t i = 1; i < reader_stage_ids.size(); ++i) {
+    ss << ", " << CleanName(reader_stages[i]->op->name);
+  }
+  ss << "])\n";
+
+  const auto& iters = out->op->root_iter_vars();
+  for (size_t i = 0; i < iters.size(); ++i) {
+    ss << CleanName(iters[i]->var->name_hint);
+    if (i != iters.size() - 1) {
+      ss << ", ";
+    }
+  }
+  ss << " = "
+     << "tuple(" << CleanName(out->op->name) << ".op.axis)\n";
+
+  return ss.str();
+}
+
+/********** Cache Write **********/
+CacheWriteStep::CacheWriteStep(int stage_id, String scope_name) {
+  auto node = make_object<CacheWriteStepNode>();
+  node->stage_id = stage_id;
+  node->scope_name = std::move(scope_name);
+  data_ = std::move(node);
+}
+
+CacheWriteStep::CacheWriteStep(dmlc::JSONReader* reader) {
+  auto node = make_object<CacheWriteStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  std::string string_value;
+  reader->Read(&string_value);
+  node->scope_name = std::move(string_value);
+  data_ = std::move(node);
+}
+
+void CacheWriteStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArraySeperator();
+  writer->WriteString(scope_name);
+}
+
+int CacheWriteStepNode::ApplyToState(State* state, const ComputeDAG& dag) 
const {
+  StateNode* pstate = state->CopyOnWrite();
+  Array<Step> replay_steps;
+  for (size_t i = 0; i < pstate->transform_steps.size(); ++i) {
+    AddStageModificationSteps(i, pstate->transform_steps, &replay_steps);
+    if (pstate->transform_steps[i].same_as(GetRef<Step>(this))) {
+      break;
+    }
+  }
+  int last_dag_op_size = pstate->current_compute_dag.defined()
+                             ? 
pstate->current_compute_dag.as<ComputeDAGNode>()->ops.size()
+                             : dag->ops.size();
+  const ComputeDAG& current_compute_dag = dag.ReplayAndGetDAG(replay_steps);
+  int added_ops = current_compute_dag->ops.size() - last_dag_op_size;
+  CHECK_GE(added_ops, 1);

Review comment:
       According to L1150, please add a TODO here as well to change to 
`CHECK_EQ`.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to