diff --git a/README.md b/README.md index ae2807b5..5de326f1 100644 --- a/README.md +++ b/README.md @@ -579,6 +579,7 @@ A number of traits, under the `rhai::` module namespace, provide additional func | `RegisterDynamicFn` | Trait for registering functions returning [`Dynamic`] | `register_dynamic_fn` | | `RegisterResultFn` | Trait for registering fallible functions returning `Result<`_T_`, Box>` | `register_result_fn` | | `Func` | Trait for creating anonymous functions from script | `create_from_ast`, `create_from_script` | +| `ModuleResolver` | Trait implemented by module resolution services | `resolve` | Working with functions ---------------------- @@ -2104,6 +2105,21 @@ engine.eval_expression_with_scope::(&scope, "question::answer + 1")? == 42; engine.eval_expression_with_scope::(&scope, "question::inc(question::answer)")? == 42; ``` +### Module resolvers + +When encountering an `import` statement, Rhai attempts to _resolve_ the module based on the path string. +_Module Resolvers_ are service types that implement the [`ModuleResolver`](#traits) trait. +There are a number of standard resolvers built into Rhai, the default being the `FileModuleResolver` +which simply loads a script file based on the path (with `.rhai` extension attached) and execute it to form a module. + +Built-in module resolvers are grouped under the `rhai::module_resolvers` module namespace. + +| Module Resolver | Description | +| ---------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `FileModuleResolver` | The default module resolution service, not available under the [`no_std`] feature. Loads a script file (based off the current directory) with `.rhai` extension.
The base directory can be changed via the `FileModuleResolver::new_with_path()` constructor function. | +| `StaticModuleResolver` | Loads modules that are statically added. This can be used when the [`no_std`] feature is turned on. | +| `NullModuleResolver` | The default module resolution service under the [`no_std`] feature. Always returns an `EvalAltResult::ErrorModuleNotFound` error. | + Script optimization =================== diff --git a/src/engine.rs b/src/engine.rs index d3702339..c68efd95 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -3,7 +3,7 @@ use crate::any::{Dynamic, Union}; use crate::calc_fn_hash; use crate::error::ParseErrorType; -use crate::module::Module; +use crate::module::{resolvers, Module, ModuleResolver}; use crate::optimize::OptimizationLevel; use crate::packages::{CorePackage, Package, PackageLibrary, StandardPackage}; use crate::parser::{Expr, FnDef, ModuleRef, ReturnType, Stmt, AST}; @@ -260,6 +260,10 @@ pub struct Engine { /// A hashmap containing all iterators known to the engine. pub(crate) type_iterators: HashMap>, + + /// A module resolution service. + pub(crate) module_resolver: Box, + /// A hashmap mapping type names to pretty-print names. pub(crate) type_names: HashMap, @@ -293,6 +297,13 @@ impl Default for Engine { packages: Default::default(), functions: HashMap::with_capacity(FUNCTIONS_COUNT), type_iterators: Default::default(), + + #[cfg(not(feature = "no_module"))] + #[cfg(not(feature = "no_std"))] + module_resolver: Box::new(resolvers::FileModuleResolver::new()), + #[cfg(any(feature = "no_std", feature = "no_module"))] + module_resolver: Box::new(resolvers::NullModuleResolver::new()), + type_names: Default::default(), // default print/debug implementations @@ -427,6 +438,13 @@ impl Engine { packages: Default::default(), functions: HashMap::with_capacity(FUNCTIONS_COUNT / 2), type_iterators: Default::default(), + + #[cfg(not(feature = "no_module"))] + #[cfg(not(feature = "no_std"))] + module_resolver: Box::new(resolvers::FileModuleResolver::new()), + #[cfg(any(feature = "no_std", feature = "no_module"))] + module_resolver: Box::new(resolvers::NullModuleResolver::new()), + type_names: Default::default(), print: Box::new(|_| {}), debug: Box::new(|_| {}), @@ -455,7 +473,7 @@ impl Engine { self.packages.insert(0, package); } - /// Control whether and how the `Engine` will optimize an AST after compilation + /// Control whether and how the `Engine` will optimize an AST after compilation. /// /// Not available under the `no_optimize` feature. #[cfg(not(feature = "no_optimize"))] @@ -469,7 +487,15 @@ impl Engine { self.max_call_stack_depth = levels } - /// Universal method for calling functions either registered with the `Engine` or written in Rhai + /// Set the module resolution service used by the `Engine`. + /// + /// Not available under the `no_module` feature. + #[cfg(not(feature = "no_module"))] + pub fn set_module_resolver(&mut self, resolver: impl ModuleResolver + 'static) { + self.module_resolver = Box::new(resolver); + } + + /// Universal method for calling functions either registered with the `Engine` or written in Rhai. pub(crate) fn call_fn_raw( &self, scope: Option<&mut Scope>, @@ -1220,18 +1246,29 @@ impl Engine { if let Some(modules) = modules { // Module-qualified function call - let hash = calc_fn_hash(fn_name, args.iter().map(|a| a.type_id())); + let modules = modules.as_ref(); let (id, root_pos) = modules.get(0); // First module let module = scope.find_module(id).ok_or_else(|| { Box::new(EvalAltResult::ErrorModuleNotFound(id.into(), *root_pos)) })?; - match module.get_qualified_fn(fn_name, hash, modules.as_ref(), *pos) { - Ok(func) => func(&mut args, *pos) - .map_err(|err| EvalAltResult::set_position(err, *pos)), - Err(_) if def_val.is_some() => Ok(def_val.as_deref().unwrap().clone()), - Err(err) => Err(err), + + // First search in script-defined functions (can override built-in) + if let Some(fn_def) = + module.get_qualified_fn_lib(fn_name, args.len(), modules)? + { + self.call_fn_from_lib(None, fn_lib, fn_def, &mut args, *pos, level) + } else { + // Then search in Rust functions + let hash = calc_fn_hash(fn_name, args.iter().map(|a| a.type_id())); + + match module.get_qualified_fn(fn_name, hash, modules, *pos) { + Ok(func) => func(&mut args, *pos) + .map_err(|err| EvalAltResult::set_position(err, *pos)), + Err(_) if def_val.is_some() => Ok(def_val.as_deref().unwrap().clone()), + Err(err) => Err(err), + } } } else if fn_name.as_ref() == KEYWORD_EVAL && args.len() == 1 @@ -1486,14 +1523,14 @@ impl Engine { // Import statement Stmt::Import(expr, name, _) => { + #[cfg(feature = "no_module")] + unreachable!(); + if let Some(path) = self .eval_expr(scope, state, fn_lib, expr, level)? .try_cast::() { - let mut module = Module::new(); - module.set_var("kitty", "foo".to_string()); - module.set_var("path", path); - module.set_fn_1_mut("calc", |x: &mut String| Ok(x.len() as crate::parser::INT)); + let module = self.module_resolver.resolve(self, &path)?; // TODO - avoid copying module name in inner block? let mod_name = name.as_ref().clone(); diff --git a/src/lib.rs b/src/lib.rs index ddcdee28..1bdf8fd7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -110,7 +110,12 @@ pub use engine::Map; pub use parser::FLOAT; #[cfg(not(feature = "no_module"))] -pub use module::Module; +pub use module::{Module, ModuleResolver}; + +#[cfg(not(feature = "no_module"))] +pub mod module_resolvers { + pub use crate::module::resolvers::*; +} #[cfg(not(feature = "no_optimize"))] pub use optimize::OptimizationLevel; diff --git a/src/module.rs b/src/module.rs index e7b2a51b..32d27395 100644 --- a/src/module.rs +++ b/src/module.rs @@ -2,28 +2,44 @@ use crate::any::{Dynamic, Variant}; use crate::calc_fn_hash; -use crate::engine::{FnAny, FnCallArgs, FunctionsLib}; +use crate::engine::{Engine, FnAny, FnCallArgs, FunctionsLib}; +use crate::parser::FnDef; use crate::result::EvalAltResult; +use crate::scope::{EntryType as ScopeEntryType, Scope}; use crate::token::Position; use crate::token::Token; use crate::utils::StaticVec; -use crate::stdlib::{any::TypeId, collections::HashMap, fmt, iter::empty, mem, string::String}; +use crate::stdlib::{ + any::TypeId, collections::HashMap, fmt, iter::empty, mem, rc::Rc, string::String, sync::Arc, +}; + +/// A trait that encapsulates a module resolution service. +pub trait ModuleResolver { + /// Resolve a module based on a path string. + fn resolve(&self, engine: &Engine, path: &str) -> Result>; +} /// An imported module, which may contain variables, sub-modules, /// external Rust functions, and script-defined functions. /// /// Not available under the `no_module` feature. -#[derive(Default)] +#[derive(Default, Clone)] pub struct Module { /// Sub-modules. modules: HashMap, /// Module variables, including sub-modules. variables: HashMap, + /// External Rust functions. - functions: HashMap>, + #[cfg(not(feature = "sync"))] + functions: HashMap>>, + /// External Rust functions. + #[cfg(feature = "sync")] + functions: HashMap>>, + /// Script-defined functions. - lib: FunctionsLib, + fn_lib: FunctionsLib, } impl fmt::Debug for Module { @@ -33,19 +49,11 @@ impl fmt::Debug for Module { "", self.variables, self.functions.len(), - self.lib.len() + self.fn_lib.len() ) } } -impl Clone for Module { - fn clone(&self) -> Self { - // `Module` implements `Clone` so it can fit inside a `Dynamic` - // but we should never actually clone it. - unimplemented!() - } -} - impl Module { /// Create a new module. pub fn new() -> Self { @@ -147,7 +155,13 @@ impl Module { /// If there is an existing Rust function of the same hash, it is replaced. pub fn set_fn(&mut self, fn_name: &str, params: &[TypeId], func: Box) -> u64 { let hash = calc_fn_hash(fn_name, params.iter().cloned()); - self.functions.insert(hash, func); + + #[cfg(not(feature = "sync"))] + self.functions.insert(hash, Rc::new(func)); + + #[cfg(feature = "sync")] + self.functions.insert(hash, Arc::new(func)); + hash } @@ -163,10 +177,8 @@ impl Module { + Sync + 'static, ) -> u64 { - let hash = calc_fn_hash(fn_name, empty()); let f = move |_: &mut FnCallArgs, _: Position| func().map(|v| v.into()); - self.functions.insert(hash, Box::new(f)); - hash + self.set_fn(fn_name, &[], Box::new(f)) } /// Set a Rust function taking one parameter into the module, returning a hash key. @@ -181,13 +193,10 @@ impl Module { + Sync + 'static, ) -> u64 { - let hash = calc_fn_hash(fn_name, [TypeId::of::()].iter().cloned()); - let f = move |args: &mut FnCallArgs, _: Position| { func(mem::take(args[0]).cast::()).map(|v| v.into()) }; - self.functions.insert(hash, Box::new(f)); - hash + self.set_fn(fn_name, &[TypeId::of::()], Box::new(f)) } /// Set a Rust function taking one mutable parameter into the module, returning a hash key. @@ -202,13 +211,10 @@ impl Module { + Sync + 'static, ) -> u64 { - let hash = calc_fn_hash(fn_name, [TypeId::of::()].iter().cloned()); - let f = move |args: &mut FnCallArgs, _: Position| { func(args[0].downcast_mut::().unwrap()).map(|v| v.into()) }; - self.functions.insert(hash, Box::new(f)); - hash + self.set_fn(fn_name, &[TypeId::of::()], Box::new(f)) } /// Set a Rust function taking two parameters into the module, returning a hash key. @@ -223,19 +229,17 @@ impl Module { + Sync + 'static, ) -> u64 { - let hash = calc_fn_hash( - fn_name, - [TypeId::of::(), TypeId::of::()].iter().cloned(), - ); - let f = move |args: &mut FnCallArgs, _: Position| { let a = mem::take(args[0]).cast::(); let b = mem::take(args[1]).cast::(); func(a, b).map(|v| v.into()) }; - self.functions.insert(hash, Box::new(f)); - hash + self.set_fn( + fn_name, + &[TypeId::of::(), TypeId::of::()], + Box::new(f), + ) } /// Set a Rust function taking two parameters (the first one mutable) into the module, @@ -252,19 +256,17 @@ impl Module { + Sync + 'static, ) -> u64 { - let hash = calc_fn_hash( - fn_name, - [TypeId::of::(), TypeId::of::()].iter().cloned(), - ); - let f = move |args: &mut FnCallArgs, _: Position| { let b = mem::take(args[1]).cast::(); let a = args[0].downcast_mut::().unwrap(); func(a, b).map(|v| v.into()) }; - self.functions.insert(hash, Box::new(f)); - hash + self.set_fn( + fn_name, + &[TypeId::of::(), TypeId::of::()], + Box::new(f), + ) } /// Set a Rust function taking three parameters into the module, returning a hash key. @@ -284,13 +286,6 @@ impl Module { + Sync + 'static, ) -> u64 { - let hash = calc_fn_hash( - fn_name, - [TypeId::of::(), TypeId::of::(), TypeId::of::()] - .iter() - .cloned(), - ); - let f = move |args: &mut FnCallArgs, _: Position| { let a = mem::take(args[0]).cast::(); let b = mem::take(args[1]).cast::(); @@ -298,8 +293,11 @@ impl Module { func(a, b, c).map(|v| v.into()) }; - self.functions.insert(hash, Box::new(f)); - hash + self.set_fn( + fn_name, + &[TypeId::of::(), TypeId::of::(), TypeId::of::()], + Box::new(f), + ) } /// Set a Rust function taking three parameters (the first one mutable) into the module, @@ -321,13 +319,6 @@ impl Module { + Sync + 'static, ) -> u64 { - let hash = calc_fn_hash( - fn_name, - [TypeId::of::(), TypeId::of::(), TypeId::of::()] - .iter() - .cloned(), - ); - let f = move |args: &mut FnCallArgs, _: Position| { let b = mem::take(args[1]).cast::(); let c = mem::take(args[2]).cast::(); @@ -335,8 +326,11 @@ impl Module { func(a, b, c).map(|v| v.into()) }; - self.functions.insert(hash, Box::new(f)); - hash + self.set_fn( + fn_name, + &[TypeId::of::(), TypeId::of::(), TypeId::of::()], + Box::new(f), + ) } /// Get a Rust function. @@ -344,7 +338,7 @@ impl Module { /// The `u64` hash is calculated by the function `crate::calc_fn_hash`. /// It is also returned by the `set_fn_XXX` calls. pub fn get_fn(&self, hash: u64) -> Option<&Box> { - self.functions.get(&hash) + self.functions.get(&hash).map(|v| v.as_ref()) } /// Get a modules-qualified function. @@ -374,4 +368,141 @@ impl Module { Box::new(EvalAltResult::ErrorFunctionNotFound(fn_name, pos)) })?) } + + /// Get a script-defined function. + pub fn get_fn_lib(&self) -> &FunctionsLib { + &self.fn_lib + } + + /// Get a modules-qualified functions library. + pub(crate) fn get_qualified_fn_lib( + &mut self, + name: &str, + args: usize, + modules: &StaticVec<(String, Position)>, + ) -> Result, Box> { + Ok(self + .get_qualified_module_mut(modules)? + .fn_lib + .get_function(name, args)) + } +} + +pub mod resolvers { + use super::*; + + #[cfg(not(feature = "no_std"))] + use crate::stdlib::path::PathBuf; + + /// A module resolution service that loads module script files (assumed `.rhai` extension). + #[cfg(not(feature = "no_std"))] + #[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] + pub struct FileModuleResolver(PathBuf); + + #[cfg(not(feature = "no_std"))] + impl FileModuleResolver { + /// Create a new `FileModuleResolver` with a specific base path. + pub fn new_with_path(path: PathBuf) -> Self { + Self(path) + } + /// Create a new `FileModuleResolver` with the current directory as base path. + pub fn new() -> Self { + Default::default() + } + } + + #[cfg(not(feature = "no_std"))] + impl Default for FileModuleResolver { + fn default() -> Self { + Self::new_with_path(".".into()) + } + } + + #[cfg(not(feature = "no_std"))] + impl ModuleResolver for FileModuleResolver { + fn resolve(&self, engine: &Engine, path: &str) -> Result> { + // Load the script file (attaching `.rhai`) + let mut file_path = self.0.clone(); + file_path.push(path); + file_path.set_extension("rhai"); + + // Compile it + let ast = engine.compile_file(file_path)?; + + // Use new scope + let mut scope = Scope::new(); + + // Run the script + engine.eval_ast_with_scope_raw(&mut scope, &ast)?; + + // Create new module + let mut module = Module::new(); + + // Variables left in the scope become module variables + for entry in scope.into_iter() { + match entry.typ { + ScopeEntryType::Normal | ScopeEntryType::Constant => { + module + .variables + .insert(entry.name.into_owned(), entry.value); + } + ScopeEntryType::Module => { + module + .modules + .insert(entry.name.into_owned(), entry.value.cast::()); + } + } + } + + module.fn_lib = FunctionsLib::new().merge(ast.fn_lib()); + + Ok(module) + } + } + + /// A module resolution service that serves modules added into it. + #[derive(Debug, Clone, Default)] + pub struct StaticModuleResolver(HashMap); + + impl StaticModuleResolver { + /// Create a new `StaticModuleResolver`. + pub fn new() -> Self { + Default::default() + } + /// Add a named module. + pub fn add_module(&mut self, name: &str, module: Module) { + self.0.insert(name.to_string(), module); + } + } + + impl ModuleResolver for StaticModuleResolver { + fn resolve(&self, _: &Engine, path: &str) -> Result> { + self.0.get(path).cloned().ok_or_else(|| { + Box::new(EvalAltResult::ErrorModuleNotFound( + path.to_string(), + Position::none(), + )) + }) + } + } + + /// A module resolution service that always returns a not-found error. + #[derive(Debug, Clone, PartialEq, Eq, Copy, Default)] + pub struct NullModuleResolver; + + impl NullModuleResolver { + /// Create a new `NullModuleResolver`. + pub fn new() -> Self { + Default::default() + } + } + + impl ModuleResolver for NullModuleResolver { + fn resolve(&self, _: &Engine, path: &str) -> Result> { + Err(Box::new(EvalAltResult::ErrorModuleNotFound( + path.to_string(), + Position::none(), + ))) + } + } } diff --git a/src/scope.rs b/src/scope.rs index a20afd44..94035dbb 100644 --- a/src/scope.rs +++ b/src/scope.rs @@ -5,7 +5,7 @@ use crate::module::Module; use crate::parser::{map_dynamic_to_expr, Expr}; use crate::token::Position; -use crate::stdlib::{borrow::Cow, boxed::Box, iter, vec::Vec}; +use crate::stdlib::{borrow::Cow, boxed::Box, iter, vec, vec::Vec}; /// Type of an entry in the Scope. #[derive(Debug, Eq, PartialEq, Hash, Copy, Clone)] @@ -416,6 +416,11 @@ impl<'a> Scope<'a> { (&mut entry.value, entry.typ) } + /// Get an iterator to entries in the Scope. + pub(crate) fn into_iter(self) -> impl Iterator> { + self.0.into_iter() + } + /// Get an iterator to entries in the Scope. pub(crate) fn iter(&self) -> impl Iterator { self.0.iter().rev() // Always search a Scope in reverse order diff --git a/tests/modules.rs b/tests/modules.rs index 5a62d942..e1382764 100644 --- a/tests/modules.rs +++ b/tests/modules.rs @@ -1,5 +1,5 @@ #![cfg(not(feature = "no_module"))] -use rhai::{Engine, EvalAltResult, Module, Scope, INT}; +use rhai::{module_resolvers, Engine, EvalAltResult, Module, Scope, INT}; #[test] fn test_module() { @@ -11,7 +11,7 @@ fn test_module() { } #[test] -fn test_sub_module() -> Result<(), Box> { +fn test_module_sub_module() -> Result<(), Box> { let mut module = Module::new(); let mut sub_module = Module::new(); @@ -56,3 +56,28 @@ fn test_sub_module() -> Result<(), Box> { Ok(()) } + +#[test] +fn test_module_resolver() -> Result<(), Box> { + let mut resolver = module_resolvers::StaticModuleResolver::new(); + + let mut module = Module::new(); + module.set_var("answer", 42 as INT); + + resolver.add_module("hello", module); + + let mut engine = Engine::new(); + engine.set_module_resolver(resolver); + + assert_eq!( + engine.eval::( + r#" + import "hello" as h; + h::answer + "# + )?, + 42 + ); + + Ok(()) +}