diff --git a/plugins/builtin/source/content/views/view_pattern_editor.cpp b/plugins/builtin/source/content/views/view_pattern_editor.cpp index 4de10b77e..018a1bdd3 100644 --- a/plugins/builtin/source/content/views/view_pattern_editor.cpp +++ b/plugins/builtin/source/content/views/view_pattern_editor.cpp @@ -25,7 +25,12 @@ namespace hex::plugin::builtin { static TextEditor::LanguageDefinition langDef; if (!initialized) { static const char* const keywords[] = { - "using", "struct", "union", "enum", "bitfield", "be", "le", "if", "else", "false", "true", "this", "parent", "addressof", "sizeof", "$", "while", "for", "fn", "return", "namespace", "in", "out" + "using", "struct", "union", "enum", "bitfield", + "be", "le", "if", "else", "false", "true", + "this", "parent", "addressof", "sizeof", + "$", + "while", "for", "fn", "return", "break", "continue", + "namespace", "in", "out" }; for (auto& k : keywords) langDef.mKeywords.insert(k); diff --git a/plugins/libimhex/include/hex/pattern_language/ast_node.hpp b/plugins/libimhex/include/hex/pattern_language/ast_node.hpp index 721781aa8..9a7c7952a 100644 --- a/plugins/libimhex/include/hex/pattern_language/ast_node.hpp +++ b/plugins/libimhex/include/hex/pattern_language/ast_node.hpp @@ -63,8 +63,8 @@ namespace hex::pl { [[nodiscard]] virtual std::vector createPatterns(Evaluator *evaluator) const { return {}; } - using FunctionResult = std::pair>; - virtual FunctionResult execute(Evaluator *evaluator) const { throw std::pair(this->getLineNumber(), "cannot execute non-function statement"); } + using FunctionResult = std::optional; + virtual FunctionResult execute(Evaluator *evaluator) const { evaluator->getConsole().abortEvaluation("cannot execute non-function statement", this); } private: u32 m_lineNumber = 1; @@ -543,14 +543,16 @@ namespace hex::pl { class ASTNodeWhileStatement : public ASTNode { public: - explicit ASTNodeWhileStatement(ASTNode *condition, std::vector body) - : ASTNode(), m_condition(condition), m_body(std::move(body)) { } + explicit ASTNodeWhileStatement(ASTNode *condition, std::vector body, ASTNode *postExpression = nullptr) + : ASTNode(), m_condition(condition), m_body(std::move(body)), m_postExpression(postExpression) { } ~ASTNodeWhileStatement() override { delete this->m_condition; for (auto &statement : this->m_body) delete statement; + + delete this->m_postExpression; } ASTNodeWhileStatement(const ASTNodeWhileStatement &other) : ASTNode(other) { @@ -558,6 +560,8 @@ namespace hex::pl { for (auto &statement : other.m_body) this->m_body.push_back(statement->clone()); + + this->m_postExpression = other.m_postExpression->clone(); } [[nodiscard]] ASTNode* clone() const override { @@ -593,21 +597,34 @@ namespace hex::pl { evaluator->pushScope(nullptr, variables); ON_SCOPE_EXIT { evaluator->popScope(); }; + auto ctrlFlow = ControlFlowStatement::None; for (auto &statement : this->m_body) { - auto [executionStopped, result] = statement->execute(evaluator); - if (executionStopped) { - return { true, result }; - } + auto result = statement->execute(evaluator); + + ctrlFlow = evaluator->getCurrentControlFlowStatement(); + evaluator->setCurrentControlFlowStatement(ControlFlowStatement::None); + if (ctrlFlow == ControlFlowStatement::Return) + return result; + else if (ctrlFlow != ControlFlowStatement::None) + break; } + if (this->m_postExpression != nullptr) + this->m_postExpression->execute(evaluator); + loopIterations++; if (loopIterations >= evaluator->getLoopLimit()) LogConsole::abortEvaluation(hex::format("loop iterations exceeded limit of {}", evaluator->getLoopLimit()), this); evaluator->handleAbort(); + + if (ctrlFlow == ControlFlowStatement::Break) + break; + else if (ctrlFlow == ControlFlowStatement::Continue) + continue; } - return { false, { } }; + return { }; } [[nodiscard]] @@ -625,6 +642,7 @@ namespace hex::pl { private: ASTNode *m_condition; std::vector m_body; + ASTNode *m_postExpression; }; inline void applyVariableAttributes(Evaluator *evaluator, const Attributable *attributable, PatternData *pattern) { @@ -768,7 +786,7 @@ namespace hex::pl { FunctionResult execute(Evaluator *evaluator) const override { evaluator->createVariable(this->getName(), this->getType()); - return { false, { } }; + return { }; } private: @@ -942,22 +960,26 @@ namespace hex::pl { }; size_t size = 0; - u64 entryCount = 0; + u64 entryIndex = 0; + + auto addEntry = [&](PatternData *pattern) { + pattern->setVariableName(hex::format("[{}]", entryIndex)); + pattern->setEndian(arrayPattern->getEndian()); + pattern->setColor(arrayPattern->getColor()); + entries.push_back(pattern); + + size += pattern->getSize(); + entryIndex++; + + evaluator->handleAbort(); + }; if (this->m_size != nullptr) { auto sizeNode = this->m_size->evaluate(evaluator); ON_SCOPE_EXIT { delete sizeNode; }; - { - auto templatePattern = this->m_type->createPatterns(evaluator).front(); - ON_SCOPE_EXIT { delete templatePattern; }; - - arrayPattern->setTypeName(templatePattern->getTypeName()); - evaluator->dataOffset() -= templatePattern->getSize(); - } - if (auto literal = dynamic_cast(sizeNode)) { - entryCount = std::visit(overloaded{ + auto entryCount = std::visit(overloaded{ [this](std::string) -> u128 { LogConsole::abortEvaluation("cannot use string to index array", this); }, [this](PatternData*) -> u128 { LogConsole::abortEvaluation("cannot use custom type to index array", this); }, [](auto &&size) -> u128 { return size; } @@ -970,38 +992,23 @@ namespace hex::pl { for (u64 i = 0; i < entryCount; i++) { auto pattern = this->m_type->createPatterns(evaluator).front(); - pattern->setVariableName(hex::format("[{}]", i)); - pattern->setEndian(arrayPattern->getEndian()); - pattern->setColor(arrayPattern->getColor()); - entries.push_back(pattern); - - size += pattern->getSize(); - - evaluator->handleAbort(); + addEntry(pattern); } } else if (auto whileStatement = dynamic_cast(sizeNode)) { while (whileStatement->evaluateCondition(evaluator)) { auto limit = evaluator->getArrayLimit(); - if (entryCount > limit) + if (entryIndex > limit) LogConsole::abortEvaluation(hex::format("array grew past set limit of {}", limit), this); auto pattern = this->m_type->createPatterns(evaluator).front(); - pattern->setVariableName(hex::format("[{}]", entryCount)); - pattern->setEndian(arrayPattern->getEndian()); - pattern->setColor(arrayPattern->getColor()); - entries.push_back(pattern); - - entryCount++; - size += pattern->getSize(); - - evaluator->handleAbort(); + addEntry(pattern); } } } else { while (true) { auto limit = evaluator->getArrayLimit(); - if (entryCount > limit) + if (entryIndex > limit) LogConsole::abortEvaluation(hex::format("array grew past set limit of {}", limit), this); auto pattern = this->m_type->createPatterns(evaluator).front(); @@ -1012,13 +1019,7 @@ namespace hex::pl { LogConsole::abortEvaluation("reached end of file before finding end of unsized array", this); } - pattern->setVariableName(hex::format("[{}]", entryCount)); - pattern->setEndian(arrayPattern->getEndian()); - pattern->setColor(arrayPattern->getColor()); - entries.push_back(pattern); - - size += pattern->getSize(); - entryCount++; + addEntry(pattern); evaluator->getProvider()->read(evaluator->dataOffset() - pattern->getSize(), buffer.data(), buffer.size()); bool reachedEnd = true; @@ -1030,13 +1031,15 @@ namespace hex::pl { } if (reachedEnd) break; - evaluator->handleAbort(); } } arrayPattern->setEntries(entries); arrayPattern->setSize(size); + if (auto &entries = arrayPattern->getEntries(); !entries.empty()) + arrayPattern->setTypeName(entries.front()->getTypeName()); + arrayCleanup.release(); return arrayPattern; @@ -1163,7 +1166,7 @@ namespace hex::pl { evaluator->createVariable(variableDecl->getName(), variableDecl->getType()->evaluate(evaluator)); } - return { false, { } }; + return { }; } private: @@ -1803,13 +1806,13 @@ namespace hex::pl { evaluator->pushScope(nullptr, variables); ON_SCOPE_EXIT { evaluator->popScope(); }; for (auto &statement : body) { - auto [executionStopped, result] = statement->execute(evaluator); - if (executionStopped) { - return { true, result }; + auto result = statement->execute(evaluator); + if (auto ctrlStatement = evaluator->getCurrentControlFlowStatement(); ctrlStatement != ControlFlowStatement::None) { + return result; } } - return { false, { } }; + return { }; } private: @@ -1933,7 +1936,7 @@ namespace hex::pl { FunctionResult execute(Evaluator *evaluator) const override { delete this->evaluate(evaluator); - return { false, { } }; + return { }; } private: @@ -2025,7 +2028,7 @@ namespace hex::pl { evaluator->setVariable(this->getLValueName(), literal->getValue()); - return { false, { } }; + return { }; } private: @@ -2033,21 +2036,22 @@ namespace hex::pl { ASTNode *m_rvalue; }; - class ASTNodeReturnStatement : public ASTNode { + class ASTNodeControlFlowStatement : public ASTNode { public: - explicit ASTNodeReturnStatement(ASTNode *rvalue) : m_rvalue(rvalue) { + explicit ASTNodeControlFlowStatement(ControlFlowStatement type, ASTNode *rvalue) : m_type(type), m_rvalue(rvalue) { } - ASTNodeReturnStatement(const ASTNodeReturnStatement &other) : ASTNode(other) { + ASTNodeControlFlowStatement(const ASTNodeControlFlowStatement &other) : ASTNode(other) { + this->m_type = other.m_type; this->m_rvalue = other.m_rvalue->clone(); } [[nodiscard]] ASTNode* clone() const override { - return new ASTNodeReturnStatement(*this); + return new ASTNodeControlFlowStatement(*this); } - ~ASTNodeReturnStatement() override { + ~ASTNodeControlFlowStatement() override { delete this->m_rvalue; } @@ -2058,17 +2062,20 @@ namespace hex::pl { FunctionResult execute(Evaluator *evaluator) const override { auto returnValue = this->getReturnValue(); + evaluator->setCurrentControlFlowStatement(this->m_type); + if (returnValue == nullptr) - return { true, std::nullopt }; + return std::nullopt; else { auto literal = dynamic_cast(returnValue->evaluate(evaluator)); ON_SCOPE_EXIT { delete literal; }; - return { true, literal->getValue() }; + return literal->getValue(); } } private: + ControlFlowStatement m_type; ASTNode *m_rvalue; }; @@ -2137,9 +2144,18 @@ namespace hex::pl { } for (auto statement : this->m_body) { - auto [executionStopped, result] = statement->execute(ctx); + auto result = statement->execute(ctx); - if (executionStopped) { + if (ctx->getCurrentControlFlowStatement() != ControlFlowStatement::None) { + switch (ctx->getCurrentControlFlowStatement()) { + case ControlFlowStatement::Break: + ctx->getConsole().abortEvaluation("break statement not within a loop", statement); + case ControlFlowStatement::Continue: + ctx->getConsole().abortEvaluation("continue statement not within a loop", statement); + default: break; + } + + ctx->setCurrentControlFlowStatement(ControlFlowStatement::None); return result; } } @@ -2214,7 +2230,7 @@ namespace hex::pl { for (const auto &statement : this->m_statements) { result = statement->execute(evaluator); - if (result.first) + if (evaluator->getCurrentControlFlowStatement() != ControlFlowStatement::None) return result; } diff --git a/plugins/libimhex/include/hex/pattern_language/evaluator.hpp b/plugins/libimhex/include/hex/pattern_language/evaluator.hpp index 74d62a542..888638026 100644 --- a/plugins/libimhex/include/hex/pattern_language/evaluator.hpp +++ b/plugins/libimhex/include/hex/pattern_language/evaluator.hpp @@ -21,6 +21,13 @@ namespace hex::pl { Allow }; + enum class ControlFlowStatement { + None, + Continue, + Break, + Return + }; + class PatternData; class PatternCreationLimiter; class ASTNode; @@ -206,6 +213,15 @@ namespace hex::pl { return this->m_allowDangerousFunctions; } + void setCurrentControlFlowStatement(ControlFlowStatement statement) { + this->m_currControlFlowStatement = statement; + } + + [[nodiscard]] + ControlFlowStatement getCurrentControlFlowStatement() const { + return this->m_currControlFlowStatement; + } + private: void patternCreated(); @@ -237,6 +253,7 @@ namespace hex::pl { std::atomic m_dangerousFunctionCalled = false; std::atomic m_allowDangerousFunctions = DangerousFunctionPermission::Ask; + ControlFlowStatement m_currControlFlowStatement; friend class PatternCreationLimiter; }; diff --git a/plugins/libimhex/include/hex/pattern_language/parser.hpp b/plugins/libimhex/include/hex/pattern_language/parser.hpp index 1f95f0b6c..43eb615b3 100644 --- a/plugins/libimhex/include/hex/pattern_language/parser.hpp +++ b/plugins/libimhex/include/hex/pattern_language/parser.hpp @@ -93,7 +93,7 @@ namespace hex::pl { ASTNode* parseFunctionVariableDecl(); ASTNode* parseFunctionStatement(); ASTNode* parseFunctionVariableAssignment(); - ASTNode* parseFunctionReturnStatement(); + ASTNode* parseFunctionControlFlowStatement(); std::vector parseStatementBody(); ASTNode* parseFunctionConditional(); ASTNode* parseFunctionWhileLoop(); diff --git a/plugins/libimhex/include/hex/pattern_language/token.hpp b/plugins/libimhex/include/hex/pattern_language/token.hpp index dcfde4125..98a4d8355 100644 --- a/plugins/libimhex/include/hex/pattern_language/token.hpp +++ b/plugins/libimhex/include/hex/pattern_language/token.hpp @@ -42,7 +42,9 @@ namespace hex::pl { Return, Namespace, In, - Out + Out, + Break, + Continue }; enum class Operator { @@ -289,6 +291,8 @@ namespace hex::pl { #define KEYWORD_NAMESPACE COMPONENT(Keyword, Namespace) #define KEYWORD_IN COMPONENT(Keyword, In) #define KEYWORD_OUT COMPONENT(Keyword, Out) +#define KEYWORD_BREAK COMPONENT(Keyword, Break) +#define KEYWORD_CONTINUE COMPONENT(Keyword, Continue) #define INTEGER hex::pl::Token::Type::Integer, hex::pl::Token::Literal(u128(0)) #define IDENTIFIER hex::pl::Token::Type::Identifier, "" diff --git a/plugins/libimhex/source/pattern_language/evaluator.cpp b/plugins/libimhex/source/pattern_language/evaluator.cpp index 56f728c28..7b9dd638a 100644 --- a/plugins/libimhex/source/pattern_language/evaluator.cpp +++ b/plugins/libimhex/source/pattern_language/evaluator.cpp @@ -146,6 +146,7 @@ namespace hex::pl { std::vector patterns; try { + this->setCurrentControlFlowStatement(ControlFlowStatement::None); pushScope(nullptr, patterns); for (auto node : ast) { diff --git a/plugins/libimhex/source/pattern_language/lexer.cpp b/plugins/libimhex/source/pattern_language/lexer.cpp index d153837b5..45f16f872 100644 --- a/plugins/libimhex/source/pattern_language/lexer.cpp +++ b/plugins/libimhex/source/pattern_language/lexer.cpp @@ -420,6 +420,10 @@ namespace hex::pl { tokens.emplace_back(TOKEN(Keyword, In)); else if (identifier == "out") tokens.emplace_back(TOKEN(Keyword, Out)); + else if (identifier == "break") + tokens.emplace_back(TOKEN(Keyword, Break)); + else if (identifier == "continue") + tokens.emplace_back(TOKEN(Keyword, Continue)); // Check for built-in types else if (identifier == "u8") diff --git a/plugins/libimhex/source/pattern_language/parser.cpp b/plugins/libimhex/source/pattern_language/parser.cpp index 495c049a3..8c7b7bbca 100644 --- a/plugins/libimhex/source/pattern_language/parser.cpp +++ b/plugins/libimhex/source/pattern_language/parser.cpp @@ -496,8 +496,8 @@ namespace hex::pl { if (MATCHES(sequence(IDENTIFIER, OPERATOR_ASSIGNMENT))) statement = parseFunctionVariableAssignment(); - else if (MATCHES(sequence(KEYWORD_RETURN))) - statement = parseFunctionReturnStatement(); + else if (MATCHES(oneOf(KEYWORD_RETURN, KEYWORD_BREAK, KEYWORD_CONTINUE))) + statement = parseFunctionControlFlowStatement(); else if (MATCHES(sequence(KEYWORD_IF, SEPARATOR_ROUNDBRACKETOPEN))) { statement = parseFunctionConditional(); needsSemicolon = false; @@ -546,11 +546,21 @@ namespace hex::pl { return create(new ASTNodeAssignment(lvalue, rvalue)); } - ASTNode* Parser::parseFunctionReturnStatement() { - if (peek(SEPARATOR_ENDOFEXPRESSION)) - return create(new ASTNodeReturnStatement(nullptr)); + ASTNode* Parser::parseFunctionControlFlowStatement() { + ControlFlowStatement type; + if (peek(KEYWORD_RETURN, -1)) + type = ControlFlowStatement::Return; + else if (peek(KEYWORD_BREAK, -1)) + type = ControlFlowStatement::Break; + else if (peek(KEYWORD_CONTINUE, -1)) + type = ControlFlowStatement::Continue; else - return create(new ASTNodeReturnStatement(this->parseMathematicalExpression())); + throwParseError("invalid control flow statement. Expected 'return', 'break' or 'continue'"); + + if (peek(SEPARATOR_ENDOFEXPRESSION)) + return create(new ASTNodeControlFlowStatement(type, nullptr)); + else + return create(new ASTNodeControlFlowStatement(type, this->parseMathematicalExpression())); } std::vector Parser::parseStatementBody() { @@ -650,14 +660,12 @@ namespace hex::pl { body = parseStatementBody(); - body.push_back(postExpression); - variableCleanup.release(); conditionCleanup.release(); postExpressionCleanup.release(); bodyCleanup.release(); - return create(new ASTNodeCompoundStatement({ variable, create(new ASTNodeWhileStatement(condition, body)) }, true)); + return create(new ASTNodeCompoundStatement({ variable, create(new ASTNodeWhileStatement(condition, body, postExpression)) }, true)); } /* Control flow */