Add fallible functions support and replace most arithmetic operations with checked versions.

This commit is contained in:
Stephen Chung 2020-03-08 22:47:13 +08:00
parent 3e7adc2e51
commit b1b25d3043
9 changed files with 387 additions and 40 deletions

View File

@ -14,6 +14,9 @@ include = [
"Cargo.toml"
]
[dependencies]
num-traits = "*"
[features]
debug_msgs = []
no_stdlib = []

View File

@ -192,7 +192,9 @@ if z.type_of() == "string" {
Rhai's scripting engine is very lightweight. It gets its ability from the functions in your program. To call these functions, you need to register them with the scripting engine.
```rust
use rhai::{Dynamic, Engine, RegisterFn};
use rhai::Engine;
use rhai::RegisterFn; // include the `RegisterFn` trait to use `register_fn`
use rhai::{Dynamic, RegisterDynamicFn}; // include the `RegisterDynamicFn` trait to use `register_dynamic_fn`
// Normal function
fn add(x: i64, y: i64) -> i64 {
@ -234,7 +236,7 @@ fn decide(yes_no: bool) -> Dynamic {
}
```
# Working with generic functions
# Generic functions
Generic functions can be used in Rhai, but you'll need to register separate instances for each concrete type:
@ -258,7 +260,39 @@ fn main() {
You can also see in this example how you can register multiple functions (or in this case multiple instances of the same function) to the same name in script. This gives you a way to overload functions and call the correct one, based on the types of the arguments, from your script.
# Override built-in functions
# Fallible functions
If your function is _fallible_ (i.e. it returns a `Result<_, Error>`), you can register it with `register_result_fn` (using the `RegisterResultFn` trait).
Your function must return `Result<_, EvalAltResult>`. `EvalAltResult` implements `From<&str>` and `From<String>` etc. and the error text gets converted into `EvalAltResult::ErrorRuntime`.
```rust
use rhai::{Engine, EvalAltResult, Position};
use rhai::RegisterResultFn; // include the `RegisterResultFn` trait to use `register_result_fn`
// Function that may fail
fn safe_divide(x: i64, y: i64) -> Result<i64, EvalAltResult> {
if y == 0 {
// Return an error if y is zero
Err("Division by zero detected!".into()) // short-cut to create EvalAltResult
} else {
Ok(x / y)
}
}
fn main() {
let mut engine = Engine::new();
// Fallible functions that return Result values must use register_result_fn()
engine.register_result_fn("divide", safe_divide);
if let Err(error) = engine.eval::<i64>("divide(40, 0)") {
println!("Error: {:?}", error); // prints ErrorRuntime("Division by zero detected!", (1, 1)")
}
}
```
# Overriding built-in functions
Any similarly-named function defined in a script overrides any built-in function.

View File

@ -3,9 +3,15 @@
use crate::any::Any;
use crate::engine::{Array, Engine};
use crate::fn_register::RegisterFn;
use crate::fn_register::{RegisterFn, RegisterResultFn};
use crate::parser::Position;
use crate::result::EvalAltResult;
use num_traits::{
CheckedAdd, CheckedDiv, CheckedMul, CheckedNeg, CheckedRem, CheckedShl, CheckedShr, CheckedSub,
};
use std::convert::TryFrom;
use std::fmt::{Debug, Display};
use std::ops::{Add, BitAnd, BitOr, BitXor, Div, Mul, Neg, Range, Rem, Shl, Shr, Sub};
use std::ops::{Add, BitAnd, BitOr, BitXor, Div, Mul, Neg, Range, Rem, Sub};
macro_rules! reg_op {
($self:expr, $x:expr, $op:expr, $( $y:ty ),*) => (
@ -15,6 +21,22 @@ macro_rules! reg_op {
)
}
macro_rules! reg_op_result {
($self:expr, $x:expr, $op:expr, $( $y:ty ),*) => (
$(
$self.register_result_fn($x, $op as fn(x: $y, y: $y)->Result<$y,EvalAltResult>);
)*
)
}
macro_rules! reg_op_result1 {
($self:expr, $x:expr, $op:expr, $v:ty, $( $y:ty ),*) => (
$(
$self.register_result_fn($x, $op as fn(x: $y, y: $v)->Result<$y,EvalAltResult>);
)*
)
}
macro_rules! reg_un {
($self:expr, $x:expr, $op:expr, $( $y:ty ),*) => (
$(
@ -23,6 +45,13 @@ macro_rules! reg_un {
)
}
macro_rules! reg_un_result {
($self:expr, $x:expr, $op:expr, $( $y:ty ),*) => (
$(
$self.register_result_fn($x, $op as fn(x: $y)->Result<$y,EvalAltResult>);
)*
)
}
macro_rules! reg_cmp {
($self:expr, $x:expr, $op:expr, $( $y:ty ),*) => (
$(
@ -69,21 +98,96 @@ macro_rules! reg_func3 {
impl Engine<'_> {
/// Register the core built-in library.
pub(crate) fn register_core_lib(&mut self) {
fn add<T: Add>(x: T, y: T) -> <T as Add>::Output {
fn add<T: Display + CheckedAdd>(x: T, y: T) -> Result<T, EvalAltResult> {
x.checked_add(&y).ok_or_else(|| {
EvalAltResult::ErrorArithmetic(
format!("Addition overflow: {} + {}", x, y),
Position::none(),
)
})
}
fn sub<T: Display + CheckedSub>(x: T, y: T) -> Result<T, EvalAltResult> {
x.checked_sub(&y).ok_or_else(|| {
EvalAltResult::ErrorArithmetic(
format!("Subtraction underflow: {} - {}", x, y),
Position::none(),
)
})
}
fn mul<T: Display + CheckedMul>(x: T, y: T) -> Result<T, EvalAltResult> {
x.checked_mul(&y).ok_or_else(|| {
EvalAltResult::ErrorArithmetic(
format!("Multiplication overflow: {} * {}", x, y),
Position::none(),
)
})
}
fn div<T>(x: T, y: T) -> Result<T, EvalAltResult>
where
T: Display + CheckedDiv + PartialEq + TryFrom<i8>,
{
if y == <T as TryFrom<i8>>::try_from(0)
.map_err(|_| ())
.expect("zero should always succeed")
{
return Err(EvalAltResult::ErrorArithmetic(
format!("Division by zero: {} / {}", x, y),
Position::none(),
));
}
x.checked_div(&y).ok_or_else(|| {
EvalAltResult::ErrorArithmetic(
format!("Division overflow: {} / {}", x, y),
Position::none(),
)
})
}
fn neg<T: Display + CheckedNeg>(x: T) -> Result<T, EvalAltResult> {
x.checked_neg().ok_or_else(|| {
EvalAltResult::ErrorArithmetic(
format!("Negation overflow: -{}", x),
Position::none(),
)
})
}
fn abs<T: Display + CheckedNeg + PartialOrd + From<i8>>(x: T) -> Result<T, EvalAltResult> {
if x >= 0.into() {
Ok(x)
} else {
x.checked_neg().ok_or_else(|| {
EvalAltResult::ErrorArithmetic(
format!("Negation overflow: -{}", x),
Position::none(),
)
})
}
}
fn add_unchecked<T: Add>(x: T, y: T) -> <T as Add>::Output {
x + y
}
fn sub<T: Sub>(x: T, y: T) -> <T as Sub>::Output {
fn sub_unchecked<T: Sub>(x: T, y: T) -> <T as Sub>::Output {
x - y
}
fn mul<T: Mul>(x: T, y: T) -> <T as Mul>::Output {
fn mul_unchecked<T: Mul>(x: T, y: T) -> <T as Mul>::Output {
x * y
}
fn div<T: Div>(x: T, y: T) -> <T as Div>::Output {
fn div_unchecked<T: Div>(x: T, y: T) -> <T as Div>::Output {
x / y
}
fn neg<T: Neg>(x: T) -> <T as Neg>::Output {
fn neg_unchecked<T: Neg>(x: T) -> <T as Neg>::Output {
-x
}
fn abs_unchecked<T: Neg + PartialOrd + From<i8>>(x: T) -> T
where
<T as Neg>::Output: Into<T>,
{
if x < 0.into() {
(-x).into()
} else {
x
}
}
fn lt<T: PartialOrd>(x: T, y: T) -> bool {
x < y
}
@ -120,13 +224,45 @@ impl Engine<'_> {
fn binary_xor<T: BitXor>(x: T, y: T) -> <T as BitXor>::Output {
x ^ y
}
fn left_shift<T: Shl<T>>(x: T, y: T) -> <T as Shl<T>>::Output {
x.shl(y)
fn left_shift<T: Display + CheckedShl>(x: T, y: i64) -> Result<T, EvalAltResult> {
if y < 0 {
return Err(EvalAltResult::ErrorArithmetic(
format!("Left-shift by a negative number: {} << {}", x, y),
Position::none(),
));
}
CheckedShl::checked_shl(&x, y as u32).ok_or_else(|| {
EvalAltResult::ErrorArithmetic(
format!("Left-shift overflow: {} << {}", x, y),
Position::none(),
)
})
}
fn right_shift<T: Shr<T>>(x: T, y: T) -> <T as Shr<T>>::Output {
x.shr(y)
fn right_shift<T: Display + CheckedShr>(x: T, y: i64) -> Result<T, EvalAltResult> {
if y < 0 {
return Err(EvalAltResult::ErrorArithmetic(
format!("Right-shift by a negative number: {} >> {}", x, y),
Position::none(),
));
}
CheckedShr::checked_shr(&x, y as u32).ok_or_else(|| {
EvalAltResult::ErrorArithmetic(
format!("Right-shift overflow: {} % {}", x, y),
Position::none(),
)
})
}
fn modulo<T: Rem<T>>(x: T, y: T) -> <T as Rem<T>>::Output {
fn modulo<T: Display + CheckedRem>(x: T, y: T) -> Result<T, EvalAltResult> {
x.checked_rem(&y).ok_or_else(|| {
EvalAltResult::ErrorArithmetic(
format!("Modulo division overflow: {} % {}", x, y),
Position::none(),
)
})
}
fn modulo_unchecked<T: Rem>(x: T, y: T) -> <T as Rem>::Output {
x % y
}
fn pow_i64_i64(x: i64, y: i64) -> i64 {
@ -139,10 +275,15 @@ impl Engine<'_> {
x.powi(y as i32)
}
reg_op!(self, "+", add, i8, u8, i16, u16, i32, i64, u32, u64, f32, f64);
reg_op!(self, "-", sub, i8, u8, i16, u16, i32, i64, u32, u64, f32, f64);
reg_op!(self, "*", mul, i8, u8, i16, u16, i32, i64, u32, u64, f32, f64);
reg_op!(self, "/", div, i8, u8, i16, u16, i32, i64, u32, u64, f32, f64);
reg_op_result!(self, "+", add, i8, u8, i16, u16, i32, i64, u32, u64);
reg_op_result!(self, "-", sub, i8, u8, i16, u16, i32, i64, u32, u64);
reg_op_result!(self, "*", mul, i8, u8, i16, u16, i32, i64, u32, u64);
reg_op_result!(self, "/", div, i8, u8, i16, u16, i32, i64, u32, u64);
reg_op!(self, "+", add_unchecked, f32, f64);
reg_op!(self, "-", sub_unchecked, f32, f64);
reg_op!(self, "*", mul_unchecked, f32, f64);
reg_op!(self, "/", div_unchecked, f32, f64);
reg_cmp!(self, "<", lt, i8, u8, i16, u16, i32, i64, u32, u64, f32, f64, String, char);
reg_cmp!(self, "<=", lte, i8, u8, i16, u16, i32, i64, u32, u64, f32, f64, String, char);
@ -162,15 +303,19 @@ impl Engine<'_> {
reg_op!(self, "&", binary_and, i8, u8, i16, u16, i32, i64, u32, u64);
reg_op!(self, "&", and, bool);
reg_op!(self, "^", binary_xor, i8, u8, i16, u16, i32, i64, u32, u64);
reg_op!(self, "<<", left_shift, i8, u8, i16, u16, i32, i64, u32, u64);
reg_op!(self, ">>", right_shift, i8, u8, i16, u16);
reg_op!(self, ">>", right_shift, i32, i64, u32, u64);
reg_op!(self, "%", modulo, i8, u8, i16, u16, i32, i64, u32, u64);
reg_op_result1!(self, "<<", left_shift, i64, i8, u8, i16, u16, i32, i64, u32, u64);
reg_op_result1!(self, ">>", right_shift, i64, i8, u8, i16, u16);
reg_op_result1!(self, ">>", right_shift, i64, i32, i64, u32, u64);
reg_op_result!(self, "%", modulo, i8, u8, i16, u16, i32, i64, u32, u64);
reg_op!(self, "%", modulo_unchecked, f32, f64);
self.register_fn("~", pow_i64_i64);
self.register_fn("~", pow_f64_f64);
self.register_fn("~", pow_f64_i64);
reg_un!(self, "-", neg, i8, i16, i32, i64, f32, f64);
reg_un_result!(self, "-", neg, i8, i16, i32, i64);
reg_un!(self, "-", neg_unchecked, f32, f64);
reg_un_result!(self, "abs", abs, i8, i16, i32, i64);
reg_un!(self, "abs", abs_unchecked, f32, f64);
reg_un!(self, "!", not, bool);
self.register_fn("+", |x: String, y: String| x + &y); // String + String

View File

@ -143,6 +143,9 @@ impl fmt::Display for ParseError {
if !self.1.is_eof() {
write!(f, " ({})", self.1)
} else if !self.1.is_none() {
// Do not write any position if None
Ok(())
} else {
write!(f, " at the end of the script but there is no more input")
}

View File

@ -58,6 +58,32 @@ pub trait RegisterDynamicFn<FN, ARGS> {
fn register_dynamic_fn(&mut self, name: &str, f: FN);
}
/// A trait to register fallible custom functions returning Result<_, EvalAltResult> with the `Engine`.
///
/// # Example
///
/// ```rust
/// use rhai::{Engine, RegisterFn};
///
/// // Normal function
/// fn add(x: i64, y: i64) -> i64 {
/// x + y
/// }
///
/// let mut engine = Engine::new();
///
/// // You must use the trait rhai::RegisterFn to get this method.
/// engine.register_fn("add", add);
///
/// if let Ok(result) = engine.eval::<i64>("add(40, 2)") {
/// println!("Answer: {}", result); // prints 42
/// }
/// ```
pub trait RegisterResultFn<FN, ARGS, RET> {
/// Register a custom function with the `Engine`.
fn register_result_fn(&mut self, name: &str, f: FN);
}
pub struct Ref<A>(A);
pub struct Mut<A>(A);
@ -91,7 +117,7 @@ macro_rules! def_register {
let mut drain = args.drain(..);
$(
// Downcast every element, return in case of a type mismatch
let $par = (drain.next().unwrap().downcast_mut() as Option<&mut $par>).unwrap();
let $par = drain.next().unwrap().downcast_mut::<$par>().unwrap();
)*
// Call the user-supplied function using ($clone) to
@ -123,7 +149,7 @@ macro_rules! def_register {
let mut drain = args.drain(..);
$(
// Downcast every element, return in case of a type mismatch
let $par = (drain.next().unwrap().downcast_mut() as Option<&mut $par>).unwrap();
let $par = drain.next().unwrap().downcast_mut::<$par>().unwrap();
)*
// Call the user-supplied function using ($clone) to
@ -135,6 +161,44 @@ macro_rules! def_register {
}
}
impl<
$($par: Any + Clone,)*
FN: Fn($($param),*) -> Result<RET, EvalAltResult> + 'static,
RET: Any
> RegisterResultFn<FN, ($($mark,)*), RET> for Engine<'_>
{
fn register_result_fn(&mut self, name: &str, f: FN) {
let fn_name = name.to_string();
let fun = move |mut args: FnCallArgs, pos: Position| {
// Check for length at the beginning to avoid per-element bound checks.
const NUM_ARGS: usize = count_args!($($par)*);
if args.len() != NUM_ARGS {
Err(EvalAltResult::ErrorFunctionArgsMismatch(fn_name.clone(), NUM_ARGS, args.len(), pos))
} else {
#[allow(unused_variables, unused_mut)]
let mut drain = args.drain(..);
$(
// Downcast every element, return in case of a type mismatch
let $par = drain.next().unwrap().downcast_mut::<$par>().unwrap();
)*
// Call the user-supplied function using ($clone) to
// potentially clone the value, otherwise pass the reference.
match f($(($clone)($par)),*) {
Ok(r) => Ok(Box::new(r) as Dynamic),
Err(mut err) => {
err.set_position(pos);
Err(err)
}
}
}
};
self.register_fn_raw(name, Some(vec![$(TypeId::of::<$par>()),*]), Box::new(fun));
}
}
//def_register!(imp_pop $($par => $mark => $param),*);
};
($p0:ident $(, $p:ident)*) => {

View File

@ -69,7 +69,7 @@ pub use any::{Any, AnyExt, Dynamic, Variant};
pub use call::FuncArgs;
pub use engine::{Array, Engine};
pub use error::{ParseError, ParseErrorType};
pub use fn_register::{RegisterDynamicFn, RegisterFn};
pub use fn_register::{RegisterDynamicFn, RegisterFn, RegisterResultFn};
pub use parser::{Position, AST};
pub use result::EvalAltResult;
pub use scope::Scope;

View File

@ -2,7 +2,7 @@
use crate::any::Dynamic;
use crate::error::{LexError, ParseError, ParseErrorType};
use std::{borrow::Cow, char, fmt, iter::Peekable, str::Chars};
use std::{borrow::Cow, char, fmt, iter::Peekable, str::Chars, usize};
type LERR = LexError;
type PERR = ParseErrorType;
@ -17,25 +17,33 @@ pub struct Position {
impl Position {
/// Create a new `Position`.
pub fn new(line: usize, position: usize) -> Self {
if line == 0 || (line == usize::MAX && position == usize::MAX) {
panic!("invalid position: ({}, {})", line, position);
}
Self {
line,
pos: position,
}
}
/// Get the line number (1-based), or `None` if EOF.
/// Get the line number (1-based), or `None` if no position or EOF.
pub fn line(&self) -> Option<usize> {
match self.line {
0 => None,
x => Some(x),
if self.is_none() || self.is_eof() {
None
} else {
Some(self.line)
}
}
/// Get the character position (1-based), or `None` if at beginning of a line.
pub fn position(&self) -> Option<usize> {
match self.pos {
0 => None,
x => Some(x),
if self.is_none() || self.is_eof() {
None
} else if self.pos == 0 {
None
} else {
Some(self.pos)
}
}
@ -61,14 +69,27 @@ impl Position {
self.pos = 0;
}
/// Create a `Position` representing no position.
pub(crate) fn none() -> Self {
Self { line: 0, pos: 0 }
}
/// Create a `Position` at EOF.
pub(crate) fn eof() -> Self {
Self { line: 0, pos: 0 }
Self {
line: usize::MAX,
pos: usize::MAX,
}
}
/// Is there no `Position`?
pub fn is_none(&self) -> bool {
self.line == 0 && self.pos == 0
}
/// Is the `Position` at EOF?
pub fn is_eof(&self) -> bool {
self.line == 0
self.line == usize::MAX && self.pos == usize::MAX
}
}
@ -82,6 +103,8 @@ impl fmt::Display for Position {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.is_eof() {
write!(f, "EOF")
} else if self.is_none() {
write!(f, "none")
} else {
write!(f, "line {}, position {}", self.line, self.pos)
}

View File

@ -118,9 +118,10 @@ impl fmt::Display for EvalAltResult {
Self::ErrorMismatchOutputType(s, pos) => write!(f, "{}: {} ({})", desc, s, pos),
Self::ErrorDotExpr(s, pos) if !s.is_empty() => write!(f, "{} {} ({})", desc, s, pos),
Self::ErrorDotExpr(_, pos) => write!(f, "{} ({})", desc, pos),
Self::ErrorArithmetic(s, pos) => write!(f, "{}: {} ({})", desc, s, pos),
Self::ErrorRuntime(s, pos) if s.is_empty() => write!(f, "{} ({})", desc, pos),
Self::ErrorRuntime(s, pos) => write!(f, "{}: {} ({})", desc, s, pos),
Self::ErrorArithmetic(s, pos) => write!(f, "{} ({})", s, pos),
Self::ErrorRuntime(s, pos) => {
write!(f, "{} ({})", if s.is_empty() { desc } else { s }, pos)
}
Self::LoopBreak => write!(f, "{}", desc),
Self::Return(_, pos) => write!(f, "{} ({})", desc, pos),
Self::ErrorReadingScriptFile(filename, err) => {
@ -171,3 +172,37 @@ impl From<ParseError> for EvalAltResult {
Self::ErrorParsing(err)
}
}
impl EvalAltResult {
pub(crate) fn set_position(&mut self, new_position: Position) {
match self {
EvalAltResult::ErrorReadingScriptFile(_, _)
| EvalAltResult::LoopBreak
| EvalAltResult::ErrorParsing(_) => (),
EvalAltResult::ErrorFunctionNotFound(_, ref mut pos)
| EvalAltResult::ErrorFunctionArgsMismatch(_, _, _, ref mut pos)
| EvalAltResult::ErrorBooleanArgMismatch(_, ref mut pos)
| EvalAltResult::ErrorCharMismatch(ref mut pos)
| EvalAltResult::ErrorArrayBounds(_, _, ref mut pos)
| EvalAltResult::ErrorStringBounds(_, _, ref mut pos)
| EvalAltResult::ErrorIndexingType(_, ref mut pos)
| EvalAltResult::ErrorIndexExpr(ref mut pos)
| EvalAltResult::ErrorIfGuard(ref mut pos)
| EvalAltResult::ErrorFor(ref mut pos)
| EvalAltResult::ErrorVariableNotFound(_, ref mut pos)
| EvalAltResult::ErrorAssignmentToUnknownLHS(ref mut pos)
| EvalAltResult::ErrorMismatchOutputType(_, ref mut pos)
| EvalAltResult::ErrorDotExpr(_, ref mut pos)
| EvalAltResult::ErrorArithmetic(_, ref mut pos)
| EvalAltResult::ErrorRuntime(_, ref mut pos)
| EvalAltResult::Return(_, ref mut pos) => *pos = new_position,
}
}
}
impl<T: AsRef<str>> From<T> for EvalAltResult {
fn from(err: T) -> Self {
Self::ErrorRuntime(err.as_ref().to_string(), Position::none())
}
}

40
tests/math.rs Normal file
View File

@ -0,0 +1,40 @@
use rhai::{Engine, EvalAltResult};
#[test]
fn test_math() -> Result<(), EvalAltResult> {
let mut engine = Engine::new();
assert_eq!(engine.eval::<i64>("1 + 2")?, 3);
assert_eq!(engine.eval::<i64>("1 - 2")?, -1);
assert_eq!(engine.eval::<i64>("2 * 3")?, 6);
assert_eq!(engine.eval::<i64>("1 / 2")?, 0);
assert_eq!(engine.eval::<i64>("3 % 2")?, 1);
assert_eq!(
engine.eval::<i64>("(-9223372036854775807).abs()")?,
9223372036854775807
);
// Overflow/underflow/division-by-zero errors
match engine.eval::<i64>("9223372036854775807 + 1") {
Err(EvalAltResult::ErrorArithmetic(_, _)) => (),
r => panic!("should return overflow error: {:?}", r),
}
match engine.eval::<i64>("(-9223372036854775807) - 2") {
Err(EvalAltResult::ErrorArithmetic(_, _)) => (),
r => panic!("should return underflow error: {:?}", r),
}
match engine.eval::<i64>("9223372036854775807 * 9223372036854775807") {
Err(EvalAltResult::ErrorArithmetic(_, _)) => (),
r => panic!("should return overflow error: {:?}", r),
}
match engine.eval::<i64>("9223372036854775807 / 0") {
Err(EvalAltResult::ErrorArithmetic(_, _)) => (),
r => panic!("should return division by zero error: {:?}", r),
}
match engine.eval::<i64>("9223372036854775807 % 0") {
Err(EvalAltResult::ErrorArithmetic(_, _)) => (),
r => panic!("should return division by zero error: {:?}", r),
}
Ok(())
}