diff --git a/src/engine.rs b/src/engine.rs index 0455c6f1..1c39120f 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -855,7 +855,7 @@ impl Engine { fn_name, args.iter() .map(|name| if name.is::() { - "&str | ImmutableString" + "&str | ImmutableString | String" } else { self.map_type_name((*name).type_name()) }) diff --git a/src/fn_register.rs b/src/fn_register.rs index 7cf7be74..23d8043f 100644 --- a/src/fn_register.rs +++ b/src/fn_register.rs @@ -6,6 +6,7 @@ use crate::engine::Engine; use crate::fn_native::{CallableFunction, FnAny, FnCallArgs, SendSync}; use crate::module::Module; use crate::parser::FnAccess; +use crate::r#unsafe::unsafe_cast_box; use crate::result::EvalAltResult; use crate::utils::ImmutableString; @@ -105,6 +106,9 @@ pub fn by_value(data: &mut Dynamic) -> T { let ref_str = data.as_str().unwrap(); let ref_T = unsafe { mem::transmute::<_, &T>(&ref_str) }; ref_T.clone() + } else if TypeId::of::() == TypeId::of::() { + // If T is String, data must be ImmutableString, so map directly to it + *unsafe_cast_box(Box::new(data.as_str().unwrap().to_string())).unwrap() } else { // We consume the argument and then replace it with () - the argument is not supposed to be used again. // This way, we avoid having to clone the argument again, because it is already a clone when passed here. @@ -154,13 +158,15 @@ pub fn map_result( data } -/// Remap `&str` to `ImmutableString`. +/// Remap `&str` | `String` to `ImmutableString`. #[inline(always)] fn map_type_id() -> TypeId { let id = TypeId::of::(); if id == TypeId::of::<&str>() { TypeId::of::() + } else if id == TypeId::of::() { + TypeId::of::() } else { id } diff --git a/tests/decrement.rs b/tests/decrement.rs index 1b74dfe9..0691eddf 100644 --- a/tests/decrement.rs +++ b/tests/decrement.rs @@ -8,7 +8,7 @@ fn test_decrement() -> Result<(), Box> { assert!(matches!( *engine.eval::(r#"let s = "test"; s -= "ing"; s"#).expect_err("expects error"), - EvalAltResult::ErrorFunctionNotFound(err, _) if err == "- (&str | ImmutableString, &str | ImmutableString)" + EvalAltResult::ErrorFunctionNotFound(err, _) if err == "- (&str | ImmutableString | String, &str | ImmutableString | String)" )); Ok(()) diff --git a/tests/string.rs b/tests/string.rs index e6b77841..908b4c6c 100644 --- a/tests/string.rs +++ b/tests/string.rs @@ -173,17 +173,16 @@ fn test_string_fn() -> Result<(), Box> { "foo" ); - engine.register_fn("foo1", |s: &str| s.len() as INT); - engine.register_fn("foo2", |s: ImmutableString| s.len() as INT); - engine.register_fn("foo3", |s: String| s.len() as INT); + engine + .register_fn("foo1", |s: &str| s.len() as INT) + .register_fn("foo2", |s: ImmutableString| s.len() as INT) + .register_fn("foo3", |s: String| s.len() as INT) + .register_fn("foo4", |s: &mut ImmutableString| s.len() as INT); assert_eq!(engine.eval::(r#"foo1("hello")"#)?, 5); assert_eq!(engine.eval::(r#"foo2("hello")"#)?, 5); - - assert!(matches!( - *engine.eval::(r#"foo3("hello")"#).expect_err("should error"), - EvalAltResult::ErrorFunctionNotFound(err, _) if err == "foo3 (&str | ImmutableString)" - )); + assert_eq!(engine.eval::(r#"foo3("hello")"#)?, 5); + assert_eq!(engine.eval::(r#"foo4("hello")"#)?, 5); Ok(()) }