diff --git a/src/ast/stmt.rs b/src/ast/stmt.rs index d855e32b..3b73a91a 100644 --- a/src/ast/stmt.rs +++ b/src/ast/stmt.rs @@ -19,7 +19,7 @@ use std::{ /// Exported under the `internals` feature only. /// /// This type may hold a straight assignment (i.e. not an op-assignment). -#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] +#[derive(Clone, Copy, Eq, PartialEq, Hash)] pub struct OpAssignment<'a> { /// Hash of the op-assignment call. pub hash_op_assign: u64, @@ -106,11 +106,27 @@ impl OpAssignment<'_> { } } +impl fmt::Debug for OpAssignment<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + if self.is_op_assignment() { + f.debug_struct("OpAssignment") + .field("hash_op_assign", &self.hash_op_assign) + .field("hash_op", &self.hash_op) + .field("op_assign", &self.op_assign) + .field("op", &self.op) + .field("pos", &self.pos) + .finish() + } else { + fmt::Debug::fmt(&self.pos, f) + } + } +} + /// A statements block with an optional condition. #[derive(Debug, Clone, Hash)] pub struct ConditionalStmtBlock { - /// Optional condition. - pub condition: Option, + /// Condition. + pub condition: Expr, /// Statements block. pub statements: StmtBlock, } @@ -119,7 +135,7 @@ impl> From for ConditionalStmtBlock { #[inline(always)] fn from(value: B) -> Self { Self { - condition: None, + condition: Expr::BoolConstant(true, Position::NONE), statements: value.into(), } } @@ -128,16 +144,6 @@ impl> From for ConditionalStmtBlock { impl> From<(Expr, B)> for ConditionalStmtBlock { #[inline(always)] fn from(value: (Expr, B)) -> Self { - Self { - condition: Some(value.0), - statements: value.1.into(), - } - } -} - -impl> From<(Option, B)> for ConditionalStmtBlock { - #[inline(always)] - fn from(value: (Option, B)) -> Self { Self { condition: value.0, statements: value.1.into(), @@ -145,15 +151,6 @@ impl> From<(Option, B)> for ConditionalStmtBlock { } } -impl ConditionalStmtBlock { - /// Does the condition exist? - #[inline(always)] - #[must_use] - pub const fn has_condition(&self) -> bool { - self.condition.is_some() - } -} - /// _(internals)_ A type containing all cases for a `switch` statement. /// Exported under the `internals` feature only. #[derive(Debug, Clone, Hash)] @@ -621,12 +618,10 @@ impl Stmt { Self::Switch(x, ..) => { x.0.is_pure() && x.1.cases.values().all(|block| { - block.condition.as_ref().map(Expr::is_pure).unwrap_or(true) - && block.statements.iter().all(Stmt::is_pure) + block.condition.is_pure() && block.statements.iter().all(Stmt::is_pure) }) && x.1.ranges.iter().all(|(.., block)| { - block.condition.as_ref().map(Expr::is_pure).unwrap_or(true) - && block.statements.iter().all(Stmt::is_pure) + block.condition.is_pure() && block.statements.iter().all(Stmt::is_pure) }) && x.1.def_case.iter().all(Stmt::is_pure) } @@ -768,12 +763,7 @@ impl Stmt { return false; } for b in x.1.cases.values() { - if !b - .condition - .as_ref() - .map(|e| e.walk(path, on_node)) - .unwrap_or(true) - { + if !b.condition.walk(path, on_node) { return false; } for s in b.statements.iter() { @@ -783,12 +773,7 @@ impl Stmt { } } for (.., b) in &x.1.ranges { - if !b - .condition - .as_ref() - .map(|e| e.walk(path, on_node)) - .unwrap_or(true) - { + if !b.condition.walk(path, on_node) { return false; } for s in b.statements.iter() { diff --git a/src/eval/stmt.rs b/src/eval/stmt.rs index 5e8bc8f0..bbdf5f41 100644 --- a/src/eval/stmt.rs +++ b/src/eval/stmt.rs @@ -392,23 +392,16 @@ impl Engine { // First check hashes if let Some(case_block) = cases.get(&hash) { - let cond_result = case_block - .condition - .as_ref() - .map(|cond| { - self.eval_expr( - scope, global, caches, lib, this_ptr, cond, level, - ) + let cond_result = match case_block.condition { + Expr::BoolConstant(b, ..) => Ok(b), + ref c => self + .eval_expr(scope, global, caches, lib, this_ptr, c, level) .and_then(|v| { v.as_bool().map_err(|typ| { - self.make_type_mismatch_err::( - typ, - cond.position(), - ) + self.make_type_mismatch_err::(typ, c.position()) }) - }) - }) - .unwrap_or(Ok(true)); + }), + }; match cond_result { Ok(true) => Ok(Some(&case_block.statements)), @@ -426,23 +419,19 @@ impl Engine { || (inclusive && (start..=end).contains(&value)) }) { - let cond_result = block - .condition - .as_ref() - .map(|cond| { - self.eval_expr( - scope, global, caches, lib, this_ptr, cond, level, - ) + let cond_result = match block.condition { + Expr::BoolConstant(b, ..) => Ok(b), + ref c => self + .eval_expr(scope, global, caches, lib, this_ptr, c, level) .and_then(|v| { v.as_bool().map_err(|typ| { self.make_type_mismatch_err::( typ, - cond.position(), + c.position(), ) }) - }) - }) - .unwrap_or(Ok(true)); + }), + }; match cond_result { Ok(true) => result = Ok(Some(&block.statements)), diff --git a/src/optimizer.rs b/src/optimizer.rs index 3936b91c..403b1564 100644 --- a/src/optimizer.rs +++ b/src/optimizer.rs @@ -531,35 +531,38 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b // First check hashes if let Some(block) = cases.get_mut(&hash) { - if let Some(mut condition) = mem::take(&mut block.condition) { - // switch const { case if condition => stmt, _ => def } => if condition { stmt } else { def } - optimize_expr(&mut condition, state, false); - - let def_stmt = - optimize_stmt_block(mem::take(def_case), state, true, true, false); - - *stmt = Stmt::If( - ( - condition, + match mem::take(&mut block.condition) { + Expr::BoolConstant(true, ..) => { + // Promote the matched case + let statements = optimize_stmt_block( mem::take(&mut block.statements), - StmtBlock::new_with_span( - def_stmt, - def_case.span_or_else(*pos, Position::NONE), - ), - ) - .into(), - match_expr.start_position(), - ); - } else { - // Promote the matched case - let statements = optimize_stmt_block( - mem::take(&mut block.statements), - state, - true, - true, - false, - ); - *stmt = (statements, block.statements.span()).into(); + state, + true, + true, + false, + ); + *stmt = (statements, block.statements.span()).into(); + } + mut condition => { + // switch const { case if condition => stmt, _ => def } => if condition { stmt } else { def } + optimize_expr(&mut condition, state, false); + + let def_stmt = + optimize_stmt_block(mem::take(def_case), state, true, true, false); + + *stmt = Stmt::If( + ( + condition, + mem::take(&mut block.statements), + StmtBlock::new_with_span( + def_stmt, + def_case.span_or_else(*pos, Position::NONE), + ), + ) + .into(), + match_expr.start_position(), + ); + } } state.set_dirty(); @@ -571,7 +574,11 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b let value = value.as_int().expect("`INT`"); // Only one range or all ranges without conditions - if ranges.len() == 1 || ranges.iter().all(|(.., c)| !c.has_condition()) { + if ranges.len() == 1 + || ranges + .iter() + .all(|(.., c)| matches!(c.condition, Expr::BoolConstant(true, ..))) + { for (.., block) in ranges .iter_mut() @@ -580,30 +587,38 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b || (inclusive && (start..=end).contains(&value)) }) { - if let Some(mut condition) = mem::take(&mut block.condition) { - // switch const { range if condition => stmt, _ => def } => if condition { stmt } else { def } - optimize_expr(&mut condition, state, false); + match mem::take(&mut block.condition) { + Expr::BoolConstant(true, ..) => { + // Promote the matched case + let statements = mem::take(&mut *block.statements); + let statements = + optimize_stmt_block(statements, state, true, true, false); + *stmt = (statements, block.statements.span()).into(); + } + mut condition => { + // switch const { range if condition => stmt, _ => def } => if condition { stmt } else { def } + optimize_expr(&mut condition, state, false); - let def_stmt = - optimize_stmt_block(mem::take(def_case), state, true, true, false); - *stmt = Stmt::If( - ( - condition, - mem::take(&mut block.statements), - StmtBlock::new_with_span( - def_stmt, - def_case.span_or_else(*pos, Position::NONE), - ), - ) - .into(), - match_expr.start_position(), - ); - } else { - // Promote the matched case - let statements = mem::take(&mut *block.statements); - let statements = - optimize_stmt_block(statements, state, true, true, false); - *stmt = (statements, block.statements.span()).into(); + let def_stmt = optimize_stmt_block( + mem::take(def_case), + state, + true, + true, + false, + ); + *stmt = Stmt::If( + ( + condition, + mem::take(&mut block.statements), + StmtBlock::new_with_span( + def_stmt, + def_case.span_or_else(*pos, Position::NONE), + ), + ) + .into(), + match_expr.start_position(), + ); + } } state.set_dirty(); @@ -632,12 +647,14 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b *block.statements = optimize_stmt_block(statements, state, preserve_result, true, false); - if let Some(mut condition) = mem::take(&mut block.condition) { - optimize_expr(&mut condition, state, false); - match condition { - Expr::Unit(..) | Expr::BoolConstant(true, ..) => state.set_dirty(), - _ => block.condition = Some(condition), + optimize_expr(&mut block.condition, state, false); + + match block.condition { + Expr::Unit(pos) => { + block.condition = Expr::BoolConstant(true, pos); + state.set_dirty() } + _ => (), } } return; @@ -669,18 +686,20 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b *block.statements = optimize_stmt_block(statements, state, preserve_result, true, false); - if let Some(mut condition) = mem::take(&mut block.condition) { - optimize_expr(&mut condition, state, false); - match condition { - Expr::Unit(..) | Expr::BoolConstant(true, ..) => state.set_dirty(), - _ => block.condition = Some(condition), + optimize_expr(&mut block.condition, state, false); + + match block.condition { + Expr::Unit(pos) => { + block.condition = Expr::BoolConstant(true, pos); + state.set_dirty(); } + _ => (), } } // Remove false cases cases.retain(|_, block| match block.condition { - Some(Expr::BoolConstant(false, ..)) => { + Expr::BoolConstant(false, ..) => { state.set_dirty(); false } @@ -693,18 +712,20 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b *block.statements = optimize_stmt_block(statements, state, preserve_result, true, false); - if let Some(mut condition) = mem::take(&mut block.condition) { - optimize_expr(&mut condition, state, false); - match condition { - Expr::Unit(..) | Expr::BoolConstant(true, ..) => state.set_dirty(), - _ => block.condition = Some(condition), + optimize_expr(&mut block.condition, state, false); + + match block.condition { + Expr::Unit(pos) => { + block.condition = Expr::BoolConstant(true, pos); + state.set_dirty(); } + _ => (), } } // Remove false ranges ranges.retain(|(.., block)| match block.condition { - Some(Expr::BoolConstant(false, ..)) => { + Expr::BoolConstant(false, ..) => { state.set_dirty(); false } diff --git a/src/parser.rs b/src/parser.rs index 1d684960..33c315f8 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -1041,7 +1041,7 @@ impl Engine { return Err(PERR::WrongSwitchCaseCondition.into_err(if_pos)); } - (None, None) + (None, Expr::BoolConstant(true, Position::NONE)) } (Token::Underscore, pos) => return Err(PERR::DuplicatedSwitchCase.into_err(*pos)), @@ -1054,9 +1054,9 @@ impl Engine { Some(self.parse_expr(input, state, lib, settings.level_up())?); let condition = if match_token(input, Token::If).0 { - Some(self.parse_expr(input, state, lib, settings.level_up())?) + self.parse_expr(input, state, lib, settings.level_up())? } else { - None + Expr::BoolConstant(true, Position::NONE) }; (case_expr, condition) }