This is an automated email from the ASF dual-hosted git repository. jwfromm pushed a commit to branch checkpoint in repository https://gitbox.apache.org/repos/asf/tvm.git
commit df61188d938af20063c07ace40314a80b7f2dc32 Author: mei-ye <meiandm...@yahoo.com> AuthorDate: Thu Aug 20 22:56:52 2020 -0700 Initial commit for AMD proposal of ONNXRT<>TVM --- include/tvm/driver/jit_interface.h | 10 +++++++ src/driver/driver_api.cc | 58 ++++++++++++++++++++++++++++++++++++++ src/relay/backend/build_module.cc | 43 ++++++++++++++++++++++++++++ 3 files changed, 111 insertions(+) diff --git a/include/tvm/driver/jit_interface.h b/include/tvm/driver/jit_interface.h new file mode 100644 index 0000000..966d5a8 --- /dev/null +++ b/include/tvm/driver/jit_interface.h @@ -0,0 +1,10 @@ +#define EXPORT_DLL __attribute__((visibility("default"))) + +#ifdef __cplusplus +extern "C" { + EXPORT_DLL tvm::runtime::Module TVMCompile(const std::string& onnx_txt, const std::string& target, const std::string& target_host, int opt_level); + EXPORT_DLL void TVMRun(tvm::runtime::Module& mod, const std::string& name, tvm::runtime::TVMArgs& args, tvm::runtime::TVMRetValue* ret); + + +} // TVM_EXTERN_C +#endif diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index bbbb7e3..758f019 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -23,13 +23,25 @@ */ #include <dmlc/thread_local.h> #include <tvm/driver/driver_api.h> +#include <tvm/driver/jit_interface.h> +#include <tvm/ir/module.h> #include <tvm/ir/transform.h> +#include <tvm/relay/analysis.h> +#include <tvm/relay/expr.h> +#include <tvm/relay/op_attr_types.h> +#include <tvm/relay/op_strategy.h> +#include <tvm/relay/transform.h> +#include <tvm/relay/type.h> +#include <tvm/runtime/module.h> +#include <tvm/runtime/packed_func.h> #include <tvm/runtime/container.h> #include <tvm/runtime/registry.h> #include <tvm/target/codegen.h> #include <tvm/te/operation.h> #include <tvm/tir/analysis.h> #include <tvm/tir/transform.h> +#include <topi/generic/injective.h> +#include <tvm/target/generic_func.h> #include <algorithm> #include <mutex> @@ -324,3 +336,49 @@ runtime::Module build(const IRModule& funcs, const Target& target, const Target& } } // namespace tvm + + +tvm::runtime::Module TVMCompile(const std::string& onnx_txt, const std::string& target, const std::string& target_host, int opt_level) +{ + auto tensor_type = tvm::relay::TensorType({1, 6}, tvm::runtime::DataType::Float(32)); + auto X1 = tvm::relay::Var("X1", tensor_type); + auto mul_op = tvm::relay::Op::Get("multiply"); + auto mul1 = tvm::relay::Call(mul_op, {X1, X1}, tvm::Attrs(), {}); + auto mul2 = tvm::relay::Call(mul_op, {X1, mul1}, tvm::Attrs(), {}); + auto mul3 = tvm::relay::Call(mul_op, {X1, mul2}, tvm::Attrs(), {}); + auto Y4 = tvm::relay::Call(mul_op, {X1, mul3}, tvm::Attrs(), {}); + auto func = tvm::relay::Function(tvm::relay::FreeVars(Y4), Y4, tvm::relay::Type(), {}); + + auto reg = tvm::runtime::Registry::Get("ir.RegisterOpAttr"); + if (!reg) + LOG(FATAL) << "no _Register"; + + auto fs = tvm::runtime::Registry::Get("jit.strategy"); + if (!fs) + LOG(FATAL) << "No jit strategy registered."; + + auto fgeneric = tvm::GenericFunc::Get("jit.strategy_generic").set_default(*fs); + (*reg)("multiply", "FTVMStrategy", fgeneric, 10); + (*reg)("multiply", "TShapeDataDependant", false, 10); + + auto pfb = tvm::runtime::Registry::Get("relay.build_module._BuildModule"); + tvm::runtime::Module build_mod = (*pfb)(); + auto build_f = build_mod.GetFunction("build", false); + auto mod_f = build_mod.GetFunction("get_module", false); + auto relay_mod = tvm::IRModule::FromExpr(func); + tvm::Map<tvm::Integer, tvm::Target> targets; + // tvm::Target tgt = tvm::Target::Create(target); + tvm::Target tgt = tvm::Target::Create("llvm"); + targets.Set(0, tgt); + // tvm::Target host = (target == target_host) ? tgt : tvm::Target::Create(target_host); + build_f(relay_mod, targets, tgt); + tvm::runtime::Module mod = mod_f(); + return mod; +} + +void TVMRun(tvm::runtime::Module& mod, const std::string& name, tvm::runtime::TVMArgs& args, tvm::runtime::TVMRetValue* ret) +{ + mod.GetFunction(name).CallPacked(args, ret); + // process return value, refe to TVMFuncCall in c_runtime_api.cc + +} diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 0884692..3c047e1 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -28,6 +28,10 @@ #include <tvm/relay/qnn/transform.h> #include <tvm/relay/transform.h> #include <tvm/runtime/device_api.h> +#include <tvm/relay/op_attr_types.h> +#include <tvm/relay/op_strategy.h> +#include <topi/broadcast.h> +#include <topi/generic/injective.h> #include <memory> @@ -553,6 +557,45 @@ runtime::Module RelayBuildCreate() { return runtime::Module(exec); } +#if 1 +TVM_REGISTER_GLOBAL("jit.strategy") + .set_body_typed([](const Attrs& attrs, const Array<te::Tensor>& inputs, const Type& out_type, + const Target& target) { + FTVMCompute fcompute = [](const Attrs& attrs, const Array<te::Tensor>& inputs, + const Type& out_type) -> Array<te::Tensor> { + CHECK_EQ(inputs.size(), 2U); + return {topi::multiply(inputs[0], inputs[1])}; + }; + FTVMSchedule fschedule = [](const Attrs& attrs, const Array<te::Tensor>& outs, + const Target& target) { + With<Target> target_scope(target); + return topi::generic::schedule_injective(target, outs); + }; + + auto n = make_object<OpStrategyNode>(); + auto strategy = relay::OpStrategy(std::move(n)); + strategy.AddImplementation(fcompute, fschedule, "jit.strategy", 10); + return strategy; +}); + + +TVM_REGISTER_GLOBAL("relay.backend.lower_call") + .set_body_typed([](const relay::Call& call, const Array<te::Tensor>& inputs, + const Target& target) { + static auto fstrategy = Op::GetAttrMap<relay::FTVMStrategy>("FTVMStrategy"); + Op op = Downcast<Op>(call->op); + auto out_type = call->checked_type(); + OpStrategy strategy = fstrategy[op](call->attrs, inputs, out_type, target); + auto impl = strategy->specializations[0]->implementations[0]; + auto outs = impl.Compute(call->attrs, inputs, out_type); + auto f = runtime::Registry::Get("relay.backend._make_LoweredOutput"); + if (!f) { + LOG(FATAL) << "relay.backend._make_LoweredOutput is not registered"; + } + return (*f)(outs, impl); +}); +#endif + TVM_REGISTER_GLOBAL("relay.build_module._BuildModule").set_body([](TVMArgs args, TVMRetValue* rv) { *rv = RelayBuildCreate(); });