This is an automated email from the ASF dual-hosted git repository.

zhaowu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 7aa2de3  [Hexagon] Initial support for Hexagon codegen (#6261)
7aa2de3 is described below

commit 7aa2de31157ef11b43c38ccba5bf9c9539406abf
Author: Krzysztof Parzyszek <kparz...@quicinc.com>
AuthorDate: Wed Aug 19 04:17:07 2020 -0500

    [Hexagon] Initial support for Hexagon codegen (#6261)
    
    * [Hexagon] Initial support for Hexagon codegen
    
    This commit does not support parallel execution or prefetch.
    LLVM 7 or later is required.
    
    * Set native_vector_bits_ based on target features
    
    * Initialize hvx_bytes
    
    * Remove commented out line
---
 python/tvm/contrib/hexagon.py                      | 211 ++++++
 python/tvm/target/target.py                        |  22 +-
 src/runtime/hexagon/hexagon_module.cc              |  22 +-
 src/runtime/module.cc                              |   2 +
 src/target/llvm/codegen_hexagon.cc                 | 812 +++++++++++++++++++++
 src/target/llvm/intrin_rule_hexagon.cc             |  65 ++
 src/target/target_kind.cc                          |   4 +
 .../python/unittest/test_target_codegen_hexagon.py |  95 +++
 8 files changed, 1225 insertions(+), 8 deletions(-)

diff --git a/python/tvm/contrib/hexagon.py b/python/tvm/contrib/hexagon.py
new file mode 100644
index 0000000..6870e2a
--- /dev/null
+++ b/python/tvm/contrib/hexagon.py
@@ -0,0 +1,211 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+'''Utility for Hexagon backend'''
+
+import functools as ft
+import os
+import tvm
+import tvm.ir
+import tvm.contrib.cc as cc
+from .._ffi.registry import register_func
+
+
+# Linking Hexagon shared libraries.
+#
+#   link_shared(name-of-shared-library, list-of-objects, kw-args)
+#
+# To use a custom linker, define a function that returns the path to the
+# linker, and pass it to 'register_linker':
+#
+#   def custom_linker_path():
+#       return '/path/to/hexagon/linker'
+#
+#   register_linker(custom_linker_path)
+#
+# Subsequent calls to 'link_shared' will use the newly registered linker.
+
+hexagon_toolchain_root = os.environ.get('HEXAGON_TOOLCHAIN') or ''  # pylint: 
disable=invalid-name
+hexagon_link_master = os.path.join(                                 # pylint: 
disable=invalid-name
+    hexagon_toolchain_root, 'bin', 'hexagon-link')
+
+def register_linker(f):
+    """Register a function that will return the path to the Hexagon linker."""
+    return register_func('tvm.contrib.hexagon.hexagon_link', f, True)
+
+@register_func('tvm.contrib.hexagon.hexagon_link')
+def hexagon_link():
+    """Return path to the Hexagon linker."""
+    return hexagon_link_master
+
+@register_func('tvm.contrib.hexagon.link_shared')
+def link_shared(so_name, objs, **kwargs):
+    """Link shared library on Hexagon using the registered Hexagon linker.
+
+    Parameters
+    ----------
+    so_name : str
+        Name of the shared library file.
+    objs : list[str,StringImm]
+    kwargs : additional arguments:
+        'verbose' - print additional information
+
+    Returns
+    -------
+    ret_val : int
+        This function returns 0 at the moment.
+    """
+    # The list of object files can be passed as built-in Python strings,
+    # or as tvm.tir.StringImm's.
+    def to_str(s):
+        if isinstance(s, tvm.tir.StringImm):
+            return s.value
+        assert isinstance(s, str), 'argument "' + str(s) + '" should be a 
string or StrImm'
+        return s
+    objs = [to_str(s) for s in objs]
+
+    linker = tvm.get_global_func('tvm.contrib.hexagon.hexagon_link')()
+    if kwargs.get('verbose'):
+        print('tvm.contrib.hexagon.link_shared:')
+        print('  Using linker:', linker)
+        print('  Library name:', so_name)
+        print('  Object files:', objs)
+    if not os.access(linker, os.X_OK):
+        message = 'The linker "' + linker + '" does not exist or is not 
executable.'
+        if not os.environ.get('HEXAGON_TOOLCHAIN'):
+            message += ' The environment variable HEXAGON_TOOLCHAIN is unset. 
Please export ' + \
+                'HEXAGON_TOOLCHAIN in your environment, so that 
${HEXAGON_TOOLCHAIN}/bin/' + \
+                'hexagon-link exists.'
+        else:
+            message += ' Please verify the value of the HEXAGON_LINKER 
environment variable ' + \
+                '(currently set to "' + hexagon_toolchain_root + '").'
+        raise Exception(message)
+
+    libpath = os.path.join(
+        hexagon_toolchain_root, 'target', 'hexagon', 'lib', 'v66', 'G0')
+    cc.create_shared(
+        so_name, objs,
+        # pylint: disable=bad-whitespace
+        options = ['-Bdynamic', '-shared', '-export-dynamic',
+                   os.path.join(libpath, 'pic', 'libgcc.so')],
+        cc = linker)
+    return 0
+
+
+### VTCM
+
+vtcm_size = 4*1024*1024  # pylint: disable=invalid-name
+@register_func('tvm.info.mem.local.vtcm')
+def mem_info_vtcm():
+    # pylint: disable=bad-whitespace
+    return tvm.ir.make_node('MemoryInfo',
+                            unit_bits = 8,
+                            max_num_bits = vtcm_size*8,
+                            max_simd_bits = 128*8,
+                            head_address = tvm.runtime.const(100, 'uint32'))
+
+def lower_vtcm_(get_alloc, get_free, def_align, func, mod, ctx):  # pylint: 
disable=unused-argument
+    """Generic VTCM allocation
+
+    Parameters
+    ----------
+    get_alloc : function: tir.Allocate, int -> tir.expr (dtype='handle')
+      The VTCM allocation function. It takes an Allocate statement, and the 
required
+      alignment, and returns a pointer to the allocated VTCM buffer.
+    get_free : function: tir.expr (dtype='handle') -> None
+      The VTCM deallocation function. It takes the address of the allocated 
buffer
+      and frees it. It returns no value.
+    def_align : int
+      The default alignment that will be passed to the allocation function, if 
the
+      program does not specify the alignment via a 'storage_alignment' 
attribute.
+    func : tir.PrimFunc
+    mod : tvm.IRModule
+    ctx : transform.PassContext
+
+    Returns
+    -------
+    stmt : tvm.stmt
+        Transformed function body.
+    """
+
+    vtcm_buffers = []
+    alignments = {}
+
+    def buf_align(var):
+        """Determine the alignment of the buffer with variable 'var'."""
+        if var in alignments and alignments[var]:
+            return alignments[var][-1]
+        return def_align
+
+    def visit(stmt):
+        """Collect information about VTCM buffers and their alignments."""
+        if isinstance(stmt, tvm.tir.AttrStmt):
+            if stmt.attr_key == 'storage_scope' and stmt.value == 'local.vtcm':
+                vtcm_buffers.append(stmt.node)
+            elif stmt.attr_key == 'storage_alignment':
+                if not stmt.node in alignments:
+                    alignments[stmt.node] = []
+                alignments[stmt.node].append(stmt.value)
+
+    def mutate(stmt):
+        """Insert calls to VTCM allocation and deallocation routines."""
+        if isinstance(stmt, tvm.tir.AttrStmt):
+            if stmt.attr_key == 'storage_scope' and stmt.value == 'local.vtcm':
+                vtcm_buffers.pop()
+            elif stmt.attr_key == 'storage_alignment':
+                alignments[stmt.node].pop()
+            return stmt
+        if isinstance(stmt, tvm.tir.Allocate):
+            var = stmt.buffer_var
+            if var in vtcm_buffers:
+                is_null = tvm.tir.call_intrin('bool', 
tvm.ir.Op.get('tir.isnullptr'), var)
+                throw_error = \
+                    tvm.tir.call_intrin('int32', 
tvm.ir.Op.get('tir.tvm_throw_last_error'))
+                body_w_free = tvm.tir.SeqStmt([stmt.body, 
tvm.tir.Evaluate(get_free(var))])
+                body_w_check = \
+                    tvm.tir.IfThenElse(is_null, tvm.tir.Evaluate(throw_error), 
body_w_free)
+                return tvm.tir.LetStmt(stmt.buffer_var, get_alloc(stmt, 
buf_align(var)),
+                                       body_w_check)
+            return stmt
+        raise ValueError("Wrong argument type (" + type(stmt) + ") to 
'mutate'")
+
+    f = func.with_body(tvm.tir.stmt_functor.ir_transform(func.body, visit, 
mutate,
+                                                         ['tir.Allocate', 
'tir.AttrStmt']))
+    return f
+
+
+def ir_lower_vtcm():
+    """Create a VTCM lowering pass.
+
+    VTCM memory has to be allocated using special functions.
+    """
+    def get_alloc(stmt, align):
+        assert isinstance(stmt, tvm.tir.Allocate)
+        return tvm.tir.call_extern('handle', 'HexagonBackendAllocateVTCM',
+                                   ft.reduce(lambda x, y: x*y, stmt.extents, 
1), align)
+    def get_free(var):
+        return tvm.tir.call_extern('handle', 'HexagonBackendFreeVTCM', var)
+
+    # pylint: disable=bad-whitespace
+    @tvm.tir.transform.prim_func_pass(opt_level = 0, name = "Lower VTCM pass")
+    def transform(func, mod, ctx):
+        return lower_vtcm_(get_alloc, get_free, 2048, func, mod, ctx)
+
+    return transform
+
+def ir_lower_vtcm_pass():
+    return [(3, ir_lower_vtcm())]
diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py
index 1cde875..0ea19f8 100644
--- a/python/tvm/target/target.py
+++ b/python/tvm/target/target.py
@@ -236,7 +236,7 @@ def bifrost(model='unknown', options=None):
     return _ffi_api.TargetCreate("opencl", *opts)
 
 
-def hexagon(cpu_ver='v66', sim_args=None, hvx=128):
+def hexagon(cpu_ver='v66', sim_args=None, llvm_args=None, hvx=128):
     """Returns a Hexagon target.
 
     Parameters
@@ -249,6 +249,8 @@ def hexagon(cpu_ver='v66', sim_args=None, hvx=128):
         Otherwise, separate versions are used for codegen and sim. Not
         all allowed cpu strings will be valid, simulator will throw an
         error if invalid. Does not affect codegen.
+    llvm_args : str or list of str
+        User defined compiler arguments.
     hvx : int
         Size of hvx register. Value of 0 indicates disabled hvx.
     """
@@ -274,7 +276,7 @@ def hexagon(cpu_ver='v66', sim_args=None, hvx=128):
         # HVX enable
         if hvx:
             mattr = ' -mattr=+hvx' + cpu_ver + ',+hvx-length' + str(hvx) + 'b'
-        return 'llvm' + target + mcpu + mattr
+        return target + mcpu + mattr
 
     # Simulator string
     def create_sim(cpu_ver, sim_args):
@@ -325,12 +327,24 @@ def hexagon(cpu_ver='v66', sim_args=None, hvx=128):
 
         return sim_cpu + ' ' + validate_hvx_length(hvx, sim_args)
 
+    # LLVM string
+    def create_llvm(llvm_args):
+        # TVM's option parser doesn't allow '=' in values, but '=' can
+        # appear in LLVM flags. Replace it with '@', since it's unlikely
+        # that '@' will be used in another context.
+        if llvm_args is None or len(llvm_args.replace(' ', '')) == 0:
+            return ''
+        args = [s.replace('=', '@') for s in llvm_args.split()]
+        return '--llvm-options=' + ','.join(args)
+
     # Sim args
     os.environ['HEXAGON_SIM_ARGS'] = create_sim(cpu_ver, sim_args)
 
     target_str = create_target(cpu_ver)
-    args_list = target_str.split()
-    return _ffi_api.TargetCreate("hexagon", *args_list)
+    llvm_str = create_llvm(llvm_args)
+    args_list = target_str.split() + llvm_str.split()
+
+    return _ffi_api.TargetCreate('hexagon', *args_list)
 
 
 def create(target_str):
diff --git a/src/runtime/hexagon/hexagon_module.cc 
b/src/runtime/hexagon/hexagon_module.cc
index 6b7ca1c..66e2a56 100644
--- a/src/runtime/hexagon/hexagon_module.cc
+++ b/src/runtime/hexagon/hexagon_module.cc
@@ -195,7 +195,8 @@ class HexagonModuleNode final : public runtime::ModuleNode {
                     std::unordered_map<std::string, FunctionInfo> fmap, 
std::string asm_str,
                     std::string obj_str, std::string ir_str, std::string 
bc_str,
                     const std::set<std::string>& packed_c_abi)
-      : hexagon_device_(hexagon::Device::Global()),
+      : hexagon_device_(),
+        dl_handle_(nullptr),
         data_(data),
         fmt_(fmt),
         fmap_(fmap),
@@ -203,9 +204,8 @@ class HexagonModuleNode final : public runtime::ModuleNode {
         obj_(obj_str),
         ir_(ir_str),
         bc_(bc_str),
-        packed_c_abi_funcs_(packed_c_abi) {
-    dl_handle_ = hexagon_device_->Load(data, fmt);
-  }
+        packed_c_abi_funcs_(packed_c_abi) {}
+
   ~HexagonModuleNode() {
     if (dl_handle_) {
       hexagon_device_->Unload(dl_handle_);
@@ -213,6 +213,7 @@ class HexagonModuleNode final : public runtime::ModuleNode {
   }
 
   PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& 
sptr_to_self) final;
+  std::string GetSource(const std::string& format) final;
 
   const char* type_key() const final { return "hexagon"; }
 
@@ -333,6 +334,9 @@ PackedFunc HexagonModuleNode::GetFunction(const 
std::string& name,
   auto f = fmap_.find(name);
   if (f == fmap_.end()) return PackedFunc(nullptr);
 
+  if (!hexagon_device_) hexagon_device_ = hexagon::Device::Global();
+  if (!dl_handle_) dl_handle_ = hexagon_device_->Load(data_, fmt_);
+
   // Get function pointer from device.
   void* pf = hexagon_device_->Resolve(name);
   // The cast result and the original share ownership. Do the cast here
@@ -355,6 +359,16 @@ PackedFunc HexagonModuleNode::GetFunction(const 
std::string& name,
   }
 }
 
+std::string HexagonModuleNode::GetSource(const std::string& format) {
+  if (format == "s" || format == "asm") {
+    return asm_;
+  }
+  if (format == "ll") {
+    return ir_;
+  }
+  return "";
+}
+
 void HexagonModuleNode::RemapArgs(const TVMArgs& args, std::vector<TVMValue>& 
values,
                                   std::vector<int>& type_codes,
                                   std::vector<void*>& remote_tensors) const {
diff --git a/src/runtime/module.cc b/src/runtime/module.cc
index 8052467..98b0b3a 100644
--- a/src/runtime/module.cc
+++ b/src/runtime/module.cc
@@ -138,6 +138,8 @@ bool RuntimeEnabled(const std::string& target) {
     f_name = "device_api.rpc";
   } else if (target == "micro_dev") {
     f_name = "device_api.micro_dev";
+  } else if (target == "hexagon") {
+    f_name = "device_api.hexagon";
   } else if (target.length() >= 5 && target.substr(0, 5) == "nvptx") {
     f_name = "device_api.gpu";
   } else if (target.length() >= 4 && target.substr(0, 4) == "rocm") {
diff --git a/src/target/llvm/codegen_hexagon.cc 
b/src/target/llvm/codegen_hexagon.cc
new file mode 100644
index 0000000..eefd17c
--- /dev/null
+++ b/src/target/llvm/codegen_hexagon.cc
@@ -0,0 +1,812 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#if defined(TVM_LLVM_VERSION) && TVM_LLVM_VERSION >= 70
+
+#include <llvm/Bitcode/BitcodeWriter.h>
+#if TVM_LLVM_VERSION <= 90
+#include <llvm/IR/Intrinsics.h>
+#else
+#include <llvm/IR/IntrinsicsHexagon.h>
+#endif
+#include <llvm/Support/CommandLine.h>
+#include <tvm/runtime/module.h>
+#include <tvm/target/codegen.h>
+#include <tvm/tir/analysis.h>
+
+#include <cstdio>
+#include <cstdlib>
+#include <map>
+#include <sstream>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "../../runtime/hexagon/hexagon_module.h"
+#include "../build_common.h"
+#include "codegen_llvm.h"
+
+namespace tvm {
+namespace codegen {
+
+static std::string get_name(const PrimFunc& f) {
+  auto global_symbol = f->GetAttr<runtime::String>(tvm::attr::kGlobalSymbol);
+  CHECK(global_symbol.defined())
+      << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute";
+  return std::string(global_symbol.value());
+}
+
+// Hexagon code generation
+class CodeGenHexagon final : public CodeGenLLVM {
+ public:
+  void InitTarget(llvm::TargetMachine* tm) final;
+  void Init(const std::string& module_name, llvm::TargetMachine* tm, 
llvm::LLVMContext* ctx,
+            bool system_lib, bool dynamic_lookup, bool target_c_runtime) final;
+
+  void VisitStmt_(const AssertStmtNode* op) override;
+
+  llvm::Value* CreateIntrinsic(const CallNode* op) override;
+  llvm::Value* CreateCallExtern(Type ret_type, String global_symbol, const 
Array<PrimExpr>& args,
+                                bool skip_first_arg) override;
+  llvm::Module* GetModulePtr() const { return module_.get(); }
+
+ protected:
+  // meta data
+  llvm::MDNode* md_tbaa_ctx_ptr_{nullptr};
+  llvm::FunctionType* ftype_tvm_func_call_{nullptr};
+  llvm::FunctionType* ftype_tvm_get_func_from_env_{nullptr};
+  llvm::FunctionType* ftype_tvm_api_set_last_error_{nullptr};
+
+ private:
+  llvm::Value* CreateStructRefPtr(DataType t, llvm::Value* buf, llvm::Value* 
index, int kind);
+
+  // Check if the call to packed function is successful
+  // if not directly finalize function and pass on return code.
+  // return the end block after the check
+  llvm::BasicBlock* CheckCallSuccess(llvm::Value* retcode);
+
+  // Get runtime functions
+  llvm::Value* RuntimeTVMFuncCall();
+  llvm::Value* RuntimeTVMGetFuncFromEnv();
+  llvm::Value* RuntimeTVMAPISetLastError();
+
+  void InitGlobalContext(bool dynamic_lookup);
+  llvm::GlobalVariable* InitContextPtr(llvm::Type* type, std::string name);
+  llvm::Value* GetContextPtr(llvm::GlobalVariable* gv);
+  std::vector<std::pair<std::string, llvm::Value*>> export_system_symbols_;
+  llvm::Value* GetPackedFuncHandle(const std::string& str);
+
+  // global to packed function handle
+  std::unordered_map<std::string, llvm::GlobalVariable*> func_handle_map_;
+
+  // Make packed call.
+  llvm::BasicBlock* MakeCallPacked(const Array<PrimExpr>& args, llvm::Value** 
rvalue,
+                                   llvm::Value** ret_tcode, const DataType& 
r_type,
+                                   const int64_t begin, const int64_t end);
+  // create call into tvm packed function.
+  llvm::Value* CreateCallPacked(const CallNode* op);
+  // Create trace call into tvm packed function.
+  llvm::Value* CreateCallTracePacked(const CallNode* op);
+
+  std::map<std::string, llvm::Type*> types_for_alloca_;
+
+  // Type definitions.
+  llvm::Type* t_tvm_func_handle_{nullptr};
+  llvm::Type* t_tvm_value_{nullptr};
+  llvm::Type* t_tvm_shape_index_{nullptr};
+  llvm::Type* t_tvm_context_{nullptr};
+  llvm::Type* t_tvm_type_{nullptr};
+  llvm::Type* t_tvm_array_{nullptr};
+
+  // Context for injection lookup
+  llvm::GlobalVariable* gv_mod_ctx_{nullptr};
+  llvm::GlobalVariable* gv_tvm_func_call_{nullptr};
+  llvm::GlobalVariable* gv_tvm_get_func_from_env_{nullptr};
+  llvm::GlobalVariable* gv_tvm_api_set_last_error_{nullptr};
+  std::unordered_map<std::string, llvm::GlobalVariable*> gv_func_map_;
+
+  // context for direct dynamic lookup
+  llvm::Function* f_tvm_func_call_{nullptr};
+  llvm::Function* f_tvm_get_func_from_env_{nullptr};
+  llvm::Function* f_tvm_api_set_last_error_{nullptr};
+  llvm::Function* f_tvm_register_system_symbol_{nullptr};
+};
+
+void CodeGenHexagon::InitTarget(llvm::TargetMachine* tm) {
+  native_vector_bits_ = 64;  // Assume "scalar" vectors at first.
+  llvm::StringRef fs = tm->getTargetFeatureString();
+  size_t npos = llvm::StringRef::npos;
+  const auto hvx_length_feature = "+hvx-length";  // +hvx-length{64|128}b
+  size_t len_begin = fs.find(hvx_length_feature);
+  size_t len_end = len_begin != npos ? fs.find('b', len_begin) : npos;
+  if (len_end != npos) {
+    int hvx_bytes = 0;
+    len_begin += std::strlen(hvx_length_feature);
+    CHECK(!fs.substr(len_begin, len_end - len_begin).getAsInteger(10, 
hvx_bytes))
+        << "invalid HVX length in feature string: " << fs.str();
+    CHECK(hvx_bytes == 64 || hvx_bytes == 128)
+        << "invalid HVX vector length: " << hvx_bytes << ", should be 64 or 
128";
+    native_vector_bits_ = hvx_bytes * 8;
+  }
+  CodeGenLLVM::InitTarget(tm);
+}
+
+void CodeGenHexagon::Init(const std::string& module_name, llvm::TargetMachine* 
tm,
+                          llvm::LLVMContext* ctx, bool system_lib, bool 
dynamic_lookup,
+                          bool target_c_runtime) {
+  CodeGenLLVM::Init(module_name, tm, ctx, system_lib, dynamic_lookup, false);
+
+  func_handle_map_.clear();
+  t_tvm_value_ = llvm::StructType::create({t_float64_}, "t_tvm_value");
+  t_tvm_shape_index_ = llvm::Type::getIntNTy(*ctx, 
DataType::ShapeIndex().bits());
+  t_tvm_context_ = llvm::StructType::create({t_int_, t_int_}, "t_tvm_context");
+  t_tvm_type_ = llvm::StructType::create({t_int8_, t_int8_, t_int16_}, 
"t_tvm_type");
+  t_tvm_func_handle_ = t_void_p_;
+  // DLTensor
+  t_tvm_array_ = llvm::StructType::create(
+      {t_void_p_, t_tvm_context_, t_int_, t_tvm_type_, 
t_tvm_shape_index_->getPointerTo(),
+       t_tvm_shape_index_->getPointerTo(), t_int64_},
+      "t_tvm_array");
+
+  types_for_alloca_ = {
+      {"shape", t_tvm_shape_index_},
+      {"arg_value", t_tvm_value_},
+      {"arg_tcode", t_int_},
+      {"array", t_tvm_array_},
+  };
+
+  // Runtime functions.
+  ftype_tvm_func_call_ = llvm::FunctionType::get(
+      t_int_,
+      {t_tvm_func_handle_, t_tvm_value_->getPointerTo(), 
t_int_->getPointerTo(), t_int_,
+       t_tvm_value_->getPointerTo(), t_int_->getPointerTo()},
+      false);
+  ftype_tvm_get_func_from_env_ = llvm::FunctionType::get(
+      t_int_, {t_void_p_, t_char_->getPointerTo(), 
t_tvm_func_handle_->getPointerTo()}, false);
+  ftype_tvm_api_set_last_error_ =
+      llvm::FunctionType::get(t_void_, {t_char_->getPointerTo()}, false);
+  md_tbaa_ctx_ptr_ = md_builder_->createTBAAScalarTypeNode("ctx_ptr", 
md_tbaa_root_);
+
+  // initialize TVM runtime API
+  if (system_lib) {
+    // We will need this in environment for backward registration.
+    f_tvm_register_system_symbol_ = llvm::Function::Create(
+        llvm::FunctionType::get(t_int_, {t_char_->getPointerTo(), t_void_p_}, 
false),
+        llvm::Function::ExternalLinkage, "TVMBackendRegisterSystemLibSymbol", 
module_.get());
+  } else {
+    f_tvm_register_system_symbol_ = nullptr;
+  }
+  this->InitGlobalContext(dynamic_lookup);
+}
+
+llvm::Value* CodeGenHexagon::CreateCallExtern(Type ret_type, String 
global_symbol,
+                                              const Array<PrimExpr>& args, 
bool skip_first_arg) {
+  std::vector<llvm::Value*> arg_values;
+  for (size_t i = skip_first_arg; i < args.size(); ++i) {
+    arg_values.push_back(MakeValue(args[i]));
+  }
+  std::vector<llvm::Type*> arg_types;
+  for (llvm::Value* v : arg_values) {
+    arg_types.push_back(v->getType());
+  }
+  llvm::FunctionType* ftype = llvm::FunctionType::get(GetLLVMType(ret_type), 
arg_types, false);
+  // Check if it is available in global function table as injected function.
+  auto it = gv_func_map_.find(global_symbol);
+  if (it != gv_func_map_.end()) {
+    if (it->second == nullptr) {
+      gv_func_map_[global_symbol] = InitContextPtr(ftype->getPointerTo(), "__" 
+ global_symbol);
+      it = gv_func_map_.find(global_symbol);
+    }
+#if TVM_LLVM_VERSION >= 90
+    auto ext_callee = llvm::FunctionCallee(ftype, GetContextPtr(it->second));
+#else
+    auto ext_callee = GetContextPtr(it->second);
+#endif
+    return builder_->CreateCall(ext_callee, arg_values);
+  } else {
+    llvm::Function* f = module_->getFunction(global_symbol);
+    if (f == nullptr) {
+      f = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage,
+                                 global_symbol.operator llvm::StringRef(), 
module_.get());
+    }
+#if TVM_LLVM_VERSION >= 90
+    auto ext_callee = llvm::FunctionCallee(f);
+#else
+    auto ext_callee = f;
+#endif
+    return builder_->CreateCall(ext_callee, arg_values);
+  }
+}
+
+llvm::GlobalVariable* CodeGenHexagon::InitContextPtr(llvm::Type* p_type, 
std::string name) {
+  llvm::GlobalVariable* gv = new llvm::GlobalVariable(
+      *module_, p_type, false, llvm::GlobalValue::LinkOnceAnyLinkage, 0, name);
+#if TVM_LLVM_VERSION >= 100
+  gv->setAlignment(llvm::Align(data_layout_->getTypeAllocSize(p_type)));
+#else
+  gv->setAlignment(data_layout_->getTypeAllocSize(p_type));
+#endif
+  gv->setInitializer(llvm::Constant::getNullValue(p_type));
+  
gv->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
+  return gv;
+}
+
+llvm::Value* CodeGenHexagon::GetContextPtr(llvm::GlobalVariable* gv) {
+  CHECK(gv != nullptr);
+#if TVM_LLVM_VERSION >= 110
+  llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv, 
llvm::Align(gv->getAlignment()));
+#else
+  llvm::LoadInst* faddr = builder_->CreateAlignedLoad(gv, gv->getAlignment());
+#endif
+  faddr->setMetadata("tbaa",
+                     md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, 
md_tbaa_ctx_ptr_, 0));
+  return faddr;
+}
+
+void CodeGenHexagon::InitGlobalContext(bool dynamic_lookup) {
+  // Module context
+  gv_mod_ctx_ = InitContextPtr(t_void_p_, 
tvm::runtime::symbol::tvm_module_ctx);
+  // Register back the locations.
+  if (f_tvm_register_system_symbol_ != nullptr) {
+    export_system_symbols_.emplace_back(
+        std::make_pair(tvm::runtime::symbol::tvm_module_ctx, gv_mod_ctx_));
+  } else {
+    if (!dynamic_lookup) {
+      gv_tvm_func_call_ = InitContextPtr(ftype_tvm_func_call_->getPointerTo(), 
"__TVMFuncCall");
+      gv_tvm_get_func_from_env_ = 
InitContextPtr(ftype_tvm_get_func_from_env_->getPointerTo(),
+                                                 "__TVMBackendGetFuncFromEnv");
+      gv_tvm_api_set_last_error_ =
+          InitContextPtr(ftype_tvm_api_set_last_error_->getPointerTo(), 
"__TVMAPISetLastError");
+      // Mark as context functions
+      gv_func_map_["TVMBackendAllocWorkspace"] = nullptr;
+      gv_func_map_["TVMBackendFreeWorkspace"] = nullptr;
+    }
+  }
+}
+
+llvm::Value* CodeGenHexagon::RuntimeTVMFuncCall() {
+  if (f_tvm_func_call_ != nullptr) return f_tvm_func_call_;
+  return GetContextPtr(gv_tvm_func_call_);
+}
+
+llvm::Value* CodeGenHexagon::RuntimeTVMGetFuncFromEnv() {
+  if (f_tvm_get_func_from_env_ != nullptr) return f_tvm_get_func_from_env_;
+  return GetContextPtr(gv_tvm_get_func_from_env_);
+}
+
+llvm::Value* CodeGenHexagon::RuntimeTVMAPISetLastError() {
+  if (f_tvm_api_set_last_error_ != nullptr) return f_tvm_api_set_last_error_;
+  return GetContextPtr(gv_tvm_api_set_last_error_);
+}
+
+llvm::BasicBlock* CodeGenHexagon::MakeCallPacked(const Array<PrimExpr>& args, 
llvm::Value** rvalue,
+                                                 llvm::Value** ret_tcode, 
const DataType& r_type,
+                                                 const int64_t begin, const 
int64_t end) {
+  using llvm::BasicBlock;
+  // using namespace tir;
+  std::string func_name = args[0].as<StringImmNode>()->value;
+  llvm::Value* handle = GetPackedFuncHandle(func_name);
+  // call the function
+  int64_t nargs = end - begin;
+  CHECK_GE(nargs, 0);
+  llvm::Value* stack_value = MakeValue(args[1]);
+  llvm::Value* stack_tcode = MakeValue(args[2]);
+  llvm::Value* arg_value = builder_->CreateInBoundsGEP(
+      builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), 
ConstInt32(begin));
+  llvm::Value* arg_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, 
ConstInt32(begin));
+  llvm::Value* ret_value = builder_->CreateInBoundsGEP(
+      builder_->CreatePointerCast(stack_value, t_tvm_value_->getPointerTo()), 
ConstInt32(end));
+  *ret_tcode = CreateBufferPtr(DataType::Int(32), stack_tcode, 
ConstInt32(end));
+#if TVM_LLVM_VERSION >= 90
+  auto call_callee = llvm::FunctionCallee(ftype_tvm_func_call_, 
RuntimeTVMFuncCall());
+#else
+  auto call_callee = RuntimeTVMFuncCall();
+#endif
+  BasicBlock* end_block = CheckCallSuccess(builder_->CreateCall(
+      call_callee, {handle, arg_value, arg_tcode, ConstInt32(nargs), 
ret_value, *ret_tcode}));
+  DataType r_api_type = tir::APIType(r_type);
+#if TVM_LLVM_VERSION >= 110
+  *rvalue = builder_->CreateAlignedLoad(
+      builder_->CreatePointerCast(ret_value, 
DTypeToLLVMType(r_api_type)->getPointerTo()),
+      llvm::Align(8));
+#else
+  *rvalue = builder_->CreateAlignedLoad(
+      builder_->CreatePointerCast(ret_value, 
DTypeToLLVMType(r_api_type)->getPointerTo()), 8);
+#endif
+  *rvalue = CreateCast(r_api_type, r_type, *rvalue);
+  return end_block;
+}
+
+llvm::Value* CodeGenHexagon::GetPackedFuncHandle(const std::string& fname) {
+  using llvm::BasicBlock;
+  // We will store the packed function handle in global space.
+  // Initialize it during the first call.
+  llvm::DataLayout layout(module_.get());
+  uint64_t align = layout.getTypeAllocSize(t_tvm_func_handle_);
+  auto it = func_handle_map_.find(fname);
+
+  llvm::GlobalVariable* hptr;
+  if (it == func_handle_map_.end()) {
+    // create global location for the handle
+    // create the function handle
+    hptr =
+        new llvm::GlobalVariable(*module_, t_tvm_func_handle_, false,
+                                 llvm::GlobalValue::InternalLinkage, nullptr, 
".tvm_func." + fname);
+#if TVM_LLVM_VERSION >= 100
+    hptr->setAlignment(llvm::Align(align));
+#else
+    hptr->setAlignment(align);
+#endif
+    hptr->setInitializer(llvm::Constant::getNullValue(t_tvm_func_handle_));
+    func_handle_map_[fname] = hptr;
+  } else {
+    hptr = it->second;
+  }
+  // create emit codes that checks and load the function.
+  BasicBlock* pre_block = builder_->GetInsertBlock();
+  BasicBlock* init_block = BasicBlock::Create(*ctx_, "handle_init", function_);
+  BasicBlock* end_block = BasicBlock::Create(*ctx_, "handle_init_end", 
function_);
+#if TVM_LLVM_VERSION >= 110
+  llvm::Value* handle = builder_->CreateAlignedLoad(hptr, llvm::Align(align));
+#else
+  llvm::Value* handle = builder_->CreateAlignedLoad(hptr, align);
+#endif
+  llvm::Value* handle_not_null =
+      builder_->CreateICmpNE(handle, 
llvm::Constant::getNullValue(t_tvm_func_handle_));
+  builder_->CreateCondBr(handle_not_null, end_block, init_block, 
md_very_likely_branch_);
+  // Initialize the handle if needed.
+  builder_->SetInsertPoint(init_block);
+  llvm::Value* out =
+      WithFunctionEntry([&]() { return 
builder_->CreateAlloca(t_tvm_func_handle_); });
+#if TVM_LLVM_VERSION >= 110
+  llvm::LoadInst* ctx =
+      builder_->CreateAlignedLoad(gv_mod_ctx_, 
llvm::Align(gv_mod_ctx_->getAlignment()));
+#else
+  llvm::LoadInst* ctx = builder_->CreateAlignedLoad(gv_mod_ctx_, 
gv_mod_ctx_->getAlignment());
+#endif
+  ctx->setMetadata("tbaa",
+                   md_builder_->createTBAAStructTagNode(md_tbaa_ctx_ptr_, 
md_tbaa_ctx_ptr_, 0));
+#if TVM_LLVM_VERSION >= 90
+  auto env_callee = llvm::FunctionCallee(ftype_tvm_get_func_from_env_, 
RuntimeTVMGetFuncFromEnv());
+#else
+  auto env_callee = RuntimeTVMGetFuncFromEnv();
+#endif
+  llvm::Value* retcode = builder_->CreateCall(env_callee, {ctx, 
GetConstString(fname), out});
+  init_block = CheckCallSuccess(retcode);
+#if TVM_LLVM_VERSION >= 110
+  llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, 
llvm::Align(align));
+#else
+  llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, align);
+#endif
+  // Store the handle
+  builder_->CreateStore(loaded_handle, hptr);
+  builder_->CreateBr(end_block);
+  // end block
+  builder_->SetInsertPoint(end_block);
+  llvm::PHINode* phi = builder_->CreatePHI(t_tvm_func_handle_, 2);
+  phi->addIncoming(handle, pre_block);
+  phi->addIncoming(loaded_handle, init_block);
+  return phi;
+}
+
+llvm::Value* CodeGenHexagon::CreateCallPacked(const CallNode* op) {
+  // There is always a call to __tvm_set_device in a standalone op,
+  // and we can't have calls to packed functions, because they need
+  // a Module object to work (or at least TVMBackendGetFuncFromEnv
+  // function).
+  const std::string& name = op->args[0].as<StringImmNode>()->value;
+  if (name == "__tvm_set_device") {
+    return ConstInt32(0);
+  }
+
+  CHECK_EQ(op->args.size(), 5U);
+  llvm::Value* rvalue = nullptr;
+  llvm::Value* ret_tcode = nullptr;
+  MakeCallPacked(op->args, &rvalue, &ret_tcode, op->dtype, 
op->args[3].as<IntImmNode>()->value,
+                 op->args[4].as<IntImmNode>()->value);
+  return rvalue;
+}
+
+llvm::Value* CodeGenHexagon::CreateCallTracePacked(const CallNode* op) {
+  using llvm::BasicBlock;
+  CHECK_EQ(op->args.size(), 6U);
+  llvm::Value* rvalue = nullptr;
+  llvm::Value* ret_tcode = nullptr;
+  BasicBlock* end_block =
+      MakeCallPacked(op->args, &rvalue, &ret_tcode, op->dtype, 
op->args[3].as<IntImmNode>()->value,
+                     op->args[4].as<IntImmNode>()->value);
+  // Get traced value.
+  llvm::Value* traced_value = MakeValue(op->args[5]);
+  // The update_block handles case when we need to update the return value.
+  BasicBlock* update_block = BasicBlock::Create(*ctx_, "update_block", 
function_);
+  // The continue_block handles case when we need to return original
+  // traced value.
+  BasicBlock* continue_block = BasicBlock::Create(*ctx_, "continue_block", 
function_);
+#if TVM_LLVM_VERSION >= 110
+  llvm::Value* ret_tcode_value = builder_->CreateAlignedLoad(ret_tcode, 
llvm::Align(8));
+#else
+  llvm::Value* ret_tcode_value = builder_->CreateAlignedLoad(ret_tcode, 8);
+#endif
+  // Check the ret_type_code and create cmp instruction.
+  llvm::Value* cmp =
+      builder_->CreateICmpNE(ret_tcode_value, llvm::ConstantInt::get(t_int_, 
kTVMNullptr));
+  builder_->CreateCondBr(cmp, update_block, continue_block);
+  builder_->SetInsertPoint(update_block);
+  builder_->CreateBr(continue_block);
+  builder_->SetInsertPoint(continue_block);
+  // The return value depends on from what bb we come from.
+  llvm::PHINode* phi_rvalue = builder_->CreatePHI(traced_value->getType(), 2);
+  phi_rvalue->addIncoming(rvalue, update_block);
+  phi_rvalue->addIncoming(traced_value, end_block);
+  return phi_rvalue;
+}
+
+llvm::BasicBlock* CodeGenHexagon::CheckCallSuccess(llvm::Value* retcode) {
+  // create emit codes that checks and load the function.
+  using llvm::BasicBlock;
+  BasicBlock* fail_block = BasicBlock::Create(*ctx_, "call_fail", function_);
+  BasicBlock* end_block = BasicBlock::Create(*ctx_, "call_end", function_);
+  llvm::Value* succ = builder_->CreateICmpEQ(retcode, 
llvm::ConstantInt::get(t_int_, 0));
+  builder_->CreateCondBr(succ, end_block, fail_block, md_very_likely_branch_);
+  builder_->SetInsertPoint(fail_block);
+  // return the code.
+  builder_->CreateRet(retcode);
+  // otherwise set it to be new end.
+  builder_->SetInsertPoint(end_block);
+  return end_block;
+}
+
+void CodeGenHexagon::VisitStmt_(const AssertStmtNode* op) {
+  using llvm::BasicBlock;
+  llvm::Value* cond = MakeValue(op->condition);
+  std::ostringstream os;
+  os << "Assert fail: " << op->condition;
+  if (op->message.as<StringImmNode>()) {
+    os << ", " << op->message.as<StringImmNode>()->value;
+  }
+  llvm::Value* msg = GetConstString(os.str());
+  BasicBlock* fail_block = BasicBlock::Create(*ctx_, "assert_fail", function_);
+  BasicBlock* end_block = BasicBlock::Create(*ctx_, "assert_end", function_);
+  builder_->CreateCondBr(cond, end_block, fail_block, md_very_likely_branch_);
+  // fail condition.
+  builder_->SetInsertPoint(fail_block);
+#if TVM_LLVM_VERSION >= 90
+  auto err_callee =
+      llvm::FunctionCallee(ftype_tvm_api_set_last_error_, 
RuntimeTVMAPISetLastError());
+#else
+  auto err_callee = RuntimeTVMAPISetLastError();
+#endif
+  builder_->CreateCall(err_callee, {msg});
+  builder_->CreateRet(ConstInt32(-1));
+  // otherwise set it to be new end.
+  builder_->SetInsertPoint(end_block);
+  CodeGenLLVM::VisitStmt_(op);
+}
+
+llvm::Value* CodeGenHexagon::CreateIntrinsic(const CallNode* op) {
+  if (op->op.same_as(builtin::tvm_call_packed_lowered())) {
+    return CreateCallPacked(op);
+  } else if (op->op.same_as(builtin::tvm_call_trace_packed_lowered())) {
+    return CreateCallTracePacked(op);
+  } else if (op->op.same_as(builtin::tvm_struct_get())) {
+    CHECK_EQ(op->args.size(), 3);
+    int kind = op->args[2].as<IntImmNode>()->value;
+    llvm::Value* ref =
+        CreateStructRefPtr(op->dtype, MakeValue(op->args[0]), 
MakeValue(op->args[1]), kind);
+    if (kind == builtin::kArrAddr) {
+      return builder_->CreatePointerCast(ref, t_void_p_);
+    }
+    return builder_->CreateLoad(ref);
+  } else if (op->op.same_as(builtin::tvm_struct_set())) {
+    CHECK_EQ(op->args.size(), 4);
+    int kind = op->args[2].as<IntImmNode>()->value;
+    CHECK(kind != builtin::kArrAddr);
+    llvm::Value* ref = CreateStructRefPtr(op->args[3].dtype(), 
MakeValue(op->args[0]),
+                                          MakeValue(op->args[1]), kind);
+    llvm::Value* value = MakeValue(op->args[3]);
+    if (value->getType()->isPointerTy()) {
+      value = builder_->CreatePointerCast(value, 
ref->getType()->getPointerElementType());
+    }
+    builder_->CreateStore(value, ref);
+    return ConstInt32(0);
+  } else if (op->op.same_as(builtin::tvm_stack_alloca())) {
+    CHECK_EQ(op->args.size(), 2);
+    const std::string& name = op->args[0].as<StringImmNode>()->value;
+    llvm::Value* size = ConstInt32(op->args[1].as<IntImmNode>()->value);
+    return builder_->CreateAlloca(types_for_alloca_.at(name), size);
+  } else if (op->op.same_as(builtin::tvm_throw_last_error())) {
+    llvm::Value* neg_1 = ConstInt32(-1);
+    builder_->CreateRet(neg_1);
+    auto next_block = std::next(builder_->GetInsertBlock()->getIterator());
+    llvm::BasicBlock* new_bb = llvm::BasicBlock::Create(*ctx_, "cont", 
function_, &*next_block);
+    builder_->SetInsertPoint(new_bb);
+    return neg_1;
+  }
+
+  return CodeGenLLVM::CreateIntrinsic(op);
+}
+
+llvm::Value* CodeGenHexagon::CreateStructRefPtr(DataType t, llvm::Value* buf, 
llvm::Value* index,
+                                                int kind) {
+  static const std::map<int, int> field_index = {
+      {builtin::kArrData, 0},      {builtin::kArrDeviceType, 1}, 
{builtin::kArrDeviceId, 1},
+      {builtin::kArrNDim, 2},      {builtin::kArrTypeCode, 3},   
{builtin::kArrTypeBits, 3},
+      {builtin::kArrTypeLanes, 3}, {builtin::kArrShape, 4},      
{builtin::kArrStrides, 5},
+      {builtin::kArrByteOffset, 6}};
+  static const std::map<int, int> subfield_index = {
+      {builtin::kArrDeviceType, 0}, {builtin::kArrDeviceId, 1},  
{builtin::kArrTypeCode, 0},
+      {builtin::kArrTypeBits, 1},   {builtin::kArrTypeLanes, 2},
+  };
+
+  if (kind < builtin::kArrKindBound_) {
+    if (buf->getType() == t_void_p_) {
+      buf = builder_->CreatePointerCast(buf, t_tvm_array_->getPointerTo());
+    } else {
+      CHECK_EQ(buf->getType(), t_tvm_array_->getPointerTo());
+    }
+    /* The following "kinds" are accessing the members of DLTensor:
+       typedef struct {
+         void* data;            kArrData
+         DLContext ctx;         kArrDeviceType (ctx.device_type)
+                                kArrDeviceId (ctx.device_id)
+         int ndim;              kArrNDim
+         DLDataType dtype;      kArrTypeCode (dtype.code)
+                                kArrTypeBits (dtype.bits)
+                                kArrTypeLanes (dtype.lanes)
+         int64_t* shape;        kArrShape
+         int64_t* strides;      kArrStrides
+         uint64_t byte_offset;  kArrByteOffset
+       } DLTensor;
+    */
+    llvm::Value* base_gep = builder_->CreateInBoundsGEP(buf, index, 
"base_gep");
+    if (kind == builtin::kArrAddr) {
+      return base_gep;
+    }
+    llvm::Value* field_gep = builder_->CreateInBoundsGEP(
+        base_gep, {ConstInt32(0), ConstInt32(field_index.at(kind))}, 
"field_gep");
+    switch (kind) {
+      // These fields have no sub-fields.
+      case builtin::kArrData:
+      case builtin::kArrNDim:
+      case builtin::kArrShape:
+      case builtin::kArrStrides:
+      case builtin::kArrByteOffset:
+        return field_gep;
+    }
+    return builder_->CreateInBoundsGEP(
+        field_gep, {ConstInt32(0), ConstInt32(subfield_index.at(kind))}, 
"subfield_gep");
+  }
+
+  if (kind == builtin::kTVMValueContent) {
+    /* TVMValue is a union:
+       typedef union {
+         int64_t v_int64;
+         double v_float64;
+         void* v_handle;
+         const char* v_str;
+         TVMType v_type;
+         TVMContext v_ctx;
+       } TVMValue;
+    */
+    CHECK_EQ(t.lanes(), 1);
+    CHECK(t.is_handle() || t.bits() == 64);
+    if (t.is_int()) {
+      buf = builder_->CreatePointerCast(buf, t_int64_->getPointerTo());
+      return builder_->CreateInBoundsGEP(buf, index);
+    } else if (t.is_float()) {
+      buf = builder_->CreatePointerCast(buf, t_float64_->getPointerTo());
+      return builder_->CreateInBoundsGEP(buf, index);
+    } else {
+      CHECK(t.is_handle());
+      buf = builder_->CreatePointerCast(buf, t_tvm_value_->getPointerTo());
+      buf = builder_->CreateInBoundsGEP(buf, index);
+      return builder_->CreatePointerCast(buf, t_void_p_->getPointerTo());
+    }
+  }
+
+  assert(!"Unknown kind");
+  return nullptr;
+}
+
+namespace {
+// Check if the function matches the TVMBackendPackedCFunc prototype.
+bool UsesExportABI(const PrimFunc& f) {
+  if (f->attrs.defined()) {
+    auto it = f->attrs->dict.find("calling_conv");
+    return it != f->attrs->dict.end() &&
+           Downcast<Integer>((*it).second) == CallingConv::kCPackedFunc;
+  }
+  return false;
+}
+
+__attribute__((unused)) std::ostream& operator<<(std::ostream& os, const 
llvm::Module& m) {
+  std::string ms;
+  llvm::raw_string_ostream sos(ms);
+  sos << m;
+  os << sos.str();
+  return os;
+}
+
+void ProcessLLVMOptions(const std::vector<std::string>& llvm_vec) {
+  if (llvm_vec.empty()) return;
+
+  // LLVM options.
+  std::vector<const char*> starts;
+  std::transform(llvm_vec.begin(), llvm_vec.end(), std::back_inserter(starts),
+                 std::mem_fn(&std::string::c_str));
+  const char** args = &starts.front();
+
+  llvm::cl::ParseCommandLineOptions(llvm_vec.size(), args);
+}
+
+}  // namespace
+
+runtime::Module BuildHexagon(IRModule mod, std::string target_str) {
+  if (target_str.empty()) {
+    LOG(FATAL) << "Unknown or invalid target.";
+  }
+
+  // Make sure all targets are registered. InitializeLLVM can be called
+  // multiple times, after the first call all subsequent calls are no-ops.
+  InitializeLLVM();
+
+  auto split = [](const std::string& str, char delim = ' ') {
+    std::vector<std::string> vec;
+    std::string tmp;
+    for (std::istringstream iss(str); std::getline(iss, tmp, delim);) {
+      vec.push_back(tmp);
+    }
+    return vec;
+  };
+  auto starts_with = [](const std::string& s, const std::string& p) {
+    return !s.compare(0, p.size(), p);
+  };
+
+  std::vector<std::string> flags = split(target_str);
+  std::string llvm_target_str, llvm_options_str = "llvm";
+
+  for (const auto& s : flags) {
+    if (starts_with(s, "-mattr=") || starts_with(s, "-mtriple=") || 
starts_with(s, "-mcpu=")) {
+      llvm_target_str += " " + s;
+    } else if (starts_with(s, "-llvm-options=")) {
+      llvm_options_str += "," + s.substr(14 /*length of -llvm-options=*/);
+    }
+  }
+
+  // Postprocess the LLVM options string: replace '@' with '=', and ',' with ' 
'.
+  for (int i = 0, e = llvm_options_str.size(); i != e; ++i) {
+    switch (llvm_options_str[i]) {
+      case '@':
+        llvm_options_str[i] = '=';
+        break;
+      case ',':
+        llvm_options_str[i] = ' ';
+        break;
+    }
+  }
+
+  // The vector of LLVM options is treated at "argv" from "main(argc, argv)". 
The entry at
+  // position 0 is the name of the executable, and is ignored by the LLVM 
cl::option parser.
+  // Make sure it's set to "llvm" (tvm.target.hexagon does that).
+  std::vector<std::string> llvm_options_vec = split(llvm_options_str);
+  assert(llvm_options_vec.size() >= 1 && llvm_options_vec[0] == "llvm");
+  llvm_options_vec.insert(std::next(llvm_options_vec.begin()),
+                          {"-hexagon-small-data-threshold=0",
+                           "-force-target-max-vector-interleave=1", 
"-hexagon-autohvx=1"});
+
+  // Process extra command line options for LLVM. Make sure it's only
+  // done once.
+  static bool CallOnce = (ProcessLLVMOptions(llvm_options_vec), true);
+  (void)CallOnce;
+
+  std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(target_str);
+  std::unique_ptr<CodeGenHexagon> cg(new CodeGenHexagon());
+  std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext());
+  cg->Init("TVMHexagonModule", tm.get(), ctx.get(), false, false, false);
+  for (auto kv : mod->functions) {
+    CHECK(kv.second->IsInstance<PrimFuncNode>()) << "Can only lower IR Module 
with PrimFuncs";
+    auto f = Downcast<PrimFunc>(kv.second);
+    cg->AddFunction(f);
+  }
+  // Uncomment to get the LLVM module right out of codegen, before 
optimizations.
+  // std::cerr << "HexagonModule.0 {\n" << *cg->GetModulePtr() << "}\n";
+  std::unique_ptr<llvm::Module> module = cg->Finish();
+
+  enum CodeGenFileType { Asm, Obj, IR, BC };
+
+  auto EmitToString = [&tm](const llvm::Module& m, CodeGenFileType cgft) {
+    std::string out;
+
+    if (cgft == IR || cgft == BC) {
+      llvm::raw_string_ostream os(out);
+      if (cgft == IR)
+        m.print(os, nullptr);
+      else
+        llvm::WriteBitcodeToFile(m, os);
+    } else if (cgft == Asm || cgft == Obj) {
+      using namespace llvm;
+#if TVM_LLVM_VERSION <= 90
+      auto ft = cgft == Asm ? TargetMachine::CodeGenFileType::CGFT_AssemblyFile
+                            : TargetMachine::CodeGenFileType::CGFT_ObjectFile;
+#else
+      auto ft = cgft == Asm ? llvm::CGFT_AssemblyFile : llvm::CGFT_ObjectFile;
+#endif
+
+      SmallString<16384> ss;  // Will grow on demand.
+      llvm::raw_svector_ostream os(ss);
+      std::unique_ptr<llvm::Module> cm = CloneModule(m);
+      legacy::PassManager pass;
+      CHECK(tm->addPassesToEmitFile(pass, os, nullptr, ft) == 0) << "Cannot 
emit target code";
+      pass.run(*cm.get());
+      out.assign(ss.c_str(), ss.size());
+    }
+
+    return out;
+  };
+
+  auto SaveToFile = [](const std::string& data, const std::string& suffix) {
+    llvm::SmallString<64> file_name;
+    int fd;
+    std::error_code ec = llvm::sys::fs::createTemporaryFile("tvm", suffix, fd, 
file_name);
+    CHECK_EQ(static_cast<bool>(ec), false) << ec.message();
+    llvm::raw_fd_ostream file(fd, true);
+    file << data;
+    CHECK(!file.has_error()) << file.error().message();
+    // If there is an error, execution will never get here, but return
+    // {ec, name} anyway to allow caller to handle error conditions.
+    // This way the "CHECK" above can be removed with minimal effort.
+    return std::make_pair(file.error(), std::string(file_name.c_str()));
+  };
+
+  std::string asm_str = EmitToString(*module.get(), Asm);
+  std::string obj_str = EmitToString(*module.get(), Obj);
+  std::string ir_str = EmitToString(*module.get(), IR);
+  std::string bc_str = EmitToString(*module.get(), BC);
+
+  std::string o_name = SaveToFile(obj_str, "o").second;
+  std::string so_name(o_name, 0, o_name.size() - 1);
+  so_name += "so";
+
+  const auto* f = 
tvm::runtime::Registry::Get("tvm.contrib.hexagon.link_shared");
+  CHECK(f != nullptr) << "tvm.contrib.hexagon.link_shared does not to exist, "
+                         "do import tvm.contrib.hexagon";
+
+  Array<PrimExpr> o_names = {StringImm(o_name)};
+  int rc = (*f)(so_name, o_names);
+  CHECK(rc == 0) << "Failed to link " << so_name;
+
+  // Move it to ExtractFuncInfo?
+  std::set<std::string> export_abi;
+  for (auto kv : mod->functions) {
+    auto f = Downcast<PrimFunc>(kv.second);
+    if (UsesExportABI(f)) export_abi.insert(get_name(f));
+  }
+  return HexagonModuleCreate(so_name, "so", ExtractFuncInfo(mod), asm_str, 
obj_str, ir_str, bc_str,
+                             export_abi);
+}
+
+TVM_REGISTER_GLOBAL("target.build.hexagon").set_body([](TVMArgs args, 
TVMRetValue* rv) {
+  *rv = BuildHexagon(args[0], args[1]);
+});
+
+}  // namespace codegen
+}  // namespace tvm
+
+#endif  // TVM_LLVM_VERSION
diff --git a/src/target/llvm/intrin_rule_hexagon.cc 
b/src/target/llvm/intrin_rule_hexagon.cc
new file mode 100644
index 0000000..d382251
--- /dev/null
+++ b/src/target/llvm/intrin_rule_hexagon.cc
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifdef TVM_LLVM_VERSION
+
+#include "intrin_rule_llvm.h"
+
+namespace tvm {
+namespace codegen {
+namespace llvm {
+
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.hexagon.exp")
+    .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::exp, 1>);
+
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.hexagon.fma")
+    .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>);
+
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.hexagon.log")
+    .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::log, 1>);
+
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.hexagon.sqrt")
+    .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::sqrt, 1>);
+
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.hexagon.floor")
+    .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::floor, 1>);
+
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.hexagon.ceil")
+    .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ceil, 1>);
+
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.hexagon.trunc")
+    .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::trunc, 1>);
+
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.hexagon.fabs")
+    .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::fabs, 1>);
+
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.hexagon.round")
+    .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>);
+
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.hexagon.pow")
+    .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 1>);
+
+TVM_REGISTER_GLOBAL("tvm.intrin.rule.hexagon.popcount")
+    .set_body(DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>);
+
+}  // namespace llvm
+}  // namespace codegen
+}  // namespace tvm
+
+#endif  // TVM_LLVM_VERSION
diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc
index 3e35e5b..40ade4d 100644
--- a/src/target/target_kind.cc
+++ b/src/target/target_kind.cc
@@ -193,6 +193,10 @@ TVM_REGISTER_TARGET_KIND("hexagon")
     .add_attr_option<Array<String>>("libs")
     .add_attr_option<String>("device")
     .add_attr_option<String>("model")
+    .add_attr_option<String>("mcpu")
+    .add_attr_option<Array<String>>("mattr")
+    .add_attr_option<String>("mtriple")
+    .add_attr_option<Array<String>>("llvm-options")
     .add_attr_option<Bool>("system-lib")
     .set_default_keys({"hexagon"})
     .set_device_type(kDLHexagon);
diff --git a/tests/python/unittest/test_target_codegen_hexagon.py 
b/tests/python/unittest/test_target_codegen_hexagon.py
new file mode 100644
index 0000000..1478e2e
--- /dev/null
+++ b/tests/python/unittest/test_target_codegen_hexagon.py
@@ -0,0 +1,95 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import re
+import tvm
+import tvm.contrib.hexagon as hexagon
+
+
+def check_prereq_and_setup():
+    if tvm.target.codegen.llvm_version_major() <= 7:
+        print('Skipping test: need LLVM 7 or later for codegen')
+        return False
+    if os.name != 'posix':
+        print('Skipping test on non-POSIX platforms')
+        return False
+    if not tvm.runtime.enabled('hexagon'):
+        print('Hexagon runtime not enabled')
+        return False
+    # Register a phony linker, so that we can test codegen without a Hexagon 
toolchain.
+    hexagon.register_linker(lambda: '/bin/true')
+    return True
+
+
+def test_basic():
+    if not check_prereq_and_setup():
+        return
+    target = tvm.target.hexagon('v66', hvx=128)
+
+    def check_add(offload):
+        A = tvm.te.placeholder((128,), dtype='uint8', name='A')
+        B = tvm.te.placeholder((128,), dtype='uint8', name='A')
+        C = tvm.te.compute((128,), lambda i: A[i] + B[i], name='C')
+        s = tvm.te.create_schedule(C.op)
+
+        if offload:
+            xo, xi = s[C].split(s[C].op.axis[0], nparts=1)
+            s[C].bind(xo, tvm.te.thread_axis('pipeline'))
+            m = tvm.build(s, [C, A, B], target=target, name='offload_add')
+            hexm = m.imported_modules[0]
+        else:
+            hexm = tvm.build(s, [C, A, B], target=target, target_host=target, 
name='native_add')
+
+        asm = hexm.get_source('s')
+        vadds = re.findall(r'v[0-9]+.b = vadd\(v[0-9]+.b,v[0-9]+.b\)', asm)
+        assert vadds  # Check that it's non-empty
+
+    check_add(True)
+    check_add(False)
+
+
+def test_alloc_vtcm():
+    if not check_prereq_and_setup():
+        return
+    target = tvm.target.hexagon('v66')
+
+    buf_len = 2048
+    A = tvm.te.placeholder((buf_len,), name='A', dtype='int8')
+    B = tvm.te.placeholder((buf_len,), name='B', dtype='int8')
+
+    A_buf = tvm.te.compute((buf_len,), lambda *i: A(*i), 'A_buf')
+    B_buf = tvm.te.compute((buf_len,), lambda *i: B(*i), 'B_buf')
+    C = tvm.te.compute((buf_len,), lambda *i: A_buf(*i) + B_buf(*i), name='C')
+    s = tvm.te.create_schedule(C.op)
+
+    # Use VTCM for each buffer.
+    s[A_buf].set_scope("local.vtcm")
+    s[B_buf].set_scope("local.vtcm")
+
+    config = {'tir.add_lower_pass': hexagon.ir_lower_vtcm_pass()}
+    with tvm.transform.PassContext(config = config):
+        irmod = tvm.lower(s, [A, B, C], name = 'alloc_vtcm')
+
+    calls = re.findall('HexagonBackend[A-Za-z]*VTCM', str(irmod['alloc_vtcm']))
+    assert 'HexagonBackendAllocateVTCM' in calls
+    assert 'HexagonBackendFreeVTCM' in calls
+
+
+if __name__ == '__main__':
+    test_basic()
+    test_alloc_vtcm()

Reply via email to