From 6d551f15962224ad158388f8dd9edc406c2d6de8 Mon Sep 17 00:00:00 2001 From: Stephen Chung Date: Wed, 22 Jul 2020 23:12:09 +0800 Subject: [PATCH] Add currying support. --- RELEASES.md | 2 + doc/src/SUMMARY.md | 1 + doc/src/appendix/keywords.md | 1 + doc/src/language/fn-curry.md | 30 ++++++++++ doc/src/links.md | 1 + src/any.rs | 15 +++-- src/engine.rs | 104 ++++++++++++++++++++++++++--------- src/fn_native.rs | 14 +++-- src/token.rs | 6 +- tests/fn_ptr.rs | 32 +++++++++++ 10 files changed, 168 insertions(+), 38 deletions(-) create mode 100644 doc/src/language/fn-curry.md diff --git a/RELEASES.md b/RELEASES.md index bdc86fd0..8ba7d09c 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -7,6 +7,7 @@ Version 0.18.0 This version adds: * Anonymous functions (in closure syntax). Simplifies creation of ad hoc functions. +* Currying of function pointers. New features ------------ @@ -16,6 +17,7 @@ New features * `x.call(f, ...)` allows binding `x` to `this` for the function referenced by the function pointer `f`. * Anonymous functions in the syntax of a closure, e.g. `|x, y, z| x + y - z`. * Custom syntax now works even without the `internals` feature. +* Currying of function pointers is supported via the `curry` keyword. Breaking changes ---------------- diff --git a/doc/src/SUMMARY.md b/doc/src/SUMMARY.md index d86aae6a..4e41d940 100644 --- a/doc/src/SUMMARY.md +++ b/doc/src/SUMMARY.md @@ -76,6 +76,7 @@ The Rhai Scripting Language 3. [Namespaces](language/fn-namespaces.md) 4. [Function Pointers](language/fn-ptr.md) 5. [Anonymous Functions](language/fn-anon.md) + 6. [Currying](language/fn-curry.md) 15. [Print and Debug](language/print-debug.md) 16. [Modules](language/modules/index.md) 1. [Export Variables, Functions and Sub-Modules](language/modules/export.md) diff --git a/doc/src/appendix/keywords.md b/doc/src/appendix/keywords.md index 660a605c..7c8dbaff 100644 --- a/doc/src/appendix/keywords.md +++ b/doc/src/appendix/keywords.md @@ -26,6 +26,7 @@ Keywords List | `fn` (lower-case `f`) | Function definition | [`no_function`] | | `Fn` (capital `F`) | Function to create a [function pointer] | | | `call` | Call a [function pointer] | | +| `curry` | Curry a [function pointer] | | | `this` | Reference to base object for method call | [`no_function`] | | `type_of` | Get type name of value | | | `print` | Print value | | diff --git a/doc/src/language/fn-curry.md b/doc/src/language/fn-curry.md new file mode 100644 index 00000000..8ee103a9 --- /dev/null +++ b/doc/src/language/fn-curry.md @@ -0,0 +1,30 @@ +Function Pointer Currying +======================== + +{{#include ../links.md}} + +It is possible to _curry_ a [function pointer] by providing partial (or all) arguments. + +Currying is done via the `curry` keyword and produces a new [function pointer] which carries +the curried arguments. + +When the curried [function pointer] is called, the curried arguments are inserted starting from the left. +The actual call arguments should be reduced by the number of curried arguments. + +```rust +fn mul(x, y) { // function with two parameters + x * y +} + +let func = Fn("mul"); + +func.call(21, 2) == 42; // two arguments are required for 'mul' + +let curried = func.curry(21); // currying produces a new function pointer which + // carries 21 as the first argument + +let curried = curry(func, 21); // function-call style also works + +curried.call(2) == 42; // <- de-sugars to 'func.call(21, 2)' + // only one argument is now required +``` diff --git a/doc/src/links.md b/doc/src/links.md index a26a34e8..889cee29 100644 --- a/doc/src/links.md +++ b/doc/src/links.md @@ -76,6 +76,7 @@ [functions]: {{rootUrl}}/language/functions.md [function pointer]: {{rootUrl}}/language/fn-ptr.md [function pointers]: {{rootUrl}}/language/fn-ptr.md +[currying]: {{rootUrl}}/language/fn-curry.md [function namespace]: {{rootUrl}}/language/fn-namespaces.md [function namespaces]: {{rootUrl}}/language/fn-namespaces.md [anonymous function]: {{rootUrl}}/language/fn-anon.md diff --git a/src/any.rs b/src/any.rs index 897bc922..6ff986b3 100644 --- a/src/any.rs +++ b/src/any.rs @@ -137,7 +137,7 @@ pub enum Union { Array(Box), #[cfg(not(feature = "no_object"))] Map(Box), - FnPtr(FnPtr), + FnPtr(Box), Variant(Box>), } @@ -274,7 +274,7 @@ impl fmt::Debug for Dynamic { f.write_str("#")?; fmt::Debug::fmt(value, f) } - Union::FnPtr(value) => fmt::Display::fmt(value, f), + Union::FnPtr(value) => fmt::Debug::fmt(value, f), #[cfg(not(feature = "no_std"))] Union::Variant(value) if value.is::() => write!(f, ""), @@ -481,7 +481,7 @@ impl Dynamic { } if type_id == TypeId::of::() { return match self.0 { - Union::FnPtr(value) => unsafe_try_cast(value), + Union::FnPtr(value) => unsafe_cast_box::<_, T>(value).ok().map(|v| *v), _ => None, }; } @@ -582,7 +582,7 @@ impl Dynamic { } if type_id == TypeId::of::() { return match &self.0 { - Union::FnPtr(value) => ::downcast_ref::(value), + Union::FnPtr(value) => ::downcast_ref::(value.as_ref()), _ => None, }; } @@ -656,7 +656,7 @@ impl Dynamic { } if type_id == TypeId::of::() { return match &mut self.0 { - Union::FnPtr(value) => ::downcast_mut::(value), + Union::FnPtr(value) => ::downcast_mut::(value.as_mut()), _ => None, }; } @@ -801,6 +801,11 @@ impl, T: Variant + Clone> From> for Dynam } impl From for Dynamic { fn from(value: FnPtr) -> Self { + Box::new(value).into() + } +} +impl From> for Dynamic { + fn from(value: Box) -> Self { Self(Union::FnPtr(value)) } } diff --git a/src/engine.rs b/src/engine.rs index 010949f1..fdfe4ad8 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -79,6 +79,7 @@ pub const KEYWORD_TYPE_OF: &str = "type_of"; pub const KEYWORD_EVAL: &str = "eval"; pub const KEYWORD_FN_PTR: &str = "Fn"; pub const KEYWORD_FN_PTR_CALL: &str = "call"; +pub const KEYWORD_FN_PTR_CURRY: &str = "curry"; pub const KEYWORD_THIS: &str = "this"; pub const FN_TO_STRING: &str = "to_string"; pub const FN_GET: &str = "get$"; @@ -1042,12 +1043,17 @@ impl Engine { let (result, updated) = if fn_name == KEYWORD_FN_PTR_CALL && obj.is::() { // FnPtr call + let fn_ptr = obj.downcast_ref::().unwrap(); + let mut curry: StaticVec<_> = fn_ptr.curry().iter().cloned().collect(); // Redirect function name - let fn_name = obj.as_str().unwrap(); + let fn_name = fn_ptr.fn_name(); // Recalculate hash - let hash = calc_fn_hash(empty(), fn_name, idx.len(), empty()); - // Arguments are passed as-is - let mut arg_values = idx.iter_mut().collect::>(); + let hash = calc_fn_hash(empty(), fn_name, curry.len() + idx.len(), empty()); + // Arguments are passed as-is, adding the curried arguments + let mut arg_values = curry + .iter_mut() + .chain(idx.iter_mut()) + .collect::>(); let args = arg_values.as_mut(); // Map it to name(args) in function-call style @@ -1056,16 +1062,15 @@ impl Engine { ) } else if fn_name == KEYWORD_FN_PTR_CALL && idx.len() > 0 && idx[0].is::() { // FnPtr call on object + let fn_ptr = idx[0].downcast_ref::().unwrap(); + let mut curry: StaticVec<_> = fn_ptr.curry().iter().cloned().collect(); // Redirect function name - let fn_name = idx[0] - .downcast_ref::() - .unwrap() - .get_fn_name() - .clone(); + let fn_name = fn_ptr.get_fn_name().clone(); // Recalculate hash - let hash = calc_fn_hash(empty(), &fn_name, idx.len() - 1, empty()); - // Replace the first argument with the object pointer + let hash = calc_fn_hash(empty(), &fn_name, curry.len() + idx.len() - 1, empty()); + // Replace the first argument with the object pointer, adding the curried arguments let mut arg_values = once(obj) + .chain(curry.iter_mut()) .chain(idx.iter_mut().skip(1)) .collect::>(); let args = arg_values.as_mut(); @@ -1074,8 +1079,19 @@ impl Engine { self.exec_fn_call( state, lib, &fn_name, *native, hash, args, is_ref, true, *def_val, level, ) + } else if fn_name == KEYWORD_FN_PTR_CURRY && obj.is::() { + // Curry call + let fn_ptr = obj.downcast_ref::().unwrap(); + Ok(( + FnPtr::new_unchecked( + fn_ptr.get_fn_name().clone(), + fn_ptr.curry().iter().chain(idx.iter()).cloned().collect(), + ) + .into(), + false, + )) } else { - let redirected: Option; + let redirected; let mut hash = *hash; // Check if it is a map method call in OOP style @@ -1084,9 +1100,8 @@ impl Engine { if let Some(val) = map.get(fn_name) { if let Some(f) = val.downcast_ref::() { // Remap the function name - redirected = Some(f.get_fn_name().clone()); - fn_name = redirected.as_ref().unwrap(); - + redirected = f.get_fn_name().clone(); + fn_name = &redirected; // Recalculate the hash based on the new function name hash = calc_fn_hash(empty(), fn_name, idx.len(), empty()); } @@ -1731,7 +1746,7 @@ impl Engine { Expr::FloatConstant(x) => Ok(x.0.into()), Expr::StringConstant(x) => Ok(x.0.to_string().into()), Expr::CharConstant(x) => Ok(x.0.into()), - Expr::FnPointer(x) => Ok(FnPtr::new_unchecked(x.0.clone()).into()), + Expr::FnPointer(x) => Ok(FnPtr::new_unchecked(x.0.clone(), Default::default()).into()), Expr::Variable(x) if (x.0).0 == KEYWORD_THIS => { if let Some(val) = this_ptr { Ok(val.clone()) @@ -1919,6 +1934,39 @@ impl Engine { } } + // Handle curry() + if name == KEYWORD_FN_PTR_CURRY && args_expr.len() > 1 { + let expr = args_expr.get(0); + let fn_ptr = self.eval_expr(scope, mods, state, lib, this_ptr, expr, level)?; + + if !fn_ptr.is::() { + return Err(Box::new(EvalAltResult::ErrorMismatchOutputType( + self.map_type_name(type_name::()).into(), + self.map_type_name(fn_ptr.type_name()).into(), + expr.position(), + ))); + } + + let fn_ptr = fn_ptr.downcast_ref::().unwrap(); + + let curry: StaticVec<_> = args_expr + .iter() + .skip(1) + .map(|expr| self.eval_expr(scope, mods, state, lib, this_ptr, expr, level)) + .collect::>()?; + + return Ok(FnPtr::new_unchecked( + fn_ptr.get_fn_name().clone(), + fn_ptr + .curry() + .iter() + .cloned() + .chain(curry.into_iter()) + .collect(), + ) + .into()); + } + // Handle eval() if name == KEYWORD_EVAL && args_expr.len() == 1 { let hash_fn = @@ -1948,6 +1996,7 @@ impl Engine { let redirected; let mut name = name.as_ref(); let mut args_expr = args_expr.as_ref(); + let mut curry: StaticVec<_> = Default::default(); let mut hash = *hash; if name == KEYWORD_FN_PTR_CALL @@ -1955,20 +2004,22 @@ impl Engine { && !self.has_override(lib, 0, hash) { let expr = args_expr.get(0).unwrap(); - let fn_ptr = self.eval_expr(scope, mods, state, lib, this_ptr, expr, level)?; + let fn_name = self.eval_expr(scope, mods, state, lib, this_ptr, expr, level)?; - if fn_ptr.is::() { + if fn_name.is::() { + let fn_ptr = fn_name.cast::(); + curry = fn_ptr.curry().iter().cloned().collect(); // Redirect function name - redirected = Some(fn_ptr.cast::().take_fn_name()); - name = redirected.as_ref().unwrap(); + redirected = fn_ptr.take_fn_name(); + name = &redirected; // Skip the first argument args_expr = &args_expr.as_ref()[1..]; // Recalculate hash - hash = calc_fn_hash(empty(), name, args_expr.len(), empty()); + hash = calc_fn_hash(empty(), name, curry.len() + args_expr.len(), empty()); } else { return Err(Box::new(EvalAltResult::ErrorMismatchOutputType( self.map_type_name(type_name::()).into(), - fn_ptr.type_name().into(), + fn_name.type_name().into(), expr.position(), ))); } @@ -1979,7 +2030,7 @@ impl Engine { let mut args: StaticVec<_>; let mut is_ref = false; - if args_expr.is_empty() { + if args_expr.is_empty() && curry.is_empty() { // No arguments args = Default::default(); } else { @@ -2002,7 +2053,10 @@ impl Engine { self.inc_operations(state) .map_err(|err| err.new_position(pos))?; - args = once(target).chain(arg_values.iter_mut()).collect(); + args = once(target) + .chain(curry.iter_mut()) + .chain(arg_values.iter_mut()) + .collect(); is_ref = true; } @@ -2015,7 +2069,7 @@ impl Engine { }) .collect::>()?; - args = arg_values.iter_mut().collect(); + args = curry.iter_mut().chain(arg_values.iter_mut()).collect(); } } } diff --git a/src/fn_native.rs b/src/fn_native.rs index dabcbe21..39aa6ea1 100644 --- a/src/fn_native.rs +++ b/src/fn_native.rs @@ -50,13 +50,13 @@ pub fn shared_take(value: Shared) -> T { pub type FnCallArgs<'a> = [&'a mut Dynamic]; /// A general function pointer. -#[derive(Debug, Clone, Eq, PartialEq, Hash, Default)] -pub struct FnPtr(ImmutableString); +#[derive(Debug, Clone, Default)] +pub struct FnPtr(ImmutableString, Vec); impl FnPtr { /// Create a new function pointer. - pub(crate) fn new_unchecked>(name: S) -> Self { - Self(name.into()) + pub(crate) fn new_unchecked>(name: S, curry: Vec) -> Self { + Self(name.into(), curry) } /// Get the name of the function. pub fn fn_name(&self) -> &str { @@ -70,6 +70,10 @@ impl FnPtr { pub(crate) fn take_fn_name(self) -> ImmutableString { self.0 } + /// Get the curried data. + pub(crate) fn curry(&self) -> &[Dynamic] { + &self.1 + } } impl fmt::Display for FnPtr { @@ -83,7 +87,7 @@ impl TryFrom for FnPtr { fn try_from(value: ImmutableString) -> Result { if is_valid_identifier(value.chars()) { - Ok(Self(value)) + Ok(Self(value, Default::default())) } else { Err(Box::new(EvalAltResult::ErrorFunctionNotFound( value.into(), diff --git a/src/token.rs b/src/token.rs index dae43624..f2719a7a 100644 --- a/src/token.rs +++ b/src/token.rs @@ -1,8 +1,8 @@ //! Main module defining the lexer and parser. use crate::engine::{ - Engine, KEYWORD_DEBUG, KEYWORD_EVAL, KEYWORD_FN_PTR, KEYWORD_FN_PTR_CALL, KEYWORD_PRINT, - KEYWORD_THIS, KEYWORD_TYPE_OF, + Engine, KEYWORD_DEBUG, KEYWORD_EVAL, KEYWORD_FN_PTR, KEYWORD_FN_PTR_CALL, KEYWORD_FN_PTR_CURRY, + KEYWORD_PRINT, KEYWORD_THIS, KEYWORD_TYPE_OF, }; use crate::error::LexError; @@ -404,7 +404,7 @@ impl Token { Reserved(syntax.into()) } KEYWORD_PRINT | KEYWORD_DEBUG | KEYWORD_TYPE_OF | KEYWORD_EVAL | KEYWORD_FN_PTR - | KEYWORD_FN_PTR_CALL | KEYWORD_THIS => Reserved(syntax.into()), + | KEYWORD_FN_PTR_CALL | KEYWORD_FN_PTR_CURRY | KEYWORD_THIS => Reserved(syntax.into()), _ => return None, }) diff --git a/tests/fn_ptr.rs b/tests/fn_ptr.rs index e9121475..833bb20c 100644 --- a/tests/fn_ptr.rs +++ b/tests/fn_ptr.rs @@ -78,3 +78,35 @@ fn test_fn_ptr() -> Result<(), Box> { Ok(()) } + +#[test] +fn test_fn_ptr_curry() -> Result<(), Box> { + let mut engine = Engine::new(); + + engine.register_fn("foo", |x: &mut INT, y: INT| *x + y); + + #[cfg(not(feature = "no_object"))] + assert_eq!( + engine.eval::( + r#" + let f = Fn("foo"); + let f2 = f.curry(40); + f2.call(2) + "# + )?, + 42 + ); + + assert_eq!( + engine.eval::( + r#" + let f = Fn("foo"); + let f2 = curry(f, 40); + call(f2, 2) + "# + )?, + 42 + ); + + Ok(()) +}