This is an automated email from the ASF dual-hosted git repository. jroesch pushed a commit to branch rust-stablize in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit 5939ce700ca24f10dab71fa0d0c5fab19fd2be4f Author: Jared Roesch <jroe...@octoml.ai> AuthorDate: Sat Mar 21 22:30:29 2020 -0700 Fix up the final pieces --- rust/Cargo.toml | 1 - rust/frontend/tests/callback/Cargo.toml | 1 + rust/frontend/tests/callback/src/bin/error.rs | 5 +- rust/macros/Cargo.toml | 9 +- rust/macros/src/lib.rs | 123 ++++++++++++++++++++-- rust/macros_raw/Cargo.toml | 36 ------- rust/macros_raw/src/lib.rs | 141 -------------------------- rust/runtime/tests/test_nn/build.rs | 3 +- rust/runtime/tests/test_tvm_basic/build.rs | 16 ++- rust/runtime/tests/test_tvm_basic/src/main.rs | 2 +- 10 files changed, 141 insertions(+), 196 deletions(-) diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 190a6eb..8467f6a 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -19,7 +19,6 @@ members = [ "common", "macros", - "macros_raw", "runtime", "runtime/tests/test_tvm_basic", "runtime/tests/test_tvm_dso", diff --git a/rust/frontend/tests/callback/Cargo.toml b/rust/frontend/tests/callback/Cargo.toml index a452572..dfe80cc 100644 --- a/rust/frontend/tests/callback/Cargo.toml +++ b/rust/frontend/tests/callback/Cargo.toml @@ -19,6 +19,7 @@ name = "callback" version = "0.0.0" authors = ["TVM Contributors"] +edition = "2018" [dependencies] ndarray = "0.12" diff --git a/rust/frontend/tests/callback/src/bin/error.rs b/rust/frontend/tests/callback/src/bin/error.rs index 29bfd9a..c9f9a6f 100644 --- a/rust/frontend/tests/callback/src/bin/error.rs +++ b/rust/frontend/tests/callback/src/bin/error.rs @@ -19,10 +19,7 @@ use std::panic; -#[macro_use] -extern crate tvm_frontend as tvm; - -use tvm::{errors::Error, *}; +use tvm_frontend::{errors::Error, *}; fn main() { register_global_func! { diff --git a/rust/macros/Cargo.toml b/rust/macros/Cargo.toml index ff4f7d8..784b35e 100644 --- a/rust/macros/Cargo.toml +++ b/rust/macros/Cargo.toml @@ -19,13 +19,18 @@ name = "tvm-macros" version = "0.1.1" license = "Apache-2.0" -description = "Proc macros used by the TVM crates." +description = "Procedural macros of the TVM crate." repository = "https://github.com/apache/incubator-tvm" readme = "README.md" keywords = ["tvm"] authors = ["TVM Contributors"] edition = "2018" +[lib] +proc-macro = true [dependencies] -tvm-macros-raw = { path = "../macros_raw" } +goblin = "0.0.24" +proc-macro2 = "^1.0" +quote = "1.0" +syn = "1.0" diff --git a/rust/macros/src/lib.rs b/rust/macros/src/lib.rs index efd85d0..d1d86b6 100644 --- a/rust/macros/src/lib.rs +++ b/rust/macros/src/lib.rs @@ -17,12 +17,123 @@ * under the License. */ -#[macro_use] -extern crate tvm_macros_raw; +extern crate proc_macro; -#[macro_export] -macro_rules! import_module { - ($module_path:literal) => { - $crate::import_module_raw!(file!(), $module_path); +use std::{fs::File, io::Read}; +use syn::parse::{Parse, ParseStream, Result}; +use syn::{LitStr}; +use quote::quote; + +use std::path::PathBuf; + +struct ImportModule { + importing_file: LitStr, +} + +impl Parse for ImportModule { + fn parse(input: ParseStream) -> Result<Self> { + let importing_file: LitStr = input.parse()?; + Ok(ImportModule { + importing_file, + }) + } +} + +#[proc_macro] +pub fn import_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let import_module_args = syn::parse_macro_input!(input as ImportModule); + + let manifest = std::env::var("CARGO_MANIFEST_DIR") + .expect("variable should always be set by Cargo."); + + let mut path = PathBuf::new(); + path.push(manifest); + path = path.join(import_module_args.importing_file.value()); + + let mut fd = File::open(&path) + .unwrap_or_else(|_| panic!("Unable to find TVM object file at `{}`", path.display())); + let mut buffer = Vec::new(); + fd.read_to_end(&mut buffer).unwrap(); + + let fn_names = match goblin::Object::parse(&buffer).unwrap() { + goblin::Object::Elf(elf) => elf + .syms + .iter() + .filter_map(|s| { + if s.st_type() == 0 || goblin::elf::sym::type_to_str(s.st_type()) == "FILE" { + return None; + } + match elf.strtab.get(s.st_name) { + Some(Ok(name)) if name != "" => { + Some(syn::Ident::new(name, proc_macro2::Span::call_site())) + } + _ => None, + } + }) + .collect::<Vec<_>>(), + goblin::Object::Mach(goblin::mach::Mach::Binary(obj)) => { + obj.symbols() + .filter_map(|s| match s { + Ok((name, ref nlist)) + if nlist.is_global() + && nlist.n_sect != 0 + && !name.ends_with("tvm_module_ctx") => + { + Some(syn::Ident::new( + if name.starts_with('_') { + // Mach objects prepend a _ to globals. + &name[1..] + } else { + &name + }, + proc_macro2::Span::call_site(), + )) + } + _ => None, + }) + .collect::<Vec<_>>() + } + _ => panic!("Unsupported object format."), + }; + + let extern_fns = quote! { + mod ext { + extern "C" { + #( + pub(super) fn #fn_names( + args: *const tvm_runtime::ffi::TVMValue, + type_codes: *const std::os::raw::c_int, + num_args: std::os::raw::c_int + ) -> std::os::raw::c_int; + )* + } + } }; + + let fns = quote! { + use tvm_runtime::{ffi::TVMValue, TVMArgValue, TVMRetValue, FuncCallError}; + #extern_fns + + #( + pub fn #fn_names(args: &[TVMArgValue]) -> Result<TVMRetValue, FuncCallError> { + let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args + .into_iter() + .map(|arg| { + let (val, code) = arg.to_tvm_value(); + (val, code as i32) + }) + .unzip(); + let exit_code = unsafe { + ext::#fn_names(values.as_ptr(), type_codes.as_ptr(), values.len() as i32) + }; + if exit_code == 0 { + Ok(TVMRetValue::default()) + } else { + Err(FuncCallError::get_with_context(stringify!(#fn_names).to_string())) + } + } + )* + }; + + proc_macro::TokenStream::from(fns) } diff --git a/rust/macros_raw/Cargo.toml b/rust/macros_raw/Cargo.toml deleted file mode 100644 index 9b3d3e9..0000000 --- a/rust/macros_raw/Cargo.toml +++ /dev/null @@ -1,36 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -[package] -name = "tvm-macros-raw" -version = "0.1.1" -license = "Apache-2.0" -description = "Proc macros used by the TVM crates." -repository = "https://github.com/apache/incubator-tvm" -readme = "README.md" -keywords = ["tvm"] -authors = ["TVM Contributors"] -edition = "2018" - -[lib] -proc-macro = true - -[dependencies] -goblin = "0.0.24" -proc-macro2 = "^1.0" -quote = "1.0" -syn = "1.0" diff --git a/rust/macros_raw/src/lib.rs b/rust/macros_raw/src/lib.rs deleted file mode 100644 index f518f88..0000000 --- a/rust/macros_raw/src/lib.rs +++ /dev/null @@ -1,141 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -extern crate proc_macro; - -use std::{fs::File, io::Read}; -use syn::parse::{Parse, ParseStream, Result}; -use syn::{Token, LitStr}; -use quote::quote; - -use std::path::PathBuf; - -struct ImportModule { - importing_file: LitStr, - module_path: LitStr, -} - -impl Parse for ImportModule { - fn parse(input: ParseStream) -> Result<Self> { - let importing_file: LitStr = input.parse()?; - input.parse::<Token![,]>()?; - let module_path: LitStr = input.parse()?; - Ok(ImportModule { - importing_file, - module_path, - }) - } -} - -#[proc_macro] -pub fn import_module_raw(input: proc_macro::TokenStream) -> proc_macro::TokenStream { - let import_module_args = syn::parse_macro_input!(input as ImportModule); - - let mut path = PathBuf::new(); - path = path.join(import_module_args.importing_file.value()); - path.pop(); // remove the filename - path.push(import_module_args.module_path.value()); - - let mut fd = File::open(&path) - .unwrap_or_else(|_| panic!("Unable to find TVM object file at `{}`", path.display())); - let mut buffer = Vec::new(); - fd.read_to_end(&mut buffer).unwrap(); - - let fn_names = match goblin::Object::parse(&buffer).unwrap() { - goblin::Object::Elf(elf) => elf - .syms - .iter() - .filter_map(|s| { - if s.st_type() == 0 || goblin::elf::sym::type_to_str(s.st_type()) == "FILE" { - return None; - } - match elf.strtab.get(s.st_name) { - Some(Ok(name)) if name != "" => { - Some(syn::Ident::new(name, proc_macro2::Span::call_site())) - } - _ => None, - } - }) - .collect::<Vec<_>>(), - goblin::Object::Mach(goblin::mach::Mach::Binary(obj)) => { - obj.symbols() - .filter_map(|s| match s { - Ok((name, ref nlist)) - if nlist.is_global() - && nlist.n_sect != 0 - && !name.ends_with("tvm_module_ctx") => - { - Some(syn::Ident::new( - if name.starts_with('_') { - // Mach objects prepend a _ to globals. - &name[1..] - } else { - &name - }, - proc_macro2::Span::call_site(), - )) - } - _ => None, - }) - .collect::<Vec<_>>() - } - _ => panic!("Unsupported object format."), - }; - - let extern_fns = quote! { - mod ext { - extern "C" { - #( - pub(super) fn #fn_names( - args: *const tvm_runtime::ffi::TVMValue, - type_codes: *const std::os::raw::c_int, - num_args: std::os::raw::c_int - ) -> std::os::raw::c_int; - )* - } - } - }; - - let fns = quote! { - use tvm_runtime::{ffi::TVMValue, TVMArgValue, TVMRetValue, FuncCallError}; - #extern_fns - - #( - pub fn #fn_names(args: &[TVMArgValue]) -> Result<TVMRetValue, FuncCallError> { - let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args - .into_iter() - .map(|arg| { - let (val, code) = arg.to_tvm_value(); - (val, code as i32) - }) - .unzip(); - let exit_code = unsafe { - ext::#fn_names(values.as_ptr(), type_codes.as_ptr(), values.len() as i32) - }; - if exit_code == 0 { - Ok(TVMRetValue::default()) - } else { - Err(FuncCallError::get_with_context(stringify!(#fn_names).to_string())) - } - } - )* - }; - - proc_macro::TokenStream::from(fns) -} diff --git a/rust/runtime/tests/test_nn/build.rs b/rust/runtime/tests/test_nn/build.rs index 2d0b066..f072a90 100644 --- a/rust/runtime/tests/test_nn/build.rs +++ b/rust/runtime/tests/test_nn/build.rs @@ -44,7 +44,8 @@ fn main() { .unwrap_or("") ); - let mut builder = Builder::new(File::create(format!("{}/libgraph.a", out_dir)).unwrap()); + let file = File::create(format!("{}/libtestnn.a", out_dir)).unwrap(); + let mut builder = Builder::new(file); builder.append_path(format!("{}/graph.o", out_dir)).unwrap(); println!("cargo:rustc-link-lib=static=graph"); diff --git a/rust/runtime/tests/test_tvm_basic/build.rs b/rust/runtime/tests/test_tvm_basic/build.rs index 3439f9c..ade9e02 100644 --- a/rust/runtime/tests/test_tvm_basic/build.rs +++ b/rust/runtime/tests/test_tvm_basic/build.rs @@ -33,7 +33,7 @@ fn main() { } let obj_file = out_dir.join("test.o"); - let lib_file = out_dir.join("libtest.a"); + let lib_file = out_dir.join("libtest_basic.a"); let output = Command::new(concat!( env!("CARGO_MANIFEST_DIR"), @@ -53,9 +53,17 @@ fn main() { .unwrap_or("") ); - let mut builder = Builder::new(File::create(lib_file).unwrap()); - builder.append_path(obj_file).unwrap(); + let mut builder = Builder::new(File::create(&lib_file).unwrap()); + builder.append_path(&obj_file).unwrap(); + drop(builder); - println!("cargo:rustc-link-lib=static=test"); + let status = Command::new("ranlib") + .arg(&lib_file) + .status() + .expect("fdjlksafjdsa"); + + assert!(status.success()); + + println!("cargo:rustc-link-lib=static=test_basic"); println!("cargo:rustc-link-search=native={}", out_dir.display()); } diff --git a/rust/runtime/tests/test_tvm_basic/src/main.rs b/rust/runtime/tests/test_tvm_basic/src/main.rs index a83078e..653cb43 100644 --- a/rust/runtime/tests/test_tvm_basic/src/main.rs +++ b/rust/runtime/tests/test_tvm_basic/src/main.rs @@ -25,7 +25,7 @@ use ndarray::Array; use tvm_runtime::{DLTensor, Module as _, SystemLibModule}; mod tvm_mod { - import_module!("../lib/test.o"); + import_module!("lib/test.o"); } fn main() {