diff --git a/CHANGELOG.md b/CHANGELOG.md index 73f3a40c..e9e88417 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ Enhancements * `switch` cases can now include multiple values separated by `|`. * `EvalContext::eval_expression_tree_raw` and `Expression::eval_with_context_raw` are added to allow for not rewinding the `Scope` at the end of a statements block. * A new `range` function variant that takes an exclusive range with a step. +* Ranges in `switch` statements that are small (currently no more than 16 items) are unrolled if possible. Version 1.8.0 diff --git a/src/ast/stmt.rs b/src/ast/stmt.rs index 5acbd892..0448be56 100644 --- a/src/ast/stmt.rs +++ b/src/ast/stmt.rs @@ -175,7 +175,6 @@ impl fmt::Debug for RangeCase { impl From> for RangeCase { #[inline(always)] - #[must_use] fn from(value: Range) -> Self { Self::ExclusiveInt(value, 0) } @@ -183,12 +182,24 @@ impl From> for RangeCase { impl From> for RangeCase { #[inline(always)] - #[must_use] fn from(value: RangeInclusive) -> Self { Self::InclusiveInt(value, 0) } } +impl IntoIterator for RangeCase { + type Item = INT; + type IntoIter = Box>; + + #[inline(always)] + fn into_iter(self) -> Self::IntoIter { + match self { + Self::ExclusiveInt(r, ..) => Box::new(r.into_iter()), + Self::InclusiveInt(r, ..) => Box::new(r.into_iter()), + } + } +} + impl RangeCase { /// Is the range empty? #[inline(always)] @@ -199,6 +210,17 @@ impl RangeCase { Self::InclusiveInt(r, ..) => r.is_empty(), } } + /// Size of the range. + #[inline(always)] + #[must_use] + pub fn len(&self) -> usize { + match self { + Self::ExclusiveInt(r, ..) if r.is_empty() => 0, + Self::ExclusiveInt(r, ..) => (r.end - r.start) as usize, + Self::InclusiveInt(r, ..) if r.is_empty() => 0, + Self::InclusiveInt(r, ..) => (*r.end() - *r.start()) as usize, + } + } /// Is the specified number within this range? #[inline(always)] #[must_use] @@ -208,19 +230,6 @@ impl RangeCase { Self::InclusiveInt(r, ..) => r.contains(&n), } } - /// If the range contains only of a single [`INT`], return it; - /// otherwise return [`None`]. - #[inline(always)] - #[must_use] - pub fn single_int(&self) -> Option { - match self { - Self::ExclusiveInt(r, ..) if r.end.checked_sub(r.start) == Some(1) => Some(r.start), - Self::InclusiveInt(r, ..) if r.end().checked_sub(*r.start()) == Some(0) => { - Some(*r.start()) - } - _ => None, - } - } /// Is the specified range inclusive? #[inline(always)] #[must_use] diff --git a/src/optimizer.rs b/src/optimizer.rs index 2a807055..3c55695e 100644 --- a/src/optimizer.rs +++ b/src/optimizer.rs @@ -525,7 +525,7 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b let ( match_expr, SwitchCases { - blocks: blocks_list, + blocks, cases, ranges, def_case, @@ -538,29 +538,29 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b let hash = hasher.finish(); // First check hashes - if let Some(block) = cases.remove(&hash) { - let mut block = mem::take(&mut blocks_list[block]); + if let Some(b) = cases.remove(&hash) { + let mut b = mem::take(&mut blocks[b]); cases.clear(); - match block.condition { + match b.condition { Expr::BoolConstant(true, ..) => { // Promote the matched case - let statements: StmtBlockContainer = mem::take(&mut block.statements); + let statements: StmtBlockContainer = mem::take(&mut b.statements); let statements = optimize_stmt_block(statements, state, true, true, false); - *stmt = (statements, block.statements.span()).into(); + *stmt = (statements, b.statements.span()).into(); } ref mut condition => { // switch const { case if condition => stmt, _ => def } => if condition { stmt } else { def } optimize_expr(condition, state, false); - let def_case = &mut blocks_list[*def_case].statements; + let def_case = &mut blocks[*def_case].statements; let def_span = def_case.span_or_else(*pos, Position::NONE); let def_case: StmtBlockContainer = mem::take(def_case); let def_stmt = optimize_stmt_block(def_case, state, true, true, false); *stmt = Stmt::If( ( mem::take(condition), - mem::take(&mut block.statements), + mem::take(&mut b.statements), StmtBlock::new_with_span(def_stmt, def_span), ) .into(), @@ -580,19 +580,16 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b // Only one range or all ranges without conditions if ranges.len() == 1 || ranges.iter().all(|r| { - matches!( - blocks_list[r.index()].condition, - Expr::BoolConstant(true, ..) - ) + matches!(blocks[r.index()].condition, Expr::BoolConstant(true, ..)) }) { for r in ranges.iter().filter(|r| r.contains(value)) { - let condition = mem::take(&mut blocks_list[r.index()].condition); + let condition = mem::take(&mut blocks[r.index()].condition); match condition { Expr::BoolConstant(true, ..) => { // Promote the matched case - let block = &mut blocks_list[r.index()]; + let block = &mut blocks[r.index()]; let statements = mem::take(&mut *block.statements); let statements = optimize_stmt_block(statements, state, true, true, false); @@ -602,13 +599,13 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b // switch const { range if condition => stmt, _ => def } => if condition { stmt } else { def } optimize_expr(&mut condition, state, false); - let def_case = &mut blocks_list[*def_case].statements; + let def_case = &mut blocks[*def_case].statements; let def_span = def_case.span_or_else(*pos, Position::NONE); let def_case: StmtBlockContainer = mem::take(def_case); let def_stmt = optimize_stmt_block(def_case, state, true, true, false); - let statements = mem::take(&mut blocks_list[r.index()].statements); + let statements = mem::take(&mut blocks[r.index()].statements); *stmt = Stmt::If( ( @@ -641,16 +638,16 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b } for r in &*ranges { - let block = &mut blocks_list[r.index()]; - let statements = mem::take(&mut *block.statements); - *block.statements = + let b = &mut blocks[r.index()]; + let statements = mem::take(&mut *b.statements); + *b.statements = optimize_stmt_block(statements, state, preserve_result, true, false); - optimize_expr(&mut block.condition, state, false); + optimize_expr(&mut b.condition, state, false); - match block.condition { + match b.condition { Expr::Unit(pos) => { - block.condition = Expr::BoolConstant(true, pos); + b.condition = Expr::BoolConstant(true, pos); state.set_dirty() } _ => (), @@ -662,7 +659,7 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b // Promote the default case state.set_dirty(); - let def_case = &mut blocks_list[*def_case].statements; + let def_case = &mut blocks[*def_case].statements; let def_span = def_case.span_or_else(*pos, Position::NONE); let def_case: StmtBlockContainer = mem::take(def_case); let def_stmt = optimize_stmt_block(def_case, state, true, true, false); @@ -673,7 +670,7 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b let ( match_expr, SwitchCases { - blocks: blocks_list, + blocks, cases, ranges, def_case, @@ -684,21 +681,21 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b optimize_expr(match_expr, state, false); // Optimize blocks - for block in blocks_list.iter_mut() { - let statements = mem::take(&mut *block.statements); - *block.statements = + for b in blocks.iter_mut() { + let statements = mem::take(&mut *b.statements); + *b.statements = optimize_stmt_block(statements, state, preserve_result, true, false); - optimize_expr(&mut block.condition, state, false); + optimize_expr(&mut b.condition, state, false); - match block.condition { + match b.condition { Expr::Unit(pos) => { - block.condition = Expr::BoolConstant(true, pos); + b.condition = Expr::BoolConstant(true, pos); state.set_dirty(); } Expr::BoolConstant(false, ..) => { - if !block.statements.is_empty() { - block.statements = StmtBlock::NONE; + if !b.statements.is_empty() { + b.statements = StmtBlock::NONE; state.set_dirty(); } } @@ -707,7 +704,7 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b } // Remove false cases - cases.retain(|_, &mut block| match blocks_list[block].condition { + cases.retain(|_, &mut block| match blocks[block].condition { Expr::BoolConstant(false, ..) => { state.set_dirty(); false @@ -715,7 +712,7 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b _ => true, }); // Remove false ranges - ranges.retain(|r| match blocks_list[r.index()].condition { + ranges.retain(|r| match blocks[r.index()].condition { Expr::BoolConstant(false, ..) => { state.set_dirty(); false @@ -723,9 +720,26 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b _ => true, }); - let def_case = &mut blocks_list[*def_case].statements; - let def_block = mem::take(&mut **def_case); - **def_case = optimize_stmt_block(def_block, state, preserve_result, true, false); + let def_stmt_block = &mut blocks[*def_case].statements; + let def_block = mem::take(&mut **def_stmt_block); + **def_stmt_block = optimize_stmt_block(def_block, state, preserve_result, true, false); + + // Remove unused block statements + for index in 0..blocks.len() { + if *def_case == index + || cases.values().any(|&n| n == index) + || ranges.iter().any(|r| r.index() == index) + { + continue; + } + + let b = &mut blocks[index]; + + if !b.statements.is_empty() { + b.statements = StmtBlock::NONE; + state.set_dirty(); + } + } } // while false { block } -> Noop diff --git a/src/parser.rs b/src/parser.rs index 119f207d..a59a2e81 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -39,6 +39,9 @@ const SCOPE_SEARCH_BARRIER_MARKER: &str = "$ BARRIER $"; /// The message: `TokenStream` never ends const NEVER_ENDS: &str = "`Token`"; +/// Unroll `switch` ranges no larger than this. +const SMALL_SWITCH_RANGE: usize = 16; + /// _(internals)_ A type that encapsulates the current state of the parser. /// Exported under the `internals` feature only. pub struct ParseState<'e> { @@ -1138,6 +1141,7 @@ impl Engine { let stmt = self.parse_stmt(input, state, lib, settings.level_up())?; let need_comma = !stmt.is_self_terminated(); + let has_condition = !matches!(condition, Expr::BoolConstant(true, ..)); blocks.push((condition, stmt).into()); let index = blocks.len() - 1; @@ -1159,14 +1163,15 @@ impl Engine { if let Some(mut r) = range_value { if !r.is_empty() { - if let Some(n) = r.single_int() { - // Unroll single range - let value = Dynamic::from_int(n); - let hasher = &mut get_hasher(); - value.hash(hasher); - let hash = hasher.finish(); - - cases.entry(hash).or_insert(index); + // Do not unroll ranges if there are previous non-unrolled ranges + if !has_condition && ranges.is_empty() && r.len() <= SMALL_SWITCH_RANGE + { + // Unroll small range + for n in r { + let hasher = &mut get_hasher(); + Dynamic::from_int(n).hash(hasher); + cases.entry(hasher.finish()).or_insert(index); + } } else { // Other range r.set_index(index); diff --git a/tests/switch.rs b/tests/switch.rs index 301401a8..5d7f18fa 100644 --- a/tests/switch.rs +++ b/tests/switch.rs @@ -290,6 +290,49 @@ fn test_switch_ranges() -> Result<(), Box> { )?, 'x' ); + assert_eq!( + engine.eval_with_scope::( + &mut scope, + " + switch 5 { + 'a' => true, + 0..10 => 123, + 2..12 => 'z', + _ => 'x' + } + " + )?, + 123 + ); + assert_eq!( + engine.eval_with_scope::( + &mut scope, + " + switch 5 { + 'a' => true, + 4 | 5 | 6 => 42, + 0..10 => 123, + 2..12 => 'z', + _ => 'x' + } + " + )?, + 42 + ); + assert_eq!( + engine.eval_with_scope::( + &mut scope, + " + switch 5 { + 'a' => true, + 2..12 => 'z', + 0..10 if x+2==1+2 => print(40+2), + _ => 'x' + } + " + )?, + 'z' + ); assert_eq!( engine.eval_with_scope::( &mut scope,