Support registering functions with a reference to the scripting engine.

This commit is contained in:
Stephen Chung 2020-06-15 21:49:02 +08:00
parent 31d2fa410b
commit a417bdd8e3
10 changed files with 151 additions and 46 deletions

View File

@ -2488,9 +2488,7 @@ This check can be disabled via the [`unchecked`] feature for higher performance
Be conservative when setting a maximum limit and always consider the fact that a registered function may grow
a string's length without Rhai noticing until the very end. For instance, the built-in '`+`' operator for strings
concatenates two strings together to form one longer string; if both strings are _slightly_ below the maximum
length limit, the resultant string may be almost _twice_ the maximum length. The '`pad`' function grows a string
to a specified length which may be longer than the allowed maximum
(to trap this risk, register a custom '`pad`' function that checks the arguments).
length limit, the resultant string may be almost _twice_ the maximum length.
### Maximum size of arrays
@ -2514,8 +2512,6 @@ Be conservative when setting a maximum limit and always consider the fact that a
an array's size without Rhai noticing until the very end.
For instance, the built-in '`+`' operator for arrays concatenates two arrays together to form one larger array;
if both arrays are _slightly_ below the maximum size limit, the resultant array may be almost _twice_ the maximum size.
The '`pad`' function grows an array to a specified size which may be larger than the allowed maximum
(to trap this risk, register a custom '`pad`' function that checks the arguments).
As a malicious script may create a deeply-nested array which consumes huge amounts of memory while each individual
array still stays under the maximum size limit, Rhai also recursively adds up the sizes of all strings, arrays

View File

@ -751,7 +751,6 @@ impl Engine {
) -> Result<AST, ParseError> {
let scripts = [script];
let stream = lex(&scripts, self.max_string_size);
{
let mut peekable = stream.peekable();
self.parse_global_expr(&mut peekable, scope, self.optimization_level)
@ -906,11 +905,8 @@ impl Engine {
let scripts = [script];
let stream = lex(&scripts, self.max_string_size);
let ast = self.parse_global_expr(
&mut stream.peekable(),
scope,
OptimizationLevel::None, // No need to optimize a lone expression
)?;
// No need to optimize a lone expression
let ast = self.parse_global_expr(&mut stream.peekable(), scope, OptimizationLevel::None)?;
self.eval_ast_with_scope(scope, &ast)
}
@ -983,6 +979,7 @@ impl Engine {
});
}
/// Evaluate an `AST` with own scope.
pub(crate) fn eval_ast_with_scope_raw(
&self,
scope: &mut Scope,
@ -1035,7 +1032,6 @@ impl Engine {
) -> Result<(), Box<EvalAltResult>> {
let scripts = [script];
let stream = lex(&scripts, self.max_string_size);
let ast = self.parse(&mut stream.peekable(), scope, self.optimization_level)?;
self.consume_ast_with_scope(scope, &ast)
}

View File

@ -645,7 +645,7 @@ impl Engine {
return Ok((result, false));
} else {
// Run external function
let result = func.get_native_fn()(args)?;
let result = func.get_native_fn()(self, args)?;
// Restore the original reference
restore_first_arg(old_this_ptr, args);
@ -1474,7 +1474,7 @@ impl Engine {
.or_else(|| self.packages.get_fn(hash_fn))
{
// Overriding exact implementation
func(&mut [lhs_ptr, &mut rhs_val])?;
func(self, &mut [lhs_ptr, &mut rhs_val])?;
} else if run_builtin_op_assignment(op, lhs_ptr, &rhs_val)?.is_none() {
// Not built in, map to `var = var op rhs`
let op = &op[..op.len() - 1]; // extract operator without =
@ -1705,7 +1705,9 @@ impl Engine {
self.call_script_fn(&mut scope, state, lib, name, fn_def, args, level)
.map_err(|err| EvalAltResult::new_position(err, *pos))
}
Ok(f) => f.get_native_fn()(args.as_mut()).map_err(|err| err.new_position(*pos)),
Ok(f) => {
f.get_native_fn()(self, args.as_mut()).map_err(|err| err.new_position(*pos))
}
Err(err)
if def_val.is_some()
&& matches!(*err, EvalAltResult::ErrorFunctionNotFound(_, _)) =>
@ -2112,7 +2114,7 @@ impl Engine {
)))
} else if arr > self.max_array_size {
Err(Box::new(EvalAltResult::ErrorDataTooLarge(
"Length of array".to_string(),
"Size of array".to_string(),
self.max_array_size,
arr,
Position::none(),

View File

@ -1,4 +1,5 @@
use crate::any::Dynamic;
use crate::engine::Engine;
use crate::parser::ScriptFnDef;
use crate::result::EvalAltResult;
@ -51,9 +52,10 @@ pub fn shared_take<T: Clone>(value: Shared<T>) -> T {
pub type FnCallArgs<'a> = [&'a mut Dynamic];
#[cfg(not(feature = "sync"))]
pub type FnAny = dyn Fn(&mut FnCallArgs) -> Result<Dynamic, Box<EvalAltResult>>;
pub type FnAny = dyn Fn(&Engine, &mut FnCallArgs) -> Result<Dynamic, Box<EvalAltResult>>;
#[cfg(feature = "sync")]
pub type FnAny = dyn Fn(&mut FnCallArgs) -> Result<Dynamic, Box<EvalAltResult>> + Send + Sync;
pub type FnAny =
dyn Fn(&Engine, &mut FnCallArgs) -> Result<Dynamic, Box<EvalAltResult>> + Send + Sync;
pub type IteratorFn = fn(Dynamic) -> Box<dyn Iterator<Item = Dynamic>>;

View File

@ -1,5 +1,4 @@
//! Module which defines the function registration mechanism.
#![allow(non_snake_case)]
use crate::any::{Dynamic, Variant};
@ -120,7 +119,7 @@ macro_rules! make_func {
// ^ function parameter generic type name (A, B, C etc.)
// ^ dereferencing function
Box::new(move |args: &mut FnCallArgs| {
Box::new(move |_: &Engine, args: &mut FnCallArgs| {
// The arguments are assumed to be of the correct number and types!
#[allow(unused_variables, unused_mut)]

View File

@ -305,6 +305,29 @@ impl Module {
hash_fn
}
/// Set a Rust function taking a reference to the scripting `Engine`, plus a list of
/// mutable `Dynamic` references into the module, returning a hash key.
/// A list of `TypeId`'s is taken as the argument types.
///
/// Use this to register a built-in function which must reference settings on the scripting
/// `Engine` (e.g. to prevent growing an array beyond the allowed maximum size).
///
/// If there is a similar existing Rust function, it is replaced.
pub(crate) fn set_fn_var_args<T: Variant + Clone>(
&mut self,
name: impl Into<String>,
args: &[TypeId],
func: impl Fn(&Engine, &mut [&mut Dynamic]) -> FuncReturn<T> + SendSync + 'static,
) -> u64 {
let f = move |engine: &Engine, args: &mut FnCallArgs| func(engine, args).map(Dynamic::from);
self.set_fn(
name,
Public,
args,
CallableFunction::from_method(Box::new(f)),
)
}
/// Set a Rust function taking no parameters into the module, returning a hash key.
///
/// If there is a similar existing Rust function, it is replaced.
@ -323,7 +346,7 @@ impl Module {
name: impl Into<String>,
func: impl Fn() -> FuncReturn<T> + SendSync + 'static,
) -> u64 {
let f = move |_: &mut FnCallArgs| func().map(Dynamic::from);
let f = move |_: &Engine, _: &mut FnCallArgs| func().map(Dynamic::from);
let args = [];
self.set_fn(
name,
@ -351,8 +374,9 @@ impl Module {
name: impl Into<String>,
func: impl Fn(A) -> FuncReturn<T> + SendSync + 'static,
) -> u64 {
let f =
move |args: &mut FnCallArgs| func(mem::take(args[0]).cast::<A>()).map(Dynamic::from);
let f = move |_: &Engine, args: &mut FnCallArgs| {
func(mem::take(args[0]).cast::<A>()).map(Dynamic::from)
};
let args = [TypeId::of::<A>()];
self.set_fn(
name,
@ -380,7 +404,7 @@ impl Module {
name: impl Into<String>,
func: impl Fn(&mut A) -> FuncReturn<T> + SendSync + 'static,
) -> u64 {
let f = move |args: &mut FnCallArgs| {
let f = move |_: &Engine, args: &mut FnCallArgs| {
func(args[0].downcast_mut::<A>().unwrap()).map(Dynamic::from)
};
let args = [TypeId::of::<A>()];
@ -434,7 +458,7 @@ impl Module {
name: impl Into<String>,
func: impl Fn(A, B) -> FuncReturn<T> + SendSync + 'static,
) -> u64 {
let f = move |args: &mut FnCallArgs| {
let f = move |_: &Engine, args: &mut FnCallArgs| {
let a = mem::take(args[0]).cast::<A>();
let b = mem::take(args[1]).cast::<B>();
@ -470,7 +494,7 @@ impl Module {
name: impl Into<String>,
func: impl Fn(&mut A, B) -> FuncReturn<T> + SendSync + 'static,
) -> u64 {
let f = move |args: &mut FnCallArgs| {
let f = move |_: &Engine, args: &mut FnCallArgs| {
let b = mem::take(args[1]).cast::<B>();
let a = args[0].downcast_mut::<A>().unwrap();
@ -561,7 +585,7 @@ impl Module {
name: impl Into<String>,
func: impl Fn(A, B, C) -> FuncReturn<T> + SendSync + 'static,
) -> u64 {
let f = move |args: &mut FnCallArgs| {
let f = move |_: &Engine, args: &mut FnCallArgs| {
let a = mem::take(args[0]).cast::<A>();
let b = mem::take(args[1]).cast::<B>();
let c = mem::take(args[2]).cast::<C>();
@ -603,7 +627,7 @@ impl Module {
name: impl Into<String>,
func: impl Fn(&mut A, B, C) -> FuncReturn<T> + SendSync + 'static,
) -> u64 {
let f = move |args: &mut FnCallArgs| {
let f = move |_: &Engine, args: &mut FnCallArgs| {
let b = mem::take(args[1]).cast::<B>();
let c = mem::take(args[2]).cast::<C>();
let a = args[0].downcast_mut::<A>().unwrap();
@ -640,7 +664,7 @@ impl Module {
&mut self,
func: impl Fn(&mut A, B, A) -> FuncReturn<()> + SendSync + 'static,
) -> u64 {
let f = move |args: &mut FnCallArgs| {
let f = move |_: &Engine, args: &mut FnCallArgs| {
let b = mem::take(args[1]).cast::<B>();
let c = mem::take(args[2]).cast::<A>();
let a = args[0].downcast_mut::<A>().unwrap();
@ -682,7 +706,7 @@ impl Module {
name: impl Into<String>,
func: impl Fn(A, B, C, D) -> FuncReturn<T> + SendSync + 'static,
) -> u64 {
let f = move |args: &mut FnCallArgs| {
let f = move |_: &Engine, args: &mut FnCallArgs| {
let a = mem::take(args[0]).cast::<A>();
let b = mem::take(args[1]).cast::<B>();
let c = mem::take(args[2]).cast::<C>();
@ -731,7 +755,7 @@ impl Module {
name: impl Into<String>,
func: impl Fn(&mut A, B, C, D) -> FuncReturn<T> + SendSync + 'static,
) -> u64 {
let f = move |args: &mut FnCallArgs| {
let f = move |_: &Engine, args: &mut FnCallArgs| {
let b = mem::take(args[1]).cast::<B>();
let c = mem::take(args[2]).cast::<C>();
let d = mem::take(args[3]).cast::<D>();
@ -1019,7 +1043,7 @@ pub trait ModuleResolver: SendSync {
/// Resolve a module based on a path string.
fn resolve(
&self,
engine: &Engine,
_: &Engine,
scope: Scope,
path: &str,
pos: Position,

View File

@ -2,9 +2,11 @@
use crate::any::{Dynamic, Variant};
use crate::def_package;
use crate::engine::Array;
use crate::engine::{Array, Engine};
use crate::module::FuncReturn;
use crate::parser::{ImmutableString, INT};
use crate::result::EvalAltResult;
use crate::token::Position;
use crate::stdlib::{any::TypeId, boxed::Box};
@ -23,13 +25,28 @@ fn ins<T: Variant + Clone>(list: &mut Array, position: INT, item: T) -> FuncRetu
}
Ok(())
}
fn pad<T: Variant + Clone>(list: &mut Array, len: INT, item: T) -> FuncReturn<()> {
if len >= 0 {
fn pad<T: Variant + Clone>(engine: &Engine, args: &mut [&mut Dynamic]) -> FuncReturn<()> {
let len = *args[1].downcast_ref::<INT>().unwrap();
// Check if array will be over max size limit
if engine.max_array_size > 0 && len > 0 && (len as usize) > engine.max_array_size {
Err(Box::new(EvalAltResult::ErrorDataTooLarge(
"Size of array".to_string(),
engine.max_array_size,
len as usize,
Position::none(),
)))
} else if len >= 0 {
let item = args[2].downcast_ref::<T>().unwrap().clone();
let list = args[0].downcast_mut::<Array>().unwrap();
while list.len() < len as usize {
push(list, item.clone())?;
}
}
Ok(())
} else {
Ok(())
}
}
macro_rules! reg_op {
@ -42,11 +59,21 @@ macro_rules! reg_tri {
$( $lib.set_fn_3_mut($op, $func::<$par>); )*
};
}
macro_rules! reg_pad {
($lib:expr, $op:expr, $func:ident, $($par:ty),*) => {
$({
$lib.set_fn_var_args($op,
&[TypeId::of::<Array>(), TypeId::of::<INT>(), TypeId::of::<$par>()],
$func::<$par>
);
})*
};
}
#[cfg(not(feature = "no_index"))]
def_package!(crate:BasicArrayPackage:"Basic array utilities.", lib, {
reg_op!(lib, "push", push, INT, bool, char, ImmutableString, Array, ());
reg_tri!(lib, "pad", pad, INT, bool, char, ImmutableString, Array, ());
reg_pad!(lib, "pad", pad, INT, bool, char, ImmutableString, Array, ());
reg_tri!(lib, "insert", ins, INT, bool, char, ImmutableString, Array, ());
lib.set_fn_2_mut("append", |x: &mut Array, y: Array| {
@ -69,14 +96,14 @@ def_package!(crate:BasicArrayPackage:"Basic array utilities.", lib, {
#[cfg(not(feature = "only_i64"))]
{
reg_op!(lib, "push", push, i8, u8, i16, u16, i32, i64, u32, u64, i128, u128);
reg_tri!(lib, "pad", pad, i8, u8, i16, u16, i32, u32, i64, u64, i128, u128);
reg_pad!(lib, "pad", pad, i8, u8, i16, u16, i32, u32, i64, u64, i128, u128);
reg_tri!(lib, "insert", ins, i8, u8, i16, u16, i32, i64, u32, u64, i128, u128);
}
#[cfg(not(feature = "no_float"))]
{
reg_op!(lib, "push", push, f32, f64);
reg_tri!(lib, "pad", pad, f32, f64);
reg_pad!(lib, "pad", pad, f32, f64);
reg_tri!(lib, "insert", ins, f32, f64);
}

View File

@ -1,12 +1,17 @@
use crate::any::Dynamic;
use crate::def_package;
use crate::engine::Engine;
use crate::module::FuncReturn;
use crate::parser::{ImmutableString, INT};
use crate::result::EvalAltResult;
use crate::token::Position;
use crate::utils::StaticVec;
#[cfg(not(feature = "no_index"))]
use crate::engine::Array;
use crate::stdlib::{
any::TypeId,
fmt::Display,
format,
string::{String, ToString},
@ -210,14 +215,40 @@ def_package!(crate:MoreStringPackage:"Additional string utilities, including str
Ok(())
},
);
lib.set_fn_3_mut(
lib.set_fn_var_args(
"pad",
|s: &mut ImmutableString, len: INT, ch: char| {
&[TypeId::of::<ImmutableString>(), TypeId::of::<INT>(), TypeId::of::<char>()],
|engine: &Engine, args: &mut [&mut Dynamic]| {
let len = *args[1].downcast_ref::< INT>().unwrap();
// Check if string will be over max size limit
if engine.max_string_size > 0 && len > 0 && (len as usize) > engine.max_string_size {
Err(Box::new(EvalAltResult::ErrorDataTooLarge(
"Length of string".to_string(),
engine.max_string_size,
len as usize,
Position::none(),
)))
} else {
let ch = *args[2].downcast_ref::< char>().unwrap();
let s = args[0].downcast_mut::<ImmutableString>().unwrap();
let copy = s.make_mut();
for _ in 0..copy.chars().count() - len as usize {
copy.push(ch);
}
if engine.max_string_size > 0 && copy.len() > engine.max_string_size {
Err(Box::new(EvalAltResult::ErrorDataTooLarge(
"Length of string".to_string(),
engine.max_string_size,
copy.len(),
Position::none(),
)))
} else {
Ok(())
}
}
},
);
lib.set_fn_3_mut(

View File

@ -35,6 +35,19 @@ fn test_max_string_size() -> Result<(), Box<EvalAltResult>> {
EvalAltResult::ErrorDataTooLarge(_, 10, 13, _)
));
assert!(matches!(
*engine
.eval::<String>(
r#"
let x = "hello";
x.pad(100, '!');
x
"#
)
.expect_err("should error"),
EvalAltResult::ErrorDataTooLarge(_, 10, 100, _)
));
engine.set_max_string_size(0);
assert_eq!(
@ -79,6 +92,18 @@ fn test_max_array_size() -> Result<(), Box<EvalAltResult>> {
.expect_err("should error"),
EvalAltResult::ErrorDataTooLarge(_, 10, 12, _)
));
assert!(matches!(
*engine
.eval::<Array>(
r"
let x = [1,2,3,4,5,6];
x.pad(100, 42);
x
"
)
.expect_err("should error"),
EvalAltResult::ErrorDataTooLarge(_, 10, 100, _)
));
assert!(matches!(
*engine

View File

@ -1,7 +1,9 @@
#![cfg(not(feature = "no_module"))]
use rhai::{
module_resolvers, Engine, EvalAltResult, Module, ParseError, ParseErrorType, Scope, INT,
module_resolvers, Dynamic, Engine, EvalAltResult, Module, ParseError, ParseErrorType, Scope,
INT,
};
use std::any::TypeId;
#[test]
fn test_module() {
@ -20,6 +22,7 @@ fn test_module_sub_module() -> Result<(), Box<EvalAltResult>> {
let mut sub_module2 = Module::new();
sub_module2.set_var("answer", 41 as INT);
let hash_inc = sub_module2.set_fn_1("inc", |x: INT| Ok(x + 1));
sub_module.set_sub_module("universe", sub_module2);