Simplify switch condition.

This commit is contained in:
Stephen Chung 2022-04-19 16:20:43 +08:00
parent 40c4906336
commit 770b2e04cc
4 changed files with 131 additions and 136 deletions

View File

@ -19,7 +19,7 @@ use std::{
/// Exported under the `internals` feature only. /// Exported under the `internals` feature only.
/// ///
/// This type may hold a straight assignment (i.e. not an op-assignment). /// This type may hold a straight assignment (i.e. not an op-assignment).
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] #[derive(Clone, Copy, Eq, PartialEq, Hash)]
pub struct OpAssignment<'a> { pub struct OpAssignment<'a> {
/// Hash of the op-assignment call. /// Hash of the op-assignment call.
pub hash_op_assign: u64, pub hash_op_assign: u64,
@ -106,11 +106,27 @@ impl OpAssignment<'_> {
} }
} }
impl fmt::Debug for OpAssignment<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.is_op_assignment() {
f.debug_struct("OpAssignment")
.field("hash_op_assign", &self.hash_op_assign)
.field("hash_op", &self.hash_op)
.field("op_assign", &self.op_assign)
.field("op", &self.op)
.field("pos", &self.pos)
.finish()
} else {
fmt::Debug::fmt(&self.pos, f)
}
}
}
/// A statements block with an optional condition. /// A statements block with an optional condition.
#[derive(Debug, Clone, Hash)] #[derive(Debug, Clone, Hash)]
pub struct ConditionalStmtBlock { pub struct ConditionalStmtBlock {
/// Optional condition. /// Condition.
pub condition: Option<Expr>, pub condition: Expr,
/// Statements block. /// Statements block.
pub statements: StmtBlock, pub statements: StmtBlock,
} }
@ -119,7 +135,7 @@ impl<B: Into<StmtBlock>> From<B> for ConditionalStmtBlock {
#[inline(always)] #[inline(always)]
fn from(value: B) -> Self { fn from(value: B) -> Self {
Self { Self {
condition: None, condition: Expr::BoolConstant(true, Position::NONE),
statements: value.into(), statements: value.into(),
} }
} }
@ -128,16 +144,6 @@ impl<B: Into<StmtBlock>> From<B> for ConditionalStmtBlock {
impl<B: Into<StmtBlock>> From<(Expr, B)> for ConditionalStmtBlock { impl<B: Into<StmtBlock>> From<(Expr, B)> for ConditionalStmtBlock {
#[inline(always)] #[inline(always)]
fn from(value: (Expr, B)) -> Self { fn from(value: (Expr, B)) -> Self {
Self {
condition: Some(value.0),
statements: value.1.into(),
}
}
}
impl<B: Into<StmtBlock>> From<(Option<Expr>, B)> for ConditionalStmtBlock {
#[inline(always)]
fn from(value: (Option<Expr>, B)) -> Self {
Self { Self {
condition: value.0, condition: value.0,
statements: value.1.into(), statements: value.1.into(),
@ -145,15 +151,6 @@ impl<B: Into<StmtBlock>> From<(Option<Expr>, B)> for ConditionalStmtBlock {
} }
} }
impl ConditionalStmtBlock {
/// Does the condition exist?
#[inline(always)]
#[must_use]
pub const fn has_condition(&self) -> bool {
self.condition.is_some()
}
}
/// _(internals)_ A type containing all cases for a `switch` statement. /// _(internals)_ A type containing all cases for a `switch` statement.
/// Exported under the `internals` feature only. /// Exported under the `internals` feature only.
#[derive(Debug, Clone, Hash)] #[derive(Debug, Clone, Hash)]
@ -621,12 +618,10 @@ impl Stmt {
Self::Switch(x, ..) => { Self::Switch(x, ..) => {
x.0.is_pure() x.0.is_pure()
&& x.1.cases.values().all(|block| { && x.1.cases.values().all(|block| {
block.condition.as_ref().map(Expr::is_pure).unwrap_or(true) block.condition.is_pure() && block.statements.iter().all(Stmt::is_pure)
&& block.statements.iter().all(Stmt::is_pure)
}) })
&& x.1.ranges.iter().all(|(.., block)| { && x.1.ranges.iter().all(|(.., block)| {
block.condition.as_ref().map(Expr::is_pure).unwrap_or(true) block.condition.is_pure() && block.statements.iter().all(Stmt::is_pure)
&& block.statements.iter().all(Stmt::is_pure)
}) })
&& x.1.def_case.iter().all(Stmt::is_pure) && x.1.def_case.iter().all(Stmt::is_pure)
} }
@ -768,12 +763,7 @@ impl Stmt {
return false; return false;
} }
for b in x.1.cases.values() { for b in x.1.cases.values() {
if !b if !b.condition.walk(path, on_node) {
.condition
.as_ref()
.map(|e| e.walk(path, on_node))
.unwrap_or(true)
{
return false; return false;
} }
for s in b.statements.iter() { for s in b.statements.iter() {
@ -783,12 +773,7 @@ impl Stmt {
} }
} }
for (.., b) in &x.1.ranges { for (.., b) in &x.1.ranges {
if !b if !b.condition.walk(path, on_node) {
.condition
.as_ref()
.map(|e| e.walk(path, on_node))
.unwrap_or(true)
{
return false; return false;
} }
for s in b.statements.iter() { for s in b.statements.iter() {

View File

@ -392,23 +392,16 @@ impl Engine {
// First check hashes // First check hashes
if let Some(case_block) = cases.get(&hash) { if let Some(case_block) = cases.get(&hash) {
let cond_result = case_block let cond_result = match case_block.condition {
.condition Expr::BoolConstant(b, ..) => Ok(b),
.as_ref() ref c => self
.map(|cond| { .eval_expr(scope, global, caches, lib, this_ptr, c, level)
self.eval_expr(
scope, global, caches, lib, this_ptr, cond, level,
)
.and_then(|v| { .and_then(|v| {
v.as_bool().map_err(|typ| { v.as_bool().map_err(|typ| {
self.make_type_mismatch_err::<bool>( self.make_type_mismatch_err::<bool>(typ, c.position())
typ,
cond.position(),
)
}) })
}) }),
}) };
.unwrap_or(Ok(true));
match cond_result { match cond_result {
Ok(true) => Ok(Some(&case_block.statements)), Ok(true) => Ok(Some(&case_block.statements)),
@ -426,23 +419,19 @@ impl Engine {
|| (inclusive && (start..=end).contains(&value)) || (inclusive && (start..=end).contains(&value))
}) })
{ {
let cond_result = block let cond_result = match block.condition {
.condition Expr::BoolConstant(b, ..) => Ok(b),
.as_ref() ref c => self
.map(|cond| { .eval_expr(scope, global, caches, lib, this_ptr, c, level)
self.eval_expr(
scope, global, caches, lib, this_ptr, cond, level,
)
.and_then(|v| { .and_then(|v| {
v.as_bool().map_err(|typ| { v.as_bool().map_err(|typ| {
self.make_type_mismatch_err::<bool>( self.make_type_mismatch_err::<bool>(
typ, typ,
cond.position(), c.position(),
) )
}) })
}) }),
}) };
.unwrap_or(Ok(true));
match cond_result { match cond_result {
Ok(true) => result = Ok(Some(&block.statements)), Ok(true) => result = Ok(Some(&block.statements)),

View File

@ -531,35 +531,38 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b
// First check hashes // First check hashes
if let Some(block) = cases.get_mut(&hash) { if let Some(block) = cases.get_mut(&hash) {
if let Some(mut condition) = mem::take(&mut block.condition) { match mem::take(&mut block.condition) {
// switch const { case if condition => stmt, _ => def } => if condition { stmt } else { def } Expr::BoolConstant(true, ..) => {
optimize_expr(&mut condition, state, false); // Promote the matched case
let statements = optimize_stmt_block(
let def_stmt =
optimize_stmt_block(mem::take(def_case), state, true, true, false);
*stmt = Stmt::If(
(
condition,
mem::take(&mut block.statements), mem::take(&mut block.statements),
StmtBlock::new_with_span( state,
def_stmt, true,
def_case.span_or_else(*pos, Position::NONE), true,
), false,
) );
.into(), *stmt = (statements, block.statements.span()).into();
match_expr.start_position(), }
); mut condition => {
} else { // switch const { case if condition => stmt, _ => def } => if condition { stmt } else { def }
// Promote the matched case optimize_expr(&mut condition, state, false);
let statements = optimize_stmt_block(
mem::take(&mut block.statements), let def_stmt =
state, optimize_stmt_block(mem::take(def_case), state, true, true, false);
true,
true, *stmt = Stmt::If(
false, (
); condition,
*stmt = (statements, block.statements.span()).into(); mem::take(&mut block.statements),
StmtBlock::new_with_span(
def_stmt,
def_case.span_or_else(*pos, Position::NONE),
),
)
.into(),
match_expr.start_position(),
);
}
} }
state.set_dirty(); state.set_dirty();
@ -571,7 +574,11 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b
let value = value.as_int().expect("`INT`"); let value = value.as_int().expect("`INT`");
// Only one range or all ranges without conditions // Only one range or all ranges without conditions
if ranges.len() == 1 || ranges.iter().all(|(.., c)| !c.has_condition()) { if ranges.len() == 1
|| ranges
.iter()
.all(|(.., c)| matches!(c.condition, Expr::BoolConstant(true, ..)))
{
for (.., block) in for (.., block) in
ranges ranges
.iter_mut() .iter_mut()
@ -580,30 +587,38 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b
|| (inclusive && (start..=end).contains(&value)) || (inclusive && (start..=end).contains(&value))
}) })
{ {
if let Some(mut condition) = mem::take(&mut block.condition) { match mem::take(&mut block.condition) {
// switch const { range if condition => stmt, _ => def } => if condition { stmt } else { def } Expr::BoolConstant(true, ..) => {
optimize_expr(&mut condition, state, false); // Promote the matched case
let statements = mem::take(&mut *block.statements);
let statements =
optimize_stmt_block(statements, state, true, true, false);
*stmt = (statements, block.statements.span()).into();
}
mut condition => {
// switch const { range if condition => stmt, _ => def } => if condition { stmt } else { def }
optimize_expr(&mut condition, state, false);
let def_stmt = let def_stmt = optimize_stmt_block(
optimize_stmt_block(mem::take(def_case), state, true, true, false); mem::take(def_case),
*stmt = Stmt::If( state,
( true,
condition, true,
mem::take(&mut block.statements), false,
StmtBlock::new_with_span( );
def_stmt, *stmt = Stmt::If(
def_case.span_or_else(*pos, Position::NONE), (
), condition,
) mem::take(&mut block.statements),
.into(), StmtBlock::new_with_span(
match_expr.start_position(), def_stmt,
); def_case.span_or_else(*pos, Position::NONE),
} else { ),
// Promote the matched case )
let statements = mem::take(&mut *block.statements); .into(),
let statements = match_expr.start_position(),
optimize_stmt_block(statements, state, true, true, false); );
*stmt = (statements, block.statements.span()).into(); }
} }
state.set_dirty(); state.set_dirty();
@ -632,12 +647,14 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b
*block.statements = *block.statements =
optimize_stmt_block(statements, state, preserve_result, true, false); optimize_stmt_block(statements, state, preserve_result, true, false);
if let Some(mut condition) = mem::take(&mut block.condition) { optimize_expr(&mut block.condition, state, false);
optimize_expr(&mut condition, state, false);
match condition { match block.condition {
Expr::Unit(..) | Expr::BoolConstant(true, ..) => state.set_dirty(), Expr::Unit(pos) => {
_ => block.condition = Some(condition), block.condition = Expr::BoolConstant(true, pos);
state.set_dirty()
} }
_ => (),
} }
} }
return; return;
@ -669,18 +686,20 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b
*block.statements = *block.statements =
optimize_stmt_block(statements, state, preserve_result, true, false); optimize_stmt_block(statements, state, preserve_result, true, false);
if let Some(mut condition) = mem::take(&mut block.condition) { optimize_expr(&mut block.condition, state, false);
optimize_expr(&mut condition, state, false);
match condition { match block.condition {
Expr::Unit(..) | Expr::BoolConstant(true, ..) => state.set_dirty(), Expr::Unit(pos) => {
_ => block.condition = Some(condition), block.condition = Expr::BoolConstant(true, pos);
state.set_dirty();
} }
_ => (),
} }
} }
// Remove false cases // Remove false cases
cases.retain(|_, block| match block.condition { cases.retain(|_, block| match block.condition {
Some(Expr::BoolConstant(false, ..)) => { Expr::BoolConstant(false, ..) => {
state.set_dirty(); state.set_dirty();
false false
} }
@ -693,18 +712,20 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b
*block.statements = *block.statements =
optimize_stmt_block(statements, state, preserve_result, true, false); optimize_stmt_block(statements, state, preserve_result, true, false);
if let Some(mut condition) = mem::take(&mut block.condition) { optimize_expr(&mut block.condition, state, false);
optimize_expr(&mut condition, state, false);
match condition { match block.condition {
Expr::Unit(..) | Expr::BoolConstant(true, ..) => state.set_dirty(), Expr::Unit(pos) => {
_ => block.condition = Some(condition), block.condition = Expr::BoolConstant(true, pos);
state.set_dirty();
} }
_ => (),
} }
} }
// Remove false ranges // Remove false ranges
ranges.retain(|(.., block)| match block.condition { ranges.retain(|(.., block)| match block.condition {
Some(Expr::BoolConstant(false, ..)) => { Expr::BoolConstant(false, ..) => {
state.set_dirty(); state.set_dirty();
false false
} }

View File

@ -1041,7 +1041,7 @@ impl Engine {
return Err(PERR::WrongSwitchCaseCondition.into_err(if_pos)); return Err(PERR::WrongSwitchCaseCondition.into_err(if_pos));
} }
(None, None) (None, Expr::BoolConstant(true, Position::NONE))
} }
(Token::Underscore, pos) => return Err(PERR::DuplicatedSwitchCase.into_err(*pos)), (Token::Underscore, pos) => return Err(PERR::DuplicatedSwitchCase.into_err(*pos)),
@ -1054,9 +1054,9 @@ impl Engine {
Some(self.parse_expr(input, state, lib, settings.level_up())?); Some(self.parse_expr(input, state, lib, settings.level_up())?);
let condition = if match_token(input, Token::If).0 { let condition = if match_token(input, Token::If).0 {
Some(self.parse_expr(input, state, lib, settings.level_up())?) self.parse_expr(input, state, lib, settings.level_up())?
} else { } else {
None Expr::BoolConstant(true, Position::NONE)
}; };
(case_expr, condition) (case_expr, condition)
} }