This is an automated email from the ASF dual-hosted git repository. jroesch pushed a commit to branch rust-tvm-rt in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit d8d3147e304bdf5336acf22fb7616a9e59c2ad71 Author: Jared Roesch <jroe...@octoml.ai> AuthorDate: Wed May 6 02:22:08 2020 -0700 Add tvm-rt --- include/tvm/ir/expr.h | 5 +- python/tvm/runtime/object_generic.py | 2 +- rust/macros/Cargo.toml | 4 +- rust/macros/src/{lib.rs => import_module.rs} | 12 +- rust/macros/src/lib.rs | 124 +----- rust/macros/src/object.rs | 171 ++++++++ rust/tvm-rt/.gitignore | 7 + rust/{macros/Cargo.toml => tvm-rt/.travis.yml} | 24 +- rust/{macros => tvm-rt}/Cargo.toml | 30 +- rust/tvm-rt/README.md | 235 +++++++++++ rust/{macros => tvm-rt/examples/resnet}/Cargo.toml | 23 +- rust/tvm-rt/examples/resnet/README.md | 45 +++ rust/tvm-rt/examples/resnet/build.rs | 42 ++ rust/tvm-rt/examples/resnet/src/build_resnet.py | 134 +++++++ rust/tvm-rt/examples/resnet/src/main.rs | 160 ++++++++ rust/tvm-rt/src/context.rs | 76 ++++ rust/tvm-rt/src/errors.rs | 45 +++ rust/tvm-rt/src/function.rs | 340 ++++++++++++++++ rust/tvm-rt/src/lib.rs | 124 ++++++ rust/tvm-rt/src/module.rs | 130 +++++++ rust/tvm-rt/src/ndarray.rs | 431 +++++++++++++++++++++ rust/tvm-rt/src/object/mod.rs | 99 +++++ rust/tvm-rt/src/object/object_ptr.rs | 283 ++++++++++++++ rust/tvm-rt/src/string.rs | 72 ++++ rust/tvm-rt/src/to_boxed_fn.rs | 222 +++++++++++ rust/tvm-rt/src/to_function.rs | 377 ++++++++++++++++++ rust/tvm-rt/src/value.rs | 166 ++++++++ rust/tvm-rt/tests/test_ir.rs | 36 ++ src/ir/expr.cc | 11 +- src/printer/relay_text_printer.cc | 15 +- src/relay/transforms/to_cps.cc | 2 +- src/runtime/object.cc | 14 + src/runtime/object_internal.h | 9 + 33 files changed, 3294 insertions(+), 176 deletions(-) diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index fba35a9..82689bd 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -27,6 +27,7 @@ #include <tvm/runtime/object.h> #include <tvm/node/node.h> #include <tvm/node/container.h> +#include <tvm/runtime/container.h> #include <tvm/ir/span.h> #include <tvm/ir/type.h> #include <string> @@ -36,6 +37,8 @@ namespace tvm { +using tvm::runtime::String; + /*! * \brief Base type of all the expressions. * \sa Expr @@ -189,7 +192,7 @@ class GlobalVar; class GlobalVarNode : public RelayExprNode { public: /*! \brief The name of the variable, this only acts as a hint. */ - std::string name_hint; + String name_hint; void VisitAttrs(AttrVisitor* v) { v->Visit("name_hint", &name_hint); diff --git a/python/tvm/runtime/object_generic.py b/python/tvm/runtime/object_generic.py index cc21450..8f559ae 100644 --- a/python/tvm/runtime/object_generic.py +++ b/python/tvm/runtime/object_generic.py @@ -38,7 +38,7 @@ ObjectTypes = (ObjectBase, NDArrayBase, Module, ObjectRValueRef, PyNativeObject) def convert_to_object(value): - """Convert a python value to corresponding object type. + """Convert a Python value to corresponding object type. Parameters ---------- diff --git a/rust/macros/Cargo.toml b/rust/macros/Cargo.toml index 784b35e..7abc9ae 100644 --- a/rust/macros/Cargo.toml +++ b/rust/macros/Cargo.toml @@ -32,5 +32,5 @@ proc-macro = true [dependencies] goblin = "0.0.24" proc-macro2 = "^1.0" -quote = "1.0" -syn = "1.0" +quote = "^1.0" +syn = { version = "1.0.17", features = ["full", "extra-traits"] } diff --git a/rust/macros/src/lib.rs b/rust/macros/src/import_module.rs similarity index 92% copy from rust/macros/src/lib.rs copy to rust/macros/src/import_module.rs index 9f28c74..6b059ae 100644 --- a/rust/macros/src/lib.rs +++ b/rust/macros/src/import_module.rs @@ -16,9 +16,6 @@ * specific language governing permissions and limitations * under the License. */ - -extern crate proc_macro; - use quote::quote; use std::{fs::File, io::Read}; use syn::parse::{Parse, ParseStream, Result}; @@ -37,8 +34,7 @@ impl Parse for ImportModule { } } -#[proc_macro] -pub fn import_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream { +pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let import_module_args = syn::parse_macro_input!(input as ImportModule); let manifest = @@ -109,11 +105,11 @@ pub fn import_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream }; let fns = quote! { - use tvm_runtime::{ffi::TVMValue, TVMArgValue, TVMRetValue, FuncCallError}; + use tvm_runtime::{ffi::TVMValue, ArgValue, RetValue, FuncCallError}; #extern_fns #( - pub fn #fn_names(args: &[TVMArgValue]) -> Result<TVMRetValue, FuncCallError> { + pub fn #fn_names(args: &[ArgValue]) -> Result<RetValue, FuncCallError> { let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args .into_iter() .map(|arg| { @@ -125,7 +121,7 @@ pub fn import_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream ext::#fn_names(values.as_ptr(), type_codes.as_ptr(), values.len() as i32) }; if exit_code == 0 { - Ok(TVMRetValue::default()) + Ok(RetValue::default()) } else { Err(FuncCallError::get_with_context(stringify!(#fn_names).to_string())) } diff --git a/rust/macros/src/lib.rs b/rust/macros/src/lib.rs index 9f28c74..e9ddc25 100644 --- a/rust/macros/src/lib.rs +++ b/rust/macros/src/lib.rs @@ -17,121 +17,17 @@ * under the License. */ -extern crate proc_macro; - -use quote::quote; -use std::{fs::File, io::Read}; -use syn::parse::{Parse, ParseStream, Result}; -use syn::LitStr; - -use std::path::PathBuf; - -struct ImportModule { - importing_file: LitStr, -} - -impl Parse for ImportModule { - fn parse(input: ParseStream) -> Result<Self> { - let importing_file: LitStr = input.parse()?; - Ok(ImportModule { importing_file }) - } -} +use proc_macro::TokenStream; +mod import_module; +mod object; #[proc_macro] -pub fn import_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream { - let import_module_args = syn::parse_macro_input!(input as ImportModule); - - let manifest = - std::env::var("CARGO_MANIFEST_DIR").expect("variable should always be set by Cargo."); - - let mut path = PathBuf::new(); - path.push(manifest); - path = path.join(import_module_args.importing_file.value()); - - let mut fd = File::open(&path) - .unwrap_or_else(|_| panic!("Unable to find TVM object file at `{}`", path.display())); - let mut buffer = Vec::new(); - fd.read_to_end(&mut buffer).unwrap(); - - let fn_names = match goblin::Object::parse(&buffer).unwrap() { - goblin::Object::Elf(elf) => elf - .syms - .iter() - .filter_map(|s| { - if s.st_type() == 0 || goblin::elf::sym::type_to_str(s.st_type()) == "FILE" { - return None; - } - match elf.strtab.get(s.st_name) { - Some(Ok(name)) if name != "" => { - Some(syn::Ident::new(name, proc_macro2::Span::call_site())) - } - _ => None, - } - }) - .collect::<Vec<_>>(), - goblin::Object::Mach(goblin::mach::Mach::Binary(obj)) => { - obj.symbols() - .filter_map(|s| match s { - Ok((name, ref nlist)) - if nlist.is_global() - && nlist.n_sect != 0 - && !name.ends_with("tvm_module_ctx") => - { - Some(syn::Ident::new( - if name.starts_with('_') { - // Mach objects prepend a _ to globals. - &name[1..] - } else { - &name - }, - proc_macro2::Span::call_site(), - )) - } - _ => None, - }) - .collect::<Vec<_>>() - } - _ => panic!("Unsupported object format."), - }; - - let extern_fns = quote! { - mod ext { - extern "C" { - #( - pub(super) fn #fn_names( - args: *const tvm_runtime::ffi::TVMValue, - type_codes: *const std::os::raw::c_int, - num_args: std::os::raw::c_int - ) -> std::os::raw::c_int; - )* - } - } - }; - - let fns = quote! { - use tvm_runtime::{ffi::TVMValue, TVMArgValue, TVMRetValue, FuncCallError}; - #extern_fns - - #( - pub fn #fn_names(args: &[TVMArgValue]) -> Result<TVMRetValue, FuncCallError> { - let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args - .into_iter() - .map(|arg| { - let (val, code) = arg.to_tvm_value(); - (val, code as i32) - }) - .unzip(); - let exit_code = unsafe { - ext::#fn_names(values.as_ptr(), type_codes.as_ptr(), values.len() as i32) - }; - if exit_code == 0 { - Ok(TVMRetValue::default()) - } else { - Err(FuncCallError::get_with_context(stringify!(#fn_names).to_string())) - } - } - )* - }; +pub fn import_module(input: TokenStream) -> TokenStream { + import_module::macro_impl(input) +} - proc_macro::TokenStream::from(fns) +#[proc_macro_derive(Object, attributes(base, ref_name, type_key))] +pub fn macro_impl(input: TokenStream) -> TokenStream { + // let input = proc_macro2::TokenStream::from(input); + TokenStream::from(object::macro_impl(input)) } diff --git a/rust/macros/src/object.rs b/rust/macros/src/object.rs new file mode 100644 index 0000000..96a86dd --- /dev/null +++ b/rust/macros/src/object.rs @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use proc_macro::TokenStream; +use proc_macro2::Span; +use quote::quote; +use syn::DeriveInput; +use syn::Ident; + +pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { + let derive_input = syn::parse_macro_input!(input as DeriveInput); + let payload_id = derive_input.ident; + + let mut type_key = None; + let mut ref_name = None; + let base = Some(Ident::new("base", Span::call_site())); + + for attr in derive_input.attrs { + if attr.path.is_ident("type_key") { + type_key = Some(attr.parse_meta().expect("foo")) + } + + if attr.path.is_ident("ref_name") { + ref_name = Some(attr.parse_meta().expect("foo")) + } + } + + let type_key = if let Some(syn::Meta::NameValue(name_value)) = type_key { + match name_value.lit { + syn::Lit::Str(type_key) => type_key, + _ => panic!("foo"), + } + } else { + panic!("bar"); + }; + + let ref_name = if let Some(syn::Meta::NameValue(name_value)) = ref_name { + match name_value.lit { + syn::Lit::Str(ref_name) => ref_name, + _ => panic!("foo"), + } + } else { + panic!("bar"); + }; + + let ref_id = Ident::new(&ref_name.value(), Span::call_site()); + let base = base.expect("should be present"); + + let expanded = quote! { + unsafe impl tvm_rt::object::IsObject for #payload_id { + const TYPE_KEY: &'static str = #type_key; + + fn as_object<'s>(&'s self) -> &'s Object { + &self.#base.as_object() + } + } + + #[derive(Clone)] + pub struct #ref_id(Option<tvm_rt::object::ObjectPtr<#payload_id>>); + + impl tvm_rt::object::ToObjectRef for #ref_id { + fn to_object_ref(&self) -> ObjectRef { + ObjectRef(self.0.as_ref().map(|o| o.upcast())) + } + } + + impl std::ops::Deref for #ref_id { + type Target = #payload_id; + + fn deref(&self) -> &Self::Target { + self.0.as_ref().unwrap() + } + } + + impl std::convert::TryFrom<tvm_rt::RetValue> for #ref_id { + type Error = ::anyhow::Error; + + fn try_from(ret_val: tvm_rt::RetValue) -> Result<#ref_id, Self::Error> { + use std::convert::TryInto; + let oref: ObjectRef = ret_val.try_into()?; + let ptr = oref.0.ok_or(anyhow::anyhow!("null ptr"))?; + let ptr = ptr.downcast::<#payload_id>()?; + Ok(#ref_id(Some(ptr))) + } + } + + impl<'a> From<#ref_id> for tvm_rt::ArgValue<'a> { + fn from(object_ref: #ref_id) -> tvm_rt::ArgValue<'a> { + use std::ffi::c_void; + let object_ptr = &object_ref.0; + match object_ptr { + None => { + tvm_rt::ArgValue:: + ObjectHandle(std::ptr::null::<c_void>() as *mut c_void) + } + Some(value) => value.clone().into() + } + } + } + + impl<'a> From<&#ref_id> for tvm_rt::ArgValue<'a> { + fn from(object_ref: &#ref_id) -> tvm_rt::ArgValue<'a> { + let oref: #ref_id = object_ref.clone(); + tvm_rt::ArgValue::<'a>::from(oref) + } + } + + impl<'a> std::convert::TryFrom<tvm_rt::ArgValue<'a>> for #ref_id { + type Error = anyhow::Error; + + fn try_from(arg_value: tvm_rt::ArgValue<'a>) -> Result<#ref_id, Self::Error> { + use std::convert::TryInto; + let optr = arg_value.try_into()?; + Ok(#ref_id(Some(optr))) + } + } + + impl<'a> std::convert::TryFrom<&tvm_rt::ArgValue<'a>> for #ref_id { + type Error = anyhow::Error; + + fn try_from(arg_value: &tvm_rt::ArgValue<'a>) -> Result<#ref_id, Self::Error> { + use std::convert::TryInto; + let optr = arg_value.try_into()?; + Ok(#ref_id(Some(optr))) + } + } + + impl From<#ref_id> for tvm_rt::RetValue { + fn from(object_ref: #ref_id) -> tvm_rt::RetValue { + use std::ffi::c_void; + let object_ptr = &object_ref.0; + match object_ptr { + None => { + tvm_rt::RetValue::ObjectHandle(std::ptr::null::<c_void>() as *mut c_void) + } + Some(value) => value.clone().into() + } + } + } + + }; + + TokenStream::from(expanded) +} + +// impl TryFrom<RetValue> for Var { +// type Error = anyhow::Error; + +// fn try_from(ret_val: RetValue) -> Result<Var, Self::Error> { +// let oref: ObjectRef = ret_val.try_into()?; +// let var_ptr = oref.0.ok_or(anyhow!("null ptr"))?; +// let var_ptr = var_ptr.downcast::<VarNode>()?; +// Ok(Var(Some(var_ptr))) +// } +// } diff --git a/rust/tvm-rt/.gitignore b/rust/tvm-rt/.gitignore new file mode 100644 index 0000000..2430329 --- /dev/null +++ b/rust/tvm-rt/.gitignore @@ -0,0 +1,7 @@ +target +**/*.rs.bk +Cargo.lock +/tests/basics/add_* +/examples/resnet/deploy_* +/examples/resnet/*.png +/examples/resnet/synset.* diff --git a/rust/macros/Cargo.toml b/rust/tvm-rt/.travis.yml similarity index 67% copy from rust/macros/Cargo.toml copy to rust/tvm-rt/.travis.yml index 784b35e..e963b7c 100644 --- a/rust/macros/Cargo.toml +++ b/rust/tvm-rt/.travis.yml @@ -15,22 +15,8 @@ # specific language governing permissions and limitations # under the License. -[package] -name = "tvm-macros" -version = "0.1.1" -license = "Apache-2.0" -description = "Procedural macros of the TVM crate." -repository = "https://github.com/apache/incubator-tvm" -readme = "README.md" -keywords = ["tvm"] -authors = ["TVM Contributors"] -edition = "2018" - -[lib] -proc-macro = true - -[dependencies] -goblin = "0.0.24" -proc-macro2 = "^1.0" -quote = "1.0" -syn = "1.0" +language: rust +rust: + - nightly +matrix: + fast_finish: true diff --git a/rust/macros/Cargo.toml b/rust/tvm-rt/Cargo.toml similarity index 65% copy from rust/macros/Cargo.toml copy to rust/tvm-rt/Cargo.toml index 784b35e..417f256 100644 --- a/rust/macros/Cargo.toml +++ b/rust/tvm-rt/Cargo.toml @@ -16,21 +16,29 @@ # under the License. [package] -name = "tvm-macros" -version = "0.1.1" +name = "tvm-rt" +version = "0.1.0" license = "Apache-2.0" -description = "Procedural macros of the TVM crate." +description = "Rust bindings for the TVM runtime API." repository = "https://github.com/apache/incubator-tvm" +homepage = "https://github.com/apache/incubator-tvm" readme = "README.md" -keywords = ["tvm"] +keywords = ["rust", "tvm"] +categories = ["api-bindings", "science"] authors = ["TVM Contributors"] edition = "2018" -[lib] -proc-macro = true - [dependencies] -goblin = "0.0.24" -proc-macro2 = "^1.0" -quote = "1.0" -syn = "1.0" +thiserror = "^1.0" +anyhow = "^1.0" +lazy_static = "1.1" +ndarray = "0.12" +num-traits = "0.2" +tvm-sys = { version = "0.1", path = "../tvm-sys/", features = ["bindings"] } +tvm-macros = { version = "0.1", path = "../macros" } +paste = "0.1" +mashup = "0.1" +once_cell = "^1.3.1" + +[features] +blas = ["ndarray/blas"] diff --git a/rust/tvm-rt/README.md b/rust/tvm-rt/README.md new file mode 100644 index 0000000..fff3b56 --- /dev/null +++ b/rust/tvm-rt/README.md @@ -0,0 +1,235 @@ +<!--- Licensed to the Apache Software Foundation (ASF) under one --> +<!--- or more contributor license agreements. See the NOTICE file --> +<!--- distributed with this work for additional information --> +<!--- regarding copyright ownership. The ASF licenses this file --> +<!--- to you under the Apache License, Version 2.0 (the --> +<!--- "License"); you may not use this file except in compliance --> +<!--- with the License. You may obtain a copy of the License at --> + +<!--- http://www.apache.org/licenses/LICENSE-2.0 --> + +<!--- Unless required by applicable law or agreed to in writing, --> +<!--- software distributed under the License is distributed on an --> +<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY --> +<!--- KIND, either express or implied. See the License for the --> +<!--- specific language governing permissions and limitations --> +<!--- under the License. --> + +# TVM Runtime Frontend Support + +This crate provides an idiomatic Rust API for [TVM](https://github.com/apache/incubator-tvm) runtime frontend. Currently this requires **Nightly Rust** and tested on `rustc 1.32.0-nightly` + +## What Does This Crate Offer? + +Here is a major workflow + +1. Train your **Deep Learning** model using any major framework such as [PyTorch](https://pytorch.org/), [Apache MXNet](https://mxnet.incubator.apache.org/) or [TensorFlow](https://www.tensorflow.org/) +2. Use **TVM** to build optimized model artifacts on a supported context such as CPU, GPU, OpenCL and specialized accelerators. +3. Deploy your models using **Rust** :heart: + +### Example: Deploy Image Classification from Pretrained Resnet18 on ImageNet1k + +Please checkout [examples/resnet](examples/resnet) for the complete end-to-end example. + +Here's a Python snippet for downloading and building a pretrained Resnet18 via Apache MXNet and TVM + +```python +block = get_model('resnet18_v1', pretrained=True) + +sym, params = relay.frontend.from_mxnet(block, shape_dict) +# compile the model +with relay.build_config(opt_level=opt_level): + graph, lib, params = relay.build( + net, target, params=params) +# same the model artifacts +lib.save(os.path.join(target_dir, "deploy_lib.o")) +cc.create_shared(os.path.join(target_dir, "deploy_lib.so"), + [os.path.join(target_dir, "deploy_lib.o")]) + +with open(os.path.join(target_dir, "deploy_graph.json"), "w") as fo: + fo.write(graph.json()) +with open(os.path.join(target_dir,"deploy_param.params"), "wb") as fo: + fo.write(relay.save_param_dict(params)) +``` + +Now, we need to input the artifacts to create and run the *Graph Runtime* to detect our input cat image + +![cat](https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true) + +as demostrated in the following Rust snippet + +```rust + let graph = fs::read_to_string("deploy_graph.json")?; + // load the built module + let lib = Module::load(&Path::new("deploy_lib.so"))?; + // get the global TVM graph runtime function + let runtime_create_fn = Function::get("tvm.graph_runtime.create", true).unwrap(); + let runtime_create_fn_ret = call_packed!( + runtime_create_fn, + &graph, + &lib, + &ctx.device_type, + &ctx.device_id + )?; + // get graph runtime module + let graph_runtime_module: Module = runtime_create_fn_ret.try_into()?; + // get the registered `load_params` from runtime module + let ref load_param_fn = graph_runtime_module + .get_function("load_params", false) + .unwrap(); + // parse parameters and convert to ByteArray + let params: Vec<u8> = fs::read("deploy_param.params")?; + let barr = ByteArray::from(¶ms); + // load the parameters + call_packed!(load_param_fn, &barr)?; + // get the set_input function + let ref set_input_fn = graph_runtime_module + .get_function("set_input", false) + .unwrap(); + + call_packed!(set_input_fn, "data", &input)?; + // get `run` function from runtime module + let ref run_fn = graph_runtime_module.get_function("run", false).unwrap(); + // execute the run function. Note that it has no argument + call_packed!(run_fn,)?; + // prepare to get the output + let output_shape = &mut [1, 1000]; + let output = empty(output_shape, TVMContext::cpu(0), TVMType::from("float32")); + // get the `get_output` function from runtime module + let ref get_output_fn = graph_runtime_module + .get_function("get_output", false) + .unwrap(); + // execute the get output function + call_packed!(get_output_fn, &0, &output)?; + // flatten the output as Vec<f32> + let output = output.to_vec::<f32>()?; +``` + +and the model correctly predicts the input image as **tiger cat**. + +## Installations + +Please follow TVM [installations](https://tvm.apache.org/docs/install/index.html), `export TVM_HOME=/path/to/tvm` and add `libtvm_runtime` to your `LD_LIBRARY_PATH`. + +*Note:* To run the end-to-end examples and tests, `tvm` and `topi` need to be added to your `PYTHONPATH` or it's automatic via an Anaconda environment when it is installed individually. + +## Supported TVM Functionalities + +### Use TVM to Generate Shared Library + +One can use the following Python snippet to generate `add_gpu.so` which add two vectors on GPU. + +```python +import os +import tvm +from tvm import te +from tvm.contrib import cc + +def test_add(target_dir): + if not tvm.runtime.enabled("cuda"): + print("skip {__file__} because cuda is not enabled...".format(__file__=__file__)) + return + n = te.var("n") + A = te.placeholder((n,), name='A') + B = te.placeholder((n,), name='B') + C = te.compute(A.shape, lambda i: A[i] + B[i], name="C") + s = te.create_schedule(C.op) + bx, tx = s[C].split(C.op.axis[0], factor=64) + s[C].bind(bx, tvm.thread_axis("blockIdx.x")) + s[C].bind(tx, tvm.thread_axis("threadIdx.x")) + fadd_cuda = tvm.build(s, [A, B, C], "cuda", target_host="llvm", name="myadd") + + fadd_cuda.save(os.path.join(target_dir, "add_gpu.o")) + fadd_cuda.imported_modules[0].save(os.path.join(target_dir, "add_gpu.ptx")) + cc.create_shared(os.path.join(target_dir, "add_gpu.so"), + [os.path.join(target_dir, "add_gpu.o")]) + + +if __name__ == "__main__": + import sys + if len(sys.argv) != 2: + sys.exit(-1) + test_add(sys.argv[1]) +``` + +### Run the Generated Shared Library + +The following code snippet demonstrates how to load and test the generated shared library (`add_gpu.so`) in Rust. + +```rust +extern crate tvm_frontend as tvm; + +use tvm::*; + +fn main() { + let shape = &mut [2]; + let mut data = vec![3f32, 4.0]; + let mut arr = empty(shape, TVMContext::gpu(0), TVMType::from("float32")); + arr.copy_from_buffer(data.as_mut_slice()); + let mut ret = empty(shape, TVMContext::gpu(0), TVMType::from("float32")); + let mut fadd = Module::load(&Path::new("add_gpu.so")).unwrap(); + let fadd_dep = Module::load(&Path::new("add_gpu.ptx")).unwrap(); + assert!(fadd.enabled("gpu")); + fadd.import_module(fadd_dep); + fadd.entry(); + function::Builder::from(&mut fadd) + .arg(&arr) + .arg(&arr) + .set_output(&mut ret)? + .invoke() + .unwrap(); + + assert_eq!(ret.to_vec::<f32>().unwrap(), vec![6f32, 8.0]); +} +``` + +**Note:** it is required to instruct the `rustc` to link to the generated `add_gpu.so` in runtime, for example by +`cargo:rustc-link-search=native=add_gpu`. + +See the tests and examples custom `build.rs` for more details. + +### Convert and Register a Rust Function as a TVM Packed Function + +One can use `register_global_func!` macro to convert and register a Rust +function of type `fn(&[ArgValue]) -> Result<RetValue>` to a global TVM **packed function** as follows + +```rust +#[macro_use] +extern crate tvm_frontend as tvm; +use std::convert::TryInto; +use tvm::*; + +fn main() { + register_global_func! { + fn sum(args: &[ArgValue]) -> Result<RetValue> { + let mut ret = 0f32; + let shape = &mut [2]; + for arg in args.iter() { + let e = empty(shape, TVMContext::cpu(0), TVMType::from("float32")); + let arg: NDArray = arg.try_into()?; + let arr = arg.copy_to_ndarray(e).unwrap(); + let rnd: ArrayD<f32> = ArrayD::try_from(&arr).unwrap(); + ret += rnd.scalar_sum(); + } + let ret_val = RetValue::from(&ret); + Ok(ret_val) + } + } + + let shape = &mut [2]; + let mut data = vec![3f32, 4.0]; + let mut arr = empty(shape, TVMContext::cpu(0), TVMType::from("float32")); + arr.copy_from_buffer(data.as_mut_slice()); + let mut registered = function::Builder::default(); + let ret: f64 = registered + .get_function("sum", true) + .arg(&arr) + .arg(&arr) + .invoke() + .unwrap() + .try_into() + .unwrap(); + + assert_eq!(ret, 14f64); +} +``` diff --git a/rust/macros/Cargo.toml b/rust/tvm-rt/examples/resnet/Cargo.toml similarity index 74% copy from rust/macros/Cargo.toml copy to rust/tvm-rt/examples/resnet/Cargo.toml index 784b35e..dbf59f3 100644 --- a/rust/macros/Cargo.toml +++ b/rust/tvm-rt/examples/resnet/Cargo.toml @@ -16,21 +16,14 @@ # under the License. [package] -name = "tvm-macros" -version = "0.1.1" -license = "Apache-2.0" -description = "Procedural macros of the TVM crate." -repository = "https://github.com/apache/incubator-tvm" -readme = "README.md" -keywords = ["tvm"] +name = "resnet" +version = "0.0.0" authors = ["TVM Contributors"] -edition = "2018" - -[lib] -proc-macro = true +license = "Apache-2.0" +build = "build.rs" [dependencies] -goblin = "0.0.24" -proc-macro2 = "^1.0" -quote = "1.0" -syn = "1.0" +ndarray = "0.12" +tvm-frontend = { path = "../../" } +image = "0.20" +csv = "1.1" diff --git a/rust/tvm-rt/examples/resnet/README.md b/rust/tvm-rt/examples/resnet/README.md new file mode 100644 index 0000000..d6e32f7 --- /dev/null +++ b/rust/tvm-rt/examples/resnet/README.md @@ -0,0 +1,45 @@ +<!--- Licensed to the Apache Software Foundation (ASF) under one --> +<!--- or more contributor license agreements. See the NOTICE file --> +<!--- distributed with this work for additional information --> +<!--- regarding copyright ownership. The ASF licenses this file --> +<!--- to you under the Apache License, Version 2.0 (the --> +<!--- "License"); you may not use this file except in compliance --> +<!--- with the License. You may obtain a copy of the License at --> + +<!--- http://www.apache.org/licenses/LICENSE-2.0 --> + +<!--- Unless required by applicable law or agreed to in writing, --> +<!--- software distributed under the License is distributed on an --> +<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY --> +<!--- KIND, either express or implied. See the License for the --> +<!--- specific language governing permissions and limitations --> +<!--- under the License. --> + +## Resnet example + +This end-to-end example shows how to: +* build `Resnet 18` with `tvm` from Python +* use the provided Rust frontend API to test for an input image + +To run the example with pretrained resnet weights, first `tvm` and `mxnet` must be installed for the python build. To install mxnet for cpu, run `pip install mxnet` +and to install `tvm` with `llvm` follow the [TVM installation guide](https://tvm.apache.org/docs/install/index.html). + +* **Build the example**: `cargo build + +To have a successful build, note that it is required to instruct Rust compiler to link to the compiled shared library, for example with +`println!("cargo:rustc-link-search=native={}", build_path)`. See the `build.rs` for more details. + +* **Run the example**: `cargo run` + +Note: To use pretrained weights, one can enable `--pretrained` in `build.rs` with + +``` +let output = Command::new("python") + .arg(concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py")) + .arg(&format!("--build-dir={}", env!("CARGO_MANIFEST_DIR"))) + .arg(&format!("--pretrained")) + .output() + .expect("Failed to execute command"); +``` + +Otherwise, *random weights* are used, therefore, the prediction will be `limpkin, Aramus pictus`! diff --git a/rust/tvm-rt/examples/resnet/build.rs b/rust/tvm-rt/examples/resnet/build.rs new file mode 100644 index 0000000..b9a3c4c --- /dev/null +++ b/rust/tvm-rt/examples/resnet/build.rs @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{path::Path, process::Command}; + +fn main() { + let output = Command::new("python3") + .arg(concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py")) + .arg(&format!("--build-dir={}", env!("CARGO_MANIFEST_DIR"))) + .output() + .expect("Failed to execute command"); + assert!( + Path::new(&format!("{}/deploy_lib.o", env!("CARGO_MANIFEST_DIR"))).exists(), + "Could not prepare demo: {}", + String::from_utf8(output.stderr) + .unwrap() + .trim() + .split("\n") + .last() + .unwrap_or("") + ); + println!( + "cargo:rustc-link-search=native={}", + env!("CARGO_MANIFEST_DIR") + ); +} diff --git a/rust/tvm-rt/examples/resnet/src/build_resnet.py b/rust/tvm-rt/examples/resnet/src/build_resnet.py new file mode 100644 index 0000000..49c67bf --- /dev/null +++ b/rust/tvm-rt/examples/resnet/src/build_resnet.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import argparse +import csv +import logging +from os import path as osp +import sys + +import numpy as np + +import tvm +from tvm import te +from tvm import relay +from tvm.relay import testing +from tvm.contrib import graph_runtime, cc + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +parser = argparse.ArgumentParser(description='Resnet build example') +aa = parser.add_argument +aa('--build-dir', type=str, required=True, help='directory to put the build artifacts') +aa('--pretrained', action='store_true', help='use a pretrained resnet') +aa('--batch-size', type=int, default=1, help='input image batch size') +aa('--opt-level', type=int, default=3, + help='level of optimization. 0 is unoptimized and 3 is the highest level') +aa('--target', type=str, default='llvm', help='target context for compilation') +aa('--image-shape', type=str, default='3,224,224', help='input image dimensions') +aa('--image-name', type=str, default='cat.png', help='name of input image to download') +args = parser.parse_args() + +build_dir = args.build_dir +batch_size = args.batch_size +opt_level = args.opt_level +target = tvm.target.create(args.target) +image_shape = tuple(map(int, args.image_shape.split(","))) +data_shape = (batch_size,) + image_shape + +def build(target_dir): + """ Compiles resnet18 with TVM""" + deploy_lib = osp.join(target_dir, 'deploy_lib.o') + if osp.exists(deploy_lib): + return + + if args.pretrained: + # needs mxnet installed + from mxnet.gluon.model_zoo.vision import get_model + + # if `--pretrained` is enabled, it downloads a pretrained + # resnet18 trained on imagenet1k dataset for image classification task + block = get_model('resnet18_v1', pretrained=True) + net, params = relay.frontend.from_mxnet(block, {"data": data_shape}) + # we want a probability so add a softmax operator + net = relay.Function(net.params, relay.nn.softmax(net.body), + None, net.type_params, net.attrs) + else: + # use random weights from relay.testing + net, params = relay.testing.resnet.get_workload( + num_layers=18, batch_size=batch_size, image_shape=image_shape) + + # compile the model + with relay.build_config(opt_level=opt_level): + graph, lib, params = relay.build_module.build(net, target, params=params) + + # save the model artifacts + lib.save(deploy_lib) + cc.create_shared(osp.join(target_dir, "deploy_lib.so"), + [osp.join(target_dir, "deploy_lib.o")]) + + with open(osp.join(target_dir, "deploy_graph.json"), "w") as fo: + fo.write(graph) + + with open(osp.join(target_dir,"deploy_param.params"), "wb") as fo: + fo.write(relay.save_param_dict(params)) + +def download_img_labels(): + """ Download an image and imagenet1k class labels for test""" + from mxnet.gluon.utils import download + + img_name = 'cat.png' + synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/', + '4d0b62f3d01426887599d4f7ede23ee5/raw/', + '596b27d23537e5a1b5751d2b0481ef172f58b539/', + 'imagenet1000_clsid_to_human.txt']) + synset_name = 'synset.txt' + download('https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true', img_name) + download(synset_url, synset_name) + + with open(synset_name) as fin: + synset = eval(fin.read()) + + with open("synset.csv", "w") as fout: + w = csv.writer(fout) + w.writerows(synset.items()) + +def test_build(build_dir): + """ Sanity check with random input""" + graph = open(osp.join(build_dir, "deploy_graph.json")).read() + lib = tvm.runtime.load(osp.join(build_dir, "deploy_lib.so")) + params = bytearray(open(osp.join(build_dir,"deploy_param.params"), "rb").read()) + input_data = tvm.nd.array(np.random.uniform(size=data_shape).astype("float32")) + ctx = tvm.cpu() + module = graph_runtime.create(graph, lib, ctx) + module.load_params(params) + module.run(data=input_data) + out = module.get_output(0).asnumpy() + + +if __name__ == '__main__': + logger.info("building the model") + build(build_dir) + logger.info("build was successful") + logger.info("test the build artifacts") + test_build(build_dir) + logger.info("test was successful") + if args.pretrained: + download_img_labels() + logger.info("image and synset downloads are successful") diff --git a/rust/tvm-rt/examples/resnet/src/main.rs b/rust/tvm-rt/examples/resnet/src/main.rs new file mode 100644 index 0000000..8b74b65 --- /dev/null +++ b/rust/tvm-rt/examples/resnet/src/main.rs @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +extern crate csv; +extern crate image; +extern crate ndarray; +extern crate tvm_frontend as tvm; + +use std::{ + collections::HashMap, + convert::TryInto, + fs::{self, File}, + path::Path, + str::FromStr, +}; + +use image::{FilterType, GenericImageView}; +use ndarray::{Array, ArrayD, Axis}; + +use tvm::*; + +fn main() { + let ctx = TVMContext::cpu(0); + let img = image::open(concat!(env!("CARGO_MANIFEST_DIR"), "/cat.png")).unwrap(); + println!("original image dimensions: {:?}", img.dimensions()); + // for bigger size images, one needs to first resize to 256x256 + // with `img.resize_exact` method and then `image.crop` to 224x224 + let img = img.resize(224, 224, FilterType::Nearest).to_rgb(); + println!("resized image dimensions: {:?}", img.dimensions()); + let mut pixels: Vec<f32> = vec![]; + for pixel in img.pixels() { + let tmp = pixel.data; + // normalize the RGB channels using mean, std of imagenet1k + let tmp = [ + (tmp[0] as f32 - 123.0) / 58.395, // R + (tmp[1] as f32 - 117.0) / 57.12, // G + (tmp[2] as f32 - 104.0) / 57.375, // B + ]; + for e in &tmp { + pixels.push(*e); + } + } + + let arr = Array::from_shape_vec((224, 224, 3), pixels).unwrap(); + let arr: ArrayD<f32> = arr.permuted_axes([2, 0, 1]).into_dyn(); + // make arr shape as [1, 3, 224, 224] acceptable to resnet + let arr = arr.insert_axis(Axis(0)); + // create input tensor from rust's ndarray + let input = NDArray::from_rust_ndarray( + &arr, + TVMContext::cpu(0), + DLDataType::from_str("float32").unwrap(), + ) + .unwrap(); + println!( + "input size is {:?}", + input.shape().expect("cannot get the input shape") + ); + let graph = + fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_graph.json")).unwrap(); + // load the built module + let lib = Module::load(&Path::new(concat!( + env!("CARGO_MANIFEST_DIR"), + "/deploy_lib.so" + ))) + .unwrap(); + // get the global TVM graph runtime function + let runtime_create_fn = Function::get("tvm.graph_runtime.create").unwrap(); + let runtime_create_fn_ret = call_packed!( + runtime_create_fn, + graph, + &lib, + &ctx.device_type, + &ctx.device_id + ) + .unwrap(); + // get graph runtime module + let graph_runtime_module: Module = runtime_create_fn_ret.try_into().unwrap(); + // get the registered `load_params` from runtime module + let ref load_param_fn = graph_runtime_module + .get_function("load_params", false) + .unwrap(); + // parse parameters and convert to ByteArray + let params: Vec<u8> = + fs::read(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_param.params")).unwrap(); + let barr = ByteArray::from(¶ms); + // load the parameters + call_packed!(load_param_fn, &barr).unwrap(); + // get the set_input function + let ref set_input_fn = graph_runtime_module + .get_function("set_input", false) + .unwrap(); + + call_packed!(set_input_fn, "data".to_string(), &input).unwrap(); + // get `run` function from runtime module + let ref run_fn = graph_runtime_module.get_function("run", false).unwrap(); + // execute the run function. Note that it has no argument + call_packed!(run_fn,).unwrap(); + // prepare to get the output + let output_shape = &mut [1, 1000]; + let output = NDArray::empty( + output_shape, + TVMContext::cpu(0), + DLDataType::from_str("float32").unwrap(), + ); + // get the `get_output` function from runtime module + let ref get_output_fn = graph_runtime_module + .get_function("get_output", false) + .unwrap(); + // execute the get output function + call_packed!(get_output_fn, &0, &output).unwrap(); + // flatten the output as Vec<f32> + let output = output.to_vec::<f32>().unwrap(); + // find the maximum entry in the output and its index + let mut argmax = -1; + let mut max_prob = 0.; + for i in 0..output.len() { + if output[i] > max_prob { + max_prob = output[i]; + argmax = i as i32; + } + } + // create a hash map of (class id, class name) + let mut synset: HashMap<i32, String> = HashMap::new(); + let file = File::open("synset.csv").unwrap(); + let mut rdr = csv::ReaderBuilder::new() + .has_headers(true) + .from_reader(file); + + for result in rdr.records() { + let record = result.unwrap(); + let id: i32 = record[0].parse().unwrap(); + let cls = record[1].to_string(); + synset.insert(id, cls); + } + + println!( + "input image belongs to the class `{}` with probability {}", + synset + .get(&argmax) + .expect("cannot find the class id for argmax"), + max_prob + ); +} diff --git a/rust/tvm-rt/src/context.rs b/rust/tvm-rt/src/context.rs new file mode 100644 index 0000000..bceae5e --- /dev/null +++ b/rust/tvm-rt/src/context.rs @@ -0,0 +1,76 @@ +use tvm_sys::ffi; +pub use tvm_sys::context::*; + +use std::os::raw::c_void; +use std::ptr; + +trait ContextExt { + /// Checks whether the context exists or not. + fn exist(&self) -> bool; + fn sync(&self) -> anyhow::Result<()>; + fn max_threads_per_block(&self) -> isize; + fn warp_size(&self) -> isize; + fn max_shared_memory_per_block(&self) -> isize; + fn compute_version(&self) -> isize; + fn device_name(&self) -> isize; + fn max_clock_rate(&self) -> isize; + fn multi_processor_count(&self) -> isize; + fn max_thread_dimensions(&self) -> isize; +} + +macro_rules! impl_device_attrs { + ($(($attr_name:ident, $attr_kind:expr));+) => { + $( + fn $attr_name(&self) -> isize { + get_device_attr(self.device_type.0 as i32, self.device_id as i32, 0) + .expect("should not fail") as isize + } + + )+ + }; +} + +external_func! { + fn get_device_attr(device_type: i32, device_id: i32, device_kind: i32) -> i32 as "runtime.GetDeviceAttr"; +} + + +impl ContextExt for Context { + fn exist(&self) -> bool { + let exists = get_device_attr(self.device_type.0 as i32, self.device_id as i32, 0) + .expect("should not fail"); + + exists != 0 + } + + /// Synchronize the context stream. + fn sync(&self) -> anyhow::Result<()> { + check_call!(ffi::TVMSynchronize( + self.device_type.0 as i32, + self.device_id as i32, + ptr::null_mut() as *mut c_void + )); + Ok(()) + } + + impl_device_attrs!((max_threads_per_block, 1); + (warp_size, 2); + (max_shared_memory_per_block, 3); + (compute_version, 4); + (device_name, 5); + (max_clock_rate, 6); + (multi_processor_count, 7); + (max_thread_dimensions, 8)); +} + + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn sync() { + let ctx = Context::cpu(0); + assert!(ctx.sync().is_ok()) + } +} diff --git a/rust/tvm-rt/src/errors.rs b/rust/tvm-rt/src/errors.rs new file mode 100644 index 0000000..77dbba7 --- /dev/null +++ b/rust/tvm-rt/src/errors.rs @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use thiserror::Error; + +#[derive(Debug, Error)] +#[error("Cannot convert from an empty array.")] +pub struct EmptyArrayError; + +#[derive(Debug, Error)] +#[error("Handle `{name}` is null.")] +pub struct NullHandleError { + pub name: String, +} + +#[derive(Debug, Error)] +#[error("Function was not set in `function::Builder`")] +pub struct FunctionNotFoundError; + +#[derive(Debug, Error)] +#[error("Expected type `{expected}` but found `{actual}`")] +pub struct TypeMismatchError { + pub expected: String, + pub actual: String, +} + +#[derive(Debug, Error)] +#[error("Missing NDArray shape.")] +pub struct MissingShapeError; diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs new file mode 100644 index 0000000..739c7a0 --- /dev/null +++ b/rust/tvm-rt/src/function.rs @@ -0,0 +1,340 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! This module provides an idiomatic Rust API for creating and working with TVM functions. +//! +//! For calling an already registered TVM function use [`function::Builder`] +//! To register a TVM packed function from Rust side either +//! use [`function::register`] or the macro [`register_global_func`]. +//! +//! See the tests and examples repository for more examples. + +use std::{ + collections::BTreeMap, + ffi::{CStr, CString}, + mem::{self, MaybeUninit}, + os::raw::{c_char, c_int}, + ptr, slice, str, + sync::Mutex, +}; + +use anyhow::{Result}; +use lazy_static::lazy_static; + +pub use tvm_sys::{ffi, ArgValue, RetValue}; + +use super::to_function::{ToFunction, Typed}; +use super::to_boxed_fn::ToBoxedFn; + +lazy_static! { + static ref GLOBAL_FUNCTIONS: Mutex<BTreeMap<String, Option<Function>>> = { + let mut out_size = 0 as c_int; + let mut names_ptr = ptr::null_mut() as *mut *const c_char; + check_call!(ffi::TVMFuncListGlobalNames( + &mut out_size as *mut _, + &mut names_ptr as *mut _, + )); + let names_list = unsafe { slice::from_raw_parts(names_ptr, out_size as usize) }; + + let names_list: Vec<String> = + names_list + .iter() + .map(|&p| unsafe { CStr::from_ptr(p).to_str().unwrap().into() }) + .collect(); + + // println!("{:?}", &names_list); + + let names_list = names_list + .into_iter() + .map(|p| (p, None)) + .collect(); + + Mutex::new(names_list) + }; +} + +/// Wrapper around TVM function handle which includes `is_global` +/// indicating whether the function is global or not, and `is_cloned` showing +/// not to drop a cloned function from Rust side. +/// The value of these fields can be accessed through their respective methods. +#[derive(Debug, Hash)] +pub struct Function { + pub(crate) handle: ffi::TVMFunctionHandle, + // whether the registered function is global or not. + is_global: bool, + // whether the function has been cloned from frontend or not. + is_cloned: bool, +} + +unsafe impl Send for Function {} +unsafe impl Sync for Function {} + +impl Function { + pub(crate) fn new(handle: ffi::TVMFunctionHandle) -> Self { + Function { + handle, + is_global: false, + is_cloned: false, + } + } + + /// For a given function, it returns a function by name. + pub fn get<S: AsRef<str>>(name: S) -> Option<&'static Function> { + let mut globals = GLOBAL_FUNCTIONS.lock().unwrap(); + globals.get_mut(name.as_ref()).and_then(|maybe_func| { + if maybe_func.is_none() { + let name = CString::new(name.as_ref()).unwrap(); + let mut handle = ptr::null_mut() as ffi::TVMFunctionHandle; + check_call!(ffi::TVMFuncGetGlobal( + name.as_ptr() as *const c_char, + &mut handle as *mut _ + )); + maybe_func.replace(Function { + handle, + is_global: true, + is_cloned: false, + }); + } + + unsafe { + mem::transmute::<Option<&Function>, Option<&'static Function>>(maybe_func.as_ref()) + } + }) + } + + /// Returns the underlying TVM function handle. + pub fn handle(&self) -> ffi::TVMFunctionHandle { + self.handle + } + + /// Returns `true` if the underlying TVM function is global and `false` otherwise. + pub fn is_global(&self) -> bool { + self.is_global + } + + /// Returns `true` if the underlying TVM function has been cloned + /// from the frontend and `false` otherwise. + pub fn is_cloned(&self) -> bool { + self.is_cloned + } + + /// Calls the function that created from `Builder`. + pub fn invoke<'a>(&self, arg_buf: Vec<ArgValue<'a>>) -> Result<RetValue> { + let num_args = arg_buf.len(); + let (mut values, mut type_codes): (Vec<ffi::TVMValue>, Vec<ffi::TVMTypeCode>) = + arg_buf.iter().map(|arg| arg.to_tvm_value()).unzip(); + + let mut ret_val = unsafe { MaybeUninit::uninit().assume_init() }; + let mut ret_type_code = 0i32; + check_call!(ffi::TVMFuncCall( + self.handle, + values.as_mut_ptr(), + type_codes.as_mut_ptr() as *mut i32, + num_args as c_int, + &mut ret_val as *mut _, + &mut ret_type_code as *mut _ + )); + + Ok(RetValue::from_tvm_value(ret_val, ret_type_code as u32)) + } + + pub fn to_boxed_fn<F: ?Sized>(&'static self) -> Box<F> where F: ToBoxedFn { + F::to_boxed_fn(self) + } +} + +impl Clone for Function { + fn clone(&self) -> Function { + Self { + handle: self.handle, + is_global: self.is_global, + is_cloned: true, + } + } +} + +impl Drop for Function { + fn drop(&mut self) { + if !self.is_global && !self.is_cloned { + check_call!(ffi::TVMFuncFree(self.handle)); + } + } +} + +/// Registers a Rust function with signature +/// `fn(&[ArgValue]) -> Result<RetValue, Error>` +/// as a **global TVM packed function** from frontend to TVM backend. +/// +/// Use [`register_global_func`] if overriding an existing global TVM function +/// is not required. +/// +/// ## Example +/// +/// ``` +/// # use tvm_rt::{ArgValue, function, RetValue}; +/// # use tvm_rt::function::Builder; +/// # use anyhow::Error; +/// use std::convert::TryInto; +/// +/// fn sum(args: &[ArgValue]) -> Result<RetValue, Error> { +/// let mut ret = 0i64; +/// for arg in args.iter() { +/// let arg: i64 = arg.try_into()?; +/// ret += arg; +/// } +/// let ret_val = RetValue::from(ret); +/// Ok(ret_val) +/// } +/// +/// function::register(sum, "mysum".to_owned()).unwrap(); +/// let mut registered = Builder::default(); +/// registered.get_function("mysum"); +/// assert!(registered.func.is_some()); +/// let ret: i64 = registered.args(&[10, 20, 30]).invoke().unwrap().try_into().unwrap(); +/// assert_eq!(ret, 60); +/// ``` +pub fn register<F, I, O, S: Into<String>>(f: F, name: S) -> Result<()> +where + F: ToFunction<I, O>, + F: Typed<I, O>, +{ + register_override(f, name, false) +} + +/// Registers a Rust function with signature +/// `fn(&[ArgValue]) -> Result<RetValue, Error>` +/// as a **global TVM packed function** from frontend to TVM backend. +/// +/// Use [`register_global_func`] if overriding an existing global TVM function +/// is not required. +/// +/// ## Example +/// +/// ``` +/// # use tvm_rt::{ArgValue, function, RetValue}; +/// # use tvm_rt::function::Builder; +/// # use anyhow::Error; +/// use std::convert::TryInto; +/// +/// fn sum(args: &[ArgValue]) -> Result<RetValue, Error> { +/// let mut ret = 0i64; +/// for arg in args.iter() { +/// let arg: i64 = arg.try_into()?; +/// ret += arg; +/// } +/// let ret_val = RetValue::from(ret); +/// Ok(ret_val) +/// } +/// +/// function::register_override(sum, "mysum".to_owned(), false).unwrap(); +/// let mut registered = Builder::default(); +/// registered.get_function("mysum"); +/// assert!(registered.func.is_some()); +/// let ret: i64 = registered.args(&[10, 20, 30]).invoke().unwrap().try_into().unwrap(); +/// assert_eq!(ret, 60); +/// ``` +pub fn register_override<F, I, O, S: Into<String>>(f: F, name: S, override_: bool) -> Result<()> +where + F: ToFunction<I, O>, + F: Typed<I, O>, +{ + let func = f.to_function(); + let name = name.into(); + let mut globals = GLOBAL_FUNCTIONS.lock().unwrap(); + // Not sure about this code + let handle = func.handle(); + globals.insert(name.clone(), Some(func)); + let name= CString::new(name)?; + check_call!(ffi::TVMFuncRegisterGlobal( + name.into_raw(), + handle, + override_ as c_int + )); + + Ok(()) +} + +#[macro_export] +macro_rules! external_func { + (fn $name:ident ( $($arg:ident : $ty:ty),* ) -> $ret_type:ty as $ext_name:literal;) => { + ::paste::item! { + #[allow(non_upper_case_globals)] + static [<global_ $name>]: ::once_cell::sync::Lazy<&'static $crate::Function> = + ::once_cell::sync::Lazy::new(|| { + $crate::Function::get($ext_name) + .expect(concat!("unable to load external function", stringify!($ext_name), "from TVM registry.")) + }); + } + + pub fn $name($($arg : $ty),*) -> Result<$ret_type, anyhow::Error> { + let func_ref: &$crate::Function = ::paste::expr! { &*[<global_ $name>] }; + let func_ref: Box<dyn Fn($($ty),*) -> anyhow::Result<$ret_type>> = func_ref.to_boxed_fn(); + let res: $ret_type = func_ref($($arg),*)?; + Ok(res) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::function::{Function}; + + static CANARY: &str = "runtime.ModuleLoadFromFile"; + + // #[test] + // fn list_global_func() { + // assert!(GLOBAL_FUNCTIONS.lock().unwrap().contains_key(CANARY)); + // } + + #[test] + fn get_fn() { + assert!(Function::get(CANARY).is_some()); + assert!(Function::get("does not exists!").is_none()); + } + + #[test] + fn register_and_call_closure0() { + use crate::function; + + fn constfn() -> i64 { + return 10; + } + + function::register_override(constfn, "constfn".to_owned(), true).unwrap(); + let func = Function::get("constfn").unwrap(); + let func = func.to_boxed_fn::<dyn Fn() -> Result<i32>>(); + let ret = func().unwrap(); + assert_eq!(ret, 10); + } + + // #[test] + // fn register_and_call_closure1() { + // use crate::function::{self}; + + // fn ident(x: i64) -> i64 { + // return x; + // } + + // function::register_override(ident, "ident".to_owned(), false).unwrap(); + // let func = Function::get("ident").unwrap(); + // let func = func.to_boxed_fn::<dyn Fn(i32) -> Result<i32>>(); + // assert_eq!(func(60).unwrap(), 60); + // } +} diff --git a/rust/tvm-rt/src/lib.rs b/rust/tvm-rt/src/lib.rs new file mode 100644 index 0000000..e9ae02f --- /dev/null +++ b/rust/tvm-rt/src/lib.rs @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! [TVM](https://github.com/apache/incubator-tvm) is a compiler stack for deep learning systems. +//! +//! This crate provides an idiomatic Rust API for TVM runtime frontend. +//! +//! One particular use case is that given optimized deep learning model artifacts, +//! (compiled with TVM) which include a shared library +//! `lib.so`, `graph.json` and a byte-array `param.params`, one can load them +//! in Rust idomatically to create a TVM Graph Runtime and +//! run the model for some inputs and get the +//! desired predictions *all in Rust*. +//! +//! Checkout the `examples` repository for more details. + +extern crate ndarray as rust_ndarray; + +pub use crate as tvm_rt; + +pub mod object; +pub mod string; + +pub use object::*; +pub use string::*; + +use std::{ + ffi::{CStr, CString}, + str, +}; + +use anyhow::Error; + +pub use crate::{ + context::{Context, TVMDeviceType}, + errors::*, + function::Function, + module::Module, + ndarray::NDArray, +}; + +pub use function::{ArgValue, RetValue}; +pub use tvm_sys::byte_array::ByteArray; +pub use tvm_sys::datatype::DataType; + +use tvm_sys::ffi; + +// Macro to check the return call to TVM runtime shared library. +#[macro_export] +macro_rules! check_call { + ($e:expr) => {{ + if unsafe { $e } != 0 { + panic!("{}", $crate::get_last_error()); + } + }}; +} + +/// Gets the last error message. +pub fn get_last_error() -> &'static str { + unsafe { + match CStr::from_ptr(ffi::TVMGetLastError()).to_str() { + Ok(s) => s, + Err(_) => "Invalid UTF-8 message", + } + } +} + +pub(crate) fn set_last_error(err: &Error) { + let c_string = CString::new(err.to_string()).unwrap(); + unsafe { + ffi::TVMAPISetLastError(c_string.as_ptr()); + } +} + +#[macro_use] +pub mod function; +pub mod context; +pub mod errors; +pub mod module; +pub mod ndarray; +pub mod to_function; +pub mod to_boxed_fn; +pub mod value; + +/// Outputs the current TVM version. +pub fn version() -> &'static str { + match str::from_utf8(ffi::TVM_VERSION) { + Ok(s) => s, + Err(_) => "Invalid UTF-8 string", + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn print_version() { + println!("TVM version: {}", version()); + } + + #[test] + fn set_error() { + let err = errors::EmptyArrayError; + set_last_error(&err.into()); + assert_eq!(get_last_error().trim(), errors::EmptyArrayError.to_string()); + } +} diff --git a/rust/tvm-rt/src/module.rs b/rust/tvm-rt/src/module.rs new file mode 100644 index 0000000..f9b49d9 --- /dev/null +++ b/rust/tvm-rt/src/module.rs @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! Provides the [`Module`] type and methods for working with runtime TVM modules. + +use std::{ + ffi::CString, + os::raw::{c_char, c_int}, + path::Path, + ptr, +}; + +use anyhow::{anyhow, ensure, Error}; +use tvm_sys::ffi; + +use crate::{errors, function::Function}; + +const ENTRY_FUNC: &str = "__tvm_main__"; + +/// Wrapper around TVM module handle which contains an entry function. +/// The entry function can be applied to an imported module through [`entry_func`]. +/// +/// [`entry_func`]:struct.Module.html#method.entry_func +#[derive(Debug, Clone)] +pub struct Module { + pub(crate) handle: ffi::TVMModuleHandle, + entry_func: Option<Function>, +} + + +external_func! { + fn runtime_enabled(target: CString) -> i32 as "runtime.RuntimeEnabled"; +} + +external_func! { + fn load_from_file(file_name: CString, format: CString) -> Module as "runtime.ModuleLoadFromFile"; +} + + +impl Module { + pub(crate) fn new(handle: ffi::TVMModuleHandle) -> Self { + Self { + handle, + entry_func: None, + } + } + + pub fn entry(&mut self) -> Option<&Function> { + if self.entry_func.is_none() { + self.entry_func = self.get_function(ENTRY_FUNC, false).ok(); + } + self.entry_func.as_ref() + } + + /// Gets a function by name from a registered module. + pub fn get_function(&self, name: &str, query_import: bool) -> Result<Function, Error> { + let name = CString::new(name)?; + let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle; + check_call!(ffi::TVMModGetFunction( + self.handle, + name.as_ptr() as *const c_char, + query_import as c_int, + &mut fhandle as *mut _ + )); + ensure!( + !fhandle.is_null(), + errors::NullHandleError { + name: name.into_string()?.to_string() + } + ); + Ok(Function::new(fhandle)) + } + + /// Imports a dependent module such as `.ptx` for gpu. + pub fn import_module(&self, dependent_module: Module) { + check_call!(ffi::TVMModImport(self.handle, dependent_module.handle)) + } + + /// Loads a module shared library from path. + pub fn load<P: AsRef<Path>>(path: &P) -> Result<Module, Error> { + let ext = CString::new( + path.as_ref() + .extension() + .unwrap_or_else(|| std::ffi::OsStr::new("")) + .to_str() + .ok_or_else(|| anyhow!("Bad module load path: `{}`.", path.as_ref().display()))?, + )?; + let cpath = CString::new( + path.as_ref() + .to_str() + .ok_or_else(|| anyhow!("Bad module load path: `{}`.", path.as_ref().display()))?, + )?; + let module = load_from_file(cpath, ext)?; + Ok(module) + } + + /// Checks if a target device is enabled for a module. + pub fn enabled(&self, target: &str) -> bool { + let target = CString::new(target).unwrap(); + let enabled = runtime_enabled(target).unwrap(); + enabled != 0 + } + + /// Returns the underlying module handle. + pub fn handle(&self) -> ffi::TVMModuleHandle { + self.handle + } +} + +impl Drop for Module { + fn drop(&mut self) { + check_call!(ffi::TVMModFree(self.handle)); + } +} diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs new file mode 100644 index 0000000..4653117 --- /dev/null +++ b/rust/tvm-rt/src/ndarray.rs @@ -0,0 +1,431 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! This module implements the [`NDArray`] type for working with *TVM tensors* or +//! coverting from a Rust's ndarray to TVM `NDArray`. +//! +//! One can create an empty NDArray given the shape, context and dtype using [`empty`]. +//! To create an NDArray from a mutable buffer in cpu use [`copy_from_buffer`]. +//! To copy an NDArray to different context use [`copy_to_ctx`]. +//! +//! Given a [`Rust's dynamic ndarray`], one can convert it to TVM NDArray as follows: +//! +//! # Example +//! +//! ``` +//! # use tvm_rt::{NDArray, Context, DataType}; +//! # use ndarray::{Array, ArrayD}; +//! # use std::str::FromStr; +//! use std::convert::TryFrom; +//! +//! let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.]) +//! .unwrap() +//! .into_dyn(); // Rust's ndarray +//! let nd = NDArray::from_rust_ndarray(&a, Context::cpu(0), DataType::from_str("float32").unwrap()).unwrap(); +//! assert_eq!(nd.shape(), Some(&mut [2, 2][..])); +//! let rnd: ArrayD<f32> = ArrayD::try_from(&nd).unwrap(); +//! assert!(rnd.all_close(&a, 1e-8f32)); +//! ``` +//! +//! [`Rust's dynamic ndarray`]:https://docs.rs/ndarray/0.12.1/ndarray/ +//! [`copy_from_buffer`]:struct.NDArray.html#method.copy_from_buffer +//! [`copy_to_ctx`]:struct.NDArray.html#method.copy_to_ctx + +use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr}; + +use crate::errors; +use anyhow::{bail, ensure, Result}; +use num_traits::Num; +use rust_ndarray::{Array, ArrayD}; +use std::convert::TryInto; +use std::ffi::c_void; +use tvm_sys::ffi::DLTensor; +use tvm_sys::{ffi, ByteArray, Context, DataType}; + +/// See the [`module-level documentation`](../ndarray/index.html) for more details. +/// +/// Wrapper around TVM array handle. +#[derive(Debug)] +pub enum NDArray { + Borrowed { handle: ffi::TVMArrayHandle }, + Owned { handle: *mut c_void }, +} + +impl NDArray { + pub(crate) fn new(handle: ffi::TVMArrayHandle) -> Self { + NDArray::Borrowed { handle } + } + + pub(crate) fn from_ndarray_handle(handle: *mut c_void) -> Self { + NDArray::Owned { handle } + } + + pub fn as_dltensor(&self) -> &DLTensor { + unsafe { + match self { + NDArray::Borrowed { ref handle } => std::mem::transmute(*handle), + NDArray::Owned { ref handle } => std::mem::transmute(*handle), + } + } + } + + pub(crate) fn as_raw_dltensor(&self) -> *mut DLTensor { + unsafe { + match self { + NDArray::Borrowed { ref handle } => std::mem::transmute(*handle), + NDArray::Owned { ref handle } => std::mem::transmute(*handle), + } + } + } + + pub fn is_view(&self) -> bool { + if let &NDArray::Borrowed { .. } = self { + true + } else { + false + } + } + + /// Returns the shape of the NDArray. + pub fn shape(&self) -> Option<&mut [usize]> { + let arr = self.as_dltensor(); + if arr.shape.is_null() || arr.data.is_null() { + return None; + }; + let slc = unsafe { slice::from_raw_parts_mut(arr.shape as *mut usize, arr.ndim as usize) }; + Some(slc) + } + + /// Returns the total number of entries of the NDArray. + pub fn size(&self) -> Option<usize> { + self.shape().map(|v| v.iter().product()) + } + + /// Returns the context which the NDArray was defined. + pub fn ctx(&self) -> Context { + self.as_dltensor().ctx.into() + } + + /// Returns the type of the entries of the NDArray. + pub fn dtype(&self) -> DataType { + self.as_dltensor().dtype.into() + } + + /// Returns the number of dimensions of the NDArray. + pub fn ndim(&self) -> usize { + self.as_dltensor() + .ndim + .try_into() + .expect("number of dimensions must always be positive") + } + + /// Returns the strides of the underlying NDArray. + pub fn strides(&self) -> Option<&[usize]> { + unsafe { + let sz = self.ndim() * mem::size_of::<usize>(); + let strides_ptr = self.as_dltensor().strides as *const usize; + let slc = slice::from_raw_parts(strides_ptr, sz); + Some(slc) + } + } + + /// Shows whether the underlying ndarray is contiguous in memory or not. + pub fn is_contiguous(&self) -> Result<bool> { + Ok(match self.strides() { + None => true, + Some(strides) => { + // errors::MissingShapeError in case shape is not determined + self.shape() + .ok_or(errors::MissingShapeError)? + .iter() + .zip(strides) + .rfold( + (true, 1), + |(is_contig, expected_stride), (shape, stride)| { + ( + is_contig && *stride == expected_stride, + expected_stride * (*shape as usize), + ) + }, + ) + .0 + } + }) + } + + pub fn byte_offset(&self) -> isize { + self.as_dltensor().byte_offset as isize + } + + /// Flattens the NDArray to a `Vec` of the same type in cpu. + /// + /// ## Example + /// + /// ``` + /// # use tvm_rt::{Context, DataType, NDArray}; + /// # use std::str::FromStr; + /// let mut shape = [4]; + /// let mut data = vec![1i32, 2, 3, 4]; + /// let ctx = Context::cpu(0); + /// let mut ndarray = NDArray::empty(&mut shape, ctx, DataType::from_str("int32").unwrap()); + /// ndarray.copy_from_buffer(&mut data); + /// assert_eq!(ndarray.shape(), Some(&mut shape[..])); + /// assert_eq!(ndarray.to_vec::<i32>().unwrap(), data); + /// ``` + pub fn to_vec<T>(&self) -> Result<Vec<T>> { + ensure!(self.shape().is_some(), errors::EmptyArrayError); + let earr = NDArray::empty( + self.shape().ok_or(errors::MissingShapeError)?, + Context::cpu(0), + self.dtype(), + ); + let target = self.copy_to_ndarray(earr)?; + let arr = target.as_dltensor(); + let sz = self.size().ok_or(errors::MissingShapeError)?; + let mut v: Vec<T> = Vec::with_capacity(sz * mem::size_of::<T>()); + unsafe { + v.as_mut_ptr() + .copy_from_nonoverlapping(arr.data as *const T, sz); + v.set_len(sz); + } + Ok(v) + } + + /// Converts the NDArray to [`ByteArray`]. + pub fn to_bytearray(&self) -> Result<ByteArray> { + let v = self.to_vec::<u8>()?; + Ok(ByteArray::from(v)) + } + + /// Creates an NDArray from a mutable buffer of types i32, u32 or f32 in cpu. + /// + /// ## Example + /// + /// ``` + /// # use tvm_rt::{Context, DataType, NDArray}; + /// # use std::str::FromStr; + /// let shape = &mut [2]; + /// let mut data = vec![1f32, 2.0]; + /// let ctx = Context::cpu(0); + /// let mut ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); + /// ndarray.copy_from_buffer(&mut data); + /// ``` + /// + /// *Note*: if something goes wrong during the copy, it will panic + /// from TVM side. See `TVMArrayCopyFromBytes` in `include/tvm/runtime/c_runtime_api.h`. + pub fn copy_from_buffer<T: Num32>(&mut self, data: &mut [T]) { + check_call!(ffi::TVMArrayCopyFromBytes( + self.as_raw_dltensor(), + data.as_ptr() as *mut _, + data.len() * mem::size_of::<T>() + )); + } + + /// Copies the NDArray to another target NDArray. + pub fn copy_to_ndarray(&self, target: NDArray) -> Result<NDArray> { + if self.dtype() != target.dtype() { + bail!( + "{}", + errors::TypeMismatchError { + expected: self.dtype().to_string(), + actual: target.dtype().to_string(), + } + ); + } + check_call!(ffi::TVMArrayCopyFromTo( + self.as_raw_dltensor(), + target.as_raw_dltensor(), + ptr::null_mut() as ffi::TVMStreamHandle + )); + Ok(target) + } + + /// Copies the NDArray to a target context. + pub fn copy_to_ctx(&self, target: &Context) -> Result<NDArray> { + let tmp = NDArray::empty( + self.shape().ok_or(errors::MissingShapeError)?, + *target, + self.dtype(), + ); + let copy = self.copy_to_ndarray(tmp)?; + Ok(copy) + } + + /// Converts a Rust's ndarray to TVM NDArray. + pub fn from_rust_ndarray<T: Num32 + Copy>( + rnd: &ArrayD<T>, + ctx: Context, + dtype: DataType, + ) -> Result<Self> { + let shape = rnd.shape().to_vec(); + let mut nd = NDArray::empty(&shape, ctx, dtype); + let mut buf = Array::from_iter(rnd.into_iter().map(|&v| v as T)); + nd.copy_from_buffer( + buf.as_slice_mut() + .expect("Array from iter must be contiguous."), + ); + Ok(nd) + } + + /// Allocates and creates an empty NDArray given the shape, context and dtype. + pub fn empty(shape: &[usize], ctx: Context, dtype: DataType) -> NDArray { + let mut handle = ptr::null_mut() as ffi::TVMArrayHandle; + check_call!(ffi::TVMArrayAlloc( + shape.as_ptr() as *const i64, + shape.len() as c_int, + i32::from(dtype.code) as c_int, + i32::from(dtype.bits) as c_int, + i32::from(dtype.lanes) as c_int, + ctx.device_type.0 as c_int, + ctx.device_id as c_int, + &mut handle as *mut _, + )); + NDArray::Borrowed { handle: handle } + } +} + +macro_rules! impl_from_ndarray_rustndarray { + ($type:ty, $type_name:tt) => { + impl<'a> TryFrom<&'a NDArray> for ArrayD<$type> { + type Error = anyhow::Error; + fn try_from(nd: &NDArray) -> Result<ArrayD<$type>> { + ensure!(nd.shape().is_some(), errors::MissingShapeError); + assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type mismatch"); + Ok(Array::from_shape_vec( + &*nd.shape().ok_or(errors::MissingShapeError)?, + nd.to_vec::<$type>()?, + )?) + } + } + + impl<'a> TryFrom<&'a mut NDArray> for ArrayD<$type> { + type Error = anyhow::Error; + fn try_from(nd: &mut NDArray) -> Result<ArrayD<$type>> { + ensure!(nd.shape().is_some(), errors::MissingShapeError); + assert_eq!(nd.dtype(), DataType::from_str($type_name)?, "Type mismatch"); + Ok(Array::from_shape_vec( + &*nd.shape().ok_or(errors::MissingShapeError)?, + nd.to_vec::<$type>()?, + )?) + } + } + }; +} + +impl_from_ndarray_rustndarray!(i32, "int"); +impl_from_ndarray_rustndarray!(u32, "uint"); +impl_from_ndarray_rustndarray!(f32, "float"); + +impl Drop for NDArray { + fn drop(&mut self) { + if let &mut NDArray::Owned { .. } = self { + check_call!(ffi::TVMArrayFree(self.as_raw_dltensor())); + } + } +} + +mod sealed { + /// Private trait to prevent other traits from being implemeneted in downstream crates. + pub trait Sealed {} +} + +/// A trait for the supported 32-bits numerical types in frontend. +pub trait Num32: Num + sealed::Sealed { + const BITS: u8 = 32; +} + +macro_rules! impl_num32 { + ($($type:ty),+) => { + $( + impl sealed::Sealed for $type {} + impl Num32 for $type {} + )+ + }; +} + +impl_num32!(i32, u32, f32); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn basics() { + let shape = &mut [1, 2, 3]; + let ctx = Context::cpu(0); + let ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); + assert_eq!(ndarray.shape().unwrap(), shape); + assert_eq!( + ndarray.size().unwrap(), + shape.to_vec().into_iter().product() + ); + assert_eq!(ndarray.ndim(), 3); + assert!(ndarray.strides().is_none()); + assert_eq!(ndarray.byte_offset(), 0); + } + + #[test] + fn copy() { + let shape = &mut [4]; + let mut data = vec![1i32, 2, 3, 4]; + let ctx = Context::cpu(0); + let mut ndarray = NDArray::empty(shape, ctx, DataType::from_str("int32").unwrap()); + assert!(ndarray.to_vec::<i32>().is_ok()); + ndarray.copy_from_buffer(&mut data); + assert_eq!(ndarray.shape().unwrap(), shape); + assert_eq!(ndarray.to_vec::<i32>().unwrap(), data); + assert_eq!(ndarray.ndim(), 1); + assert!(ndarray.is_contiguous().is_ok()); + assert_eq!(ndarray.byte_offset(), 0); + let shape = vec![4]; + let e = NDArray::empty( + &shape, + Context::cpu(0), + DataType::from_str("int32").unwrap(), + ); + let nd = ndarray.copy_to_ndarray(e); + assert!(nd.is_ok()); + assert_eq!(nd.unwrap().to_vec::<i32>().unwrap(), data); + } + + #[test] + #[should_panic(expected = "called `Result::unwrap()` on an `Err`")] + fn copy_wrong_dtype() { + let shape = vec![4]; + let mut data = vec![1f32, 2., 3., 4.]; + let ctx = Context::cpu(0); + let mut nd_float = NDArray::empty(&shape, ctx, DataType::from_str("float32").unwrap()); + nd_float.copy_from_buffer(&mut data); + let empty_int = NDArray::empty(&shape, ctx, DataType::from_str("int32").unwrap()); + nd_float.copy_to_ndarray(empty_int).unwrap(); + } + + #[test] + fn rust_ndarray() { + let a = Array::from_shape_vec((2, 2), vec![1f32, 2., 3., 4.]) + .unwrap() + .into_dyn(); + let nd = + NDArray::from_rust_ndarray(&a, Context::cpu(0), DataType::from_str("float32").unwrap()) + .unwrap(); + assert_eq!(nd.shape().unwrap(), &mut [2, 2]); + let rnd: ArrayD<f32> = ArrayD::try_from(&nd).unwrap(); + assert!(rnd.all_close(&a, 1e-8f32)); + } +} diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs new file mode 100644 index 0000000..8d8efdf --- /dev/null +++ b/rust/tvm-rt/src/object/mod.rs @@ -0,0 +1,99 @@ +use std::convert::TryFrom; +use std::convert::TryInto; +use std::ffi::CString; +use tvm_sys::{ArgValue, RetValue}; +use crate::external_func; + +mod object_ptr; + +pub use object_ptr::{IsObject, Object, ObjectPtr}; + +#[derive(Clone)] +pub struct ObjectRef(pub Option<ObjectPtr<Object>>); + +impl ObjectRef { + pub fn null() -> ObjectRef { + ObjectRef(None) + } +} + +pub trait ToObjectRef { + fn to_object_ref(&self) -> ObjectRef; +} + +impl ToObjectRef for ObjectRef { + fn to_object_ref(&self) -> ObjectRef { + self.clone() + } +} + +// impl<T: ToObjectRef> ToObjectRef for &T { +// fn to_object_ref(&self) -> ObjectRef { +// (*self).to_object_ref() +// } +// } + +impl TryFrom<RetValue> for ObjectRef { + type Error = anyhow::Error; + + fn try_from(ret_val: RetValue) -> Result<ObjectRef, Self::Error> { + let optr = ret_val.try_into()?; + Ok(ObjectRef(Some(optr))) + } +} + +impl From<ObjectRef> for RetValue { + fn from(object_ref: ObjectRef) -> RetValue { + use std::ffi::c_void; + let object_ptr = &object_ref.0; + match object_ptr { + None => RetValue::ObjectHandle(std::ptr::null::<c_void>() as *mut c_void), + Some(value) => value.clone().into(), + } + } +} + +impl<'a> std::convert::TryFrom<ArgValue<'a>> for ObjectRef { + type Error = anyhow::Error; + + fn try_from(arg_value: ArgValue<'a>) -> Result<ObjectRef, Self::Error> { + let optr = arg_value.try_into()?; + Ok(ObjectRef(Some(optr))) + } +} + +impl<'a> std::convert::TryFrom<&ArgValue<'a>> for ObjectRef { + type Error = anyhow::Error; + + fn try_from(arg_value: &ArgValue<'a>) -> Result<ObjectRef, Self::Error> { + // TODO(@jroesch): remove the clone + let value: ArgValue<'a> = arg_value.clone(); + ObjectRef::try_from(value) + } +} + +impl<'a> From<ObjectRef> for ArgValue<'a> { + fn from(object_ref: ObjectRef) -> ArgValue<'a> { + use std::ffi::c_void; + let object_ptr = &object_ref.0; + match object_ptr { + None => ArgValue::ObjectHandle(std::ptr::null::<c_void>() as *mut c_void), + Some(value) => value.clone().into(), + } + } +} + +impl<'a> From<&ObjectRef> for ArgValue<'a> { + fn from(object_ref: &ObjectRef) -> ArgValue<'a> { + let oref: ObjectRef = object_ref.clone(); + ArgValue::<'a>::from(oref) + } +} + +external_func! { + fn debug_print(object: ObjectRef) -> CString as "ir.DebugPrinter"; +} + +external_func! { + fn as_text(object: ObjectRef) -> CString as "ir.TextPrinter"; +} diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs new file mode 100644 index 0000000..c716c05 --- /dev/null +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -0,0 +1,283 @@ +use anyhow::Context; +use std::convert::TryFrom; +use std::ffi::CString; +use std::ptr::NonNull; +use tvm_sys::ffi::{self, /* TVMObjectFree, */ TVMObjectRetain, TVMObjectTypeKey2Index}; +use tvm_sys::{ArgValue, RetValue}; + +type Deleter<T> = unsafe extern "C" fn(object: *mut T) -> (); + +#[derive(Debug)] +#[repr(C)] +pub struct Object { + pub type_index: u32, + pub ref_count: i32, + pub fdeleter: Deleter<Object>, +} + +unsafe extern "C" fn delete<T: IsObject>(object: *mut Object) { + let typed_object: *mut T = std::mem::transmute(object); + T::typed_delete(typed_object); +} + +fn derived_from(child_type_index: u32, parent_type_index: u32) -> bool { + let mut is_derived = 0; + crate::check_call!(ffi::TVMObjectDerivedFrom( + child_type_index, + parent_type_index, + &mut is_derived + )); + if is_derived == 0 { + false + } else { + true + } +} + +impl Object { + fn new(type_index: u32, deleter: Deleter<Object>) -> Object { + Object { + type_index, + // Note: do not touch this field directly again, this is + // a critical section, we write a 1 to the atomic which will now + // be managed by the C++ atomics. + // In the future we should probably use C-atomcis. + ref_count: 1, + fdeleter: deleter, + } + } + + fn get_type_index<T: IsObject>() -> u32 { + let type_key = T::TYPE_KEY; + let cstring = CString::new(type_key).expect("type key must not contain null characters"); + if type_key == "Object" { + return 0; + } else { + let mut index = 0; + unsafe { + let index_ptr = std::mem::transmute(&mut index); + if TVMObjectTypeKey2Index(cstring.as_ptr(), index_ptr) != 0 { + panic!(crate::get_last_error()) + } + } + return index; + } + } + + pub fn base_object<T: IsObject>() -> Object { + let index = Object::get_type_index::<T>(); + Object::new(index, delete::<T>) + } +} + +pub unsafe trait IsObject { + const TYPE_KEY: &'static str; + + fn as_object<'s>(&'s self) -> &'s Object; + + unsafe extern "C" fn typed_delete(_object: *mut Self) { + // let object = Box::from_raw(object); + // drop(object) + } +} + +unsafe impl IsObject for Object { + const TYPE_KEY: &'static str = "Object"; + + fn as_object<'s>(&'s self) -> &'s Object { + self + } +} + +#[repr(C)] +pub struct ObjectPtr<T> { + pub ptr: NonNull<T>, +} + +impl ObjectPtr<Object> { + fn from_raw(object_ptr: *mut Object) -> Option<ObjectPtr<Object>> { + println!("{:?}", object_ptr); + let non_null = NonNull::new(object_ptr); + non_null.map(|ptr| ObjectPtr { ptr }) + } +} + +impl<T> Clone for ObjectPtr<T> { + fn clone(&self) -> Self { + unsafe { + let raw_ptr = std::mem::transmute(self.ptr); + assert_eq!(TVMObjectRetain(raw_ptr), 0); + ObjectPtr { ptr: self.ptr } + } + } +} + +// impl<T> Drop for ObjectPtr<T> { +// fn drop(&mut self) { +// unsafe { +// let raw_ptr = std::mem::transmute(self.ptr); +// assert_eq!(TVMObjectFree(raw_ptr), 0) +// } +// } +// } + +impl<T: IsObject> ObjectPtr<T> { + pub fn new(object: T) -> ObjectPtr<T> { + let object_ptr = Box::new(object); + let ptr = NonNull::from(Box::leak(object_ptr)); + ObjectPtr { ptr } + } + + pub fn count(&self) -> i32 { + // need to do atomic read in C++ + // ABI compatible atomics is funky/hard. + self.as_object().ref_count + } + + fn as_object<'s>(&'s self) -> &'s Object { + unsafe { self.ptr.as_ref().as_object() } + } + + pub fn upcast(&self) -> ObjectPtr<Object> { + ObjectPtr { + ptr: self.ptr.cast(), + } + } + + pub fn downcast<U: IsObject>(&self) -> anyhow::Result<ObjectPtr<U>> { + let child_index = Object::get_type_index::<U>(); + let object_index = self.as_object().type_index; + + let is_derived = if child_index == object_index { + true + } else { + // TODO(@jroesch): write tests + derived_from(object_index, child_index) + }; + + if is_derived { + Ok(ObjectPtr { + ptr: self.ptr.cast(), + }) + } else { + Err(anyhow::anyhow!("failed to downcast to object subtype")) + } + } +} + +impl<T> std::ops::Deref for ObjectPtr<T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + unsafe { self.ptr.as_ref() } + } +} + +impl<'a, T: IsObject> From<ObjectPtr<T>> for RetValue { + fn from(object_ptr: ObjectPtr<T>) -> RetValue { + let raw_object_ptr = object_ptr.ptr.as_ptr(); + // Should be able to hide this unsafety in raw bindings. + let void_ptr = unsafe { std::mem::transmute(raw_object_ptr) }; + RetValue::ObjectHandle(void_ptr) + } +} + +impl<'a, T: IsObject> TryFrom<RetValue> for ObjectPtr<T> { + type Error = anyhow::Error; + + fn try_from(ret_value: RetValue) -> Result<ObjectPtr<T>, Self::Error> { + match ret_value { + RetValue::ObjectHandle(handle) => { + let handle: *mut Object = unsafe { std::mem::transmute(handle) }; + let optr = ObjectPtr::from_raw(handle).context("unable to convert nullptr")?; + optr.downcast() + } + _ => Err(anyhow::anyhow!("unable to convert the result to an Object")), + } + } +} + +impl<'a, T: IsObject> From<ObjectPtr<T>> for ArgValue<'a> { + fn from(object_ptr: ObjectPtr<T>) -> ArgValue<'a> { + let raw_object_ptr = object_ptr.ptr.as_ptr(); + // Should be able to hide this unsafety in raw bindings. + let void_ptr = unsafe { std::mem::transmute(raw_object_ptr) }; + ArgValue::ObjectHandle(void_ptr) + } +} + +impl<'a, T: IsObject> TryFrom<ArgValue<'a>> for ObjectPtr<T> { + type Error = anyhow::Error; + fn try_from(arg_value: ArgValue<'a>) -> Result<ObjectPtr<T>, Self::Error> { + match arg_value { + ArgValue::ObjectHandle(handle) => { + let handle = unsafe { std::mem::transmute(handle) }; + let optr = ObjectPtr::from_raw(handle).context("unable to convert nullptr")?; + optr.downcast() + } + _ => Err(anyhow::anyhow!("unable to convert the result to an Object")), + } + } +} + +impl<'a, T: IsObject> TryFrom<&ArgValue<'a>> for ObjectPtr<T> { + type Error = anyhow::Error; + fn try_from(arg_value: &ArgValue<'a>) -> Result<ObjectPtr<T>, Self::Error> { + match arg_value { + ArgValue::ObjectHandle(handle) => { + let handle = unsafe { std::mem::transmute(handle) }; + let optr = ObjectPtr::from_raw(handle).context("unable to convert nullptr")?; + optr.downcast() + } + _ => Err(anyhow::anyhow!("unable to convert the result to an Object")), + } + } +} + +#[cfg(test)] +mod tests { + use super::{Object, ObjectPtr}; + use anyhow::{ensure, Result}; + use std::convert::TryInto; + use tvm_sys::{ArgValue, RetValue}; + + #[test] + fn test_new_object() -> anyhow::Result<()> { + let object = Object::base_object::<Object>(); + let ptr = ObjectPtr::new(object); + assert_eq!(ptr.count(), 1); + Ok(()) + } + + #[test] + fn roundtrip_retvalue() -> Result<()> { + let ptr = ObjectPtr::new(Object::base_object::<Object>()); + let ret_value: RetValue = ptr.clone().into(); + let ptr2: ObjectPtr<Object> = ret_value.try_into()?; + ensure!( + ptr.type_index == ptr2.type_index, + "type indices do not match" + ); + ensure!( + ptr.fdeleter == ptr2.fdeleter, + "objects have different deleters" + ); + Ok(()) + } + + #[test] + fn roundtrip_argvalue() -> Result<()> { + let ptr = ObjectPtr::new(Object::base_object::<Object>()); + let arg_value: ArgValue = ptr.clone().into(); + let ptr2: ObjectPtr<Object> = arg_value.try_into()?; + ensure!( + ptr.type_index == ptr2.type_index, + "type indices do not match" + ); + ensure!( + ptr.fdeleter == ptr2.fdeleter, + "objects have different deleters" + ); + Ok(()) + } +} diff --git a/rust/tvm-rt/src/string.rs b/rust/tvm-rt/src/string.rs new file mode 100644 index 0000000..ac80625 --- /dev/null +++ b/rust/tvm-rt/src/string.rs @@ -0,0 +1,72 @@ +use std::ffi::{CString, NulError}; +use std::os::raw::c_char; + +use super::{Object, ObjectPtr, ObjectRef}; +use crate as tvm_rt; +use tvm_macros::Object; + +#[repr(C)] +#[derive(Object)] +#[ref_name = "String"] +#[type_key = "runtime.String"] +pub struct StringObj { + base: Object, + data: *const c_char, + size: u64, +} + +impl String { + pub fn new(string: std::string::String) -> Result<String, NulError> { + let cstring = CString::new(string)?; + + // The string is being corrupted. + // why is this wrong + let length = cstring.as_bytes().len(); + + let string_obj = StringObj { + base: Object::base_object::<StringObj>(), + data: cstring.into_raw(), + size: length as u64, + }; + + let object_ptr = ObjectPtr::new(string_obj); + Ok(String(Some(object_ptr))) + } + + pub fn to_cstring(&self) -> Result<std::ffi::CString, NulError> { + use std::slice; + let ptr = self.0.as_ref().unwrap().data; + let size = self.0.as_ref().unwrap().size; + unsafe { + let slice: &[u8] = slice::from_raw_parts(ptr as *const u8, size as usize); + CString::new(slice) + } + } + + pub fn to_string(&self) -> anyhow::Result<std::string::String> { + let string = self.to_cstring()?.into_string()?; + Ok(string) + } +} + +#[cfg(test)] +mod tests { + use super::String; + use crate::object::debug_print; + use crate::ToObjectRef; + use anyhow::{ensure, Result}; + + #[test] + fn test_string_debug() -> Result<()> { + let s = String::new("foo".to_string()).unwrap(); + let object_ref = s.to_object_ref(); + println!("about to call"); + let string = debug_print(object_ref)?; + println!("after call"); + ensure!( + string.into_string().expect("is cstring").contains("foo"), + "string content is invalid" + ); + Ok(()) + } +} diff --git a/rust/tvm-rt/src/to_boxed_fn.rs b/rust/tvm-rt/src/to_boxed_fn.rs new file mode 100644 index 0000000..7a560b6 --- /dev/null +++ b/rust/tvm-rt/src/to_boxed_fn.rs @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! This module provides an idiomatic Rust API for creating and working with TVM functions. +//! +//! For calling an already registered TVM function use [`function::Builder`] +//! To register a TVM packed function from Rust side either +//! use [`function::register`] or the macro [`register_global_func`]. +//! +//! See the tests and examples repository for more examples. + +use anyhow::Result; + +pub use tvm_sys::{ffi, ArgValue, RetValue}; + +use crate::{Module}; + +use super::function::Function; + +pub trait ToBoxedFn { + fn to_boxed_fn(func: &'static Function) -> Box<Self>; +} + +use std::convert::{TryInto, TryFrom}; + +impl<E, O> ToBoxedFn for dyn Fn() -> Result<O> + where E: std::error::Error + Send + Sync + 'static, + O: TryFrom<RetValue, Error=E>, { + fn to_boxed_fn(func: &'static Function) -> Box<Self> { + Box::new(move || { + let mut builder = Builder::default(); + builder.func = Some(func); + let res = builder.invoke()?.try_into()?; + Ok(res) + }) + } +} + +impl<E, A, O> ToBoxedFn for dyn Fn(A) -> Result<O> + where E: std::error::Error + Send + Sync + 'static, + A: Into<ArgValue<'static>>, + O: TryFrom<RetValue, Error=E>, { + fn to_boxed_fn(func: &'static Function) -> Box<Self> { + Box::new(move |a: A| { + let mut builder = Builder::default(); + builder.func = Some(func); + builder.arg(a.into()); + let res = builder.invoke()?.try_into()?; + Ok(res) + }) + } +} + +impl<E, A, B, O> ToBoxedFn for dyn Fn(A, B) -> Result<O> + where E: std::error::Error + Send + Sync + 'static, + A: Into<ArgValue<'static>>, + B: Into<ArgValue<'static>>, + O: TryFrom<RetValue, Error=E>, { + fn to_boxed_fn(func: &'static Function) -> Box<Self> { + Box::new(move |a: A, b: B| { + let mut builder = Builder::default(); + builder.func = Some(func); + builder.arg(a.into()); + builder.arg(b.into()); + let res = builder.invoke()?.try_into()?; + Ok(res) + }) + } +} + +impl<E, A, B, C, O> ToBoxedFn for dyn Fn(A, B, C) -> Result<O> + where E: std::error::Error + Send + Sync + 'static, + A: Into<ArgValue<'static>>, + B: Into<ArgValue<'static>>, + C: Into<ArgValue<'static>>, + O: TryFrom<RetValue, Error=E>, { + fn to_boxed_fn(func: &'static Function) -> Box<Self> { + Box::new(move |a: A, b: B, c: C| { + let mut builder = Builder::default(); + builder.func = Some(func); + builder.arg(a.into()); + builder.arg(b.into()); + builder.arg(c.into()); + let res = builder.invoke()?.try_into()?; + Ok(res) + }) + } +} + +impl<E, A, B, C, D, O> ToBoxedFn for dyn Fn(A, B, C, D) -> Result<O> + where E: std::error::Error + Send + Sync + 'static, + A: Into<ArgValue<'static>>, + B: Into<ArgValue<'static>>, + C: Into<ArgValue<'static>>, + D: Into<ArgValue<'static>>, + O: TryFrom<RetValue, Error=E>, { + fn to_boxed_fn(func: &'static Function) -> Box<Self> { + Box::new(move |a: A, b: B, c: C, d: D| { + let mut builder = Builder::default(); + builder.func = Some(func); + builder.arg(a.into()); + builder.arg(b.into()); + builder.arg(c.into()); + builder.arg(d.into()); + let res = builder.invoke()?.try_into()?; + Ok(res) + }) + } +} + +/// Function builder in order to create and call functions. +/// +/// *Note:* Currently TVM functions accept *at most* one return value. +#[derive(Default)] +pub struct Builder<'a, 'm> { + pub func: Option<&'m Function>, + pub arg_buf: Vec<ArgValue<'a>>, + pub ret_buf: Option<RetValue>, +} + +impl<'a, 'm> Builder<'a, 'm> { + pub fn new( + func: Option<&'m Function>, + arg_buf: Vec<ArgValue<'a>>, + ret_buf: Option<RetValue>, + ) -> Self { + Self { + func, + arg_buf, + ret_buf, + } + } + + pub fn get_function(&mut self, name: &'m str) -> &mut Self { + self.func = Function::get(name); + self + } + + /// Pushes a [`ArgValue`] into the function argument buffer. + pub fn arg<T: 'a>(&mut self, arg: T) -> &mut Self + where + ArgValue<'a>: From<T>, + { + self.arg_buf.push(arg.into()); + self + } + + /// Pushes multiple [`ArgValue`]s into the function argument buffer. + pub fn args<T: 'a, I>(&mut self, args: I) -> &mut Self + where + I: IntoIterator<Item = T>, + ArgValue<'a>: From<T>, + { + args.into_iter().for_each(|arg| { + self.arg(arg); + }); + self + } + + /// Sets an output for a function that requirs a mutable output to be provided. + /// See the `basics` in tests for an example. + pub fn set_output<T>(&mut self, ret: T) -> &mut Self + where + RetValue: From<T>, + { + self.ret_buf = Some(ret.into()); + self + } + + pub fn invoke(self) -> Result<RetValue> { + self.func.unwrap().invoke(self.arg_buf) + } + +} + +/// Converts a [`Function`] to builder. Currently, this is the best way to work with +/// TVM functions. +impl<'a, 'm> From<&'m Function> for Builder<'a, 'm> { + fn from(func: &'m Function) -> Self { + Builder::new(Some(func), Vec::new(), None) + } +} + +/// Converts a mutable reference of a [`Module`] to [`Builder`]. +impl<'a, 'm> From<&'m mut Module> for Builder<'a, 'm> { + fn from(module: &'m mut Module) -> Self { + Builder::new(module.entry(), Vec::new(), None) + } +} +#[cfg(test)] +mod tests { + use anyhow::Result; + use crate::function::{self, Function}; + + #[test] + fn to_boxed_fn0() { + fn boxed0() -> i64 { + return 10; + } + + function::register_override(boxed0, "boxed0".to_owned(), true).unwrap(); + let func = Function::get("boxed0").unwrap(); + let typed_func: Box<dyn Fn() -> Result<i64>> = func.to_boxed_fn(); + assert_eq!(typed_func().unwrap(), 10); + } +} diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs new file mode 100644 index 0000000..6954650 --- /dev/null +++ b/rust/tvm-rt/src/to_function.rs @@ -0,0 +1,377 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! This module provides an idiomatic Rust API for creating and working with TVM functions. +//! +//! For calling an already registered TVM function use [`function::Builder`] +//! To register a TVM packed function from Rust side either +//! use [`function::register`] or the macro [`register_global_func`]. +//! +//! See the tests and examples repository for more examples. + +use std::{ + mem::MaybeUninit, + os::raw::{c_int, c_void}, + ptr, slice, +}; + +use anyhow::Result; + +pub use tvm_sys::{ffi, ArgValue, RetValue}; + +use super::Function; +use std::convert::{TryFrom, TryInto}; + +/// A trait representing whether the function arguments +/// and return type can be assigned to a TVM packed function. +/// +/// By splitting the conversion to function into two traits +/// we are able to improve error reporting, by splitting the +/// conversion of inputs and outputs to this trait. +/// +/// And the implementation of it to `ToFunction`. +pub trait Typed<I, O> { + fn args(i: &[ArgValue<'static>]) -> anyhow::Result<I>; + fn ret(o: O) -> RetValue; +} + +impl<'a, F> Typed<&'a [ArgValue<'static>], anyhow::Result<RetValue>> for F +where + F: Fn(&'a [ArgValue]) -> anyhow::Result<RetValue>, +{ + fn args(args: &[ArgValue<'static>]) -> anyhow::Result<&'a [ArgValue<'static>]> { + // this is BAD but just hacking for time being + Ok(unsafe { std::mem::transmute(args) }) + } + + fn ret(ret_value: anyhow::Result<RetValue>) -> RetValue { + ret_value.unwrap() + } +} + +impl<F, O: Into<RetValue>> Typed<(), O> for F +where + F: Fn() -> O, +{ + fn args(_args: &[ArgValue<'static>]) -> anyhow::Result<()> { + debug_assert!(_args.len() == 0); + Ok(()) + } + + fn ret(o: O) -> RetValue { + o.into() + } +} + +impl<F, A, O: Into<RetValue>, E: Into<anyhow::Error>> Typed<(A,), O> for F +where + F: Fn(A) -> O, + E: std::error::Error + Send + Sync + 'static, + A: TryFrom<ArgValue<'static>, Error = E>, +{ + fn args(args: &[ArgValue<'static>]) -> anyhow::Result<(A,)> { + debug_assert!(args.len() == 1); + let a: A = args[0].clone().try_into()?; + Ok((a,)) + } + + fn ret(o: O) -> RetValue { + o.into() + } +} + +impl<F, A, B, O: Into<RetValue>, E: Into<anyhow::Error>> Typed<(A, B), O> for F +where + F: Fn(A, B) -> O, + E: std::error::Error + Send + Sync + 'static, + A: TryFrom<ArgValue<'static>, Error = E>, + B: TryFrom<ArgValue<'static>, Error = E>, +{ + fn args(args: &[ArgValue<'static>]) -> anyhow::Result<(A, B)> { + debug_assert!(args.len() == 1); + let a: A = args[0].clone().try_into()?; + let b: B = args[1].clone().try_into()?; + Ok((a, b)) + } + + fn ret(o: O) -> RetValue { + o.into() + } +} + +impl<F, A, B, C, O: Into<RetValue>, E: Into<anyhow::Error>> Typed<(A, B, C), O> for F +where + F: Fn(A, B, C) -> O, + E: std::error::Error + Send + Sync + 'static, + A: TryFrom<ArgValue<'static>, Error = E>, + B: TryFrom<ArgValue<'static>, Error = E>, + C: TryFrom<ArgValue<'static>, Error = E>, +{ + fn args(args: &[ArgValue<'static>]) -> anyhow::Result<(A, B, C)> { + debug_assert!(args.len() == 1); + let a: A = args[0].clone().try_into()?; + let b: B = args[1].clone().try_into()?; + let c: C = args[2].clone().try_into()?; + Ok((a, b, c)) + } + + fn ret(o: O) -> RetValue { + o.into() + } +} + +pub trait ToFunction<I, O>: Sized { + type Handle; + + fn into_raw(self) -> *mut Self::Handle; + + fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) -> anyhow::Result<RetValue> + where + Self: Typed<I, O>; + + fn drop(handle: *mut Self::Handle); + + fn to_function(self) -> Function + where + Self: Typed<I, O>, + { + let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle; + let resource_handle = self.into_raw(); + check_call!(ffi::TVMFuncCreateFromCFunc( + Some(Self::tvm_callback), + resource_handle as *mut _, + Some(Self::tvm_finalizer), + &mut fhandle as *mut _ + )); + println!("fnhandle: {:?}", fhandle); + Function::new(fhandle) + } + + /// The callback function which is wrapped converted by TVM + /// into a packed function stored in fhandle. + unsafe extern "C" fn tvm_callback( + args: *mut ffi::TVMValue, + type_codes: *mut c_int, + num_args: c_int, + ret: ffi::TVMRetValueHandle, + fhandle: *mut c_void, + ) -> c_int + where + Self: Typed<I, O>, + { + // turning off the incorrect linter complaints + #![allow(unused_assignments, unused_unsafe)] + println!("here"); + let len = num_args as usize; + let args_list = slice::from_raw_parts_mut(args, len); + let type_codes_list = slice::from_raw_parts_mut(type_codes, len); + let mut local_args: Vec<ArgValue> = Vec::new(); + let mut value = MaybeUninit::uninit().assume_init(); + let mut tcode = MaybeUninit::uninit().assume_init(); + let rust_fn = fhandle as *mut Self::Handle; + for i in 0..len { + value = args_list[i]; + println!("{:?}", value.v_handle); + tcode = type_codes_list[i]; + if tcode == ffi::TVMTypeCode_kTVMObjectHandle as c_int + || tcode == ffi::TVMTypeCode_kTVMPackedFuncHandle as c_int + || tcode == ffi::TVMTypeCode_kTVMModuleHandle as c_int + { + check_call!(ffi::TVMCbArgToReturn( + &mut value as *mut _, + &mut tcode as *mut _ + )); + println!("{:?}", value.v_handle); + } + let arg_value = ArgValue::from_tvm_value(value, tcode as u32); + println!("{:?}", arg_value); + local_args.push(arg_value); + } + println!("before call"); + let rv = match Self::call(rust_fn, local_args.as_slice()) { + Ok(v) => v, + Err(msg) => { + crate::set_last_error(&msg); + return -1; + } + }; + println!("after call"); + + let (mut ret_val, ret_tcode) = rv.to_tvm_value(); + let mut ret_type_code = ret_tcode as c_int; + check_call!(ffi::TVMCFuncSetReturn( + ret, + &mut ret_val as *mut _, + &mut ret_type_code as *mut _, + 1 as c_int + )); + 0 + } + + /// The finalizer which is invoked when the packed function's + /// reference count is zero. + unsafe extern "C" fn tvm_finalizer(fhandle: *mut c_void) { + let handle = std::mem::transmute(fhandle); + Self::drop(handle) + } +} + +// /// A wrapper that is used to work around inference issues for bare functions. +// /// +// /// Used to implement `register_untyped`. +// pub(self) struct RawFunction { +// fn_ptr: for<'a> fn (&'a [ArgValue<'static>]) -> Result<RetValue> +// } + +// impl RawFunction { +// fn new(fn_ptr: for<'a> fn (&'a [ArgValue<'static>]) -> Result<RetValue>) -> RawFunction { +// RawFunction { fn_ptr: fn_ptr } +// } +// } + +// impl Typed<&[ArgValue<'static>], ()> for RawFunction { +// fn args(i: &[ArgValue<'static>]) -> anyhow::Result<&[ArgValue<'static>]> { +// Ok(i) +// } + +// fn ret(o: O) -> RetValue; +// } + +// impl ToFunction<(), ()> for RawFunction +// { +// type Handle = fn(&[ArgValue<'static>]) -> Result<RetValue>; + +// fn into_raw(self) -> *mut Self::Handle { +// self.fn_ptr as *mut Self::Handle +// } + +// fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) -> Result<RetValue> { +// let handle: Self::Handle = unsafe { std::mem::transmute(handle) }; +// let r = handle(args); +// println!("afters"); +// r +// } + +// // Function's don't need de-allocation because the pointers are into the code section of memory. +// fn drop(_: *mut Self::Handle) {} +// } + +impl<O, F> ToFunction<(), O> for F +where + F: Fn() -> O + 'static, +{ + type Handle = Box<dyn Fn() -> O + 'static>; + + fn into_raw(self) -> *mut Self::Handle { + let ptr: Box<Self::Handle> = Box::new(Box::new(self)); + Box::into_raw(ptr) + } + + fn call(handle: *mut Self::Handle, _: &[ArgValue<'static>]) -> Result<RetValue> + where + F: Typed<(), O>, + { + // Ideally we shouldn't need to clone, probably doesn't really matter. + let out = unsafe { (*handle)() }; + Ok(F::ret(out)) + } + + fn drop(_: *mut Self::Handle) {} +} + +macro_rules! to_function_instance { + ($(($param:ident,$index:tt),)+) => { + impl<F, $($param,)+ O> ToFunction<($($param,)+), O> for + F where F: Fn($($param,)+) -> O + 'static { + type Handle = Box<dyn Fn($($param,)+) -> O + 'static>; + + fn into_raw(self) -> *mut Self::Handle { + let ptr: Box<Self::Handle> = Box::new(Box::new(self)); + Box::into_raw(ptr) + } + + fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) -> Result<RetValue> where F: Typed<($($param,)+), O> { + // Ideally we shouldn't need to clone, probably doesn't really matter. + let args = F::args(args)?; + let out = unsafe { + (*handle)($(args.$index),+) + }; + Ok(F::ret(out)) + } + + fn drop(_: *mut Self::Handle) {} + } + } +} + +to_function_instance!((A, 0),); +to_function_instance!((A, 0), (B, 1),); +to_function_instance!((A, 0), (B, 1), (C, 2),); +to_function_instance!((A, 0), (B, 1), (C, 2), (D, 3),); + +#[cfg(test)] +mod tests { + // use super::RawFunction; + use super::{Function, ToFunction, Typed}; + + fn zero() -> i32 { + 10 + } + + fn helper<F, I, O>(f: F) -> Function + where + F: ToFunction<I, O>, + F: Typed<I, O>, + { + f.to_function() + } + + // fn func_args(args: &[ArgValue<'static>]) -> anyhow::Result<RetValue> { + // Ok(10.into()) + // } + + // #[test] + // fn test_fn_ptr() { + // let raw_fn = RawFunction::new(func_args); + // raw_fn.to_function(); + // } + + #[test] + fn test_to_function0() { + helper(zero); + } + + fn one_arg(i: i32) -> i32 { + i + } + + #[test] + fn test_to_function1() { + helper(one_arg); + } + + fn two_arg(i: i32, j: i32) -> i32 { + i + j + } + + #[test] + fn test_to_function2() { + helper(two_arg); + } +} diff --git a/rust/tvm-rt/src/value.rs b/rust/tvm-rt/src/value.rs new file mode 100644 index 0000000..a9355e0 --- /dev/null +++ b/rust/tvm-rt/src/value.rs @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! This module implements [`ArgValue`] and [`RetValue`] types +//! and their conversions needed for the types used in frontend crate. +//! `RetValue` is the owned version of `TVMPODValue`. + +use std::convert::TryFrom; +// use std::ffi::c_void; + +use crate::{ArgValue, Function, Module, NDArray, RetValue}; +use tvm_sys::{ + errors::ValueDowncastError, + ffi::{TVMFunctionHandle, TVMModuleHandle}, + try_downcast, +}; + +macro_rules! impl_handle_val { + ($type:ty, $variant:ident, $inner_type:ty, $ctor:path) => { + impl<'a> From<&'a $type> for ArgValue<'a> { + fn from(arg: &'a $type) -> Self { + ArgValue::$variant(arg.handle() as $inner_type) + } + } + + impl<'a> From<&'a mut $type> for ArgValue<'a> { + fn from(arg: &'a mut $type) -> Self { + ArgValue::$variant(arg.handle() as $inner_type) + } + } + + impl<'a> TryFrom<ArgValue<'a>> for $type { + type Error = ValueDowncastError; + fn try_from(val: ArgValue<'a>) -> Result<$type, Self::Error> { + try_downcast!(val -> $type, |ArgValue::$variant(val)| { $ctor(val) }) + } + } + + impl<'a, 'v> TryFrom<&'a ArgValue<'v>> for $type { + type Error = ValueDowncastError; + fn try_from(val: &'a ArgValue<'v>) -> Result<$type, Self::Error> { + try_downcast!(val -> $type, |ArgValue::$variant(val)| { $ctor(*val) }) + } + } + + impl From<$type> for RetValue { + fn from(val: $type) -> RetValue { + RetValue::$variant(val.handle() as $inner_type) + } + } + + impl TryFrom<RetValue> for $type { + type Error = ValueDowncastError; + fn try_from(val: RetValue) -> Result<$type, Self::Error> { + try_downcast!(val -> $type, |RetValue::$variant(val)| { $ctor(val) }) + } + } + }; +} + +impl_handle_val!(Function, FuncHandle, TVMFunctionHandle, Function::new); +impl_handle_val!(Module, ModuleHandle, TVMModuleHandle, Module::new); + +impl<'a> From<&'a NDArray> for ArgValue<'a> { + fn from(arg: &'a NDArray) -> Self { + match arg { + &NDArray::Borrowed { handle } => ArgValue::ArrayHandle(handle), + &NDArray::Owned { handle } => ArgValue::NDArrayHandle(handle), + } + } +} + +impl<'a> From<&'a mut NDArray> for ArgValue<'a> { + fn from(arg: &'a mut NDArray) -> Self { + match arg { + &mut NDArray::Borrowed { handle } => ArgValue::ArrayHandle(handle), + &mut NDArray::Owned { handle } => ArgValue::NDArrayHandle(handle), + } + } +} + +impl<'a> TryFrom<ArgValue<'a>> for NDArray { + type Error = ValueDowncastError; + fn try_from(val: ArgValue<'a>) -> Result<NDArray, Self::Error> { + try_downcast!(val -> NDArray, + |ArgValue::NDArrayHandle(val)| { NDArray::from_ndarray_handle(val) }, + |ArgValue::ArrayHandle(val)| { NDArray::new(val) }) + } +} + +impl<'a, 'v> TryFrom<&'a ArgValue<'v>> for NDArray { + type Error = ValueDowncastError; + fn try_from(val: &'a ArgValue<'v>) -> Result<NDArray, Self::Error> { + try_downcast!(val -> NDArray, + |ArgValue::NDArrayHandle(val)| { NDArray::from_ndarray_handle(*val) }, + |ArgValue::ArrayHandle(val)| { NDArray::new(*val) }) + } +} + +impl From<NDArray> for RetValue { + fn from(val: NDArray) -> RetValue { + match val { + NDArray::Owned { handle } => RetValue::NDArrayHandle(handle), + _ => panic!("NYI"), + } + } +} + +impl TryFrom<RetValue> for NDArray { + type Error = ValueDowncastError; + fn try_from(val: RetValue) -> Result<NDArray, Self::Error> { + try_downcast!(val -> NDArray, + |RetValue::NDArrayHandle(val)| { NDArray::from_ndarray_handle(val) }, + |RetValue::ArrayHandle(val)| { NDArray::new(val) }) + } +} + +#[cfg(test)] +mod tests { + use std::{convert::TryInto, str::FromStr}; + + use crate::{ByteArray, Context, DataType}; + + use super::*; + + #[test] + fn bytearray() { + let w = vec![1u8, 2, 3, 4, 5]; + let v = ByteArray::from(w.as_slice()); + let tvm: ByteArray = RetValue::from(v).try_into().unwrap(); + assert_eq!( + tvm.data(), + w.iter().copied().collect::<Vec<u8>>().as_slice() + ); + } + + #[test] + fn ty() { + let t = DataType::from_str("int32").unwrap(); + let tvm: DataType = RetValue::from(t).try_into().unwrap(); + assert_eq!(tvm, t); + } + + #[test] + fn ctx() { + let c = Context::from_str("gpu").unwrap(); + let tvm: Context = RetValue::from(c).try_into().unwrap(); + assert_eq!(tvm, c); + } +} diff --git a/rust/tvm-rt/tests/test_ir.rs b/rust/tvm-rt/tests/test_ir.rs new file mode 100644 index 0000000..7d9e475 --- /dev/null +++ b/rust/tvm-rt/tests/test_ir.rs @@ -0,0 +1,36 @@ +// use std::convert::TryInto; +// use std::str::FromStr; +// use tvm_rt::string::String as TString; +// use tvm::runtime::{debug_print, Object, ObjectPtr, ObjectRef}; +// use tvm::{call_packed, DLDataType, Function}; +// use tvm_sys::RetValue; + +// #[test] +// fn test_new_object() -> anyhow::Result<()> { +// let object = Object::base_object::<Object>(); +// let ptr = ObjectPtr::new(object); +// assert_eq!(ptr.count(), 1); +// Ok(()) +// } + +// #[test] +// fn test_new_string() -> anyhow::Result<()> { +// let string = TString::new("hello world!".to_string())?; +// Ok(()) +// } + +// #[test] +// fn test_obj_build() -> anyhow::Result<()> { +// let int_imm = Function::get("ir.IntImm").expect("Stable TVM API not found."); + +// let dt = DLDataType::from_str("int32").expect("Known datatype doesn't convert."); + +// let ret_val: ObjectRef = call_packed!(int_imm, dt, 1337) +// .expect("foo") +// .try_into() +// .unwrap(); + +// debug_print(&ret_val); + +// Ok(()) +// } diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 7272213..b322388 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -162,7 +162,7 @@ GlobalVar::GlobalVar(std::string name_hint) { TVM_REGISTER_NODE_TYPE(GlobalVarNode); TVM_REGISTER_GLOBAL("ir.GlobalVar") -.set_body_typed([](std::string name){ +.set_body_typed([](String name){ return GlobalVar(name); }); @@ -214,4 +214,13 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) } p->stream << '}'; }); + +TVM_REGISTER_GLOBAL("ir.DebugPrinter") +.set_body([](TVMArgs args, TVMRetValue* ret) { + ObjectRef ref = args[0]; + std::stringstream ss; + ss << ref; + *ret = ss.str(); +}); + } // namespace tvm diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index bda997a..fc9546a 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -193,8 +193,7 @@ class RelayTextPrinter : case kTypeData: return Doc::Text("TypeData"); default: - LOG(ERROR) << "Unknown Kind"; - throw; + CHECK(false) << "Unknown Kind"; } } /*! @@ -479,7 +478,8 @@ class RelayTextPrinter : } Doc VisitExpr_(const GlobalVarNode* op) final { - return Doc::Text('@' + op->name_hint); + std::string name_hint = op->name_hint; + return Doc::Text('@' + name_hint); } Doc VisitExpr_(const OpNode* op) final { @@ -939,4 +939,13 @@ TVM_REGISTER_GLOBAL("ir.PrettyPrint") TVM_REGISTER_GLOBAL("ir.AsText") .set_body_typed(AsText); + +TVM_REGISTER_GLOBAL("ir.TextPrinter") +.set_body_typed([](ObjectRef node) { + std::cout << "The program: " << node << std::endl; + auto text = AsText(node, false, nullptr); + std::cout << "The text " << text; + return text; +}); + } // namespace tvm diff --git a/src/relay/transforms/to_cps.cc b/src/relay/transforms/to_cps.cc index e6c8392..65ee57f 100644 --- a/src/relay/transforms/to_cps.cc +++ b/src/relay/transforms/to_cps.cc @@ -164,7 +164,7 @@ Function ToCPS(const Function& f, // only look unfold non-external calls. BaseFunc base_func = m->Lookup(gv); if (auto* n = base_func.as<FunctionNode>()) { - auto cps_gv = GlobalVar(gv->name_hint + "_cps"); + auto cps_gv = GlobalVar(std::string(gv->name_hint) + "_cps"); cm->insert({gv, cps_gv}); m->Add(cps_gv, ToCPS(GetRef<Function>(n), m, cm)); } else { diff --git a/src/runtime/object.cc b/src/runtime/object.cc index 0301200..5496159 100644 --- a/src/runtime/object.cc +++ b/src/runtime/object.cc @@ -244,12 +244,26 @@ int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex) { API_END(); } +int TVMObjectRetain(TVMObjectHandle obj) { + API_BEGIN(); + tvm::runtime::ObjectInternal::ObjectRetain(obj); + API_END(); +} + int TVMObjectFree(TVMObjectHandle obj) { API_BEGIN(); tvm::runtime::ObjectInternal::ObjectFree(obj); API_END(); } + +int TVMObjectDerivedFrom(uint32_t child_type_index, uint32_t parent_type_index, int* is_derived) { + API_BEGIN(); + *is_derived = tvm::runtime::TypeContext::Global()-> + DerivedFrom(child_type_index, parent_type_index); + API_END(); +} + int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex) { API_BEGIN(); out_tindex[0] = tvm::runtime::ObjectInternal::ObjectTypeKey2Index( diff --git a/src/runtime/object_internal.h b/src/runtime/object_internal.h index 7955130..ab48802 100644 --- a/src/runtime/object_internal.h +++ b/src/runtime/object_internal.h @@ -38,6 +38,15 @@ namespace runtime { class ObjectInternal { public: /*! + * \brief Retain an object handle. + */ + static void ObjectRetain(TVMObjectHandle obj) { + if (obj != nullptr) { + static_cast<Object*>(obj)->IncRef(); + } + } + + /*! * \brief Free an object handle. */ static void ObjectFree(TVMObjectHandle obj) {