From 281e94fc6237b847841003dcfefc44745418ac4c Mon Sep 17 00:00:00 2001 From: Stephen Chung Date: Mon, 18 Jul 2022 23:28:12 +0800 Subject: [PATCH] Switch case condition that is constant () no longer optimizes to false. --- CHANGELOG.md | 5 ++ src/ast/stmt.rs | 21 ++++++ src/optimizer.rs | 185 ++++++++++++++++++++--------------------------- 3 files changed, 105 insertions(+), 106 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bc16d51d..88db90a7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,11 @@ Version 1.9.0 The minimum Rust version is now `1.60.0` in order to use the `dep:` syntax for dependencies. +Bug fixes +--------- + +* `switch` cases with conditions that evaluate to constant `()` no longer optimize to `false` (should raise a type error during runtime). + New features ------------ diff --git a/src/ast/stmt.rs b/src/ast/stmt.rs index d214f95c..49e3b0ee 100644 --- a/src/ast/stmt.rs +++ b/src/ast/stmt.rs @@ -153,6 +153,27 @@ impl> From<(Expr, B)> for ConditionalStmtBlock { } } +impl ConditionalStmtBlock { + /// Is this conditional statements block always `true`? + #[inline(always)] + #[must_use] + pub fn is_always_true(&self) -> bool { + match self.condition { + Expr::BoolConstant(true, ..) => true, + _ => false, + } + } + /// Is this conditional statements block always `false`? + #[inline(always)] + #[must_use] + pub fn is_always_false(&self) -> bool { + match self.condition { + Expr::BoolConstant(false, ..) => true, + _ => false, + } + } +} + /// _(internals)_ A type containing a range case for a `switch` statement. /// Exported under the `internals` feature only. #[derive(Clone, Hash)] diff --git a/src/optimizer.rs b/src/optimizer.rs index 1ed3b0b1..350ef1c5 100644 --- a/src/optimizer.rs +++ b/src/optimizer.rs @@ -547,39 +547,36 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b let mut b = mem::take(&mut case_blocks[*index]); cases.clear(); - match b.condition { - Expr::BoolConstant(true, ..) => { - // Promote the matched case - let statements: StmtBlockContainer = mem::take(&mut b.statements); - let statements = - optimize_stmt_block(statements, state, true, true, false); - *stmt = (statements, b.statements.span()).into(); - } - ref mut condition => { - // switch const { case if condition => stmt, _ => def } => if condition { stmt } else { def } - optimize_expr(condition, state, false); + if b.is_always_true() { + // Promote the matched case + let statements: StmtBlockContainer = mem::take(&mut b.statements); + let statements = + optimize_stmt_block(statements, state, true, true, false); + *stmt = (statements, b.statements.span()).into(); + } else { + // switch const { case if condition => stmt, _ => def } => if condition { stmt } else { def } + optimize_expr(&mut b.condition, state, false); - let else_stmt = if let Some(index) = def_case { - let def_case = &mut case_blocks[*index].statements; - let def_span = def_case.span_or_else(*pos, Position::NONE); - let def_case: StmtBlockContainer = mem::take(def_case); - let def_stmt = - optimize_stmt_block(def_case, state, true, true, false); - StmtBlock::new_with_span(def_stmt, def_span) - } else { - StmtBlock::NONE - }; + let else_stmt = if let Some(index) = def_case { + let def_case = &mut case_blocks[*index].statements; + let def_span = def_case.span_or_else(*pos, Position::NONE); + let def_case: StmtBlockContainer = mem::take(def_case); + let def_stmt = + optimize_stmt_block(def_case, state, true, true, false); + StmtBlock::new_with_span(def_stmt, def_span) + } else { + StmtBlock::NONE + }; - *stmt = Stmt::If( - ( - mem::take(condition), - mem::take(&mut b.statements), - else_stmt, - ) - .into(), - match_expr.start_position(), - ); - } + *stmt = Stmt::If( + ( + mem::take(&mut b.condition), + mem::take(&mut b.statements), + else_stmt, + ) + .into(), + match_expr.start_position(), + ); } state.set_dirty(); @@ -589,18 +586,14 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b for &index in case_blocks_list { let mut b = mem::take(&mut case_blocks[index]); - match b.condition { - Expr::BoolConstant(true, ..) => { - // Promote the matched case - let statements: StmtBlockContainer = - mem::take(&mut b.statements); - let statements = - optimize_stmt_block(statements, state, true, true, false); - *stmt = (statements, b.statements.span()).into(); - state.set_dirty(); - return; - } - _ => (), + if b.is_always_true() { + // Promote the matched case + let statements: StmtBlockContainer = mem::take(&mut b.statements); + let statements = + optimize_stmt_block(statements, state, true, true, false); + *stmt = (statements, b.statements.span()).into(); + state.set_dirty(); + return; } } } @@ -613,47 +606,43 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b // Only one range or all ranges without conditions if ranges.len() == 1 - || ranges.iter().all(|r| { - matches!( - case_blocks[r.index()].condition, - Expr::BoolConstant(true, ..) - ) - }) + || ranges + .iter() + .all(|r| case_blocks[r.index()].is_always_true()) { for r in ranges.iter().filter(|r| r.contains(value)) { - let condition = mem::take(&mut case_blocks[r.index()].condition); + let range_block = &mut case_blocks[r.index()]; - match condition { - Expr::BoolConstant(true, ..) => { - // Promote the matched case - let block = &mut case_blocks[r.index()]; - let statements = mem::take(&mut *block.statements); - let statements = - optimize_stmt_block(statements, state, true, true, false); - *stmt = (statements, block.statements.span()).into(); - } - mut condition => { - // switch const { range if condition => stmt, _ => def } => if condition { stmt } else { def } - optimize_expr(&mut condition, state, false); + if range_block.is_always_true() { + // Promote the matched case + let block = &mut case_blocks[r.index()]; + let statements = mem::take(&mut *block.statements); + let statements = + optimize_stmt_block(statements, state, true, true, false); + *stmt = (statements, block.statements.span()).into(); + } else { + let mut condition = mem::take(&mut range_block.condition); - let else_stmt = if let Some(index) = def_case { - let def_case = &mut case_blocks[*index].statements; - let def_span = def_case.span_or_else(*pos, Position::NONE); - let def_case: StmtBlockContainer = mem::take(def_case); - let def_stmt = - optimize_stmt_block(def_case, state, true, true, false); - StmtBlock::new_with_span(def_stmt, def_span) - } else { - StmtBlock::NONE - }; + // switch const { range if condition => stmt, _ => def } => if condition { stmt } else { def } + optimize_expr(&mut condition, state, false); - let statements = mem::take(&mut case_blocks[r.index()].statements); + let else_stmt = if let Some(index) = def_case { + let def_case = &mut case_blocks[*index].statements; + let def_span = def_case.span_or_else(*pos, Position::NONE); + let def_case: StmtBlockContainer = mem::take(def_case); + let def_stmt = + optimize_stmt_block(def_case, state, true, true, false); + StmtBlock::new_with_span(def_stmt, def_span) + } else { + StmtBlock::NONE + }; - *stmt = Stmt::If( - (condition, statements, else_stmt).into(), - match_expr.start_position(), - ); - } + let statements = mem::take(&mut case_blocks[r.index()].statements); + + *stmt = Stmt::If( + (condition, statements, else_stmt).into(), + match_expr.start_position(), + ); } state.set_dirty(); @@ -681,14 +670,6 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b optimize_stmt_block(statements, state, preserve_result, true, false); optimize_expr(&mut b.condition, state, false); - - match b.condition { - Expr::Unit(pos) => { - b.condition = Expr::BoolConstant(true, pos); - state.set_dirty() - } - _ => (), - } } return; } @@ -730,38 +711,29 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b optimize_expr(&mut b.condition, state, false); - match b.condition { - Expr::Unit(pos) => { - b.condition = Expr::BoolConstant(true, pos); + if b.is_always_false() { + if !b.statements.is_empty() { + b.statements = StmtBlock::NONE; state.set_dirty(); } - Expr::BoolConstant(false, ..) => { - if !b.statements.is_empty() { - b.statements = StmtBlock::NONE; - state.set_dirty(); - } - } - _ => (), } } // Remove false cases cases.retain(|_, list| { // Remove all entries that have false conditions - list.retain(|index| match case_blocks[*index].condition { - Expr::BoolConstant(false, ..) => { + list.retain(|index| { + if case_blocks[*index].is_always_false() { state.set_dirty(); false + } else { + true } - _ => true, }); // Remove all entries after a `true` condition if let Some(n) = list .iter() - .find(|&&index| match case_blocks[index].condition { - Expr::BoolConstant(true, ..) => true, - _ => false, - }) + .find(|&&index| case_blocks[index].is_always_true()) { if n + 1 < list.len() { state.set_dirty(); @@ -779,12 +751,13 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b }); // Remove false ranges - ranges.retain(|r| match case_blocks[r.index()].condition { - Expr::BoolConstant(false, ..) => { + ranges.retain(|r| { + if case_blocks[r.index()].is_always_false() { state.set_dirty(); false + } else { + true } - _ => true, }); if let Some(index) = def_case {