mbrookhart commented on a change in pull request #6337:
URL: https://github.com/apache/incubator-tvm/pull/6337#discussion_r478661229



##########
File path: src/relay/analysis/context_analysis.cc
##########
@@ -0,0 +1,697 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/analysis/context_analysis.cc
+ * \brief A pass for analyzing device attribute of each IR node.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/transform.h>
+#include <tvm/relay/type.h>
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/object.h>
+
+namespace tvm {
+namespace relay {
+
+using PackedAnalysisResultMap = Map<Expr, Array<Integer>>;
+using AnalysisResultMap =
+    std::unordered_map<Expr, TVMContext, runtime::ObjectPtrHash, 
runtime::ObjectPtrEqual>;
+
+namespace analysis {
+
+// Cache ops
+static const Op& device_copy_op = Op::Get("device_copy");
+static const Op& alloc_storage_op = Op::Get("memory.alloc_storage");
+static const Op& alloc_tensor_op = Op::Get("memory.alloc_tensor");
+static const Op& shape_of_op = Op::Get("vm.shape_of");
+static const Op& invoke_tvm_op = Op::Get("vm.invoke_tvm_op");
+static const Op& shape_func_of = Op::Get("vm.shape_func");
+static const Op& reshape_tensor_op = Op::Get("vm.reshape_tensor");
+
+class DeviceDomain;
+using DeviceDomainPtr = std::shared_ptr<DeviceDomain>;
+
+/*
+ * \brief A class to represent the device of a domain, i.e. a segment of relay 
program.
+ */
+class DeviceDomain {
+ public:
+  // Construct an empty domain.
+  DeviceDomain() {
+    ctx_.device_type = static_cast<DLDeviceType>(-1);
+    ctx_.device_id = -1;
+  }
+
+  // Construct a domain based on a given context.
+  explicit DeviceDomain(const TVMContext& ctx) : ctx_(ctx) {}
+
+  // Check if the current domain is empty.
+  bool IsEmptyDomain() const {
+    return static_cast<int>(ctx_.device_type) == -1 && ctx_.device_id == -1;
+  }
+
+  // Check if the current domain equals the other one.
+  bool operator==(const DeviceDomain& other) const {
+    return ctx_.device_type == other.ctx_.device_type && ctx_.device_id == 
other.ctx_.device_id;
+  }
+
+  bool operator!=(const DeviceDomain& other) const { return !(*this == other); 
}
+
+ private:
+  // Create a hash for a domain.
+  struct Hash {
+    size_t operator()(const DeviceDomainPtr& domain) const {
+      if (domain->IsEmptyDomain()) {
+        return (size_t)(domain.get());
+      } else {
+        size_t const 
h1(std::hash<int>()(static_cast<int>(domain->ctx_.device_type)));
+        size_t const h2(std::hash<int>()(domain->ctx_.device_id));
+        return h1 ^ (h2 << 1);
+      }
+    }
+  };
+
+  // Create an equality for domains.
+  struct Equal {
+   public:
+    bool operator()(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) 
const {
+      // We compare the pointer for empty domains.
+      if (lhs->IsEmptyDomain() && rhs->IsEmptyDomain()) return lhs.get() == 
rhs.get();
+
+      // Otherwise device type and id are used to check equality.
+      return (*lhs.get() == *rhs.get());
+    }
+  };
+
+  /* \brief The device to be assigned to the current domain. */
+  TVMContext ctx_;
+
+  friend DeviceDomainPtr Join(const DeviceDomainPtr& lhs, const 
DeviceDomainPtr& rhs);
+  friend class ContextAnalyzer;
+};
+
+// Join two domains.
+DeviceDomainPtr Join(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) {
+  if (lhs->IsEmptyDomain() && rhs->IsEmptyDomain()) {
+    return lhs;
+  } else if (lhs->IsEmptyDomain()) {
+    return rhs;
+  } else if (rhs->IsEmptyDomain()) {
+    return lhs;
+  } else {
+    CHECK(*lhs.get() == *rhs.get()) << "All expressions must have a singular 
device to unify";
+    return lhs;
+  }
+}
+
+/*
+ * \brief Compute on which device each sub-expression will execute. A union 
find
+ * algorithm is used to assign and merge the context domains.
+ */
+class ContextAnalyzer : public ExprVisitor {

Review comment:
       Does this need to be recursive? If not, can we use the MixedModeVisitor 
to prevent future issues with recursive stack overflow?

##########
File path: python/tvm/relay/backend/vm.py
##########
@@ -261,12 +260,6 @@ def _make_executor(self, expr=None):
 
         def _vm_wrapper(*args, **kwargs):
             args = self._convert_args(main, args, kwargs)
-            ret_type = self.mod["main"].checked_type.ret_type
-            if is_dynamic(ret_type) and "llvm" not in str(self.target) and 
"arm" not in str(
-                    self.target):
-                raise ValueError(
-                    "Virtual Machine only supports dynamic graphs on CPU, got 
output type",
-                    ret_type, "on target", self.target)

Review comment:
       :smile: 

##########
File path: python/tvm/relay/transform/memory_alloc.py
##########
@@ -66,7 +85,7 @@ def is_reshape_only(func):
 class ManifestAllocPass(ExprMutator):

Review comment:
       Definitely out of scope for this PR, but longer term this should 
probably be in C++, especially if we want portability of the VM.

##########
File path: python/tvm/relay/transform/memory_alloc.py
##########
@@ -75,8 +94,22 @@ def __init__(self, target_host):
         self.target_host = target_host
         self.default_context = cpu(0)
         self.compute_dtype = "int64"
+        self.context_analysis = context_analysis
         super().__init__()
 
+    def get_context(self, exp):
+        """Get the context of a given expression"""
+        assert exp in self.context_analysis, exp.astext(False)
+        val = self.context_analysis[exp]

Review comment:
       Nitpick: I don't love the name of this, given that the other 
context_analysis is a function, not a map. self.context_analysis_map?

##########
File path: src/relay/analysis/context_analysis.cc
##########
@@ -0,0 +1,697 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/analysis/context_analysis.cc
+ * \brief A pass for analyzing device attribute of each IR node.
+ */
+
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/attrs/device_copy.h>
+#include <tvm/relay/attrs/memory.h>
+#include <tvm/relay/attrs/transform.h>
+#include <tvm/relay/expr_functor.h>
+#include <tvm/relay/op.h>
+#include <tvm/relay/op_attr_types.h>
+#include <tvm/relay/transform.h>
+#include <tvm/relay/type.h>
+#include <tvm/runtime/c_runtime_api.h>
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/object.h>
+
+namespace tvm {
+namespace relay {
+
+using PackedAnalysisResultMap = Map<Expr, Array<Integer>>;
+using AnalysisResultMap =
+    std::unordered_map<Expr, TVMContext, runtime::ObjectPtrHash, 
runtime::ObjectPtrEqual>;
+
+namespace analysis {
+
+// Cache ops
+static const Op& device_copy_op = Op::Get("device_copy");
+static const Op& alloc_storage_op = Op::Get("memory.alloc_storage");
+static const Op& alloc_tensor_op = Op::Get("memory.alloc_tensor");
+static const Op& shape_of_op = Op::Get("vm.shape_of");
+static const Op& invoke_tvm_op = Op::Get("vm.invoke_tvm_op");
+static const Op& shape_func_of = Op::Get("vm.shape_func");
+static const Op& reshape_tensor_op = Op::Get("vm.reshape_tensor");
+
+class DeviceDomain;
+using DeviceDomainPtr = std::shared_ptr<DeviceDomain>;
+
+/*
+ * \brief A class to represent the device of a domain, i.e. a segment of relay 
program.
+ */
+class DeviceDomain {
+ public:
+  // Construct an empty domain.
+  DeviceDomain() {
+    ctx_.device_type = static_cast<DLDeviceType>(-1);
+    ctx_.device_id = -1;
+  }
+
+  // Construct a domain based on a given context.
+  explicit DeviceDomain(const TVMContext& ctx) : ctx_(ctx) {}
+
+  // Check if the current domain is empty.
+  bool IsEmptyDomain() const {
+    return static_cast<int>(ctx_.device_type) == -1 && ctx_.device_id == -1;
+  }
+
+  // Check if the current domain equals the other one.
+  bool operator==(const DeviceDomain& other) const {
+    return ctx_.device_type == other.ctx_.device_type && ctx_.device_id == 
other.ctx_.device_id;
+  }
+
+  bool operator!=(const DeviceDomain& other) const { return !(*this == other); 
}
+
+ private:
+  // Create a hash for a domain.
+  struct Hash {
+    size_t operator()(const DeviceDomainPtr& domain) const {
+      if (domain->IsEmptyDomain()) {
+        return (size_t)(domain.get());
+      } else {
+        size_t const 
h1(std::hash<int>()(static_cast<int>(domain->ctx_.device_type)));
+        size_t const h2(std::hash<int>()(domain->ctx_.device_id));
+        return h1 ^ (h2 << 1);
+      }
+    }
+  };
+
+  // Create an equality for domains.
+  struct Equal {
+   public:
+    bool operator()(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) 
const {
+      // We compare the pointer for empty domains.
+      if (lhs->IsEmptyDomain() && rhs->IsEmptyDomain()) return lhs.get() == 
rhs.get();
+
+      // Otherwise device type and id are used to check equality.
+      return (*lhs.get() == *rhs.get());
+    }
+  };
+
+  /* \brief The device to be assigned to the current domain. */
+  TVMContext ctx_;
+
+  friend DeviceDomainPtr Join(const DeviceDomainPtr& lhs, const 
DeviceDomainPtr& rhs);
+  friend class ContextAnalyzer;
+};
+
+// Join two domains.
+DeviceDomainPtr Join(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) {
+  if (lhs->IsEmptyDomain() && rhs->IsEmptyDomain()) {
+    return lhs;
+  } else if (lhs->IsEmptyDomain()) {
+    return rhs;
+  } else if (rhs->IsEmptyDomain()) {
+    return lhs;
+  } else {
+    CHECK(*lhs.get() == *rhs.get()) << "All expressions must have a singular 
device to unify";
+    return lhs;
+  }
+}
+
+/*
+ * \brief Compute on which device each sub-expression will execute. A union 
find
+ * algorithm is used to assign and merge the context domains.
+ */

Review comment:
       This is a complicated class. Could we get a more detailed doc string to 
explain what it's doing?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to