diff --git a/src/module.rs b/src/module.rs index 4d70b169..51d5539c 100644 --- a/src/module.rs +++ b/src/module.rs @@ -4,10 +4,11 @@ use crate::any::{Dynamic, Variant}; use crate::calc_fn_hash; use crate::engine::Engine; use crate::fn_native::{CallableFunction as Func, FnCallArgs, IteratorFn, SendSync}; +use crate::fn_register::by_value as cast_arg; use crate::parser::{FnAccess, FnAccess::Public, ScriptFnDef}; use crate::result::EvalAltResult; use crate::token::{Position, Token}; -use crate::utils::{StaticVec, StraightHasherBuilder}; +use crate::utils::{ImmutableString, StaticVec, StraightHasherBuilder}; #[cfg(not(feature = "no_function"))] use crate::fn_native::Shared; @@ -32,7 +33,6 @@ use crate::stdlib::{ collections::HashMap, fmt, format, iter::empty, - mem, num::NonZeroUsize, ops::{Deref, DerefMut}, string::{String, ToString}, @@ -396,9 +396,21 @@ impl Module { arg_types.len() }; - let hash_fn = calc_fn_hash(empty(), &name, args_len, arg_types.iter().cloned()); + let params = arg_types + .into_iter() + .cloned() + .map(|id| { + if id == TypeId::of::<&str>() { + TypeId::of::() + } else if id == TypeId::of::() { + TypeId::of::() + } else { + id + } + }) + .collect(); - let params = arg_types.into_iter().cloned().collect(); + let hash_fn = calc_fn_hash(empty(), &name, args_len, arg_types.iter().cloned()); self.functions .insert(hash_fn, (name, access, params, func.into())); @@ -518,7 +530,7 @@ impl Module { func: impl Fn(A) -> FuncReturn + SendSync + 'static, ) -> u64 { let f = move |_: &Engine, _: &Module, args: &mut FnCallArgs| { - func(mem::take(args[0]).cast::()).map(Dynamic::from) + func(cast_arg::(&mut args[0])).map(Dynamic::from) }; let arg_types = [TypeId::of::()]; self.set_fn(name, Public, &arg_types, Func::from_pure(Box::new(f))) @@ -592,8 +604,8 @@ impl Module { func: impl Fn(A, B) -> FuncReturn + SendSync + 'static, ) -> u64 { let f = move |_: &Engine, _: &Module, args: &mut FnCallArgs| { - let a = mem::take(args[0]).cast::(); - let b = mem::take(args[1]).cast::(); + let a = cast_arg::(&mut args[0]); + let b = cast_arg::(&mut args[1]); func(a, b).map(Dynamic::from) }; @@ -623,7 +635,7 @@ impl Module { func: impl Fn(&mut A, B) -> FuncReturn + SendSync + 'static, ) -> u64 { let f = move |_: &Engine, _: &Module, args: &mut FnCallArgs| { - let b = mem::take(args[1]).cast::(); + let b = cast_arg::(&mut args[1]); let a = &mut args[0].write_lock::().unwrap(); func(a, b).map(Dynamic::from) @@ -709,9 +721,9 @@ impl Module { func: impl Fn(A, B, C) -> FuncReturn + SendSync + 'static, ) -> u64 { let f = move |_: &Engine, _: &Module, args: &mut FnCallArgs| { - let a = mem::take(args[0]).cast::(); - let b = mem::take(args[1]).cast::(); - let c = mem::take(args[2]).cast::(); + let a = cast_arg::(&mut args[0]); + let b = cast_arg::(&mut args[1]); + let c = cast_arg::(&mut args[2]); func(a, b, c).map(Dynamic::from) }; @@ -746,8 +758,8 @@ impl Module { func: impl Fn(&mut A, B, C) -> FuncReturn + SendSync + 'static, ) -> u64 { let f = move |_: &Engine, _: &Module, args: &mut FnCallArgs| { - let b = mem::take(args[1]).cast::(); - let c = mem::take(args[2]).cast::(); + let b = cast_arg::(&mut args[2]); + let c = cast_arg::(&mut args[3]); let a = &mut args[0].write_lock::().unwrap(); func(a, b, c).map(Dynamic::from) @@ -780,8 +792,8 @@ impl Module { func: impl Fn(&mut A, B, C) -> FuncReturn<()> + SendSync + 'static, ) -> u64 { let f = move |_: &Engine, _: &Module, args: &mut FnCallArgs| { - let b = mem::take(args[1]).cast::(); - let c = mem::take(args[2]).cast::(); + let b = cast_arg::(&mut args[1]); + let c = cast_arg::(&mut args[2]); let a = &mut args[0].write_lock::().unwrap(); func(a, b, c).map(Dynamic::from) @@ -858,10 +870,10 @@ impl Module { func: impl Fn(A, B, C, D) -> FuncReturn + SendSync + 'static, ) -> u64 { let f = move |_: &Engine, _: &Module, args: &mut FnCallArgs| { - let a = mem::take(args[0]).cast::(); - let b = mem::take(args[1]).cast::(); - let c = mem::take(args[2]).cast::(); - let d = mem::take(args[3]).cast::(); + let a = cast_arg::(&mut args[0]); + let b = cast_arg::(&mut args[1]); + let c = cast_arg::(&mut args[2]); + let d = cast_arg::(&mut args[3]); func(a, b, c, d).map(Dynamic::from) }; @@ -902,9 +914,9 @@ impl Module { func: impl Fn(&mut A, B, C, D) -> FuncReturn + SendSync + 'static, ) -> u64 { let f = move |_: &Engine, _: &Module, args: &mut FnCallArgs| { - let b = mem::take(args[1]).cast::(); - let c = mem::take(args[2]).cast::(); - let d = mem::take(args[3]).cast::(); + let b = cast_arg::(&mut args[1]); + let c = cast_arg::(&mut args[2]); + let d = cast_arg::(&mut args[3]); let a = &mut args[0].write_lock::().unwrap(); func(a, b, c, d).map(Dynamic::from) diff --git a/tests/modules.rs b/tests/modules.rs index 0744fd8c..e3475584 100644 --- a/tests/modules.rs +++ b/tests/modules.rs @@ -1,7 +1,7 @@ #![cfg(not(feature = "no_module"))] use rhai::{ - module_resolvers::StaticModuleResolver, Dynamic, Engine, EvalAltResult, Module, ParseError, - ParseErrorType, Scope, INT, + module_resolvers::StaticModuleResolver, Dynamic, Engine, EvalAltResult, ImmutableString, + Module, ParseError, ParseErrorType, Scope, INT, }; #[test] @@ -79,12 +79,12 @@ fn test_module_resolver() -> Result<(), Box> { }); #[cfg(not(feature = "no_float"))] - module.set_fn_4_mut( + module.set_fn_4_mut( "sum_of_three_args".to_string(), |target: &mut INT, a: INT, b: INT, c: f64| { *target = a + b + c as INT; Ok(()) - } + }, ); resolver.insert("hello", module); @@ -316,3 +316,41 @@ fn test_module_export() -> Result<(), Box> { Ok(()) } + +#[test] +fn test_module_str() -> Result<(), Box> { + fn test_fn(_input: ImmutableString) -> Result> { + Ok(42) + } + fn test_fn2(_input: &str) -> Result> { + Ok(42) + } + fn test_fn3(_input: String) -> Result> { + Ok(42) + } + + let mut engine = rhai::Engine::new(); + let mut module = Module::new(); + module.set_fn_1("test", test_fn); + module.set_fn_1("test2", test_fn2); + module.set_fn_1("test3", test_fn3); + + let mut static_modules = rhai::module_resolvers::StaticModuleResolver::new(); + static_modules.insert("test", module); + engine.set_module_resolver(Some(static_modules)); + + assert_eq!( + engine.eval::(r#"import "test" as test; test::test("test");"#)?, + 42 + ); + assert_eq!( + engine.eval::(r#"import "test" as test; test::test2("test");"#)?, + 42 + ); + assert_eq!( + engine.eval::(r#"import "test" as test; test::test3("test");"#)?, + 42 + ); + + Ok(()) +}