diff --git a/src/optimize.rs b/src/optimize.rs index d087f584..7c863747 100644 --- a/src/optimize.rs +++ b/src/optimize.rs @@ -88,11 +88,6 @@ impl<'a> State<'a> { optimization_level, } } - /// Reset the state from dirty to clean. - #[inline(always)] - pub fn reset(&mut self) { - self.changed = false; - } /// Set the [`AST`] state to be dirty (i.e. changed). #[inline(always)] pub fn set_dirty(&mut self) { @@ -179,6 +174,8 @@ fn optimize_stmt_block( mut statements: Vec, state: &mut State, preserve_result: bool, + is_internal: bool, + reduce_return: bool, ) -> Vec { if statements.is_empty() { return statements; @@ -186,6 +183,12 @@ fn optimize_stmt_block( let mut is_dirty = state.is_dirty(); + let is_pure = if is_internal { + Stmt::is_internally_pure + } else { + Stmt::is_pure + }; + loop { state.clear_dirty(); @@ -228,18 +231,89 @@ fn optimize_stmt_block( } }); + // Remove all pure statements except the last one + let mut index = 0; + let mut first_non_constant = statements + .iter() + .rev() + .enumerate() + .find_map(|(i, stmt)| match stmt { + stmt if !is_pure(stmt) => Some(i), + + Stmt::Noop(_) | Stmt::Return(_, None, _) | Stmt::Export(_, _) | Stmt::Share(_) => { + None + } + + Stmt::Let(e, _, _, _) + | Stmt::Const(e, _, _, _) + | Stmt::Expr(e) + | Stmt::Return(_, Some(e), _) + | Stmt::Import(e, _, _) + if e.is_constant() => + { + None + } + + _ => Some(i), + }) + .map_or(0, |n| statements.len() - n); + + while index < statements.len() { + if preserve_result && index >= statements.len() - 1 { + break; + } else { + match &statements[index] { + stmt if is_pure(stmt) && index >= first_non_constant => { + state.set_dirty(); + statements.remove(index); + } + stmt if stmt.is_pure() => { + state.set_dirty(); + if index < first_non_constant { + first_non_constant -= 1; + } + statements.remove(index); + } + _ => index += 1, + } + } + } + // Remove all pure statements that do not return values at the end of a block. // We cannot remove anything for non-pure statements due to potential side-effects. if preserve_result { loop { - match &statements[..] { - [stmt] if !stmt.returns_value() && stmt.is_internally_pure() => { + match &mut statements[..] { + // { return; } -> {} + [Stmt::Return(crate::ast::ReturnType::Return, None, _)] if reduce_return => { state.set_dirty(); statements.clear(); } + [stmt] if !stmt.returns_value() && is_pure(stmt) => { + state.set_dirty(); + statements.clear(); + } + // { ...; return; } -> { ... } + [.., last_stmt, Stmt::Return(crate::ast::ReturnType::Return, None, _)] + if reduce_return && !last_stmt.returns_value() => + { + state.set_dirty(); + statements.pop().unwrap(); + } + // { ...; return val; } -> { ...; val } + [.., Stmt::Return(crate::ast::ReturnType::Return, expr, pos)] + if reduce_return => + { + state.set_dirty(); + *statements.last_mut().unwrap() = if let Some(expr) = expr { + Stmt::Expr(mem::take(expr)) + } else { + Stmt::Noop(*pos) + }; + } [.., second_last_stmt, Stmt::Noop(_)] if second_last_stmt.returns_value() => {} [.., second_last_stmt, last_stmt] - if !last_stmt.returns_value() && last_stmt.is_internally_pure() => + if !last_stmt.returns_value() && is_pure(last_stmt) => { state.set_dirty(); if second_last_stmt.returns_value() { @@ -254,11 +328,25 @@ fn optimize_stmt_block( } else { loop { match &statements[..] { - [stmt] if stmt.is_internally_pure() => { + [stmt] if is_pure(stmt) => { state.set_dirty(); statements.clear(); } - [.., last_stmt] if last_stmt.is_internally_pure() => { + // { ...; return; } -> { ... } + [.., Stmt::Return(crate::ast::ReturnType::Return, None, _)] + if reduce_return => + { + state.set_dirty(); + statements.pop().unwrap(); + } + // { ...; return pure_val; } -> { ... } + [.., Stmt::Return(crate::ast::ReturnType::Return, Some(expr), _)] + if reduce_return && expr.is_pure() => + { + state.set_dirty(); + statements.pop().unwrap(); + } + [.., last_stmt] if is_pure(last_stmt) => { state.set_dirty(); statements.pop().unwrap(); } @@ -282,6 +370,9 @@ fn optimize_stmt_block( state.set_dirty(); } + println!("{:?}", statements); + + statements.shrink_to_fit(); statements } @@ -321,11 +412,8 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut State, preserve_result: bool) { // if false { if_block } else { else_block } -> else_block Stmt::If(Expr::BoolConstant(false, _), x, _) => { state.set_dirty(); - *stmt = match optimize_stmt_block( - mem::take(&mut x.1.statements).into_vec(), - state, - preserve_result, - ) { + let else_block = mem::take(&mut x.1.statements).into_vec(); + *stmt = match optimize_stmt_block(else_block, state, preserve_result, true, false) { statements if statements.is_empty() => Stmt::Noop(x.1.pos), statements => Stmt::Block(statements, x.1.pos), } @@ -333,11 +421,8 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut State, preserve_result: bool) { // if true { if_block } else { else_block } -> if_block Stmt::If(Expr::BoolConstant(true, _), x, _) => { state.set_dirty(); - *stmt = match optimize_stmt_block( - mem::take(&mut x.0.statements).into_vec(), - state, - preserve_result, - ) { + let if_block = mem::take(&mut x.0.statements).into_vec(); + *stmt = match optimize_stmt_block(if_block, state, preserve_result, true, false) { statements if statements.is_empty() => Stmt::Noop(x.0.pos), statements => Stmt::Block(statements, x.0.pos), } @@ -345,18 +430,12 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut State, preserve_result: bool) { // if expr { if_block } else { else_block } Stmt::If(condition, x, _) => { optimize_expr(condition, state); - x.0.statements = optimize_stmt_block( - mem::take(&mut x.0.statements).into_vec(), - state, - preserve_result, - ) - .into(); - x.1.statements = optimize_stmt_block( - mem::take(&mut x.1.statements).into_vec(), - state, - preserve_result, - ) - .into(); + let if_block = mem::take(&mut x.0.statements).into_vec(); + x.0.statements = + optimize_stmt_block(if_block, state, preserve_result, true, false).into(); + let else_block = mem::take(&mut x.1.statements).into_vec(); + x.1.statements = + optimize_stmt_block(else_block, state, preserve_result, true, false).into(); } // switch const { ... } @@ -371,15 +450,15 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut State, preserve_result: bool) { let table = &mut x.0; let (statements, new_pos) = if let Some(block) = table.get_mut(&hash) { + let match_block = mem::take(&mut block.statements).into_vec(); ( - optimize_stmt_block(mem::take(&mut block.statements).into_vec(), state, true) - .into(), + optimize_stmt_block(match_block, state, true, true, false).into(), block.pos, ) } else { + let def_block = mem::take(&mut x.1.statements).into_vec(); ( - optimize_stmt_block(mem::take(&mut x.1.statements).into_vec(), state, true) - .into(), + optimize_stmt_block(def_block, state, true, true, false).into(), if x.1.pos.is_none() { *pos } else { x.1.pos }, ) }; @@ -393,19 +472,13 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut State, preserve_result: bool) { Stmt::Switch(expr, x, _) => { optimize_expr(expr, state); x.0.values_mut().for_each(|block| { - block.statements = optimize_stmt_block( - mem::take(&mut block.statements).into_vec(), - state, - preserve_result, - ) - .into() + let match_block = mem::take(&mut block.statements).into_vec(); + block.statements = + optimize_stmt_block(match_block, state, preserve_result, true, false).into() }); - x.1.statements = optimize_stmt_block( - mem::take(&mut x.1.statements).into_vec(), - state, - preserve_result, - ) - .into() + let def_block = mem::take(&mut x.1.statements).into_vec(); + x.1.statements = + optimize_stmt_block(def_block, state, preserve_result, true, false).into() } // while false { block } -> Noop @@ -414,15 +487,14 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut State, preserve_result: bool) { *stmt = Stmt::Noop(*pos) } // while expr { block } - Stmt::While(condition, block, _) => { + Stmt::While(condition, body, _) => { optimize_expr(condition, state); - block.statements = - optimize_stmt_block(mem::take(&mut block.statements).into_vec(), state, false) - .into(); + let block = mem::take(&mut body.statements).into_vec(); + body.statements = optimize_stmt_block(block, state, false, true, false).into(); - if block.len() == 1 { - match block.statements[0] { + if body.len() == 1 { + match body.statements[0] { // while expr { break; } -> { expr; } Stmt::Break(pos) => { // Only a single break statement - turn into running the guard expression once @@ -442,26 +514,26 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut State, preserve_result: bool) { } } // do { block } while false | do { block } until true -> { block } - Stmt::Do(block, Expr::BoolConstant(true, _), false, _) - | Stmt::Do(block, Expr::BoolConstant(false, _), true, _) => { + Stmt::Do(body, Expr::BoolConstant(true, _), false, _) + | Stmt::Do(body, Expr::BoolConstant(false, _), true, _) => { state.set_dirty(); + let block = mem::take(&mut body.statements).into_vec(); *stmt = Stmt::Block( - optimize_stmt_block(mem::take(&mut block.statements).into_vec(), state, false), - block.pos, + optimize_stmt_block(block, state, false, true, false), + body.pos, ); } // do { block } while|until expr - Stmt::Do(block, condition, _, _) => { + Stmt::Do(body, condition, _, _) => { optimize_expr(condition, state); - block.statements = - optimize_stmt_block(mem::take(&mut block.statements).into_vec(), state, false) - .into(); + let block = mem::take(&mut body.statements).into_vec(); + body.statements = optimize_stmt_block(block, state, false, true, false).into(); } // for id in expr { block } Stmt::For(iterable, x, _) => { optimize_expr(iterable, state); - x.1.statements = - optimize_stmt_block(mem::take(&mut x.1.statements).into_vec(), state, false).into(); + let body = mem::take(&mut x.1.statements).into_vec(); + x.1.statements = optimize_stmt_block(body, state, false, true, false).into(); } // let id = expr; Stmt::Let(expr, _, _, _) => optimize_expr(expr, state), @@ -470,7 +542,8 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut State, preserve_result: bool) { Stmt::Import(expr, _, _) => optimize_expr(expr, state), // { block } Stmt::Block(statements, pos) => { - *stmt = match optimize_stmt_block(mem::take(statements), state, preserve_result) { + let block = mem::take(statements); + *stmt = match optimize_stmt_block(block, state, preserve_result, true, false) { statements if statements.is_empty() => { state.set_dirty(); Stmt::Noop(*pos) @@ -483,21 +556,22 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut State, preserve_result: bool) { statements => Stmt::Block(statements, *pos), }; } - // try { pure block } catch ( var ) { block } + // try { pure try_block } catch ( var ) { catch_block } -> try_block Stmt::TryCatch(x, _, _) if x.0.statements.iter().all(Stmt::is_pure) => { // If try block is pure, there will never be any exceptions state.set_dirty(); + let try_block = mem::take(&mut x.0.statements).into_vec(); *stmt = Stmt::Block( - optimize_stmt_block(mem::take(&mut x.0.statements).into_vec(), state, false), + optimize_stmt_block(try_block, state, false, true, false), x.0.pos, ); } - // try { block } catch ( var ) { block } + // try { try_block } catch ( var ) { catch_block } Stmt::TryCatch(x, _, _) => { - x.0.statements = - optimize_stmt_block(mem::take(&mut x.0.statements).into_vec(), state, false).into(); - x.2.statements = - optimize_stmt_block(mem::take(&mut x.2.statements).into_vec(), state, false).into(); + let try_block = mem::take(&mut x.0.statements).into_vec(); + x.0.statements = optimize_stmt_block(try_block, state, false, true, false).into(); + let catch_block = mem::take(&mut x.2.statements).into_vec(); + x.2.statements = optimize_stmt_block(catch_block, state, false, true, false).into(); } // {} Stmt::Expr(Expr::Stmt(x)) if x.statements.is_empty() => { @@ -532,7 +606,7 @@ fn optimize_expr(expr: &mut Expr, state: &mut State) { // {} Expr::Stmt(x) if x.statements.is_empty() => { state.set_dirty(); *expr = Expr::Unit(x.pos) } // { stmt; ... } - do not count promotion as dirty because it gets turned back into an array - Expr::Stmt(x) => x.statements = optimize_stmt_block(mem::take(&mut x.statements).into_vec(), state, true).into(), + Expr::Stmt(x) => x.statements = optimize_stmt_block(mem::take(&mut x.statements).into_vec(), state, true, true, false).into(), // lhs.rhs #[cfg(not(feature = "no_object"))] Expr::Dot(x, _) => match (&mut x.lhs, &mut x.rhs) { @@ -785,62 +859,7 @@ fn optimize_top_level( } }); - let orig_constants_len = state.variables.len(); - - // Optimization loop - loop { - state.reset(); - state.restore_var(orig_constants_len); - - let num_statements = statements.len(); - - statements.iter_mut().enumerate().for_each(|(i, stmt)| { - match stmt { - Stmt::Const(value_expr, Ident { name, .. }, _, _) => { - // Load constants - optimize_expr(value_expr, &mut state); - - if value_expr.is_constant() { - state.push_var(name, AccessMode::ReadOnly, value_expr.clone()); - } - } - Stmt::Let(value_expr, Ident { name, pos, .. }, _, _) => { - optimize_expr(value_expr, &mut state); - state.push_var(name, AccessMode::ReadWrite, Expr::Unit(*pos)); - } - _ => { - // Keep all variable declarations at this level - // and always keep the last return value - let keep = match stmt { - Stmt::Let(_, _, _, _) | Stmt::Const(_, _, _, _) => true, - #[cfg(not(feature = "no_module"))] - Stmt::Import(_, _, _) => true, - _ => i >= num_statements - 1, - }; - optimize_stmt(stmt, &mut state, keep); - } - } - }); - - if !state.is_dirty() { - break; - } - } - - // Eliminate code that is pure but always keep the last statement - let last_stmt = statements.pop(); - - // Remove all pure statements at global level - statements.retain(|stmt| !stmt.is_pure()); - - // Add back the last statement unless it is a lone No-op - if let Some(stmt) = last_stmt { - if !statements.is_empty() || !stmt.is_noop() { - statements.push(stmt); - } - } - - statements.shrink_to_fit(); + statements = optimize_stmt_block(statements, &mut state, true, false, true); statements } @@ -895,34 +914,10 @@ pub fn optimize_into_ast( let mut body = fn_def.body.statements.into_vec(); - loop { - // Optimize the function body - let state = &mut State::new(engine, lib2, level); + // Optimize the function body + let state = &mut State::new(engine, lib2, level); - body = optimize_stmt_block(body, state, true); - - match &mut body[..] { - // { return; } -> {} - [Stmt::Return(crate::ast::ReturnType::Return, None, _)] => { - body.clear(); - } - // { ...; return; } -> { ... } - [.., last_stmt, Stmt::Return(crate::ast::ReturnType::Return, None, _)] - if !last_stmt.returns_value() => - { - body.pop().unwrap(); - } - // { ...; return val; } -> { ...; val } - [.., Stmt::Return(crate::ast::ReturnType::Return, expr, pos)] => { - *body.last_mut().unwrap() = if let Some(expr) = expr { - Stmt::Expr(mem::take(expr)) - } else { - Stmt::Noop(*pos) - }; - } - _ => break, - } - } + body = optimize_stmt_block(body, state, true, true, true); fn_def.body = StmtBlock { statements: body.into(), diff --git a/tests/optimizer.rs b/tests/optimizer.rs index 4c7cfb49..7f9d2bd5 100644 --- a/tests/optimizer.rs +++ b/tests/optimizer.rs @@ -55,7 +55,13 @@ fn test_optimizer_parse() -> Result<(), Box> { let ast = engine.compile("{ const DECISION = false; if DECISION { 42 } else { 123 } }")?; - assert!(format!("{:?}", ast).starts_with(r#"AST { source: None, body: [Block([Const(BoolConstant(false, 1:20), Ident("DECISION" @ 1:9), false, 1:3), Expr(IntegerConstant(123, 1:53))], 1:1)], functions: Module("#)); + assert!(format!("{:?}", ast).starts_with( + r#"AST { source: None, body: [Expr(IntegerConstant(123, 1:53))], functions: Module("# + )); + + let ast = engine.compile("const DECISION = false; if DECISION { 42 } else { 123 }")?; + + assert!(format!("{:?}", ast).starts_with(r#"AST { source: None, body: [Const(BoolConstant(false, 1:18), Ident("DECISION" @ 1:7), false, 1:1), Expr(IntegerConstant(123, 1:51))], functions: Module("#)); let ast = engine.compile("if 1 == 2 { 42 }")?;