diff --git a/plugins/libimhex/include/hex/lang/ast_node.hpp b/plugins/libimhex/include/hex/lang/ast_node.hpp index d69d6631f..8a6782424 100644 --- a/plugins/libimhex/include/hex/lang/ast_node.hpp +++ b/plugins/libimhex/include/hex/lang/ast_node.hpp @@ -480,10 +480,14 @@ namespace hex::lang { class ASTNodeWhileStatement : public ASTNode { public: - explicit ASTNodeWhileStatement(ASTNode *condition) : ASTNode(), m_condition(condition) { } + explicit ASTNodeWhileStatement(ASTNode *condition, std::vector body) + : ASTNode(), m_condition(condition), m_body(std::move(body)) { } ~ASTNodeWhileStatement() override { delete this->m_condition; + + for (auto &statement : this->m_body) + delete statement; } ASTNodeWhileStatement(const ASTNodeWhileStatement &other) : ASTNode(other) { @@ -498,8 +502,13 @@ namespace hex::lang { return this->m_condition; } + [[nodiscard]] const std::vector& getBody() { + return this->m_body; + } + private: ASTNode *m_condition; + std::vector m_body; }; class ASTNodeFunctionCall : public ASTNode { diff --git a/plugins/libimhex/include/hex/lang/parser.hpp b/plugins/libimhex/include/hex/lang/parser.hpp index d05342581..8b4b3a102 100644 --- a/plugins/libimhex/include/hex/lang/parser.hpp +++ b/plugins/libimhex/include/hex/lang/parser.hpp @@ -71,12 +71,14 @@ namespace hex::lang { ASTNode* parseTernaryConditional(); ASTNode* parseMathematicalExpression(); - void parseAttribute(Attributable *currNode); ASTNode* parseFunctionDefintion(); ASTNode* parseFunctionStatement(); ASTNode* parseFunctionVariableAssignment(); ASTNode* parseFunctionReturnStatement(); ASTNode* parseFunctionConditional(); + ASTNode* parseFunctionWhileLoop(); + + void parseAttribute(Attributable *currNode); ASTNode* parseConditional(); ASTNode* parseWhileStatement(); ASTNode* parseType(s32 startIndex); diff --git a/plugins/libimhex/source/lang/evaluator.cpp b/plugins/libimhex/source/lang/evaluator.cpp index 45f3a950c..a86f77be5 100644 --- a/plugins/libimhex/source/lang/evaluator.cpp +++ b/plugins/libimhex/source/lang/evaluator.cpp @@ -534,11 +534,34 @@ namespace hex::lang { else returnResult = this->evaluateFunctionBody(conditionalNode->getFalseBody()); - for (u32 i = localVariableStartCount; i < this->m_localVariables.size(); i++) + for (u32 i = localVariableStartCount; i < this->m_localVariables.back()->size(); i++) delete (*this->m_localVariables.back())[i]; this->m_localVariables.back()->resize(localVariableStartCount); this->m_localStack.resize(localVariableStackStartSize); + } else { + this->getConsole().abortEvaluation("invalid rvalue used in return statement"); + } + } else if (auto whileLoopNode = dynamic_cast(statement); whileLoopNode != nullptr) { + if (auto numericExpressionNode = dynamic_cast(whileLoopNode->getCondition()); numericExpressionNode != nullptr) { + auto condition = this->evaluateMathematicalExpression(numericExpressionNode); + + while (std::visit([](auto &&value) { return value != 0; }, condition->getValue())) { + u32 localVariableStartCount = this->m_localVariables.back()->size(); + u32 localVariableStackStartSize = this->m_localStack.size(); + + returnResult = this->evaluateFunctionBody(whileLoopNode->getBody()); + if (returnResult.has_value()) + break; + + for (u32 i = localVariableStartCount; i < this->m_localVariables.back()->size(); i++) + delete (*this->m_localVariables.back())[i]; + this->m_localVariables.back()->resize(localVariableStartCount); + this->m_localStack.resize(localVariableStackStartSize); + + condition = this->evaluateMathematicalExpression(numericExpressionNode); + } + } else { this->getConsole().abortEvaluation("invalid rvalue used in return statement"); } diff --git a/plugins/libimhex/source/lang/parser.cpp b/plugins/libimhex/source/lang/parser.cpp index 7be7919bb..f25534554 100644 --- a/plugins/libimhex/source/lang/parser.cpp +++ b/plugins/libimhex/source/lang/parser.cpp @@ -410,8 +410,10 @@ namespace hex::lang { else if (MATCHES(sequence(KEYWORD_IF, SEPARATOR_ROUNDBRACKETOPEN))) { statement = parseFunctionConditional(); needsSemicolon = false; - } - else + } else if (MATCHES(sequence(KEYWORD_WHILE, SEPARATOR_ROUNDBRACKETOPEN))) { + statement = parseFunctionWhileLoop(); + needsSemicolon = false; + } else throwParseError("invalid sequence", 0); if (needsSemicolon && !MATCHES(sequence(SEPARATOR_ENDOFEXPRESSION))) { @@ -471,6 +473,30 @@ namespace hex::lang { return new ASTNodeConditionalStatement(condition, trueBody, falseBody); } + ASTNode* Parser::parseFunctionWhileLoop() { + auto condition = parseMathematicalExpression(); + std::vector body; + + auto cleanup = SCOPE_GUARD { + delete condition; + for (auto &statement : body) + delete statement; + }; + + if (MATCHES(sequence(SEPARATOR_ROUNDBRACKETCLOSE, SEPARATOR_CURLYBRACKETOPEN))) { + while (!MATCHES(sequence(SEPARATOR_CURLYBRACKETCLOSE))) { + body.push_back(parseFunctionStatement()); + } + } else if (MATCHES(sequence(SEPARATOR_ROUNDBRACKETCLOSE))) { + body.push_back(parseFunctionStatement()); + } else + throwParseError("expected body of conditional statement"); + + cleanup.release(); + + return new ASTNodeWhileStatement(condition, body); + } + /* Control flow */ // if ((parseMathematicalExpression)) { (parseMember) } @@ -521,7 +547,7 @@ namespace hex::lang { cleanup.release(); - return new ASTNodeWhileStatement(condition); + return new ASTNodeWhileStatement(condition, { }); } /* Type declarations */