diff --git a/RELEASES.md b/RELEASES.md index e85c6e56..7158e27c 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -14,6 +14,7 @@ Breaking changes * For plugin functions, constants passed to methods (i.e. `&mut` parameter) now raise an error unless the functions are marked with `#[rhai_fn(pure)]`. * Visibility (i.e. `pub` or not) for generated _plugin_ modules now follow the visibility of the underlying module. +* Comparison operators between the sames types now throw errors when they're not defined instead of returning the default. Only comparing between _different_ types will return the default. * Default stack-overflow and top-level expression nesting limits for release builds are lowered to 64 from 128. New features diff --git a/src/fn_call.rs b/src/fn_call.rs index c39d90cd..78c77702 100644 --- a/src/fn_call.rs +++ b/src/fn_call.rs @@ -21,7 +21,6 @@ use crate::stdlib::{ string::ToString, vec::Vec, }; -use crate::token::is_assignment_operator; use crate::utils::combine_hashes; use crate::{ calc_native_fn_hash, calc_script_fn_hash, Dynamic, Engine, EvalAltResult, FnPtr, @@ -174,6 +173,7 @@ impl Engine { hash_fn: NonZeroU64, args: &mut FnCallArgs, is_ref: bool, + is_op_assignment: bool, pos: Position, ) -> Result<(Dynamic, bool), Box> { self.inc_operations(state, pos)?; @@ -286,15 +286,16 @@ impl Engine { // See if it is built in. if args.len() == 2 && !args[0].is_variant() && !args[1].is_variant() { - if is_assignment_operator(fn_name) { - if is_ref { - // Op-assignment - let (first, second) = args.split_first_mut().unwrap(); + // Op-assignment? + if is_op_assignment { + if !is_ref { + unreachable!("op-assignments must have ref argument"); + } + let (first, second) = args.split_first_mut().unwrap(); - match run_builtin_op_assignment(fn_name, first, second[0])? { - Some(_) => return Ok((Dynamic::UNIT, false)), - None => (), - } + match run_builtin_op_assignment(fn_name, first, second[0])? { + Some(_) => return Ok((Dynamic::UNIT, false)), + None => (), } } else { match run_builtin_binary_op(fn_name, args[0], args[1])? { @@ -756,12 +757,14 @@ impl Engine { Ok((result, false)) } else { // Native function call - self.call_native_fn(mods, state, lib, fn_name, hash_fn, args, is_ref, pos) + self.call_native_fn( + mods, state, lib, fn_name, hash_fn, args, is_ref, false, pos, + ) } } // Native function call - _ => self.call_native_fn(mods, state, lib, fn_name, hash_fn, args, is_ref, pos), + _ => self.call_native_fn(mods, state, lib, fn_name, hash_fn, args, is_ref, false, pos), } } @@ -1368,17 +1371,29 @@ pub fn run_builtin_binary_op( x: &Dynamic, y: &Dynamic, ) -> Result, Box> { - let first_type = x.type_id(); - let second_type = y.type_id(); + let type1 = x.type_id(); + let type2 = y.type_id(); - let type_id = (first_type, second_type); + if x.is_variant() || y.is_variant() { + // One of the operands is a custom type, so it is never built-in + return Ok(match op { + "!=" if type1 != type2 => Some(Dynamic::TRUE), + "==" | ">" | ">=" | "<" | "<=" if type1 != type2 => Some(Dynamic::FALSE), + _ => None, + }); + } + + let types_pair = (type1, type2); #[cfg(not(feature = "no_float"))] - if let Some((x, y)) = if type_id == (TypeId::of::(), TypeId::of::()) { + if let Some((x, y)) = if types_pair == (TypeId::of::(), TypeId::of::()) { + // FLOAT op FLOAT Some((x.clone().cast::(), y.clone().cast::())) - } else if type_id == (TypeId::of::(), TypeId::of::()) { + } else if types_pair == (TypeId::of::(), TypeId::of::()) { + // FLOAT op INT Some((x.clone().cast::(), y.clone().cast::() as FLOAT)) - } else if type_id == (TypeId::of::(), TypeId::of::()) { + } else if types_pair == (TypeId::of::(), TypeId::of::()) { + // INT op FLOAT Some((x.clone().cast::() as FLOAT, y.clone().cast::())) } else { None @@ -1402,16 +1417,19 @@ pub fn run_builtin_binary_op( #[cfg(feature = "decimal")] if let Some((x, y)) = if type_id == (TypeId::of::(), TypeId::of::()) { + // Decimal op Decimal Some(( *x.read_lock::().unwrap(), *y.read_lock::().unwrap(), )) } else if type_id == (TypeId::of::(), TypeId::of::()) { + // Decimal op INT Some(( *x.read_lock::().unwrap(), y.clone().cast::().into(), )) } else if type_id == (TypeId::of::(), TypeId::of::()) { + // INT op Decimal Some(( x.clone().cast::().into(), *y.read_lock::().unwrap(), @@ -1452,8 +1470,9 @@ pub fn run_builtin_binary_op( } } - if second_type != first_type { - if type_id == (TypeId::of::(), TypeId::of::()) { + if type2 != type1 { + // char op string + if types_pair == (TypeId::of::(), TypeId::of::()) { let x = x.clone().cast::(); let y = &*y.read_lock::().unwrap(); @@ -1462,8 +1481,8 @@ pub fn run_builtin_binary_op( _ => return Ok(None), } } - - if type_id == (TypeId::of::(), TypeId::of::()) { + // string op char + if types_pair == (TypeId::of::(), TypeId::of::()) { let x = &*x.read_lock::().unwrap(); let y = y.clone().cast::(); @@ -1472,7 +1491,7 @@ pub fn run_builtin_binary_op( _ => return Ok(None), } } - + // Default comparison operators for different types return Ok(match op { "!=" => Some(Dynamic::TRUE), "==" | ">" | ">=" | "<" | "<=" => Some(Dynamic::FALSE), @@ -1480,7 +1499,9 @@ pub fn run_builtin_binary_op( }); } - if first_type == TypeId::of::() { + // Beyond here, type1 == type2 + + if type1 == TypeId::of::() { let x = x.clone().cast::(); let y = y.clone().cast::(); @@ -1526,7 +1547,7 @@ pub fn run_builtin_binary_op( } } - if first_type == TypeId::of::() { + if type1 == TypeId::of::() { let x = x.clone().cast::(); let y = y.clone().cast::(); @@ -1540,7 +1561,7 @@ pub fn run_builtin_binary_op( } } - if first_type == TypeId::of::() { + if type1 == TypeId::of::() { let x = &*x.read_lock::().unwrap(); let y = &*y.read_lock::().unwrap(); @@ -1556,7 +1577,7 @@ pub fn run_builtin_binary_op( } } - if first_type == TypeId::of::() { + if type1 == TypeId::of::() { let x = x.clone().cast::(); let y = y.clone().cast::(); @@ -1572,7 +1593,7 @@ pub fn run_builtin_binary_op( } } - if first_type == TypeId::of::<()>() { + if type1 == TypeId::of::<()>() { match op { "==" => return Ok(Some(true.into())), "!=" | ">" | ">=" | "<" | "<=" => return Ok(Some(false.into())), @@ -1580,11 +1601,7 @@ pub fn run_builtin_binary_op( } } - Ok(match op { - "!=" => Some(Dynamic::TRUE), - "==" | ">" | ">=" | "<" | "<=" => Some(Dynamic::FALSE), - _ => None, - }) + Ok(None) } /// Build in common operator assignment implementations to avoid the cost of calling a registered function. @@ -1593,16 +1610,18 @@ pub fn run_builtin_op_assignment( x: &mut Dynamic, y: &Dynamic, ) -> Result, Box> { - let first_type = x.type_id(); - let second_type = y.type_id(); + let type1 = x.type_id(); + let type2 = y.type_id(); - let type_id = (first_type, second_type); + let types_pair = (type1, type2); #[cfg(not(feature = "no_float"))] - if let Some((mut x, y)) = if type_id == (TypeId::of::(), TypeId::of::()) { + if let Some((mut x, y)) = if types_pair == (TypeId::of::(), TypeId::of::()) { + // FLOAT op= FLOAT let y = y.clone().cast::(); Some((x.write_lock::().unwrap(), y)) - } else if type_id == (TypeId::of::(), TypeId::of::()) { + } else if types_pair == (TypeId::of::(), TypeId::of::()) { + // FLOAT op= INT let y = y.clone().cast::() as FLOAT; Some((x.write_lock::().unwrap(), y)) } else { @@ -1621,9 +1640,11 @@ pub fn run_builtin_op_assignment( #[cfg(feature = "decimal")] if let Some((mut x, y)) = if type_id == (TypeId::of::(), TypeId::of::()) { + // Decimal op= Decimal let y = *y.read_lock::().unwrap(); Some((x.write_lock::().unwrap(), y)) } else if type_id == (TypeId::of::(), TypeId::of::()) { + // Decimal op= INT let y = y.clone().cast::().into(); Some((x.write_lock::().unwrap(), y)) } else { @@ -1652,8 +1673,8 @@ pub fn run_builtin_op_assignment( } } - if second_type != first_type { - if type_id == (TypeId::of::(), TypeId::of::()) { + if type2 != type1 { + if types_pair == (TypeId::of::(), TypeId::of::()) { let y = y.read_lock::().unwrap().deref().clone(); let mut x = x.write_lock::().unwrap(); @@ -1666,7 +1687,9 @@ pub fn run_builtin_op_assignment( return Ok(None); } - if first_type == TypeId::of::() { + // Beyond here, type1 == type2 + + if type1 == TypeId::of::() { let y = y.clone().cast::(); let mut x = x.write_lock::().unwrap(); @@ -1706,7 +1729,7 @@ pub fn run_builtin_op_assignment( } } - if first_type == TypeId::of::() { + if type1 == TypeId::of::() { let y = y.clone().cast::(); let mut x = x.write_lock::().unwrap(); @@ -1717,7 +1740,7 @@ pub fn run_builtin_op_assignment( } } - if first_type == TypeId::of::() { + if type1 == TypeId::of::() { let y = y.read_lock::().unwrap().deref().clone(); let mut x = x.write_lock::().unwrap(); @@ -1727,7 +1750,7 @@ pub fn run_builtin_op_assignment( } } - if first_type == TypeId::of::() { + if type1 == TypeId::of::() { let y = y.read_lock::().unwrap().deref().clone(); let mut x = x.write_lock::().unwrap(); diff --git a/tests/mismatched_op.rs b/tests/mismatched_op.rs index 5ef401f0..eea4f693 100644 --- a/tests/mismatched_op.rs +++ b/tests/mismatched_op.rs @@ -12,7 +12,7 @@ fn test_mismatched_op() { #[test] #[cfg(not(feature = "no_object"))] -fn test_mismatched_op_custom_type() { +fn test_mismatched_op_custom_type() -> Result<(), Box> { #[derive(Debug, Clone)] struct TestStruct { x: INT, @@ -30,9 +30,18 @@ fn test_mismatched_op_custom_type() { .register_type_with_name::("TestStruct") .register_fn("new_ts", TestStruct::new); + assert!(matches!(*engine.eval::(r" + let x = new_ts(); + let y = new_ts(); + x == y + ").expect_err("should error"), + EvalAltResult::ErrorFunctionNotFound(f, _) if f == "== (TestStruct, TestStruct)")); + + assert!(!engine.eval::("new_ts() == 42")?); + assert!(matches!( *engine.eval::("60 + new_ts()").expect_err("should error"), - EvalAltResult::ErrorFunctionNotFound(err, _) if err == format!("+ ({}, TestStruct)", std::any::type_name::()) + EvalAltResult::ErrorFunctionNotFound(f, _) if f == format!("+ ({}, TestStruct)", std::any::type_name::()) )); assert!(matches!( @@ -40,4 +49,6 @@ fn test_mismatched_op_custom_type() { EvalAltResult::ErrorMismatchOutputType(need, actual, _) if need == "TestStruct" && actual == std::any::type_name::() )); + + Ok(()) }