Fix bug in optimize_ast skipping Stmt::FnCall.

This commit is contained in:
Stephen Chung 2023-04-10 13:11:24 +08:00
parent a82bb7b2ef
commit 120ff91074
3 changed files with 66 additions and 27 deletions

View File

@ -6,7 +6,8 @@ use crate::ast::{
SwitchCasesCollection, SwitchCasesCollection,
}; };
use crate::engine::{ use crate::engine::{
KEYWORD_DEBUG, KEYWORD_EVAL, KEYWORD_FN_PTR, KEYWORD_PRINT, KEYWORD_TYPE_OF, OP_NOT, KEYWORD_DEBUG, KEYWORD_EVAL, KEYWORD_FN_PTR, KEYWORD_FN_PTR_CURRY, KEYWORD_PRINT,
KEYWORD_TYPE_OF, OP_NOT,
}; };
use crate::eval::{Caches, GlobalRuntimeState}; use crate::eval::{Caches, GlobalRuntimeState};
use crate::func::builtin::get_builtin_binary_op_fn; use crate::func::builtin::get_builtin_binary_op_fn;
@ -816,21 +817,28 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b
} }
} }
Stmt::Expr(expr) => { // expr(func())
optimize_expr(expr, state, false); Stmt::Expr(expr) if matches!(**expr, Expr::FnCall(..)) => {
// Do not promote until the expression is fully optimized
if !state.is_dirty() && matches!(**expr, Expr::FnCall(..) | Expr::Stmt(..)) {
*stmt = match *mem::take(expr) {
// func(...);
Expr::FnCall(x, pos) => Stmt::FnCall(x, pos),
// {};
Expr::Stmt(x) if x.is_empty() => Stmt::Noop(x.position()),
// {...};
Expr::Stmt(x) => (*x).into(),
_ => unreachable!(),
};
state.set_dirty(); state.set_dirty();
match mem::take(expr.as_mut()) {
Expr::FnCall(x, pos) => *stmt = Stmt::FnCall(x, pos),
_ => unreachable!(),
}
}
Stmt::Expr(expr) => optimize_expr(expr, state, false),
// func(...)
Stmt::FnCall(..) => {
if let Stmt::FnCall(x, pos) = mem::take(stmt) {
let mut expr = Expr::FnCall(x, pos);
optimize_expr(&mut expr, state, false);
*stmt = match expr {
Expr::FnCall(x, pos) => Stmt::FnCall(x, pos),
_ => Stmt::Expr(expr.into()),
}
} else {
unreachable!();
} }
} }
@ -1122,9 +1130,9 @@ fn optimize_expr(expr: &mut Expr, state: &mut OptimizerState, _chaining: bool) {
} }
} }
// Do not call some special keywords // Do not call some special keywords that may have side effects
Expr::FnCall(x, ..) if DONT_EVAL_KEYWORDS.contains(&x.name.as_str()) => { Expr::FnCall(x, ..) if DONT_EVAL_KEYWORDS.contains(&x.name.as_str()) => {
x.args.iter_mut().for_each(|a| optimize_expr(a, state, false)); x.args.iter_mut().for_each(|arg_expr| optimize_expr(arg_expr, state, false));
} }
// Call built-in operators // Call built-in operators
@ -1133,7 +1141,7 @@ fn optimize_expr(expr: &mut Expr, state: &mut OptimizerState, _chaining: bool) {
&& state.optimization_level == OptimizationLevel::Simple // simple optimizations && state.optimization_level == OptimizationLevel::Simple // simple optimizations
&& x.constant_args() // all arguments are constants && x.constant_args() // all arguments are constants
=> { => {
let arg_values = &mut x.args.iter().map(|e| e.get_literal_value().unwrap()).collect::<StaticVec<_>>(); let arg_values = &mut x.args.iter().map(|arg_expr| arg_expr.get_literal_value().unwrap()).collect::<StaticVec<_>>();
let arg_types: StaticVec<_> = arg_values.iter().map(Dynamic::type_id).collect(); let arg_types: StaticVec<_> = arg_values.iter().map(Dynamic::type_id).collect();
match x.name.as_str() { match x.name.as_str() {
@ -1165,10 +1173,10 @@ fn optimize_expr(expr: &mut Expr, state: &mut OptimizerState, _chaining: bool) {
_ => () _ => ()
} }
x.args.iter_mut().for_each(|a| optimize_expr(a, state, false)); x.args.iter_mut().for_each(|arg_expr| optimize_expr(arg_expr, state, false));
// Move constant arguments // Move constant arguments
x.args.iter_mut().for_each(|arg| match arg { x.args.iter_mut().for_each(|arg_expr| match arg_expr {
Expr::DynamicConstant(..) | Expr::Unit(..) Expr::DynamicConstant(..) | Expr::Unit(..)
| Expr::StringConstant(..) | Expr::CharConstant(..) | Expr::StringConstant(..) | Expr::CharConstant(..)
| Expr::BoolConstant(..) | Expr::IntegerConstant(..) => (), | Expr::BoolConstant(..) | Expr::IntegerConstant(..) => (),
@ -1176,9 +1184,9 @@ fn optimize_expr(expr: &mut Expr, state: &mut OptimizerState, _chaining: bool) {
#[cfg(not(feature = "no_float"))] #[cfg(not(feature = "no_float"))]
Expr:: FloatConstant(..) => (), Expr:: FloatConstant(..) => (),
_ => if let Some(value) = arg.get_literal_value() { _ => if let Some(value) = arg_expr.get_literal_value() {
state.set_dirty(); state.set_dirty();
*arg = Expr::DynamicConstant(value.into(), arg.start_position()); *arg_expr = Expr::DynamicConstant(value.into(), arg_expr.start_position());
}, },
}); });
} }
@ -1216,11 +1224,11 @@ fn optimize_expr(expr: &mut Expr, state: &mut OptimizerState, _chaining: bool) {
} }
// id(args ..) or xxx.id(args ..) -> optimize function call arguments // id(args ..) or xxx.id(args ..) -> optimize function call arguments
Expr::FnCall(x, ..) | Expr::MethodCall(x, ..) => x.args.iter_mut().for_each(|arg| { Expr::FnCall(x, ..) | Expr::MethodCall(x, ..) => x.args.iter_mut().for_each(|arg_expr| {
optimize_expr(arg, state, false); optimize_expr(arg_expr, state, false);
// Move constant arguments // Move constant arguments
match arg { match arg_expr {
Expr::DynamicConstant(..) | Expr::Unit(..) Expr::DynamicConstant(..) | Expr::Unit(..)
| Expr::StringConstant(..) | Expr::CharConstant(..) | Expr::StringConstant(..) | Expr::CharConstant(..)
| Expr::BoolConstant(..) | Expr::IntegerConstant(..) => (), | Expr::BoolConstant(..) | Expr::IntegerConstant(..) => (),
@ -1228,9 +1236,9 @@ fn optimize_expr(expr: &mut Expr, state: &mut OptimizerState, _chaining: bool) {
#[cfg(not(feature = "no_float"))] #[cfg(not(feature = "no_float"))]
Expr:: FloatConstant(..) => (), Expr:: FloatConstant(..) => (),
_ => if let Some(value) = arg.get_literal_value() { _ => if let Some(value) = arg_expr.get_literal_value() {
state.set_dirty(); state.set_dirty();
*arg = Expr::DynamicConstant(value.into(), arg.start_position()); *arg_expr = Expr::DynamicConstant(value.into(), arg_expr.start_position());
}, },
} }
}), }),
@ -1378,7 +1386,6 @@ impl Engine {
functions.into_iter().for_each(|fn_def| { functions.into_iter().for_each(|fn_def| {
let mut fn_def = crate::func::shared_take_or_clone(fn_def); let mut fn_def = crate::func::shared_take_or_clone(fn_def);
// Optimize the function body // Optimize the function body
let body = mem::take(&mut *fn_def.body); let body = mem::take(&mut *fn_def.body);

View File

@ -581,3 +581,10 @@ impl IndexMut<usize> for FnPtr {
self.curry.index_mut(index) self.curry.index_mut(index)
} }
} }
impl Extend<Dynamic> for FnPtr {
#[inline(always)]
fn extend<T: IntoIterator<Item = Dynamic>>(&mut self, iter: T) {
self.curry.extend(iter)
}
}

View File

@ -163,3 +163,28 @@ fn test_optimizer_scope() -> Result<(), Box<EvalAltResult>> {
Ok(()) Ok(())
} }
#[cfg(not(feature = "no_function"))]
#[cfg(not(feature = "no_closure"))]
#[test]
fn test_optimizer_reoptimize() -> Result<(), Box<EvalAltResult>> {
const SCRIPT: &str = "
const FOO = 42;
fn foo() {
let f = || FOO * 2;
f.call()
}
foo()
";
let engine = Engine::new();
let ast = engine.compile(SCRIPT)?;
let scope: Scope = ast.iter_literal_variables(true, false).collect();
let ast = engine.optimize_ast(&scope, ast, OptimizationLevel::Simple);
println!("{ast:#?}");
assert_eq!(engine.eval_ast::<INT>(&ast)?, 84);
Ok(())
}