This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6b20caee2d [Bugfix] [Relay] Insertion of "device_copy" CallNode to
Resolve Device Conflict on Unconstrained Nodes (#15090)
6b20caee2d is described below
commit 6b20caee2d4098222f7c05a894c148b09e1df911
Author: lecoan <[email protected]>
AuthorDate: Thu Jun 22 08:33:16 2023 +0800
[Bugfix] [Relay] Insertion of "device_copy" CallNode to Resolve Device
Conflict on Unconstrained Nodes (#15090)
* Fix: add a new subpass in PlanDevice to add device_copy op for
conflicated inputs
* Fix some spelling errors in comments
* Fix some spelling errors in comments
---
src/relay/transforms/device_planner.cc | 227 ++++++++++++++++++++++++++-
tests/python/relay/test_pass_plan_devices.py | 47 ++++++
2 files changed, 268 insertions(+), 6 deletions(-)
diff --git a/src/relay/transforms/device_planner.cc
b/src/relay/transforms/device_planner.cc
index c9050c730d..80ae66ea9e 100644
--- a/src/relay/transforms/device_planner.cc
+++ b/src/relay/transforms/device_planner.cc
@@ -60,7 +60,7 @@
* 'result_virtual_device' function attributes we introduce below. This is
so the pass is
* idempotent and can be re-run to flow additional memory scope constraints.
*
- * We proceed in four phases:
+ * We proceed in five phases:
*
* Phase 0
* -------
@@ -77,6 +77,13 @@
*
* Phase 1
* -------
+ * We iteratively process the programs and find nodes with conflicting virtual
devices. If the
+ * virtual devices ( \p d1 and \p d2 ) are joinable, they are replaced with a
joined device \p d. If
+ * they are unjoinable, a "device_copy" CallNode is inserted to copy the node
output to the second
+ * device.
+ *
+ * Phase 2
+ * -------
* We flow constraints from the "on_device" and "device_copy" calls, PrimFunc
buffer memory scopes,
* and some special ops, to all other Relay sub-expressions.
*
@@ -109,7 +116,7 @@
* devices from their original Relay Function representations. However we know
all calls to those
* functions are device-consistent, thus no information is lost.
*
- * Phase 2
+ * Phase 3
* -------
* After flowing constraints we apply some defaulting heuristics (using a
global default \p
* VirtualDevice) to fix the device for any as-yet unconstrained
sub-expressions.
@@ -121,7 +128,7 @@
* This requires a formal notion of 'choicepoint' inside the compiler which
can integrate with
* automation.
*
- * Phase 3
+ * Phase 4
* -------
* Finally, the result of this analysis is reified into the result as:
* - Additional "param_virtual_devices" (an \p Array<VirtualDevice>) and
"result_virtual_device"
@@ -404,6 +411,201 @@ class RewriteOnDevices : public ExprMutator {
/* =============== Phase 1 =============== */
+/*!
+ * \brief Add "device_copy" calls for nodes that have conflicting virtual
devices.
+ *
+ * Eg Suppose an IRModule contains the following expr:
+ * \code
+ * %0 = add(%a, %b);
+ * %1 = on_device(%0, virtual_device=d1);
+ * %2 = add(%b, %c);
+ * %3 = on_device(%2, virtual_device=d2);
+ * \endcode
+ * In the above example, node %b has two possible virtual devices: \p d1 and
\p d2.
+ *
+ * - If \p d1 and \p d2 are joinable, replace \p d1 and \p d2 with the joined
device \p d:
+ * \code
+ * %0 = add(%a, %b);
+ * %1 = on_device(%0, virtual_device=d);
+ * %2 = add(%b, %c);
+ * %3 = on_device(%2, virtual_device=d);
+ * \endcode
+ *
+ * - If \p d1 and \p d2 are unjoinable, insert a "device_copy" CallNode to
copy \p %b to \p d2:
+ * \code
+ * %0 = add(%a, %b);
+ * %1 = on_device(%0, virtual_device=d);
+ * %2 = device_copy(%b, src_dev_type=d1, dst_dev_type=d2);
+ * %3 = add(%2, %c);
+ * %4 = on_device(%3, virtual_device=d);
+ * \endcode
+ */
+struct DeviceContext {
+ VirtualDevice VirtualDeviceFor(const ExprNode* expr) {
+ auto itr = expr_to_device.find(expr);
+ if (itr != expr_to_device.end()) {
+ return itr->second;
+ }
+ auto default_dev = VirtualDevice::FullyUnconstrained();
+ expr_to_device.emplace(expr, default_dev);
+ return default_dev;
+ }
+
+ bool Update(const ExprNode* expr, VirtualDevice dev) {
+ bool success = true;
+ auto pair = expr_to_device.emplace(expr, dev);
+ if (!pair.second) {
+ auto replaced_item = pair.first;
+ auto joined_dev = VirtualDevice::Join(replaced_item->second, dev);
+ if (joined_dev == nullptr) {
+ success = false;
+ } else {
+ replaced_item->second = joined_dev.value();
+ }
+ }
+ return success;
+ }
+
+ bool IsConflicted(const ExprNode* expr) {
+ auto itr = conflicted_nodes.find(expr);
+ return itr != conflicted_nodes.end();
+ }
+
+ std::unordered_set<const ExprNode*> conflicted_nodes;
+ std::unordered_map<const ExprNode*, VirtualDevice> expr_to_device;
+};
+
+/*!
+ * \brief Flow the device constraints over the module and find all the
conflicted nodes. The
+ * conflicted nodes only contain nodes that have no explicit constraints. For
example, "on_device"
+ * nodes are not considered as conflicted.
+ */
+class ConflictedNodeFinder : ExprVisitor {
+ public:
+ explicit ConflictedNodeFinder(IRModule mod)
+ : mod_(std::move(mod)), dev_ctx_(std::make_unique<DeviceContext>()) {}
+
+ std::unique_ptr<DeviceContext> Finder() {
+ VLOG_CONTEXT << "ConflictedNodeFinder";
+ for (const auto& kv : mod_->functions) {
+ if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) {
+ VisitExpr(GetRef<Function>(function_node));
+ }
+ }
+ for (auto const node : dev_ctx_->conflicted_nodes) {
+ if (node->IsInstance<CallNode>()) {
+ auto call = Downcast<Call>(GetRef<Expr>(node));
+ // "DeviceCapturer" will insert "device_copy" for "on_device" calls.
+ // Therefore, "on_device" should not be considered as conflicted.
+ if (call->op == OnDeviceOp()) {
+ dev_ctx_->conflicted_nodes.erase(node);
+ }
+ }
+ }
+ return std::move(dev_ctx_);
+ }
+
+ private:
+ void VisitExpr_(const CallNode* call_node) final {
+ VLOG(2) << "Initial call node: " << std::endl <<
PrettyPrint(GetRef<Call>(call_node));
+ auto call_dev = dev_ctx_->VirtualDeviceFor(call_node);
+ auto body_dev = call_dev;
+
+ auto on_dev_props = GetOnDeviceProps(call_node);
+ auto dev_cp_props = GetDeviceCopyProps(call_node);
+ if (call_node->op == OnDeviceOp()) {
+ if (on_dev_props.constrain_body) {
+ body_dev = on_dev_props.virtual_device;
+ }
+ if (on_dev_props.constrain_result) {
+ call_dev = on_dev_props.virtual_device;
+ }
+ } else if (call_node->op == DeviceCopyOp()) {
+ body_dev = dev_cp_props.src_virtual_device;
+ call_dev = dev_cp_props.dst_virtual_device;
+ }
+
+ if (!dev_ctx_->Update(call_node, call_dev) && call_node->op !=
OnDeviceOp()) {
+ LOG(FATAL) << "Mismatched device type after iterating args. Implied
device: " << std::endl
+ << PrettyPrint(call_dev) << "and practial device:" <<
std::endl
+ << PrettyPrint(dev_ctx_->VirtualDeviceFor(call_node)) <<
std::endl
+ << "With CallNode: " << std::endl
+ << PrettyPrint(GetRef<Call>(call_node));
+ }
+
+ for (auto& arg : call_node->args) {
+ VLOG(3) << "Handle call node arg: " << std::endl << PrettyPrint(arg);
+ if (!dev_ctx_->Update(arg.get(), body_dev)) {
+ VLOG(2) << "Conflicted node found:" << std::endl
+ << PrettyPrint(GetRef<Expr>(arg.get())) << std::endl
+ << "With corresponding Callee:" << std::endl
+ << PrettyPrint(GetRef<Call>(call_node));
+ dev_ctx_->conflicted_nodes.emplace(arg.get());
+ }
+ }
+ for (auto& expr : call_node->args) {
+ VisitExpr(expr);
+ }
+ }
+
+ IRModule mod_;
+ std::unique_ptr<DeviceContext> dev_ctx_;
+};
+
+/*!
+ * \brief Insert "device_copy" CallNode for all the conflicted nodes found by
\p
+ * ConflictedNodeFinder.
+ */
+class ConflictedNodeRewriter : ExprMutator {
+ public:
+ ConflictedNodeRewriter(IRModule mod, CompilationConfig config,
+ std::unique_ptr<DeviceContext> dev_ctx)
+ : mod_(mod), config_(config), dev_ctx_(std::move(dev_ctx)) {}
+
+ IRModule Rewrite() {
+ VLOG_CONTEXT << "ConflictedNodeRewriter";
+ IRModule result(/*functions=*/{}, mod_->type_definitions, mod_->Imports(),
mod_->source_map,
+ mod_->attrs);
+ for (const auto& kv : mod_->functions) {
+ if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) {
+ auto func = Mutate(GetRef<Function>(function_node));
+ result->Add(kv.first, Downcast<Function>(func));
+ } else {
+ result->Add(kv.first, kv.second);
+ }
+ }
+
+ return result;
+ }
+
+ private:
+ Expr VisitExpr_(const CallNode* call_node) final {
+ VLOG(3) << "Initial call node:" << std::endl <<
PrettyPrint(GetRef<Call>(call_node));
+ auto call = Downcast<Call>(ExprMutator::VisitExpr_(call_node));
+ tvm::Array<Expr> call_args;
+ call_args.reserve(call_node->args.size());
+ for (auto arg : call->args) {
+ if (dev_ctx_->IsConflicted(arg.get())) {
+ auto src_dev =
config_->CanonicalVirtualDevice(dev_ctx_->VirtualDeviceFor(arg.get()));
+ auto dst_dev =
config_->CanonicalVirtualDevice(dev_ctx_->VirtualDeviceFor(call_node));
+ call_args.push_back(MaybeDeviceCopy(arg, src_dev, dst_dev));
+ VLOG(2) << "Adding DeviceCopy Op: " << std::endl <<
PrettyPrint(call_args.back());
+ } else {
+ call_args.push_back(arg);
+ }
+ }
+ auto new_call = WithFields(GetRef<Call>(call_node), call_node->op,
call_args);
+ VLOG(3) << "Final call node:" << std::endl <<
PrettyPrint(GetRef<Call>(call_node));
+ return new_call;
+ }
+
+ IRModule mod_;
+ CompilationConfig config_;
+ std::unique_ptr<DeviceContext> dev_ctx_;
+};
+
+/* =============== Phase 2 =============== */
+
/*
* \brief Collects the system of device constraints for all sub-expressions in
a module.
* It is possible some devices remain free and will need to be defaulted by \p
DeviceDefaulter.
@@ -707,7 +909,7 @@ class DeviceAnalyzer : public MixedModeVisitor {
std::unique_ptr<DeviceDomains> domains_;
};
-/* =============== Phase 2 =============== */
+/* =============== Phase 3 =============== */
/*!
* \brief Calls to 'free' "on_device" annotations (ie where both
constrain_body=false and
@@ -865,7 +1067,7 @@ class DeviceDefaulter : public ExprVisitor {
std::unique_ptr<DeviceDomains> domains_;
};
-/* =============== Phase 3 =============== */
+/* =============== Phase 4 =============== */
/*!
* \brief Inserts missing "device_copy" CallNodes, and ensures the device type
of every
* sub-expression in a module can be easily recovered by a later
transformation using simple
@@ -1276,6 +1478,17 @@ tvm::transform::Pass Rewrite() {
return tvm::relay::transform::CreateFunctionPass(pass_func, 0,
"PlanDevicesRewrite", {});
}
+/*! \brief Check the conflicted nodes and add "device_copy" calls. */
+tvm::transform::Pass Check(CompilationConfig config) {
+ return tvm::transform::CreateModulePass(
+ [config = std::move(config)](IRModule mod,
+ tvm::transform::PassContext pass_cnxt) ->
IRModule {
+ auto dev_ctx = ConflictedNodeFinder(mod).Finder();
+ return ConflictedNodeRewriter(mod, config,
std::move(dev_ctx)).Rewrite();
+ },
+ /*opt_level=*/0, "PlanDevicesCheckConflicts", {});
+}
+
/*! \brief Run the remaining phases. */
tvm::transform::Pass PlanDevicesCore(CompilationConfig config) {
return tvm::transform::CreateModulePass(
@@ -1308,7 +1521,9 @@ tvm::transform::Pass PlanDevicesCore(CompilationConfig
config) {
tvm::transform::Pass PlanDevices(CompilationConfig config) {
std::vector<Pass> passes;
passes.emplace_back(Rewrite());
- passes.emplace_back(PlanDevicesCore(std::move(config)));
+ passes.emplace_back(Check(config));
+ passes.emplace_back(InferType());
+ passes.emplace_back(PlanDevicesCore(config));
return tvm::transform::Sequential(passes, "PlanDevices");
}
diff --git a/tests/python/relay/test_pass_plan_devices.py
b/tests/python/relay/test_pass_plan_devices.py
index 3ff49389cb..937ece1f82 100644
--- a/tests/python/relay/test_pass_plan_devices.py
+++ b/tests/python/relay/test_pass_plan_devices.py
@@ -1830,5 +1830,52 @@ def test_primitive():
print(mod)
+def test_conflicated_inputs():
+ metatable = {"VirtualDevice": [CPU, GPU]}
+
+ def input():
+ return tvm.relay.parse(
+ """
+ #[version = "0.0.5"]
+ def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32],
+ %c: Tensor[(5, 7), float32]) {
+ %0 = add(%a, %b);
+ %1 = on_device(%0, virtual_device=meta[VirtualDevice][0]);
+ %2 = add(%b, %c);
+ %3 = on_device(%2, virtual_device=meta[VirtualDevice][1]);
+ subtract(%1, %3)
+ }
+ """,
+ "from_string",
+ None,
+ metatable,
+ )
+
+ def expected():
+ return tvm.relay.parse(
+ """
+ #[version = "0.0.5"]
+ def @main(%a {virtual_device=meta[VirtualDevice][0]}: Tensor[(5,
7), float32],
+ %b {virtual_device=meta[VirtualDevice][0]}: Tensor[(5,
7), float32],
+ %c {virtual_device=meta[VirtualDevice][1]}: Tensor[(5,
7), float32]) {
+ %0 = add(%a, %b);
+ %1 = on_device(%0, virtual_device=meta[VirtualDevice][0],
constrain_result=True);
+ %2 = device_copy(%b,
src_virtual_device=meta[VirtualDevice][0],
dst_virtual_device=meta[VirtualDevice][1]);
+ %3 = device_copy(%1,
src_virtual_device=meta[VirtualDevice][0],
dst_virtual_device=meta[VirtualDevice][1]);
+ %4 = add(%2, %c);
+ subtract(%3, %4)
+ }
+ """,
+ "from_string",
+ None,
+ metatable,
+ )
+
+ def ref(a, b, c):
+ return np.subtract(np.add(a, b), np.add(b, c))
+
+ exercise(input(), expected(), ref, rands((5, 7), 3))
+
+
if __name__ == "__main__":
tvm.testing.main()