This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 8820ce48d1 [Unity][MSC][special] Change special names (#15691)
8820ce48d1 is described below

commit 8820ce48d13fbe983c8286d06a377c8e54863a3d
Author: Archermmt <[email protected]>
AuthorDate: Thu Sep 7 22:39:24 2023 +0800

    [Unity][MSC][special] Change special names (#15691)
---
 python/tvm/contrib/msc/core/ir/graph.py         | 16 +++++++-------
 src/contrib/msc/core/ir/graph.cc                | 28 ++++++++++++-------------
 src/contrib/msc/core/ir/graph.h                 | 28 ++++++++++++-------------
 src/contrib/msc/core/ir/graph_builder.cc        |  8 +++----
 src/contrib/msc/core/transform/set_expr_name.cc |  8 +++----
 5 files changed, 44 insertions(+), 44 deletions(-)

diff --git a/python/tvm/contrib/msc/core/ir/graph.py 
b/python/tvm/contrib/msc/core/ir/graph.py
index 0f4d94290e..5475b83005 100644
--- a/python/tvm/contrib/msc/core/ir/graph.py
+++ b/python/tvm/contrib/msc/core/ir/graph.py
@@ -126,8 +126,8 @@ class MSCJoint(BaseJoint):
         The index of the node.
     name: string
         The name of the node.
-    master_name: string
-        The master name of the node.
+    shared_ref: string
+        The share reference of the node.
     optype: string
         The optype of the node.
     attrs: dict<string, string>
@@ -144,7 +144,7 @@ class MSCJoint(BaseJoint):
         self,
         index: int,
         name: str,
-        master_name: str,
+        shared_ref: str,
         optype: str,
         attrs: Dict[str, str],
         inputs: List[Tuple[BaseJoint, int]],
@@ -158,7 +158,7 @@ class MSCJoint(BaseJoint):
             _ffi_api.MSCJoint,
             index,
             name,
-            master_name,
+            shared_ref,
             optype,
             attrs,
             parents,
@@ -289,8 +289,8 @@ class WeightJoint(BaseJoint):
         The index of the node.
     name: string
         The name of the node.
-    master_name: string
-        The master name of the node.
+    shared_ref: string
+        The share reference of the node.
     optype: string
         The optype of the node.
     wtype: string
@@ -309,7 +309,7 @@ class WeightJoint(BaseJoint):
         self,
         index: int,
         name: str,
-        master_name: str,
+        shared_ref: str,
         optype: str,
         wtype: str,
         attrs: Dict[str, str],
@@ -322,7 +322,7 @@ class WeightJoint(BaseJoint):
             _ffi_api.WeightJoint,
             index,
             name,
-            master_name,
+            shared_ref,
             optype,
             wtype,
             attrs,
diff --git a/src/contrib/msc/core/ir/graph.cc b/src/contrib/msc/core/ir/graph.cc
index 7595559a60..b07f2ba422 100644
--- a/src/contrib/msc/core/ir/graph.cc
+++ b/src/contrib/msc/core/ir/graph.cc
@@ -248,14 +248,14 @@ bool BaseJointNode::GetAttr(const String& key, 
std::vector<float>* val) const {
   return false;
 }
 
-MSCJoint::MSCJoint(int index, const String& name, const String& master_name, 
const String& optype,
+MSCJoint::MSCJoint(int index, const String& name, const String& shared_ref, 
const String& optype,
                    const Map<String, String>& attrs, const Array<String>& 
scope,
                    const std::vector<std::pair<BaseJoint, size_t>>& inputs,
                    const Array<MSCTensor>& outputs, const Map<String, 
MSCTensor>& weights) {
   ObjectPtr<MSCJointNode> n = make_object<MSCJointNode>();
   n->index = index;
   n->name = std::move(name);
-  n->master_name = std::move(master_name);
+  n->shared_ref = std::move(shared_ref);
   n->optype = std::move(optype);
   n->attrs = std::move(attrs);
   n->scope = std::move(scope);
@@ -301,15 +301,15 @@ MSCJoint::MSCJoint(const std::string& json_str, const 
Map<String, BaseJoint>& no
 
 const MSCJoint MSCJoint::Clone(const MSCJoint& node,
                                const std::vector<std::pair<BaseJoint, 
size_t>>& inputs) {
-  return MSCJoint(node->index, node->name, node->master_name, node->optype, 
node->attrs,
-                  node->scope, inputs, node->outputs, node->weights);
+  return MSCJoint(node->index, node->name, node->shared_ref, node->optype, 
node->attrs, node->scope,
+                  inputs, node->outputs, node->weights);
 }
 
 const JsonMSCJoint MSCJointNode::ToJson() const {
   JsonMSCJoint j_joint;
   j_joint.index = index;
   j_joint.name = name;
-  j_joint.master_name = master_name;
+  j_joint.shared_ref = shared_ref;
   j_joint.optype = optype;
   for (const auto& pair : attrs) {
     j_joint.attrs[pair.first] = pair.second;
@@ -335,7 +335,7 @@ const JsonMSCJoint MSCJointNode::ToJson() const {
 void MSCJointNode::FromJson(const JsonMSCJoint& j_joint, const Map<String, 
BaseJoint>& nodes) {
   index = j_joint.index;
   name = j_joint.name;
-  master_name = j_joint.master_name;
+  shared_ref = j_joint.shared_ref;
   optype = j_joint.optype;
   for (const auto& pair : j_joint.attrs) {
     attrs.Set(pair.first, pair.second);
@@ -453,14 +453,14 @@ const std::pair<MSCJoint, size_t> 
MSCJointNode::ProducerAndIdxOf(const MSCTensor
   return ProducerAndIdxOf(input->name);
 }
 
-WeightJoint::WeightJoint(int index, const String& name, const String& 
master_name,
+WeightJoint::WeightJoint(int index, const String& name, const String& 
shared_ref,
                          const String& optype, const String& wtype,
                          const Map<String, String>& attrs, const MSCTensor& 
weight,
                          const Array<BaseJoint> parents, const 
Array<BaseJoint>& friends) {
   ObjectPtr<WeightJointNode> n = make_object<WeightJointNode>();
   n->index = index;
   n->name = std::move(name);
-  n->master_name = std::move(master_name);
+  n->shared_ref = std::move(shared_ref);
   n->optype = std::move(optype);
   n->wtype = std::move(wtype);
   n->attrs = std::move(attrs);
@@ -823,8 +823,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 
 #define MSC_NODE_BASE_HEAD(Stream, Joint)                                      
          \
   Stream << "ID_" << Joint->index << " " << Joint->name;                       
          \
-  if (Joint->master_name.size() > 0) {                                         
          \
-    Stream << "(M: " << Joint->master_name << ")";                             
          \
+  if (Joint->shared_ref.size() > 0) {                                          
          \
+    Stream << "(M: " << Joint->shared_ref << ")";                              
          \
   }                                                                            
          \
   Stream << " <PARENTS: ";                                                     
          \
   if (Joint->parents.size() > 0) {                                             
          \
@@ -946,7 +946,7 @@ TVM_REGISTER_GLOBAL("msc.core.MSCTensor")
     });
 
 TVM_REGISTER_GLOBAL("msc.core.MSCJoint")
-    .set_body_typed([](Integer index, const String& name, const String& 
master_name,
+    .set_body_typed([](Integer index, const String& name, const String& 
shared_ref,
                        const String& optype, const Map<String, String>& attrs,
                        const Array<String>& scope, const Array<MSCJoint>& 
parents,
                        const Array<Integer> out_indices, const 
Array<MSCTensor>& outputs,
@@ -955,12 +955,12 @@ TVM_REGISTER_GLOBAL("msc.core.MSCJoint")
       for (size_t i = 0; i < parents.size(); i++) {
         inputs.push_back(std::make_pair(parents[i], out_indices[i]->value));
       }
-      return MSCJoint(index->value, name, master_name, optype, attrs, scope, 
inputs, outputs,
+      return MSCJoint(index->value, name, shared_ref, optype, attrs, scope, 
inputs, outputs,
                       weights);
     });
 
 TVM_REGISTER_GLOBAL("msc.core.WeightJoint")
-    .set_body_typed([](Integer index, const String& name, const String& 
master_name,
+    .set_body_typed([](Integer index, const String& name, const String& 
shared_ref,
                        const String& optype, const String& wtype, const 
Map<String, String>& attrs,
                        const MSCTensor& weight, const Array<WeightJoint> 
parents,
                        const Array<WeightJoint>& friends) -> WeightJoint {
@@ -971,7 +971,7 @@ TVM_REGISTER_GLOBAL("msc.core.WeightJoint")
       for (const auto& f : friends) {
         b_friends.push_back(f);
       }
-      return WeightJoint(index->value, name, master_name, optype, wtype, 
attrs, weight, b_parents,
+      return WeightJoint(index->value, name, shared_ref, optype, wtype, attrs, 
weight, b_parents,
                          b_friends);
     });
 
diff --git a/src/contrib/msc/core/ir/graph.h b/src/contrib/msc/core/ir/graph.h
index 981824e0a9..8179471d5a 100644
--- a/src/contrib/msc/core/ir/graph.h
+++ b/src/contrib/msc/core/ir/graph.h
@@ -91,7 +91,7 @@ struct JsonMSCTensor {
 struct JsonMSCJoint {
   size_t index;
   std::string name;
-  std::string master_name;
+  std::string shared_ref;
   std::string optype;
   std::vector<std::string> scope;
   std::vector<std::string> parents;
@@ -104,7 +104,7 @@ struct JsonMSCJoint {
     writer->BeginObject();
     writer->WriteObjectKeyValue("index", index);
     writer->WriteObjectKeyValue("name", name);
-    writer->WriteObjectKeyValue("master_name", master_name);
+    writer->WriteObjectKeyValue("shared_ref", shared_ref);
     writer->WriteObjectKeyValue("optype", optype);
     writer->WriteObjectKeyValue("parents", parents);
     writer->WriteObjectKeyValue("inputs", inputs);
@@ -125,8 +125,8 @@ struct JsonMSCJoint {
       } else if (key == "name") {
         reader->Read(&name);
         bitmask |= 2;
-      } else if (key == "master_name") {
-        reader->Read(&master_name);
+      } else if (key == "shared_ref") {
+        reader->Read(&shared_ref);
       } else if (key == "optype") {
         reader->Read(&optype);
         bitmask |= 4;
@@ -288,8 +288,8 @@ class BaseJointNode : public Object {
   mutable int index;
   /*! \brief The name of node. */
   String name;
-  /*! \brief The master_name of node, can be changed. */
-  String master_name;
+  /*! \brief The shared_ref of node, can be changed. */
+  String shared_ref;
   /*! \brief The op type of node. */
   String optype;
   /*! \brief The attributes of node. */
@@ -332,7 +332,7 @@ class BaseJointNode : public Object {
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("index", &index);
     v->Visit("name", &name);
-    v->Visit("master_name", &master_name);
+    v->Visit("shared_ref", &shared_ref);
     v->Visit("optype", &optype);
     v->Visit("attrs", &attrs);
     v->Visit("parents", &parents);
@@ -341,14 +341,14 @@ class BaseJointNode : public Object {
 
   bool SEqualReduce(const BaseJointNode* other, SEqualReducer equal) const {
     return equal(name, other->name) &&
-           equal(master_name, other->master_name) & equal(optype, 
other->optype) &&
+           equal(shared_ref, other->shared_ref) & equal(optype, other->optype) 
&&
            equal(attrs, other->attrs) && equal(parents, other->parents) &&
            equal(children, other->children);
   }
 
   void SHashReduce(SHashReducer hash_reduce) const {
     hash_reduce(name);
-    hash_reduce(master_name);
+    hash_reduce(shared_ref);
     hash_reduce(optype);
     hash_reduce(attrs);
     hash_reduce(parents);
@@ -450,14 +450,14 @@ class MSCJoint : public BaseJoint {
    * \brief The constructor.
    * \param index The index of the node.
    * \param name The name of the node.
-   * \param master_name The master_name of the node.
+   * \param shared_ref The shared_ref of the node.
    * \param optype The op type the node.
    * \param attrs The attributes of the node.
    * \param inputs The inputs of the node.
    * \param outputs The outputs of the node.
    * \param weights The weights of the node.
    */
-  TVM_DLL MSCJoint(int index, const String& name, const String& master_name, 
const String& optype,
+  TVM_DLL MSCJoint(int index, const String& name, const String& shared_ref, 
const String& optype,
                    const Map<String, String>& attrs, const Array<String>& 
scope,
                    const std::vector<std::pair<BaseJoint, size_t>>& inputs,
                    const Array<MSCTensor>& outputs, const Map<String, 
MSCTensor>& weights);
@@ -531,7 +531,7 @@ class WeightJoint : public BaseJoint {
    * \brief The constructor.
    * \param index The index of the node.
    * \param name The name of the node.
-   * \param master_name The master_name of the node.
+   * \param shared_ref The shared_ref of the node.
    * \param optype The optype of the node.
    * \param wtype The weight type of the node.
    * \param attrs The attributes of the node.
@@ -539,8 +539,8 @@ class WeightJoint : public BaseJoint {
    * \param parents The parents of the node.
    * \param friends The friends of the node.
    */
-  TVM_DLL WeightJoint(int index, const String& name, const String& master_name,
-                      const String& optype, const String& wtype, const 
Map<String, String>& attrs,
+  TVM_DLL WeightJoint(int index, const String& name, const String& shared_ref, 
const String& optype,
+                      const String& wtype, const Map<String, String>& attrs,
                       const MSCTensor& weight, const Array<BaseJoint> parents,
                       const Array<BaseJoint>& friends);
 
diff --git a/src/contrib/msc/core/ir/graph_builder.cc 
b/src/contrib/msc/core/ir/graph_builder.cc
index bca650a586..cadebcfcb0 100644
--- a/src/contrib/msc/core/ir/graph_builder.cc
+++ b/src/contrib/msc/core/ir/graph_builder.cc
@@ -103,7 +103,7 @@ const MSCGraph RelaxGraphBuilder::Build(const 
relax::Function& func) {
 const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const 
Optional<Expr>& binding_var,
                                           const String& name) {
   const auto& node_name = name.size() > 0 ? name : 
SpanUtils::GetAttr(expr->span, "name");
-  const auto& master_name = SpanUtils::GetAttr(expr->span, "master_name");
+  const auto& shared_ref = SpanUtils::GetAttr(expr->span, "shared_ref");
   String optype;
   if (expr->IsInstance<relax::VarNode>()) {
     optype = "input";
@@ -256,7 +256,7 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, 
const Optional<Expr>
     LOG(FATAL) << "Unexpected struct info (" << sinfo->GetTypeKey() << ")" << 
sinfo;
   }
   // Build node
-  const auto& node = MSCJoint(nodes_.size(), node_name, master_name, optype, 
attrs, scope, inputs,
+  const auto& node = MSCJoint(nodes_.size(), node_name, shared_ref, optype, 
attrs, scope, inputs,
                               outputs, node_weights);
   Array<String> output_names;
   for (size_t i = 0; i < outputs.size(); i++) {
@@ -419,7 +419,7 @@ MSCGraph RelayGraphBuilder::Build(const relay::Function& 
func) {
 
 MSCJoint RelayGraphBuilder::AddNode(const Expr& expr, const String& name) {
   const auto& node_name = name.size() > 0 ? name : 
SpanUtils::GetAttr(expr->span, "name");
-  const auto& master_name = SpanUtils::GetAttr(expr->span, "master_name");
+  const auto& shared_ref = SpanUtils::GetAttr(expr->span, "shared_ref");
   String optype;
   if (expr->IsInstance<relay::VarNode>()) {
     optype = "input";
@@ -570,7 +570,7 @@ MSCJoint RelayGraphBuilder::AddNode(const Expr& expr, const 
String& name) {
   }
 
   // Build node
-  const auto& node = MSCJoint(nodes_.size(), node_name, master_name, optype, 
attrs, scope, inputs,
+  const auto& node = MSCJoint(nodes_.size(), node_name, shared_ref, optype, 
attrs, scope, inputs,
                               outputs, node_weights);
   Array<String> output_names;
   for (size_t i = 0; i < outputs.size(); i++) {
diff --git a/src/contrib/msc/core/transform/set_expr_name.cc 
b/src/contrib/msc/core/transform/set_expr_name.cc
index b35ac821d9..25c6499c82 100644
--- a/src/contrib/msc/core/transform/set_expr_name.cc
+++ b/src/contrib/msc/core/transform/set_expr_name.cc
@@ -130,7 +130,7 @@ class RelaxExprNameSetter : public ExprVisitor {
     if (unique_name != SpanUtils::GetAttr(val->span, "name")) {
       val->span = SpanUtils::SetAttr(val->span, "name", unique_name);
     }
-    // set constant consumer && master_name
+    // set constant consumer && shared_ref
     Array<String> input_types;
     try {
       input_types = ExprUtils::GetInputTypes(optype, val->args.size(), true);
@@ -145,7 +145,7 @@ class RelaxExprNameSetter : public ExprVisitor {
       if (const auto* c_node = val->args[i].as<ConstantNode>()) {
         const String& const_name = SpanUtils::GetAttr(c_node->span, "name");
         if (constant_consumers_.count(const_name)) {
-          val->span = SpanUtils::SetAttr(val->span, "master_name", 
constant_consumers_[const_name]);
+          val->span = SpanUtils::SetAttr(val->span, "shared_ref", 
constant_consumers_[const_name]);
         } else {
           constant_consumers_.Set(const_name, unique_name);
         }
@@ -272,7 +272,7 @@ class RelayExprNameSetter : public ExprVisitor {
     if (unique_name != SpanUtils::GetAttr(op->span, "name")) {
       op->span = SpanUtils::SetAttr(op->span, "name", unique_name);
     }
-    // set constant consumer && master_name
+    // set constant consumer && shared_ref
     Array<String> input_types;
     try {
       input_types = ExprUtils::GetInputTypes(optype, op->args.size(), false);
@@ -287,7 +287,7 @@ class RelayExprNameSetter : public ExprVisitor {
       if (const auto* c_node = op->args[i].as<ConstantNode>()) {
         const String& const_name = SpanUtils::GetAttr(c_node->span, "name");
         if (constant_consumers_.count(const_name)) {
-          op->span = SpanUtils::SetAttr(op->span, "master_name", 
constant_consumers_[const_name]);
+          op->span = SpanUtils::SetAttr(op->span, "shared_ref", 
constant_consumers_[const_name]);
         } else {
           constant_consumers_.Set(const_name, unique_name);
         }

Reply via email to