comaniac commented on a change in pull request #5919:
URL: https://github.com/apache/incubator-tvm/pull/5919#discussion_r445653492



##########
File path: src/runtime/contrib/json/json_runtime.h
##########
@@ -0,0 +1,267 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/runtime/json/json_runtime.h
+ * \brief Utilities for json runtime.
+ */
+
+#ifndef TVM_RUNTIME_CONTRIB_JSON_JSON_RUNTIME_H_
+#define TVM_RUNTIME_CONTRIB_JSON_JSON_RUNTIME_H_
+
+#include <tvm/runtime/container.h>
+#include <tvm/runtime/module.h>
+#include <tvm/runtime/ndarray.h>
+
+#include <cstddef>
+#include <string>
+#include <tuple>
+#include <type_traits>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "json_node.h"
+
+namespace tvm {
+namespace runtime {
+namespace json {
+
+/*!
+ * \brief A json runtime that executes the serialized JSON format. This runtime
+ * can be extended by user defined runtime for execution.
+ */
+class JSONRuntimeBase : public ModuleNode {
+ public:
+  JSONRuntimeBase(const std::string& symbol_name, const std::string& 
graph_json,
+                  const Array<String> const_names)
+      : symbol_name_(symbol_name), graph_json_(graph_json), 
const_names_(const_names) {
+    LoadGraph(graph_json_);
+  }
+
+  const char* type_key() const { return "json"; }
+
+  /*! \brief Initialize a specific json runtime. */
+  virtual void Init(const Array<NDArray>& consts) = 0;
+
+  /*! \brief Invoke the execution engine to inteprete a specific json runtime. 
*/
+  virtual void Run() = 0;
+
+  /*!
+   * \brief Get a packed function.
+   * \param name The name/symbol of the function.
+   * \param sptr_to_self The pointer to the module node.
+   * \return The packed function.
+   */
+  virtual PackedFunc GetFunction(const std::string& name, const 
ObjectPtr<Object>& sptr_to_self) {
+    if (name == "get_symbol") {
+      return PackedFunc(
+          [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = 
this->symbol_name_; });
+    } else if (name == "get_const_vars") {
+      return PackedFunc(
+          [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = 
this->const_names_; });
+    } else if (this->symbol_name_ == name) {
+      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+        CHECK(this->initialized_) << "The module has not been initialized";
+
+        // Bind argument tensors to data entries.
+        this->SetInputOutputBuffers(args);
+        // Execute the subgraph.
+        this->Run();
+      });
+    } else if ("__init_" + this->symbol_name_ == name) {
+      // The function to initialize constant tensors.
+      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+        CHECK_EQ(args.size(), 1U);
+        this->Init(args[0]);
+        this->initialized_ = true;
+        *rv = 0;
+      });
+    } else {
+      return PackedFunc(nullptr);
+    }
+  }
+
+  virtual void SaveToBinary(dmlc::Stream* stream) {
+    // Save the symbol
+    stream->Write(symbol_name_);
+    // Save the graph
+    stream->Write(graph_json_);
+    // Save the required const names
+    std::vector<std::string> consts;
+    for (const auto& it : const_names_) {
+      consts.push_back(it);
+    }
+    stream->Write(consts);
+  }
+
+  template <typename T,
+            typename = typename 
std::enable_if<std::is_base_of<JSONRuntimeBase, T>::value>::type>
+  static Module LoadFromBinary(void* strm) {
+    dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
+    std::string symbol;
+    std::string graph_json;
+    std::vector<std::string> consts;
+    // Load the symbol
+    CHECK(stream->Read(&symbol)) << "Loading symbol name failed";

Review comment:
       We had a discussion with @tqchen and our conclusion was that the 
overhead is ignorable. Looking into the implementation, the module 
initialization process is required for each subgraph, and the potential 
overhead of copying shared tensors from one metadata module to many customized 
modules has been avoided by zero-copy.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to