This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 8820ce48d1 [Unity][MSC][special] Change special names (#15691)
8820ce48d1 is described below
commit 8820ce48d13fbe983c8286d06a377c8e54863a3d
Author: Archermmt <[email protected]>
AuthorDate: Thu Sep 7 22:39:24 2023 +0800
[Unity][MSC][special] Change special names (#15691)
---
python/tvm/contrib/msc/core/ir/graph.py | 16 +++++++-------
src/contrib/msc/core/ir/graph.cc | 28 ++++++++++++-------------
src/contrib/msc/core/ir/graph.h | 28 ++++++++++++-------------
src/contrib/msc/core/ir/graph_builder.cc | 8 +++----
src/contrib/msc/core/transform/set_expr_name.cc | 8 +++----
5 files changed, 44 insertions(+), 44 deletions(-)
diff --git a/python/tvm/contrib/msc/core/ir/graph.py
b/python/tvm/contrib/msc/core/ir/graph.py
index 0f4d94290e..5475b83005 100644
--- a/python/tvm/contrib/msc/core/ir/graph.py
+++ b/python/tvm/contrib/msc/core/ir/graph.py
@@ -126,8 +126,8 @@ class MSCJoint(BaseJoint):
The index of the node.
name: string
The name of the node.
- master_name: string
- The master name of the node.
+ shared_ref: string
+ The share reference of the node.
optype: string
The optype of the node.
attrs: dict<string, string>
@@ -144,7 +144,7 @@ class MSCJoint(BaseJoint):
self,
index: int,
name: str,
- master_name: str,
+ shared_ref: str,
optype: str,
attrs: Dict[str, str],
inputs: List[Tuple[BaseJoint, int]],
@@ -158,7 +158,7 @@ class MSCJoint(BaseJoint):
_ffi_api.MSCJoint,
index,
name,
- master_name,
+ shared_ref,
optype,
attrs,
parents,
@@ -289,8 +289,8 @@ class WeightJoint(BaseJoint):
The index of the node.
name: string
The name of the node.
- master_name: string
- The master name of the node.
+ shared_ref: string
+ The share reference of the node.
optype: string
The optype of the node.
wtype: string
@@ -309,7 +309,7 @@ class WeightJoint(BaseJoint):
self,
index: int,
name: str,
- master_name: str,
+ shared_ref: str,
optype: str,
wtype: str,
attrs: Dict[str, str],
@@ -322,7 +322,7 @@ class WeightJoint(BaseJoint):
_ffi_api.WeightJoint,
index,
name,
- master_name,
+ shared_ref,
optype,
wtype,
attrs,
diff --git a/src/contrib/msc/core/ir/graph.cc b/src/contrib/msc/core/ir/graph.cc
index 7595559a60..b07f2ba422 100644
--- a/src/contrib/msc/core/ir/graph.cc
+++ b/src/contrib/msc/core/ir/graph.cc
@@ -248,14 +248,14 @@ bool BaseJointNode::GetAttr(const String& key,
std::vector<float>* val) const {
return false;
}
-MSCJoint::MSCJoint(int index, const String& name, const String& master_name,
const String& optype,
+MSCJoint::MSCJoint(int index, const String& name, const String& shared_ref,
const String& optype,
const Map<String, String>& attrs, const Array<String>&
scope,
const std::vector<std::pair<BaseJoint, size_t>>& inputs,
const Array<MSCTensor>& outputs, const Map<String,
MSCTensor>& weights) {
ObjectPtr<MSCJointNode> n = make_object<MSCJointNode>();
n->index = index;
n->name = std::move(name);
- n->master_name = std::move(master_name);
+ n->shared_ref = std::move(shared_ref);
n->optype = std::move(optype);
n->attrs = std::move(attrs);
n->scope = std::move(scope);
@@ -301,15 +301,15 @@ MSCJoint::MSCJoint(const std::string& json_str, const
Map<String, BaseJoint>& no
const MSCJoint MSCJoint::Clone(const MSCJoint& node,
const std::vector<std::pair<BaseJoint,
size_t>>& inputs) {
- return MSCJoint(node->index, node->name, node->master_name, node->optype,
node->attrs,
- node->scope, inputs, node->outputs, node->weights);
+ return MSCJoint(node->index, node->name, node->shared_ref, node->optype,
node->attrs, node->scope,
+ inputs, node->outputs, node->weights);
}
const JsonMSCJoint MSCJointNode::ToJson() const {
JsonMSCJoint j_joint;
j_joint.index = index;
j_joint.name = name;
- j_joint.master_name = master_name;
+ j_joint.shared_ref = shared_ref;
j_joint.optype = optype;
for (const auto& pair : attrs) {
j_joint.attrs[pair.first] = pair.second;
@@ -335,7 +335,7 @@ const JsonMSCJoint MSCJointNode::ToJson() const {
void MSCJointNode::FromJson(const JsonMSCJoint& j_joint, const Map<String,
BaseJoint>& nodes) {
index = j_joint.index;
name = j_joint.name;
- master_name = j_joint.master_name;
+ shared_ref = j_joint.shared_ref;
optype = j_joint.optype;
for (const auto& pair : j_joint.attrs) {
attrs.Set(pair.first, pair.second);
@@ -453,14 +453,14 @@ const std::pair<MSCJoint, size_t>
MSCJointNode::ProducerAndIdxOf(const MSCTensor
return ProducerAndIdxOf(input->name);
}
-WeightJoint::WeightJoint(int index, const String& name, const String&
master_name,
+WeightJoint::WeightJoint(int index, const String& name, const String&
shared_ref,
const String& optype, const String& wtype,
const Map<String, String>& attrs, const MSCTensor&
weight,
const Array<BaseJoint> parents, const
Array<BaseJoint>& friends) {
ObjectPtr<WeightJointNode> n = make_object<WeightJointNode>();
n->index = index;
n->name = std::move(name);
- n->master_name = std::move(master_name);
+ n->shared_ref = std::move(shared_ref);
n->optype = std::move(optype);
n->wtype = std::move(wtype);
n->attrs = std::move(attrs);
@@ -823,8 +823,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
#define MSC_NODE_BASE_HEAD(Stream, Joint)
\
Stream << "ID_" << Joint->index << " " << Joint->name;
\
- if (Joint->master_name.size() > 0) {
\
- Stream << "(M: " << Joint->master_name << ")";
\
+ if (Joint->shared_ref.size() > 0) {
\
+ Stream << "(M: " << Joint->shared_ref << ")";
\
}
\
Stream << " <PARENTS: ";
\
if (Joint->parents.size() > 0) {
\
@@ -946,7 +946,7 @@ TVM_REGISTER_GLOBAL("msc.core.MSCTensor")
});
TVM_REGISTER_GLOBAL("msc.core.MSCJoint")
- .set_body_typed([](Integer index, const String& name, const String&
master_name,
+ .set_body_typed([](Integer index, const String& name, const String&
shared_ref,
const String& optype, const Map<String, String>& attrs,
const Array<String>& scope, const Array<MSCJoint>&
parents,
const Array<Integer> out_indices, const
Array<MSCTensor>& outputs,
@@ -955,12 +955,12 @@ TVM_REGISTER_GLOBAL("msc.core.MSCJoint")
for (size_t i = 0; i < parents.size(); i++) {
inputs.push_back(std::make_pair(parents[i], out_indices[i]->value));
}
- return MSCJoint(index->value, name, master_name, optype, attrs, scope,
inputs, outputs,
+ return MSCJoint(index->value, name, shared_ref, optype, attrs, scope,
inputs, outputs,
weights);
});
TVM_REGISTER_GLOBAL("msc.core.WeightJoint")
- .set_body_typed([](Integer index, const String& name, const String&
master_name,
+ .set_body_typed([](Integer index, const String& name, const String&
shared_ref,
const String& optype, const String& wtype, const
Map<String, String>& attrs,
const MSCTensor& weight, const Array<WeightJoint>
parents,
const Array<WeightJoint>& friends) -> WeightJoint {
@@ -971,7 +971,7 @@ TVM_REGISTER_GLOBAL("msc.core.WeightJoint")
for (const auto& f : friends) {
b_friends.push_back(f);
}
- return WeightJoint(index->value, name, master_name, optype, wtype,
attrs, weight, b_parents,
+ return WeightJoint(index->value, name, shared_ref, optype, wtype, attrs,
weight, b_parents,
b_friends);
});
diff --git a/src/contrib/msc/core/ir/graph.h b/src/contrib/msc/core/ir/graph.h
index 981824e0a9..8179471d5a 100644
--- a/src/contrib/msc/core/ir/graph.h
+++ b/src/contrib/msc/core/ir/graph.h
@@ -91,7 +91,7 @@ struct JsonMSCTensor {
struct JsonMSCJoint {
size_t index;
std::string name;
- std::string master_name;
+ std::string shared_ref;
std::string optype;
std::vector<std::string> scope;
std::vector<std::string> parents;
@@ -104,7 +104,7 @@ struct JsonMSCJoint {
writer->BeginObject();
writer->WriteObjectKeyValue("index", index);
writer->WriteObjectKeyValue("name", name);
- writer->WriteObjectKeyValue("master_name", master_name);
+ writer->WriteObjectKeyValue("shared_ref", shared_ref);
writer->WriteObjectKeyValue("optype", optype);
writer->WriteObjectKeyValue("parents", parents);
writer->WriteObjectKeyValue("inputs", inputs);
@@ -125,8 +125,8 @@ struct JsonMSCJoint {
} else if (key == "name") {
reader->Read(&name);
bitmask |= 2;
- } else if (key == "master_name") {
- reader->Read(&master_name);
+ } else if (key == "shared_ref") {
+ reader->Read(&shared_ref);
} else if (key == "optype") {
reader->Read(&optype);
bitmask |= 4;
@@ -288,8 +288,8 @@ class BaseJointNode : public Object {
mutable int index;
/*! \brief The name of node. */
String name;
- /*! \brief The master_name of node, can be changed. */
- String master_name;
+ /*! \brief The shared_ref of node, can be changed. */
+ String shared_ref;
/*! \brief The op type of node. */
String optype;
/*! \brief The attributes of node. */
@@ -332,7 +332,7 @@ class BaseJointNode : public Object {
void VisitAttrs(AttrVisitor* v) {
v->Visit("index", &index);
v->Visit("name", &name);
- v->Visit("master_name", &master_name);
+ v->Visit("shared_ref", &shared_ref);
v->Visit("optype", &optype);
v->Visit("attrs", &attrs);
v->Visit("parents", &parents);
@@ -341,14 +341,14 @@ class BaseJointNode : public Object {
bool SEqualReduce(const BaseJointNode* other, SEqualReducer equal) const {
return equal(name, other->name) &&
- equal(master_name, other->master_name) & equal(optype,
other->optype) &&
+ equal(shared_ref, other->shared_ref) & equal(optype, other->optype)
&&
equal(attrs, other->attrs) && equal(parents, other->parents) &&
equal(children, other->children);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(name);
- hash_reduce(master_name);
+ hash_reduce(shared_ref);
hash_reduce(optype);
hash_reduce(attrs);
hash_reduce(parents);
@@ -450,14 +450,14 @@ class MSCJoint : public BaseJoint {
* \brief The constructor.
* \param index The index of the node.
* \param name The name of the node.
- * \param master_name The master_name of the node.
+ * \param shared_ref The shared_ref of the node.
* \param optype The op type the node.
* \param attrs The attributes of the node.
* \param inputs The inputs of the node.
* \param outputs The outputs of the node.
* \param weights The weights of the node.
*/
- TVM_DLL MSCJoint(int index, const String& name, const String& master_name,
const String& optype,
+ TVM_DLL MSCJoint(int index, const String& name, const String& shared_ref,
const String& optype,
const Map<String, String>& attrs, const Array<String>&
scope,
const std::vector<std::pair<BaseJoint, size_t>>& inputs,
const Array<MSCTensor>& outputs, const Map<String,
MSCTensor>& weights);
@@ -531,7 +531,7 @@ class WeightJoint : public BaseJoint {
* \brief The constructor.
* \param index The index of the node.
* \param name The name of the node.
- * \param master_name The master_name of the node.
+ * \param shared_ref The shared_ref of the node.
* \param optype The optype of the node.
* \param wtype The weight type of the node.
* \param attrs The attributes of the node.
@@ -539,8 +539,8 @@ class WeightJoint : public BaseJoint {
* \param parents The parents of the node.
* \param friends The friends of the node.
*/
- TVM_DLL WeightJoint(int index, const String& name, const String& master_name,
- const String& optype, const String& wtype, const
Map<String, String>& attrs,
+ TVM_DLL WeightJoint(int index, const String& name, const String& shared_ref,
const String& optype,
+ const String& wtype, const Map<String, String>& attrs,
const MSCTensor& weight, const Array<BaseJoint> parents,
const Array<BaseJoint>& friends);
diff --git a/src/contrib/msc/core/ir/graph_builder.cc
b/src/contrib/msc/core/ir/graph_builder.cc
index bca650a586..cadebcfcb0 100644
--- a/src/contrib/msc/core/ir/graph_builder.cc
+++ b/src/contrib/msc/core/ir/graph_builder.cc
@@ -103,7 +103,7 @@ const MSCGraph RelaxGraphBuilder::Build(const
relax::Function& func) {
const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const
Optional<Expr>& binding_var,
const String& name) {
const auto& node_name = name.size() > 0 ? name :
SpanUtils::GetAttr(expr->span, "name");
- const auto& master_name = SpanUtils::GetAttr(expr->span, "master_name");
+ const auto& shared_ref = SpanUtils::GetAttr(expr->span, "shared_ref");
String optype;
if (expr->IsInstance<relax::VarNode>()) {
optype = "input";
@@ -256,7 +256,7 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr,
const Optional<Expr>
LOG(FATAL) << "Unexpected struct info (" << sinfo->GetTypeKey() << ")" <<
sinfo;
}
// Build node
- const auto& node = MSCJoint(nodes_.size(), node_name, master_name, optype,
attrs, scope, inputs,
+ const auto& node = MSCJoint(nodes_.size(), node_name, shared_ref, optype,
attrs, scope, inputs,
outputs, node_weights);
Array<String> output_names;
for (size_t i = 0; i < outputs.size(); i++) {
@@ -419,7 +419,7 @@ MSCGraph RelayGraphBuilder::Build(const relay::Function&
func) {
MSCJoint RelayGraphBuilder::AddNode(const Expr& expr, const String& name) {
const auto& node_name = name.size() > 0 ? name :
SpanUtils::GetAttr(expr->span, "name");
- const auto& master_name = SpanUtils::GetAttr(expr->span, "master_name");
+ const auto& shared_ref = SpanUtils::GetAttr(expr->span, "shared_ref");
String optype;
if (expr->IsInstance<relay::VarNode>()) {
optype = "input";
@@ -570,7 +570,7 @@ MSCJoint RelayGraphBuilder::AddNode(const Expr& expr, const
String& name) {
}
// Build node
- const auto& node = MSCJoint(nodes_.size(), node_name, master_name, optype,
attrs, scope, inputs,
+ const auto& node = MSCJoint(nodes_.size(), node_name, shared_ref, optype,
attrs, scope, inputs,
outputs, node_weights);
Array<String> output_names;
for (size_t i = 0; i < outputs.size(); i++) {
diff --git a/src/contrib/msc/core/transform/set_expr_name.cc
b/src/contrib/msc/core/transform/set_expr_name.cc
index b35ac821d9..25c6499c82 100644
--- a/src/contrib/msc/core/transform/set_expr_name.cc
+++ b/src/contrib/msc/core/transform/set_expr_name.cc
@@ -130,7 +130,7 @@ class RelaxExprNameSetter : public ExprVisitor {
if (unique_name != SpanUtils::GetAttr(val->span, "name")) {
val->span = SpanUtils::SetAttr(val->span, "name", unique_name);
}
- // set constant consumer && master_name
+ // set constant consumer && shared_ref
Array<String> input_types;
try {
input_types = ExprUtils::GetInputTypes(optype, val->args.size(), true);
@@ -145,7 +145,7 @@ class RelaxExprNameSetter : public ExprVisitor {
if (const auto* c_node = val->args[i].as<ConstantNode>()) {
const String& const_name = SpanUtils::GetAttr(c_node->span, "name");
if (constant_consumers_.count(const_name)) {
- val->span = SpanUtils::SetAttr(val->span, "master_name",
constant_consumers_[const_name]);
+ val->span = SpanUtils::SetAttr(val->span, "shared_ref",
constant_consumers_[const_name]);
} else {
constant_consumers_.Set(const_name, unique_name);
}
@@ -272,7 +272,7 @@ class RelayExprNameSetter : public ExprVisitor {
if (unique_name != SpanUtils::GetAttr(op->span, "name")) {
op->span = SpanUtils::SetAttr(op->span, "name", unique_name);
}
- // set constant consumer && master_name
+ // set constant consumer && shared_ref
Array<String> input_types;
try {
input_types = ExprUtils::GetInputTypes(optype, op->args.size(), false);
@@ -287,7 +287,7 @@ class RelayExprNameSetter : public ExprVisitor {
if (const auto* c_node = op->args[i].as<ConstantNode>()) {
const String& const_name = SpanUtils::GetAttr(c_node->span, "name");
if (constant_consumers_.count(const_name)) {
- op->span = SpanUtils::SetAttr(op->span, "master_name",
constant_consumers_[const_name]);
+ op->span = SpanUtils::SetAttr(op->span, "shared_ref",
constant_consumers_[const_name]);
} else {
constant_consumers_.Set(const_name, unique_name);
}