diff --git a/codegen/src/attrs.rs b/codegen/src/attrs.rs index 1850d294..5d1e42d7 100644 --- a/codegen/src/attrs.rs +++ b/codegen/src/attrs.rs @@ -1,5 +1,18 @@ use syn::{parse::ParseStream, parse::Parser, spanned::Spanned}; +#[derive(Debug)] +pub enum ExportScope { + PubOnly, + Prefix(String), + All, +} + +impl Default for ExportScope { + fn default() -> ExportScope { + ExportScope::PubOnly + } +} + pub trait ExportedParams: Sized { fn parse_stream(args: ParseStream) -> syn::Result; fn no_attrs() -> Self; diff --git a/codegen/src/function.rs b/codegen/src/function.rs index 337eccab..f5bb6989 100644 --- a/codegen/src/function.rs +++ b/codegen/src/function.rs @@ -10,10 +10,12 @@ use alloc::format; #[cfg(not(no_std))] use std::format; +use std::borrow::Cow; + use quote::{quote, quote_spanned}; use syn::{parse::Parse, parse::ParseStream, parse::Parser, spanned::Spanned}; -use crate::attrs::{ExportInfo, ExportedParams}; +use crate::attrs::{ExportInfo, ExportScope, ExportedParams}; #[derive(Debug, Default)] pub(crate) struct ExportedFnParams { @@ -222,10 +224,24 @@ impl ExportedFn { &self.params } + pub(crate) fn update_scope(&mut self, parent_scope: &ExportScope) { + let keep = match (self.params.skip, parent_scope) { + (true, _) => false, + (_, ExportScope::PubOnly) => self.is_public, + (_, ExportScope::Prefix(s)) => self.exported_name().as_ref().starts_with(s), + (_, ExportScope::All) => true, + }; + self.params.skip = !keep; + } + pub(crate) fn skipped(&self) -> bool { self.params.skip } + pub(crate) fn signature(&self) -> &syn::Signature { + &self.signature + } + pub(crate) fn mutable_receiver(&self) -> bool { self.mut_receiver } @@ -242,6 +258,14 @@ impl ExportedFn { &self.signature.ident } + pub(crate) fn exported_name<'n>(&'n self) -> Cow<'n, str> { + if let Some(ref name) = self.params.name { + Cow::Borrowed(name.as_str()) + } else { + Cow::Owned(self.signature.ident.to_string()) + } + } + pub(crate) fn arg_list(&self) -> impl Iterator { self.signature.inputs.iter() } diff --git a/codegen/src/lib.rs b/codegen/src/lib.rs index 7054cbe5..9e47a831 100644 --- a/codegen/src/lib.rs +++ b/codegen/src/lib.rs @@ -124,10 +124,18 @@ pub fn export_fn( #[proc_macro_attribute] pub fn export_module( - _args: proc_macro::TokenStream, + args: proc_macro::TokenStream, input: proc_macro::TokenStream, ) -> proc_macro::TokenStream { - let module_def = parse_macro_input!(input as module::Module); + let parsed_params = match crate::attrs::outer_item_attributes(args.into(), "export_module") { + Ok(args) => args, + Err(err) => return proc_macro::TokenStream::from(err.to_compile_error()), + }; + let mut module_def = parse_macro_input!(input as module::Module); + if let Err(e) = module_def.set_params(parsed_params) { + return e.to_compile_error().into(); + } + let tokens = module_def.generate(); proc_macro::TokenStream::from(tokens) } diff --git a/codegen/src/module.rs b/codegen/src/module.rs index b1a76f29..79014b8d 100644 --- a/codegen/src/module.rs +++ b/codegen/src/module.rs @@ -16,13 +16,14 @@ use std::mem; use std::borrow::Cow; -use crate::attrs::{AttrItem, ExportInfo, ExportedParams}; -use crate::function::{ExportedFnParams}; +use crate::attrs::{AttrItem, ExportInfo, ExportScope, ExportedParams}; +use crate::function::ExportedFnParams; #[derive(Debug, Default)] pub(crate) struct ExportedModParams { pub name: Option, - pub skip: bool, + skip: bool, + pub scope: ExportScope, } impl Parse for ExportedModParams { @@ -50,6 +51,7 @@ impl ExportedParams for ExportedModParams { let ExportInfo { items: attrs, .. } = info; let mut name = None; let mut skip = false; + let mut scope = ExportScope::default(); for attr in attrs { let AttrItem { key, value } = attr; match (key.to_string().as_ref(), value) { @@ -57,6 +59,14 @@ impl ExportedParams for ExportedModParams { ("name", None) => return Err(syn::Error::new(key.span(), "requires value")), ("skip", None) => skip = true, ("skip", Some(s)) => return Err(syn::Error::new(s.span(), "extraneous value")), + ("export_prefix", Some(s)) => scope = ExportScope::Prefix(s.value()), + ("export_prefix", None) => { + return Err(syn::Error::new(key.span(), "requires value")) + } + ("export_all", None) => scope = ExportScope::All, + ("export_all", Some(s)) => { + return Err(syn::Error::new(s.span(), "extraneous value")) + } (attr, _) => { return Err(syn::Error::new( key.span(), @@ -69,6 +79,7 @@ impl ExportedParams for ExportedModParams { Ok(ExportedModParams { name, skip, + scope, ..Default::default() }) } @@ -83,6 +94,13 @@ pub(crate) struct Module { params: ExportedModParams, } +impl Module { + pub fn set_params(&mut self, params: ExportedModParams) -> syn::Result<()> { + self.params = params; + Ok(()) + } +} + impl Parse for Module { fn parse(input: ParseStream) -> syn::Result { let mut mod_all: syn::ItemMod = input.parse()?; @@ -101,16 +119,11 @@ impl Parse for Module { // #[cfg] attributes are not allowed on functions crate::attrs::deny_cfg_attr(&itemfn.attrs)?; - let mut params: ExportedFnParams = + let params: ExportedFnParams = match crate::attrs::inner_item_attributes(&mut itemfn.attrs, "rhai_fn") { Ok(p) => p, Err(e) => return Err(e), }; - params.skip = if let syn::Visibility::Public(_) = itemfn.vis { - params.skip - } else { - true - }; syn::parse2::(itemfn.to_token_stream()) .and_then(|mut f| { f.set_params(params)?; @@ -150,20 +163,15 @@ impl Parse for Module { syn::Item::Mod(m) => m, _ => unreachable!(), }; - let mut params: ExportedModParams = + let params: ExportedModParams = match crate::attrs::inner_item_attributes(&mut itemmod.attrs, "rhai_mod") { Ok(p) => p, Err(e) => return Err(e), }; - params.skip = if let syn::Visibility::Public(_) = itemmod.vis { - params.skip - } else { - true - }; let module = - syn::parse2::(itemmod.to_token_stream()).map(|mut f| { - f.params = params; - f + syn::parse2::(itemmod.to_token_stream()).and_then(|mut m| { + m.set_params(params)?; + Ok(m) })?; submodules.push(module); } else { @@ -200,6 +208,28 @@ impl Module { } } + pub fn update_scope(&mut self, parent_scope: &ExportScope) { + let keep = match (self.params.skip, parent_scope) { + (true, _) => false, + (_, ExportScope::PubOnly) => { + if let Some(ref mod_all) = self.mod_all { + matches!(mod_all.vis, syn::Visibility::Public(_)) + } else { + false + } + } + (_, ExportScope::Prefix(s)) => { + if let Some(ref mod_all) = self.mod_all { + mod_all.ident.to_string().starts_with(s) + } else { + false + } + } + (_, ExportScope::All) => true, + }; + self.params.skip = !keep; + } + pub fn skipped(&self) -> bool { self.params.skip } @@ -218,7 +248,7 @@ impl Module { // Extract the current structure of the module. let Module { mod_all, - fns, + mut fns, consts, mut submodules, params, @@ -233,7 +263,12 @@ impl Module { // Generate new module items. // // This is done before inner module recursive generation, because that is destructive. - let mod_gen = crate::rhai_module::generate_body(&fns, &consts, &submodules); + let mod_gen = crate::rhai_module::generate_body( + &mut fns, + &consts, + &mut submodules, + ¶ms.scope, + ); // NB: submodules must have their new items for exporting generated in depth-first order // to avoid issues caused by re-parsing them @@ -454,22 +489,6 @@ mod module_tests { ); } - #[test] - fn one_private_fn_module() { - let input_tokens: TokenStream = quote! { - pub mod one_fn { - fn get_mystic_number() -> INT { - 42 - } - } - }; - - let item_mod = syn::parse2::(input_tokens).unwrap(); - assert_eq!(item_mod.fns.len(), 1); - assert!(item_mod.fns[0].skipped()); - assert!(item_mod.consts.is_empty()); - } - #[test] fn one_skipped_fn_module() { let input_tokens: TokenStream = quote! { diff --git a/codegen/src/rhai_module.rs b/codegen/src/rhai_module.rs index 7e1aa106..40485adb 100644 --- a/codegen/src/rhai_module.rs +++ b/codegen/src/rhai_module.rs @@ -2,15 +2,17 @@ use std::collections::HashMap; use quote::{quote, ToTokens}; +use crate::attrs::ExportScope; use crate::function::ExportedFn; use crate::module::Module; pub(crate) type ExportedConst = (String, syn::Expr); pub(crate) fn generate_body( - fns: &[ExportedFn], + fns: &mut [ExportedFn], consts: &[ExportedConst], - submodules: &[Module], + submodules: &mut [Module], + parent_scope: &ExportScope, ) -> proc_macro2::TokenStream { let mut set_fn_stmts: Vec = Vec::new(); let mut set_const_stmts: Vec = Vec::new(); @@ -28,6 +30,7 @@ pub(crate) fn generate_body( } for itemmod in submodules { + itemmod.update_scope(&parent_scope); if itemmod.skipped() { continue; } @@ -56,6 +59,7 @@ pub(crate) fn generate_body( // NB: these are token streams, because reparsing messes up "> >" vs ">>" let mut gen_fn_tokens: Vec = Vec::new(); for function in fns { + function.update_scope(&parent_scope); if function.skipped() { continue; } diff --git a/codegen/tests/test_modules.rs b/codegen/tests/test_modules.rs index dd5acd05..3ce1bf37 100644 --- a/codegen/tests/test_modules.rs +++ b/codegen/tests/test_modules.rs @@ -221,3 +221,168 @@ fn duplicate_fn_rename_test() -> Result<(), Box> { assert_eq!(&output_array[1].as_int().unwrap(), &43); Ok(()) } + +mod export_by_prefix { + use rhai::plugin::*; + #[export_module(export_prefix = "foo_")] + pub mod my_adds { + use rhai::{FLOAT, INT}; + + #[rhai_fn(name = "foo_add_f")] + pub fn add_float(f1: FLOAT, f2: FLOAT) -> FLOAT { + f1 + f2 + } + + #[rhai_fn(name = "foo_add_i")] + fn add_int(i1: INT, i2: INT) -> INT { + i1 + i2 + } + + #[rhai_fn(name = "bar_add")] + pub fn add_float2(f1: FLOAT, f2: FLOAT) -> FLOAT { + f1 + f2 + } + + pub fn foo_m(f1: FLOAT, f2: FLOAT) -> FLOAT { + f1 + f2 + } + + fn foo_n(i1: INT, i2: INT) -> INT { + i1 + i2 + } + + pub fn bar_m(f1: FLOAT, f2: FLOAT) -> FLOAT { + f1 + f2 + } + } +} + +#[test] +fn export_by_prefix_test() -> Result<(), Box> { + let mut engine = Engine::new(); + let m = rhai::exported_module!(crate::export_by_prefix::my_adds); + let mut r = StaticModuleResolver::new(); + r.insert("Math::Advanced".to_string(), m); + engine.set_module_resolver(Some(r)); + + let output_array = engine.eval::( + r#"import "Math::Advanced" as math; + let ex = 41.0; + let fx = math::foo_add_f(ex, 1.0); + let gx = math::foo_m(41.0, 1.0); + let ei = 41; + let fi = math::foo_add_i(ei, 1); + let gi = math::foo_n(41, 1); + [fx, gx, fi, gi] + "#, + )?; + assert_eq!(&output_array[0].as_float().unwrap(), &42.0); + assert_eq!(&output_array[1].as_float().unwrap(), &42.0); + assert_eq!(&output_array[2].as_int().unwrap(), &42); + assert_eq!(&output_array[3].as_int().unwrap(), &42); + + assert!(matches!(*engine.eval::( + r#"import "Math::Advanced" as math; + let ex = 41.0; + let fx = math::bar_add(ex, 1.0); + fx + "#).unwrap_err(), + EvalAltResult::ErrorFunctionNotFound(s, p) + if s == "math::bar_add (f64, f64)" + && p == rhai::Position::new(3, 23))); + + assert!(matches!(*engine.eval::( + r#"import "Math::Advanced" as math; + let ex = 41.0; + let fx = math::add_float2(ex, 1.0); + fx + "#).unwrap_err(), + EvalAltResult::ErrorFunctionNotFound(s, p) + if s == "math::add_float2 (f64, f64)" + && p == rhai::Position::new(3, 23))); + + assert!(matches!(*engine.eval::( + r#"import "Math::Advanced" as math; + let ex = 41.0; + let fx = math::bar_m(ex, 1.0); + fx + "#).unwrap_err(), + EvalAltResult::ErrorFunctionNotFound(s, p) + if s == "math::bar_m (f64, f64)" + && p == rhai::Position::new(3, 23))); + + Ok(()) +} + +mod export_all { + use rhai::plugin::*; + #[export_module(export_all)] + pub mod my_adds { + use rhai::{FLOAT, INT}; + + #[rhai_fn(name = "foo_add_f")] + pub fn add_float(f1: FLOAT, f2: FLOAT) -> FLOAT { + f1 + f2 + } + + #[rhai_fn(name = "foo_add_i")] + fn add_int(i1: INT, i2: INT) -> INT { + i1 + i2 + } + + #[rhai_fn(skip)] + pub fn add_float2(f1: FLOAT, f2: FLOAT) -> FLOAT { + f1 + f2 + } + + pub fn foo_m(f1: FLOAT, f2: FLOAT) -> FLOAT { + f1 + f2 + } + + fn foo_n(i1: INT, i2: INT) -> INT { + i1 + i2 + } + + #[rhai_fn(skip)] + fn foo_p(i1: INT, i2: INT) -> INT { + i1 * i2 + } + } +} + +#[test] +fn export_all_test() -> Result<(), Box> { + let mut engine = Engine::new(); + let m = rhai::exported_module!(crate::export_all::my_adds); + let mut r = StaticModuleResolver::new(); + r.insert("Math::Advanced".to_string(), m); + engine.set_module_resolver(Some(r)); + + let output_array = engine.eval::( + r#"import "Math::Advanced" as math; + let ex = 41.0; + let fx = math::foo_add_f(ex, 1.0); + let gx = math::foo_m(41.0, 1.0); + let ei = 41; + let fi = math::foo_add_i(ei, 1); + let gi = math::foo_n(41, 1); + [fx, gx, fi, gi] + "#, + )?; + assert_eq!(&output_array[0].as_float().unwrap(), &42.0); + assert_eq!(&output_array[1].as_float().unwrap(), &42.0); + assert_eq!(&output_array[2].as_int().unwrap(), &42); + assert_eq!(&output_array[3].as_int().unwrap(), &42); + + assert!(matches!(*engine.eval::( + r#"import "Math::Advanced" as math; + let ex = 41; + let fx = math::foo_p(ex, 1); + fx + "#).unwrap_err(), + EvalAltResult::ErrorFunctionNotFound(s, p) + if s == "math::foo_p (i64, i64)" + && p == rhai::Position::new(3, 23))); + + Ok(()) +}