diff --git a/codegen/src/attrs.rs b/codegen/src/attrs.rs new file mode 100644 index 00000000..1850d294 --- /dev/null +++ b/codegen/src/attrs.rs @@ -0,0 +1,117 @@ +use syn::{parse::ParseStream, parse::Parser, spanned::Spanned}; + +pub trait ExportedParams: Sized { + fn parse_stream(args: ParseStream) -> syn::Result; + fn no_attrs() -> Self; + fn from_info(info: ExportInfo) -> syn::Result; +} + +pub struct AttrItem { + pub key: proc_macro2::Ident, + pub value: Option, +} + +pub struct ExportInfo { + pub item_span: proc_macro2::Span, + pub items: Vec, +} + +pub fn parse_attr_items(args: ParseStream) -> syn::Result { + if args.is_empty() { + return Ok(ExportInfo { item_span: args.span(), items: Vec::new()}); + } + let arg_list = args + .call(syn::punctuated::Punctuated::::parse_separated_nonempty)?; + + parse_punctuated_items(arg_list) +} + +pub fn parse_punctuated_items( + arg_list: syn::punctuated::Punctuated, +) -> syn::Result { + let list_span = arg_list.span(); + + let mut attrs: Vec = Vec::new(); + for arg in arg_list { + let (key, value) = match arg { + syn::Expr::Assign(syn::ExprAssign { + ref left, + ref right, + .. + }) => { + let attr_name: syn::Ident = match left.as_ref() { + syn::Expr::Path(syn::ExprPath { + path: attr_path, .. + }) => attr_path.get_ident().cloned().ok_or_else(|| { + syn::Error::new(attr_path.span(), "expecting attribute name") + })?, + x => return Err(syn::Error::new(x.span(), "expecting attribute name")), + }; + let attr_value = match right.as_ref() { + syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Str(string), + .. + }) => string.clone(), + x => return Err(syn::Error::new(x.span(), "expecting string literal")), + }; + (attr_name, Some(attr_value)) + } + syn::Expr::Path(syn::ExprPath { + path: attr_path, .. + }) => attr_path + .get_ident() + .cloned() + .map(|a| (a, None)) + .ok_or_else(|| syn::Error::new(attr_path.span(), "expecting attribute name"))?, + x => return Err(syn::Error::new(x.span(), "expecting identifier")), + }; + attrs.push(AttrItem { key, value }); + } + + Ok(ExportInfo { item_span: list_span, items: attrs }) +} + +pub(crate) fn outer_item_attributes( + args: proc_macro2::TokenStream, + _attr_name: &str, +) -> syn::Result { + if args.is_empty() { + return Ok(T::no_attrs()); + } + + let parser = syn::punctuated::Punctuated::::parse_separated_nonempty; + let arg_list = parser.parse2(args)?; + + let export_info = parse_punctuated_items(arg_list)?; + T::from_info(export_info) +} + +pub(crate) fn inner_item_attributes( + attrs: &mut Vec, + attr_name: &str, +) -> syn::Result { + // Find the #[rhai_fn] attribute which will turn be read for the function parameters. + if let Some(rhai_fn_idx) = attrs + .iter() + .position(|a| a.path.get_ident().map(|i| *i == attr_name).unwrap_or(false)) + { + let rhai_fn_attr = attrs.remove(rhai_fn_idx); + rhai_fn_attr.parse_args_with(T::parse_stream) + } else { + Ok(T::no_attrs()) + } +} + +pub(crate) fn deny_cfg_attr(attrs: &Vec) -> syn::Result<()> { + if let Some(cfg_attr) = attrs + .iter() + .find(|a| a.path.get_ident().map(|i| *i == "cfg").unwrap_or(false)) + { + Err(syn::Error::new( + cfg_attr.span(), + "cfg attributes not allowed on this item", + )) + } else { + Ok(()) + } +} diff --git a/codegen/src/function.rs b/codegen/src/function.rs index 89358274..b02a657d 100644 --- a/codegen/src/function.rs +++ b/codegen/src/function.rs @@ -2,17 +2,19 @@ #[cfg(no_std)] use core::mem; +#[cfg(not(no_std))] +use std::mem; #[cfg(no_std)] use alloc::format; #[cfg(not(no_std))] use std::format; -use std::collections::HashMap; - use quote::{quote, quote_spanned}; use syn::{parse::Parse, parse::ParseStream, parse::Parser, spanned::Spanned}; +use crate::attrs::{ExportInfo, ExportedParams}; + #[derive(Debug, Default)] pub(crate) struct ExportedFnParams { pub name: Option, @@ -21,14 +23,6 @@ pub(crate) struct ExportedFnParams { pub span: Option, } -impl ExportedFnParams { - pub fn skip() -> ExportedFnParams { - let mut skip = ExportedFnParams::default(); - skip.skip = true; - skip - } -} - pub const FN_IDX_GET: &str = "index$get$"; pub const FN_IDX_SET: &str = "index$set$"; @@ -45,58 +39,44 @@ impl Parse for ExportedFnParams { return Ok(ExportedFnParams::default()); } - let arg_list = args.call( - syn::punctuated::Punctuated::::parse_separated_nonempty, - )?; - let span = arg_list.span(); + let info = crate::attrs::parse_attr_items(args)?; + Self::from_info(info) + } +} - let mut attrs: HashMap> = HashMap::new(); - for arg in arg_list { - let (left, right) = match arg { - syn::Expr::Assign(syn::ExprAssign { - ref left, - ref right, - .. - }) => { - let attr_name: syn::Ident = match left.as_ref() { - syn::Expr::Path(syn::ExprPath { - path: attr_path, .. - }) => attr_path.get_ident().cloned().ok_or_else(|| { - syn::Error::new(attr_path.span(), "expecting attribute name") - })?, - x => return Err(syn::Error::new(x.span(), "expecting attribute name")), - }; - let attr_value = match right.as_ref() { - syn::Expr::Lit(syn::ExprLit { - lit: syn::Lit::Str(string), - .. - }) => string.clone(), - x => return Err(syn::Error::new(x.span(), "expecting string literal")), - }; - (attr_name, Some(attr_value)) - } - syn::Expr::Path(syn::ExprPath { - path: attr_path, .. - }) => attr_path - .get_ident() - .cloned() - .map(|a| (a, None)) - .ok_or_else(|| syn::Error::new(attr_path.span(), "expecting attribute name"))?, - x => return Err(syn::Error::new(x.span(), "expecting identifier")), - }; - attrs.insert(left, right); - } +impl ExportedParams for ExportedFnParams { + fn parse_stream(args: ParseStream) -> syn::Result { + Self::parse(args) + } + fn no_attrs() -> Self { + Default::default() + } + + fn from_info( + info: crate::attrs::ExportInfo, + ) -> syn::Result { + let ExportInfo { item_span: span, items: attrs } = info; let mut name = None; let mut return_raw = false; let mut skip = false; - for (ident, value) in attrs.drain() { - match (ident.to_string().as_ref(), value) { - ("name", Some(s)) => name = Some(s.value()), + for attr in attrs { + let crate::attrs::AttrItem { key, value } = attr; + match (key.to_string().as_ref(), value) { + ("name", Some(s)) => { + // check validity of name + if s.value().contains('.') { + return Err(syn::Error::new( + s.span(), + "Rhai function names may not contain dot", + )); + } + name = Some(s.value()) + } ("get", Some(s)) => name = Some(make_getter(&s.value())), ("set", Some(s)) => name = Some(make_setter(&s.value())), ("get", None) | ("set", None) | ("name", None) => { - return Err(syn::Error::new(ident.span(), "requires value")) + return Err(syn::Error::new(key.span(), "requires value")) } ("index_get", None) => name = Some(FN_IDX_GET.to_string()), ("index_set", None) => name = Some(FN_IDX_SET.to_string()), @@ -108,21 +88,13 @@ impl Parse for ExportedFnParams { ("skip", Some(s)) => return Err(syn::Error::new(s.span(), "extraneous value")), (attr, _) => { return Err(syn::Error::new( - ident.span(), + key.span(), format!("unknown attribute '{}'", attr), )) } } } - // Check validity of name, if present. - if name.as_ref().filter(|n| n.contains('.')).is_some() { - return Err(syn::Error::new( - span, - "Rhai function names may not contain dot" - )) - } - Ok(ExportedFnParams { name, return_raw, @@ -149,20 +121,10 @@ impl Parse for ExportedFn { let str_type_path = syn::parse2::(quote! { str }).unwrap(); // #[cfg] attributes are not allowed on functions due to what is generated for them - if let Some(cfg_attr) = fn_all.attrs.iter().find(|a| { - a.path - .get_ident() - .map(|i| i.to_string() == "cfg") - .unwrap_or(false) - }) { - return Err(syn::Error::new(cfg_attr.span(), "cfg attributes not allowed on this item")); - } + crate::attrs::deny_cfg_attr(&fn_all.attrs)?; // Determine if the function is public. - let is_public = match fn_all.vis { - syn::Visibility::Public(_) => true, - _ => false, - }; + let is_public = matches!(fn_all.vis, syn::Visibility::Public(_)); // Determine whether function generates a special calling convention for a mutable // reciever. let mut_receiver = { @@ -214,10 +176,7 @@ impl Parse for ExportedFn { mutability: None, ref elem, .. - }) => match elem.as_ref() { - &syn::Type::Path(ref p) if p.path == str_type_path => true, - _ => false, - }, + }) => matches!(elem.as_ref(), &syn::Type::Path(ref p) if p.path == str_type_path), &syn::Type::Verbatim(_) => false, _ => true, }; @@ -291,12 +250,22 @@ impl ExportedFn { } } - pub fn generate_with_params( - mut self, - mut params: ExportedFnParams, - ) -> proc_macro2::TokenStream { + pub fn set_params(&mut self, mut params: ExportedFnParams) -> syn::Result<()> { + // Do not allow non-returning raw functions. + // + // This is caught now to avoid issues with diagnostics later. + if params.return_raw + && mem::discriminant(&self.signature.output) + == mem::discriminant(&syn::ReturnType::Default) + { + return Err(syn::Error::new( + self.signature.span(), + "return_raw functions must return Result", + )); + } + self.params = params; - self.generate() + Ok(()) } pub fn generate(self) -> proc_macro2::TokenStream { @@ -353,7 +322,7 @@ impl ExportedFn { } } } else { - quote! { + quote_spanned! { self.return_type().unwrap().span()=> type EvalBox = Box; pub #dynamic_signature { super::#name(#(#arguments),*) @@ -520,7 +489,7 @@ impl ExportedFn { Ok(Dynamic::from(#sig_name(#(#unpack_exprs),*))) } } else { - quote! { + quote_spanned! { self.return_type().unwrap().span()=> #sig_name(#(#unpack_exprs),*) } }; @@ -607,7 +576,7 @@ mod function_tests { &syn::parse2::(quote! { x: usize }).unwrap() ); assert_eq!( - item_fn.arg_list().skip(1).next().unwrap(), + item_fn.arg_list().nth(1).unwrap(), &syn::parse2::(quote! { y: f32 }).unwrap() ); } @@ -728,7 +697,7 @@ mod function_tests { &syn::parse2::(quote! { level: usize }).unwrap() ); assert_eq!( - item_fn.arg_list().skip(1).next().unwrap(), + item_fn.arg_list().nth(1).unwrap(), &syn::parse2::(quote! { message: &str }).unwrap() ); } diff --git a/codegen/src/lib.rs b/codegen/src/lib.rs index 9675ac66..7054cbe5 100644 --- a/codegen/src/lib.rs +++ b/codegen/src/lib.rs @@ -96,6 +96,7 @@ use quote::quote; use syn::parse_macro_input; +mod attrs; mod function; mod module; mod register; @@ -108,10 +109,16 @@ pub fn export_fn( ) -> proc_macro::TokenStream { let mut output = proc_macro2::TokenStream::from(input.clone()); - let parsed_params = parse_macro_input!(args as function::ExportedFnParams); - let function_def = parse_macro_input!(input as function::ExportedFn); + let parsed_params = match crate::attrs::outer_item_attributes(args.into(), "export_fn") { + Ok(args) => args, + Err(err) => return proc_macro::TokenStream::from(err.to_compile_error()), + }; + let mut function_def = parse_macro_input!(input as function::ExportedFn); + if let Err(e) = function_def.set_params(parsed_params) { + return e.to_compile_error().into(); + } - output.extend(function_def.generate_with_params(parsed_params)); + output.extend(function_def.generate()); proc_macro::TokenStream::from(output) } diff --git a/codegen/src/module.rs b/codegen/src/module.rs index 34b67876..535ea9d7 100644 --- a/codegen/src/module.rs +++ b/codegen/src/module.rs @@ -1,7 +1,7 @@ use quote::{quote, ToTokens}; -use syn::{parse::Parse, parse::ParseStream, spanned::Spanned}; +use syn::{parse::Parse, parse::ParseStream}; -use crate::function::{ExportedFn, ExportedFnParams}; +use crate::function::ExportedFn; use crate::rhai_module::ExportedConst; #[cfg(no_std)] @@ -15,92 +15,9 @@ use core::mem; use std::mem; use std::borrow::Cow; -use std::collections::HashMap; -fn inner_fn_attributes(f: &mut syn::ItemFn) -> syn::Result { - // #[cfg] attributes are not allowed on objects - if let Some(cfg_attr) = f.attrs.iter().find(|a| { - a.path - .get_ident() - .map(|i| i.to_string() == "cfg") - .unwrap_or(false) - }) { - return Err(syn::Error::new(cfg_attr.span(), "cfg attributes not allowed on this item")); - } - - // Find the #[rhai_fn] attribute which will turn be read for the function parameters. - if let Some(rhai_fn_idx) = f.attrs.iter().position(|a| { - a.path - .get_ident() - .map(|i| i.to_string() == "rhai_fn") - .unwrap_or(false) - }) { - let rhai_fn_attr = f.attrs.remove(rhai_fn_idx); - rhai_fn_attr.parse_args() - } else if let syn::Visibility::Public(_) = f.vis { - Ok(ExportedFnParams::default()) - } else { - Ok(ExportedFnParams::skip()) - } -} - -fn check_rename_collisions(fns: &Vec) -> Result<(), syn::Error> { - let mut renames = HashMap::::new(); - let mut names = HashMap::::new(); - for itemfn in fns.iter() { - if let Some(ref name) = itemfn.params.name { - let current_span = itemfn.params.span.as_ref().unwrap(); - let key = itemfn.arg_list().fold(name.clone(), |mut argstr, fnarg| { - let type_string: String = match fnarg { - syn::FnArg::Receiver(_) => unimplemented!("receiver rhai_fns not implemented"), - syn::FnArg::Typed(syn::PatType { ref ty, .. }) => - ty.as_ref().to_token_stream().to_string(), - }; - argstr.push('.'); - argstr.extend(type_string.chars()); - argstr - }); - if let Some(other_span) = renames.insert(key, - current_span.clone()) { - let mut err = syn::Error::new(current_span.clone(), - format!("duplicate Rhai signature for '{}'", &name)); - err.combine(syn::Error::new(other_span, - format!("duplicated function renamed '{}'", &name))); - return Err(err); - } - } else { - let ident = itemfn.name(); - names.insert(ident.to_string(), ident.span()); - } - } - for (new_name, attr_span) in renames.drain() { - let new_name = new_name.split('.').next().unwrap(); - if let Some(fn_span) = names.get(new_name) { - let mut err = syn::Error::new(attr_span, - format!("duplicate Rhai signature for '{}'", &new_name)); - err.combine(syn::Error::new(fn_span.clone(), - format!("duplicated function '{}'", &new_name))); - return Err(err); - } - } - Ok(()) -} - -fn inner_mod_attributes(f: &mut syn::ItemMod) -> syn::Result { - if let Some(rhai_mod_idx) = f.attrs.iter().position(|a| { - a.path - .get_ident() - .map(|i| i.to_string() == "rhai_mod") - .unwrap_or(false) - }) { - let rhai_mod_attr = f.attrs.remove(rhai_mod_idx); - rhai_mod_attr.parse_args() - } else if let syn::Visibility::Public(_) = f.vis { - Ok(ExportedModParams::default()) - } else { - Ok(ExportedModParams::skip()) - } -} +use crate::attrs::{AttrItem, ExportInfo, ExportedParams}; +use crate::function::{ExportedFnParams}; #[derive(Debug, Default)] pub(crate) struct ExportedModParams { @@ -108,81 +25,52 @@ pub(crate) struct ExportedModParams { pub skip: bool, } -impl ExportedModParams { - pub fn skip() -> ExportedModParams { - let mut skip = ExportedModParams::default(); - skip.skip = true; - skip - } -} - impl Parse for ExportedModParams { fn parse(args: ParseStream) -> syn::Result { if args.is_empty() { return Ok(ExportedModParams::default()); } - let arg_list = args.call( - syn::punctuated::Punctuated::::parse_separated_nonempty, - )?; + let info = crate::attrs::parse_attr_items(args)?; - let mut attrs: HashMap> = HashMap::new(); - for arg in arg_list { - let (left, right) = match arg { - syn::Expr::Assign(syn::ExprAssign { - ref left, - ref right, - .. - }) => { - let attr_name: syn::Ident = match left.as_ref() { - syn::Expr::Path(syn::ExprPath { - path: attr_path, .. - }) => attr_path.get_ident().cloned().ok_or_else(|| { - syn::Error::new(attr_path.span(), "expecting attribute name") - })?, - x => return Err(syn::Error::new(x.span(), "expecting attribute name")), - }; - let attr_value = match right.as_ref() { - syn::Expr::Lit(syn::ExprLit { - lit: syn::Lit::Str(string), - .. - }) => string.clone(), - x => return Err(syn::Error::new(x.span(), "expecting string literal")), - }; - (attr_name, Some(attr_value)) - } - syn::Expr::Path(syn::ExprPath { - path: attr_path, .. - }) => attr_path - .get_ident() - .cloned() - .map(|a| (a, None)) - .ok_or_else(|| syn::Error::new(attr_path.span(), "expecting attribute name"))?, - x => return Err(syn::Error::new(x.span(), "expecting identifier")), - }; - attrs.insert(left, right); - } + Self::from_info(info) + } +} +impl ExportedParams for ExportedModParams { + fn parse_stream(args: ParseStream) -> syn::Result { + Self::parse(args) + } + + fn no_attrs() -> Self { + Default::default() + } + + fn from_info(info: ExportInfo) -> syn::Result { + let ExportInfo { items: attrs, .. } = info; let mut name = None; let mut skip = false; - for (ident, value) in attrs.drain() { - match (ident.to_string().as_ref(), value) { + for attr in attrs { + let AttrItem { key, value } = attr; + match (key.to_string().as_ref(), value) { ("name", Some(s)) => name = Some(s.value()), - ("name", None) => return Err(syn::Error::new(ident.span(), "requires value")), + ("name", None) => return Err(syn::Error::new(key.span(), "requires value")), ("skip", None) => skip = true, - ("skip", Some(s)) => { - return Err(syn::Error::new(s.span(), "extraneous value")) - } + ("skip", Some(s)) => return Err(syn::Error::new(s.span(), "extraneous value")), (attr, _) => { return Err(syn::Error::new( - ident.span(), + key.span(), format!("unknown attribute '{}'", attr), )) } } } - Ok(ExportedModParams { name, skip, ..Default::default() }) + Ok(ExportedModParams { + name, + skip, + ..Default::default() + }) } } @@ -209,17 +97,26 @@ impl Parse for Module { syn::Item::Fn(f) => Some(f), _ => None, }) - .try_fold(Vec::new(), |mut vec, mut itemfn| { - let params = match inner_fn_attributes(&mut itemfn) { - Ok(p) => p, - Err(e) => return Err(e), + .try_fold(Vec::new(), |mut vec, itemfn| { + // #[cfg] attributes are not allowed on functions + crate::attrs::deny_cfg_attr(&itemfn.attrs)?; + + let mut params: ExportedFnParams = + match crate::attrs::inner_item_attributes(&mut itemfn.attrs, "rhai_fn") { + Ok(p) => p, + Err(e) => return Err(e), + }; + params.skip = if let syn::Visibility::Public(_) = itemfn.vis { + params.skip + } else { + true }; syn::parse2::(itemfn.to_token_stream()) .map(|mut f| { f.params = params; f }) - .map(|f| if !f.params.skip { vec.push(f) }) + .map(|f| vec.push(f)) .map(|_| vec) })?; // Gather and parse constants definitions. @@ -233,23 +130,14 @@ impl Parse for Module { .. }) => { // #[cfg] attributes are not allowed on const declarations - if let Some(cfg_attr) = attrs.iter().find(|a| { - a.path - .get_ident() - .map(|i| i.to_string() == "cfg") - .unwrap_or(false) - }) { - return Err(syn::Error::new( - cfg_attr.span(), - "cfg attributes not allowed on this item")); - } + crate::attrs::deny_cfg_attr(&attrs)?; if let syn::Visibility::Public(_) = vis { consts.push((ident.to_string(), expr.as_ref().clone())); } - }, - _ => {}, + } + _ => {} } - }; + } // Gather and parse submodule definitions. // // They are actually removed from the module's body, because they will need @@ -257,23 +145,27 @@ impl Parse for Module { submodules.reserve(content.len() - fns.len() - consts.len()); let mut i = 0; while i < content.len() { - if let syn::Item::Mod(_) = &content[i] { + if let syn::Item::Mod(_) = &content[i] { let mut itemmod = match content.remove(i) { syn::Item::Mod(m) => m, _ => unreachable!(), }; - let params = match inner_mod_attributes(&mut itemmod) { - Ok(p) => p, - Err(e) => return Err(e), + let mut params: ExportedModParams = + match crate::attrs::inner_item_attributes(&mut itemmod.attrs, "rhai_mod") { + Ok(p) => p, + Err(e) => return Err(e), + }; + params.skip = if let syn::Visibility::Public(_) = itemmod.vis { + params.skip + } else { + true }; - let module = syn::parse2::(itemmod.to_token_stream()) - .map(|mut f| { + let module = + syn::parse2::(itemmod.to_token_stream()).map(|mut f| { f.params = params; f })?; - if !module.params.skip { - submodules.push(module); - } + submodules.push(module); } else { i += 1; } @@ -308,6 +200,10 @@ impl Module { } } + pub fn skipped(&self) -> bool { + self.params.skip + } + pub fn generate(self) -> proc_macro2::TokenStream { match self.generate_inner() { Ok(tokens) => tokens, @@ -315,39 +211,55 @@ impl Module { } } - fn generate_inner(mut self) -> Result { + fn generate_inner(self) -> Result { // Check for collisions if the "name" attribute was used on inner functions. - check_rename_collisions(&self.fns)?; + crate::rhai_module::check_rename_collisions(&self.fns)?; - // Generate new module items. - // - // This is done before inner module recursive generation, because that is destructive. - let mod_gen = crate::rhai_module::generate_body(&self.fns, &self.consts, &self.submodules); - - // NB: submodules must have their new items for exporting generated in depth-first order to - // avoid issues with reparsing them. - let inner_modules: Vec = self.submodules.drain(..) - .try_fold::, _, - Result, syn::Error>>( - Vec::new(), |mut acc, m| { acc.push(m.generate_inner()?); Ok(acc) })?; - - // Generate new module items for exporting functions and constant. - - // Rebuild the structure of the module, with the new content added. - let Module { mod_all, .. } = self; + // Extract the current structure of the module. + let Module { + mod_all, + fns, + consts, + mut submodules, + params, + .. + } = self; let mut mod_all = mod_all.unwrap(); let mod_name = mod_all.ident.clone(); let (_, orig_content) = mod_all.content.take().unwrap(); let mod_attrs = mem::replace(&mut mod_all.attrs, Vec::with_capacity(0)); - Ok(quote! { - #(#mod_attrs)* - pub mod #mod_name { - #(#orig_content)* - #(#inner_modules)* - #mod_gen - } - }) + if !params.skip { + // Generate new module items. + // + // This is done before inner module recursive generation, because that is destructive. + let mod_gen = crate::rhai_module::generate_body(&fns, &consts, &submodules); + + // NB: submodules must have their new items for exporting generated in depth-first order + // to avoid issues caused by re-parsing them + let inner_modules: Vec = submodules.drain(..) + .try_fold::, _, + Result, syn::Error>>( + Vec::new(), |mut acc, m| { acc.push(m.generate_inner()?); Ok(acc) })?; + + // Regenerate the module with the new content added. + Ok(quote! { + #(#mod_attrs)* + pub mod #mod_name { + #(#orig_content)* + #(#inner_modules)* + #mod_gen + } + }) + } else { + // Regenerate the original module as-is. + Ok(quote! { + #(#mod_attrs)* + pub mod #mod_name { + #(#orig_content)* + } + }) + } } pub fn name(&self) -> Option<&syn::Ident> { @@ -356,7 +268,10 @@ impl Module { pub fn content(&self) -> Option<&Vec> { match self.mod_all { - Some(syn::ItemMod { content: Some((_, ref vec)), .. }) => Some(vec), + Some(syn::ItemMod { + content: Some((_, ref vec)), + .. + }) => Some(vec), _ => None, } } @@ -495,7 +410,8 @@ mod module_tests { assert!(item_mod.fns.is_empty()); assert!(item_mod.consts.is_empty()); assert_eq!(item_mod.submodules.len(), 1); - assert!(item_mod.submodules[0].fns.is_empty()); + assert_eq!(item_mod.submodules[0].fns.len(), 1); + assert!(item_mod.submodules[0].fns[0].params.skip); assert!(item_mod.submodules[0].consts.is_empty()); assert!(item_mod.submodules[0].submodules.is_empty()); } @@ -516,7 +432,8 @@ mod module_tests { let item_mod = syn::parse2::(input_tokens).unwrap(); assert!(item_mod.fns.is_empty()); assert!(item_mod.consts.is_empty()); - assert!(item_mod.submodules.is_empty()); + assert_eq!(item_mod.submodules.len(), 1); + assert!(item_mod.submodules[0].params.skip); } #[test] @@ -548,7 +465,8 @@ mod module_tests { }; let item_mod = syn::parse2::(input_tokens).unwrap(); - assert!(item_mod.fns.is_empty()); + assert_eq!(item_mod.fns.len(), 1); + assert!(item_mod.fns[0].params.skip); assert!(item_mod.consts.is_empty()); } @@ -564,7 +482,8 @@ mod module_tests { }; let item_mod = syn::parse2::(input_tokens).unwrap(); - assert!(item_mod.fns.is_empty()); + assert_eq!(item_mod.fns.len(), 1); + assert!(item_mod.fns[0].params.skip); assert!(item_mod.consts.is_empty()); } @@ -599,7 +518,7 @@ mod generate_tests { .zip(expected.chars()) .inspect(|_| counter += 1) .skip_while(|(a, e)| *a == *e); - let (actual_diff, expected_diff) = { + let (_actual_diff, _expected_diff) = { let mut actual_diff = String::new(); let mut expected_diff = String::new(); for (a, e) in iter.take(50) { @@ -750,6 +669,109 @@ mod generate_tests { assert_streams_eq(item_mod.generate(), expected_tokens); } + #[test] + fn two_fn_overload_module() { + let input_tokens: TokenStream = quote! { + pub mod two_fns { + #[rhai_fn(name = "add_n")] + pub fn add_one_to(x: INT) -> INT { + x + 1 + } + + #[rhai_fn(name = "add_n")] + pub fn add_n_to(x: INT, y: INT) -> INT { + x + y + } + } + }; + + let expected_tokens = quote! { + pub mod two_fns { + pub fn add_one_to(x: INT) -> INT { + x + 1 + } + + pub fn add_n_to(x: INT, y: INT) -> INT { + x + y + } + + #[allow(unused_imports)] + use super::*; + #[allow(unused_mut)] + pub fn rhai_module_generate() -> Module { + let mut m = Module::new(); + m.set_fn("add_n", FnAccess::Public, &[core::any::TypeId::of::()], + CallableFunction::from_plugin(add_one_to_token())); + m.set_fn("add_n", FnAccess::Public, &[core::any::TypeId::of::(), + core::any::TypeId::of::()], + CallableFunction::from_plugin(add_n_to_token())); + m + } + + #[allow(non_camel_case_types)] + struct add_one_to_token(); + impl PluginFunction for add_one_to_token { + fn call(&self, + args: &mut [&mut Dynamic], pos: Position + ) -> Result> { + debug_assert_eq!(args.len(), 1usize, + "wrong arg count: {} != {}", args.len(), 1usize); + let arg0 = mem::take(args[0usize]).clone().cast::(); + Ok(Dynamic::from(add_one_to(arg0))) + } + + fn is_method_call(&self) -> bool { false } + fn is_varadic(&self) -> bool { false } + fn clone_boxed(&self) -> Box { + Box::new(add_one_to_token()) + } + fn input_types(&self) -> Box<[TypeId]> { + new_vec![TypeId::of::()].into_boxed_slice() + } + } + pub fn add_one_to_token_callable() -> CallableFunction { + CallableFunction::from_plugin(add_one_to_token()) + } + pub fn add_one_to_token_input_types() -> Box<[TypeId]> { + add_one_to_token().input_types() + } + + #[allow(non_camel_case_types)] + struct add_n_to_token(); + impl PluginFunction for add_n_to_token { + fn call(&self, + args: &mut [&mut Dynamic], pos: Position + ) -> Result> { + debug_assert_eq!(args.len(), 2usize, + "wrong arg count: {} != {}", args.len(), 2usize); + let arg0 = mem::take(args[0usize]).clone().cast::(); + let arg1 = mem::take(args[1usize]).clone().cast::(); + Ok(Dynamic::from(add_n_to(arg0, arg1))) + } + + fn is_method_call(&self) -> bool { false } + fn is_varadic(&self) -> bool { false } + fn clone_boxed(&self) -> Box { + Box::new(add_n_to_token()) + } + fn input_types(&self) -> Box<[TypeId]> { + new_vec![TypeId::of::(), + TypeId::of::()].into_boxed_slice() + } + } + pub fn add_n_to_token_callable() -> CallableFunction { + CallableFunction::from_plugin(add_n_to_token()) + } + pub fn add_n_to_token_input_types() -> Box<[TypeId]> { + add_n_to_token().input_types() + } + } + }; + + let item_mod = syn::parse2::(input_tokens).unwrap(); + assert_streams_eq(item_mod.generate(), expected_tokens); + } + #[test] fn one_double_arg_fn_module() { let input_tokens: TokenStream = quote! { @@ -924,6 +946,70 @@ mod generate_tests { assert_streams_eq(item_mod.generate(), expected_tokens); } + #[test] + fn one_skipped_submodule() { + let input_tokens: TokenStream = quote! { + pub mod one_fn { + pub fn get_mystic_number() -> INT { + 42 + } + #[rhai_mod(skip)] + pub mod inner_secrets { + pub const SECRET_NUMBER: INT = 86; + } + } + }; + + let expected_tokens = quote! { + pub mod one_fn { + pub fn get_mystic_number() -> INT { + 42 + } + pub mod inner_secrets { + pub const SECRET_NUMBER: INT = 86; + } + #[allow(unused_imports)] + use super::*; + #[allow(unused_mut)] + pub fn rhai_module_generate() -> Module { + let mut m = Module::new(); + m.set_fn("get_mystic_number", FnAccess::Public, &[], + CallableFunction::from_plugin(get_mystic_number_token())); + m + } + #[allow(non_camel_case_types)] + struct get_mystic_number_token(); + impl PluginFunction for get_mystic_number_token { + fn call(&self, + args: &mut [&mut Dynamic], pos: Position + ) -> Result> { + debug_assert_eq!(args.len(), 0usize, + "wrong arg count: {} != {}", args.len(), 0usize); + Ok(Dynamic::from(get_mystic_number())) + } + + fn is_method_call(&self) -> bool { false } + fn is_varadic(&self) -> bool { false } + fn clone_boxed(&self) -> Box { + Box::new(get_mystic_number_token()) + } + fn input_types(&self) -> Box<[TypeId]> { + new_vec![].into_boxed_slice() + } + } + pub fn get_mystic_number_token_callable() -> CallableFunction { + CallableFunction::from_plugin(get_mystic_number_token()) + } + pub fn get_mystic_number_token_input_types() -> Box<[TypeId]> { + get_mystic_number_token().input_types() + } + } + }; + + let item_mod = syn::parse2::(input_tokens).unwrap(); + assert_streams_eq(item_mod.generate(), expected_tokens); + } + #[test] fn one_private_constant_module() { let input_tokens: TokenStream = quote! { diff --git a/codegen/src/rhai_module.rs b/codegen/src/rhai_module.rs index c327c71f..74144081 100644 --- a/codegen/src/rhai_module.rs +++ b/codegen/src/rhai_module.rs @@ -1,4 +1,6 @@ -use quote::quote; +use std::collections::HashMap; + +use quote::{quote, ToTokens}; use crate::function::ExportedFn; use crate::module::Module; @@ -6,9 +8,9 @@ use crate::module::Module; pub(crate) type ExportedConst = (String, syn::Expr); pub(crate) fn generate_body( - fns: &Vec, - consts: &Vec, - submodules: &Vec, + fns: &[ExportedFn], + consts: &[ExportedConst], + submodules: &[Module], ) -> proc_macro2::TokenStream { let mut set_fn_stmts: Vec = Vec::new(); let mut set_const_stmts: Vec = Vec::new(); @@ -26,17 +28,21 @@ pub(crate) fn generate_body( } for itemmod in submodules { + if itemmod.skipped() { + continue; + } let module_name: &syn::Ident = itemmod.module_name().unwrap(); let exported_name: syn::LitStr = if let Some(name) = itemmod.exported_name() { syn::LitStr::new(&name, proc_macro2::Span::call_site()) } else { syn::LitStr::new(&module_name.to_string(), proc_macro2::Span::call_site()) }; - let cfg_attrs: Vec<&syn::Attribute> = itemmod.attrs().unwrap().iter().filter(|&a| { - a.path.get_ident() - .map(|i| i.to_string() == "cfg") - .unwrap_or(false) - }).collect(); + let cfg_attrs: Vec<&syn::Attribute> = itemmod + .attrs() + .unwrap() + .iter() + .filter(|&a| a.path.get_ident().map(|i| *i == "cfg").unwrap_or(false)) + .collect(); add_mod_blocks.push( syn::parse2::(quote! { #(#cfg_attrs)* { @@ -47,10 +53,12 @@ pub(crate) fn generate_body( ); } - // NB: these are token streams, because reparsing messes up "> >" vs ">>" let mut gen_fn_tokens: Vec = Vec::new(); for function in fns { + if function.params.skip { + continue; + } let fn_token_name = syn::Ident::new( &format!("{}_token", function.name().to_string()), function.name().span(), @@ -67,24 +75,24 @@ pub(crate) fn generate_body( syn::FnArg::Receiver(_) => panic!("internal error: receiver fn outside impl!?"), syn::FnArg::Typed(syn::PatType { ref ty, .. }) => { let arg_type = match ty.as_ref() { - &syn::Type::Reference(syn::TypeReference { + syn::Type::Reference(syn::TypeReference { mutability: None, ref elem, .. }) => match elem.as_ref() { - &syn::Type::Path(ref p) if p.path == str_type_path => { + syn::Type::Path(ref p) if p.path == str_type_path => { syn::parse2::(quote! { ImmutableString }) .unwrap() } _ => panic!("internal error: non-string shared reference!?"), }, - &syn::Type::Reference(syn::TypeReference { + syn::Type::Reference(syn::TypeReference { mutability: Some(_), ref elem, .. }) => match elem.as_ref() { - &syn::Type::Path(ref p) => syn::parse2::(quote! { + syn::Type::Path(ref p) => syn::parse2::(quote! { #p }) .unwrap(), _ => panic!("internal error: non-string shared reference!?"), @@ -138,3 +146,53 @@ pub(crate) fn generate_body( #(#gen_fn_tokens)* } } + +pub(crate) fn check_rename_collisions(fns: &Vec) -> Result<(), syn::Error> { + let mut renames = HashMap::::new(); + let mut names = HashMap::::new(); + for itemfn in fns.iter() { + if let Some(ref name) = itemfn.params.name { + let current_span = itemfn.params.span.as_ref().unwrap(); + let key = itemfn.arg_list().fold(name.clone(), |mut argstr, fnarg| { + let type_string: String = match fnarg { + syn::FnArg::Receiver(_) => unimplemented!("receiver rhai_fns not implemented"), + syn::FnArg::Typed(syn::PatType { ref ty, .. }) => { + ty.as_ref().to_token_stream().to_string() + } + }; + argstr.push('.'); + argstr.push_str(&type_string); + argstr + }); + if let Some(other_span) = renames.insert(key, *current_span) { + let mut err = syn::Error::new( + *current_span, + format!("duplicate Rhai signature for '{}'", &name), + ); + err.combine(syn::Error::new( + other_span, + format!("duplicated function renamed '{}'", &name), + )); + return Err(err); + } + } else { + let ident = itemfn.name(); + names.insert(ident.to_string(), ident.span()); + } + } + for (new_name, attr_span) in renames.drain() { + let new_name = new_name.split('.').next().unwrap(); + if let Some(fn_span) = names.get(new_name) { + let mut err = syn::Error::new( + attr_span, + format!("duplicate Rhai signature for '{}'", &new_name), + ); + err.combine(syn::Error::new( + *fn_span, + format!("duplicated function '{}'", &new_name), + )); + return Err(err); + } + } + Ok(()) +} diff --git a/codegen/ui_tests/export_fn_raw_noreturn.rs b/codegen/ui_tests/export_fn_raw_noreturn.rs new file mode 100644 index 00000000..7c8b42e0 --- /dev/null +++ b/codegen/ui_tests/export_fn_raw_noreturn.rs @@ -0,0 +1,25 @@ +use rhai::plugin::*; + +#[derive(Clone)] +struct Point { + x: f32, + y: f32, +} + +#[export_fn(return_raw)] +pub fn test_fn(input: &mut Point) { + input.x += 1.0; +} + +fn main() { + let n = Point { + x: 0.0, + y: 10.0, + }; + test_fn(&mut n); + if n.x >= 10.0 { + println!("yes"); + } else { + println!("no"); + } +} diff --git a/codegen/ui_tests/export_fn_raw_noreturn.stderr b/codegen/ui_tests/export_fn_raw_noreturn.stderr new file mode 100644 index 00000000..0687c8c6 --- /dev/null +++ b/codegen/ui_tests/export_fn_raw_noreturn.stderr @@ -0,0 +1,11 @@ +error: return_raw functions must return Result + --> $DIR/export_fn_raw_noreturn.rs:10:5 + | +10 | pub fn test_fn(input: &mut Point) { + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +error[E0425]: cannot find function `test_fn` in this scope + --> $DIR/export_fn_raw_noreturn.rs:19:5 + | +19 | test_fn(&mut n); + | ^^^^^^^ not found in this scope diff --git a/codegen/ui_tests/export_fn_raw_return.rs b/codegen/ui_tests/export_fn_raw_return.rs new file mode 100644 index 00000000..9df99549 --- /dev/null +++ b/codegen/ui_tests/export_fn_raw_return.rs @@ -0,0 +1,24 @@ +use rhai::plugin::*; + +#[derive(Clone)] +struct Point { + x: f32, + y: f32, +} + +#[export_fn(return_raw)] +pub fn test_fn(input: Point) -> bool { + input.x > input.y +} + +fn main() { + let n = Point { + x: 0.0, + y: 10.0, + }; + if test_fn(n) { + println!("yes"); + } else { + println!("no"); + } +} diff --git a/codegen/ui_tests/export_fn_raw_return.stderr b/codegen/ui_tests/export_fn_raw_return.stderr new file mode 100644 index 00000000..f570fda9 --- /dev/null +++ b/codegen/ui_tests/export_fn_raw_return.stderr @@ -0,0 +1,21 @@ +error[E0308]: mismatched types + --> $DIR/export_fn_raw_return.rs:10:8 + | +9 | #[export_fn(return_raw)] + | ------------------------ expected `std::result::Result>` because of return type +10 | pub fn test_fn(input: Point) -> bool { + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected enum `std::result::Result`, found `bool` + | + = note: expected enum `std::result::Result>` + found type `bool` + +error[E0308]: mismatched types + --> $DIR/export_fn_raw_return.rs:10:33 + | +9 | #[export_fn(return_raw)] + | ------------------------ expected `std::result::Result>` because of return type +10 | pub fn test_fn(input: Point) -> bool { + | ^^^^ expected enum `std::result::Result`, found `bool` + | + = note: expected enum `std::result::Result>` + found type `bool` diff --git a/codegen/ui_tests/rhai_fn_rename_dot.stderr b/codegen/ui_tests/rhai_fn_rename_dot.stderr index f650a081..61299e8b 100644 --- a/codegen/ui_tests/rhai_fn_rename_dot.stderr +++ b/codegen/ui_tests/rhai_fn_rename_dot.stderr @@ -1,8 +1,8 @@ error: Rhai function names may not contain dot - --> $DIR/rhai_fn_rename_dot.rs:12:15 + --> $DIR/rhai_fn_rename_dot.rs:12:22 | 12 | #[rhai_fn(name = "foo.bar")] - | ^^^^^^^^^^^^^^^^ + | ^^^^^^^^^ error[E0433]: failed to resolve: use of undeclared type or module `test_module` --> $DIR/rhai_fn_rename_dot.rs:23:8 diff --git a/src/packages/string_more.rs b/src/packages/string_more.rs index fbcb2acd..03704f5a 100644 --- a/src/packages/string_more.rs +++ b/src/packages/string_more.rs @@ -292,7 +292,7 @@ mod string_functions { } #[rhai_fn(name = "crop")] - fn crop_string(s: &mut ImmutableString, start: INT, len: INT) { + pub fn crop_string(s: &mut ImmutableString, start: INT, len: INT) { let offset = if s.is_empty() || len <= 0 { s.make_mut().clear(); return;