This is an automated email from the ASF dual-hosted git repository. jroesch pushed a commit to branch cargo-build in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit db245535a4ea5b8bed3cdddd26a94825d2fcc9b5 Author: Jared Roesch <roesch...@gmail.com> AuthorDate: Thu Oct 15 01:03:03 2020 -0700 Update CMake and delete old API --- CMakeLists.txt | 1 + cmake/modules/RustExt.cmake | 25 ++- include/tvm/parser/source_map.h | 2 - rust/compiler-ext/Cargo.toml | 5 +- rust/compiler-ext/src/lib.rs | 334 ++-------------------------------------- rust/tvm-rt/Cargo.toml | 15 +- rust/tvm-sys/Cargo.toml | 1 + rust/tvm-sys/build.rs | 1 + rust/tvm/Cargo.toml | 22 ++- rust/tvm/src/bin/tyck.rs | 1 - rust/tvm/src/ir/diagnostics.rs | 42 +++-- rust/tvm/src/ir/mod.rs | 2 +- rust/tvm/src/ir/relay/mod.rs | 3 +- rust/tvm/src/ir/source_map.rs | 61 ++++++++ rust/tvm/src/ir/span.rs | 95 +++++++++--- src/ir/expr.cc | 11 ++ src/parser/source_map.cc | 11 -- 17 files changed, 237 insertions(+), 395 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9f82754..58bb2f7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -353,6 +353,7 @@ include(cmake/modules/contrib/ArmComputeLib.cmake) include(cmake/modules/contrib/TensorRT.cmake) include(cmake/modules/Git.cmake) include(cmake/modules/LibInfo.cmake) +include(cmake/modules/RustExt.cmake) include(CheckCXXCompilerFlag) if(NOT MSVC) diff --git a/cmake/modules/RustExt.cmake b/cmake/modules/RustExt.cmake index 45e46bd..2ad726e9 100644 --- a/cmake/modules/RustExt.cmake +++ b/cmake/modules/RustExt.cmake @@ -1,7 +1,14 @@ -if(USE_RUST_EXT) - set(RUST_SRC_DIR "rust") - set(CARGO_OUT_DIR "rust/target" - set(COMPILER_EXT_PATH "${CARGO_OUT_DIR}/target/release/libcompiler_ext.dylib") +if(USE_RUST_EXT AND NOT USE_RUST_EXT EQUAL OFF) + set(RUST_SRC_DIR "${CMAKE_SOURCE_DIR}/rust") + set(CARGO_OUT_DIR "${CMAKE_SOURCE_DIR}/rust/target") + + if(USE_RUST_EXT STREQUAL "STATIC") + set(COMPILER_EXT_PATH "${CARGO_OUT_DIR}/release/libcompiler_ext.a") + elseif(USE_RUST_EXT STREQUAL "DYNAMIC") + set(COMPILER_EXT_PATH "${CARGO_OUT_DIR}/release/libcompiler_ext.so") + else() + message(FATAL_ERROR "invalid setting for RUST_EXT") + endif() add_custom_command( OUTPUT "${COMPILER_EXT_PATH}" @@ -9,5 +16,11 @@ if(USE_RUST_EXT) MAIN_DEPENDENCY "${RUST_SRC_DIR}" WORKING_DIRECTORY "${RUST_SRC_DIR}/compiler-ext") - target_link_libraries(tvm "${COMPILER_EXT_PATH}" PRIVATE) -endif(USE_RUST_EXT) + add_custom_target(rust_ext ALL DEPENDS "${COMPILER_EXT_PATH}") + + # TODO(@jroesch, @tkonolige): move this to CMake target + # target_link_libraries(tvm "${COMPILER_EXT_PATH}" PRIVATE) + list(APPEND TVM_LINKER_LIBS ${COMPILER_EXT_PATH}) + + add_definitions(-DRUST_COMPILER_EXT=1) +endif() diff --git a/include/tvm/parser/source_map.h b/include/tvm/parser/source_map.h index 424af5c..a160c22 100644 --- a/include/tvm/parser/source_map.h +++ b/include/tvm/parser/source_map.h @@ -103,8 +103,6 @@ class SourceMap : public ObjectRef { TVM_DLL SourceMap() : SourceMap(Map<SourceName, Source>()) {} - TVM_DLL static SourceMap Global(); - void Add(const Source& source); SourceMapNode* operator->() { diff --git a/rust/compiler-ext/Cargo.toml b/rust/compiler-ext/Cargo.toml index 76d10eb..3b13bc5 100644 --- a/rust/compiler-ext/Cargo.toml +++ b/rust/compiler-ext/Cargo.toml @@ -6,8 +6,11 @@ edition = "2018" # TODO(@jroesch): would be cool to figure out how to statically link instead. [lib] -crate-type = ["cdylib"] +crate-type = ["staticlib", "cdylib"] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +tvm = { path = "../tvm", default-features = false, features = ["static-linking"] } +log = "*" +env_logger = "*" diff --git a/rust/compiler-ext/src/lib.rs b/rust/compiler-ext/src/lib.rs index 58bdd0c..3e37d21 100644 --- a/rust/compiler-ext/src/lib.rs +++ b/rust/compiler-ext/src/lib.rs @@ -17,321 +17,19 @@ * under the License. */ - use std::os::raw::c_int; - use tvm::initialize; - use tvm::ir::{tir, PrimExpr}; - use tvm::runtime::function::register_override; - use tvm::runtime::map::Map; - use tvm::runtime::object::{IsObject, IsObjectRef}; - - use ordered_float::NotNan; - - mod interval; - mod math; - - use math::{BoundsMap, Expr, RecExpr}; - use tvm::ir::arith::ConstIntBound; - use tvm_rt::{ObjectRef, array::Array}; - - macro_rules! downcast_match { - ($id:ident; { $($t:ty => $arm:expr $(,)? )+ , else => $default:expr }) => { - $( if let Ok($id) = $id.downcast_clone::<$t>() { $arm } else )+ - { $default } - } - } - - #[derive(Default)] - struct VarMap { - vars: Vec<(tvm::ir::tir::Var, egg::Symbol)>, - objs: Vec<ObjectRef>, - } - - impl VarMap { - // FIXME this should eventually do the right thing for TVM variables - // right now it depends on them having unique names - fn make_symbol(&mut self, var: tvm::ir::tir::Var) -> egg::Symbol { - let sym = egg::Symbol::from(var.name_hint.as_str().unwrap()); - for (_, sym2) in &self.vars { - if sym == *sym2 { - return sym; - } - } - - self.vars.push((var, sym)); - sym - } - - fn get_symbol(&self, sym: egg::Symbol) -> tvm::ir::tir::Var { - for (v, sym2) in &self.vars { - if sym == *sym2 { - return v.clone(); - } - } - panic!("Should have found a var") - } - - fn push_obj(&mut self, obj: impl IsObjectRef) -> usize { - let i = self.objs.len(); - self.objs.push(obj.upcast()); - i - } - - fn get_obj<T: IsObjectRef>(&self, i: usize) -> T { - self.objs[i].clone().downcast().expect("bad downcast") - } - } - - fn to_egg(vars: &mut VarMap, prim: &PrimExpr) -> RecExpr { - fn build(vars: &mut VarMap, p: &PrimExpr, recexpr: &mut RecExpr) -> egg::Id { - macro_rules! r { - ($e:expr) => { - build(vars, &$e, recexpr) - }; - } - - let dt = recexpr.add(Expr::DataType(p.datatype)); - let e = downcast_match!(p; { - tir::Add => Expr::Add([dt, r!(p.a), r!(p.b)]), - tir::Sub => Expr::Sub([dt, r!(p.a), r!(p.b)]), - tir::Mul => Expr::Mul([dt, r!(p.a), r!(p.b)]), - - tir::Div => Expr::Div([dt, r!(p.a), r!(p.b)]), - tir::Mod => Expr::Mod([dt, r!(p.a), r!(p.b)]), - tir::FloorDiv => Expr::FloorDiv([dt, r!(p.a), r!(p.b)]), - tir::FloorMod => Expr::FloorMod([dt, r!(p.a), r!(p.b)]), - - tir::Min => Expr::Min([dt, r!(p.a), r!(p.b)]), - tir::Max => Expr::Max([dt, r!(p.a), r!(p.b)]), - - tir::Ramp => Expr::Ramp([dt, r!(p.start), r!(p.stride), recexpr.add(Expr::Int(p.lanes.into()))]), - tir::Select => Expr::Select([dt, r!(p.condition), r!(p.true_value), r!(p.false_value)]), - - tir::Eq => Expr::Equal([dt, r!(p.a), r!(p.b)]), - tir::Ne => Expr::NotEqual([dt, r!(p.a), r!(p.b)]), - tir::Lt => Expr::Less([dt, r!(p.a), r!(p.b)]), - tir::Le => Expr::LessEqual([dt, r!(p.a), r!(p.b)]), - tir::Gt => Expr::Greater([dt, r!(p.a), r!(p.b)]), - tir::Ge => Expr::GreaterEqual([dt, r!(p.a), r!(p.b)]), - - tir::And => Expr::And([dt, r!(p.a), r!(p.b)]), - tir::Or => Expr::Or([dt, r!(p.a), r!(p.b)]), - tir::Not => Expr::Not([dt, r!(p.value)]), - - tir::Broadcast => Expr::Broadcast([dt, r!(p.value), recexpr.add(Expr::Int(p.lanes.into()))]), - - tir::Let => { - let sym = recexpr.add(Expr::Symbol(vars.make_symbol(p.var.clone()))); - Expr::Let([dt, sym, r!(p.value), r!(p.body)]) - } - tir::Var => { - let sym = recexpr.add(Expr::Symbol(vars.make_symbol(p))); - Expr::Var([dt, sym]) - } - tir::IntImm => { - let int = recexpr.add(Expr::Int(p.value)); - Expr::IntImm([dt, int]) - } - tir::FloatImm => { - let float = recexpr.add(Expr::Float(NotNan::new(p.value).unwrap())); - Expr::FloatImm([dt, float]) - } - tir::Cast => Expr::Cast([dt, r!(p.value)]), - - tir::Call => { - let op = vars.push_obj(p.op.clone()); - let mut arg_ids = vec![dt]; - for i in 0..p.args.len() { - let arg: PrimExpr = p.args.get(i as isize).expect("array get fail"); - arg_ids.push(r!(arg)); - } - Expr::Call(op, arg_ids) - }, - tir::Load => { - let sym = recexpr.add(Expr::Symbol(vars.make_symbol(p.var.clone()))); - Expr::Load([dt, sym, r!(p.index), r!(p.predicate)]) - }, - else => { - println!("Failed to downcast type '{}': {}", p.type_key(), tvm::runtime::debug_print(p.clone().upcast()).unwrap().to_str().unwrap()); - Expr::Object(vars.push_obj(p.clone())) - } - }); - - recexpr.add(e) - } - - let mut recexpr = Default::default(); - build(vars, prim, &mut recexpr); - recexpr - } - - fn from_egg(vars: &VarMap, recexpr: &RecExpr) -> PrimExpr { - fn build(vars: &VarMap, nodes: &[Expr]) -> PrimExpr { - let go = |i: &egg::Id| build(vars, &nodes[..usize::from(*i) + 1]); - let get_dt = |i: &egg::Id| nodes[usize::from(*i)].to_dtype().unwrap(); - let prim: PrimExpr = match nodes.last().expect("cannot be empty") { - Expr::Var([_dt, s]) => match &nodes[usize::from(*s)] { - Expr::Symbol(sym) => vars.get_symbol(*sym).upcast(), - n => panic!("Expected a symbol, got {:?}", n), - }, - Expr::IntImm([dt, v]) => { - let value = nodes[usize::from(*v)].to_int().unwrap(); - tir::IntImm::new(get_dt(dt), value).upcast() - } - Expr::FloatImm([dt, v]) => { - let value = nodes[usize::from(*v)].to_float().unwrap(); - tir::FloatImm::new(get_dt(dt), value).upcast() - } - Expr::Let([dt, s, value, body]) => { - let var = match &nodes[usize::from(*s)] { - Expr::Symbol(sym) => vars.get_symbol(*sym).upcast(), - n => panic!("Expected a symbol, got {:?}", n), - }; - tir::Let::new(get_dt(dt), var, go(value), go(body)).upcast() - } - Expr::Load([dt, s, value, body]) => { - let var = match &nodes[usize::from(*s)] { - Expr::Symbol(sym) => vars.get_symbol(*sym).upcast(), - n => panic!("Expected a symbol, got {:?}", n), - }; - tir::Load::new(get_dt(dt), var, go(value), go(body)).upcast() - } - - Expr::Add([dt, a, b]) => tir::Add::new(get_dt(dt), go(a), go(b)).upcast(), - Expr::Sub([dt, a, b]) => tir::Sub::new(get_dt(dt), go(a), go(b)).upcast(), - Expr::Mul([dt, a, b]) => tir::Mul::new(get_dt(dt), go(a), go(b)).upcast(), - - Expr::Div([dt, a, b]) => tir::Div::new(get_dt(dt), go(a), go(b)).upcast(), - Expr::Mod([dt, a, b]) => tir::Mod::new(get_dt(dt), go(a), go(b)).upcast(), - Expr::FloorDiv([dt, a, b]) => tir::FloorDiv::new(get_dt(dt), go(a), go(b)).upcast(), - Expr::FloorMod([dt, a, b]) => tir::FloorMod::new(get_dt(dt), go(a), go(b)).upcast(), - - Expr::Min([dt, a, b]) => tir::Min::new(get_dt(dt), go(a), go(b)).upcast(), - Expr::Max([dt, a, b]) => tir::Max::new(get_dt(dt), go(a), go(b)).upcast(), - - Expr::Equal([dt, a, b]) => tir::Eq::new(get_dt(dt), go(a), go(b)).upcast(), - Expr::NotEqual([dt, a, b]) => tir::Ne::new(get_dt(dt), go(a), go(b)).upcast(), - Expr::Less([dt, a, b]) => tir::Lt::new(get_dt(dt), go(a), go(b)).upcast(), - Expr::LessEqual([dt, a, b]) => tir::Le::new(get_dt(dt), go(a), go(b)).upcast(), - Expr::Greater([dt, a, b]) => tir::Gt::new(get_dt(dt), go(a), go(b)).upcast(), - Expr::GreaterEqual([dt, a, b]) => tir::Ge::new(get_dt(dt), go(a), go(b)).upcast(), - - Expr::And([dt, a, b]) => tir::And::new(get_dt(dt), go(a), go(b)).upcast(), - Expr::Or([dt, a, b]) => tir::Or::new(get_dt(dt), go(a), go(b)).upcast(), - Expr::Not([dt, a]) => tir::Not::new(get_dt(dt), go(a)).upcast(), - - Expr::Ramp([dt, a, b, c]) => { - let len = &nodes[usize::from(*c)]; - let i = len - .to_int() - .unwrap_or_else(|| panic!("Ramp lanes must be an int, got {:?}", len)); - tir::Ramp::new(get_dt(dt), go(a), go(b), i as i32).upcast() - } - Expr::Broadcast([dt, val, lanes]) => { - let lanes = &nodes[usize::from(*lanes)]; - let lanes = lanes - .to_int() - .unwrap_or_else(|| panic!("Ramp lanes must be an int, got {:?}", lanes)); - println!("dt: {}", get_dt(dt)); - tir::Broadcast::new(get_dt(dt), go(val), lanes as i32).upcast() - } - - Expr::Select([dt, a, b, c]) => tir::Select::new(get_dt(dt), go(a), go(b), go(c)).upcast(), - Expr::Cast([dt, a]) => tir::Cast::new(get_dt(dt), go(a)).upcast(), - Expr::Call(expr, args) => { - let arg_exprs: Vec<PrimExpr> = args[1..].iter().map(go).collect(); - let arg_exprs = Array::from_vec(arg_exprs).expect("failed to convert args"); - tir::Call::new(get_dt(&args[0]), vars.get_obj(*expr), arg_exprs).upcast() - } - - Expr::Object(i) => vars.get_obj(*i), - node => panic!("I don't know how to extract {:?}", node), - }; - assert_ne!(prim.datatype.bits(), 0); - assert_ne!(prim.datatype.lanes(), 0); - prim - } - build(vars, recexpr.as_ref()) - } - - fn run( - input: PrimExpr, - expected: Option<PrimExpr>, - map: Map<PrimExpr, ConstIntBound>, - ) -> Result<PrimExpr, String> { - use egg::{CostFunction, Extractor}; - - let mut bounds = BoundsMap::default(); - for (k, v) in map { - if let Ok(var) = k.downcast_clone::<tir::Var>() { - let sym: egg::Symbol = var.name_hint.as_str().unwrap().into(); - bounds.insert(sym, (v.min_value, v.max_value)); - } else { - println!("Non var in bounds map: {}", tvm::ir::as_text(k)); - } - } - - let mut vars = VarMap::default(); - let expr = to_egg(&mut vars, &input); - let mut runner = math::default_runner(); - runner.egraph.analysis.bounds = bounds; - - let mut runner = runner.with_expr(&expr).run(&math::rules()); - // runner.print_report(); - let mut extractor = Extractor::new(&runner.egraph, math::CostFn); - let root = runner.egraph.find(runner.roots[0]); - let (cost, best) = extractor.find_best(root); - if let Some(expected) = expected { - let mut expected_vars = VarMap::default(); - let expected_expr = to_egg(&mut expected_vars, &expected); - let expected_root = runner.egraph.add_expr(&expected_expr); - if expected_root != root { - return Err(format!( - "\n\nFailed to prove them equal!\nExpected:\n{}\nFound:\n{}\n", - expected_expr.pretty(40), - best.pretty(40) - )); - } - let expected_cost = math::CostFn.cost_rec(&expected_expr); - if expected_cost != cost { - let msg = format!( - "\n\nCosts not equal: Expected {}:\n{}\nFound {}:\n{}\n", - expected_cost, - expected_expr.pretty(40), - cost, - best.pretty(40) - ); - if cost < expected_cost { - println!("egg wins: {}", msg) - } else { - return Err(msg); - } - } - } - log::info!(" returning... {}", best.pretty(60)); - Ok(from_egg(&vars, &best)) - } - - fn simplify(prim: PrimExpr, map: Map<PrimExpr, ConstIntBound>) -> Result<PrimExpr, tvm::Error> { - log::debug!("map: {:?}", map); - run(prim, None, map).map_err(tvm::Error::CallFailed) - } - - fn simplify_and_check( - prim: PrimExpr, - check: PrimExpr, - map: Map<PrimExpr, ConstIntBound>, - ) -> Result<PrimExpr, tvm::Error> { - log::debug!("check map: {:?}", map); - run(prim, Some(check), map).map_err(tvm::Error::CallFailed) - } - - initialize!({ - let _ = env_logger::try_init(); - // NOTE this print prevents a segfault (on Linux) for now... - println!("Initializing simplifier... "); - register_override(simplify, "egg.simplify", true).expect("failed to initialize simplifier"); - register_override(simplify_and_check, "egg.simplify_and_check", true) - .expect("failed to initialize simplifier"); - log::debug!("done!"); - }); - \ No newline at end of file +use env_logger; +use tvm; +use tvm::runtime::function::register_override; + +fn test_fn() -> Result<(), tvm::Error> { + println!("Hello from Rust!"); + Ok(()) +} + +#[no_mangle] +fn compiler_ext_initialize() -> i32 { + let _ = env_logger::try_init(); + register_override(test_fn, "rust_ext.test_fn", true).expect("failed to initialize simplifier"); + log::debug!("done!"); + return 0; +} diff --git a/rust/tvm-rt/Cargo.toml b/rust/tvm-rt/Cargo.toml index acece5a..9660943 100644 --- a/rust/tvm-rt/Cargo.toml +++ b/rust/tvm-rt/Cargo.toml @@ -28,19 +28,26 @@ categories = ["api-bindings", "science"] authors = ["TVM Contributors"] edition = "2018" +[features] +default = ["dynamic-linking"] +dynamic-linking = ["tvm-sys/bindings"] +static-linking = [] +blas = ["ndarray/blas"] + [dependencies] thiserror = "^1.0" ndarray = "0.12" num-traits = "0.2" -tvm-sys = { version = "0.1", path = "../tvm-sys/", features = ["bindings"] } tvm-macros = { version = "0.1", path = "../tvm-macros" } paste = "0.1" mashup = "0.1" once_cell = "^1.3.1" memoffset = "0.5.6" +[dependencies.tvm-sys] +version = "0.1" +default-features = false +path = "../tvm-sys/" + [dev-dependencies] anyhow = "^1.0" - -[features] -blas = ["ndarray/blas"] diff --git a/rust/tvm-sys/Cargo.toml b/rust/tvm-sys/Cargo.toml index 4e3fc98..c25a5bf 100644 --- a/rust/tvm-sys/Cargo.toml +++ b/rust/tvm-sys/Cargo.toml @@ -23,6 +23,7 @@ license = "Apache-2.0" edition = "2018" [features] +default = ["bindings"] bindings = [] [dependencies] diff --git a/rust/tvm-sys/build.rs b/rust/tvm-sys/build.rs index 05806c0..2d86c4b 100644 --- a/rust/tvm-sys/build.rs +++ b/rust/tvm-sys/build.rs @@ -60,6 +60,7 @@ fn main() -> Result<()> { if cfg!(feature = "bindings") { println!("cargo:rerun-if-env-changed=TVM_HOME"); println!("cargo:rustc-link-lib=dylib=tvm"); + println!("cargo:rustc-link-lib=dylib=llvm-10"); println!("cargo:rustc-link-search={}/build", tvm_home); } diff --git a/rust/tvm/Cargo.toml b/rust/tvm/Cargo.toml index 71a4b93..153a195 100644 --- a/rust/tvm/Cargo.toml +++ b/rust/tvm/Cargo.toml @@ -28,14 +28,24 @@ categories = ["api-bindings", "science"] authors = ["TVM Contributors"] edition = "2018" +[features] +default = ["python", "dynamic-linking"] +dynamic-linking = ["tvm-rt/dynamic-linking"] +static-linking = ["tvm-rt/static-linking"] +blas = ["ndarray/blas"] +python = ["pyo3"] + +[dependencies.tvm-rt] +version = "0.1" +default-features = false +path = "../tvm-rt/" + [dependencies] thiserror = "^1.0" anyhow = "^1.0" lazy_static = "1.1" ndarray = "0.12" num-traits = "0.2" -tvm-rt = { version = "0.1", path = "../tvm-rt/" } -tvm-sys = { version = "0.1", path = "../tvm-sys/" } tvm-macros = { version = "*", path = "../tvm-macros/" } paste = "0.1" mashup = "0.1" @@ -44,8 +54,6 @@ pyo3 = { version = "0.11.1", optional = true } codespan-reporting = "0.9.5" structopt = { version = "0.3" } -[features] -default = ["python"] - -blas = ["ndarray/blas"] -python = ["pyo3"] +[[bin]] +name = "tyck" +required-features = ["dynamic-linking"] diff --git a/rust/tvm/src/bin/tyck.rs b/rust/tvm/src/bin/tyck.rs index 9300412..b869012 100644 --- a/rust/tvm/src/bin/tyck.rs +++ b/rust/tvm/src/bin/tyck.rs @@ -6,7 +6,6 @@ use structopt::StructOpt; use tvm::ir::diagnostics::codespan; use tvm::ir::IRModule; - #[derive(Debug, StructOpt)] #[structopt(name = "tyck", about = "Parse and type check a Relay program.")] struct Opt { diff --git a/rust/tvm/src/ir/diagnostics.rs b/rust/tvm/src/ir/diagnostics.rs index d306185..b76e43f 100644 --- a/rust/tvm/src/ir/diagnostics.rs +++ b/rust/tvm/src/ir/diagnostics.rs @@ -17,17 +17,20 @@ * under the License. */ +use super::module::IRModule; +use super::span::Span; +use crate::runtime::function::Result; +use crate::runtime::object::{Object, ObjectPtr, ObjectRef}; +use crate::runtime::{ + array::Array, + function::{self, Function, ToFunction}, + string::String as TString, +}; /// The diagnostic interface to TVM, used for reporting and rendering /// diagnostic information by the compiler. This module exposes /// three key abstractions: a Diagnostic, the DiagnosticContext, /// and the DiagnosticRenderer. - -use tvm_macros::{Object, external}; -use super::module::IRModule; -use crate::runtime::{function::{self, Function, ToFunction}, array::Array, string::String as TString}; -use crate::runtime::object::{Object, ObjectPtr, ObjectRef}; -use crate::runtime::function::Result; -use super::span::Span; +use tvm_macros::{external, Object}; type SourceName = ObjectRef; @@ -134,7 +137,6 @@ pub struct DiagnosticRendererNode { // memory layout } - // def render(self, ctx): // """ // Render the provided context. @@ -169,7 +171,8 @@ pub struct DiagnosticContextNode { /// and contains a renderer. impl DiagnosticContext { pub fn new<F>(module: IRModule, render_func: F) -> DiagnosticContext - where F: Fn(DiagnosticContext) -> () + 'static + where + F: Fn(DiagnosticContext) -> () + 'static, { let renderer = diagnostic_renderer(render_func.to_function()).unwrap(); let node = DiagnosticContextNode { @@ -210,21 +213,16 @@ impl DiagnosticContext { // If the render_func is None it will remove the current custom renderer // and return to default behavior. fn override_renderer<F>(opt_func: Option<F>) -> Result<()> -where F: Fn(DiagnosticContext) -> () + 'static +where + F: Fn(DiagnosticContext) -> () + 'static, { - match opt_func { None => clear_renderer(), Some(func) => { let func = func.to_function(); - let render_factory = move || { - diagnostic_renderer(func.clone()).unwrap() - }; + let render_factory = move || diagnostic_renderer(func.clone()).unwrap(); - function::register_override( - render_factory, - "diagnostics.OverrideRenderer", - true)?; + function::register_override(render_factory, "diagnostics.OverrideRenderer", true)?; Ok(()) } @@ -243,9 +241,9 @@ pub mod codespan { End, } - struct SpanToBytes { - inner: HashMap<std::String, HashMap<usize, (StartOrEnd, - } + // struct SpanToBytes { + // inner: HashMap<std::String, HashMap<usize, (StartOrEnd, + // } struct ByteRange<FileId> { file_id: FileId, @@ -276,7 +274,7 @@ pub mod codespan { .with_message(message) .with_code("EXXX") .with_labels(vec![ - Label::primary(file_id, 328..331).with_message(inner_message), + Label::primary(file_id, 328..331).with_message(inner_message) ]); diagnostic diff --git a/rust/tvm/src/ir/mod.rs b/rust/tvm/src/ir/mod.rs index 8450bd7..401b6c2 100644 --- a/rust/tvm/src/ir/mod.rs +++ b/rust/tvm/src/ir/mod.rs @@ -19,8 +19,8 @@ pub mod arith; pub mod attrs; -pub mod expr; pub mod diagnostics; +pub mod expr; pub mod function; pub mod module; pub mod op; diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index e539221..4b09128 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -28,6 +28,7 @@ use super::attrs::Attrs; use super::expr::BaseExprNode; use super::function::BaseFuncNode; use super::ty::{Type, TypeNode}; +use super::span::Span; use tvm_macros::Object; use tvm_rt::NDArray; @@ -51,7 +52,7 @@ impl ExprNode { span: ObjectRef::null(), checked_type: Type::from(TypeNode { base: Object::base_object::<TypeNode>(), - span: ObjectRef::null(), + span: Span::empty(), }), } } diff --git a/rust/tvm/src/ir/source_map.rs b/rust/tvm/src/ir/source_map.rs index e69de29..e6c0371 100644 --- a/rust/tvm/src/ir/source_map.rs +++ b/rust/tvm/src/ir/source_map.rs @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use crate::runtime::map::Map; +use crate::runtime::object::Object; + +/// A program source in any language. +/// +/// Could represent the source from an ML framework or a source of an IRModule. +#[repr(C)] +#[derive(Object)] +#[type_key = "Source"] +#[ref_key = "Source"] +struct SourceNode { + pub base: Object, + /*! \brief The source name. */ + SourceName source_name; + + /*! \brief The raw source. */ + String source; + + /*! \brief A mapping of line breaks into the raw source. */ + std::vector<std::pair<int, int>> line_map; +} + + +// class Source : public ObjectRef { +// public: +// TVM_DLL Source(SourceName src_name, std::string source); +// TVM_DLL tvm::String GetLine(int line); + +// TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Source, ObjectRef, SourceNode); +// }; + + +/// A mapping from a unique source name to source fragments. +#[repr(C)] +#[derive(Object)] +#[type_key = "SourceMap"] +#[ref_key = "SourceMap"] +struct SourceMapNode { + pub base: Object, + /// The source mapping. + pub source_map: Map<SourceName, Source>, +} diff --git a/rust/tvm/src/ir/span.rs b/rust/tvm/src/ir/span.rs index d2e19a2..c54fd51 100644 --- a/rust/tvm/src/ir/span.rs +++ b/rust/tvm/src/ir/span.rs @@ -1,22 +1,75 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use crate::runtime::ObjectRef; - -pub type Span = ObjectRef; +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the + +* specific language governing permissions and limitations +* under the License. +*/ + +use crate::runtime::{ObjectRef, Object, String as TString}; +use tvm_macros::Object; + +/// A source file name, contained in a Span. + +#[repr(C)] +#[derive(Object)] +#[type_key = "SourceName"] +#[ref_name = "SourceName"] +pub struct SourceNameNode { + pub base: Object, + pub name: TString, +} + +// /*! +// * \brief The source name of a file span. +// * \sa SourceNameNode, Span +// */ +// class SourceName : public ObjectRef { +// public: +// /*! +// * \brief Get an SourceName for a given operator name. +// * Will raise an error if the source name has not been registered. +// * \param name Name of the operator. +// * \return SourceName valid throughout program lifetime. +// */ +// TVM_DLL static SourceName Get(const String& name); + +// TVM_DEFINE_OBJECT_REF_METHODS(SourceName, ObjectRef, SourceNameNode); +// }; + +/// Span information for diagnostic purposes. +#[repr(C)] +#[derive(Object)] +#[type_key = "Span"] +#[ref_name = "Span"] +pub struct SpanNode { + pub base: Object, + /// The source name. + pub source_name: SourceName, + /// The line number. + pub line: i32, + /// The column offset. + pub column: i32, + /// The end line number. + pub end_line: i32, + /// The end column number. + pub end_column: i32, +} + +impl Span { + pub fn empty() -> Span { + todo!() + } +} diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 67e5cea..5110eef 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -192,4 +192,15 @@ TVM_REGISTER_GLOBAL("ir.DebugPrint").set_body_typed([](ObjectRef ref) { return ss.str(); }); + + } // namespace tvm + +#ifdef RUST_COMPILER_EXT + +extern "C" { + int compiler_ext_initialize(); + static int test = compiler_ext_initialize(); +} + +#endif diff --git a/src/parser/source_map.cc b/src/parser/source_map.cc index 7ac978c..7340f69 100644 --- a/src/parser/source_map.cc +++ b/src/parser/source_map.cc @@ -77,12 +77,6 @@ tvm::String Source::GetLine(int line) { return line_text; } -// TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -// .set_dispatch<SourceNameNode>([](const ObjectRef& ref, ReprPrinter* p) { -// auto* node = static_cast<const SourceNameNode*>(ref.get()); -// p->stream << "SourceName(" << node->name << ", " << node << ")"; -// }); - TVM_REGISTER_NODE_TYPE(SourceMapNode); SourceMap::SourceMap(Map<SourceName, Source> source_map) { @@ -91,11 +85,6 @@ SourceMap::SourceMap(Map<SourceName, Source> source_map) { data_ = std::move(n); } -// TODO(@jroesch): fix this -static SourceMap global_source_map = SourceMap(Map<SourceName, Source>()); - -SourceMap SourceMap::Global() { return global_source_map; } - void SourceMap::Add(const Source& source) { (*this)->source_map.Set(source->source_name, source); } TVM_REGISTER_GLOBAL("SourceMapAdd").set_body_typed([](SourceMap map, String name, String content) {