Implement export_* attributes for macros

This commit is contained in:
J Henry Waugh 2020-09-01 23:15:22 -05:00
parent 91b4f8a6bc
commit 3af49cec70
6 changed files with 274 additions and 41 deletions

View File

@ -1,5 +1,18 @@
use syn::{parse::ParseStream, parse::Parser, spanned::Spanned}; use syn::{parse::ParseStream, parse::Parser, spanned::Spanned};
#[derive(Debug)]
pub enum ExportScope {
PubOnly,
Prefix(String),
All,
}
impl Default for ExportScope {
fn default() -> ExportScope {
ExportScope::PubOnly
}
}
pub trait ExportedParams: Sized { pub trait ExportedParams: Sized {
fn parse_stream(args: ParseStream) -> syn::Result<Self>; fn parse_stream(args: ParseStream) -> syn::Result<Self>;
fn no_attrs() -> Self; fn no_attrs() -> Self;

View File

@ -10,10 +10,12 @@ use alloc::format;
#[cfg(not(no_std))] #[cfg(not(no_std))]
use std::format; use std::format;
use std::borrow::Cow;
use quote::{quote, quote_spanned}; use quote::{quote, quote_spanned};
use syn::{parse::Parse, parse::ParseStream, parse::Parser, spanned::Spanned}; use syn::{parse::Parse, parse::ParseStream, parse::Parser, spanned::Spanned};
use crate::attrs::{ExportInfo, ExportedParams}; use crate::attrs::{ExportInfo, ExportScope, ExportedParams};
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub(crate) struct ExportedFnParams { pub(crate) struct ExportedFnParams {
@ -222,10 +224,24 @@ impl ExportedFn {
&self.params &self.params
} }
pub(crate) fn update_scope(&mut self, parent_scope: &ExportScope) {
let keep = match (self.params.skip, parent_scope) {
(true, _) => false,
(_, ExportScope::PubOnly) => self.is_public,
(_, ExportScope::Prefix(s)) => self.exported_name().as_ref().starts_with(s),
(_, ExportScope::All) => true,
};
self.params.skip = !keep;
}
pub(crate) fn skipped(&self) -> bool { pub(crate) fn skipped(&self) -> bool {
self.params.skip self.params.skip
} }
pub(crate) fn signature(&self) -> &syn::Signature {
&self.signature
}
pub(crate) fn mutable_receiver(&self) -> bool { pub(crate) fn mutable_receiver(&self) -> bool {
self.mut_receiver self.mut_receiver
} }
@ -242,6 +258,14 @@ impl ExportedFn {
&self.signature.ident &self.signature.ident
} }
pub(crate) fn exported_name<'n>(&'n self) -> Cow<'n, str> {
if let Some(ref name) = self.params.name {
Cow::Borrowed(name.as_str())
} else {
Cow::Owned(self.signature.ident.to_string())
}
}
pub(crate) fn arg_list(&self) -> impl Iterator<Item = &syn::FnArg> { pub(crate) fn arg_list(&self) -> impl Iterator<Item = &syn::FnArg> {
self.signature.inputs.iter() self.signature.inputs.iter()
} }

View File

@ -124,10 +124,18 @@ pub fn export_fn(
#[proc_macro_attribute] #[proc_macro_attribute]
pub fn export_module( pub fn export_module(
_args: proc_macro::TokenStream, args: proc_macro::TokenStream,
input: proc_macro::TokenStream, input: proc_macro::TokenStream,
) -> proc_macro::TokenStream { ) -> proc_macro::TokenStream {
let module_def = parse_macro_input!(input as module::Module); let parsed_params = match crate::attrs::outer_item_attributes(args.into(), "export_module") {
Ok(args) => args,
Err(err) => return proc_macro::TokenStream::from(err.to_compile_error()),
};
let mut module_def = parse_macro_input!(input as module::Module);
if let Err(e) = module_def.set_params(parsed_params) {
return e.to_compile_error().into();
}
let tokens = module_def.generate(); let tokens = module_def.generate();
proc_macro::TokenStream::from(tokens) proc_macro::TokenStream::from(tokens)
} }

View File

@ -16,13 +16,14 @@ use std::mem;
use std::borrow::Cow; use std::borrow::Cow;
use crate::attrs::{AttrItem, ExportInfo, ExportedParams}; use crate::attrs::{AttrItem, ExportInfo, ExportScope, ExportedParams};
use crate::function::{ExportedFnParams}; use crate::function::ExportedFnParams;
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub(crate) struct ExportedModParams { pub(crate) struct ExportedModParams {
pub name: Option<String>, pub name: Option<String>,
pub skip: bool, skip: bool,
pub scope: ExportScope,
} }
impl Parse for ExportedModParams { impl Parse for ExportedModParams {
@ -50,6 +51,7 @@ impl ExportedParams for ExportedModParams {
let ExportInfo { items: attrs, .. } = info; let ExportInfo { items: attrs, .. } = info;
let mut name = None; let mut name = None;
let mut skip = false; let mut skip = false;
let mut scope = ExportScope::default();
for attr in attrs { for attr in attrs {
let AttrItem { key, value } = attr; let AttrItem { key, value } = attr;
match (key.to_string().as_ref(), value) { match (key.to_string().as_ref(), value) {
@ -57,6 +59,14 @@ impl ExportedParams for ExportedModParams {
("name", None) => return Err(syn::Error::new(key.span(), "requires value")), ("name", None) => return Err(syn::Error::new(key.span(), "requires value")),
("skip", None) => skip = true, ("skip", None) => skip = true,
("skip", Some(s)) => return Err(syn::Error::new(s.span(), "extraneous value")), ("skip", Some(s)) => return Err(syn::Error::new(s.span(), "extraneous value")),
("export_prefix", Some(s)) => scope = ExportScope::Prefix(s.value()),
("export_prefix", None) => {
return Err(syn::Error::new(key.span(), "requires value"))
}
("export_all", None) => scope = ExportScope::All,
("export_all", Some(s)) => {
return Err(syn::Error::new(s.span(), "extraneous value"))
}
(attr, _) => { (attr, _) => {
return Err(syn::Error::new( return Err(syn::Error::new(
key.span(), key.span(),
@ -69,6 +79,7 @@ impl ExportedParams for ExportedModParams {
Ok(ExportedModParams { Ok(ExportedModParams {
name, name,
skip, skip,
scope,
..Default::default() ..Default::default()
}) })
} }
@ -83,6 +94,13 @@ pub(crate) struct Module {
params: ExportedModParams, params: ExportedModParams,
} }
impl Module {
pub fn set_params(&mut self, params: ExportedModParams) -> syn::Result<()> {
self.params = params;
Ok(())
}
}
impl Parse for Module { impl Parse for Module {
fn parse(input: ParseStream) -> syn::Result<Self> { fn parse(input: ParseStream) -> syn::Result<Self> {
let mut mod_all: syn::ItemMod = input.parse()?; let mut mod_all: syn::ItemMod = input.parse()?;
@ -101,16 +119,11 @@ impl Parse for Module {
// #[cfg] attributes are not allowed on functions // #[cfg] attributes are not allowed on functions
crate::attrs::deny_cfg_attr(&itemfn.attrs)?; crate::attrs::deny_cfg_attr(&itemfn.attrs)?;
let mut params: ExportedFnParams = let params: ExportedFnParams =
match crate::attrs::inner_item_attributes(&mut itemfn.attrs, "rhai_fn") { match crate::attrs::inner_item_attributes(&mut itemfn.attrs, "rhai_fn") {
Ok(p) => p, Ok(p) => p,
Err(e) => return Err(e), Err(e) => return Err(e),
}; };
params.skip = if let syn::Visibility::Public(_) = itemfn.vis {
params.skip
} else {
true
};
syn::parse2::<ExportedFn>(itemfn.to_token_stream()) syn::parse2::<ExportedFn>(itemfn.to_token_stream())
.and_then(|mut f| { .and_then(|mut f| {
f.set_params(params)?; f.set_params(params)?;
@ -150,20 +163,15 @@ impl Parse for Module {
syn::Item::Mod(m) => m, syn::Item::Mod(m) => m,
_ => unreachable!(), _ => unreachable!(),
}; };
let mut params: ExportedModParams = let params: ExportedModParams =
match crate::attrs::inner_item_attributes(&mut itemmod.attrs, "rhai_mod") { match crate::attrs::inner_item_attributes(&mut itemmod.attrs, "rhai_mod") {
Ok(p) => p, Ok(p) => p,
Err(e) => return Err(e), Err(e) => return Err(e),
}; };
params.skip = if let syn::Visibility::Public(_) = itemmod.vis {
params.skip
} else {
true
};
let module = let module =
syn::parse2::<Module>(itemmod.to_token_stream()).map(|mut f| { syn::parse2::<Module>(itemmod.to_token_stream()).and_then(|mut m| {
f.params = params; m.set_params(params)?;
f Ok(m)
})?; })?;
submodules.push(module); submodules.push(module);
} else { } else {
@ -200,6 +208,28 @@ impl Module {
} }
} }
pub fn update_scope(&mut self, parent_scope: &ExportScope) {
let keep = match (self.params.skip, parent_scope) {
(true, _) => false,
(_, ExportScope::PubOnly) => {
if let Some(ref mod_all) = self.mod_all {
matches!(mod_all.vis, syn::Visibility::Public(_))
} else {
false
}
}
(_, ExportScope::Prefix(s)) => {
if let Some(ref mod_all) = self.mod_all {
mod_all.ident.to_string().starts_with(s)
} else {
false
}
}
(_, ExportScope::All) => true,
};
self.params.skip = !keep;
}
pub fn skipped(&self) -> bool { pub fn skipped(&self) -> bool {
self.params.skip self.params.skip
} }
@ -218,7 +248,7 @@ impl Module {
// Extract the current structure of the module. // Extract the current structure of the module.
let Module { let Module {
mod_all, mod_all,
fns, mut fns,
consts, consts,
mut submodules, mut submodules,
params, params,
@ -233,7 +263,12 @@ impl Module {
// Generate new module items. // Generate new module items.
// //
// This is done before inner module recursive generation, because that is destructive. // This is done before inner module recursive generation, because that is destructive.
let mod_gen = crate::rhai_module::generate_body(&fns, &consts, &submodules); let mod_gen = crate::rhai_module::generate_body(
&mut fns,
&consts,
&mut submodules,
&params.scope,
);
// NB: submodules must have their new items for exporting generated in depth-first order // NB: submodules must have their new items for exporting generated in depth-first order
// to avoid issues caused by re-parsing them // to avoid issues caused by re-parsing them
@ -454,22 +489,6 @@ mod module_tests {
); );
} }
#[test]
fn one_private_fn_module() {
let input_tokens: TokenStream = quote! {
pub mod one_fn {
fn get_mystic_number() -> INT {
42
}
}
};
let item_mod = syn::parse2::<Module>(input_tokens).unwrap();
assert_eq!(item_mod.fns.len(), 1);
assert!(item_mod.fns[0].skipped());
assert!(item_mod.consts.is_empty());
}
#[test] #[test]
fn one_skipped_fn_module() { fn one_skipped_fn_module() {
let input_tokens: TokenStream = quote! { let input_tokens: TokenStream = quote! {

View File

@ -2,15 +2,17 @@ use std::collections::HashMap;
use quote::{quote, ToTokens}; use quote::{quote, ToTokens};
use crate::attrs::ExportScope;
use crate::function::ExportedFn; use crate::function::ExportedFn;
use crate::module::Module; use crate::module::Module;
pub(crate) type ExportedConst = (String, syn::Expr); pub(crate) type ExportedConst = (String, syn::Expr);
pub(crate) fn generate_body( pub(crate) fn generate_body(
fns: &[ExportedFn], fns: &mut [ExportedFn],
consts: &[ExportedConst], consts: &[ExportedConst],
submodules: &[Module], submodules: &mut [Module],
parent_scope: &ExportScope,
) -> proc_macro2::TokenStream { ) -> proc_macro2::TokenStream {
let mut set_fn_stmts: Vec<syn::Stmt> = Vec::new(); let mut set_fn_stmts: Vec<syn::Stmt> = Vec::new();
let mut set_const_stmts: Vec<syn::Stmt> = Vec::new(); let mut set_const_stmts: Vec<syn::Stmt> = Vec::new();
@ -28,6 +30,7 @@ pub(crate) fn generate_body(
} }
for itemmod in submodules { for itemmod in submodules {
itemmod.update_scope(&parent_scope);
if itemmod.skipped() { if itemmod.skipped() {
continue; continue;
} }
@ -56,6 +59,7 @@ pub(crate) fn generate_body(
// NB: these are token streams, because reparsing messes up "> >" vs ">>" // NB: these are token streams, because reparsing messes up "> >" vs ">>"
let mut gen_fn_tokens: Vec<proc_macro2::TokenStream> = Vec::new(); let mut gen_fn_tokens: Vec<proc_macro2::TokenStream> = Vec::new();
for function in fns { for function in fns {
function.update_scope(&parent_scope);
if function.skipped() { if function.skipped() {
continue; continue;
} }

View File

@ -221,3 +221,168 @@ fn duplicate_fn_rename_test() -> Result<(), Box<EvalAltResult>> {
assert_eq!(&output_array[1].as_int().unwrap(), &43); assert_eq!(&output_array[1].as_int().unwrap(), &43);
Ok(()) Ok(())
} }
mod export_by_prefix {
use rhai::plugin::*;
#[export_module(export_prefix = "foo_")]
pub mod my_adds {
use rhai::{FLOAT, INT};
#[rhai_fn(name = "foo_add_f")]
pub fn add_float(f1: FLOAT, f2: FLOAT) -> FLOAT {
f1 + f2
}
#[rhai_fn(name = "foo_add_i")]
fn add_int(i1: INT, i2: INT) -> INT {
i1 + i2
}
#[rhai_fn(name = "bar_add")]
pub fn add_float2(f1: FLOAT, f2: FLOAT) -> FLOAT {
f1 + f2
}
pub fn foo_m(f1: FLOAT, f2: FLOAT) -> FLOAT {
f1 + f2
}
fn foo_n(i1: INT, i2: INT) -> INT {
i1 + i2
}
pub fn bar_m(f1: FLOAT, f2: FLOAT) -> FLOAT {
f1 + f2
}
}
}
#[test]
fn export_by_prefix_test() -> Result<(), Box<EvalAltResult>> {
let mut engine = Engine::new();
let m = rhai::exported_module!(crate::export_by_prefix::my_adds);
let mut r = StaticModuleResolver::new();
r.insert("Math::Advanced".to_string(), m);
engine.set_module_resolver(Some(r));
let output_array = engine.eval::<Array>(
r#"import "Math::Advanced" as math;
let ex = 41.0;
let fx = math::foo_add_f(ex, 1.0);
let gx = math::foo_m(41.0, 1.0);
let ei = 41;
let fi = math::foo_add_i(ei, 1);
let gi = math::foo_n(41, 1);
[fx, gx, fi, gi]
"#,
)?;
assert_eq!(&output_array[0].as_float().unwrap(), &42.0);
assert_eq!(&output_array[1].as_float().unwrap(), &42.0);
assert_eq!(&output_array[2].as_int().unwrap(), &42);
assert_eq!(&output_array[3].as_int().unwrap(), &42);
assert!(matches!(*engine.eval::<FLOAT>(
r#"import "Math::Advanced" as math;
let ex = 41.0;
let fx = math::bar_add(ex, 1.0);
fx
"#).unwrap_err(),
EvalAltResult::ErrorFunctionNotFound(s, p)
if s == "math::bar_add (f64, f64)"
&& p == rhai::Position::new(3, 23)));
assert!(matches!(*engine.eval::<FLOAT>(
r#"import "Math::Advanced" as math;
let ex = 41.0;
let fx = math::add_float2(ex, 1.0);
fx
"#).unwrap_err(),
EvalAltResult::ErrorFunctionNotFound(s, p)
if s == "math::add_float2 (f64, f64)"
&& p == rhai::Position::new(3, 23)));
assert!(matches!(*engine.eval::<FLOAT>(
r#"import "Math::Advanced" as math;
let ex = 41.0;
let fx = math::bar_m(ex, 1.0);
fx
"#).unwrap_err(),
EvalAltResult::ErrorFunctionNotFound(s, p)
if s == "math::bar_m (f64, f64)"
&& p == rhai::Position::new(3, 23)));
Ok(())
}
mod export_all {
use rhai::plugin::*;
#[export_module(export_all)]
pub mod my_adds {
use rhai::{FLOAT, INT};
#[rhai_fn(name = "foo_add_f")]
pub fn add_float(f1: FLOAT, f2: FLOAT) -> FLOAT {
f1 + f2
}
#[rhai_fn(name = "foo_add_i")]
fn add_int(i1: INT, i2: INT) -> INT {
i1 + i2
}
#[rhai_fn(skip)]
pub fn add_float2(f1: FLOAT, f2: FLOAT) -> FLOAT {
f1 + f2
}
pub fn foo_m(f1: FLOAT, f2: FLOAT) -> FLOAT {
f1 + f2
}
fn foo_n(i1: INT, i2: INT) -> INT {
i1 + i2
}
#[rhai_fn(skip)]
fn foo_p(i1: INT, i2: INT) -> INT {
i1 * i2
}
}
}
#[test]
fn export_all_test() -> Result<(), Box<EvalAltResult>> {
let mut engine = Engine::new();
let m = rhai::exported_module!(crate::export_all::my_adds);
let mut r = StaticModuleResolver::new();
r.insert("Math::Advanced".to_string(), m);
engine.set_module_resolver(Some(r));
let output_array = engine.eval::<Array>(
r#"import "Math::Advanced" as math;
let ex = 41.0;
let fx = math::foo_add_f(ex, 1.0);
let gx = math::foo_m(41.0, 1.0);
let ei = 41;
let fi = math::foo_add_i(ei, 1);
let gi = math::foo_n(41, 1);
[fx, gx, fi, gi]
"#,
)?;
assert_eq!(&output_array[0].as_float().unwrap(), &42.0);
assert_eq!(&output_array[1].as_float().unwrap(), &42.0);
assert_eq!(&output_array[2].as_int().unwrap(), &42);
assert_eq!(&output_array[3].as_int().unwrap(), &42);
assert!(matches!(*engine.eval::<INT>(
r#"import "Math::Advanced" as math;
let ex = 41;
let fx = math::foo_p(ex, 1);
fx
"#).unwrap_err(),
EvalAltResult::ErrorFunctionNotFound(s, p)
if s == "math::foo_p (i64, i64)"
&& p == rhai::Position::new(3, 23)));
Ok(())
}