This is an automated email from the ASF dual-hosted git repository. jroesch pushed a commit to branch rust-tvm in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit 43325abcdf9b1aa8d405a9769aa43dc12f66112d Author: Jared Roesch <jroe...@octoml.ai> AuthorDate: Wed Jun 10 02:30:10 2020 -0700 Add tvm crate --- rust/tvm/.gitignore | 7 + rust/tvm/.travis.yml | 22 +++ rust/tvm/Cargo.toml | 45 +++++ rust/tvm/README.md | 235 +++++++++++++++++++++++++++ rust/tvm/examples/resnet/Cargo.toml | 29 ++++ rust/tvm/examples/resnet/README.md | 45 +++++ rust/tvm/examples/resnet/build.rs | 42 +++++ rust/tvm/examples/resnet/src/build_resnet.py | 134 +++++++++++++++ rust/tvm/examples/resnet/src/main.rs | 160 ++++++++++++++++++ rust/tvm/src/ir/array.rs | 74 +++++++++ rust/tvm/src/ir/mod.rs | 17 ++ rust/tvm/src/ir/relay/mod.rs | 232 ++++++++++++++++++++++++++ rust/tvm/src/lib.rs | 47 ++++++ rust/tvm/src/runtime/mod.rs | 1 + rust/tvm/src/transform.rs | 42 +++++ rust/tvm/tests/basics/.gitignore | 7 + rust/tvm/tests/basics/Cargo.toml | 32 ++++ rust/tvm/tests/basics/build.rs | 46 ++++++ rust/tvm/tests/basics/src/main.rs | 55 +++++++ rust/tvm/tests/basics/src/tvm_add.py | 50 ++++++ rust/tvm/tests/callback/Cargo.toml | 26 +++ rust/tvm/tests/callback/src/bin/array.rs | 72 ++++++++ rust/tvm/tests/callback/src/bin/error.rs | 56 +++++++ rust/tvm/tests/callback/src/bin/float.rs | 50 ++++++ rust/tvm/tests/callback/src/bin/int.rs | 49 ++++++ rust/tvm/tests/callback/src/bin/string.rs | 54 ++++++ rust/tvm/tests/test_ir.rs | 37 +++++ 27 files changed, 1666 insertions(+) diff --git a/rust/tvm/.gitignore b/rust/tvm/.gitignore new file mode 100644 index 0000000..2430329 --- /dev/null +++ b/rust/tvm/.gitignore @@ -0,0 +1,7 @@ +target +**/*.rs.bk +Cargo.lock +/tests/basics/add_* +/examples/resnet/deploy_* +/examples/resnet/*.png +/examples/resnet/synset.* diff --git a/rust/tvm/.travis.yml b/rust/tvm/.travis.yml new file mode 100644 index 0000000..e963b7c --- /dev/null +++ b/rust/tvm/.travis.yml @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +language: rust +rust: + - nightly +matrix: + fast_finish: true diff --git a/rust/tvm/Cargo.toml b/rust/tvm/Cargo.toml new file mode 100644 index 0000000..ebfb5e6 --- /dev/null +++ b/rust/tvm/Cargo.toml @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "tvm" +version = "0.1.0" +license = "Apache-2.0" +description = "Rust frontend support for TVM" +repository = "https://github.com/apache/incubator-tvm" +homepage = "https://github.com/apache/incubator-tvm" +readme = "README.md" +keywords = ["rust", "tvm"] +categories = ["api-bindings", "science"] +authors = ["TVM Contributors"] +edition = "2018" + +[dependencies] +thiserror = "^1.0" +anyhow = "^1.0" +lazy_static = "1.1" +ndarray = "0.12" +num-traits = "0.2" +tvm-rt = { version = "0.1", path = "../tvm-rt/" } +tvm-sys = { version = "0.1", path = "../tvm-sys/" } +tvm-macros = { version = "*", path = "../tvm-macros/" } +paste = "0.1" +mashup = "0.1" +once_cell = "^1.3.1" + +[features] +blas = ["ndarray/blas"] diff --git a/rust/tvm/README.md b/rust/tvm/README.md new file mode 100644 index 0000000..01e088f --- /dev/null +++ b/rust/tvm/README.md @@ -0,0 +1,235 @@ +<!--- Licensed to the Apache Software Foundation (ASF) under one --> +<!--- or more contributor license agreements. See the NOTICE file --> +<!--- distributed with this work for additional information --> +<!--- regarding copyright ownership. The ASF licenses this file --> +<!--- to you under the Apache License, Version 2.0 (the --> +<!--- "License"); you may not use this file except in compliance --> +<!--- with the License. You may obtain a copy of the License at --> + +<!--- http://www.apache.org/licenses/LICENSE-2.0 --> + +<!--- Unless required by applicable law or agreed to in writing, --> +<!--- software distributed under the License is distributed on an --> +<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY --> +<!--- KIND, either express or implied. See the License for the --> +<!--- specific language governing permissions and limitations --> +<!--- under the License. --> + +# TVM Runtime Frontend Support + +This crate provides an idiomatic Rust API for [TVM](https://github.com/apache/incubator-tvm) runtime frontend. Currently this requires **Nightly Rust** and tested on `rustc 1.32.0-nightly` + +## What Does This Crate Offer? + +Here is a major workflow + +1. Train your **Deep Learning** model using any major framework such as [PyTorch](https://pytorch.org/), [Apache MXNet](https://mxnet.incubator.apache.org/) or [TensorFlow](https://www.tensorflow.org/) +2. Use **TVM** to build optimized model artifacts on a supported context such as CPU, GPU, OpenCL and specialized accelerators. +3. Deploy your models using **Rust** :heart: + +### Example: Deploy Image Classification from Pretrained Resnet18 on ImageNet1k + +Please checkout [examples/resnet](examples/resnet) for the complete end-to-end example. + +Here's a Python snippet for downloading and building a pretrained Resnet18 via Apache MXNet and TVM + +```python +block = get_model('resnet18_v1', pretrained=True) + +sym, params = relay.frontend.from_mxnet(block, shape_dict) +# compile the model +with relay.build_config(opt_level=opt_level): + graph, lib, params = relay.build( + net, target, params=params) +# same the model artifacts +lib.save(os.path.join(target_dir, "deploy_lib.o")) +cc.create_shared(os.path.join(target_dir, "deploy_lib.so"), + [os.path.join(target_dir, "deploy_lib.o")]) + +with open(os.path.join(target_dir, "deploy_graph.json"), "w") as fo: + fo.write(graph.json()) +with open(os.path.join(target_dir,"deploy_param.params"), "wb") as fo: + fo.write(relay.save_param_dict(params)) +``` + +Now, we need to input the artifacts to create and run the *Graph Runtime* to detect our input cat image + +![cat](https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true) + +as demostrated in the following Rust snippet + +```rust + let graph = fs::read_to_string("deploy_graph.json")?; + // load the built module + let lib = Module::load(&Path::new("deploy_lib.so"))?; + // get the global TVM graph runtime function + let runtime_create_fn = Function::get("tvm.graph_runtime.create", true).unwrap(); + let runtime_create_fn_ret = call_packed!( + runtime_create_fn, + &graph, + &lib, + &ctx.device_type, + &ctx.device_id + )?; + // get graph runtime module + let graph_runtime_module: Module = runtime_create_fn_ret.try_into()?; + // get the registered `load_params` from runtime module + let ref load_param_fn = graph_runtime_module + .get_function("load_params", false) + .unwrap(); + // parse parameters and convert to TVMByteArray + let params: Vec<u8> = fs::read("deploy_param.params")?; + let barr = TVMByteArray::from(¶ms); + // load the parameters + call_packed!(load_param_fn, &barr)?; + // get the set_input function + let ref set_input_fn = graph_runtime_module + .get_function("set_input", false) + .unwrap(); + + call_packed!(set_input_fn, "data", &input)?; + // get `run` function from runtime module + let ref run_fn = graph_runtime_module.get_function("run", false).unwrap(); + // execute the run function. Note that it has no argument + call_packed!(run_fn,)?; + // prepare to get the output + let output_shape = &mut [1, 1000]; + let output = empty(output_shape, TVMContext::cpu(0), TVMType::from("float32")); + // get the `get_output` function from runtime module + let ref get_output_fn = graph_runtime_module + .get_function("get_output", false) + .unwrap(); + // execute the get output function + call_packed!(get_output_fn, &0, &output)?; + // flatten the output as Vec<f32> + let output = output.to_vec::<f32>()?; +``` + +and the model correctly predicts the input image as **tiger cat**. + +## Installations + +Please follow TVM [installations](https://tvm.apache.org/docs/install/index.html), `export TVM_HOME=/path/to/tvm` and add `libtvm_runtime` to your `LD_LIBRARY_PATH`. + +*Note:* To run the end-to-end examples and tests, `tvm` and `topi` need to be added to your `PYTHONPATH` or it's automatic via an Anaconda environment when it is installed individually. + +## Supported TVM Functionalities + +### Use TVM to Generate Shared Library + +One can use the following Python snippet to generate `add_gpu.so` which add two vectors on GPU. + +```python +import os +import tvm +from tvm import te +from tvm.contrib import cc + +def test_add(target_dir): + if not tvm.runtime.enabled("cuda"): + print("skip {__file__} because cuda is not enabled...".format(__file__=__file__)) + return + n = te.var("n") + A = te.placeholder((n,), name='A') + B = te.placeholder((n,), name='B') + C = te.compute(A.shape, lambda i: A[i] + B[i], name="C") + s = te.create_schedule(C.op) + bx, tx = s[C].split(C.op.axis[0], factor=64) + s[C].bind(bx, tvm.thread_axis("blockIdx.x")) + s[C].bind(tx, tvm.thread_axis("threadIdx.x")) + fadd_cuda = tvm.build(s, [A, B, C], "cuda", target_host="llvm", name="myadd") + + fadd_cuda.save(os.path.join(target_dir, "add_gpu.o")) + fadd_cuda.imported_modules[0].save(os.path.join(target_dir, "add_gpu.ptx")) + cc.create_shared(os.path.join(target_dir, "add_gpu.so"), + [os.path.join(target_dir, "add_gpu.o")]) + + +if __name__ == "__main__": + import sys + if len(sys.argv) != 2: + sys.exit(-1) + test_add(sys.argv[1]) +``` + +### Run the Generated Shared Library + +The following code snippet demonstrates how to load and test the generated shared library (`add_gpu.so`) in Rust. + +```rust +extern crate tvm_frontend as tvm; + +use tvm::*; + +fn main() { + let shape = &mut [2]; + let mut data = vec![3f32, 4.0]; + let mut arr = empty(shape, TVMContext::gpu(0), TVMType::from("float32")); + arr.copy_from_buffer(data.as_mut_slice()); + let mut ret = empty(shape, TVMContext::gpu(0), TVMType::from("float32")); + let mut fadd = Module::load(&Path::new("add_gpu.so")).unwrap(); + let fadd_dep = Module::load(&Path::new("add_gpu.ptx")).unwrap(); + assert!(fadd.enabled("gpu")); + fadd.import_module(fadd_dep); + fadd.entry(); + function::Builder::from(&mut fadd) + .arg(&arr) + .arg(&arr) + .set_output(&mut ret)? + .invoke() + .unwrap(); + + assert_eq!(ret.to_vec::<f32>().unwrap(), vec![6f32, 8.0]); +} +``` + +**Note:** it is required to instruct the `rustc` to link to the generated `add_gpu.so` in runtime, for example by +`cargo:rustc-link-search=native=add_gpu`. + +See the tests and examples custom `build.rs` for more details. + +### Convert and Register a Rust Function as a TVM Packed Function + +One can use `register_global_func!` macro to convert and register a Rust +function of type `fn(&[TVMArgValue]) -> Result<TVMRetValue>` to a global TVM **packed function** as follows + +```rust +#[macro_use] +extern crate tvm_frontend as tvm; +use std::convert::TryInto; +use tvm::*; + +fn main() { + register_global_func! { + fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue> { + let mut ret = 0f32; + let shape = &mut [2]; + for arg in args.iter() { + let e = empty(shape, TVMContext::cpu(0), TVMType::from("float32")); + let arg: NDArray = arg.try_into()?; + let arr = arg.copy_to_ndarray(e).unwrap(); + let rnd: ArrayD<f32> = ArrayD::try_from(&arr).unwrap(); + ret += rnd.scalar_sum(); + } + let ret_val = TVMRetValue::from(&ret); + Ok(ret_val) + } + } + + let shape = &mut [2]; + let mut data = vec![3f32, 4.0]; + let mut arr = empty(shape, TVMContext::cpu(0), TVMType::from("float32")); + arr.copy_from_buffer(data.as_mut_slice()); + let mut registered = function::Builder::default(); + let ret: f64 = registered + .get_function("sum", true) + .arg(&arr) + .arg(&arr) + .invoke() + .unwrap() + .try_into() + .unwrap(); + + assert_eq!(ret, 14f64); +} +``` diff --git a/rust/tvm/examples/resnet/Cargo.toml b/rust/tvm/examples/resnet/Cargo.toml new file mode 100644 index 0000000..e1a474e --- /dev/null +++ b/rust/tvm/examples/resnet/Cargo.toml @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "resnet" +version = "0.0.0" +authors = ["TVM Contributors"] +license = "Apache-2.0" +build = "build.rs" + +[dependencies] +ndarray = "0.12" +tvm = { path = "../../" } +image = "0.20" +csv = "1.1" diff --git a/rust/tvm/examples/resnet/README.md b/rust/tvm/examples/resnet/README.md new file mode 100644 index 0000000..d6e32f7 --- /dev/null +++ b/rust/tvm/examples/resnet/README.md @@ -0,0 +1,45 @@ +<!--- Licensed to the Apache Software Foundation (ASF) under one --> +<!--- or more contributor license agreements. See the NOTICE file --> +<!--- distributed with this work for additional information --> +<!--- regarding copyright ownership. The ASF licenses this file --> +<!--- to you under the Apache License, Version 2.0 (the --> +<!--- "License"); you may not use this file except in compliance --> +<!--- with the License. You may obtain a copy of the License at --> + +<!--- http://www.apache.org/licenses/LICENSE-2.0 --> + +<!--- Unless required by applicable law or agreed to in writing, --> +<!--- software distributed under the License is distributed on an --> +<!--- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY --> +<!--- KIND, either express or implied. See the License for the --> +<!--- specific language governing permissions and limitations --> +<!--- under the License. --> + +## Resnet example + +This end-to-end example shows how to: +* build `Resnet 18` with `tvm` from Python +* use the provided Rust frontend API to test for an input image + +To run the example with pretrained resnet weights, first `tvm` and `mxnet` must be installed for the python build. To install mxnet for cpu, run `pip install mxnet` +and to install `tvm` with `llvm` follow the [TVM installation guide](https://tvm.apache.org/docs/install/index.html). + +* **Build the example**: `cargo build + +To have a successful build, note that it is required to instruct Rust compiler to link to the compiled shared library, for example with +`println!("cargo:rustc-link-search=native={}", build_path)`. See the `build.rs` for more details. + +* **Run the example**: `cargo run` + +Note: To use pretrained weights, one can enable `--pretrained` in `build.rs` with + +``` +let output = Command::new("python") + .arg(concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py")) + .arg(&format!("--build-dir={}", env!("CARGO_MANIFEST_DIR"))) + .arg(&format!("--pretrained")) + .output() + .expect("Failed to execute command"); +``` + +Otherwise, *random weights* are used, therefore, the prediction will be `limpkin, Aramus pictus`! diff --git a/rust/tvm/examples/resnet/build.rs b/rust/tvm/examples/resnet/build.rs new file mode 100644 index 0000000..b9a3c4c --- /dev/null +++ b/rust/tvm/examples/resnet/build.rs @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{path::Path, process::Command}; + +fn main() { + let output = Command::new("python3") + .arg(concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py")) + .arg(&format!("--build-dir={}", env!("CARGO_MANIFEST_DIR"))) + .output() + .expect("Failed to execute command"); + assert!( + Path::new(&format!("{}/deploy_lib.o", env!("CARGO_MANIFEST_DIR"))).exists(), + "Could not prepare demo: {}", + String::from_utf8(output.stderr) + .unwrap() + .trim() + .split("\n") + .last() + .unwrap_or("") + ); + println!( + "cargo:rustc-link-search=native={}", + env!("CARGO_MANIFEST_DIR") + ); +} diff --git a/rust/tvm/examples/resnet/src/build_resnet.py b/rust/tvm/examples/resnet/src/build_resnet.py new file mode 100644 index 0000000..49c67bf --- /dev/null +++ b/rust/tvm/examples/resnet/src/build_resnet.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import argparse +import csv +import logging +from os import path as osp +import sys + +import numpy as np + +import tvm +from tvm import te +from tvm import relay +from tvm.relay import testing +from tvm.contrib import graph_runtime, cc + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +parser = argparse.ArgumentParser(description='Resnet build example') +aa = parser.add_argument +aa('--build-dir', type=str, required=True, help='directory to put the build artifacts') +aa('--pretrained', action='store_true', help='use a pretrained resnet') +aa('--batch-size', type=int, default=1, help='input image batch size') +aa('--opt-level', type=int, default=3, + help='level of optimization. 0 is unoptimized and 3 is the highest level') +aa('--target', type=str, default='llvm', help='target context for compilation') +aa('--image-shape', type=str, default='3,224,224', help='input image dimensions') +aa('--image-name', type=str, default='cat.png', help='name of input image to download') +args = parser.parse_args() + +build_dir = args.build_dir +batch_size = args.batch_size +opt_level = args.opt_level +target = tvm.target.create(args.target) +image_shape = tuple(map(int, args.image_shape.split(","))) +data_shape = (batch_size,) + image_shape + +def build(target_dir): + """ Compiles resnet18 with TVM""" + deploy_lib = osp.join(target_dir, 'deploy_lib.o') + if osp.exists(deploy_lib): + return + + if args.pretrained: + # needs mxnet installed + from mxnet.gluon.model_zoo.vision import get_model + + # if `--pretrained` is enabled, it downloads a pretrained + # resnet18 trained on imagenet1k dataset for image classification task + block = get_model('resnet18_v1', pretrained=True) + net, params = relay.frontend.from_mxnet(block, {"data": data_shape}) + # we want a probability so add a softmax operator + net = relay.Function(net.params, relay.nn.softmax(net.body), + None, net.type_params, net.attrs) + else: + # use random weights from relay.testing + net, params = relay.testing.resnet.get_workload( + num_layers=18, batch_size=batch_size, image_shape=image_shape) + + # compile the model + with relay.build_config(opt_level=opt_level): + graph, lib, params = relay.build_module.build(net, target, params=params) + + # save the model artifacts + lib.save(deploy_lib) + cc.create_shared(osp.join(target_dir, "deploy_lib.so"), + [osp.join(target_dir, "deploy_lib.o")]) + + with open(osp.join(target_dir, "deploy_graph.json"), "w") as fo: + fo.write(graph) + + with open(osp.join(target_dir,"deploy_param.params"), "wb") as fo: + fo.write(relay.save_param_dict(params)) + +def download_img_labels(): + """ Download an image and imagenet1k class labels for test""" + from mxnet.gluon.utils import download + + img_name = 'cat.png' + synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/', + '4d0b62f3d01426887599d4f7ede23ee5/raw/', + '596b27d23537e5a1b5751d2b0481ef172f58b539/', + 'imagenet1000_clsid_to_human.txt']) + synset_name = 'synset.txt' + download('https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true', img_name) + download(synset_url, synset_name) + + with open(synset_name) as fin: + synset = eval(fin.read()) + + with open("synset.csv", "w") as fout: + w = csv.writer(fout) + w.writerows(synset.items()) + +def test_build(build_dir): + """ Sanity check with random input""" + graph = open(osp.join(build_dir, "deploy_graph.json")).read() + lib = tvm.runtime.load(osp.join(build_dir, "deploy_lib.so")) + params = bytearray(open(osp.join(build_dir,"deploy_param.params"), "rb").read()) + input_data = tvm.nd.array(np.random.uniform(size=data_shape).astype("float32")) + ctx = tvm.cpu() + module = graph_runtime.create(graph, lib, ctx) + module.load_params(params) + module.run(data=input_data) + out = module.get_output(0).asnumpy() + + +if __name__ == '__main__': + logger.info("building the model") + build(build_dir) + logger.info("build was successful") + logger.info("test the build artifacts") + test_build(build_dir) + logger.info("test was successful") + if args.pretrained: + download_img_labels() + logger.info("image and synset downloads are successful") diff --git a/rust/tvm/examples/resnet/src/main.rs b/rust/tvm/examples/resnet/src/main.rs new file mode 100644 index 0000000..0aed72b --- /dev/null +++ b/rust/tvm/examples/resnet/src/main.rs @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +extern crate csv; +extern crate image; +extern crate ndarray; +extern crate tvm_frontend as tvm; + +use std::{ + collections::HashMap, + convert::TryInto, + fs::{self, File}, + path::Path, + str::FromStr, +}; + +use image::{FilterType, GenericImageView}; +use ndarray::{Array, ArrayD, Axis}; + +use tvm::*; + +fn main() { + let ctx = TVMContext::cpu(0); + let img = image::open(concat!(env!("CARGO_MANIFEST_DIR"), "/cat.png")).unwrap(); + println!("original image dimensions: {:?}", img.dimensions()); + // for bigger size images, one needs to first resize to 256x256 + // with `img.resize_exact` method and then `image.crop` to 224x224 + let img = img.resize(224, 224, FilterType::Nearest).to_rgb(); + println!("resized image dimensions: {:?}", img.dimensions()); + let mut pixels: Vec<f32> = vec![]; + for pixel in img.pixels() { + let tmp = pixel.data; + // normalize the RGB channels using mean, std of imagenet1k + let tmp = [ + (tmp[0] as f32 - 123.0) / 58.395, // R + (tmp[1] as f32 - 117.0) / 57.12, // G + (tmp[2] as f32 - 104.0) / 57.375, // B + ]; + for e in &tmp { + pixels.push(*e); + } + } + + let arr = Array::from_shape_vec((224, 224, 3), pixels).unwrap(); + let arr: ArrayD<f32> = arr.permuted_axes([2, 0, 1]).into_dyn(); + // make arr shape as [1, 3, 224, 224] acceptable to resnet + let arr = arr.insert_axis(Axis(0)); + // create input tensor from rust's ndarray + let input = NDArray::from_rust_ndarray( + &arr, + TVMContext::cpu(0), + DLDataType::from_str("float32").unwrap(), + ) + .unwrap(); + println!( + "input size is {:?}", + input.shape().expect("cannot get the input shape") + ); + let graph = + fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_graph.json")).unwrap(); + // load the built module + let lib = Module::load(&Path::new(concat!( + env!("CARGO_MANIFEST_DIR"), + "/deploy_lib.so" + ))) + .unwrap(); + // get the global TVM graph runtime function + let runtime_create_fn = Function::get("tvm.graph_runtime.create").unwrap(); + let runtime_create_fn_ret = call_packed!( + runtime_create_fn, + graph, + &lib, + &ctx.device_type, + &ctx.device_id + ) + .unwrap(); + // get graph runtime module + let graph_runtime_module: Module = runtime_create_fn_ret.try_into().unwrap(); + // get the registered `load_params` from runtime module + let ref load_param_fn = graph_runtime_module + .get_function("load_params", false) + .unwrap(); + // parse parameters and convert to TVMByteArray + let params: Vec<u8> = + fs::read(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_param.params")).unwrap(); + let barr = TVMByteArray::from(¶ms); + // load the parameters + call_packed!(load_param_fn, &barr).unwrap(); + // get the set_input function + let ref set_input_fn = graph_runtime_module + .get_function("set_input", false) + .unwrap(); + + call_packed!(set_input_fn, "data".to_string(), &input).unwrap(); + // get `run` function from runtime module + let ref run_fn = graph_runtime_module.get_function("run", false).unwrap(); + // execute the run function. Note that it has no argument + call_packed!(run_fn,).unwrap(); + // prepare to get the output + let output_shape = &mut [1, 1000]; + let output = NDArray::empty( + output_shape, + TVMContext::cpu(0), + DLDataType::from_str("float32").unwrap(), + ); + // get the `get_output` function from runtime module + let ref get_output_fn = graph_runtime_module + .get_function("get_output", false) + .unwrap(); + // execute the get output function + call_packed!(get_output_fn, &0, &output).unwrap(); + // flatten the output as Vec<f32> + let output = output.to_vec::<f32>().unwrap(); + // find the maximum entry in the output and its index + let mut argmax = -1; + let mut max_prob = 0.; + for i in 0..output.len() { + if output[i] > max_prob { + max_prob = output[i]; + argmax = i as i32; + } + } + // create a hash map of (class id, class name) + let mut synset: HashMap<i32, String> = HashMap::new(); + let file = File::open("synset.csv").unwrap(); + let mut rdr = csv::ReaderBuilder::new() + .has_headers(true) + .from_reader(file); + + for result in rdr.records() { + let record = result.unwrap(); + let id: i32 = record[0].parse().unwrap(); + let cls = record[1].to_string(); + synset.insert(id, cls); + } + + println!( + "input image belongs to the class `{}` with probability {}", + synset + .get(&argmax) + .expect("cannot find the class id for argmax"), + max_prob + ); +} diff --git a/rust/tvm/src/ir/array.rs b/rust/tvm/src/ir/array.rs new file mode 100644 index 0000000..2b5a23b --- /dev/null +++ b/rust/tvm/src/ir/array.rs @@ -0,0 +1,74 @@ +use std::convert::TryFrom; +use std::marker::PhantomData; + +use crate::runtime::object::{ObjectRef, ToObjectRef}; + +use tvm_rt::external; +use tvm_rt::RetValue; + +use anyhow::Result; + +#[derive(Clone)] +pub struct Array<T: ToObjectRef> { + object: ObjectRef, + _data: PhantomData<T>, +} + +external! { + #[name("node.ArrayGetItem")] + fn array_get_item(array: ObjectRef, index: isize) -> ObjectRef; +} + +impl<T: ToObjectRef> Array<T> { + pub fn from_vec(data: Vec<T>) -> Result<Array<T>> { + unimplemented!() + // let iter = data.iter().map(|element| element.to_object_ref()); + + // let array_data = Builder::default() + // .get_function("node.Array") + // .args(iter) + // .invoke()? + // .try_into()?; + + // Ok(Array { + // object: array_data, + // _data: PhantomData, + // }) + } + + pub fn get(&self, index: isize) -> Result<T> + where + T: TryFrom<RetValue, Error = anyhow::Error>, + { + unimplemented!() + // // TODO(@jroesch): why do we used a signed index here? + // let element: T = Builder::default() + // .get_function("node.ArrayGetItem") + // .arg(self.object.clone()) + // .arg(index) + // .invoke()? + // .try_into()?; + + // Ok(element) + } +} + +#[cfg(test)] +mod tests { + use super::Array; + use crate::ir::relay::Var; + use crate::runtime::object::ObjectRef; + use anyhow::Result; + + #[test] + fn create_array_and_get() -> Result<()> { + let vec = vec![ + Var::new("foo".into(), ObjectRef::null()), + Var::new("bar".into(), ObjectRef::null()), + ]; + let array = Array::from_vec(vec)?; + assert_eq!(array.get(0)?.name_hint().to_string()?, "foo"); + assert_eq!(array.get(1)?.name_hint().to_string()?, "bar"); + Ok(()) + } +} diff --git a/rust/tvm/src/ir/mod.rs b/rust/tvm/src/ir/mod.rs new file mode 100644 index 0000000..bc667fd --- /dev/null +++ b/rust/tvm/src/ir/mod.rs @@ -0,0 +1,17 @@ +use crate::runtime::Object; +use crate::DataType; + +pub mod array; +pub mod relay; + +#[repr(C)] +pub struct PrimExprNode { + pub base: Object, + pub dtype: DataType, +} + +#[repr(C)] +pub struct IntImmNode { + pub base: PrimExprNode, + pub value: i64, +} diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs new file mode 100644 index 0000000..ac7b707 --- /dev/null +++ b/rust/tvm/src/ir/relay/mod.rs @@ -0,0 +1,232 @@ +use super::array::Array; +use crate::runtime::{IsObject, Object, ObjectPtr, ObjectRef, String as TString, ToObjectRef}; +use crate::DataType; +use tvm_macros::Object; + +#[repr(C)] +#[derive(Object)] +#[ref_name = "Id"] +#[type_key = "relay.Id"] +pub struct IdNode { + pub base: Object, + pub name_hint: TString, +} + +impl Id { + fn new(name_hint: TString) -> Id { + let node = IdNode { + base: Object::base_object::<IdNode>(), + name_hint: name_hint, + }; + Id(Some(ObjectPtr::new(node))) + } +} + +// define_ref!(Id, IdNode); + +#[repr(C)] +#[derive(Object)] +#[ref_name = "BaseExpr"] +#[type_key = "Expr"] +pub struct BaseExprNode { + pub base: Object, +} + +#[repr(C)] +pub struct PrimExprNode { + pub base: BaseExprNode, + pub datatype: DataType, +} + +impl BaseExprNode { + fn base<T: IsObject>() -> BaseExprNode { + BaseExprNode { + base: Object::base_object::<T>(), + } + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "Expr"] +#[type_key = "relay.Expr"] +pub struct RelayExpr { + pub base: BaseExprNode, + pub span: ObjectRef, + pub checked_type: ObjectRef, +} + +impl RelayExpr { + fn base<T: IsObject>() -> RelayExpr { + RelayExpr { + base: BaseExprNode::base::<T>(), + span: ObjectRef::null(), + checked_type: ObjectRef::null(), + } + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "GlobalVar"] +#[type_key = "relay.GlobalVar"] +pub struct GlobalVarNode { + pub base: RelayExpr, + pub name_hint: TString, +} + +impl GlobalVar { + pub fn new(name_hint: String, _span: ObjectRef) -> GlobalVar { + let node = GlobalVarNode { + base: RelayExpr::base::<GlobalVarNode>(), + // span: span, + // checked_type: ObjectRef(None),, + name_hint: TString::new(name_hint).unwrap(), + }; + GlobalVar(Some(ObjectPtr::new(node))) + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "Constant"] +#[type_key = "relay.Constant"] +pub struct ConstantNode { + pub base: RelayExpr, + pub data: ObjectRef, // make this NDArray. +} + +impl Constant { + pub fn new(data: ObjectRef, _span: ObjectRef) -> Constant { + let node = ConstantNode { + base: RelayExpr::base::<ConstantNode>(), + data: data, + }; + Constant(Some(ObjectPtr::new(node))) + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "Var"] +#[type_key = "relay.Var"] +pub struct VarNode { + pub base: RelayExpr, + pub vid: Id, + pub type_annotation: ObjectRef, +} + +impl Var { + pub fn new(name_hint: String, _span: ObjectRef) -> Var { + let node = VarNode { + base: RelayExpr::base::<VarNode>(), + vid: Id::new(TString::new(name_hint.to_string()).unwrap()), + type_annotation: ObjectRef::null(), + }; + Var(Some(ObjectPtr::new(node))) + } + + pub fn name_hint(&self) -> &TString { + &self.vid.0.as_ref().unwrap().name_hint + } + + pub fn to_expr(self) -> Expr { + unsafe { Expr(std::mem::transmute(self.0)) } + } +} + +pub type Type = ObjectRef; +pub type Attrs = ObjectRef; + +#[repr(C)] +#[derive(Object)] +#[ref_name = "Call"] +#[type_key = "relay.Call"] +pub struct CallNode { + pub base: RelayExpr, + pub op: Expr, + pub args: Array<Expr>, + pub attrs: ObjectRef, + pub type_args: Array<ObjectRef>, +} + +impl Call { + pub fn new( + op: Expr, + args: Array<Expr>, + attrs: Attrs, + type_args: Array<ObjectRef>, + _span: ObjectRef, + ) -> Call { + let node = CallNode { + base: RelayExpr::base::<VarNode>(), + op: op, + args: args, + attrs: attrs, + type_args: type_args, + }; + Call(Some(ObjectPtr::new(node))) + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "Function"] +#[type_key = "relay.Function"] +pub struct FunctionNode { + pub base: RelayExpr, + pub params: Array<Var>, + pub body: Expr, + pub ret_type: Type, + pub type_params: Array<Type>, +} + +impl Function { + pub fn new( + params: Array<Var>, + body: Expr, + ret_type: Type, + type_params: Array<Type>, + ) -> Function { + let node = FunctionNode { + base: RelayExpr::base::<FunctionNode>(), + params: params, + body: body, + ret_type: ret_type, + type_params: type_params, + }; + Function(Some(ObjectPtr::new(node))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::runtime::{as_text, String as TString}; + use anyhow::Result; + + #[test] + fn test_id() -> Result<()> { + let string = TString::new("foo".to_string()).expect("bar"); + let id = Id::new(string); + let cstr = as_text(&id.upcast())?; + assert!(cstr.into_string()?.contains("relay.Id")); + Ok(()) + } + + #[test] + fn test_global() -> Result<()> { + let gv = GlobalVar::new("main".to_string(), ObjectRef::null()); + let cstr = as_text(&gv.upcast())?; + assert!(cstr.into_string()?.contains("@main")); + Ok(()) + } + + #[test] + fn test_var() -> Result<()> { + let var = Var::new("local".to_string(), ObjectRef::null()); + let cstr = as_text(&var.upcast())?; + assert!(cstr.into_string()?.contains("%local")); + Ok(()) + } +} diff --git a/rust/tvm/src/lib.rs b/rust/tvm/src/lib.rs new file mode 100644 index 0000000..64252a4 --- /dev/null +++ b/rust/tvm/src/lib.rs @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! [TVM](https://github.com/apache/incubator-tvm) is a compiler stack for deep learning systems. +//! +//! This crate provides an idiomatic Rust API for TVM runtime frontend. +//! +//! One particular use case is that given optimized deep learning model artifacts, +//! (compiled with TVM) which include a shared library +//! `lib.so`, `graph.json` and a byte-array `param.params`, one can load them +//! in Rust idomatically to create a TVM Graph Runtime and +//! run the model for some inputs and get the +//! desired predictions *all in Rust*. +//! +//! Checkout the `examples` repository for more details. + +pub use crate::{errors::*, function::Function, module::Module, ndarray::NDArray}; + +pub use tvm_rt::{Context, DataType, DeviceType}; + +pub use tvm_rt::context; +pub use tvm_rt::errors; +pub use tvm_rt::function; +pub use tvm_rt::module; +pub use tvm_rt::ndarray; +pub use tvm_rt::value; +pub mod ir; +pub mod runtime; +pub mod transform; + +pub use runtime::version; diff --git a/rust/tvm/src/runtime/mod.rs b/rust/tvm/src/runtime/mod.rs new file mode 100644 index 0000000..57d43ee --- /dev/null +++ b/rust/tvm/src/runtime/mod.rs @@ -0,0 +1 @@ +pub use tvm_rt::*; diff --git a/rust/tvm/src/transform.rs b/rust/tvm/src/transform.rs new file mode 100644 index 0000000..0f10ca3 --- /dev/null +++ b/rust/tvm/src/transform.rs @@ -0,0 +1,42 @@ +use crate::ir::array::Array; +use crate::runtime::{external, Function, String as TString}; +use crate::runtime::{Object, ObjectPtr, ObjectRef}; +use tvm_macros::Object; + +type Pass = ObjectRef; + +#[repr(C)] +#[derive(Object)] +#[ref_name = "PassInfo"] +#[type_key = "transform.PassInfo"] +pub struct PassInfoNode { + pub base: Object, + pub opt_level: i32, + pub name: TString, + pub required: Array<TString>, +} + +impl PassInfo { + pub fn new(opt_level: i32, name: String, required: Vec<String>) -> anyhow::Result<PassInfo> { + let required: Result<_, _> = required + .into_iter() + .map(|name| TString::new(name)) + .collect(); + + let required = Array::from_vec(required?)?; + + let node = PassInfoNode { + base: Object::base_object::<PassInfoNode>(), + opt_level, + name: TString::new(name).unwrap(), + required, + }; + + Ok(PassInfo(Some(ObjectPtr::new(node)))) + } +} + +external! { + #[name("relay._transform.MakeFunctionPass")] + fn create_func_pass(func: Function, pass_info: PassInfo) -> Pass; +} diff --git a/rust/tvm/tests/basics/.gitignore b/rust/tvm/tests/basics/.gitignore new file mode 100644 index 0000000..10a4b22 --- /dev/null +++ b/rust/tvm/tests/basics/.gitignore @@ -0,0 +1,7 @@ +/target +**/*.rs.bk +Cargo.lock +*.o +*.so +*.ptx +*.json diff --git a/rust/tvm/tests/basics/Cargo.toml b/rust/tvm/tests/basics/Cargo.toml new file mode 100644 index 0000000..0b059da --- /dev/null +++ b/rust/tvm/tests/basics/Cargo.toml @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "basics" +version = "0.0.0" +authors = ["TVM Contributors"] +license = "Apache-2.0" +build = "build.rs" + +[dependencies] +ndarray = "0.12" +tvm = { path = "../../" } + +[features] +default = ["cpu"] +cpu = [] +gpu = [] diff --git a/rust/tvm/tests/basics/build.rs b/rust/tvm/tests/basics/build.rs new file mode 100644 index 0000000..77a3bae --- /dev/null +++ b/rust/tvm/tests/basics/build.rs @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +fn main() { + let out_dir = std::env::var("OUT_DIR").unwrap(); + + let output = std::process::Command::new(concat!(env!("CARGO_MANIFEST_DIR"), "/src/tvm_add.py")) + .args(&[ + if cfg!(feature = "cpu") { + "llvm" + } else { + "cuda" + }, + &std::env::var("OUT_DIR").unwrap(), + ]) + .output() + .expect("Failed to execute command"); + assert!( + std::path::Path::new(&format!("{}/test_add.so", out_dir)).exists(), + "Could not build tvm lib: {}", + String::from_utf8(output.stderr) + .unwrap() + .trim() + .split("\n") + .last() + .unwrap_or("") + ); + + println!("cargo:rustc-link-search=native={}", out_dir); +} diff --git a/rust/tvm/tests/basics/src/main.rs b/rust/tvm/tests/basics/src/main.rs new file mode 100644 index 0000000..ca53dcf --- /dev/null +++ b/rust/tvm/tests/basics/src/main.rs @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +extern crate ndarray as rust_ndarray; +extern crate tvm_frontend as tvm; + +use std::str::FromStr; + +use tvm::*; + +fn main() { + let shape = &mut [2]; + let mut data = vec![3f32, 4.0]; + + let (ctx, ctx_name) = if cfg!(feature = "cpu") { + (TVMContext::cpu(0), "cpu") + } else { + (TVMContext::gpu(0), "gpu") + }; + let dtype = DLDataType::from_str("float32").unwrap(); + let mut arr = NDArray::empty(shape, ctx, dtype); + arr.copy_from_buffer(data.as_mut_slice()); + let mut ret = NDArray::empty(shape, ctx, dtype); + let mut fadd = Module::load(&concat!(env!("OUT_DIR"), "/test_add.so")).unwrap(); + if !fadd.enabled(ctx_name) { + return; + } + if cfg!(feature = "gpu") { + fadd.import_module(Module::load(&concat!(env!("OUT_DIR"), "/test_add.ptx")).unwrap()); + } + function::Builder::from(&mut fadd) + .arg(&arr) + .arg(&arr) + .arg(&mut ret) + .invoke() + .unwrap(); + + assert_eq!(ret.to_vec::<f32>().unwrap(), vec![6f32, 8.0]); +} diff --git a/rust/tvm/tests/basics/src/tvm_add.py b/rust/tvm/tests/basics/src/tvm_add.py new file mode 100755 index 0000000..3911d40 --- /dev/null +++ b/rust/tvm/tests/basics/src/tvm_add.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os.path as osp +import sys + +import tvm +from tvm import te +from tvm.contrib import cc + + +def main(target, out_dir): + n = te.var('n') + A = te.placeholder((n,), name='A') + B = te.placeholder((n,), name='B') + C = te.compute(A.shape, lambda i: A[i] + B[i], name='C') + s = te.create_schedule(C.op) + + if target == 'cuda': + bx, tx = s[C].split(C.op.axis[0], factor=64) + s[C].bind(bx, te.thread_axis('blockIdx.x')) + s[C].bind(tx, te.thread_axis('threadIdx.x')) + + fadd = tvm.build(s, [A, B, C], target, target_host='llvm', name='myadd') + + fadd.save(osp.join(out_dir, 'test_add.o')) + if target == 'cuda': + fadd.imported_modules[0].save(osp.join(out_dir, 'test_add.ptx')) + cc.create_shared( + osp.join(out_dir, 'test_add.so'), [osp.join(out_dir, 'test_add.o')]) + + +if __name__ == '__main__': + main(sys.argv[1], sys.argv[2]) + diff --git a/rust/tvm/tests/callback/Cargo.toml b/rust/tvm/tests/callback/Cargo.toml new file mode 100644 index 0000000..5c89d2a --- /dev/null +++ b/rust/tvm/tests/callback/Cargo.toml @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "callback" +version = "0.0.0" +authors = ["TVM Contributors"] +edition = "2018" + +[dependencies] +ndarray = "0.12" +tvm = { path = "../../" } diff --git a/rust/tvm/tests/callback/src/bin/array.rs b/rust/tvm/tests/callback/src/bin/array.rs new file mode 100644 index 0000000..cb4a822 --- /dev/null +++ b/rust/tvm/tests/callback/src/bin/array.rs @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#![allow(unused_imports)] + +extern crate ndarray as rust_ndarray; +#[macro_use] +extern crate tvm_frontend as tvm; + +use rust_ndarray::ArrayD; +use std::{ + convert::{TryFrom, TryInto}, + str::FromStr, +}; + +use tvm::{errors::Error, *}; + +fn main() { + register_global_func! { + fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> { + let mut ret = 0f32; + let shape = &mut [2]; + for arg in args.iter() { + let e = NDArray::empty( + shape, TVMContext::cpu(0), + DLDataType::from_str("float32").unwrap() + ); + let arg: NDArray = arg.try_into()?; + let arr = arg.copy_to_ndarray(e)?; + let rnd: ArrayD<f32> = ArrayD::try_from(&arr)?; + ret += rnd.scalar_sum(); + } + Ok(TVMRetValue::from(ret)) + } + } + + let shape = &mut [2]; + let mut data = vec![3f32, 4.0]; + let mut arr = NDArray::empty( + shape, + TVMContext::cpu(0), + DLDataType::from_str("float32").unwrap(), + ); + arr.copy_from_buffer(data.as_mut_slice()); + + let mut registered = function::Builder::default(); + let ret: f32 = registered + .get_function("sum") + .arg(&arr) + .arg(&arr) + .invoke() + .unwrap() + .try_into() + .unwrap(); + assert_eq!(ret, 7f32); +} diff --git a/rust/tvm/tests/callback/src/bin/error.rs b/rust/tvm/tests/callback/src/bin/error.rs new file mode 100644 index 0000000..c9f9a6f --- /dev/null +++ b/rust/tvm/tests/callback/src/bin/error.rs @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::panic; + +use tvm_frontend::{errors::Error, *}; + +fn main() { + register_global_func! { + fn error(_args: &[TVMArgValue]) -> Result<TVMRetValue, Error> { + Err(errors::TypeMismatchError{ + expected: "i64".to_string(), + actual: "f64".to_string(), + }.into()) + } + } + + let mut registered = function::Builder::default(); + registered.get_function("error"); + assert!(registered.func.is_some()); + registered.args(&[10, 20]); + + println!("expected error message is:"); + panic::set_hook(Box::new(|panic_info| { + // if let Some(msg) = panic_info.message() { + // println!("{:?}", msg); + // } + if let Some(location) = panic_info.location() { + println!( + "panic occurred in file '{}' at line {}", + location.file(), + location.line() + ); + } else { + println!("panic occurred but can't get location information"); + } + })); + + let _result = registered.invoke(); +} diff --git a/rust/tvm/tests/callback/src/bin/float.rs b/rust/tvm/tests/callback/src/bin/float.rs new file mode 100644 index 0000000..7111e28 --- /dev/null +++ b/rust/tvm/tests/callback/src/bin/float.rs @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#![allow(unused_imports)] + +#[macro_use] +extern crate tvm_frontend as tvm; + +use std::convert::TryInto; +use tvm::{errors::Error, *}; + +fn main() { + register_global_func! { + fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> { + let mut ret = 0.0; + for arg in args.into_iter() { + let val: f64 = arg.try_into()?; + ret += val; + } + Ok(TVMRetValue::from(ret)) + } + } + + let mut registered = function::Builder::default(); + registered.get_function("sum"); + assert!(registered.func.is_some()); + let ret: f64 = registered + .args(&[10.0f64, 20.0, 30.0]) + .invoke() + .unwrap() + .try_into() + .unwrap(); + assert_eq!(ret, 60f64); +} diff --git a/rust/tvm/tests/callback/src/bin/int.rs b/rust/tvm/tests/callback/src/bin/int.rs new file mode 100644 index 0000000..23910a3 --- /dev/null +++ b/rust/tvm/tests/callback/src/bin/int.rs @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#![allow(unused_imports)] + +extern crate tvm_frontend as tvm; + +use std::convert::TryInto; +use tvm::{errors::Error, *}; + +fn main() { + fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> { + let mut ret = 0i64; + for arg in args.iter() { + let val: i64 = arg.try_into()?; + ret += val; + } + Ok(TVMRetValue::from(ret)) + } + + tvm::function::register(sum, "mysum".to_owned(), false).unwrap(); + + let mut registered = function::Builder::default(); + registered.get_function("mysum"); + assert!(registered.func.is_some()); + let ret: i64 = registered + .args(&[10, 20, 30]) + .invoke() + .unwrap() + .try_into() + .unwrap(); + assert_eq!(ret, 60); +} diff --git a/rust/tvm/tests/callback/src/bin/string.rs b/rust/tvm/tests/callback/src/bin/string.rs new file mode 100644 index 0000000..9ead587 --- /dev/null +++ b/rust/tvm/tests/callback/src/bin/string.rs @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#![allow(unused_imports)] + +#[macro_use] +extern crate tvm_frontend as tvm; +use std::convert::TryInto; +use tvm::{errors::Error, *}; + +// FIXME +fn main() { + register_global_func! { + fn concate_str(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> { + let mut ret = "".to_string(); + for arg in args.iter() { + let val: &str = arg.try_into()?; + ret += val; + } + Ok(TVMRetValue::from(ret)) + } + } + let a = std::ffi::CString::new("a").unwrap(); + let b = std::ffi::CString::new("b").unwrap(); + let c = std::ffi::CString::new("c").unwrap(); + let mut registered = function::Builder::default(); + registered.get_function("concate_str"); + assert!(registered.func.is_some()); + let ret: String = registered + .arg(a.as_c_str()) + .arg(b.as_c_str()) + .arg(c.as_c_str()) + .invoke() + .unwrap() + .try_into() + .unwrap(); + assert_eq!(ret, "abc".to_owned()); +} diff --git a/rust/tvm/tests/test_ir.rs b/rust/tvm/tests/test_ir.rs new file mode 100644 index 0000000..a43f27e --- /dev/null +++ b/rust/tvm/tests/test_ir.rs @@ -0,0 +1,37 @@ +use std::convert::TryInto; +use std::str::FromStr; +use tvm::ir::IntImmNode; +use tvm::runtime::String as TString; +use tvm::runtime::{debug_print, Object, ObjectPtr, ObjectRef}; +use tvm_rt::{call_packed, DLDataType, Function}; +use tvm_sys::TVMRetValue; + +#[test] +fn test_new_object() -> anyhow::Result<()> { + let object = Object::base_object::<Object>(); + let ptr = ObjectPtr::new(object); + assert_eq!(ptr.count(), 1); + Ok(()) +} + +#[test] +fn test_new_string() -> anyhow::Result<()> { + let string = TString::new("hello world!".to_string())?; + Ok(()) +} + +#[test] +fn test_obj_build() -> anyhow::Result<()> { + let int_imm = Function::get("ir.IntImm").expect("Stable TVM API not found."); + + let dt = DLDataType::from_str("int32").expect("Known datatype doesn't convert."); + + let ret_val: ObjectRef = call_packed!(int_imm, dt, 1337) + .expect("foo") + .try_into() + .unwrap(); + + debug_print(&ret_val); + + Ok(()) +}