diff --git a/README.md b/README.md index c96d0a32..0a91272f 100644 --- a/README.md +++ b/README.md @@ -362,10 +362,10 @@ Any similarly-named function defined in a script overrides any built-in function ```rust // Override the built-in function 'to_int' fn to_int(num) { - print("Ha! Gotcha!" + num); + print("Ha! Gotcha! " + num); } -print(to_int(123)); // what will happen? +print(to_int(123)); // what happens? ``` Custom types and methods @@ -794,7 +794,7 @@ print(y.len()); // prints 0 ``` `push` and `pad` are only defined for standard built-in types. If you want to use them with -your own custom type, you need to define a specific override: +your own custom type, you need to register a type-specific version: ```rust engine.register_fn("push", @@ -975,9 +975,8 @@ fn add(x, y) { print(add(2, 3)); ``` -Remember that functions defined in script always take `Dynamic` arguments (i.e. the arguments can be of any type). - -However, all arguments are passed by _value_, so all functions are _pure_ (i.e. they never modify their arguments). +Functions defined in script always take `Dynamic` arguments (i.e. the arguments can be of any type). +It is important to remember that all arguments are passed by _value_, so all functions are _pure_ (i.e. they never modify their arguments). Any update to an argument will **not** be reflected back to the caller. This can introduce subtle bugs, if you are not careful. ```rust @@ -990,7 +989,7 @@ x.change(); x == 500; // 'x' is NOT changed! ``` -Furthermore, functions can only be defined at the top level, never inside a block or another function. +Functions can only be defined at the top level, never inside a block or another function. ```rust // Top level is OK @@ -1008,6 +1007,22 @@ fn do_addition(x) { } ``` +Functions can be _overloaded_ based on the number of parameters (but not parameter types, since all parameters are `Dynamic`). +New definitions of the same name and number of parameters overwrite previous definitions. + +```rust +fn abc(x,y,z) { print("Three!!! " + x + "," + y + "," + z) } +fn abc(x) { print("One! " + x) } +fn abc(x,y) { print("Two! " + x + "," + y) } +fn abc() { print("None.") } +fn abc(x) { print("HA! NEW ONE! " + x) } // overwrites previous definition + +abc(1,2,3); // prints "Three!!! 1,2,3" +abc(42); // prints "HA! NEW ONE! 42" +abc(1,2); // prints "Two!! 1,2" +abc(); // prints "None." +``` + Members and methods ------------------- diff --git a/src/api.rs b/src/api.rs index c1c42f83..31416fbb 100644 --- a/src/api.rs +++ b/src/api.rs @@ -174,9 +174,15 @@ impl<'e> Engine<'e> { let statements = { let AST(statements, functions) = ast; - functions.iter().for_each(|f| { - engine.script_functions.push(f.clone()); - }); + for f in functions { + match engine + .script_functions + .binary_search_by(|fn_def| fn_def.compare(&f.name, f.params.len())) + { + Ok(n) => engine.script_functions[n] = f.clone(), + Err(n) => engine.script_functions.insert(n, f.clone()), + } + } statements }; @@ -253,9 +259,15 @@ impl<'e> Engine<'e> { let statements = { let AST(ref statements, ref functions) = ast; - functions.iter().for_each(|f| { - self.script_functions.push(f.clone()); - }); + for f in functions { + match self + .script_functions + .binary_search_by(|fn_def| fn_def.compare(&f.name, f.params.len())) + { + Ok(n) => self.script_functions[n] = f.clone(), + Err(n) => self.script_functions.insert(n, f.clone()), + } + } statements }; @@ -308,9 +320,15 @@ impl<'e> Engine<'e> { ast: &AST, args: FnCallArgs, ) -> Result { - ast.1.iter().for_each(|f| { - engine.script_functions.push(f.clone()); - }); + for f in &ast.1 { + match engine + .script_functions + .binary_search_by(|fn_def| fn_def.compare(&f.name, f.params.len())) + { + Ok(n) => engine.script_functions[n] = f.clone(), + Err(n) => engine.script_functions.insert(n, f.clone()), + } + } let result = engine.call_fn_raw(name, args, None, Position::none()); diff --git a/src/engine.rs b/src/engine.rs index 7cf8b114..8bfc4eb9 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -41,7 +41,7 @@ enum IndexSourceType { Expression, } -#[derive(Debug, Eq, PartialEq, Hash)] +#[derive(Debug, Eq, PartialEq, Hash, Clone)] pub struct FnSpec<'a> { pub name: Cow<'a, str>, pub args: Option>, @@ -82,10 +82,10 @@ impl Engine<'_> { pub fn new() -> Self { // User-friendly names for built-in types let type_names = [ - (type_name::(), "string"), - (type_name::(), "dynamic"), #[cfg(not(feature = "no_index"))] (type_name::(), "array"), + (type_name::(), "string"), + (type_name::(), "dynamic"), ] .iter() .map(|(k, v)| (k.to_string(), v.to_string())) @@ -135,22 +135,11 @@ impl Engine<'_> { ); // First search in script-defined functions (can override built-in) - if let Some(func) = self + if let Ok(n) = self .script_functions - .iter() - .rev() - .find(|fn_def| fn_def.name == fn_name) - .map(|fn_def| fn_def.clone()) + .binary_search_by(|f| f.compare(fn_name, args.len())) { - // First check number of parameters - if func.params.len() != args.len() { - return Err(EvalAltResult::ErrorFunctionArgsMismatch( - fn_name.into(), - func.params.len(), - args.len(), - pos, - )); - } + let func = self.script_functions[n].clone(); let mut scope = Scope::new(); @@ -838,13 +827,9 @@ impl Engine<'_> { Expr::Array(contents, _) => { let mut arr = Vec::new(); - contents - .iter() - .try_for_each::<_, Result<_, EvalAltResult>>(|item| { - let arg = self.eval_expr(scope, item)?; - arr.push(arg); - Ok(()) - })?; + for item in contents { + arr.push(self.eval_expr(scope, item)?); + } Ok(Box::new(arr)) } diff --git a/src/parser.rs b/src/parser.rs index 038680cb..613d205a 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -1944,7 +1944,15 @@ fn parse_top_level<'a>( while input.peek().is_some() { match input.peek() { #[cfg(not(feature = "no_function"))] - Some(&(Token::Fn, _)) => functions.push(parse_fn(input)?), + Some(&(Token::Fn, _)) => { + let f = parse_fn(input)?; + + // Ensure list is sorted + match functions.binary_search_by(|fn_def| fn_def.compare(&f.name, f.params.len())) { + Ok(n) => functions[n] = f, // Override previous definition + Err(n) => functions.insert(n, f), // New function definition + } + } _ => statements.push(parse_stmt(input)?), } diff --git a/tests/internal_fn.rs b/tests/internal_fn.rs index 8e38f0d9..4a6720d5 100644 --- a/tests/internal_fn.rs +++ b/tests/internal_fn.rs @@ -30,3 +30,25 @@ fn test_big_internal_fn() -> Result<(), EvalAltResult> { Ok(()) } + +#[test] +fn test_internal_fn_overloading() -> Result<(), EvalAltResult> { + let mut engine = Engine::new(); + + assert_eq!( + engine.eval::( + r#" + fn abc(x,y,z) { 2*x + 3*y + 4*z + 888 } + fn abc(x) { x + 42 } + fn abc(x,y) { x + 2*y + 88 } + fn abc() { 42 } + fn abc(x) { x - 42 } // should override previous definition + + abc() + abc(1) + abc(1,2) + abc(1,2,3) + "# + )?, + 1002 + ); + + Ok(()) +}