This is an automated email from the ASF dual-hosted git repository. jroesch pushed a commit to branch cargo-build in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit 49246bff342eb59757cecc34a9a9465a2e3c063d Author: Jared Roesch <roesch...@gmail.com> AuthorDate: Wed Oct 21 14:09:37 2020 -0700 WIP --- rust/tvm-macros/src/external.rs | 5 ++- rust/tvm-macros/src/lib.rs | 1 + rust/tvm/src/ir/module.rs | 67 +++++++++++++---------------------------- 3 files changed, 26 insertions(+), 47 deletions(-) diff --git a/rust/tvm-macros/src/external.rs b/rust/tvm-macros/src/external.rs index 802d7ae..de8ada3 100644 --- a/rust/tvm-macros/src/external.rs +++ b/rust/tvm-macros/src/external.rs @@ -17,6 +17,7 @@ * under the License. */ use proc_macro2::Span; +use proc_macro_error::abort; use quote::quote; use syn::parse::{Parse, ParseStream, Result}; @@ -109,7 +110,9 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream { .iter() .map(|ty_param| match ty_param { syn::GenericParam::Type(param) => param.clone(), - _ => panic!(), + _ => abort! { ty_param, + "Only supports type parameters." + } }) .collect(); diff --git a/rust/tvm-macros/src/lib.rs b/rust/tvm-macros/src/lib.rs index 603e1ce..ab75c92 100644 --- a/rust/tvm-macros/src/lib.rs +++ b/rust/tvm-macros/src/lib.rs @@ -35,6 +35,7 @@ pub fn macro_impl(input: TokenStream) -> TokenStream { TokenStream::from(object::macro_impl(input)) } +#[proc_macro_error] #[proc_macro] pub fn external(input: TokenStream) -> TokenStream { external::macro_impl(input) diff --git a/rust/tvm/src/ir/module.rs b/rust/tvm/src/ir/module.rs index 443915f..8918bdc 100644 --- a/rust/tvm/src/ir/module.rs +++ b/rust/tvm/src/ir/module.rs @@ -63,6 +63,8 @@ external! { #[name("parser.ParseExpr")] fn parse_expression(file_name: TVMString, source: TVMString) -> IRModule; // Module methods + #[name("ir.Module_Add")] + fn module_add_def(module: IRModule, type_name: GlobalVar, expr: relay::Expr, update: bool) -> (); #[name("ir.Module_AddDef")] fn module_add_def(module: IRModule, type_name: GlobalTypeVar, type_data: TypeData, update: bool) -> (); #[name("ir.Module_GetGlobalVar")] @@ -73,55 +75,28 @@ external! { fn module_lookup(module: IRModule, var: GlobalVar) -> BaseFunc; #[name("ir.Module_Lookup_str")] fn module_lookup_str(module: IRModule, name: TVMString) -> BaseFunc; + #[name("ir.Module_GetGlobalTypeVars")] + fn module_get_global_type_vars() -> Array<GlobalTypeVar>; + #[name("ir.Module_ContainGlobalVar")] + fn module_get_global_var(name: TVMString) -> bool; + #[name("ir.Module_ContainGlobalTypeVar")] + fn module_get_global_type_var(name: TVMString) -> bool; + #[name("ir.Module_LookupDef")] + fn module_lookup_def(module: IRModule, global: GlobalTypeVar) -> TypeDef; + #[name("ir.Module_LookupDef_str")] + fn module_lookup_def_str(module: IRModule, global: GlobalTypeVar) -> TypeDef; + #[name("ir.Module_LookupTag")] + fn module_lookup_tag(module: IRModule, tag: i32) -> Constructor; + #[name("ir.Module_FromExpr")] + fn module_from_expr(expr: relay::Expr, funcs: Map<GlobalVar, BaseFunc>, types: Map<GlobalTypeVar, TypeData>) -> IRModule; + #[name("ir.Module_Import")] + fn module_import(module: IRModule, path: TVMString); + #[name("ir.Module_ImportFromStd")] + fn module_import_from_std(module: IRModule, path: TVMString); } -// TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVars") -// .set_body_method<IRModule>(&IRModuleNode::GetGlobalTypeVars); +// Note: we don't expose update here as update is going to be removed. -// TVM_REGISTER_GLOBAL("ir.Module_ContainGlobalVar") -// .set_body_method<IRModule>(&IRModuleNode::ContainGlobalVar); - -// TVM_REGISTER_GLOBAL("ir.Module_GetGlobalTypeVar") -// .set_body_method<IRModule>(&IRModuleNode::GetGlobalTypeVar); - -// TVM_REGISTER_GLOBAL("ir.Module_LookupDef").set_body_typed([](IRModule mod, GlobalTypeVar var) { -// return mod->LookupTypeDef(var); -// }); - -// TVM_REGISTER_GLOBAL("ir.Module_LookupDef_str").set_body_typed([](IRModule mod, String var) { -// return mod->LookupTypeDef(var); -// }); - -// TVM_REGISTER_GLOBAL("ir.Module_LookupTag").set_body_typed([](IRModule mod, int32_t tag) { -// return mod->LookupTag(tag); -// }); - -// TVM_REGISTER_GLOBAL("ir.Module_FromExpr") -// .set_body_typed([](RelayExpr e, tvm::Map<GlobalVar, BaseFunc> funcs, -// tvm::Map<GlobalTypeVar, TypeData> type_defs) { -// return IRModule::FromExpr(e, funcs, type_defs); -// }); - -// TVM_REGISTER_GLOBAL("ir.Module_Update").set_body_typed([](IRModule mod, IRModule from) { -// mod->Update(from); -// }); - -// TVM_REGISTER_GLOBAL("ir.Module_UpdateFunction") -// .set_body_typed([](IRModule mod, GlobalVar gv, BaseFunc func) { mod->Update(gv, func); }); - -// TVM_REGISTER_GLOBAL("ir.Module_Import").set_body_typed([](IRModule mod, String path) { -// mod->Import(path); -// }); - -// TVM_REGISTER_GLOBAL("ir.Module_ImportFromStd").set_body_typed([](IRModule mod, String path) { -// mod->ImportFromStd(path); -// }); - -// TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -// .set_dispatch<IRModuleNode>([](const ObjectRef& ref, ReprPrinter* p) { -// auto* node = static_cast<const IRModuleNode*>(ref.get()); -// p->stream << "IRModuleNode( " << node->functions << ")"; -// }); impl IRModule { pub fn parse<N, S>(file_name: N, source: S) -> Result<IRModule>