This is an automated email from the ASF dual-hosted git repository. jroesch pushed a commit to branch cargo-build in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
commit b2b59c229e9b8c2002d8c8cd520748df6b38e074 Author: Jared Roesch <jroe...@octoml.ai> AuthorDate: Tue Oct 13 15:26:54 2020 -0700 Borrow code from Egg --- rust/compiler-ext/src/lib.rs | 344 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 337 insertions(+), 7 deletions(-) diff --git a/rust/compiler-ext/src/lib.rs b/rust/compiler-ext/src/lib.rs index 31e1bb2..58bdd0c 100644 --- a/rust/compiler-ext/src/lib.rs +++ b/rust/compiler-ext/src/lib.rs @@ -1,7 +1,337 @@ -#[cfg(test)] -mod tests { - #[test] - fn it_works() { - assert_eq!(2 + 2, 4); - } -} +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + use std::os::raw::c_int; + use tvm::initialize; + use tvm::ir::{tir, PrimExpr}; + use tvm::runtime::function::register_override; + use tvm::runtime::map::Map; + use tvm::runtime::object::{IsObject, IsObjectRef}; + + use ordered_float::NotNan; + + mod interval; + mod math; + + use math::{BoundsMap, Expr, RecExpr}; + use tvm::ir::arith::ConstIntBound; + use tvm_rt::{ObjectRef, array::Array}; + + macro_rules! downcast_match { + ($id:ident; { $($t:ty => $arm:expr $(,)? )+ , else => $default:expr }) => { + $( if let Ok($id) = $id.downcast_clone::<$t>() { $arm } else )+ + { $default } + } + } + + #[derive(Default)] + struct VarMap { + vars: Vec<(tvm::ir::tir::Var, egg::Symbol)>, + objs: Vec<ObjectRef>, + } + + impl VarMap { + // FIXME this should eventually do the right thing for TVM variables + // right now it depends on them having unique names + fn make_symbol(&mut self, var: tvm::ir::tir::Var) -> egg::Symbol { + let sym = egg::Symbol::from(var.name_hint.as_str().unwrap()); + for (_, sym2) in &self.vars { + if sym == *sym2 { + return sym; + } + } + + self.vars.push((var, sym)); + sym + } + + fn get_symbol(&self, sym: egg::Symbol) -> tvm::ir::tir::Var { + for (v, sym2) in &self.vars { + if sym == *sym2 { + return v.clone(); + } + } + panic!("Should have found a var") + } + + fn push_obj(&mut self, obj: impl IsObjectRef) -> usize { + let i = self.objs.len(); + self.objs.push(obj.upcast()); + i + } + + fn get_obj<T: IsObjectRef>(&self, i: usize) -> T { + self.objs[i].clone().downcast().expect("bad downcast") + } + } + + fn to_egg(vars: &mut VarMap, prim: &PrimExpr) -> RecExpr { + fn build(vars: &mut VarMap, p: &PrimExpr, recexpr: &mut RecExpr) -> egg::Id { + macro_rules! r { + ($e:expr) => { + build(vars, &$e, recexpr) + }; + } + + let dt = recexpr.add(Expr::DataType(p.datatype)); + let e = downcast_match!(p; { + tir::Add => Expr::Add([dt, r!(p.a), r!(p.b)]), + tir::Sub => Expr::Sub([dt, r!(p.a), r!(p.b)]), + tir::Mul => Expr::Mul([dt, r!(p.a), r!(p.b)]), + + tir::Div => Expr::Div([dt, r!(p.a), r!(p.b)]), + tir::Mod => Expr::Mod([dt, r!(p.a), r!(p.b)]), + tir::FloorDiv => Expr::FloorDiv([dt, r!(p.a), r!(p.b)]), + tir::FloorMod => Expr::FloorMod([dt, r!(p.a), r!(p.b)]), + + tir::Min => Expr::Min([dt, r!(p.a), r!(p.b)]), + tir::Max => Expr::Max([dt, r!(p.a), r!(p.b)]), + + tir::Ramp => Expr::Ramp([dt, r!(p.start), r!(p.stride), recexpr.add(Expr::Int(p.lanes.into()))]), + tir::Select => Expr::Select([dt, r!(p.condition), r!(p.true_value), r!(p.false_value)]), + + tir::Eq => Expr::Equal([dt, r!(p.a), r!(p.b)]), + tir::Ne => Expr::NotEqual([dt, r!(p.a), r!(p.b)]), + tir::Lt => Expr::Less([dt, r!(p.a), r!(p.b)]), + tir::Le => Expr::LessEqual([dt, r!(p.a), r!(p.b)]), + tir::Gt => Expr::Greater([dt, r!(p.a), r!(p.b)]), + tir::Ge => Expr::GreaterEqual([dt, r!(p.a), r!(p.b)]), + + tir::And => Expr::And([dt, r!(p.a), r!(p.b)]), + tir::Or => Expr::Or([dt, r!(p.a), r!(p.b)]), + tir::Not => Expr::Not([dt, r!(p.value)]), + + tir::Broadcast => Expr::Broadcast([dt, r!(p.value), recexpr.add(Expr::Int(p.lanes.into()))]), + + tir::Let => { + let sym = recexpr.add(Expr::Symbol(vars.make_symbol(p.var.clone()))); + Expr::Let([dt, sym, r!(p.value), r!(p.body)]) + } + tir::Var => { + let sym = recexpr.add(Expr::Symbol(vars.make_symbol(p))); + Expr::Var([dt, sym]) + } + tir::IntImm => { + let int = recexpr.add(Expr::Int(p.value)); + Expr::IntImm([dt, int]) + } + tir::FloatImm => { + let float = recexpr.add(Expr::Float(NotNan::new(p.value).unwrap())); + Expr::FloatImm([dt, float]) + } + tir::Cast => Expr::Cast([dt, r!(p.value)]), + + tir::Call => { + let op = vars.push_obj(p.op.clone()); + let mut arg_ids = vec![dt]; + for i in 0..p.args.len() { + let arg: PrimExpr = p.args.get(i as isize).expect("array get fail"); + arg_ids.push(r!(arg)); + } + Expr::Call(op, arg_ids) + }, + tir::Load => { + let sym = recexpr.add(Expr::Symbol(vars.make_symbol(p.var.clone()))); + Expr::Load([dt, sym, r!(p.index), r!(p.predicate)]) + }, + else => { + println!("Failed to downcast type '{}': {}", p.type_key(), tvm::runtime::debug_print(p.clone().upcast()).unwrap().to_str().unwrap()); + Expr::Object(vars.push_obj(p.clone())) + } + }); + + recexpr.add(e) + } + + let mut recexpr = Default::default(); + build(vars, prim, &mut recexpr); + recexpr + } + + fn from_egg(vars: &VarMap, recexpr: &RecExpr) -> PrimExpr { + fn build(vars: &VarMap, nodes: &[Expr]) -> PrimExpr { + let go = |i: &egg::Id| build(vars, &nodes[..usize::from(*i) + 1]); + let get_dt = |i: &egg::Id| nodes[usize::from(*i)].to_dtype().unwrap(); + let prim: PrimExpr = match nodes.last().expect("cannot be empty") { + Expr::Var([_dt, s]) => match &nodes[usize::from(*s)] { + Expr::Symbol(sym) => vars.get_symbol(*sym).upcast(), + n => panic!("Expected a symbol, got {:?}", n), + }, + Expr::IntImm([dt, v]) => { + let value = nodes[usize::from(*v)].to_int().unwrap(); + tir::IntImm::new(get_dt(dt), value).upcast() + } + Expr::FloatImm([dt, v]) => { + let value = nodes[usize::from(*v)].to_float().unwrap(); + tir::FloatImm::new(get_dt(dt), value).upcast() + } + Expr::Let([dt, s, value, body]) => { + let var = match &nodes[usize::from(*s)] { + Expr::Symbol(sym) => vars.get_symbol(*sym).upcast(), + n => panic!("Expected a symbol, got {:?}", n), + }; + tir::Let::new(get_dt(dt), var, go(value), go(body)).upcast() + } + Expr::Load([dt, s, value, body]) => { + let var = match &nodes[usize::from(*s)] { + Expr::Symbol(sym) => vars.get_symbol(*sym).upcast(), + n => panic!("Expected a symbol, got {:?}", n), + }; + tir::Load::new(get_dt(dt), var, go(value), go(body)).upcast() + } + + Expr::Add([dt, a, b]) => tir::Add::new(get_dt(dt), go(a), go(b)).upcast(), + Expr::Sub([dt, a, b]) => tir::Sub::new(get_dt(dt), go(a), go(b)).upcast(), + Expr::Mul([dt, a, b]) => tir::Mul::new(get_dt(dt), go(a), go(b)).upcast(), + + Expr::Div([dt, a, b]) => tir::Div::new(get_dt(dt), go(a), go(b)).upcast(), + Expr::Mod([dt, a, b]) => tir::Mod::new(get_dt(dt), go(a), go(b)).upcast(), + Expr::FloorDiv([dt, a, b]) => tir::FloorDiv::new(get_dt(dt), go(a), go(b)).upcast(), + Expr::FloorMod([dt, a, b]) => tir::FloorMod::new(get_dt(dt), go(a), go(b)).upcast(), + + Expr::Min([dt, a, b]) => tir::Min::new(get_dt(dt), go(a), go(b)).upcast(), + Expr::Max([dt, a, b]) => tir::Max::new(get_dt(dt), go(a), go(b)).upcast(), + + Expr::Equal([dt, a, b]) => tir::Eq::new(get_dt(dt), go(a), go(b)).upcast(), + Expr::NotEqual([dt, a, b]) => tir::Ne::new(get_dt(dt), go(a), go(b)).upcast(), + Expr::Less([dt, a, b]) => tir::Lt::new(get_dt(dt), go(a), go(b)).upcast(), + Expr::LessEqual([dt, a, b]) => tir::Le::new(get_dt(dt), go(a), go(b)).upcast(), + Expr::Greater([dt, a, b]) => tir::Gt::new(get_dt(dt), go(a), go(b)).upcast(), + Expr::GreaterEqual([dt, a, b]) => tir::Ge::new(get_dt(dt), go(a), go(b)).upcast(), + + Expr::And([dt, a, b]) => tir::And::new(get_dt(dt), go(a), go(b)).upcast(), + Expr::Or([dt, a, b]) => tir::Or::new(get_dt(dt), go(a), go(b)).upcast(), + Expr::Not([dt, a]) => tir::Not::new(get_dt(dt), go(a)).upcast(), + + Expr::Ramp([dt, a, b, c]) => { + let len = &nodes[usize::from(*c)]; + let i = len + .to_int() + .unwrap_or_else(|| panic!("Ramp lanes must be an int, got {:?}", len)); + tir::Ramp::new(get_dt(dt), go(a), go(b), i as i32).upcast() + } + Expr::Broadcast([dt, val, lanes]) => { + let lanes = &nodes[usize::from(*lanes)]; + let lanes = lanes + .to_int() + .unwrap_or_else(|| panic!("Ramp lanes must be an int, got {:?}", lanes)); + println!("dt: {}", get_dt(dt)); + tir::Broadcast::new(get_dt(dt), go(val), lanes as i32).upcast() + } + + Expr::Select([dt, a, b, c]) => tir::Select::new(get_dt(dt), go(a), go(b), go(c)).upcast(), + Expr::Cast([dt, a]) => tir::Cast::new(get_dt(dt), go(a)).upcast(), + Expr::Call(expr, args) => { + let arg_exprs: Vec<PrimExpr> = args[1..].iter().map(go).collect(); + let arg_exprs = Array::from_vec(arg_exprs).expect("failed to convert args"); + tir::Call::new(get_dt(&args[0]), vars.get_obj(*expr), arg_exprs).upcast() + } + + Expr::Object(i) => vars.get_obj(*i), + node => panic!("I don't know how to extract {:?}", node), + }; + assert_ne!(prim.datatype.bits(), 0); + assert_ne!(prim.datatype.lanes(), 0); + prim + } + build(vars, recexpr.as_ref()) + } + + fn run( + input: PrimExpr, + expected: Option<PrimExpr>, + map: Map<PrimExpr, ConstIntBound>, + ) -> Result<PrimExpr, String> { + use egg::{CostFunction, Extractor}; + + let mut bounds = BoundsMap::default(); + for (k, v) in map { + if let Ok(var) = k.downcast_clone::<tir::Var>() { + let sym: egg::Symbol = var.name_hint.as_str().unwrap().into(); + bounds.insert(sym, (v.min_value, v.max_value)); + } else { + println!("Non var in bounds map: {}", tvm::ir::as_text(k)); + } + } + + let mut vars = VarMap::default(); + let expr = to_egg(&mut vars, &input); + let mut runner = math::default_runner(); + runner.egraph.analysis.bounds = bounds; + + let mut runner = runner.with_expr(&expr).run(&math::rules()); + // runner.print_report(); + let mut extractor = Extractor::new(&runner.egraph, math::CostFn); + let root = runner.egraph.find(runner.roots[0]); + let (cost, best) = extractor.find_best(root); + if let Some(expected) = expected { + let mut expected_vars = VarMap::default(); + let expected_expr = to_egg(&mut expected_vars, &expected); + let expected_root = runner.egraph.add_expr(&expected_expr); + if expected_root != root { + return Err(format!( + "\n\nFailed to prove them equal!\nExpected:\n{}\nFound:\n{}\n", + expected_expr.pretty(40), + best.pretty(40) + )); + } + let expected_cost = math::CostFn.cost_rec(&expected_expr); + if expected_cost != cost { + let msg = format!( + "\n\nCosts not equal: Expected {}:\n{}\nFound {}:\n{}\n", + expected_cost, + expected_expr.pretty(40), + cost, + best.pretty(40) + ); + if cost < expected_cost { + println!("egg wins: {}", msg) + } else { + return Err(msg); + } + } + } + log::info!(" returning... {}", best.pretty(60)); + Ok(from_egg(&vars, &best)) + } + + fn simplify(prim: PrimExpr, map: Map<PrimExpr, ConstIntBound>) -> Result<PrimExpr, tvm::Error> { + log::debug!("map: {:?}", map); + run(prim, None, map).map_err(tvm::Error::CallFailed) + } + + fn simplify_and_check( + prim: PrimExpr, + check: PrimExpr, + map: Map<PrimExpr, ConstIntBound>, + ) -> Result<PrimExpr, tvm::Error> { + log::debug!("check map: {:?}", map); + run(prim, Some(check), map).map_err(tvm::Error::CallFailed) + } + + initialize!({ + let _ = env_logger::try_init(); + // NOTE this print prevents a segfault (on Linux) for now... + println!("Initializing simplifier... "); + register_override(simplify, "egg.simplify", true).expect("failed to initialize simplifier"); + register_override(simplify_and_check, "egg.simplify_and_check", true) + .expect("failed to initialize simplifier"); + log::debug!("done!"); + }); + \ No newline at end of file