From 10089c5cb0ff4e4b11cb56e709588e33df46fed7 Mon Sep 17 00:00:00 2001 From: Stephen Chung Date: Sun, 12 Feb 2023 23:20:14 +0800 Subject: [PATCH] Support switch range cases for floating-point values. --- CHANGELOG.md | 1 + src/ast/stmt.rs | 56 ++++++++++++++++++++++++++++++++++++++------ src/eval/stmt.rs | 17 +++++++------- src/func/hashing.rs | 1 + src/optimizer.rs | 8 +++---- src/parser.rs | 26 ++++---------------- tests/expressions.rs | 4 ++-- 7 files changed, 68 insertions(+), 45 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3126de3d..1e9732b6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ Enhancements ------------ * The functions `min` and `max` are added for numbers. +* Range cases in `switch` statements now also match floating-point and decimal values. In order to support this, however, small numeric ranges cases are no longer unrolled. Version 1.12.0 diff --git a/src/ast/stmt.rs b/src/ast/stmt.rs index cd26df7f..9c67bcab 100644 --- a/src/ast/stmt.rs +++ b/src/ast/stmt.rs @@ -4,8 +4,9 @@ use super::{ASTFlags, ASTNode, BinaryExpr, Expr, FnCallExpr, Ident}; use crate::engine::{KEYWORD_EVAL, OP_EQUALS}; use crate::func::StraightHashMap; use crate::tokenizer::Token; +use crate::types::dynamic::Union; use crate::types::Span; -use crate::{calc_fn_hash, Position, StaticVec, INT}; +use crate::{calc_fn_hash, Dynamic, Position, StaticVec, INT}; #[cfg(feature = "no_std")] use std::prelude::v1::*; use std::{ @@ -257,7 +258,7 @@ impl IntoIterator for RangeCase { type Item = INT; type IntoIter = Box>; - #[inline(always)] + #[inline] #[must_use] fn into_iter(self) -> Self::IntoIter { match self { @@ -269,7 +270,7 @@ impl IntoIterator for RangeCase { impl RangeCase { /// Returns `true` if the range contains no items. - #[inline(always)] + #[inline] #[must_use] pub fn is_empty(&self) -> bool { match self { @@ -278,7 +279,7 @@ impl RangeCase { } } /// Size of the range. - #[inline(always)] + #[inline] #[must_use] pub fn len(&self) -> INT { match self { @@ -288,15 +289,56 @@ impl RangeCase { Self::InclusiveInt(r, ..) => *r.end() - *r.start() + 1, } } - /// Is the specified number within this range? - #[inline(always)] + /// Is the specified value within this range? + #[inline] #[must_use] - pub fn contains(&self, n: INT) -> bool { + pub fn contains(&self, value: &Dynamic) -> bool { + match value { + Dynamic(Union::Int(v, ..)) => self.contains_int(*v), + #[cfg(not(feature = "no_float"))] + Dynamic(Union::Float(v, ..)) => self.contains_float(**v), + #[cfg(feature = "decimal")] + Dynamic(Union::Decimal(v, ..)) => self.contains_decimal(**v), + _ => false, + } + } + /// Is the specified number within this range? + #[inline] + #[must_use] + pub fn contains_int(&self, n: INT) -> bool { match self { Self::ExclusiveInt(r, ..) => r.contains(&n), Self::InclusiveInt(r, ..) => r.contains(&n), } } + /// Is the specified floating-point number within this range? + #[cfg(not(feature = "no_float"))] + #[inline] + #[must_use] + pub fn contains_float(&self, n: crate::FLOAT) -> bool { + use crate::FLOAT; + + match self { + Self::ExclusiveInt(r, ..) => ((r.start as FLOAT)..(r.end as FLOAT)).contains(&n), + Self::InclusiveInt(r, ..) => ((*r.start() as FLOAT)..=(*r.end() as FLOAT)).contains(&n), + } + } + /// Is the specified decimal number within this range? + #[cfg(feature = "decimal")] + #[inline] + #[must_use] + pub fn contains_decimal(&self, n: rust_decimal::Decimal) -> bool { + use rust_decimal::Decimal; + + match self { + Self::ExclusiveInt(r, ..) => { + (Into::::into(r.start)..Into::::into(r.end)).contains(&n) + } + Self::InclusiveInt(r, ..) => { + (Into::::into(*r.start())..=Into::::into(*r.end())).contains(&n) + } + } + } /// Is the specified range inclusive? #[inline(always)] #[must_use] diff --git a/src/eval/stmt.rs b/src/eval/stmt.rs index d6e5229e..194951c9 100644 --- a/src/eval/stmt.rs +++ b/src/eval/stmt.rs @@ -3,7 +3,8 @@ use super::{Caches, EvalContext, GlobalRuntimeState, Target}; use crate::api::events::VarDefInfo; use crate::ast::{ - ASTFlags, BinaryExpr, Expr, FlowControl, OpAssignment, Stmt, SwitchCasesCollection, + ASTFlags, BinaryExpr, ConditionalExpr, Expr, FlowControl, OpAssignment, Stmt, + SwitchCasesCollection, }; use crate::func::{get_builtin_op_assignment_fn, get_hasher}; use crate::types::dynamic::{AccessMode, Union}; @@ -359,15 +360,13 @@ impl Engine { break; } } - } else if value.is_int() && !ranges.is_empty() { + } else if !ranges.is_empty() { // Then check integer ranges - let value = value.as_int().expect("`INT`"); + for r in ranges.iter().filter(|r| r.contains(&value)) { + let ConditionalExpr { condition, expr } = &expressions[r.index()]; - for r in ranges.iter().filter(|r| r.contains(value)) { - let block = &expressions[r.index()]; - - let cond_result = match block.condition { - Expr::BoolConstant(b, ..) => b, + let cond_result = match condition { + Expr::BoolConstant(b, ..) => *b, ref c => self .eval_expr(global, caches, scope, this_ptr.as_deref_mut(), c)? .as_bool() @@ -377,7 +376,7 @@ impl Engine { }; if cond_result { - result = Some(&block.expr); + result = Some(expr); break; } } diff --git a/src/func/hashing.rs b/src/func/hashing.rs index ac0f1b27..58c4202b 100644 --- a/src/func/hashing.rs +++ b/src/func/hashing.rs @@ -28,6 +28,7 @@ impl Hasher for StraightHasher { self.0 } #[cold] + #[inline(never)] fn write(&mut self, _bytes: &[u8]) { panic!("StraightHasher can only hash u64 values"); } diff --git a/src/optimizer.rs b/src/optimizer.rs index d01f98c3..a502a6f4 100644 --- a/src/optimizer.rs +++ b/src/optimizer.rs @@ -563,16 +563,14 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b } // Then check ranges - if value.is_int() && !ranges.is_empty() { - let value = value.as_int().unwrap(); - + if !ranges.is_empty() { // Only one range or all ranges without conditions if ranges.len() == 1 || ranges .iter() .all(|r| expressions[r.index()].is_always_true()) { - if let Some(r) = ranges.iter().find(|r| r.contains(value)) { + if let Some(r) = ranges.iter().find(|r| r.contains(&value)) { let range_block = &mut expressions[r.index()]; if range_block.is_always_true() { @@ -619,7 +617,7 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut OptimizerState, preserve_result: b let old_ranges_len = ranges.len(); - ranges.retain(|r| r.contains(value)); + ranges.retain(|r| r.contains(&value)); if ranges.len() != old_ranges_len { state.set_dirty(); diff --git a/src/parser.rs b/src/parser.rs index ac296d15..67807e87 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -19,7 +19,7 @@ use crate::types::StringsInterner; use crate::{ calc_fn_hash, Dynamic, Engine, EvalAltResult, EvalContext, ExclusiveRange, FnArgsVec, Identifier, ImmutableString, InclusiveRange, LexError, OptimizationLevel, ParseError, Position, - Scope, Shared, SmartString, StaticVec, AST, INT, PERR, + Scope, Shared, SmartString, StaticVec, AST, PERR, }; use bitflags::bitflags; #[cfg(feature = "no_std")] @@ -42,9 +42,6 @@ const SCOPE_SEARCH_BARRIER_MARKER: &str = "$ BARRIER $"; /// The message: `TokenStream` never ends const NEVER_ENDS: &str = "`Token`"; -/// Unroll `switch` ranges no larger than this. -const SMALL_SWITCH_RANGE: INT = 16; - /// _(internals)_ A type that encapsulates the current state of the parser. /// Exported under the `internals` feature only. pub struct ParseState<'e, 's> { @@ -1216,7 +1213,6 @@ impl Engine { let stmt_block: StmtBlock = stmt.into(); (Expr::Stmt(stmt_block.into()), need_comma) }; - let has_condition = !matches!(condition, Expr::BoolConstant(true, ..)); expressions.push((condition, action_expr).into()); let index = expressions.len() - 1; @@ -1240,23 +1236,9 @@ impl Engine { if let Some(mut r) = range_value { if !r.is_empty() { - // Do not unroll ranges if there are previous non-unrolled ranges - if !has_condition && ranges.is_empty() && r.len() <= SMALL_SWITCH_RANGE - { - // Unroll small range - r.into_iter().for_each(|n| { - let hasher = &mut get_hasher(); - Dynamic::from_int(n).hash(hasher); - cases - .entry(hasher.finish()) - .and_modify(|cases| cases.push(index)) - .or_insert_with(|| [index].into()); - }); - } else { - // Other range - r.set_index(index); - ranges.push(r); - } + // Other range + r.set_index(index); + ranges.push(r); } continue; } diff --git a/tests/expressions.rs b/tests/expressions.rs index 182f9fa9..e15008e8 100644 --- a/tests/expressions.rs +++ b/tests/expressions.rs @@ -50,8 +50,8 @@ fn test_expressions() -> Result<(), Box> { " switch x { 0 => 1, - 1..10 => 123, 10 => 42, + 1..10 => 123, } " )?, @@ -63,11 +63,11 @@ fn test_expressions() -> Result<(), Box> { " switch x { 0 => 1, + 10 => 42, 1..10 => { let y = 123; y } - 10 => 42, } " )