From bf245a798b57b9776eacffba8926b55f248391af Mon Sep 17 00:00:00 2001 From: Stephen Chung Date: Sat, 19 Sep 2020 18:18:40 +0800 Subject: [PATCH] Enable String parameters. --- codegen/src/function.rs | 23 ++++++++--- codegen/src/rhai_module.rs | 6 +++ codegen/src/test/function.rs | 12 ++---- codegen/src/test/module.rs | 76 ++++++++++++++++++++++++++++++++---- tests/plugins.rs | 5 +++ 5 files changed, 100 insertions(+), 22 deletions(-) diff --git a/codegen/src/function.rs b/codegen/src/function.rs index aa059013..82e76214 100644 --- a/codegen/src/function.rs +++ b/codegen/src/function.rs @@ -638,11 +638,13 @@ impl ExportedFn { // Handle the rest of the arguments, which all are passed by value. // // The only exception is strings, which need to be downcast to ImmutableString to enable a - // zero-copy conversion to &str by reference. + // zero-copy conversion to &str by reference, or a cloned String. let str_type_path = syn::parse2::(quote! { str }).unwrap(); + let string_type_path = syn::parse2::(quote! { String }).unwrap(); for (i, arg) in self.arg_list().enumerate().skip(skip_first_arg as usize) { let var = syn::Ident::new(&format!("arg{}", i), proc_macro2::Span::call_site()); - let is_str_ref; + let is_string; + let is_ref; match arg { syn::FnArg::Typed(pattern) => { let arg_type: &syn::Type = pattern.ty.as_ref(); @@ -653,15 +655,24 @@ impl ExportedFn { .. }) => match elem.as_ref() { &syn::Type::Path(ref p) if p.path == str_type_path => { - is_str_ref = true; + is_string = true; + is_ref = true; quote_spanned!(arg_type.span()=> mem::take(args[#i]) .clone().cast::()) } _ => panic!("internal error: why wasn't this found earlier!?"), }, + &syn::Type::Path(ref p) if p.path == string_type_path => { + is_string = true; + is_ref = false; + quote_spanned!(arg_type.span()=> + mem::take(args[#i]) + .clone().cast::<#arg_type>()) + } _ => { - is_str_ref = false; + is_string = false; + is_ref = false; quote_spanned!(arg_type.span()=> mem::take(args[#i]).clone().cast::<#arg_type>()) } @@ -673,7 +684,7 @@ impl ExportedFn { }) .unwrap(), ); - if !is_str_ref { + if !is_string { input_type_exprs.push( syn::parse2::(quote_spanned!( arg_type.span()=> TypeId::of::<#arg_type>() @@ -691,7 +702,7 @@ impl ExportedFn { } syn::FnArg::Receiver(_) => panic!("internal error: how did this happen!?"), } - if !is_str_ref { + if !is_ref { unpack_exprs.push(syn::parse2::(quote! { #var }).unwrap()); } else { unpack_exprs.push(syn::parse2::(quote! { &#var }).unwrap()); diff --git a/codegen/src/rhai_module.rs b/codegen/src/rhai_module.rs index 1859ba42..c25c9484 100644 --- a/codegen/src/rhai_module.rs +++ b/codegen/src/rhai_module.rs @@ -19,6 +19,7 @@ pub(crate) fn generate_body( let mut add_mod_blocks: Vec = Vec::new(); let mut set_flattened_mod_blocks: Vec = Vec::new(); let str_type_path = syn::parse2::(quote! { str }).unwrap(); + let string_type_path = syn::parse2::(quote! { String }).unwrap(); for (const_name, _, _) in consts { let const_literal = syn::LitStr::new(&const_name, proc_macro2::Span::call_site()); @@ -97,6 +98,11 @@ pub(crate) fn generate_body( } _ => panic!("internal error: non-string shared reference!?"), }, + syn::Type::Path(ref p) if p.path == string_type_path => { + syn::parse2::(quote! { + ImmutableString }) + .unwrap() + } syn::Type::Reference(syn::TypeReference { mutability: Some(_), ref elem, diff --git a/codegen/src/test/function.rs b/codegen/src/test/function.rs index 61c1f942..b6d8e894 100644 --- a/codegen/src/test/function.rs +++ b/codegen/src/test/function.rs @@ -245,8 +245,7 @@ mod generate_tests { false } }); - /* - let (actual_diff, expected_diff) = { + let (_actual_diff, _expected_diff) = { let mut actual_diff = String::new(); let mut expected_diff = String::new(); for (a, e) in _iter.take(50) { @@ -255,13 +254,10 @@ mod generate_tests { } (actual_diff, expected_diff) }; - */ eprintln!("actual != expected, diverge at char {}", counter); - /* - eprintln!(" actual: {}", actual_diff); - eprintln!("expected: {}", expected_diff); - assert!(false); - */ + // eprintln!(" actual: {}", _actual_diff); + // eprintln!("expected: {}", _expected_diff); + // assert!(false); } assert_eq!(actual, expected); } diff --git a/codegen/src/test/module.rs b/codegen/src/test/module.rs index 59e58ce7..a2ab7eab 100644 --- a/codegen/src/test/module.rs +++ b/codegen/src/test/module.rs @@ -226,8 +226,7 @@ mod generate_tests { false } }); - /* - let (actual_diff, expected_diff) = { + let (_actual_diff, _expected_diff) = { let mut actual_diff = String::new(); let mut expected_diff = String::new(); for (a, e) in _iter.take(50) { @@ -236,13 +235,10 @@ mod generate_tests { } (actual_diff, expected_diff) }; - */ eprintln!("actual != expected, diverge at char {}", counter); - /* - eprintln!(" actual: {}", actual_diff); - eprintln!("expected: {}", expected_diff); - assert!(false); - */ + // eprintln!(" actual: {}", _actual_diff); + // eprintln!("expected: {}", _expected_diff); + // assert!(false); } assert_eq!(actual, expected); } @@ -978,6 +974,70 @@ mod generate_tests { assert_streams_eq(item_mod.generate(), expected_tokens); } + #[test] + fn one_string_arg_fn_module() { + let input_tokens: TokenStream = quote! { + pub mod str_fn { + pub fn print_out_to(x: String) { + x + 1 + } + } + }; + + let expected_tokens = quote! { + pub mod str_fn { + pub fn print_out_to(x: String) { + x + 1 + } + #[allow(unused_imports)] + use super::*; + + pub fn rhai_module_generate() -> Module { + let mut m = Module::new(); + rhai_generate_into_module(&mut m, false); + m + } + #[allow(unused_mut)] + pub fn rhai_generate_into_module(m: &mut Module, flatten: bool) { + m.set_fn("print_out_to", FnAccess::Public, + &[core::any::TypeId::of::()], + CallableFunction::from_plugin(print_out_to_token())); + if flatten {} else {} + } + #[allow(non_camel_case_types)] + struct print_out_to_token(); + impl PluginFunction for print_out_to_token { + fn call(&self, + args: &mut [&mut Dynamic] + ) -> Result> { + debug_assert_eq!(args.len(), 1usize, + "wrong arg count: {} != {}", args.len(), 1usize); + let arg0 = mem::take(args[0usize]).clone().cast::(); + Ok(Dynamic::from(print_out_to(arg0))) + } + + fn is_method_call(&self) -> bool { false } + fn is_varadic(&self) -> bool { false } + fn clone_boxed(&self) -> Box { + Box::new(print_out_to_token()) + } + fn input_types(&self) -> Box<[TypeId]> { + new_vec![TypeId::of::()].into_boxed_slice() + } + } + pub fn print_out_to_token_callable() -> CallableFunction { + CallableFunction::from_plugin(print_out_to_token()) + } + pub fn print_out_to_token_input_types() -> Box<[TypeId]> { + print_out_to_token().input_types() + } + } + }; + + let item_mod = syn::parse2::(input_tokens).unwrap(); + assert_streams_eq(item_mod.generate(), expected_tokens); + } + #[test] fn one_mut_ref_fn_module() { let input_tokens: TokenStream = quote! { diff --git a/tests/plugins.rs b/tests/plugins.rs index 2bee84e4..ce403dd6 100644 --- a/tests/plugins.rs +++ b/tests/plugins.rs @@ -24,6 +24,10 @@ mod test { } } + pub fn hash(_text: String) -> INT { + 42 + } + #[rhai_fn(name = "test", name = "hi")] #[inline(always)] pub fn len(array: &mut Array, mul: INT) -> INT { @@ -77,6 +81,7 @@ fn test_plugins_package() -> Result<(), Box> { #[cfg(not(feature = "no_object"))] assert_eq!(engine.eval::("let a = [1, 2, 3]; a.foo")?, 1); + assert_eq!(engine.eval::(r#"hash("hello")"#)?, 42); assert_eq!(engine.eval::("let a = [1, 2, 3]; test(a, 2)")?, 6); assert_eq!(engine.eval::("let a = [1, 2, 3]; hi(a, 2)")?, 6); assert_eq!(engine.eval::("let a = [1, 2, 3]; test(a, 2)")?, 6);