From 22a505b57b358c37e2d0234030d6b5b378600ce0 Mon Sep 17 00:00:00 2001 From: Stephen Chung Date: Mon, 2 Mar 2020 12:08:03 +0800 Subject: [PATCH] Short-curcuit boolean operators. --- README.md | 16 ++++++++- src/builtin.rs | 4 +-- src/engine.rs | 89 +++++++++++++++++++++++++++++++++------------- src/fn_register.rs | 22 +++++++----- src/parser.rs | 8 +++-- tests/bool_op.rs | 88 ++++++++++++++++++++++++++++++++++++++++++++- 6 files changed, 187 insertions(+), 40 deletions(-) diff --git a/README.md b/README.md index 4096dc82..cced83b6 100644 --- a/README.md +++ b/README.md @@ -387,12 +387,26 @@ fn main() { let x = 3; ``` -## Operators +## Numeric operators ```rust let x = (1 + 2) * (6 - 4) / 2; ``` +## Boolean operators + +Double boolean operators `&&` and `||` _short-circuit_, meaning that the second operand will not be evaluated if the first one already proves the condition wrong. + +Single boolean operators `&` and `|` always evaluate both operands. + +```rust +this() || that(); // that() is not evaluated if this() is true +this() && that(); // that() is not evaluated if this() is false + +this() | that(); // both this() and that() are evaluated +this() & that(); // both this() and that() are evaluated +``` + ## If ```rust diff --git a/src/builtin.rs b/src/builtin.rs index 0a6f0cfd..7b928754 100644 --- a/src/builtin.rs +++ b/src/builtin.rs @@ -149,8 +149,8 @@ impl Engine { reg_cmp!(self, "==", eq, i32, i64, u32, u64, bool, String, char, f32, f64); reg_cmp!(self, "!=", ne, i32, i64, u32, u64, bool, String, char, f32, f64); - reg_op!(self, "||", or, bool); - reg_op!(self, "&&", and, bool); + //reg_op!(self, "||", or, bool); + //reg_op!(self, "&&", and, bool); reg_op!(self, "|", binary_or, i32, i64, u32, u64); reg_op!(self, "|", or, bool); reg_op!(self, "&", binary_and, i32, i64, u32, u64); diff --git a/src/engine.rs b/src/engine.rs index 11229517..87132e4f 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -14,9 +14,10 @@ pub type FnCallArgs<'a> = Vec<&'a mut Variant>; #[derive(Debug, Clone)] pub enum EvalAltResult { - ErrorParseError(ParseError), + ErrorParsing(ParseError), ErrorFunctionNotFound(String), - ErrorFunctionArgMismatch, + ErrorFunctionArgsMismatch(String, usize), + ErrorBooleanArgMismatch(String), ErrorArrayBounds(usize, i64), ErrorStringBounds(usize, i64), ErrorIndexing, @@ -28,6 +29,7 @@ pub enum EvalAltResult { ErrorMismatchOutputType(String), ErrorCantOpenScriptFile(String), ErrorDotExpr, + ErrorArithmetic(String), LoopBreak, Return(Dynamic), } @@ -35,10 +37,12 @@ pub enum EvalAltResult { impl EvalAltResult { fn as_str(&self) -> Option<&str> { Some(match self { - EvalAltResult::ErrorCantOpenScriptFile(ref s) - | EvalAltResult::ErrorVariableNotFound(ref s) - | EvalAltResult::ErrorFunctionNotFound(ref s) - | EvalAltResult::ErrorMismatchOutputType(ref s) => s, + Self::ErrorCantOpenScriptFile(s) + | Self::ErrorVariableNotFound(s) + | Self::ErrorFunctionNotFound(s) + | Self::ErrorMismatchOutputType(s) + | Self::ErrorCantOpenScriptFile(s) + | Self::ErrorArithmetic(s) => s, _ => return None, }) } @@ -49,9 +53,12 @@ impl PartialEq for EvalAltResult { use EvalAltResult::*; match (self, other) { - (ErrorParseError(ref a), ErrorParseError(ref b)) => a == b, - (ErrorFunctionNotFound(ref a), ErrorFunctionNotFound(ref b)) => a == b, - (ErrorFunctionArgMismatch, ErrorFunctionArgMismatch) => true, + (ErrorParsing(a), ErrorParsing(b)) => a == b, + (ErrorFunctionNotFound(a), ErrorFunctionNotFound(b)) => a == b, + (ErrorFunctionArgsMismatch(f1, n1), ErrorFunctionArgsMismatch(f2, n2)) => { + f1 == f2 && *n1 == *n2 + } + (ErrorBooleanArgMismatch(a), ErrorBooleanArgMismatch(b)) => a == b, (ErrorIndexExpr, ErrorIndexExpr) => true, (ErrorIndexing, ErrorIndexing) => true, (ErrorArrayBounds(max1, index1), ErrorArrayBounds(max2, index2)) => { @@ -62,11 +69,12 @@ impl PartialEq for EvalAltResult { } (ErrorIfGuard, ErrorIfGuard) => true, (ErrorFor, ErrorFor) => true, - (ErrorVariableNotFound(ref a), ErrorVariableNotFound(ref b)) => a == b, + (ErrorVariableNotFound(a), ErrorVariableNotFound(b)) => a == b, (ErrorAssignmentToUnknownLHS, ErrorAssignmentToUnknownLHS) => true, - (ErrorMismatchOutputType(ref a), ErrorMismatchOutputType(ref b)) => a == b, - (ErrorCantOpenScriptFile(ref a), ErrorCantOpenScriptFile(ref b)) => a == b, + (ErrorMismatchOutputType(a), ErrorMismatchOutputType(b)) => a == b, + (ErrorCantOpenScriptFile(a), ErrorCantOpenScriptFile(b)) => a == b, (ErrorDotExpr, ErrorDotExpr) => true, + (ErrorArithmetic(a), ErrorArithmetic(b)) => a == b, (LoopBreak, LoopBreak) => true, _ => false, } @@ -76,20 +84,21 @@ impl PartialEq for EvalAltResult { impl Error for EvalAltResult { fn description(&self) -> &str { match self { - Self::ErrorParseError(ref p) => p.description(), + Self::ErrorParsing(p) => p.description(), Self::ErrorFunctionNotFound(_) => "Function not found", - Self::ErrorFunctionArgMismatch => "Function argument types do not match", + Self::ErrorFunctionArgsMismatch(_, _) => "Function call with wrong number of arguments", + Self::ErrorBooleanArgMismatch(_) => "Boolean operator expects boolean operands", Self::ErrorIndexExpr => "Indexing into an array or string expects an integer index", Self::ErrorIndexing => "Indexing can only be performed on an array or a string", - Self::ErrorArrayBounds(_, ref index) if *index < 0 => { + Self::ErrorArrayBounds(_, index) if *index < 0 => { "Array access expects non-negative index" } - Self::ErrorArrayBounds(ref max, _) if *max == 0 => "Access of empty array", + Self::ErrorArrayBounds(max, _) if *max == 0 => "Access of empty array", Self::ErrorArrayBounds(_, _) => "Array index out of bounds", - Self::ErrorStringBounds(_, ref index) if *index < 0 => { + Self::ErrorStringBounds(_, index) if *index < 0 => { "Indexing a string expects a non-negative index" } - Self::ErrorStringBounds(ref max, _) if *max == 0 => "Indexing of empty string", + Self::ErrorStringBounds(max, _) if *max == 0 => "Indexing of empty string", Self::ErrorStringBounds(_, _) => "String index out of bounds", Self::ErrorIfGuard => "If guards expect boolean expression", Self::ErrorFor => "For loops expect array", @@ -100,6 +109,7 @@ impl Error for EvalAltResult { Self::ErrorMismatchOutputType(_) => "Output type is incorrect", Self::ErrorCantOpenScriptFile(_) => "Cannot open script file", Self::ErrorDotExpr => "Malformed dot expression", + Self::ErrorArithmetic(_) => "Arithmetic error", Self::LoopBreak => "[Not Error] Breaks out of loop", Self::Return(_) => "[Not Error] Function returns value", } @@ -116,7 +126,13 @@ impl std::fmt::Display for EvalAltResult { write!(f, "{}: {}", self.description(), s) } else { match self { - EvalAltResult::ErrorParseError(ref p) => write!(f, "Syntax error: {}", p), + EvalAltResult::ErrorParsing(p) => write!(f, "Syntax error: {}", p), + EvalAltResult::ErrorFunctionArgsMismatch(fun, n) => { + write!(f, "Function '{}' expects {} argument(s)", fun, n) + } + EvalAltResult::ErrorBooleanArgMismatch(op) => { + write!(f, "Boolean {} operator expects boolean operands", op) + } EvalAltResult::ErrorArrayBounds(_, index) if *index < 0 => { write!(f, "{}: {} < 0", self.description(), index) } @@ -704,9 +720,11 @@ impl Engine { Err(EvalAltResult::ErrorIndexExpr) } } + Expr::Dot(ref dot_lhs, ref dot_rhs) => { self.set_dot_val(scope, dot_lhs, dot_rhs, rhs_val) } + _ => Err(EvalAltResult::ErrorAssignmentToUnknownLHS), } } @@ -728,13 +746,35 @@ impl Engine { Expr::FunctionCall(fn_name, args) => self.call_fn_raw( fn_name.to_owned(), args.iter() - .map(|ex| self.eval_expr(scope, ex)) + .map(|expr| self.eval_expr(scope, expr)) .collect::>()? .iter_mut() .map(|b| b.as_mut()) .collect(), ), + Expr::And(lhs, rhs) => Ok(Box::new( + *self + .eval_expr(scope, &*lhs)? + .downcast::() + .map_err(|_| EvalAltResult::ErrorBooleanArgMismatch("AND".into()))? + && *self + .eval_expr(scope, &*rhs)? + .downcast::() + .map_err(|_| EvalAltResult::ErrorBooleanArgMismatch("AND".into()))?, + )), + + Expr::Or(lhs, rhs) => Ok(Box::new( + *self + .eval_expr(scope, &*lhs)? + .downcast::() + .map_err(|_| EvalAltResult::ErrorBooleanArgMismatch("OR".into()))? + || *self + .eval_expr(scope, &*rhs)? + .downcast::() + .map_err(|_| EvalAltResult::ErrorBooleanArgMismatch("OR".into()))?, + )), + Expr::True => Ok(Box::new(true)), Expr::False => Ok(Box::new(false)), Expr::Unit => Ok(Box::new(())), @@ -878,7 +918,7 @@ impl Engine { let mut contents = String::new(); if f.read_to_string(&mut contents).is_ok() { - Self::compile(&contents).map_err(|err| EvalAltResult::ErrorParseError(err)) + Self::compile(&contents).map_err(|err| EvalAltResult::ErrorParsing(err)) } else { Err(EvalAltResult::ErrorCantOpenScriptFile(filename.to_owned())) } @@ -917,7 +957,7 @@ impl Engine { scope: &mut Scope, input: &str, ) -> Result { - let ast = Self::compile(input).map_err(|err| EvalAltResult::ErrorParseError(err))?; + let ast = Self::compile(input).map_err(|err| EvalAltResult::ErrorParsing(err))?; self.eval_ast_with_scope(scope, &ast) } @@ -1005,9 +1045,8 @@ impl Engine { let tokens = lex(input); let mut peekables = tokens.peekable(); - let tree = parse(&mut peekables); - match tree { + match parse(&mut peekables) { Ok(AST(ref os, ref fns)) => { for f in fns { if f.params.len() > 6 { @@ -1032,7 +1071,7 @@ impl Engine { Ok(()) } - Err(_) => Err(EvalAltResult::ErrorFunctionArgMismatch), + Err(err) => Err(EvalAltResult::ErrorParsing(err)), } } diff --git a/src/fn_register.rs b/src/fn_register.rs index 2aab7623..c23794de 100644 --- a/src/fn_register.rs +++ b/src/fn_register.rs @@ -30,19 +30,22 @@ macro_rules! def_register { > RegisterFn for Engine { fn register_fn(&mut self, name: &str, f: FN) { + let fn_name = name.to_string(); + let fun = move |mut args: FnCallArgs| { // Check for length at the beginning to avoid // per-element bound checks. - if args.len() != count_args!($($par)*) { - return Err(EvalAltResult::ErrorFunctionArgMismatch); + const NUM_ARGS: usize = count_args!($($par)*); + + if args.len() != NUM_ARGS { + return Err(EvalAltResult::ErrorFunctionArgsMismatch(fn_name.clone(), NUM_ARGS)); } #[allow(unused_variables, unused_mut)] let mut drain = args.drain(..); $( // Downcast every element, return in case of a type mismatch - let $par = ((*drain.next().unwrap()).downcast_mut() as Option<&mut $par>) - .ok_or(EvalAltResult::ErrorFunctionArgMismatch)?; + let $par = ((*drain.next().unwrap()).downcast_mut() as Option<&mut $par>).unwrap(); )* // Call the user-supplied function using ($clone) to @@ -60,19 +63,22 @@ macro_rules! def_register { > RegisterDynamicFn for Engine { fn register_dynamic_fn(&mut self, name: &str, f: FN) { + let fn_name = name.to_string(); + let fun = move |mut args: FnCallArgs| { // Check for length at the beginning to avoid // per-element bound checks. - if args.len() != count_args!($($par)*) { - return Err(EvalAltResult::ErrorFunctionArgMismatch); + const NUM_ARGS: usize = count_args!($($par)*); + + if args.len() != NUM_ARGS { + return Err(EvalAltResult::ErrorFunctionArgsMismatch(fn_name.clone(), NUM_ARGS)); } #[allow(unused_variables, unused_mut)] let mut drain = args.drain(..); $( // Downcast every element, return in case of a type mismatch - let $par = ((*drain.next().unwrap()).downcast_mut() as Option<&mut $par>) - .ok_or(EvalAltResult::ErrorFunctionArgMismatch)?; + let $par = ((*drain.next().unwrap()).downcast_mut() as Option<&mut $par>).unwrap(); )* // Call the user-supplied function using ($clone) to diff --git a/src/parser.rs b/src/parser.rs index d36ae1c7..fccd1b30 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -179,6 +179,8 @@ pub enum Expr { Dot(Box, Box), Index(String, Box), Array(Vec), + And(Box, Box), + Or(Box, Box), True, False, Unit, @@ -618,7 +620,7 @@ impl<'a> TokenIterator<'a> { let out: String = result.iter().cloned().collect(); return Some(( - match out.as_ref() { + match out.as_str() { "true" => Token::True, "false" => Token::False, "let" => Token::Let, @@ -1189,8 +1191,8 @@ fn parse_binop<'a>( Token::LessThanEqualsTo => Expr::FunctionCall("<=".into(), vec![lhs_curr, rhs]), Token::GreaterThan => Expr::FunctionCall(">".into(), vec![lhs_curr, rhs]), Token::GreaterThanEqualsTo => Expr::FunctionCall(">=".into(), vec![lhs_curr, rhs]), - Token::Or => Expr::FunctionCall("||".into(), vec![lhs_curr, rhs]), - Token::And => Expr::FunctionCall("&&".into(), vec![lhs_curr, rhs]), + Token::Or => Expr::Or(Box::new(lhs_curr), Box::new(rhs)), + Token::And => Expr::And(Box::new(lhs_curr), Box::new(rhs)), Token::XOr => Expr::FunctionCall("^".into(), vec![lhs_curr, rhs]), Token::OrAssign => { let lhs_copy = lhs_curr.clone(); diff --git a/tests/bool_op.rs b/tests/bool_op.rs index 9565e8ee..ffd90746 100644 --- a/tests/bool_op.rs +++ b/tests/bool_op.rs @@ -1,10 +1,11 @@ -use rhai::Engine; +use rhai::{Engine, EvalAltResult}; #[test] fn test_bool_op1() { let mut engine = Engine::new(); assert_eq!(engine.eval::("true && (false || true)"), Ok(true)); + assert_eq!(engine.eval::("true & (false | true)"), Ok(true)); } #[test] @@ -12,4 +13,89 @@ fn test_bool_op2() { let mut engine = Engine::new(); assert_eq!(engine.eval::("false && (false || true)"), Ok(false)); + assert_eq!(engine.eval::("false & (false | true)"), Ok(false)); +} + +#[test] +fn test_bool_op3() { + let mut engine = Engine::new(); + + assert_eq!( + engine.eval::("true && (false || 123)"), + Err(EvalAltResult::ErrorBooleanArgMismatch("OR".into())) + ); + + assert_eq!(engine.eval::("true && (true || 123)"), Ok(true)); + + assert_eq!( + engine.eval::("123 && (false || true)"), + Err(EvalAltResult::ErrorBooleanArgMismatch("AND".into())) + ); + + assert_eq!(engine.eval::("false && (true || 123)"), Ok(false)); +} + +#[test] +fn test_bool_op_short_circuit() { + let mut engine = Engine::new(); + + assert_eq!( + engine.eval::( + r" + fn this() { true } + fn that() { 9/0 } + + this() || that(); + " + ), + Ok(true) + ); + + assert_eq!( + engine.eval::( + r" + fn this() { false } + fn that() { 9/0 } + + this() && that(); + " + ), + Ok(false) + ); +} + +#[test] +#[should_panic] +fn test_bool_op_no_short_circuit1() { + let mut engine = Engine::new(); + + assert_eq!( + engine.eval::( + r" + fn this() { false } + fn that() { 9/0 } + + this() | that(); + " + ), + Ok(false) + ); +} + +#[test] +#[should_panic] +fn test_bool_op_no_short_circuit2() { + let mut engine = Engine::new(); + + assert_eq!( + engine.eval::( + r" + fn this() { false } + fn that() { 9/0 } + + this() & that(); + " + ), + Ok(false) + ); }