This is an automated email from the ASF dual-hosted git repository. jroesch pushed a commit to branch cargo-build in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit a9ee3cb34c020a4debe75fc9a194303f22d00892 Author: Jared Roesch <roesch...@gmail.com> AuthorDate: Thu Oct 22 11:48:34 2020 -0700 WIP --- rust/tvm-macros/Cargo.toml | 2 +- rust/tvm-macros/src/external.rs | 43 +++++++++++++++++++++++++++++++++-------- rust/tvm-macros/src/lib.rs | 1 + rust/tvm-rt/src/object/mod.rs | 2 +- rust/tvm/src/ir/module.rs | 16 +++++++++++---- 5 files changed, 50 insertions(+), 14 deletions(-) diff --git a/rust/tvm-macros/Cargo.toml b/rust/tvm-macros/Cargo.toml index 63b8472..8e97d3b 100644 --- a/rust/tvm-macros/Cargo.toml +++ b/rust/tvm-macros/Cargo.toml @@ -33,5 +33,5 @@ proc-macro = true goblin = "^0.2" proc-macro2 = "^1.0" quote = "^1.0" -syn = { version = "1.0.17", features = ["full", "extra-traits"] } +syn = { version = "^1.0", features = ["full", "parsing", "extra-traits"] } proc-macro-error = "^1.0" diff --git a/rust/tvm-macros/src/external.rs b/rust/tvm-macros/src/external.rs index de8ada3..44a242c 100644 --- a/rust/tvm-macros/src/external.rs +++ b/rust/tvm-macros/src/external.rs @@ -21,9 +21,28 @@ use proc_macro_error::abort; use quote::quote; use syn::parse::{Parse, ParseStream, Result}; -use syn::{FnArg, Generics, Ident, Lit, Meta, NestedMeta, Pat, ReturnType, TraitItemMethod, Type}; +use syn::{Token, FnArg, Signature, Attribute, token::Semi, Visibility, Generics, Ident, Lit, Meta, NestedMeta, Pat, ReturnType, TraitItemMethod, Type}; + +struct ExternalItem { + attrs: Vec<Attribute>, + visibility: Visibility, + sig: Signature, +} + +impl Parse for ExternalItem { + fn parse(input: ParseStream) -> Result<Self> { + let item = ExternalItem { + attrs: input.call(Attribute::parse_outer)?, + visibility: input.parse()?, + sig: input.parse()?, + }; + let _semi: Semi = input.parse()?; + Ok(item) + } +} struct External { + visibility: Visibility, tvm_name: String, ident: Ident, generics: Generics, @@ -33,7 +52,8 @@ struct External { impl Parse for External { fn parse(input: ParseStream) -> Result<Self> { - let method: TraitItemMethod = input.parse()?; + let method: ExternalItem = input.parse()?; + let visibility = method.visibility; assert_eq!(method.attrs.len(), 1); let sig = method.sig; let tvm_name = method.attrs[0].parse_meta()?; @@ -48,8 +68,7 @@ impl Parse for External { } _ => panic!(), }; - assert_eq!(method.default, None); - assert!(method.semi_token != None); + let ident = sig.ident; let generics = sig.generics; let inputs = sig @@ -61,6 +80,7 @@ impl Parse for External { let ret_type = sig.output; Ok(External { + visibility, tvm_name, ident, generics, @@ -99,6 +119,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let mut items = Vec::new(); for external in &ext_input.externs { + let visibility = &external.visibility; let name = &external.ident; let global_name = format!("global_{}", external.ident); let global_name = Ident::new(&global_name, Span::call_site()); @@ -127,15 +148,21 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let ty: Type = *pat_type.ty.clone(); (ident, ty) } - _ => panic!(), + _ => abort! { pat_type, + "Only supports type parameters." + } }, - _ => panic!(), + pat => abort! { + pat, "invalid pattern type for function"; + + note = "{:?} is not allowed here", pat; + } }) .unzip(); let ret_type = match &external.ret_type { ReturnType::Type(_, rtype) => *rtype.clone(), - _ => panic!(), + ReturnType::Default => syn::parse_str::<Type>("()").unwrap(), }; let global = quote! { @@ -150,7 +177,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream { items.push(global); let wrapper = quote! { - pub fn #name<#(#ty_params),*>(#(#args : #tys),*) -> #result_type<#ret_type> { + #visibility fn #name<#(#ty_params),*>(#(#args : #tys),*) -> #result_type<#ret_type> { let func_ref: #tvm_rt_crate::Function = #global_name.clone(); let func_ref: Box<dyn Fn(#(#tys),*) -> #result_type<#ret_type>> = func_ref.into(); let res: #ret_type = func_ref(#(#args),*)?; diff --git a/rust/tvm-macros/src/lib.rs b/rust/tvm-macros/src/lib.rs index ab75c92..32f2839 100644 --- a/rust/tvm-macros/src/lib.rs +++ b/rust/tvm-macros/src/lib.rs @@ -18,6 +18,7 @@ */ use proc_macro::TokenStream; +use proc_macro_error::proc_macro_error; mod external; mod import_module; diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs index 46e0342..e48c017 100644 --- a/rust/tvm-rt/src/object/mod.rs +++ b/rust/tvm-rt/src/object/mod.rs @@ -88,7 +88,7 @@ pub trait IsObjectRef: external! { #[name("ir.DebugPrint")] - fn debug_print(object: ObjectRef) -> CString; + pub fn debug_print(object: ObjectRef) -> CString; #[name("node.StructuralHash")] fn structural_hash(object: ObjectRef, map_free_vars: bool) -> ObjectRef; #[name("node.StructuralEqual")] diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs index 8918bdc..3b60b0c 100644 --- a/rust/tvm/src/ir/module.rs +++ b/rust/tvm/src/ir/module.rs @@ -31,8 +31,10 @@ use crate::runtime::{external, Object, ObjectRef}; use super::expr::GlobalVar; use super::function::BaseFunc; use super::source_map::SourceMap; +use super::{ty::GlobalTypeVar, relay}; // TODO(@jroesch): define type + type TypeData = ObjectRef; type GlobalTypeVar = ObjectRef; @@ -64,7 +66,7 @@ external! { fn parse_expression(file_name: TVMString, source: TVMString) -> IRModule; // Module methods #[name("ir.Module_Add")] - fn module_add_def(module: IRModule, type_name: GlobalVar, expr: relay::Expr, update: bool) -> (); + fn module_add(module: IRModule, type_name: GlobalVar, expr: relay::Expr, update: bool) -> (); #[name("ir.Module_AddDef")] fn module_add_def(module: IRModule, type_name: GlobalTypeVar, type_data: TypeData, update: bool) -> (); #[name("ir.Module_GetGlobalVar")] @@ -78,15 +80,15 @@ external! { #[name("ir.Module_GetGlobalTypeVars")] fn module_get_global_type_vars() -> Array<GlobalTypeVar>; #[name("ir.Module_ContainGlobalVar")] - fn module_get_global_var(name: TVMString) -> bool; + fn module_contains_global_var(name: TVMString) -> bool; #[name("ir.Module_ContainGlobalTypeVar")] - fn module_get_global_type_var(name: TVMString) -> bool; + fn module_contains_global_type_var(name: TVMString) -> bool; #[name("ir.Module_LookupDef")] fn module_lookup_def(module: IRModule, global: GlobalTypeVar) -> TypeDef; #[name("ir.Module_LookupDef_str")] fn module_lookup_def_str(module: IRModule, global: GlobalTypeVar) -> TypeDef; #[name("ir.Module_LookupTag")] - fn module_lookup_tag(module: IRModule, tag: i32) -> Constructor; + fn module_lookup_tag(module: IRModule, tag: i32) -> relay::Constructor; #[name("ir.Module_FromExpr")] fn module_from_expr(expr: relay::Expr, funcs: Map<GlobalVar, BaseFunc>, types: Map<GlobalTypeVar, TypeData>) -> IRModule; #[name("ir.Module_Import")] @@ -145,3 +147,9 @@ impl IRModule { module_lookup_str(self.clone(), name.into()) } } + +#[cfg(test)] +mod tests { + // #[test] + // fn +}