comaniac commented on a change in pull request #5915:
URL: https://github.com/apache/incubator-tvm/pull/5915#discussion_r454026302



##########
File path: src/relay/backend/contrib/arm_compute_lib/codegen.cc
##########
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/contrib/arm_compute_lib/codegen_acl.cc
+ * \brief Implementation of the Relay -> ACL JSON serializer.
+ */
+#include <tvm/ir/module.h>
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/type.h>
+
+#include "../../utils.h"
+#include "codegen_acl.h"
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace arm_compute_lib {
+
+using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
+using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry;
+
+std::vector<JSONGraphNodeEntry> ACLJSONSerializer::VisitExpr_(const CallNode* 
cn) {
+  Expr expr = GetRef<Expr>(cn);
+  std::string name;
+  std::shared_ptr<JSONGraphNode> json_node;
+
+  if (cn->op.as<OpNode>()) {
+    json_node = CreateOp(cn);
+  } else if (const auto* fn = cn->op.as<FunctionNode>()) {
+    auto comp = fn->GetAttr<String>(attr::kComposite);
+    CHECK(comp.defined()) << "Arm Compute Library JSON runtime only supports 
composite functions.";
+    name = comp.value();
+    if (name == "arm_compute_lib.conv2d") {
+      json_node = CreateCompositeConvolution(cn);
+    } else {
+      LOG(FATAL) << "Unrecognized Arm Compute Library pattern: " << name;
+    }
+  } else {
+    LOG(FATAL) << "Arm Compute Library JSON runtime does not support calls to "
+               << cn->op->GetTypeKey();
+  }
+
+  return AddNode(json_node, GetRef<Expr>(cn));
+}
+
+std::vector<JSONGraphNodeEntry> ACLJSONSerializer::VisitExpr_(const 
ConstantNode* cn) {
+  this->constants_.push_back(cn->data);
+  return JSONSerializer::VisitExpr_(cn);
+}
+
+std::shared_ptr<JSONGraphNode> ACLJSONSerializer::CreateOp(const CallNode* cn) 
{
+  const auto* op = cn->op.as<OpNode>();
+  CHECK(op);
+  const std::string name = op->name;
+  // Collect inputs
+  std::vector<JSONGraphNodeEntry> inputs;
+  for (const auto& arg : cn->args) {
+    auto res = VisitExpr(arg);
+    inputs.insert(inputs.end(), res.begin(), res.end());
+  }
+  // Create JSON op
+  auto json_node = std::make_shared<JSONGraphNode>(name, "kernel", inputs, 1);
+  SetCallNodeAttribute(json_node, cn);
+  return json_node;
+}
+
+std::shared_ptr<JSONGraphNode> 
ACLJSONSerializer::CreateCompositeConvolution(const CallNode* cn) {

Review comment:
       Ditto. s/CreateCompositeConvolution/CreateCompositeConvJSONNode/ ?

##########
File path: python/tvm/relay/op/contrib/arm_compute_lib.py
##########
@@ -0,0 +1,119 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-argument
+"""ACL library supported operators."""
+import tvm
+from tvm.relay import transform
+from tvm.relay.build_module import bind_params_by_name
+
+from ...dataflow_pattern import wildcard, is_op, is_constant
+from .register import register_pattern_table
+
+
+def is_arm_compute_runtime_present():

Review comment:
       `is_arm_compute_runtime_enabled` seems better to me.

##########
File path: python/tvm/relay/op/contrib/arm_compute_lib.py
##########
@@ -0,0 +1,119 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-argument
+"""ACL library supported operators."""
+import tvm
+from tvm.relay import transform
+from tvm.relay.build_module import bind_params_by_name
+
+from ...dataflow_pattern import wildcard, is_op, is_constant
+from .register import register_pattern_table
+
+
+def is_arm_compute_runtime_present():
+    """Check if the ACL graph runtime is present.
+
+    Returns
+    -------
+    ret: bool
+        True if present, False if not.
+    """
+    return tvm.get_global_func("relay.op.is_arm_compute_runtime_enabled", True)
+
+
+def partition_for_arm_compute_lib(mod, params=None):
+    """Partition the graph greedily offloading supported
+    operators to Arm Compute Library.
+
+    Parameters
+    ----------
+    mod : Module
+        The module to run passes on.
+    params : dict[str, NDArray]

Review comment:
       The right type for `params` should be `Optional[Dict[str, NDArray]]` 
(capital D for Dict) as it might be `None`.

##########
File path: src/runtime/contrib/arm_compute_lib/acl_utils.cc
##########
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/runtime/contrib/arm_compute_lib/acl_utils.cc
+ * \brief Utils and common functions for the interface.
+ */
+
+#include "acl_utils.h"
+
+#include <arm_compute/runtime/OffsetLifetimeManager.h>
+#include <arm_compute/runtime/PoolManager.h>
+#include <tvm/runtime/data_type.h>
+
+namespace tvm {
+namespace runtime {
+namespace contrib {
+namespace arm_compute_lib {
+
+using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
+
+void CheckACLError(const arm_compute::Status& status) {
+  CHECK(status.error_code() == arm_compute::ErrorCode::OK) << "ACL: " << 
status.error_description();
+}
+
+arm_compute::Tensor MakeTensor(const JSONGraphNode& tensor_rep, void* data) {
+  CHECK(tensor_rep.GetOpType() == "input" || tensor_rep.GetOpType() == 
"const");
+  arm_compute::Tensor tensor;
+  arm_compute::TensorInfo info = MakeTensorInfo(tensor_rep.GetOpShape()[0]);
+  tensor.allocator()->init(info);
+  if (data != nullptr) {
+    CheckACLError(tensor.allocator()->import_memory(data));
+  }
+  return tensor;
+}
+
+arm_compute::Tensor MakeOutputTensor(const std::vector<int64_t>& shape) {
+  arm_compute::Tensor tensor;
+  tensor.allocator()->init(MakeTensorInfo(shape));
+  return tensor;
+}
+
+arm_compute::TensorInfo MakeTensorInfo(const std::vector<int64_t>& shape) {
+  arm_compute::TensorShape acl_shape = MakeTensorShape(shape);
+  return arm_compute::TensorInfo(acl_shape, 1, arm_compute::DataType::F32,
+                                 arm_compute::DataLayout::NHWC);
+}
+
+arm_compute::TensorShape MakeTensorShape(const std::vector<int64_t>& shape) {
+  arm_compute::TensorShape acl_shape;
+  for (unsigned int i = shape.size(); i > 0; --i) {
+    acl_shape.set(shape.size() - i, shape[i - 1]);
+  }
+  return acl_shape;
+}
+
+std::shared_ptr<arm_compute::MemoryManagerOnDemand> MakeMemoryManager() {
+  auto lifetime_mgr = std::make_shared<arm_compute::OffsetLifetimeManager>();
+  auto pool_mgr = std::make_shared<arm_compute::PoolManager>();
+  return std::make_shared<arm_compute::MemoryManagerOnDemand>(lifetime_mgr, 
pool_mgr);
+}
+
+arm_compute::PadStrideInfo ToACLPadStride(const std::vector<std::string>& pad,
+                                          const std::vector<std::string>& 
stride) {
+  int pad_0, pad_1, pad_2, pad_3;
+
+  size_t size = pad.size();
+  if (size == 1) {
+    int pad_v = std::stoi(pad[0]);
+    pad_0 = pad_v;
+    pad_1 = pad_v;
+    pad_2 = pad_v;
+    pad_3 = pad_v;
+  } else if (size == 2) {
+    // TVM: height, width -> ACL: left, right, top, bottom
+    int pad_h = std::stoi(pad[0]);
+    int pad_w = std::stoi(pad[1]);
+    pad_0 = pad_w;
+    pad_1 = pad_w;
+    pad_2 = pad_h;
+    pad_3 = pad_h;
+  } else if (size == 4) {
+    // TVM: top, left, bottom, right -> ACL: left, right, top, bottom
+    pad_0 = std::stoi(pad[1]);
+    pad_1 = std::stoi(pad[3]);
+    pad_2 = std::stoi(pad[0]);
+    pad_3 = std::stoi(pad[2]);
+  } else {
+    LOG(FATAL) << "Unsupported padding dimensions";
+    return arm_compute::PadStrideInfo();

Review comment:
       FATAL will throw exception and crash the execution so you don't need 
this return.

##########
File path: src/runtime/contrib/arm_compute_lib/acl_runtime.cc
##########
@@ -0,0 +1,399 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/runtime/contrib/arm_compute_lib/acl_runtime.cc
+ * \brief A simple JSON runtime for Arm Compute Library.
+ */
+
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/registry.h>
+
+#include "../../file_util.h"
+#include "../json/json_node.h"
+#include "../json/json_runtime.h"
+
+#ifdef TVM_GRAPH_RUNTIME_ARM_COMPUTE_LIB
+#include <arm_compute/core/Types.h>
+#include <arm_compute/runtime/NEON/functions/NEConvolutionLayer.h>
+#include <arm_compute/runtime/NEON/functions/NEPoolingLayer.h>
+#include <arm_compute/runtime/NEON/functions/NEReshapeLayer.h>
+
+#include "acl_allocator.h"
+#include "acl_utils.h"
+#endif
+
+namespace tvm {
+namespace runtime {
+namespace contrib {
+
+using namespace tvm::runtime::json;
+
+#ifdef TVM_GRAPH_RUNTIME_ARM_COMPUTE_LIB
+using namespace arm_compute_lib;
+
+/*!
+ * \brief ACL objects we cache in order to avoid needing to construct
+ * a new layer each time.
+ */
+struct CachedLayer {
+  std::shared_ptr<arm_compute::IFunction> function;
+  std::vector<arm_compute::Tensor> inputs;
+  std::vector<arm_compute::Tensor> const_inputs;
+  std::vector<arm_compute::Tensor> outputs;
+};
+#endif
+
+class ACLRuntime : public JSONRuntimeBase {
+ public:
+  /*!
+   * \brief The ACL runtime module. Deserialize the provided functions
+   * on creation and store in the layer cache.
+   *
+   * \param symbol_name The name of the function.
+   * \param graph_json serialized JSON representation of a sub-graph.
+   * \param const_names The names of each constant in the sub-graph.
+   * \params consts An array of constants pre-transposed to the correct layout 
expected by ACL.
+   */
+  explicit ACLRuntime(const std::string& symbol_name, const std::string& 
graph_json,
+                      const Array<String>& const_names, const Array<NDArray>& 
consts)
+      : JSONRuntimeBase(symbol_name, graph_json, const_names) {
+    this->constants_ = consts;
+  }
+
+  /*!
+   * \brief Get a packed function.
+   *
+   * \param name The name/symbol of the function.
+   * \param sptr_to_self The pointer to the module node.
+   * \return The packed function.
+   */
+  PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& 
sptr_to_self) override {
+    if (name == "get_symbol") {
+      return PackedFunc(
+          [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = 
this->symbol_name_; });
+    } else if (name == "get_const_vars") {
+      return PackedFunc(
+          [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = 
this->const_names_; });
+    } else if (this->symbol_name_ == name) {
+      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+        CHECK(this->initialized_) << "The module has not been initialized";
+
+        // Bind argument tensors to data entries.
+        this->SetInputOutputBuffers(args);
+        // Execute the subgraph.
+        this->Run();
+      });
+    } else if ("__init_" + this->symbol_name_ == name) {
+      // The function to initialize constant tensors.
+      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+        this->Init();
+        this->initialized_ = true;
+        *rv = 0;
+      });
+    } else {
+      return PackedFunc(nullptr);
+    }
+  }
+
+  /*!
+   * \brief Save a compiled network to a binary stream, which can then be
+   * serialized to disk.
+   *
+   * \param stream The stream to save the binary.
+   */
+  void SaveToBinary(dmlc::Stream* stream) override {
+    // Save the symbol
+    stream->Write(symbol_name_);
+    // Save the graph
+    stream->Write(graph_json_);
+    // Save the required const names
+    std::vector<std::string> const_names;
+    for (const auto& it : const_names_) {
+      const_names.push_back(it);
+    }
+    stream->Write(const_names);
+    // Save the required constant data
+    stream->Write(constants_.size());
+    for (const auto& it : constants_) {
+      it.Save(stream);
+    }
+  }
+
+  /*!
+   * \brief Load a compiled network from stream.
+   *
+   * \param strm The binary stream to load.
+   * \return The created ACL module.
+   */
+  static Module LoadFromBinary(void* strm) {
+    dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
+    std::string symbol;
+    std::string graph_json;
+    std::vector<std::string> consts;
+    // Load the symbol
+    CHECK(stream->Read(&symbol)) << "Loading symbol name failed";
+    CHECK(stream->Read(&graph_json)) << "Loading graph json failed";
+    CHECK(stream->Read(&consts)) << "Loading the const name list failed";
+    Array<String> const_names;
+    for (const auto& it : consts) {
+      const_names.push_back(it);
+    }
+    size_t const_data_count;
+    CHECK(stream->Read(&const_data_count));
+    Array<NDArray> const_data;
+    for (size_t i = 0; i < const_data_count; ++i) {
+      runtime::NDArray temp;
+      CHECK(temp.Load(stream)) << "Failed to load constant";
+      const_data.push_back(temp);
+    }
+    auto n = make_object<ACLRuntime>(symbol, graph_json, const_names, 
const_data);
+    return Module(n);
+  }
+
+  /*!
+   * \brief The type key of the module.
+   *
+   * \return module type key.
+   */
+  const char* type_key() const override { return "arm_compute_lib"; }
+
+  /*!
+   * \brief Initialize runtime. Create ACL layer from JSON
+   * representation.
+   */
+  void Init() {
+    CHECK_EQ(this->constants_.size(), const_idx_.size())
+        << "The number of input constants must match the number expected.";
+    this->SetupConstants(this->constants_);
+#ifdef TVM_GRAPH_RUNTIME_ARM_COMPUTE_LIB
+    BuildEngine();
+#endif
+  }
+
+  // Do not accept constants from MetadataModule as they should be transposed
+  // by the ACL codegen so they have the correct expected layout.
+  void Init(const Array<NDArray>& constants) override { LOG(FATAL) << "Not 
implemented."; }
+
+  /*!
+   * \brief Unpack inputs and outputs and run inference on a given layer.
+   *
+   * \param args Access inputs and outputs.
+   * \param function The layer to execute inference on.
+   * \return Status of inference.
+   */
+  void Run() override {
+#ifdef TVM_GRAPH_RUNTIME_ARM_COMPUTE_LIB
+    for (size_t i = 0; i < input_nodes_.size(); ++i) {
+      auto nid = input_nodes_[i];
+      uint32_t eid = EntryID(nid, 0);
+      if (nodes_[nid].GetOpType() == "input") {
+        void* data = data_entry_[eid]->data;
+        CheckACLError(layer_.inputs[i].allocator()->import_memory(data));
+      }
+    }
+
+    for (size_t i = 0; i < outputs_.size(); ++i) {
+      uint32_t eid = EntryID(outputs_[i]);
+      void* data = data_entry_[eid]->data;
+      CheckACLError(layer_.outputs[i].allocator()->import_memory(data));
+    }
+
+    this->layer_.function->run();
+#else
+    LOG(FATAL) << "Cannot call run on Arm Compute Library module without 
runtime enabled. "
+               << "Please build with USE_ACL_GRAPH_RUNTIME.";
+#endif
+  }
+
+  /*!
+   * \brief Get the JSON generated by codegen.
+   *
+   * \param format the format to return (only JSON for the time being)
+   * \return A string of JSON.
+   */
+  std::string GetSource(const std::string& format) override {
+    if (format == "json") {
+      return graph_json_;
+    }
+    LOG(FATAL) << "Format not supported by Arm Compute Library runtime.";
+    return "";
+  }
+
+ private:
+#ifdef TVM_GRAPH_RUNTIME_ARM_COMPUTE_LIB
+  /*!
+   * \brief Build ACL layer from JSON representation and cache.
+   *
+   * \note For the time being only one layer or operator is supported
+   * per engine.
+   */
+  void BuildEngine() {
+    std::shared_ptr<arm_compute::MemoryManagerOnDemand> mm = 
MakeMemoryManager();
+    int num_pools = 0;
+
+    for (size_t i = 0; i < input_nodes_.size(); ++i) {
+      uint32_t nid = input_nodes_[i];
+      const auto& node = nodes_[nid];
+      if (node.GetOpType() == "input") {
+        layer_.inputs.push_back(MakeTensor(node));
+      } else if (node.GetOpType() == "const") {
+        uint32_t eid = EntryID(nid, 0);
+        void* data = data_entry_[eid]->data;
+        layer_.const_inputs.push_back(MakeTensor(node, data));
+      }
+    }
+
+    for (size_t nid = 0; nid < nodes_.size(); ++nid) {
+      const auto& node = nodes_[nid];
+      if (node.GetOpType() == "kernel") {
+        auto op_name = node.GetOpName();
+        if ("nn.conv2d" == op_name || "arm_compute_lib.conv2d" == op_name) {
+          CreateConvolution2DLayer(&layer_, node, mm);
+          num_pools++;
+        } else if ("nn.max_pool2d" == op_name) {
+          CreatePoolingLayer(&layer_, node);
+        } else if ("reshape" == op_name) {
+          CreateReshapeLayer(&layer_, node);
+        } else {
+          LOG(FATAL) << "Unsupported op: " << op_name;
+        }
+        // Only expect one op for the time being
+        break;

Review comment:
       Didn't catch this comment. You meant the current ACL runtime only 
supports one kernel node? If so you should error out if there are more than one 
kernel nodes; otherwise the functionality is incorrect.

##########
File path: src/relay/backend/contrib/arm_compute_lib/codegen_acl.h
##########
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/contrib/arm_compute_lib/codegen_acl.h
+ * \brief The Relay -> ACL JSON schema compiler.
+ */
+
+#ifndef TVM_RELAY_BACKEND_CONTRIB_ARM_COMPUTE_LIB_CODEGEN_ACL_H_
+#define TVM_RELAY_BACKEND_CONTRIB_ARM_COMPUTE_LIB_CODEGEN_ACL_H_
+
+#include <tvm/relay/expr_functor.h>
+
+#include <map>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "../../../../runtime/contrib/json/json_node.h"
+#include "../codegen_json/codegen_json.h"
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace arm_compute_lib {
+
+/*!
+ * \brief Generates an ACLModule from a relay expression. This "compilation"
+ * does not require ACL since the actual conversion using ACL APIs is
+ * deferred until creation of the runtime. This step simply serializes the
+ * relay program into a JSON string.
+ */
+class ACLJSONSerializer : public backend::contrib::JSONSerializer {
+  using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
+  using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry;
+
+ public:
+  ACLJSONSerializer(const std::string& symbol, const Expr& expr) : 
JSONSerializer(symbol, expr) {}
+
+  std::vector<JSONGraphNodeEntry> VisitExpr_(const CallNode* cn) override;
+  std::vector<JSONGraphNodeEntry> VisitExpr_(const ConstantNode* cn) override;
+
+  /*!
+   * \brief Get the constant data transposed when pre-processing the
+   * input function.
+   *
+   * \return An array of constants
+   */
+  Array<runtime::NDArray> GetParamsData();
+
+ private:
+  /*!
+   * \brief Create a JSON representation of an operator.
+   *
+   * \param call The call to be represented.
+   * \return A JSON representation of a specific operator.
+   */
+  std::shared_ptr<JSONGraphNode> CreateOp(const CallNode* cn);
+  std::shared_ptr<JSONGraphNode> CreateCompositeConvolution(const CallNode* 
cn);
+
+  /* \brief Transposed constant tensors to serialize. Arm Compute Library 
expects constant tensors
+   * in OHWI format. */
+  Array<runtime::NDArray> constants_;
+};
+
+/*!
+ * \brief Pre-process a module containing functions ready for ACL codegen.
+ *
+ * For now we enforce OHWI kernel layout and fold the transforms away.
+ *
+ * \param mod The module to be pre-processed.
+ * \return The processed module.
+ */
+IRModule PreProcessModule(const IRModule& mod);
+
+/*!
+ * \brief Create a runtime module for ACL.
+ *
+ * This consists of a series of "serialized functions" which each represent a
+ * sub-graph to be computed by ACL and will each be executed independently from
+ * one another. Each function consists of serialized JSON describing the 
sub-graph
+ * and serialized constant tensors.
+ *
+ * \note The ACL runtime module only currently supports a single operator per
+ * sub-graph currently.
+ *
+ * \param ref The ext_func Relay expression/module to be executed using extern 
ops.
+ * \return A runtime module.
+ */
+runtime::Module ACLCompiler(const ObjectRef& ref);
+
+/*!
+ * \brief Get the external symbol of the Relay function name.
+ *
+ * \param func The provided function.
+ *
+ * \return An external symbol.
+ */
+std::string GetExtSymbol(const Function& func) {

Review comment:
       This is implemented in `CSourceModuleCodegenBase` already. Maybe we can 
move this function out of `CSourceModuleCodegenBase` so that it can be used by 
all backend under contrib.
   
   cc @zhiics 

##########
File path: python/tvm/relay/op/contrib/arm_compute_lib.py
##########
@@ -0,0 +1,119 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-argument
+"""ACL library supported operators."""
+import tvm
+from tvm.relay import transform
+from tvm.relay.build_module import bind_params_by_name
+
+from ...dataflow_pattern import wildcard, is_op, is_constant
+from .register import register_pattern_table
+
+
+def is_arm_compute_runtime_present():
+    """Check if the ACL graph runtime is present.
+
+    Returns
+    -------
+    ret: bool
+        True if present, False if not.
+    """
+    return tvm.get_global_func("relay.op.is_arm_compute_runtime_enabled", True)
+
+
+def partition_for_arm_compute_lib(mod, params=None):
+    """Partition the graph greedily offloading supported
+    operators to Arm Compute Library.
+
+    Parameters
+    ----------
+    mod : Module
+        The module to run passes on.
+    params : dict[str, NDArray]
+        Constant input parameters.
+
+    Returns
+    -------
+    ret : annotated and partitioned module.
+    """
+    if params:
+        mod['main'] = bind_params_by_name(mod['main'], params)
+
+    seq = tvm.transform.Sequential([transform.MergeComposite(pattern_table()),
+                                    
transform.AnnotateTarget('arm_compute_lib'),
+                                    transform.PartitionGraph()])
+
+    return seq(mod)
+
+
+@register_pattern_table("arm_compute_lib")
+def pattern_table():

Review comment:
       It's better to improve the naming of this function as it is called 
directly in this file in addition to register to the global pattern table. 
Something like `arm_compute_lib_pattern_table` for 
`get_arm_compute_lib_pattern_table` would be better to me.

##########
File path: src/runtime/contrib/arm_compute_lib/acl_runtime.cc
##########
@@ -0,0 +1,399 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/runtime/contrib/arm_compute_lib/acl_runtime.cc
+ * \brief A simple JSON runtime for Arm Compute Library.
+ */
+
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/registry.h>
+
+#include "../../file_util.h"
+#include "../json/json_node.h"
+#include "../json/json_runtime.h"
+
+#ifdef TVM_GRAPH_RUNTIME_ARM_COMPUTE_LIB
+#include <arm_compute/core/Types.h>
+#include <arm_compute/runtime/NEON/functions/NEConvolutionLayer.h>
+#include <arm_compute/runtime/NEON/functions/NEPoolingLayer.h>
+#include <arm_compute/runtime/NEON/functions/NEReshapeLayer.h>
+
+#include "acl_allocator.h"
+#include "acl_utils.h"
+#endif
+
+namespace tvm {
+namespace runtime {
+namespace contrib {
+
+using namespace tvm::runtime::json;
+
+#ifdef TVM_GRAPH_RUNTIME_ARM_COMPUTE_LIB
+using namespace arm_compute_lib;
+
+/*!
+ * \brief ACL objects we cache in order to avoid needing to construct
+ * a new layer each time.
+ */
+struct CachedLayer {
+  std::shared_ptr<arm_compute::IFunction> function;
+  std::vector<arm_compute::Tensor> inputs;
+  std::vector<arm_compute::Tensor> const_inputs;
+  std::vector<arm_compute::Tensor> outputs;
+};
+#endif
+
+class ACLRuntime : public JSONRuntimeBase {
+ public:
+  /*!
+   * \brief The ACL runtime module. Deserialize the provided functions
+   * on creation and store in the layer cache.
+   *
+   * \param symbol_name The name of the function.
+   * \param graph_json serialized JSON representation of a sub-graph.
+   * \param const_names The names of each constant in the sub-graph.
+   * \params consts An array of constants pre-transposed to the correct layout 
expected by ACL.
+   */
+  explicit ACLRuntime(const std::string& symbol_name, const std::string& 
graph_json,
+                      const Array<String>& const_names, const Array<NDArray>& 
consts)
+      : JSONRuntimeBase(symbol_name, graph_json, const_names) {
+    this->constants_ = consts;
+  }
+
+  /*!
+   * \brief Get a packed function.
+   *
+   * \param name The name/symbol of the function.
+   * \param sptr_to_self The pointer to the module node.
+   * \return The packed function.
+   */
+  PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& 
sptr_to_self) override {
+    if (name == "get_symbol") {
+      return PackedFunc(
+          [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = 
this->symbol_name_; });
+    } else if (name == "get_const_vars") {
+      return PackedFunc(
+          [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = 
this->const_names_; });
+    } else if (this->symbol_name_ == name) {
+      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+        CHECK(this->initialized_) << "The module has not been initialized";
+
+        // Bind argument tensors to data entries.
+        this->SetInputOutputBuffers(args);
+        // Execute the subgraph.
+        this->Run();
+      });
+    } else if ("__init_" + this->symbol_name_ == name) {
+      // The function to initialize constant tensors.
+      return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+        this->Init();
+        this->initialized_ = true;
+        *rv = 0;
+      });
+    } else {
+      return PackedFunc(nullptr);
+    }
+  }
+
+  /*!
+   * \brief Save a compiled network to a binary stream, which can then be
+   * serialized to disk.
+   *
+   * \param stream The stream to save the binary.
+   */
+  void SaveToBinary(dmlc::Stream* stream) override {
+    // Save the symbol
+    stream->Write(symbol_name_);
+    // Save the graph
+    stream->Write(graph_json_);
+    // Save the required const names
+    std::vector<std::string> const_names;
+    for (const auto& it : const_names_) {
+      const_names.push_back(it);
+    }
+    stream->Write(const_names);
+    // Save the required constant data
+    stream->Write(constants_.size());
+    for (const auto& it : constants_) {
+      it.Save(stream);
+    }
+  }
+
+  /*!
+   * \brief Load a compiled network from stream.
+   *
+   * \param strm The binary stream to load.
+   * \return The created ACL module.
+   */
+  static Module LoadFromBinary(void* strm) {
+    dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
+    std::string symbol;
+    std::string graph_json;
+    std::vector<std::string> consts;
+    // Load the symbol
+    CHECK(stream->Read(&symbol)) << "Loading symbol name failed";
+    CHECK(stream->Read(&graph_json)) << "Loading graph json failed";
+    CHECK(stream->Read(&consts)) << "Loading the const name list failed";
+    Array<String> const_names;
+    for (const auto& it : consts) {
+      const_names.push_back(it);
+    }
+    size_t const_data_count;
+    CHECK(stream->Read(&const_data_count));
+    Array<NDArray> const_data;
+    for (size_t i = 0; i < const_data_count; ++i) {
+      runtime::NDArray temp;
+      CHECK(temp.Load(stream)) << "Failed to load constant";
+      const_data.push_back(temp);
+    }
+    auto n = make_object<ACLRuntime>(symbol, graph_json, const_names, 
const_data);
+    return Module(n);
+  }
+
+  /*!
+   * \brief The type key of the module.
+   *
+   * \return module type key.
+   */
+  const char* type_key() const override { return "arm_compute_lib"; }
+
+  /*!
+   * \brief Initialize runtime. Create ACL layer from JSON
+   * representation.
+   */
+  void Init() {
+    CHECK_EQ(this->constants_.size(), const_idx_.size())
+        << "The number of input constants must match the number expected.";
+    this->SetupConstants(this->constants_);
+#ifdef TVM_GRAPH_RUNTIME_ARM_COMPUTE_LIB
+    BuildEngine();
+#endif
+  }
+
+  // Do not accept constants from MetadataModule as they should be transposed
+  // by the ACL codegen so they have the correct expected layout.
+  void Init(const Array<NDArray>& constants) override { LOG(FATAL) << "Not 
implemented."; }

Review comment:
       it seems to me that even the constants are transposed by the preprocess 
sequence, you can still leave them to MetadataModule? If that's the case, you 
can reuse `Init(const Array<NDArray>& constants)`, and you don't have to 
save/load constants from binary.

##########
File path: src/relay/backend/contrib/arm_compute_lib/codegen.cc
##########
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relay/backend/contrib/arm_compute_lib/codegen_acl.cc
+ * \brief Implementation of the Relay -> ACL JSON serializer.
+ */
+#include <tvm/ir/module.h>
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/type.h>
+
+#include "../../utils.h"
+#include "codegen_acl.h"
+
+namespace tvm {
+namespace relay {
+namespace contrib {
+namespace arm_compute_lib {
+
+using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
+using JSONGraphNodeEntry = tvm::runtime::json::JSONGraphNodeEntry;
+
+std::vector<JSONGraphNodeEntry> ACLJSONSerializer::VisitExpr_(const CallNode* 
cn) {
+  Expr expr = GetRef<Expr>(cn);
+  std::string name;
+  std::shared_ptr<JSONGraphNode> json_node;
+
+  if (cn->op.as<OpNode>()) {
+    json_node = CreateOp(cn);
+  } else if (const auto* fn = cn->op.as<FunctionNode>()) {
+    auto comp = fn->GetAttr<String>(attr::kComposite);
+    CHECK(comp.defined()) << "Arm Compute Library JSON runtime only supports 
composite functions.";
+    name = comp.value();
+    if (name == "arm_compute_lib.conv2d") {
+      json_node = CreateCompositeConvolution(cn);
+    } else {
+      LOG(FATAL) << "Unrecognized Arm Compute Library pattern: " << name;
+    }
+  } else {
+    LOG(FATAL) << "Arm Compute Library JSON runtime does not support calls to "
+               << cn->op->GetTypeKey();
+  }
+
+  return AddNode(json_node, GetRef<Expr>(cn));
+}
+
+std::vector<JSONGraphNodeEntry> ACLJSONSerializer::VisitExpr_(const 
ConstantNode* cn) {
+  this->constants_.push_back(cn->data);
+  return JSONSerializer::VisitExpr_(cn);
+}
+
+std::shared_ptr<JSONGraphNode> ACLJSONSerializer::CreateOp(const CallNode* cn) 
{

Review comment:
       Would s/CreateOp/CreateOpJSONNode/ be better?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to