Unroll switch ranges if possible.

This commit is contained in:
Stephen Chung 2022-07-18 08:54:10 +08:00
parent 107193e35f
commit 4b760d1d0f
5 changed files with 132 additions and 60 deletions

View File

@ -17,6 +17,7 @@ Enhancements
* `switch` cases can now include multiple values separated by `|`. * `switch` cases can now include multiple values separated by `|`.
* `EvalContext::eval_expression_tree_raw` and `Expression::eval_with_context_raw` are added to allow for not rewinding the `Scope` at the end of a statements block. * `EvalContext::eval_expression_tree_raw` and `Expression::eval_with_context_raw` are added to allow for not rewinding the `Scope` at the end of a statements block.
* A new `range` function variant that takes an exclusive range with a step. * A new `range` function variant that takes an exclusive range with a step.
* Ranges in `switch` statements that are small (currently no more than 16 items) are unrolled if possible.
Version 1.8.0 Version 1.8.0

View File

@ -175,7 +175,6 @@ impl fmt::Debug for RangeCase {
impl From<Range<INT>> for RangeCase { impl From<Range<INT>> for RangeCase {
#[inline(always)] #[inline(always)]
#[must_use]
fn from(value: Range<INT>) -> Self { fn from(value: Range<INT>) -> Self {
Self::ExclusiveInt(value, 0) Self::ExclusiveInt(value, 0)
} }
@ -183,12 +182,24 @@ impl From<Range<INT>> for RangeCase {
impl From<RangeInclusive<INT>> for RangeCase { impl From<RangeInclusive<INT>> for RangeCase {
#[inline(always)] #[inline(always)]
#[must_use]
fn from(value: RangeInclusive<INT>) -> Self { fn from(value: RangeInclusive<INT>) -> Self {
Self::InclusiveInt(value, 0) Self::InclusiveInt(value, 0)
} }
} }
impl IntoIterator for RangeCase {
type Item = INT;
type IntoIter = Box<dyn Iterator<Item = Self::Item>>;
#[inline(always)]
fn into_iter(self) -> Self::IntoIter {
match self {
Self::ExclusiveInt(r, ..) => Box::new(r.into_iter()),
Self::InclusiveInt(r, ..) => Box::new(r.into_iter()),
}
}
}
impl RangeCase { impl RangeCase {
/// Is the range empty? /// Is the range empty?
#[inline(always)] #[inline(always)]
@ -199,6 +210,17 @@ impl RangeCase {
Self::InclusiveInt(r, ..) => r.is_empty(), Self::InclusiveInt(r, ..) => r.is_empty(),
} }
} }
/// Size of the range.
#[inline(always)]
#[must_use]
pub fn len(&self) -> usize {
match self {
Self::ExclusiveInt(r, ..) if r.is_empty() => 0,
Self::ExclusiveInt(r, ..) => (r.end - r.start) as usize,
Self::InclusiveInt(r, ..) if r.is_empty() => 0,
Self::InclusiveInt(r, ..) => (*r.end() - *r.start()) as usize,
}
}
/// Is the specified number within this range? /// Is the specified number within this range?
#[inline(always)] #[inline(always)]
#[must_use] #[must_use]
@ -208,19 +230,6 @@ impl RangeCase {
Self::InclusiveInt(r, ..) => r.contains(&n), Self::InclusiveInt(r, ..) => r.contains(&n),
} }
} }
/// If the range contains only of a single [`INT`], return it;
/// otherwise return [`None`].
#[inline(always)]
#[must_use]
pub fn single_int(&self) -> Option<INT> {
match self {
Self::ExclusiveInt(r, ..) if r.end.checked_sub(r.start) == Some(1) => Some(r.start),
Self::InclusiveInt(r, ..) if r.end().checked_sub(*r.start()) == Some(0) => {
Some(*r.start())
}
_ => None,
}
}
/// Is the specified range inclusive? /// Is the specified range inclusive?
#[inline(always)] #[inline(always)]
#[must_use] #[must_use]

View File

@ -525,7 +525,7 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b
let ( let (
match_expr, match_expr,
SwitchCases { SwitchCases {
blocks: blocks_list, blocks,
cases, cases,
ranges, ranges,
def_case, def_case,
@ -538,29 +538,29 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b
let hash = hasher.finish(); let hash = hasher.finish();
// First check hashes // First check hashes
if let Some(block) = cases.remove(&hash) { if let Some(b) = cases.remove(&hash) {
let mut block = mem::take(&mut blocks_list[block]); let mut b = mem::take(&mut blocks[b]);
cases.clear(); cases.clear();
match block.condition { match b.condition {
Expr::BoolConstant(true, ..) => { Expr::BoolConstant(true, ..) => {
// Promote the matched case // Promote the matched case
let statements: StmtBlockContainer = mem::take(&mut block.statements); let statements: StmtBlockContainer = mem::take(&mut b.statements);
let statements = optimize_stmt_block(statements, state, true, true, false); let statements = optimize_stmt_block(statements, state, true, true, false);
*stmt = (statements, block.statements.span()).into(); *stmt = (statements, b.statements.span()).into();
} }
ref mut condition => { ref mut condition => {
// switch const { case if condition => stmt, _ => def } => if condition { stmt } else { def } // switch const { case if condition => stmt, _ => def } => if condition { stmt } else { def }
optimize_expr(condition, state, false); optimize_expr(condition, state, false);
let def_case = &mut blocks_list[*def_case].statements; let def_case = &mut blocks[*def_case].statements;
let def_span = def_case.span_or_else(*pos, Position::NONE); let def_span = def_case.span_or_else(*pos, Position::NONE);
let def_case: StmtBlockContainer = mem::take(def_case); let def_case: StmtBlockContainer = mem::take(def_case);
let def_stmt = optimize_stmt_block(def_case, state, true, true, false); let def_stmt = optimize_stmt_block(def_case, state, true, true, false);
*stmt = Stmt::If( *stmt = Stmt::If(
( (
mem::take(condition), mem::take(condition),
mem::take(&mut block.statements), mem::take(&mut b.statements),
StmtBlock::new_with_span(def_stmt, def_span), StmtBlock::new_with_span(def_stmt, def_span),
) )
.into(), .into(),
@ -580,19 +580,16 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b
// Only one range or all ranges without conditions // Only one range or all ranges without conditions
if ranges.len() == 1 if ranges.len() == 1
|| ranges.iter().all(|r| { || ranges.iter().all(|r| {
matches!( matches!(blocks[r.index()].condition, Expr::BoolConstant(true, ..))
blocks_list[r.index()].condition,
Expr::BoolConstant(true, ..)
)
}) })
{ {
for r in ranges.iter().filter(|r| r.contains(value)) { for r in ranges.iter().filter(|r| r.contains(value)) {
let condition = mem::take(&mut blocks_list[r.index()].condition); let condition = mem::take(&mut blocks[r.index()].condition);
match condition { match condition {
Expr::BoolConstant(true, ..) => { Expr::BoolConstant(true, ..) => {
// Promote the matched case // Promote the matched case
let block = &mut blocks_list[r.index()]; let block = &mut blocks[r.index()];
let statements = mem::take(&mut *block.statements); let statements = mem::take(&mut *block.statements);
let statements = let statements =
optimize_stmt_block(statements, state, true, true, false); optimize_stmt_block(statements, state, true, true, false);
@ -602,13 +599,13 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b
// switch const { range if condition => stmt, _ => def } => if condition { stmt } else { def } // switch const { range if condition => stmt, _ => def } => if condition { stmt } else { def }
optimize_expr(&mut condition, state, false); optimize_expr(&mut condition, state, false);
let def_case = &mut blocks_list[*def_case].statements; let def_case = &mut blocks[*def_case].statements;
let def_span = def_case.span_or_else(*pos, Position::NONE); let def_span = def_case.span_or_else(*pos, Position::NONE);
let def_case: StmtBlockContainer = mem::take(def_case); let def_case: StmtBlockContainer = mem::take(def_case);
let def_stmt = let def_stmt =
optimize_stmt_block(def_case, state, true, true, false); optimize_stmt_block(def_case, state, true, true, false);
let statements = mem::take(&mut blocks_list[r.index()].statements); let statements = mem::take(&mut blocks[r.index()].statements);
*stmt = Stmt::If( *stmt = Stmt::If(
( (
@ -641,16 +638,16 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b
} }
for r in &*ranges { for r in &*ranges {
let block = &mut blocks_list[r.index()]; let b = &mut blocks[r.index()];
let statements = mem::take(&mut *block.statements); let statements = mem::take(&mut *b.statements);
*block.statements = *b.statements =
optimize_stmt_block(statements, state, preserve_result, true, false); optimize_stmt_block(statements, state, preserve_result, true, false);
optimize_expr(&mut block.condition, state, false); optimize_expr(&mut b.condition, state, false);
match block.condition { match b.condition {
Expr::Unit(pos) => { Expr::Unit(pos) => {
block.condition = Expr::BoolConstant(true, pos); b.condition = Expr::BoolConstant(true, pos);
state.set_dirty() state.set_dirty()
} }
_ => (), _ => (),
@ -662,7 +659,7 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b
// Promote the default case // Promote the default case
state.set_dirty(); state.set_dirty();
let def_case = &mut blocks_list[*def_case].statements; let def_case = &mut blocks[*def_case].statements;
let def_span = def_case.span_or_else(*pos, Position::NONE); let def_span = def_case.span_or_else(*pos, Position::NONE);
let def_case: StmtBlockContainer = mem::take(def_case); let def_case: StmtBlockContainer = mem::take(def_case);
let def_stmt = optimize_stmt_block(def_case, state, true, true, false); let def_stmt = optimize_stmt_block(def_case, state, true, true, false);
@ -673,7 +670,7 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b
let ( let (
match_expr, match_expr,
SwitchCases { SwitchCases {
blocks: blocks_list, blocks,
cases, cases,
ranges, ranges,
def_case, def_case,
@ -684,21 +681,21 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b
optimize_expr(match_expr, state, false); optimize_expr(match_expr, state, false);
// Optimize blocks // Optimize blocks
for block in blocks_list.iter_mut() { for b in blocks.iter_mut() {
let statements = mem::take(&mut *block.statements); let statements = mem::take(&mut *b.statements);
*block.statements = *b.statements =
optimize_stmt_block(statements, state, preserve_result, true, false); optimize_stmt_block(statements, state, preserve_result, true, false);
optimize_expr(&mut block.condition, state, false); optimize_expr(&mut b.condition, state, false);
match block.condition { match b.condition {
Expr::Unit(pos) => { Expr::Unit(pos) => {
block.condition = Expr::BoolConstant(true, pos); b.condition = Expr::BoolConstant(true, pos);
state.set_dirty(); state.set_dirty();
} }
Expr::BoolConstant(false, ..) => { Expr::BoolConstant(false, ..) => {
if !block.statements.is_empty() { if !b.statements.is_empty() {
block.statements = StmtBlock::NONE; b.statements = StmtBlock::NONE;
state.set_dirty(); state.set_dirty();
} }
} }
@ -707,7 +704,7 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b
} }
// Remove false cases // Remove false cases
cases.retain(|_, &mut block| match blocks_list[block].condition { cases.retain(|_, &mut block| match blocks[block].condition {
Expr::BoolConstant(false, ..) => { Expr::BoolConstant(false, ..) => {
state.set_dirty(); state.set_dirty();
false false
@ -715,7 +712,7 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b
_ => true, _ => true,
}); });
// Remove false ranges // Remove false ranges
ranges.retain(|r| match blocks_list[r.index()].condition { ranges.retain(|r| match blocks[r.index()].condition {
Expr::BoolConstant(false, ..) => { Expr::BoolConstant(false, ..) => {
state.set_dirty(); state.set_dirty();
false false
@ -723,9 +720,26 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b
_ => true, _ => true,
}); });
let def_case = &mut blocks_list[*def_case].statements; let def_stmt_block = &mut blocks[*def_case].statements;
let def_block = mem::take(&mut **def_case); let def_block = mem::take(&mut **def_stmt_block);
**def_case = optimize_stmt_block(def_block, state, preserve_result, true, false); **def_stmt_block = optimize_stmt_block(def_block, state, preserve_result, true, false);
// Remove unused block statements
for index in 0..blocks.len() {
if *def_case == index
|| cases.values().any(|&n| n == index)
|| ranges.iter().any(|r| r.index() == index)
{
continue;
}
let b = &mut blocks[index];
if !b.statements.is_empty() {
b.statements = StmtBlock::NONE;
state.set_dirty();
}
}
} }
// while false { block } -> Noop // while false { block } -> Noop

View File

@ -39,6 +39,9 @@ const SCOPE_SEARCH_BARRIER_MARKER: &str = "$ BARRIER $";
/// The message: `TokenStream` never ends /// The message: `TokenStream` never ends
const NEVER_ENDS: &str = "`Token`"; const NEVER_ENDS: &str = "`Token`";
/// Unroll `switch` ranges no larger than this.
const SMALL_SWITCH_RANGE: usize = 16;
/// _(internals)_ A type that encapsulates the current state of the parser. /// _(internals)_ A type that encapsulates the current state of the parser.
/// Exported under the `internals` feature only. /// Exported under the `internals` feature only.
pub struct ParseState<'e> { pub struct ParseState<'e> {
@ -1138,6 +1141,7 @@ impl Engine {
let stmt = self.parse_stmt(input, state, lib, settings.level_up())?; let stmt = self.parse_stmt(input, state, lib, settings.level_up())?;
let need_comma = !stmt.is_self_terminated(); let need_comma = !stmt.is_self_terminated();
let has_condition = !matches!(condition, Expr::BoolConstant(true, ..));
blocks.push((condition, stmt).into()); blocks.push((condition, stmt).into());
let index = blocks.len() - 1; let index = blocks.len() - 1;
@ -1159,14 +1163,15 @@ impl Engine {
if let Some(mut r) = range_value { if let Some(mut r) = range_value {
if !r.is_empty() { if !r.is_empty() {
if let Some(n) = r.single_int() { // Do not unroll ranges if there are previous non-unrolled ranges
// Unroll single range if !has_condition && ranges.is_empty() && r.len() <= SMALL_SWITCH_RANGE
let value = Dynamic::from_int(n); {
let hasher = &mut get_hasher(); // Unroll small range
value.hash(hasher); for n in r {
let hash = hasher.finish(); let hasher = &mut get_hasher();
Dynamic::from_int(n).hash(hasher);
cases.entry(hash).or_insert(index); cases.entry(hasher.finish()).or_insert(index);
}
} else { } else {
// Other range // Other range
r.set_index(index); r.set_index(index);

View File

@ -290,6 +290,49 @@ fn test_switch_ranges() -> Result<(), Box<EvalAltResult>> {
)?, )?,
'x' 'x'
); );
assert_eq!(
engine.eval_with_scope::<INT>(
&mut scope,
"
switch 5 {
'a' => true,
0..10 => 123,
2..12 => 'z',
_ => 'x'
}
"
)?,
123
);
assert_eq!(
engine.eval_with_scope::<INT>(
&mut scope,
"
switch 5 {
'a' => true,
4 | 5 | 6 => 42,
0..10 => 123,
2..12 => 'z',
_ => 'x'
}
"
)?,
42
);
assert_eq!(
engine.eval_with_scope::<char>(
&mut scope,
"
switch 5 {
'a' => true,
2..12 => 'z',
0..10 if x+2==1+2 => print(40+2),
_ => 'x'
}
"
)?,
'z'
);
assert_eq!( assert_eq!(
engine.eval_with_scope::<char>( engine.eval_with_scope::<char>(
&mut scope, &mut scope,