diff --git a/include/lang/ast_node.hpp b/include/lang/ast_node.hpp index 570756b76..a293fa033 100644 --- a/include/lang/ast_node.hpp +++ b/include/lang/ast_node.hpp @@ -75,6 +75,38 @@ namespace hex::lang { Token::Operator m_operator; }; + class ASTNodeTernaryExpression : public ASTNode { + public: + ASTNodeTernaryExpression(ASTNode *first, ASTNode *second, ASTNode *third, Token::Operator op) + : ASTNode(), m_first(first), m_second(second), m_third(third), m_operator(op) { } + + ~ASTNodeTernaryExpression() override { + delete this->m_first; + delete this->m_second; + delete this->m_third; + } + + ASTNodeTernaryExpression(const ASTNodeTernaryExpression &other) : ASTNode(other) { + this->m_operator = other.m_operator; + this->m_first = other.m_first->clone(); + this->m_second = other.m_second->clone(); + this->m_third = other.m_third->clone(); + } + + ASTNode* clone() const override { + return new ASTNodeTernaryExpression(*this); + } + + ASTNode *getFirstOperand() { return this->m_first; } + ASTNode *getSecondOperand() { return this->m_second; } + ASTNode *getThirdOperand() { return this->m_third; } + Token::Operator getOperator() { return this->m_operator; } + + private: + ASTNode *m_first, *m_second, *m_third; + Token::Operator m_operator; + }; + class ASTNodeBuiltinType : public ASTNode { public: constexpr explicit ASTNodeBuiltinType(Token::ValueType type) diff --git a/include/lang/evaluator.hpp b/include/lang/evaluator.hpp index ca615eb3e..3ac937fc2 100644 --- a/include/lang/evaluator.hpp +++ b/include/lang/evaluator.hpp @@ -44,6 +44,8 @@ namespace hex::lang { ASTNodeIntegerLiteral* evaluateScopeResolution(ASTNodeScopeResolution *node); ASTNodeIntegerLiteral* evaluateRValue(ASTNodeRValue *node); ASTNodeIntegerLiteral* evaluateOperator(ASTNodeIntegerLiteral *left, ASTNodeIntegerLiteral *right, Token::Operator op); + ASTNodeIntegerLiteral* evaluateOperand(ASTNode *node); + ASTNodeIntegerLiteral* evaluateTernaryExpression(ASTNodeTernaryExpression *node); ASTNodeIntegerLiteral* evaluateMathematicalExpression(ASTNodeNumericExpression *node); PatternData* evaluateBuiltinType(ASTNodeBuiltinType *node); diff --git a/include/lang/parser.hpp b/include/lang/parser.hpp index 4f6d0dc09..d3e384da7 100644 --- a/include/lang/parser.hpp +++ b/include/lang/parser.hpp @@ -68,6 +68,7 @@ namespace hex::lang { ASTNode* parseBooleanAnd(); ASTNode* parseBooleanXor(); ASTNode* parseBooleanOr(); + ASTNode* parseTernaryConditional(); ASTNode* parseMathematicalExpression(); ASTNode* parseConditional(); diff --git a/include/lang/token.hpp b/include/lang/token.hpp index cd947a3b6..174ea0ece 100644 --- a/include/lang/token.hpp +++ b/include/lang/token.hpp @@ -55,7 +55,8 @@ namespace hex::lang { BoolAnd, BoolOr, BoolXor, - BoolNot + BoolNot, + TernaryConditional }; enum class ValueType { @@ -217,6 +218,7 @@ namespace hex::lang { #define OPERATOR_BOOLOR COMPONENT(Operator, BoolOr) #define OPERATOR_BOOLXOR COMPONENT(Operator, BoolXor) #define OPERATOR_BOOLNOT COMPONENT(Operator, BoolNot) +#define OPERATOR_TERNARYCONDITIONAL COMPONENT(Operator, TernaryConditional) #define VALUETYPE_CUSTOMTYPE COMPONENT(ValueType, CustomType) #define VALUETYPE_PADDING COMPONENT(ValueType, Padding) diff --git a/source/lang/evaluator.cpp b/source/lang/evaluator.cpp index 4917b4bc2..828a9323a 100644 --- a/source/lang/evaluator.cpp +++ b/source/lang/evaluator.cpp @@ -206,30 +206,40 @@ namespace hex::lang { } } + ASTNodeIntegerLiteral* Evaluator::evaluateOperand(ASTNode *node) { + if (auto exprLiteral = dynamic_cast(node); exprLiteral != nullptr) + return exprLiteral; + else if (auto exprExpression = dynamic_cast(node); exprExpression != nullptr) + return evaluateMathematicalExpression(exprExpression); + else if (auto exprRvalue = dynamic_cast(node); exprRvalue != nullptr) + return evaluateRValue(exprRvalue); + else if (auto exprScopeResolution = dynamic_cast(node); exprScopeResolution != nullptr) + return evaluateScopeResolution(exprScopeResolution); + else if (auto exprTernary = dynamic_cast(node); exprTernary != nullptr) + return evaluateTernaryExpression(exprTernary); + else + throwEvaluateError("invalid operand", node->getLineNumber()); + } + + ASTNodeIntegerLiteral* Evaluator::evaluateTernaryExpression(ASTNodeTernaryExpression *node) { + switch (node->getOperator()) { + case Token::Operator::TernaryConditional: { + auto condition = this->evaluateOperand(node->getFirstOperand()); + SCOPE_EXIT( delete condition; ); + + if (std::visit([](auto &&value){ return value != 0; }, condition->getValue())) + return this->evaluateOperand(node->getSecondOperand()); + else + return this->evaluateOperand(node->getThirdOperand()); + } + default: + throwEvaluateError("invalid operator used in ternary expression", node->getLineNumber()); + } + } + ASTNodeIntegerLiteral* Evaluator::evaluateMathematicalExpression(ASTNodeNumericExpression *node) { - ASTNodeIntegerLiteral *leftInteger, *rightInteger; - - if (auto leftExprLiteral = dynamic_cast(node->getLeftOperand()); leftExprLiteral != nullptr) - leftInteger = leftExprLiteral; - else if (auto leftExprExpression = dynamic_cast(node->getLeftOperand()); leftExprExpression != nullptr) - leftInteger = evaluateMathematicalExpression(leftExprExpression); - else if (auto leftExprRvalue = dynamic_cast(node->getLeftOperand()); leftExprRvalue != nullptr) - leftInteger = evaluateRValue(leftExprRvalue); - else if (auto leftExprScopeResolution = dynamic_cast(node->getLeftOperand()); leftExprScopeResolution != nullptr) - leftInteger = evaluateScopeResolution(leftExprScopeResolution); - else - throwEvaluateError("invalid expression. Expected integer literal", node->getLineNumber()); - - if (auto rightExprLiteral = dynamic_cast(node->getRightOperand()); rightExprLiteral != nullptr) - rightInteger = rightExprLiteral; - else if (auto rightExprExpression = dynamic_cast(node->getRightOperand()); rightExprExpression != nullptr) - rightInteger = evaluateMathematicalExpression(rightExprExpression); - else if (auto rightExprRvalue = dynamic_cast(node->getRightOperand()); rightExprRvalue != nullptr) - rightInteger = evaluateRValue(rightExprRvalue); - else if (auto rightExprScopeResolution = dynamic_cast(node->getRightOperand()); rightExprScopeResolution != nullptr) - rightInteger = evaluateScopeResolution(rightExprScopeResolution); - else - throwEvaluateError("invalid expression. Expected integer literal", node->getLineNumber()); + auto leftInteger = this->evaluateOperand(node->getLeftOperand()); + auto rightInteger = this->evaluateOperand(node->getRightOperand()); return evaluateOperator(leftInteger, rightInteger, node->getOperator()); } diff --git a/source/lang/lexer.cpp b/source/lang/lexer.cpp index 6ee5f4976..8c59cae9f 100644 --- a/source/lang/lexer.cpp +++ b/source/lang/lexer.cpp @@ -257,6 +257,9 @@ namespace hex::lang { } else if (c == '~') { tokens.emplace_back(TOKEN(Operator, BitNot)); offset += 1; + } else if (c == '?') { + tokens.emplace_back(TOKEN(Operator, TernaryConditional)); + offset += 1; } else if (c == '\'') { offset += 1; diff --git a/source/lang/parser.cpp b/source/lang/parser.cpp index 62136e45f..ab21357e9 100644 --- a/source/lang/parser.cpp +++ b/source/lang/parser.cpp @@ -205,9 +205,26 @@ namespace hex::lang { return node; } - // (parseBinaryOrExpression) + // (parseBooleanOr) ? (parseBooleanOr) : (parseBooleanOr) + ASTNode* Parser::parseTernaryConditional() { + auto node = this->parseBooleanXor(); + + while (MATCHES(sequence(OPERATOR_TERNARYCONDITIONAL))) { + auto second = this->parseBooleanXor(); + + if (!MATCHES(sequence(OPERATOR_INHERIT))) + throwParseError("expected ':' in ternary expression"); + + auto third = this->parseBooleanXor(); + node = new ASTNodeTernaryExpression(node, second, third, Token::Operator::TernaryConditional); + } + + return node; + } + + // (parseTernaryConditional) ASTNode* Parser::parseMathematicalExpression() { - return this->parseBooleanOr(); + return this->parseTernaryConditional(); }