From 7fd0d87d563c3068b7bd463cf22b7711f9540049 Mon Sep 17 00:00:00 2001 From: WerWolv Date: Wed, 6 Jan 2021 16:28:41 +0100 Subject: [PATCH] Allow enum entries to be accessed via the scope resolution operator --- include/lang/ast_node.hpp | 36 ++++++++++++++++++------------------ include/lang/evaluator.hpp | 3 ++- source/lang/evaluator.cpp | 26 +++++++++++++++++++++++++- source/lang/parser.cpp | 5 ++++- 4 files changed, 49 insertions(+), 21 deletions(-) diff --git a/include/lang/ast_node.hpp b/include/lang/ast_node.hpp index cedc58e99..d290967c5 100644 --- a/include/lang/ast_node.hpp +++ b/include/lang/ast_node.hpp @@ -18,7 +18,7 @@ namespace hex::lang { [[nodiscard]] constexpr u32 getLineNumber() const { return this->m_lineNumber; } constexpr void setLineNumber(u32 lineNumber) { this->m_lineNumber = lineNumber; } - virtual ASTNode* clone() = 0; + virtual ASTNode* clone() const = 0; private: u32 m_lineNumber = 1; @@ -30,7 +30,7 @@ namespace hex::lang { ASTNodeIntegerLiteral(const ASTNodeIntegerLiteral&) = default; - ASTNode* clone() override { + ASTNode* clone() const override { return new ASTNodeIntegerLiteral(*this); } @@ -62,7 +62,7 @@ namespace hex::lang { this->m_right = other.m_right->clone(); } - ASTNode* clone() override { + ASTNode* clone() const override { return new ASTNodeNumericExpression(*this); } @@ -82,7 +82,7 @@ namespace hex::lang { [[nodiscard]] constexpr const auto& getType() const { return this->m_type; } - ASTNode* clone() override { + ASTNode* clone() const override { return new ASTNodeBuiltinType(*this); } @@ -105,7 +105,7 @@ namespace hex::lang { delete this->m_type; } - ASTNode* clone() override { + ASTNode* clone() const override { return new ASTNodeTypeDecl(*this); } @@ -138,7 +138,7 @@ namespace hex::lang { delete this->m_type; } - ASTNode* clone() override { + ASTNode* clone() const override { return new ASTNodeVariableDecl(*this); } @@ -173,7 +173,7 @@ namespace hex::lang { delete this->m_size; } - ASTNode* clone() override { + ASTNode* clone() const override { return new ASTNodeArrayVariableDecl(*this); } @@ -209,7 +209,7 @@ namespace hex::lang { delete this->m_type; } - ASTNode* clone() override { + ASTNode* clone() const override { return new ASTNodePointerVariableDecl(*this); } @@ -239,7 +239,7 @@ namespace hex::lang { delete member; } - ASTNode* clone() override { + ASTNode* clone() const override { return new ASTNodeStruct(*this); } @@ -264,7 +264,7 @@ namespace hex::lang { delete member; } - ASTNode* clone() override { + ASTNode* clone() const override { return new ASTNodeUnion(*this); } @@ -281,7 +281,7 @@ namespace hex::lang { ASTNodeEnum(const ASTNodeEnum &other) : ASTNode(other) { for (const auto &[name, entry] : other.getEntries()) - this->m_entries.emplace_back(name, entry->clone()); + this->m_entries.insert({ name, entry->clone() }); this->m_underlyingType = other.m_underlyingType->clone(); } @@ -291,17 +291,17 @@ namespace hex::lang { delete this->m_underlyingType; } - ASTNode* clone() override { + ASTNode* clone() const override { return new ASTNodeEnum(*this); } - [[nodiscard]] const std::vector>& getEntries() const { return this->m_entries; } - void addEntry(const std::string &name, ASTNode* expression) { this->m_entries.emplace_back(name, expression); } + [[nodiscard]] const std::unordered_map& getEntries() const { return this->m_entries; } + void addEntry(const std::string &name, ASTNode* expression) { this->m_entries.insert({ name, expression }); } [[nodiscard]] const ASTNode *getUnderlyingType() const { return this->m_underlyingType; } private: - std::vector> m_entries; + std::unordered_map m_entries; ASTNode *m_underlyingType; }; @@ -319,7 +319,7 @@ namespace hex::lang { delete expr; } - ASTNode* clone() override { + ASTNode* clone() const override { return new ASTNodeBitfield(*this); } @@ -336,7 +336,7 @@ namespace hex::lang { ASTNodeRValue(const ASTNodeRValue&) = default; - ASTNode* clone() override { + ASTNode* clone() const override { return new ASTNodeRValue(*this); } @@ -354,7 +354,7 @@ namespace hex::lang { ASTNodeScopeResolution(const ASTNodeScopeResolution&) = default; - ASTNode* clone() override { + ASTNode* clone() const override { return new ASTNodeScopeResolution(*this); } diff --git a/include/lang/evaluator.hpp b/include/lang/evaluator.hpp index 80a1a7e2b..1dc80403a 100644 --- a/include/lang/evaluator.hpp +++ b/include/lang/evaluator.hpp @@ -22,7 +22,7 @@ namespace hex::lang { const std::pair& getError() { return this->m_error; } private: - std::unordered_map m_types; + std::map m_types; prv::Provider* &m_provider; std::endian m_defaultDataEndian; u64 m_currOffset = 0; @@ -41,6 +41,7 @@ namespace hex::lang { return this->m_currEndian.value_or(this->m_defaultDataEndian); } + ASTNodeIntegerLiteral* evaluateScopeResolution(ASTNodeScopeResolution *node); ASTNodeIntegerLiteral* evaluateRValue(ASTNodeRValue *node); ASTNodeIntegerLiteral* evaluateOperator(ASTNodeIntegerLiteral *left, ASTNodeIntegerLiteral *right, Token::Operator op); ASTNodeIntegerLiteral* evaluateMathematicalExpression(ASTNodeNumericExpression *node); diff --git a/source/lang/evaluator.cpp b/source/lang/evaluator.cpp index ce2a3327e..d9c297bfb 100644 --- a/source/lang/evaluator.cpp +++ b/source/lang/evaluator.cpp @@ -4,7 +4,6 @@ #include #include -#include #include @@ -14,6 +13,25 @@ namespace hex::lang { : m_provider(provider), m_defaultDataEndian(defaultDataEndian) { } + ASTNodeIntegerLiteral* Evaluator::evaluateScopeResolution(ASTNodeScopeResolution *node) { + ASTNode *currScope = nullptr; + for (const auto &identifier : node->getPath()) { + if (currScope == nullptr) { + if (!this->m_types.contains(identifier)) + break; + + currScope = this->m_types[identifier.data()]; + } else if (auto enumNode = dynamic_cast(currScope); enumNode != nullptr) { + if (!enumNode->getEntries().contains(identifier)) + break; + else + return evaluateMathematicalExpression(static_cast(enumNode->getEntries().at(identifier))); + } + } + + throwEvaluateError("failed to find identifier", node->getLineNumber()); + } + ASTNodeIntegerLiteral* Evaluator::evaluateRValue(ASTNodeRValue *node) { const std::vector* currMembers = this->m_currMembers.back(); @@ -159,6 +177,8 @@ namespace hex::lang { leftInteger = evaluateMathematicalExpression(leftExprExpression); else if (auto leftExprRvalue = dynamic_cast(node->getLeftOperand()); leftExprRvalue != nullptr) leftInteger = evaluateRValue(leftExprRvalue); + else if (auto leftExprScopeResolution = dynamic_cast(node->getLeftOperand()); leftExprScopeResolution != nullptr) + leftInteger = evaluateScopeResolution(leftExprScopeResolution); else throwEvaluateError("invalid expression. Expected integer literal", node->getLineNumber()); @@ -168,6 +188,8 @@ namespace hex::lang { rightInteger = evaluateMathematicalExpression(rightExprExpression); else if (auto rightExprRvalue = dynamic_cast(node->getRightOperand()); rightExprRvalue != nullptr) rightInteger = evaluateRValue(rightExprRvalue); + else if (auto rightExprScopeResolution = dynamic_cast(node->getRightOperand()); rightExprScopeResolution != nullptr) + rightInteger = evaluateScopeResolution(rightExprScopeResolution); else throwEvaluateError("invalid expression. Expected integer literal", node->getLineNumber()); @@ -486,6 +508,8 @@ namespace hex::lang { patterns.push_back(this->evaluateArray(arrayDeclNode)); } else if (auto pointerDeclNode = dynamic_cast(node); pointerDeclNode != nullptr) { patterns.push_back(this->evaluatePointer(pointerDeclNode)); + } else if (auto typeDeclNode = dynamic_cast(node); typeDeclNode != nullptr) { + this->m_types[typeDeclNode->getName().data()] = typeDeclNode->getType(); } } diff --git a/source/lang/parser.cpp b/source/lang/parser.cpp index a18c57c08..152c35266 100644 --- a/source/lang/parser.cpp +++ b/source/lang/parser.cpp @@ -307,6 +307,7 @@ namespace hex::lang { const auto enumNode = new ASTNodeEnum(underlyingType); ScopeExit enumGuard([&]{ delete enumNode; }); + ASTNode *lastEntry = nullptr; while (!MATCHES(sequence(SEPARATOR_CURLYBRACKETCLOSE))) { if (MATCHES(sequence(IDENTIFIER, OPERATOR_ASSIGNMENT))) { auto name = getValue(-2); @@ -352,9 +353,11 @@ namespace hex::lang { default: throwParseError("invalid enum underlying type", -1); } + + lastEntry = valueExpr; } else - valueExpr = new ASTNodeNumericExpression(enumNode->getEntries().back().second, new ASTNodeIntegerLiteral({ Token::ValueType::Signed32Bit, s32(1) }), Token::Operator::Plus); + valueExpr = new ASTNodeNumericExpression(lastEntry, new ASTNodeIntegerLiteral({ Token::ValueType::Signed32Bit, s32(1) }), Token::Operator::Plus); enumNode->addEntry(name, valueExpr); }