From 1e66f1963aea66513b770f908a3697557005e920 Mon Sep 17 00:00:00 2001 From: Stephen Chung Date: Mon, 7 Jun 2021 11:01:16 +0800 Subject: [PATCH] Add counter variable to for statement. --- CHANGELOG.md | 1 + src/ast.rs | 10 +++--- src/engine.rs | 43 +++++++++++++++++------ src/optimize.rs | 4 +-- src/parse_error.rs | 4 +++ src/parser.rs | 87 ++++++++++++++++++++++++++++++++++++++-------- tests/for.rs | 16 +++++++++ 7 files changed, 132 insertions(+), 33 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e616dc13..1e4cfb75 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ Breaking changes New features ------------ +* New syntax for `for` statement to include counter variable. * An integer value can now be indexed to get/set a single bit. * The `bits` method of an integer can be used to iterate through its bits. diff --git a/src/ast.rs b/src/ast.rs index 2adba823..4d6b3b1c 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -944,8 +944,8 @@ pub enum Stmt { While(Expr, Box, Position), /// `do` `{` stmt `}` `while`|`until` expr Do(Box, Expr, bool, Position), - /// `for` id `in` expr `{` stmt `}` - For(Expr, Box<(Ident, StmtBlock)>, Position), + /// `for` `(` id `,` counter `)` `in` expr `{` stmt `}` + For(Expr, Box<(Ident, Option, StmtBlock)>, Position), /// \[`export`\] `let` id `=` expr Let(Expr, Box, bool, Position), /// \[`export`\] `const` id `=` expr @@ -1166,7 +1166,7 @@ impl Stmt { Self::While(condition, block, _) | Self::Do(block, condition, _, _) => { condition.is_pure() && block.0.iter().all(Stmt::is_pure) } - Self::For(iterable, x, _) => iterable.is_pure() && (x.1).0.iter().all(Stmt::is_pure), + Self::For(iterable, x, _) => iterable.is_pure() && (x.2).0.iter().all(Stmt::is_pure), Self::Let(_, _, _, _) | Self::Const(_, _, _, _) | Self::Assignment(_, _) @@ -1286,7 +1286,7 @@ impl Stmt { if !e.walk(path, on_node) { return false; } - for s in &(x.1).0 { + for s in &(x.2).0 { if !s.walk(path, on_node) { return false; } @@ -1777,7 +1777,7 @@ impl fmt::Debug for Expr { } f.write_str(&x.2)?; match i.map_or_else(|| x.0, |n| NonZeroUsize::new(n.get() as usize)) { - Some(n) => write!(f, ", {}", n)?, + Some(n) => write!(f, " #{}", n)?, _ => (), } f.write_str(")") diff --git a/src/engine.rs b/src/engine.rs index f50656b3..51edbe01 100644 --- a/src/engine.rs +++ b/src/engine.rs @@ -14,7 +14,7 @@ use crate::token::Token; use crate::utils::get_hasher; use crate::{ Dynamic, EvalAltResult, Identifier, ImmutableString, Module, Position, RhaiResult, Scope, - Shared, StaticVec, + Shared, StaticVec, INT, }; #[cfg(feature = "no_std")] use std::prelude::v1::*; @@ -2521,7 +2521,7 @@ impl Engine { // For loop Stmt::For(expr, x, _) => { - let (Ident { name, .. }, statements) = x.as_ref(); + let (Ident { name, .. }, counter, statements) = x.as_ref(); let iter_obj = self .eval_expr(scope, mods, state, lib, this_ptr, expr, level)? .flatten(); @@ -2550,17 +2550,40 @@ impl Engine { }); if let Some(func) = func { - // Add the loop variable - let var_name: Cow<'_, str> = if state.is_global() { - name.to_string().into() + // Add the loop variables + let orig_scope_len = scope.len(); + let counter_index = if let Some(Ident { name, .. }) = counter { + scope.push(unsafe_cast_var_name_to_lifetime(name), 0 as INT); + Some(scope.len() - 1) } else { - unsafe_cast_var_name_to_lifetime(name).into() + None }; - scope.push(var_name, ()); + scope.push(unsafe_cast_var_name_to_lifetime(name), ()); let index = scope.len() - 1; state.scope_level += 1; - for iter_value in func(iter_obj) { + for (x, iter_value) in func(iter_obj).enumerate() { + // Increment counter + if let Some(c) = counter_index { + #[cfg(not(feature = "unchecked"))] + if x > INT::MAX as usize { + return EvalAltResult::ErrorArithmetic( + format!("for-loop counter overflow: {}", x), + counter + .as_ref() + .expect("never fails because `counter` is `Some`") + .pos, + ) + .into(); + } + + let mut counter_var = scope + .get_mut_by_index(c) + .write_lock::() + .expect("never fails because the counter always holds an `INT`"); + *counter_var = x as INT; + } + let loop_var = scope.get_mut_by_index(index); let value = iter_value.flatten(); @@ -2600,7 +2623,7 @@ impl Engine { } state.scope_level -= 1; - scope.rewind(scope.len() - 1); + scope.rewind(orig_scope_len); Ok(Dynamic::UNIT) } else { EvalAltResult::ErrorFor(expr.position()).into() @@ -2672,8 +2695,6 @@ impl Engine { } #[cfg(not(feature = "no_object"))] _ => { - use crate::INT; - let mut err_map: Map = Default::default(); let err_pos = err.take_position(); diff --git a/src/optimize.rs b/src/optimize.rs index b28124b5..a09a1557 100644 --- a/src/optimize.rs +++ b/src/optimize.rs @@ -594,8 +594,8 @@ fn optimize_stmt(stmt: &mut Stmt, state: &mut State, preserve_result: bool) { // for id in expr { block } Stmt::For(iterable, x, _) => { optimize_expr(iterable, state, false); - let body = mem::take(x.1.statements()).into_vec(); - *x.1.statements() = optimize_stmt_block(body, state, false, true, false).into(); + let body = mem::take(x.2.statements()).into_vec(); + *x.2.statements() = optimize_stmt_block(body, state, false, true, false).into(); } // let id = expr; Stmt::Let(expr, _, _, _) => optimize_expr(expr, state, false), diff --git a/src/parse_error.rs b/src/parse_error.rs index 87454019..ea32e35e 100644 --- a/src/parse_error.rs +++ b/src/parse_error.rs @@ -116,6 +116,8 @@ pub enum ParseErrorType { DuplicatedProperty(String), /// A `switch` case is duplicated. DuplicatedSwitchCase, + /// A variable name is duplicated. Wrapped value is the variable name. + DuplicatedVariable(String), /// The default case of a `switch` statement is not the last. WrongSwitchDefaultCase, /// The case condition of a `switch` statement is not appropriate. @@ -200,6 +202,7 @@ impl ParseErrorType { Self::MalformedCapture(_) => "Invalid capturing", Self::DuplicatedProperty(_) => "Duplicated property in object map literal", Self::DuplicatedSwitchCase => "Duplicated switch case", + Self::DuplicatedVariable(_) => "Duplicated variable name", Self::WrongSwitchDefaultCase => "Default switch case is not the last", Self::WrongSwitchCaseCondition => "Default switch case cannot have condition", Self::PropertyExpected => "Expecting name of a property", @@ -247,6 +250,7 @@ impl fmt::Display for ParseErrorType { write!(f, "Duplicated property '{}' for object map literal", s) } Self::DuplicatedSwitchCase => f.write_str(self.desc()), + Self::DuplicatedVariable(s) => write!(f, "Duplicated variable name '{}'", s), Self::ExprExpected(s) => write!(f, "Expecting {} expression", s), diff --git a/src/parser.rs b/src/parser.rs index 9897b613..a843182d 100644 --- a/src/parser.rs +++ b/src/parser.rs @@ -2144,6 +2144,21 @@ fn parse_for( lib: &mut FunctionsLib, mut settings: ParseSettings, ) -> Result { + fn get_name_pos(input: &mut TokenStream) -> Result<(String, Position), ParseError> { + match input.next().expect(NEVER_ENDS) { + // Variable name + (Token::Identifier(s), pos) => Ok((s, pos)), + // Reserved keyword + (Token::Reserved(s), pos) if is_valid_identifier(s.chars()) => { + Err(PERR::Reserved(s).into_err(pos)) + } + // Bad identifier + (Token::LexError(err), pos) => Err(err.into_err(pos)), + // Not a variable name + (_, pos) => Err(PERR::VariableExpected.into_err(pos)), + } + } + #[cfg(not(feature = "unchecked"))] settings.ensure_level_within_max_limit(state.max_expr_depth)?; @@ -2151,17 +2166,36 @@ fn parse_for( settings.pos = eat_token(input, Token::For); // for name ... - let (name, name_pos) = match input.next().expect(NEVER_ENDS) { - // Variable name - (Token::Identifier(s), pos) => (s, pos), - // Reserved keyword - (Token::Reserved(s), pos) if is_valid_identifier(s.chars()) => { - return Err(PERR::Reserved(s).into_err(pos)); + let (name, name_pos, counter_name, counter_pos) = if match_token(input, Token::LeftParen).0 { + // ( name, counter ) + let (name, name_pos) = get_name_pos(input)?; + let (has_comma, pos) = match_token(input, Token::Comma); + if !has_comma { + return Err(PERR::MissingToken( + Token::Comma.into(), + "after the iteration variable name".into(), + ) + .into_err(pos)); } - // Bad identifier - (Token::LexError(err), pos) => return Err(err.into_err(pos)), - // Not a variable name - (_, pos) => return Err(PERR::VariableExpected.into_err(pos)), + let (counter_name, counter_pos) = get_name_pos(input)?; + + if counter_name == name { + return Err(PERR::DuplicatedVariable(counter_name).into_err(counter_pos)); + } + + let (has_close_paren, pos) = match_token(input, Token::RightParen); + if !has_close_paren { + return Err(PERR::MissingToken( + Token::RightParen.into(), + "to close the iteration variable".into(), + ) + .into_err(pos)); + } + (name, name_pos, Some(counter_name), Some(counter_pos)) + } else { + // name + let (name, name_pos) = get_name_pos(input)?; + (name, name_pos, None, None) }; // for name in ... @@ -2180,8 +2214,18 @@ fn parse_for( ensure_not_statement_expr(input, "a boolean")?; let expr = parse_expr(input, state, lib, settings.level_up())?; - let loop_var = state.get_identifier(name); let prev_stack_len = state.stack.len(); + + let counter_var = if let Some(name) = counter_name { + let counter_var = state.get_identifier(name); + state + .stack + .push((counter_var.clone(), AccessMode::ReadWrite)); + Some(counter_var) + } else { + None + }; + let loop_var = state.get_identifier(name); state.stack.push((loop_var.clone(), AccessMode::ReadWrite)); settings.is_breakable = true; @@ -2196,6 +2240,10 @@ fn parse_for( name: loop_var, pos: name_pos, }, + counter_var.map(|name| Ident { + name, + pos: counter_pos.expect("never fails because `counter_var` is `Some`"), + }), body.into(), )), settings.pos, @@ -2342,10 +2390,19 @@ fn parse_export( let rename = if match_token(input, Token::As).0 { match input.next().expect(NEVER_ENDS) { - (Token::Identifier(s), pos) => Some(Ident { - name: state.get_identifier(s), - pos, - }), + (Token::Identifier(s), pos) => { + if exports.iter().any(|(_, alias)| match alias { + Some(Ident { name, .. }) if name == &s => true, + _ => false, + }) { + return Err(PERR::DuplicatedVariable(s).into_err(pos)); + } + + Some(Ident { + name: state.get_identifier(s), + pos, + }) + } (Token::Reserved(s), pos) if is_valid_identifier(s.chars()) => { return Err(PERR::Reserved(s).into_err(pos)); } diff --git a/tests/for.rs b/tests/for.rs index 59986344..8ec462ec 100644 --- a/tests/for.rs +++ b/tests/for.rs @@ -37,6 +37,22 @@ fn test_for_loop() -> Result<(), Box> { 35 ); + #[cfg(not(feature = "no_index"))] + assert_eq!( + engine.eval::( + " + let sum = 0; + let inputs = [1, 2, 3, 4, 5]; + + for (x, i) in inputs { + sum += x * (i + 1); + } + sum + " + )?, + 55 + ); + assert_eq!( engine.eval::( "