diff --git a/RELEASES.md b/RELEASES.md index 494db8b5..6f38bb72 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -14,6 +14,7 @@ Breaking changes The method-call style will bind the object to the `this` parameter instead of consuming the first parameter. * Imported modules are no longer stored in the `Scope`. `Scope::push_module` is removed. Therefore, cannot rely on module imports to persist across invocations using a `Scope`. +* `AST::retain_functions` is used for another purpose. The old `AST::retain_functions` is renamed to `AST::clear_statements`. New features ------------ @@ -21,6 +22,7 @@ New features * Support for _function pointers_ via `Fn(name)` and `Fn.call(...)` syntax - a poor man's first-class function. * Support for calling script-defined functions in method-call style with `this` binding to the object. * Special support in object maps for OOP. +* Expanded the `AST` API for fine-tuned manipulation of functions. Enhancements ------------ diff --git a/doc/src/appendix/keywords.md b/doc/src/appendix/keywords.md index 6e13178f..660a605c 100644 --- a/doc/src/appendix/keywords.md +++ b/doc/src/appendix/keywords.md @@ -24,8 +24,8 @@ Keywords List | `as` | Alias for variable export | [`no_module`] | | `private` | Mark function private | [`no_function`] | | `fn` (lower-case `f`) | Function definition | [`no_function`] | -| `Fn` (capital `F`) | Function to create a [function pointer] | [`no_function`] | -| `call` | Call a [function pointer] | [`no_function`] | +| `Fn` (capital `F`) | Function to create a [function pointer] | | +| `call` | Call a [function pointer] | | | `this` | Reference to base object for method call | [`no_function`] | | `type_of` | Get type name of value | | | `print` | Print value | | diff --git a/doc/src/language/keywords.md b/doc/src/language/keywords.md index 82eddce4..836e9b0c 100644 --- a/doc/src/language/keywords.md +++ b/doc/src/language/keywords.md @@ -15,6 +15,7 @@ The following are reserved keywords in Rhai: | `return` | Return values | | | `throw` | throw exceptions | | | `import`, `export`, `as` | Modules | [`no_module`] | +| `Fn`, `call` | Function pointers | | | `type_of`, `print`, `debug`, `eval` | Special functions | | Keywords cannot be the name of a [function] or [variable], unless the relevant feature is enabled. diff --git a/doc/src/language/values-and-types.md b/doc/src/language/values-and-types.md index 5a5ae2e0..70c276a9 100644 --- a/doc/src/language/values-and-types.md +++ b/doc/src/language/values-and-types.md @@ -15,7 +15,7 @@ The following primitive types are supported natively: | **[`Array`]** (disabled with [`no_index`]) | `rhai::Array` | `"array"` | `"[ ?, ?, ? ]"` | | **[Object map]** (disabled with [`no_object`]) | `rhai::Map` | `"map"` | `"#{ "a": 1, "b": 2 }"` | | **[Timestamp]** (implemented in the [`BasicTimePackage`][packages], disabled with [`no_std`]) | `std::time::Instant` ([`instant::Instant`] if not [WASM] build) | `"timestamp"` | _not supported_ | -| **[Function pointer]** (disabled with [`no_function`]) | _None_ | `Fn` | `"Fn(foo)"` | +| **[Function pointer]** | _None_ | `Fn` | `"Fn(foo)"` | | **[`Dynamic`] value** (i.e. can be anything) | `rhai::Dynamic` | _the actual type_ | _actual value_ | | **System integer** (current configuration) | `rhai::INT` (`i32` or `i64`) | `"i32"` or `"i64"` | `"42"`, `"123"` etc. | | **System floating-point** (current configuration, disabled with [`no_float`]) | `rhai::FLOAT` (`f32` or `f64`) | `"f32"` or `"f64"` | `"123.456"` etc. | diff --git a/examples/repl.rs b/examples/repl.rs index 81f42158..cf255986 100644 --- a/examples/repl.rs +++ b/examples/repl.rs @@ -166,7 +166,6 @@ fn main() { } // Throw away all the statements, leaving only the functions - #[cfg(not(feature = "no_function"))] - main_ast.retain_functions(); + main_ast.clear_statements(); } } diff --git a/src/any.rs b/src/any.rs index bc148956..7e8b8072 100644 --- a/src/any.rs +++ b/src/any.rs @@ -137,7 +137,6 @@ pub enum Union { Array(Box), #[cfg(not(feature = "no_object"))] Map(Box), - #[cfg(not(feature = "no_function"))] FnPtr(FnPtr), Variant(Box>), } @@ -175,7 +174,6 @@ impl Dynamic { Union::Array(_) => TypeId::of::(), #[cfg(not(feature = "no_object"))] Union::Map(_) => TypeId::of::(), - #[cfg(not(feature = "no_function"))] Union::FnPtr(_) => TypeId::of::(), Union::Variant(value) => (***value).type_id(), } @@ -195,7 +193,6 @@ impl Dynamic { Union::Array(_) => "array", #[cfg(not(feature = "no_object"))] Union::Map(_) => "map", - #[cfg(not(feature = "no_function"))] Union::FnPtr(_) => "Fn", #[cfg(not(feature = "no_std"))] @@ -220,7 +217,6 @@ impl fmt::Display for Dynamic { Union::Array(value) => fmt::Debug::fmt(value, f), #[cfg(not(feature = "no_object"))] Union::Map(value) => write!(f, "#{:?}", value), - #[cfg(not(feature = "no_function"))] Union::FnPtr(value) => fmt::Display::fmt(value, f), #[cfg(not(feature = "no_std"))] @@ -245,7 +241,6 @@ impl fmt::Debug for Dynamic { Union::Array(value) => fmt::Debug::fmt(value, f), #[cfg(not(feature = "no_object"))] Union::Map(value) => write!(f, "#{:?}", value), - #[cfg(not(feature = "no_function"))] Union::FnPtr(value) => fmt::Display::fmt(value, f), #[cfg(not(feature = "no_std"))] @@ -270,7 +265,6 @@ impl Clone for Dynamic { Union::Array(ref value) => Self(Union::Array(value.clone())), #[cfg(not(feature = "no_object"))] Union::Map(ref value) => Self(Union::Map(value.clone())), - #[cfg(not(feature = "no_function"))] Union::FnPtr(ref value) => Self(Union::FnPtr(value.clone())), Union::Variant(ref value) => (***value).clone_into_dynamic(), } @@ -400,7 +394,6 @@ impl Dynamic { Union::Array(value) => unsafe_cast_box::<_, T>(value).ok().map(|v| *v), #[cfg(not(feature = "no_object"))] Union::Map(value) => unsafe_cast_box::<_, T>(value).ok().map(|v| *v), - #[cfg(not(feature = "no_function"))] Union::FnPtr(value) => unsafe_try_cast(value), Union::Variant(value) => (*value).as_box_any().downcast().map(|x| *x).ok(), } @@ -444,7 +437,6 @@ impl Dynamic { Union::Array(value) => *unsafe_cast_box::<_, T>(value).unwrap(), #[cfg(not(feature = "no_object"))] Union::Map(value) => *unsafe_cast_box::<_, T>(value).unwrap(), - #[cfg(not(feature = "no_function"))] Union::FnPtr(value) => unsafe_try_cast(value).unwrap(), Union::Variant(value) => (*value).as_box_any().downcast().map(|x| *x).unwrap(), } @@ -471,7 +463,6 @@ impl Dynamic { Union::Array(value) => ::downcast_ref::(value.as_ref()), #[cfg(not(feature = "no_object"))] Union::Map(value) => ::downcast_ref::(value.as_ref()), - #[cfg(not(feature = "no_function"))] Union::FnPtr(value) => ::downcast_ref::(value), Union::Variant(value) => value.as_ref().as_ref().as_any().downcast_ref::(), } @@ -497,7 +488,6 @@ impl Dynamic { Union::Array(value) => ::downcast_mut::(value.as_mut()), #[cfg(not(feature = "no_object"))] Union::Map(value) => ::downcast_mut::(value.as_mut()), - #[cfg(not(feature = "no_function"))] Union::FnPtr(value) => ::downcast_mut::(value), Union::Variant(value) => value.as_mut().as_mut_any().downcast_mut::(), } @@ -626,7 +616,6 @@ impl, T: Variant + Clone> From> for Dynam ))) } } -#[cfg(not(feature = "no_function"))] impl From for Dynamic { fn from(value: FnPtr) -> Self { Self(Union::FnPtr(value)) diff --git a/src/lib.rs b/src/lib.rs index 4c37149c..aa205dec 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -103,6 +103,9 @@ pub use scope::Scope; pub use token::Position; pub use utils::calc_fn_spec as calc_fn_hash; +#[cfg(not(feature = "no_function"))] +pub use parser::FnAccess; + #[cfg(not(feature = "no_function"))] pub use fn_func::Func; @@ -135,7 +138,7 @@ pub use token::{get_next_token, parse_string_literal, InputStream, Token, Tokeni #[cfg(feature = "internals")] #[deprecated(note = "this type is volatile and may change")] -pub use parser::{Expr, FnAccess, ReturnType, ScriptFnDef, Stmt}; +pub use parser::{Expr, ReturnType, ScriptFnDef, Stmt}; #[cfg(feature = "internals")] #[deprecated(note = "this type is volatile and may change")] diff --git a/src/module.rs b/src/module.rs index 043d14dc..3cdeb711 100644 --- a/src/module.rs +++ b/src/module.rs @@ -859,12 +859,53 @@ impl Module { /// Merge another module into this module. pub fn merge(&mut self, other: &Self) { + self.merge_filtered(other, |_, _, _| true) + } + + /// Merge another module into this module, with only selected functions based on a filter predicate. + pub(crate) fn merge_filtered( + &mut self, + other: &Self, + filter: impl Fn(FnAccess, &str, usize) -> bool, + ) { self.variables .extend(other.variables.iter().map(|(k, v)| (k.clone(), v.clone()))); - self.functions - .extend(other.functions.iter().map(|(&k, v)| (k, v.clone()))); + + self.functions.extend( + other + .functions + .iter() + .filter(|(_, (_, _, _, v))| match v { + CallableFunction::Pure(_) + | CallableFunction::Method(_) + | CallableFunction::Iterator(_) => true, + CallableFunction::Script(ref f) => { + filter(f.access, f.name.as_str(), f.params.len()) + } + }) + .map(|(&k, v)| (k, v.clone())), + ); + self.type_iterators .extend(other.type_iterators.iter().map(|(&k, v)| (k, v.clone()))); + + self.all_functions.clear(); + self.all_variables.clear(); + self.indexed = false; + } + + /// Filter out the functions, retaining only some based on a filter predicate. + pub(crate) fn retain_functions(&mut self, filter: impl Fn(FnAccess, &str, usize) -> bool) { + self.functions.retain(|_, (_, _, _, v)| match v { + CallableFunction::Pure(_) + | CallableFunction::Method(_) + | CallableFunction::Iterator(_) => true, + CallableFunction::Script(ref f) => filter(f.access, f.name.as_str(), f.params.len()), + }); + + self.all_functions.clear(); + self.all_variables.clear(); + self.indexed = false; } /// Get the number of variables in the module. diff --git a/src/parser.rs b/src/parser.rs index 69289de0..906d99de 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -116,8 +116,15 @@ impl AST { /// /// let engine = Engine::new(); /// - /// let ast1 = engine.compile(r#"fn foo(x) { 42 + x } foo(1)"#)?; - /// let ast2 = engine.compile(r#"fn foo(n) { "hello" + n } foo("!")"#)?; + /// let ast1 = engine.compile(r#" + /// fn foo(x) { 42 + x } + /// foo(1) + /// "#)?; + /// + /// let ast2 = engine.compile(r#" + /// fn foo(n) { "hello" + n } + /// foo("!") + /// "#)?; /// /// let ast = ast1.merge(&ast2); // Merge 'ast2' into 'ast1' /// @@ -138,6 +145,65 @@ impl AST { /// # } /// ``` pub fn merge(&self, other: &Self) -> Self { + self.merge_filtered(other, |_, _, _| true) + } + + /// Merge two [`AST`] into one. Both [`AST`]'s are untouched and a new, merged, version + /// is returned. + /// + /// The second [`AST`] is simply appended to the end of the first _without any processing_. + /// Thus, the return value of the first [`AST`] (if using expression-statement syntax) is buried. + /// Of course, if the first [`AST`] uses a `return` statement at the end, then + /// the second [`AST`] will essentially be dead code. + /// + /// All script-defined functions in the second [`AST`] are first selected based on a filter + /// predicate, then overwrite similarly-named functions in the first [`AST`] with the + /// same number of parameters. + /// + /// # Example + /// + /// ``` + /// # fn main() -> Result<(), Box> { + /// # #[cfg(not(feature = "no_function"))] + /// # { + /// use rhai::Engine; + /// + /// let engine = Engine::new(); + /// + /// let ast1 = engine.compile(r#" + /// fn foo(x) { 42 + x } + /// foo(1) + /// "#)?; + /// + /// let ast2 = engine.compile(r#" + /// fn foo(n) { "hello" + n } + /// fn error() { 0 } + /// foo("!") + /// "#)?; + /// + /// // Merge 'ast2', picking only 'error()' but not 'foo(_)', into 'ast1' + /// let ast = ast1.merge_filtered(&ast2, |_, name, params| name == "error" && params == 0); + /// + /// // 'ast' is essentially: + /// // + /// // fn foo(n) { 42 + n } // <- definition of 'ast1::foo' is not overwritten + /// // // because 'ast2::foo' is filtered away + /// // foo(1) // <- notice this will be 43 instead of "hello1", + /// // // but it is no longer the return value + /// // fn error() { 0 } // <- this function passes the filter and is merged + /// // foo("!") // <- returns "42!" + /// + /// // Evaluate it + /// assert_eq!(engine.eval_ast::(&ast)?, "42!"); + /// # } + /// # Ok(()) + /// # } + /// ``` + pub fn merge_filtered( + &self, + other: &Self, + filter: impl Fn(FnAccess, &str, usize) -> bool, + ) -> Self { let Self(statements, functions) = self; let ast = match (statements.is_empty(), other.0.is_empty()) { @@ -152,11 +218,39 @@ impl AST { }; let mut functions = functions.clone(); - functions.merge(&other.1); + functions.merge_filtered(&other.1, filter); Self::new(ast, functions) } + /// Filter out the functions, retaining only some based on a filter predicate. + /// + /// # Example + /// + /// ``` + /// # fn main() -> Result<(), Box> { + /// # #[cfg(not(feature = "no_function"))] + /// # { + /// use rhai::Engine; + /// + /// let engine = Engine::new(); + /// + /// let mut ast = engine.compile(r#" + /// fn foo(n) { n + 1 } + /// fn bar() { print("hello"); } + /// "#)?; + /// + /// // Remove all functions except 'foo(_)' + /// ast.retain_functions(|_, name, params| name == "foo" && params == 1); + /// # } + /// # Ok(()) + /// # } + /// ``` + #[cfg(not(feature = "no_function"))] + pub fn retain_functions(&mut self, filter: impl Fn(FnAccess, &str, usize) -> bool) { + self.1.retain_functions(filter); + } + /// Clear all function definitions in the [`AST`]. #[cfg(not(feature = "no_function"))] pub fn clear_functions(&mut self) { @@ -164,8 +258,7 @@ impl AST { } /// Clear all statements in the [`AST`], leaving only function definitions. - #[cfg(not(feature = "no_function"))] - pub fn retain_functions(&mut self) { + pub fn clear_statements(&mut self) { self.0 = vec![]; } } @@ -187,6 +280,15 @@ pub enum FnAccess { Public, } +impl fmt::Display for FnAccess { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Private => write!(f, "private"), + Self::Public => write!(f, "public"), + } + } +} + /// A scripted function definition. #[derive(Debug, Clone)] pub struct ScriptFnDef { diff --git a/tests/modules.rs b/tests/modules.rs index b110fa8e..0e1eb3ba 100644 --- a/tests/modules.rs +++ b/tests/modules.rs @@ -250,6 +250,7 @@ fn test_module_export() -> Result<(), Box> { ParseError(x, _) if *x == ParseErrorType::WrongExport )); + #[cfg(not(feature = "no_function"))] assert!(matches!( engine.compile(r"fn abc(x) { export x; }").expect_err("should error"), ParseError(x, _) if *x == ParseErrorType::WrongExport