From b1b25d3043e194ca6e74a40ba6c20aabe864d7a4 Mon Sep 17 00:00:00 2001 From: Stephen Chung Date: Sun, 8 Mar 2020 22:47:13 +0800 Subject: [PATCH] Add fallible functions support and replace most arithmetic operations with checked versions. --- Cargo.toml | 3 + README.md | 40 +++++++++- src/builtin.rs | 187 ++++++++++++++++++++++++++++++++++++++++----- src/error.rs | 3 + src/fn_register.rs | 68 ++++++++++++++++- src/lib.rs | 2 +- src/parser.rs | 43 ++++++++--- src/result.rs | 41 +++++++++- tests/math.rs | 40 ++++++++++ 9 files changed, 387 insertions(+), 40 deletions(-) create mode 100644 tests/math.rs diff --git a/Cargo.toml b/Cargo.toml index f586d1f3..e39a0115 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,9 @@ include = [ "Cargo.toml" ] +[dependencies] +num-traits = "*" + [features] debug_msgs = [] no_stdlib = [] \ No newline at end of file diff --git a/README.md b/README.md index 1c7a3223..9eaacd38 100644 --- a/README.md +++ b/README.md @@ -192,7 +192,9 @@ if z.type_of() == "string" { Rhai's scripting engine is very lightweight. It gets its ability from the functions in your program. To call these functions, you need to register them with the scripting engine. ```rust -use rhai::{Dynamic, Engine, RegisterFn}; +use rhai::Engine; +use rhai::RegisterFn; // include the `RegisterFn` trait to use `register_fn` +use rhai::{Dynamic, RegisterDynamicFn}; // include the `RegisterDynamicFn` trait to use `register_dynamic_fn` // Normal function fn add(x: i64, y: i64) -> i64 { @@ -234,7 +236,7 @@ fn decide(yes_no: bool) -> Dynamic { } ``` -# Working with generic functions +# Generic functions Generic functions can be used in Rhai, but you'll need to register separate instances for each concrete type: @@ -258,7 +260,39 @@ fn main() { You can also see in this example how you can register multiple functions (or in this case multiple instances of the same function) to the same name in script. This gives you a way to overload functions and call the correct one, based on the types of the arguments, from your script. -# Override built-in functions +# Fallible functions + +If your function is _fallible_ (i.e. it returns a `Result<_, Error>`), you can register it with `register_result_fn` (using the `RegisterResultFn` trait). + +Your function must return `Result<_, EvalAltResult>`. `EvalAltResult` implements `From<&str>` and `From` etc. and the error text gets converted into `EvalAltResult::ErrorRuntime`. + +```rust +use rhai::{Engine, EvalAltResult, Position}; +use rhai::RegisterResultFn; // include the `RegisterResultFn` trait to use `register_result_fn` + +// Function that may fail +fn safe_divide(x: i64, y: i64) -> Result { + if y == 0 { + // Return an error if y is zero + Err("Division by zero detected!".into()) // short-cut to create EvalAltResult + } else { + Ok(x / y) + } +} + +fn main() { + let mut engine = Engine::new(); + + // Fallible functions that return Result values must use register_result_fn() + engine.register_result_fn("divide", safe_divide); + + if let Err(error) = engine.eval::("divide(40, 0)") { + println!("Error: {:?}", error); // prints ErrorRuntime("Division by zero detected!", (1, 1)") + } +} +``` + +# Overriding built-in functions Any similarly-named function defined in a script overrides any built-in function. diff --git a/src/builtin.rs b/src/builtin.rs index 66789ef3..5a24db0a 100644 --- a/src/builtin.rs +++ b/src/builtin.rs @@ -3,9 +3,15 @@ use crate::any::Any; use crate::engine::{Array, Engine}; -use crate::fn_register::RegisterFn; +use crate::fn_register::{RegisterFn, RegisterResultFn}; +use crate::parser::Position; +use crate::result::EvalAltResult; +use num_traits::{ + CheckedAdd, CheckedDiv, CheckedMul, CheckedNeg, CheckedRem, CheckedShl, CheckedShr, CheckedSub, +}; +use std::convert::TryFrom; use std::fmt::{Debug, Display}; -use std::ops::{Add, BitAnd, BitOr, BitXor, Div, Mul, Neg, Range, Rem, Shl, Shr, Sub}; +use std::ops::{Add, BitAnd, BitOr, BitXor, Div, Mul, Neg, Range, Rem, Sub}; macro_rules! reg_op { ($self:expr, $x:expr, $op:expr, $( $y:ty ),*) => ( @@ -15,6 +21,22 @@ macro_rules! reg_op { ) } +macro_rules! reg_op_result { + ($self:expr, $x:expr, $op:expr, $( $y:ty ),*) => ( + $( + $self.register_result_fn($x, $op as fn(x: $y, y: $y)->Result<$y,EvalAltResult>); + )* + ) +} + +macro_rules! reg_op_result1 { + ($self:expr, $x:expr, $op:expr, $v:ty, $( $y:ty ),*) => ( + $( + $self.register_result_fn($x, $op as fn(x: $y, y: $v)->Result<$y,EvalAltResult>); + )* + ) +} + macro_rules! reg_un { ($self:expr, $x:expr, $op:expr, $( $y:ty ),*) => ( $( @@ -23,6 +45,13 @@ macro_rules! reg_un { ) } +macro_rules! reg_un_result { + ($self:expr, $x:expr, $op:expr, $( $y:ty ),*) => ( + $( + $self.register_result_fn($x, $op as fn(x: $y)->Result<$y,EvalAltResult>); + )* + ) +} macro_rules! reg_cmp { ($self:expr, $x:expr, $op:expr, $( $y:ty ),*) => ( $( @@ -69,21 +98,96 @@ macro_rules! reg_func3 { impl Engine<'_> { /// Register the core built-in library. pub(crate) fn register_core_lib(&mut self) { - fn add(x: T, y: T) -> ::Output { + fn add(x: T, y: T) -> Result { + x.checked_add(&y).ok_or_else(|| { + EvalAltResult::ErrorArithmetic( + format!("Addition overflow: {} + {}", x, y), + Position::none(), + ) + }) + } + fn sub(x: T, y: T) -> Result { + x.checked_sub(&y).ok_or_else(|| { + EvalAltResult::ErrorArithmetic( + format!("Subtraction underflow: {} - {}", x, y), + Position::none(), + ) + }) + } + fn mul(x: T, y: T) -> Result { + x.checked_mul(&y).ok_or_else(|| { + EvalAltResult::ErrorArithmetic( + format!("Multiplication overflow: {} * {}", x, y), + Position::none(), + ) + }) + } + fn div(x: T, y: T) -> Result + where + T: Display + CheckedDiv + PartialEq + TryFrom, + { + if y == >::try_from(0) + .map_err(|_| ()) + .expect("zero should always succeed") + { + return Err(EvalAltResult::ErrorArithmetic( + format!("Division by zero: {} / {}", x, y), + Position::none(), + )); + } + + x.checked_div(&y).ok_or_else(|| { + EvalAltResult::ErrorArithmetic( + format!("Division overflow: {} / {}", x, y), + Position::none(), + ) + }) + } + fn neg(x: T) -> Result { + x.checked_neg().ok_or_else(|| { + EvalAltResult::ErrorArithmetic( + format!("Negation overflow: -{}", x), + Position::none(), + ) + }) + } + fn abs>(x: T) -> Result { + if x >= 0.into() { + Ok(x) + } else { + x.checked_neg().ok_or_else(|| { + EvalAltResult::ErrorArithmetic( + format!("Negation overflow: -{}", x), + Position::none(), + ) + }) + } + } + fn add_unchecked(x: T, y: T) -> ::Output { x + y } - fn sub(x: T, y: T) -> ::Output { + fn sub_unchecked(x: T, y: T) -> ::Output { x - y } - fn mul(x: T, y: T) -> ::Output { + fn mul_unchecked(x: T, y: T) -> ::Output { x * y } - fn div(x: T, y: T) -> ::Output { + fn div_unchecked(x: T, y: T) -> ::Output { x / y } - fn neg(x: T) -> ::Output { + fn neg_unchecked(x: T) -> ::Output { -x } + fn abs_unchecked>(x: T) -> T + where + ::Output: Into, + { + if x < 0.into() { + (-x).into() + } else { + x + } + } fn lt(x: T, y: T) -> bool { x < y } @@ -120,13 +224,45 @@ impl Engine<'_> { fn binary_xor(x: T, y: T) -> ::Output { x ^ y } - fn left_shift>(x: T, y: T) -> >::Output { - x.shl(y) + fn left_shift(x: T, y: i64) -> Result { + if y < 0 { + return Err(EvalAltResult::ErrorArithmetic( + format!("Left-shift by a negative number: {} << {}", x, y), + Position::none(), + )); + } + + CheckedShl::checked_shl(&x, y as u32).ok_or_else(|| { + EvalAltResult::ErrorArithmetic( + format!("Left-shift overflow: {} << {}", x, y), + Position::none(), + ) + }) } - fn right_shift>(x: T, y: T) -> >::Output { - x.shr(y) + fn right_shift(x: T, y: i64) -> Result { + if y < 0 { + return Err(EvalAltResult::ErrorArithmetic( + format!("Right-shift by a negative number: {} >> {}", x, y), + Position::none(), + )); + } + + CheckedShr::checked_shr(&x, y as u32).ok_or_else(|| { + EvalAltResult::ErrorArithmetic( + format!("Right-shift overflow: {} % {}", x, y), + Position::none(), + ) + }) } - fn modulo>(x: T, y: T) -> >::Output { + fn modulo(x: T, y: T) -> Result { + x.checked_rem(&y).ok_or_else(|| { + EvalAltResult::ErrorArithmetic( + format!("Modulo division overflow: {} % {}", x, y), + Position::none(), + ) + }) + } + fn modulo_unchecked(x: T, y: T) -> ::Output { x % y } fn pow_i64_i64(x: i64, y: i64) -> i64 { @@ -139,10 +275,15 @@ impl Engine<'_> { x.powi(y as i32) } - reg_op!(self, "+", add, i8, u8, i16, u16, i32, i64, u32, u64, f32, f64); - reg_op!(self, "-", sub, i8, u8, i16, u16, i32, i64, u32, u64, f32, f64); - reg_op!(self, "*", mul, i8, u8, i16, u16, i32, i64, u32, u64, f32, f64); - reg_op!(self, "/", div, i8, u8, i16, u16, i32, i64, u32, u64, f32, f64); + reg_op_result!(self, "+", add, i8, u8, i16, u16, i32, i64, u32, u64); + reg_op_result!(self, "-", sub, i8, u8, i16, u16, i32, i64, u32, u64); + reg_op_result!(self, "*", mul, i8, u8, i16, u16, i32, i64, u32, u64); + reg_op_result!(self, "/", div, i8, u8, i16, u16, i32, i64, u32, u64); + + reg_op!(self, "+", add_unchecked, f32, f64); + reg_op!(self, "-", sub_unchecked, f32, f64); + reg_op!(self, "*", mul_unchecked, f32, f64); + reg_op!(self, "/", div_unchecked, f32, f64); reg_cmp!(self, "<", lt, i8, u8, i16, u16, i32, i64, u32, u64, f32, f64, String, char); reg_cmp!(self, "<=", lte, i8, u8, i16, u16, i32, i64, u32, u64, f32, f64, String, char); @@ -162,15 +303,19 @@ impl Engine<'_> { reg_op!(self, "&", binary_and, i8, u8, i16, u16, i32, i64, u32, u64); reg_op!(self, "&", and, bool); reg_op!(self, "^", binary_xor, i8, u8, i16, u16, i32, i64, u32, u64); - reg_op!(self, "<<", left_shift, i8, u8, i16, u16, i32, i64, u32, u64); - reg_op!(self, ">>", right_shift, i8, u8, i16, u16); - reg_op!(self, ">>", right_shift, i32, i64, u32, u64); - reg_op!(self, "%", modulo, i8, u8, i16, u16, i32, i64, u32, u64); + reg_op_result1!(self, "<<", left_shift, i64, i8, u8, i16, u16, i32, i64, u32, u64); + reg_op_result1!(self, ">>", right_shift, i64, i8, u8, i16, u16); + reg_op_result1!(self, ">>", right_shift, i64, i32, i64, u32, u64); + reg_op_result!(self, "%", modulo, i8, u8, i16, u16, i32, i64, u32, u64); + reg_op!(self, "%", modulo_unchecked, f32, f64); self.register_fn("~", pow_i64_i64); self.register_fn("~", pow_f64_f64); self.register_fn("~", pow_f64_i64); - reg_un!(self, "-", neg, i8, i16, i32, i64, f32, f64); + reg_un_result!(self, "-", neg, i8, i16, i32, i64); + reg_un!(self, "-", neg_unchecked, f32, f64); + reg_un_result!(self, "abs", abs, i8, i16, i32, i64); + reg_un!(self, "abs", abs_unchecked, f32, f64); reg_un!(self, "!", not, bool); self.register_fn("+", |x: String, y: String| x + &y); // String + String diff --git a/src/error.rs b/src/error.rs index 41a793c3..b631c4cf 100644 --- a/src/error.rs +++ b/src/error.rs @@ -143,6 +143,9 @@ impl fmt::Display for ParseError { if !self.1.is_eof() { write!(f, " ({})", self.1) + } else if !self.1.is_none() { + // Do not write any position if None + Ok(()) } else { write!(f, " at the end of the script but there is no more input") } diff --git a/src/fn_register.rs b/src/fn_register.rs index 535ffc83..a213474a 100644 --- a/src/fn_register.rs +++ b/src/fn_register.rs @@ -58,6 +58,32 @@ pub trait RegisterDynamicFn { fn register_dynamic_fn(&mut self, name: &str, f: FN); } +/// A trait to register fallible custom functions returning Result<_, EvalAltResult> with the `Engine`. +/// +/// # Example +/// +/// ```rust +/// use rhai::{Engine, RegisterFn}; +/// +/// // Normal function +/// fn add(x: i64, y: i64) -> i64 { +/// x + y +/// } +/// +/// let mut engine = Engine::new(); +/// +/// // You must use the trait rhai::RegisterFn to get this method. +/// engine.register_fn("add", add); +/// +/// if let Ok(result) = engine.eval::("add(40, 2)") { +/// println!("Answer: {}", result); // prints 42 +/// } +/// ``` +pub trait RegisterResultFn { + /// Register a custom function with the `Engine`. + fn register_result_fn(&mut self, name: &str, f: FN); +} + pub struct Ref(A); pub struct Mut(A); @@ -91,7 +117,7 @@ macro_rules! def_register { let mut drain = args.drain(..); $( // Downcast every element, return in case of a type mismatch - let $par = (drain.next().unwrap().downcast_mut() as Option<&mut $par>).unwrap(); + let $par = drain.next().unwrap().downcast_mut::<$par>().unwrap(); )* // Call the user-supplied function using ($clone) to @@ -123,7 +149,7 @@ macro_rules! def_register { let mut drain = args.drain(..); $( // Downcast every element, return in case of a type mismatch - let $par = (drain.next().unwrap().downcast_mut() as Option<&mut $par>).unwrap(); + let $par = drain.next().unwrap().downcast_mut::<$par>().unwrap(); )* // Call the user-supplied function using ($clone) to @@ -135,6 +161,44 @@ macro_rules! def_register { } } + impl< + $($par: Any + Clone,)* + FN: Fn($($param),*) -> Result + 'static, + RET: Any + > RegisterResultFn for Engine<'_> + { + fn register_result_fn(&mut self, name: &str, f: FN) { + let fn_name = name.to_string(); + + let fun = move |mut args: FnCallArgs, pos: Position| { + // Check for length at the beginning to avoid per-element bound checks. + const NUM_ARGS: usize = count_args!($($par)*); + + if args.len() != NUM_ARGS { + Err(EvalAltResult::ErrorFunctionArgsMismatch(fn_name.clone(), NUM_ARGS, args.len(), pos)) + } else { + #[allow(unused_variables, unused_mut)] + let mut drain = args.drain(..); + $( + // Downcast every element, return in case of a type mismatch + let $par = drain.next().unwrap().downcast_mut::<$par>().unwrap(); + )* + + // Call the user-supplied function using ($clone) to + // potentially clone the value, otherwise pass the reference. + match f($(($clone)($par)),*) { + Ok(r) => Ok(Box::new(r) as Dynamic), + Err(mut err) => { + err.set_position(pos); + Err(err) + } + } + } + }; + self.register_fn_raw(name, Some(vec![$(TypeId::of::<$par>()),*]), Box::new(fun)); + } + } + //def_register!(imp_pop $($par => $mark => $param),*); }; ($p0:ident $(, $p:ident)*) => { diff --git a/src/lib.rs b/src/lib.rs index 5c947e95..541963c1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -69,7 +69,7 @@ pub use any::{Any, AnyExt, Dynamic, Variant}; pub use call::FuncArgs; pub use engine::{Array, Engine}; pub use error::{ParseError, ParseErrorType}; -pub use fn_register::{RegisterDynamicFn, RegisterFn}; +pub use fn_register::{RegisterDynamicFn, RegisterFn, RegisterResultFn}; pub use parser::{Position, AST}; pub use result::EvalAltResult; pub use scope::Scope; diff --git a/src/parser.rs b/src/parser.rs index ad714187..13afb96b 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -2,7 +2,7 @@ use crate::any::Dynamic; use crate::error::{LexError, ParseError, ParseErrorType}; -use std::{borrow::Cow, char, fmt, iter::Peekable, str::Chars}; +use std::{borrow::Cow, char, fmt, iter::Peekable, str::Chars, usize}; type LERR = LexError; type PERR = ParseErrorType; @@ -17,25 +17,33 @@ pub struct Position { impl Position { /// Create a new `Position`. pub fn new(line: usize, position: usize) -> Self { + if line == 0 || (line == usize::MAX && position == usize::MAX) { + panic!("invalid position: ({}, {})", line, position); + } + Self { line, pos: position, } } - /// Get the line number (1-based), or `None` if EOF. + /// Get the line number (1-based), or `None` if no position or EOF. pub fn line(&self) -> Option { - match self.line { - 0 => None, - x => Some(x), + if self.is_none() || self.is_eof() { + None + } else { + Some(self.line) } } /// Get the character position (1-based), or `None` if at beginning of a line. pub fn position(&self) -> Option { - match self.pos { - 0 => None, - x => Some(x), + if self.is_none() || self.is_eof() { + None + } else if self.pos == 0 { + None + } else { + Some(self.pos) } } @@ -61,14 +69,27 @@ impl Position { self.pos = 0; } + /// Create a `Position` representing no position. + pub(crate) fn none() -> Self { + Self { line: 0, pos: 0 } + } + /// Create a `Position` at EOF. pub(crate) fn eof() -> Self { - Self { line: 0, pos: 0 } + Self { + line: usize::MAX, + pos: usize::MAX, + } + } + + /// Is there no `Position`? + pub fn is_none(&self) -> bool { + self.line == 0 && self.pos == 0 } /// Is the `Position` at EOF? pub fn is_eof(&self) -> bool { - self.line == 0 + self.line == usize::MAX && self.pos == usize::MAX } } @@ -82,6 +103,8 @@ impl fmt::Display for Position { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if self.is_eof() { write!(f, "EOF") + } else if self.is_none() { + write!(f, "none") } else { write!(f, "line {}, position {}", self.line, self.pos) } diff --git a/src/result.rs b/src/result.rs index 26272030..95b5ee12 100644 --- a/src/result.rs +++ b/src/result.rs @@ -118,9 +118,10 @@ impl fmt::Display for EvalAltResult { Self::ErrorMismatchOutputType(s, pos) => write!(f, "{}: {} ({})", desc, s, pos), Self::ErrorDotExpr(s, pos) if !s.is_empty() => write!(f, "{} {} ({})", desc, s, pos), Self::ErrorDotExpr(_, pos) => write!(f, "{} ({})", desc, pos), - Self::ErrorArithmetic(s, pos) => write!(f, "{}: {} ({})", desc, s, pos), - Self::ErrorRuntime(s, pos) if s.is_empty() => write!(f, "{} ({})", desc, pos), - Self::ErrorRuntime(s, pos) => write!(f, "{}: {} ({})", desc, s, pos), + Self::ErrorArithmetic(s, pos) => write!(f, "{} ({})", s, pos), + Self::ErrorRuntime(s, pos) => { + write!(f, "{} ({})", if s.is_empty() { desc } else { s }, pos) + } Self::LoopBreak => write!(f, "{}", desc), Self::Return(_, pos) => write!(f, "{} ({})", desc, pos), Self::ErrorReadingScriptFile(filename, err) => { @@ -171,3 +172,37 @@ impl From for EvalAltResult { Self::ErrorParsing(err) } } + +impl EvalAltResult { + pub(crate) fn set_position(&mut self, new_position: Position) { + match self { + EvalAltResult::ErrorReadingScriptFile(_, _) + | EvalAltResult::LoopBreak + | EvalAltResult::ErrorParsing(_) => (), + + EvalAltResult::ErrorFunctionNotFound(_, ref mut pos) + | EvalAltResult::ErrorFunctionArgsMismatch(_, _, _, ref mut pos) + | EvalAltResult::ErrorBooleanArgMismatch(_, ref mut pos) + | EvalAltResult::ErrorCharMismatch(ref mut pos) + | EvalAltResult::ErrorArrayBounds(_, _, ref mut pos) + | EvalAltResult::ErrorStringBounds(_, _, ref mut pos) + | EvalAltResult::ErrorIndexingType(_, ref mut pos) + | EvalAltResult::ErrorIndexExpr(ref mut pos) + | EvalAltResult::ErrorIfGuard(ref mut pos) + | EvalAltResult::ErrorFor(ref mut pos) + | EvalAltResult::ErrorVariableNotFound(_, ref mut pos) + | EvalAltResult::ErrorAssignmentToUnknownLHS(ref mut pos) + | EvalAltResult::ErrorMismatchOutputType(_, ref mut pos) + | EvalAltResult::ErrorDotExpr(_, ref mut pos) + | EvalAltResult::ErrorArithmetic(_, ref mut pos) + | EvalAltResult::ErrorRuntime(_, ref mut pos) + | EvalAltResult::Return(_, ref mut pos) => *pos = new_position, + } + } +} + +impl> From for EvalAltResult { + fn from(err: T) -> Self { + Self::ErrorRuntime(err.as_ref().to_string(), Position::none()) + } +} diff --git a/tests/math.rs b/tests/math.rs new file mode 100644 index 00000000..bc79cbb0 --- /dev/null +++ b/tests/math.rs @@ -0,0 +1,40 @@ +use rhai::{Engine, EvalAltResult}; + +#[test] +fn test_math() -> Result<(), EvalAltResult> { + let mut engine = Engine::new(); + + assert_eq!(engine.eval::("1 + 2")?, 3); + assert_eq!(engine.eval::("1 - 2")?, -1); + assert_eq!(engine.eval::("2 * 3")?, 6); + assert_eq!(engine.eval::("1 / 2")?, 0); + assert_eq!(engine.eval::("3 % 2")?, 1); + assert_eq!( + engine.eval::("(-9223372036854775807).abs()")?, + 9223372036854775807 + ); + + // Overflow/underflow/division-by-zero errors + match engine.eval::("9223372036854775807 + 1") { + Err(EvalAltResult::ErrorArithmetic(_, _)) => (), + r => panic!("should return overflow error: {:?}", r), + } + match engine.eval::("(-9223372036854775807) - 2") { + Err(EvalAltResult::ErrorArithmetic(_, _)) => (), + r => panic!("should return underflow error: {:?}", r), + } + match engine.eval::("9223372036854775807 * 9223372036854775807") { + Err(EvalAltResult::ErrorArithmetic(_, _)) => (), + r => panic!("should return overflow error: {:?}", r), + } + match engine.eval::("9223372036854775807 / 0") { + Err(EvalAltResult::ErrorArithmetic(_, _)) => (), + r => panic!("should return division by zero error: {:?}", r), + } + match engine.eval::("9223372036854775807 % 0") { + Err(EvalAltResult::ErrorArithmetic(_, _)) => (), + r => panic!("should return division by zero error: {:?}", r), + } + + Ok(()) +}