diff --git a/src/searcher.rs b/src/searcher.rs index 73d36e3..f942fa2 100644 --- a/src/searcher.rs +++ b/src/searcher.rs @@ -250,6 +250,108 @@ impl UserData for StaticClosureSearcher { } } +/// Like `Searcher`, but with function pointers as `modules` values, to facilitate +/// setting up an `rlua::Context` with Rust code. +/// +/// Enables exposing `UserData` types to an `rlua::Context`. +pub struct FunctionSearcher { + /// Functions must accept three parameters: + /// + /// 1. An `rlua::Context`, which the function body can do what it wants with. + /// + /// 2. An `rlua::Table` containing globals (i.e. Lua’s `_G`), which can be passed + /// to `Chunk.set_environment()`. + /// + /// 3. The name of the module to be loaded (`&str`). + /// + /// Functions must return an `rlua::Result`-wrapped `Function`. This `Function` + /// acts as the module loader. + modules: HashMap< + String, + for<'ctx> fn(Context<'ctx>, Table<'ctx>, &str) -> rlua::Result>, + >, + + globals: RegistryKey, +} + +impl FunctionSearcher { + pub fn new( + modules: HashMap< + String, + for<'ctx> fn(Context<'ctx>, Table<'ctx>, &str) -> rlua::Result>, + >, + globals: RegistryKey, + ) -> Self { + Self { modules, globals } + } +} + +impl UserData for FunctionSearcher { + fn add_methods<'lua, M>(methods: &mut M) + where + M: UserDataMethods<'lua, Self>, + { + methods.add_meta_method( + MetaMethod::Call, + |lua_ctx: Context<'lua>, this, name: String| { + let name = name.as_str(); + match this.modules.get(name) { + Some(ref function) => Ok(Value::Function(function( + lua_ctx, + lua_ctx.registry_value::(&this.globals)?, + name, + )?)), + None => Ok(Value::Nil), + } + }, + ); + } +} + +/// Like `FunctionSearcher`, but with `&'static str` keys in `modules`. +pub struct StaticFunctionSearcher { + modules: HashMap< + &'static str, + for<'ctx> fn(Context<'ctx>, Table<'ctx>, &str) -> rlua::Result>, + >, + + globals: RegistryKey, +} + +impl StaticFunctionSearcher { + pub fn new( + modules: HashMap< + &'static str, + for<'ctx> fn(Context<'ctx>, Table<'ctx>, &str) -> rlua::Result>, + >, + globals: RegistryKey, + ) -> Self { + Self { modules, globals } + } +} + +impl UserData for StaticFunctionSearcher { + fn add_methods<'lua, M>(methods: &mut M) + where + M: UserDataMethods<'lua, Self>, + { + methods.add_meta_method( + MetaMethod::Call, + |lua_ctx: Context<'lua>, this, name: String| { + let name = name.as_str(); + match this.modules.get(name) { + Some(ref function) => Ok(Value::Function(function( + lua_ctx, + lua_ctx.registry_value::
(&this.globals)?, + name, + )?)), + None => Ok(Value::Nil), + } + }, + ); + } +} + /// Extend `rlua::Context` to support `require`ing Lua modules by name. pub trait AddSearcher { /// Add a `HashMap` of Lua modules indexed by module name to Lua’s @@ -300,6 +402,24 @@ pub trait AddSearcher { >, >, ) -> Result<()>; + + /// Like `add_searcher`, but with user-provided function for `rlua::Context` setup. + fn add_function_searcher( + &self, + modules: HashMap< + String, + for<'ctx> fn(Context<'ctx>, Table<'ctx>, &str) -> rlua::Result>, + >, + ) -> Result<()>; + + /// Like `add_function_searcher`, but with `&'static str` keys in `modules`. + fn add_static_function_searcher( + &self, + modules: HashMap< + &'static str, + for<'ctx> fn(Context<'ctx>, Table<'ctx>, &str) -> rlua::Result>, + >, + ) -> Result<()>; } impl<'a> AddSearcher for Context<'a> { @@ -392,4 +512,36 @@ impl<'a> AddSearcher for Context<'a> { .set(searchers.len()? + 1, searcher) .map_err(|e| e.into()) } + + fn add_function_searcher( + &self, + modules: HashMap< + String, + for<'ctx> fn(Context<'ctx>, Table<'ctx>, &str) -> rlua::Result>, + >, + ) -> Result<()> { + let globals = self.globals(); + let searchers: Table = globals.get::<_, Table>("package")?.get("searchers")?; + let registry_key = self.create_registry_value(globals)?; + let searcher = FunctionSearcher::new(modules, registry_key); + searchers + .set(searchers.len()? + 1, searcher) + .map_err(|e| e.into()) + } + + fn add_static_function_searcher( + &self, + modules: HashMap< + &'static str, + for<'ctx> fn(Context<'ctx>, Table<'ctx>, &str) -> rlua::Result>, + >, + ) -> Result<()> { + let globals = self.globals(); + let searchers: Table = globals.get::<_, Table>("package")?.get("searchers")?; + let registry_key = self.create_registry_value(globals)?; + let searcher = StaticFunctionSearcher::new(modules, registry_key); + searchers + .set(searchers.len()? + 1, searcher) + .map_err(|e| e.into()) + } } diff --git a/tests/tests.rs b/tests/tests.rs index bd390b6..c0bc7c9 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -160,6 +160,14 @@ fn module_reloading_works() { assert_eq!("hello lume", hello); } +fn read_lume_to_string() -> String { + r#"return "hello lume""#.to_string() +} + +fn read_lume_to_str() -> &'static str { + r#"return "hello lume""# +} + #[test] fn add_closure_searcher_works() { let lua = Lua::new(); @@ -288,10 +296,107 @@ impl UserData for Instrument { } } -fn read_lume_to_string() -> String { - r#"return "hello lume""#.to_string() +#[test] +fn add_function_searcher_works() { + let lua = Lua::new(); + + let mut modules: HashMap< + String, + for<'ctx> fn(Context<'ctx>, Table<'ctx>, &str) -> rlua::Result>, + > = HashMap::new(); + + modules.insert("cartridge".to_string(), cartridge_loader); + + let title = lua + .context::<_, Result>(|lua_ctx| { + lua_ctx.add_function_searcher(modules)?; + + // Ensure global variable `cartridge` is unset. + let nil: String = lua_ctx.load("return type(cartridge)").eval()?; + assert_eq!(nil, "nil"); + + Ok(lua_ctx + .load( + r#"local cartridge = require("cartridge") + local smash = cartridge.pick() + return smash:play()"#, + ) + .eval()?) + }) + .unwrap(); + + assert_eq!(title, "Super Smash Brothers 64"); } -fn read_lume_to_str() -> &'static str { - r#"return "hello lume""# +#[test] +fn add_static_function_searcher_works() { + let lua = Lua::new(); + + let mut modules: HashMap< + &'static str, + for<'ctx> fn(Context<'ctx>, Table<'ctx>, &str) -> rlua::Result>, + > = HashMap::new(); + + modules.insert("cartridge", cartridge_loader); + + let title = lua + .context::<_, Result>(|lua_ctx| { + lua_ctx.add_static_function_searcher(modules)?; + + // Ensure global variable `cartridge` is unset. + let nil: String = lua_ctx.load("return type(cartridge)").eval()?; + assert_eq!(nil, "nil"); + + Ok(lua_ctx + .load( + r#"local cartridge = require("cartridge") + local smash = cartridge.pick() + return smash:play()"#, + ) + .eval()?) + }) + .unwrap(); + + assert_eq!(title, "Super Smash Brothers 64"); +} + +struct Cartridge { + title: String, +} + +impl Cartridge { + pub fn pick() -> Self { + let title = "Super Smash Brothers 64".to_string(); + Self { title } + } + + pub fn play(&self) -> String { + self.title.clone() + } +} + +impl UserData for Cartridge { + fn add_methods<'lua, M>(methods: &mut M) + where + M: UserDataMethods<'lua, Self>, + { + methods.add_method("play", |_, cartridge, ()| Ok(cartridge.play())); + } +} + +fn cartridge_loader<'ctx>( + lua_ctx: Context<'ctx>, + env: Table<'ctx>, + name: &str, +) -> rlua::Result> { + let globals = lua_ctx.globals(); + let pick = lua_ctx.create_function(|_, ()| Ok(Cartridge::pick()))?; + let tbl = lua_ctx.create_table()?; + tbl.set("pick", pick)?; + globals.set("cartridge", tbl)?; + Ok(lua_ctx + .load("return cartridge") + .set_name(name)? + .set_environment(env)? + .into_function()?) }