Language refactoring, added builtin function registry

This commit is contained in:
WerWolv
2021-01-11 23:54:12 +01:00
parent 90e0aa83d8
commit c09a8bca7f
20 changed files with 203 additions and 208 deletions

View File

@@ -73,4 +73,15 @@ namespace hex {
return *SharedData::get().commandPaletteCommands;
}
/* Pattern Language Functions */
void ContentRegistry::PatternLanguageFunctions::add(std::string_view name, u32 parameterCount, const std::function<hex::lang::ASTNode*(std::vector<hex::lang::ASTNode*>)> &func) {
(*SharedData::get().patternLanguageFunctions)[name.data()] = Function{ parameterCount, func };
}
std::map<std::string, ContentRegistry::PatternLanguageFunctions::Function> ContentRegistry::PatternLanguageFunctions::getEntries() {
return *SharedData::get().patternLanguageFunctions;
}
}

View File

@@ -0,0 +1,143 @@
#include "lang/evaluator.hpp"
#include "helpers/content_registry.hpp"
namespace hex::lang {
#define LITERAL_COMPARE(literal, cond) std::visit([&, this](auto &&literal) { return (cond) != 0; }, literal)
void Evaluator::registerBuiltinFunctions() {
/* findSequence */
ContentRegistry::PatternLanguageFunctions::add("findSequence", ContentRegistry::PatternLanguageFunctions::MoreParametersThan | 1, [this](auto params) {
auto& occurrenceIndex = asType<ASTNodeIntegerLiteral>(params[0])->getValue();
std::vector<u8> sequence;
for (u32 i = 1; i < params.size(); i++) {
sequence.push_back(std::visit([](auto &&value) -> u8 {
if (value <= 0xFF)
return value;
else
throwEvaluateError("sequence bytes need to fit into 1 byte");
}, asType<ASTNodeIntegerLiteral>(params[i])->getValue()));
}
std::vector<u8> bytes(sequence.size(), 0x00);
u32 occurrences = 0;
for (u64 offset = 0; offset < this->m_provider->getSize() - sequence.size(); offset++) {
this->m_provider->read(offset, bytes.data(), bytes.size());
if (bytes == sequence) {
if (LITERAL_COMPARE(occurrenceIndex, occurrences < occurrenceIndex)) {
occurrences++;
continue;
}
return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned64Bit, offset });
}
}
throwEvaluateError("failed to find sequence");
});
/* assert */
ContentRegistry::PatternLanguageFunctions::add("readUnsigned", 2, [this](auto params) {
auto address = asType<ASTNodeIntegerLiteral>(params[0])->getValue();
auto size = asType<ASTNodeIntegerLiteral>(params[1])->getValue();
if (LITERAL_COMPARE(address, address >= this->m_provider->getActualSize()))
throwEvaluateError("address out of range");
return std::visit([this](auto &&address, auto &&size) {
if (size <= 0 || size > 16)
throwEvaluateError("invalid read size");
u8 value[(u8)size];
this->m_provider->read(address, value, size);
switch ((u8)size) {
case 1: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned8Bit, hex::changeEndianess(*reinterpret_cast<u8*>(value), 1, this->getCurrentEndian()) });
case 2: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned16Bit, hex::changeEndianess(*reinterpret_cast<u16*>(value), 2, this->getCurrentEndian()) });
case 4: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned32Bit, hex::changeEndianess(*reinterpret_cast<u32*>(value), 4, this->getCurrentEndian()) });
case 8: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned64Bit, hex::changeEndianess(*reinterpret_cast<u64*>(value), 8, this->getCurrentEndian()) });
case 16: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned128Bit, hex::changeEndianess(*reinterpret_cast<u128*>(value), 16, this->getCurrentEndian()) });
default: throwEvaluateError("invalid read size");
}
}, address, size);
});
ContentRegistry::PatternLanguageFunctions::add("readSigned", 2, [this](auto params) {
auto address = asType<ASTNodeIntegerLiteral>(params[0])->getValue();
auto size = asType<ASTNodeIntegerLiteral>(params[1])->getValue();
if (LITERAL_COMPARE(address, address >= this->m_provider->getActualSize()))
throwEvaluateError("address out of range");
return std::visit([this](auto &&address, auto &&size) {
if (size <= 0 || size > 16)
throwEvaluateError("invalid read size");
u8 value[(u8)size];
this->m_provider->read(address, value, size);
switch ((u8)size) {
case 1: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed8Bit, hex::changeEndianess(*reinterpret_cast<s8*>(value), 1, this->getCurrentEndian()) });
case 2: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed16Bit, hex::changeEndianess(*reinterpret_cast<s16*>(value), 2, this->getCurrentEndian()) });
case 4: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed32Bit, hex::changeEndianess(*reinterpret_cast<s32*>(value), 4, this->getCurrentEndian()) });
case 8: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed64Bit, hex::changeEndianess(*reinterpret_cast<s64*>(value), 8, this->getCurrentEndian()) });
case 16: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed128Bit, hex::changeEndianess(*reinterpret_cast<s128*>(value), 16, this->getCurrentEndian()) });
default: throwEvaluateError("invalid read size");
}
}, address, size);
});
ContentRegistry::PatternLanguageFunctions::add("assert", 2, [this](auto params) {
auto condition = asType<ASTNodeIntegerLiteral>(params[0])->getValue();
auto message = asType<ASTNodeStringLiteral>(params[1])->getString();
if (LITERAL_COMPARE(condition, condition == 0))
throwEvaluateError(hex::format("assert failed \"%s\"", message.data()));
return nullptr;
});
ContentRegistry::PatternLanguageFunctions::add("warnAssert", 2, [this](auto params) {
auto condition = asType<ASTNodeIntegerLiteral>(params[0])->getValue();
auto message = asType<ASTNodeStringLiteral>(params[1])->getString();
if (LITERAL_COMPARE(condition, condition == 0))
this->emmitWaring(hex::format("assert failed \"%s\"", message.data()));
return nullptr;
});
ContentRegistry::PatternLanguageFunctions::add("print", ContentRegistry::PatternLanguageFunctions::MoreParametersThan | 0, [this](auto params) {
std::string message;
for (auto& param : params) {
if (auto integerLiteral = dynamic_cast<ASTNodeIntegerLiteral*>(param); integerLiteral != nullptr) {
switch (integerLiteral->getType()) {
case Token::ValueType::Character: message += std::get<s8>(integerLiteral->getValue()); break;
case Token::ValueType::Unsigned8Bit: message += std::to_string(std::get<u8>(integerLiteral->getValue())); break;
case Token::ValueType::Signed8Bit: message += std::to_string(std::get<s8>(integerLiteral->getValue())); break;
case Token::ValueType::Unsigned16Bit: message += std::to_string(std::get<u16>(integerLiteral->getValue())); break;
case Token::ValueType::Signed16Bit: message += std::to_string(std::get<s16>(integerLiteral->getValue())); break;
case Token::ValueType::Unsigned32Bit: message += std::to_string(std::get<u32>(integerLiteral->getValue())); break;
case Token::ValueType::Signed32Bit: message += std::to_string(std::get<s32>(integerLiteral->getValue())); break;
case Token::ValueType::Unsigned64Bit: message += std::to_string(std::get<u64>(integerLiteral->getValue())); break;
case Token::ValueType::Signed64Bit: message += std::to_string(std::get<s64>(integerLiteral->getValue())); break;
case Token::ValueType::Unsigned128Bit: message += "A lot"; break; // TODO: Implement u128 to_string
case Token::ValueType::Signed128Bit: message += "A lot"; break; // TODO: Implement s128 to_string
case Token::ValueType::Float: message += std::to_string(std::get<float>(integerLiteral->getValue())); break;
case Token::ValueType::Double: message += std::to_string(std::get<double>(integerLiteral->getValue())); break;
case Token::ValueType::Boolean: message += std::get<s32>(integerLiteral->getValue()) ? "true" : "false"; break;
case Token::ValueType::CustomType: message += "< Custom Type >"; break;
}
}
else if (auto stringLiteral = dynamic_cast<ASTNodeStringLiteral*>(param); stringLiteral != nullptr)
message += stringLiteral->getString();
}
this->emmitInfo(message);
return nullptr;
});
}
}

View File

@@ -0,0 +1,718 @@
#include "lang/evaluator.hpp"
#include "lang/token.hpp"
#include "helpers/utils.hpp"
#include "helpers/content_registry.hpp"
#include <bit>
#include <algorithm>
#include <unistd.h>
namespace hex::lang {
Evaluator::Evaluator(prv::Provider* &provider, std::endian defaultDataEndian)
: m_provider(provider), m_defaultDataEndian(defaultDataEndian) {
this->registerBuiltinFunctions();
}
ASTNodeIntegerLiteral* Evaluator::evaluateScopeResolution(ASTNodeScopeResolution *node) {
ASTNode *currScope = nullptr;
for (const auto &identifier : node->getPath()) {
if (currScope == nullptr) {
if (!this->m_types.contains(identifier))
break;
currScope = this->m_types[identifier.data()];
} else if (auto enumNode = dynamic_cast<ASTNodeEnum*>(currScope); enumNode != nullptr) {
if (!enumNode->getEntries().contains(identifier))
break;
else
return evaluateMathematicalExpression(static_cast<ASTNodeNumericExpression*>(enumNode->getEntries().at(identifier)));
}
}
throwEvaluateError("failed to find identifier");
}
ASTNodeIntegerLiteral* Evaluator::evaluateRValue(ASTNodeRValue *node) {
if (this->m_currMembers.empty() && this->m_globalMembers.empty())
throwEvaluateError("no variables available");
std::vector<PatternData*> currMembers;
if (!this->m_currMembers.empty())
std::copy(this->m_currMembers.back()->begin(), this->m_currMembers.back()->end(), std::back_inserter(currMembers));
if (!this->m_globalMembers.empty())
std::copy(this->m_globalMembers.begin(), this->m_globalMembers.end(), std::back_inserter(currMembers));
PatternData *currPattern = nullptr;
for (u32 i = 0; i < node->getPath().size(); i++) {
const auto &identifier = node->getPath()[i];
if (auto structPattern = dynamic_cast<PatternDataStruct*>(currPattern); structPattern != nullptr)
currMembers = structPattern->getMembers();
else if (auto unionPattern = dynamic_cast<PatternDataUnion*>(currPattern); unionPattern != nullptr)
currMembers = unionPattern->getMembers();
else if (auto pointerPattern = dynamic_cast<PatternDataPointer*>(currPattern); pointerPattern != nullptr) {
currPattern = pointerPattern->getPointedAtPattern();
i--;
continue;
}
else if (currPattern != nullptr)
throwEvaluateError("tried to access member of a non-struct/union type");
auto candidate = std::find_if(currMembers.begin(), currMembers.end(), [&](auto member) {
return member->getVariableName() == identifier;
});
if (candidate != currMembers.end())
currPattern = *candidate;
else
throwEvaluateError(hex::format("could not find identifier '%s'", identifier.c_str()));
}
if (auto pointerPattern = dynamic_cast<PatternDataPointer*>(currPattern); pointerPattern != nullptr)
currPattern = pointerPattern->getPointedAtPattern();
if (auto unsignedPattern = dynamic_cast<PatternDataUnsigned*>(currPattern); unsignedPattern != nullptr) {
u8 value[unsignedPattern->getSize()];
this->m_provider->read(unsignedPattern->getOffset(), value, unsignedPattern->getSize());
switch (unsignedPattern->getSize()) {
case 1: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned8Bit, hex::changeEndianess(*reinterpret_cast<u8*>(value), 1, unsignedPattern->getEndian()) });
case 2: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned16Bit, hex::changeEndianess(*reinterpret_cast<u16*>(value), 2, unsignedPattern->getEndian()) });
case 4: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned32Bit, hex::changeEndianess(*reinterpret_cast<u32*>(value), 4, unsignedPattern->getEndian()) });
case 8: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned64Bit, hex::changeEndianess(*reinterpret_cast<u64*>(value), 8, unsignedPattern->getEndian()) });
case 16: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned128Bit, hex::changeEndianess(*reinterpret_cast<u128*>(value), 16, unsignedPattern->getEndian()) });
default: throwEvaluateError("invalid rvalue size");
}
} else if (auto signedPattern = dynamic_cast<PatternDataSigned*>(currPattern); signedPattern != nullptr) {
u8 value[unsignedPattern->getSize()];
this->m_provider->read(signedPattern->getOffset(), value, signedPattern->getSize());
switch (unsignedPattern->getSize()) {
case 1: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed8Bit, hex::changeEndianess(*reinterpret_cast<s8*>(value), 1, signedPattern->getEndian()) });
case 2: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed16Bit, hex::changeEndianess(*reinterpret_cast<s16*>(value), 2, signedPattern->getEndian()) });
case 4: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed32Bit, hex::changeEndianess(*reinterpret_cast<s32*>(value), 4, signedPattern->getEndian()) });
case 8: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed64Bit, hex::changeEndianess(*reinterpret_cast<s64*>(value), 8, signedPattern->getEndian()) });
case 16: return new ASTNodeIntegerLiteral({ Token::ValueType::Signed128Bit, hex::changeEndianess(*reinterpret_cast<s128*>(value), 16, signedPattern->getEndian()) });
default: throwEvaluateError("invalid rvalue size");
}
} else if (auto enumPattern = dynamic_cast<PatternDataEnum*>(currPattern); enumPattern != nullptr) {
u8 value[enumPattern->getSize()];
this->m_provider->read(enumPattern->getOffset(), value, enumPattern->getSize());
switch (enumPattern->getSize()) {
case 1: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned8Bit, hex::changeEndianess(*reinterpret_cast<u8*>(value), 1, enumPattern->getEndian()) });
case 2: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned16Bit, hex::changeEndianess(*reinterpret_cast<u16*>(value), 2, enumPattern->getEndian()) });
case 4: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned32Bit, hex::changeEndianess(*reinterpret_cast<u32*>(value), 4, enumPattern->getEndian()) });
case 8: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned64Bit, hex::changeEndianess(*reinterpret_cast<u64*>(value), 8, enumPattern->getEndian()) });
case 16: return new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned128Bit, hex::changeEndianess(*reinterpret_cast<u128*>(value), 16, enumPattern->getEndian()) });
default: throwEvaluateError("invalid rvalue size");
}
} else
throwEvaluateError("tried to use non-integer value in numeric expression");
}
ASTNode* Evaluator::evaluateFunctionCall(ASTNodeFunctionCall *node) {
std::vector<ASTNode*> evaluatedParams;
ScopeExit paramCleanup([&] {
for (auto &param : evaluatedParams)
delete param;
});
for (auto &param : node->getParams()) {
if (auto numericExpression = dynamic_cast<ASTNodeNumericExpression*>(param); numericExpression != nullptr)
evaluatedParams.push_back(this->evaluateMathematicalExpression(numericExpression));
else if (auto stringLiteral = dynamic_cast<ASTNodeStringLiteral*>(param); stringLiteral != nullptr)
evaluatedParams.push_back(stringLiteral->clone());
}
if (!ContentRegistry::PatternLanguageFunctions::getEntries().contains(node->getFunctionName().data()))
throwEvaluateError(hex::format("no function named '%s' found", node->getFunctionName().data()));
auto &function = ContentRegistry::PatternLanguageFunctions::getEntries()[node->getFunctionName().data()];
if (function.parameterCount == ContentRegistry::PatternLanguageFunctions::UnlimitedParameters) {
; // Don't check parameter count
}
else if (function.parameterCount & ContentRegistry::PatternLanguageFunctions::LessParametersThan) {
if (evaluatedParams.size() >= (function.parameterCount & ~ContentRegistry::PatternLanguageFunctions::LessParametersThan))
throwEvaluateError(hex::format("too many parameters for function '%s'. Expected %d", node->getFunctionName().data(), function.parameterCount & ~ContentRegistry::PatternLanguageFunctions::LessParametersThan));
} else if (function.parameterCount & ContentRegistry::PatternLanguageFunctions::MoreParametersThan) {
if (evaluatedParams.size() <= (function.parameterCount & ~ContentRegistry::PatternLanguageFunctions::MoreParametersThan))
throwEvaluateError(hex::format("too few parameters for function '%s'. Expected %d", node->getFunctionName().data(), function.parameterCount & ~ContentRegistry::PatternLanguageFunctions::MoreParametersThan));
} else if (function.parameterCount != evaluatedParams.size()) {
throwEvaluateError(hex::format("invalid number of parameters for function '%s'. Expected %d", node->getFunctionName().data(), function.parameterCount));
}
return function.func(evaluatedParams);
}
#define FLOAT_BIT_OPERATION(name) \
auto name(hex::floating_point auto left, auto right) { throw std::runtime_error(""); return 0; } \
auto name(auto left, hex::floating_point auto right) { throw std::runtime_error(""); return 0; } \
auto name(hex::floating_point auto left, hex::floating_point auto right) { throw std::runtime_error(""); return 0; } \
auto name(hex::integral auto left, hex::integral auto right)
namespace {
FLOAT_BIT_OPERATION(shiftLeft) {
return left << right;
}
FLOAT_BIT_OPERATION(shiftRight) {
return left >> right;
}
FLOAT_BIT_OPERATION(bitAnd) {
return left & right;
}
FLOAT_BIT_OPERATION(bitOr) {
return left | right;
}
FLOAT_BIT_OPERATION(bitXor) {
return left ^ right;
}
FLOAT_BIT_OPERATION(bitNot) {
return ~right;
}
}
ASTNodeIntegerLiteral* Evaluator::evaluateOperator(ASTNodeIntegerLiteral *left, ASTNodeIntegerLiteral *right, Token::Operator op) {
auto newType = [&] {
#define CHECK_TYPE(type) if (left->getType() == (type) || right->getType() == (type)) return (type)
#define DEFAULT_TYPE(type) return (type)
if (left->getType() == Token::ValueType::Any && right->getType() != Token::ValueType::Any)
return right->getType();
if (left->getType() != Token::ValueType::Any && right->getType() == Token::ValueType::Any)
return left->getType();
CHECK_TYPE(Token::ValueType::Double);
CHECK_TYPE(Token::ValueType::Float);
CHECK_TYPE(Token::ValueType::Unsigned128Bit);
CHECK_TYPE(Token::ValueType::Signed128Bit);
CHECK_TYPE(Token::ValueType::Unsigned64Bit);
CHECK_TYPE(Token::ValueType::Signed64Bit);
CHECK_TYPE(Token::ValueType::Unsigned32Bit);
CHECK_TYPE(Token::ValueType::Signed32Bit);
CHECK_TYPE(Token::ValueType::Unsigned16Bit);
CHECK_TYPE(Token::ValueType::Signed16Bit);
CHECK_TYPE(Token::ValueType::Unsigned8Bit);
CHECK_TYPE(Token::ValueType::Signed8Bit);
CHECK_TYPE(Token::ValueType::Character);
CHECK_TYPE(Token::ValueType::Boolean);
DEFAULT_TYPE(Token::ValueType::Signed32Bit);
#undef CHECK_TYPE
#undef DEFAULT_TYPE
}();
try {
return std::visit([&](auto &&leftValue, auto &&rightValue) -> ASTNodeIntegerLiteral * {
switch (op) {
case Token::Operator::Plus:
return new ASTNodeIntegerLiteral({ newType, leftValue + rightValue });
case Token::Operator::Minus:
return new ASTNodeIntegerLiteral({ newType, leftValue - rightValue });
case Token::Operator::Star:
return new ASTNodeIntegerLiteral({ newType, leftValue * rightValue });
case Token::Operator::Slash:
return new ASTNodeIntegerLiteral({ newType, leftValue / rightValue });
case Token::Operator::ShiftLeft:
return new ASTNodeIntegerLiteral({ newType, shiftLeft(leftValue, rightValue) });
case Token::Operator::ShiftRight:
return new ASTNodeIntegerLiteral({ newType, shiftRight(leftValue, rightValue) });
case Token::Operator::BitAnd:
return new ASTNodeIntegerLiteral({ newType, bitAnd(leftValue, rightValue) });
case Token::Operator::BitXor:
return new ASTNodeIntegerLiteral({ newType, bitXor(leftValue, rightValue) });
case Token::Operator::BitOr:
return new ASTNodeIntegerLiteral({ newType, bitOr(leftValue, rightValue) });
case Token::Operator::BitNot:
return new ASTNodeIntegerLiteral({ newType, bitNot(leftValue, rightValue) });
case Token::Operator::BoolEquals:
return new ASTNodeIntegerLiteral({ newType, leftValue == rightValue });
case Token::Operator::BoolNotEquals:
return new ASTNodeIntegerLiteral({ newType, leftValue != rightValue });
case Token::Operator::BoolGreaterThan:
return new ASTNodeIntegerLiteral({ newType, leftValue > rightValue });
case Token::Operator::BoolLessThan:
return new ASTNodeIntegerLiteral({ newType, leftValue < rightValue });
case Token::Operator::BoolGreaterThanOrEquals:
return new ASTNodeIntegerLiteral({ newType, leftValue >= rightValue });
case Token::Operator::BoolLessThanOrEquals:
return new ASTNodeIntegerLiteral({ newType, leftValue <= rightValue });
case Token::Operator::BoolAnd:
return new ASTNodeIntegerLiteral({ newType, leftValue && rightValue });
case Token::Operator::BoolXor:
return new ASTNodeIntegerLiteral({ newType, leftValue && !rightValue || !leftValue && rightValue });
case Token::Operator::BoolOr:
return new ASTNodeIntegerLiteral({ newType, leftValue || rightValue });
case Token::Operator::BoolNot:
return new ASTNodeIntegerLiteral({ newType, !rightValue });
default:
throwEvaluateError("invalid operator used in mathematical expression");
}
}, left->getValue(), right->getValue());
} catch (std::runtime_error &e) {
throwEvaluateError("bitwise operations on floating point numbers are forbidden");
}
}
ASTNodeIntegerLiteral* Evaluator::evaluateOperand(ASTNode *node) {
if (auto exprLiteral = dynamic_cast<ASTNodeIntegerLiteral*>(node); exprLiteral != nullptr)
return exprLiteral;
else if (auto exprExpression = dynamic_cast<ASTNodeNumericExpression*>(node); exprExpression != nullptr)
return evaluateMathematicalExpression(exprExpression);
else if (auto exprRvalue = dynamic_cast<ASTNodeRValue*>(node); exprRvalue != nullptr)
return evaluateRValue(exprRvalue);
else if (auto exprScopeResolution = dynamic_cast<ASTNodeScopeResolution*>(node); exprScopeResolution != nullptr)
return evaluateScopeResolution(exprScopeResolution);
else if (auto exprTernary = dynamic_cast<ASTNodeTernaryExpression*>(node); exprTernary != nullptr)
return evaluateTernaryExpression(exprTernary);
else if (auto exprFunctionCall = dynamic_cast<ASTNodeFunctionCall*>(node); exprFunctionCall != nullptr) {
auto returnValue = evaluateFunctionCall(exprFunctionCall);
if (returnValue == nullptr)
throwEvaluateError("function returning void used in expression");
else if (auto integerNode = dynamic_cast<ASTNodeIntegerLiteral*>(returnValue); integerNode != nullptr)
return integerNode;
else
throwEvaluateError("function not returning a numeric value used in expression");
}
else
throwEvaluateError("invalid operand");
}
ASTNodeIntegerLiteral* Evaluator::evaluateTernaryExpression(ASTNodeTernaryExpression *node) {
switch (node->getOperator()) {
case Token::Operator::TernaryConditional: {
auto condition = this->evaluateOperand(node->getFirstOperand());
SCOPE_EXIT( delete condition; );
if (std::visit([](auto &&value){ return value != 0; }, condition->getValue()))
return this->evaluateOperand(node->getSecondOperand());
else
return this->evaluateOperand(node->getThirdOperand());
}
default:
throwEvaluateError("invalid operator used in ternary expression");
}
}
ASTNodeIntegerLiteral* Evaluator::evaluateMathematicalExpression(ASTNodeNumericExpression *node) {
auto leftInteger = this->evaluateOperand(node->getLeftOperand());
auto rightInteger = this->evaluateOperand(node->getRightOperand());
return evaluateOperator(leftInteger, rightInteger, node->getOperator());
}
PatternData* Evaluator::evaluateBuiltinType(ASTNodeBuiltinType *node) {
auto &type = node->getType();
auto typeSize = Token::getTypeSize(type);
PatternData *pattern;
if (type == Token::ValueType::Character)
pattern = new PatternDataCharacter(this->m_currOffset);
else if (type == Token::ValueType::Boolean)
pattern = new PatternDataBoolean(this->m_currOffset);
else if (Token::isUnsigned(type))
pattern = new PatternDataUnsigned(this->m_currOffset, typeSize);
else if (Token::isSigned(type))
pattern = new PatternDataSigned(this->m_currOffset, typeSize);
else if (Token::isFloatingPoint(type))
pattern = new PatternDataFloat(this->m_currOffset, typeSize);
else
throwEvaluateError("invalid builtin type");
this->m_currOffset += typeSize;
pattern->setTypeName(Token::getTypeName(type));
pattern->setEndian(this->getCurrentEndian());
return pattern;
}
void Evaluator::evaluateMember(ASTNode *node, std::vector<PatternData*> &currMembers, bool increaseOffset) {
auto startOffset = this->m_currOffset;
if (auto memberVariableNode = dynamic_cast<ASTNodeVariableDecl*>(node); memberVariableNode != nullptr)
currMembers.push_back(this->evaluateVariable(memberVariableNode));
else if (auto memberArrayNode = dynamic_cast<ASTNodeArrayVariableDecl*>(node); memberArrayNode != nullptr)
currMembers.push_back(this->evaluateArray(memberArrayNode));
else if (auto memberPointerNode = dynamic_cast<ASTNodePointerVariableDecl*>(node); memberPointerNode != nullptr)
currMembers.push_back(this->evaluatePointer(memberPointerNode));
else if (auto conditionalNode = dynamic_cast<ASTNodeConditionalStatement*>(node); conditionalNode != nullptr) {
auto condition = this->evaluateMathematicalExpression(static_cast<ASTNodeNumericExpression*>(conditionalNode->getCondition()));
if (std::visit([](auto &&value) { return value != 0; }, condition->getValue())) {
for (auto &statement : conditionalNode->getTrueBody()) {
this->evaluateMember(statement, currMembers, increaseOffset);
}
} else {
for (auto &statement : conditionalNode->getFalseBody()) {
this->evaluateMember(statement, currMembers, increaseOffset);
}
}
delete condition;
}
else
throwEvaluateError("invalid struct member");
if (!increaseOffset)
this->m_currOffset = startOffset;
}
PatternData* Evaluator::evaluateStruct(ASTNodeStruct *node) {
std::vector<PatternData*> memberPatterns;
this->m_currMembers.push_back(&memberPatterns);
SCOPE_EXIT( this->m_currMembers.pop_back(); );
auto startOffset = this->m_currOffset;
for (auto &member : node->getMembers()) {
this->evaluateMember(member, memberPatterns, true);
}
return new PatternDataStruct(startOffset, this->m_currOffset - startOffset, memberPatterns);
}
PatternData* Evaluator::evaluateUnion(ASTNodeUnion *node) {
std::vector<PatternData*> memberPatterns;
this->m_currMembers.push_back(&memberPatterns);
SCOPE_EXIT( this->m_currMembers.pop_back(); );
auto startOffset = this->m_currOffset;
for (auto &member : node->getMembers()) {
this->evaluateMember(member, memberPatterns, false);
}
size_t size = 0;
for (const auto &pattern : memberPatterns)
size = std::max(size, pattern->getSize());
this->m_currOffset += size;
return new PatternDataUnion(startOffset, size, memberPatterns);
}
PatternData* Evaluator::evaluateEnum(ASTNodeEnum *node) {
std::vector<std::pair<Token::IntegerLiteral, std::string>> entryPatterns;
auto startOffset = this->m_currOffset;
for (auto &[name, value] : node->getEntries()) {
auto expression = dynamic_cast<ASTNodeNumericExpression*>(value);
if (expression == nullptr)
throwEvaluateError("invalid expression in enum value");
auto valueNode = evaluateMathematicalExpression(expression);
SCOPE_EXIT( delete valueNode; );
entryPatterns.push_back({{ valueNode->getType(), valueNode->getValue() }, name });
}
auto underlyingType = dynamic_cast<ASTNodeTypeDecl*>(node->getUnderlyingType());
if (underlyingType == nullptr)
throwEvaluateError("enum underlying type was not ASTNodeTypeDecl. This is a bug");
size_t size;
if (auto builtinType = dynamic_cast<ASTNodeBuiltinType*>(underlyingType->getType()); builtinType != nullptr)
size = Token::getTypeSize(builtinType->getType());
else
throwEvaluateError("invalid enum underlying type");
this->m_currOffset += size;
return new PatternDataEnum(startOffset, size, entryPatterns);;
}
PatternData* Evaluator::evaluateBitfield(ASTNodeBitfield *node) {
std::vector<std::pair<std::string, size_t>> entryPatterns;
auto startOffset = this->m_currOffset;
size_t bits = 0;
for (auto &[name, value] : node->getEntries()) {
auto expression = dynamic_cast<ASTNodeNumericExpression*>(value);
if (expression == nullptr)
throwEvaluateError("invalid expression in bitfield field size");
auto valueNode = evaluateMathematicalExpression(expression);
SCOPE_EXIT( delete valueNode; );
auto fieldBits = std::visit([node, type = valueNode->getType()] (auto &&value) {
if (Token::isFloatingPoint(type))
throwEvaluateError("bitfield entry size must be an integer value");
return static_cast<s128>(value);
}, valueNode->getValue());
if (fieldBits > 64 || fieldBits <= 0)
throwEvaluateError("bitfield entry must occupy between 1 and 64 bits");
bits += fieldBits;
entryPatterns.emplace_back(name, fieldBits);
}
size_t size = (bits + 7) / 8;
this->m_currOffset += size;
return new PatternDataBitfield(startOffset, size, entryPatterns);
}
PatternData* Evaluator::evaluateType(ASTNodeTypeDecl *node) {
auto type = node->getType();
this->m_endianStack.push_back(node->getEndian().value_or(this->m_defaultDataEndian));
PatternData *pattern;
if (auto builtinTypeNode = dynamic_cast<ASTNodeBuiltinType*>(type); builtinTypeNode != nullptr)
return this->evaluateBuiltinType(builtinTypeNode);
else if (auto typeDeclNode = dynamic_cast<ASTNodeTypeDecl*>(type); typeDeclNode != nullptr)
pattern = this->evaluateType(typeDeclNode);
else if (auto structNode = dynamic_cast<ASTNodeStruct*>(type); structNode != nullptr)
pattern = this->evaluateStruct(structNode);
else if (auto unionNode = dynamic_cast<ASTNodeUnion*>(type); unionNode != nullptr)
pattern = this->evaluateUnion(unionNode);
else if (auto enumNode = dynamic_cast<ASTNodeEnum*>(type); enumNode != nullptr)
pattern = this->evaluateEnum(enumNode);
else if (auto bitfieldNode = dynamic_cast<ASTNodeBitfield*>(type); bitfieldNode != nullptr)
pattern = this->evaluateBitfield(bitfieldNode);
else
throwEvaluateError("type could not be evaluated");
if (!node->getName().empty())
pattern->setTypeName(node->getName().data());
pattern->setEndian(this->getCurrentEndian());
this->m_endianStack.pop_back();
return pattern;
}
PatternData* Evaluator::evaluateVariable(ASTNodeVariableDecl *node) {
if (auto offset = dynamic_cast<ASTNodeNumericExpression*>(node->getPlacementOffset()); offset != nullptr) {
auto valueNode = evaluateMathematicalExpression(offset);
SCOPE_EXIT( delete valueNode; );
this->m_currOffset = std::visit([node, type = valueNode->getType()] (auto &&value) {
if (Token::isFloatingPoint(type))
throwEvaluateError("placement offset must be an integer value");
return static_cast<u64>(value);
}, valueNode->getValue());
}
if (this->m_currOffset >= this->m_provider->getActualSize())
throwEvaluateError("variable placed out of range");
PatternData *pattern;
if (auto typeDecl = dynamic_cast<ASTNodeTypeDecl*>(node->getType()); typeDecl != nullptr)
pattern = this->evaluateType(typeDecl);
else if (auto builtinTypeDecl = dynamic_cast<ASTNodeBuiltinType*>(node->getType()); builtinTypeDecl != nullptr)
pattern = this->evaluateBuiltinType(builtinTypeDecl);
else
throwEvaluateError("ASTNodeVariableDecl had an invalid type. This is a bug!");
pattern->setVariableName(node->getName().data());
return pattern;
}
PatternData* Evaluator::evaluateArray(ASTNodeArrayVariableDecl *node) {
if (auto offset = dynamic_cast<ASTNodeNumericExpression*>(node->getPlacementOffset()); offset != nullptr) {
auto valueNode = evaluateMathematicalExpression(offset);
SCOPE_EXIT( delete valueNode; );
this->m_currOffset = std::visit([node, type = valueNode->getType()] (auto &&value) {
if (Token::isFloatingPoint(type))
throwEvaluateError("placement offset must be an integer value");
return static_cast<u64>(value);
}, valueNode->getValue());
}
auto startOffset = this->m_currOffset;
ASTNodeIntegerLiteral *valueNode;
u64 arraySize = 0;
if (node->getSize() != nullptr) {
if (auto sizeNumericExpression = dynamic_cast<ASTNodeNumericExpression*>(node->getSize()); sizeNumericExpression != nullptr)
valueNode = evaluateMathematicalExpression(sizeNumericExpression);
else
throwEvaluateError("array size not a numeric expression");
SCOPE_EXIT( delete valueNode; );
arraySize = std::visit([node, type = valueNode->getType()] (auto &&value) {
if (Token::isFloatingPoint(type))
throwEvaluateError("array size must be an integer value");
return static_cast<u64>(value);
}, valueNode->getValue());
if (auto typeDecl = dynamic_cast<ASTNodeTypeDecl*>(node->getType()); typeDecl != nullptr) {
if (auto builtinType = dynamic_cast<ASTNodeBuiltinType*>(typeDecl->getType()); builtinType != nullptr) {
if (builtinType->getType() == Token::ValueType::Padding) {
this->m_currOffset += arraySize;
return new PatternDataPadding(startOffset, arraySize);
}
}
}
} else {
u8 currByte = 0x00;
u64 offset = startOffset;
do {
this->m_provider->read(offset, &currByte, sizeof(u8));
offset += sizeof(u8);
arraySize += sizeof(u8);
} while (currByte != 0x00 && offset < this->m_provider->getSize());
}
std::vector<PatternData*> entries;
std::optional<u32> color;
for (s128 i = 0; i < arraySize; i++) {
PatternData *entry;
if (auto typeDecl = dynamic_cast<ASTNodeTypeDecl*>(node->getType()); typeDecl != nullptr)
entry = this->evaluateType(typeDecl);
else if (auto builtinTypeDecl = dynamic_cast<ASTNodeBuiltinType*>(node->getType()); builtinTypeDecl != nullptr) {
entry = this->evaluateBuiltinType(builtinTypeDecl);
}
else
throwEvaluateError("ASTNodeVariableDecl had an invalid type. This is a bug!");
entry->setVariableName(hex::format("[%llu]", (u64)i));
entry->setEndian(this->getCurrentEndian());
if (!color.has_value())
color = entry->getColor();
entry->setColor(color.value_or(0));
entries.push_back(entry);
if (this->m_currOffset > this->m_provider->getActualSize())
throwEvaluateError("array exceeds size of file");
}
PatternData *pattern;
if (entries.empty()) {
pattern = new PatternDataPadding(startOffset, 0);
}
else if (dynamic_cast<PatternDataCharacter*>(entries[0]))
pattern = new PatternDataString(startOffset, (this->m_currOffset - startOffset), color.value_or(0));
else {
if (node->getSize() == nullptr)
throwEvaluateError("no bounds provided for array");
pattern = new PatternDataArray(startOffset, (this->m_currOffset - startOffset), entries, color.value_or(0));
}
pattern->setVariableName(node->getName().data());
return pattern;
}
PatternData* Evaluator::evaluatePointer(ASTNodePointerVariableDecl *node) {
s128 pointerOffset;
if (auto offset = dynamic_cast<ASTNodeNumericExpression*>(node->getPlacementOffset()); offset != nullptr) {
auto valueNode = evaluateMathematicalExpression(offset);
SCOPE_EXIT( delete valueNode; );
pointerOffset = std::visit([node, type = valueNode->getType()] (auto &&value) {
if (Token::isFloatingPoint(type))
throwEvaluateError("pointer offset must be an integer value");
return static_cast<s128>(value);
}, valueNode->getValue());
this->m_currOffset = pointerOffset;
} else {
pointerOffset = this->m_currOffset;
}
PatternData *sizeType;
auto underlyingType = dynamic_cast<ASTNodeTypeDecl*>(node->getSizeType());
if (underlyingType == nullptr)
throwEvaluateError("underlying type is not ASTNodeTypeDecl. This is a bug");
if (auto builtinTypeNode = dynamic_cast<ASTNodeBuiltinType*>(underlyingType->getType()); builtinTypeNode != nullptr) {
sizeType = evaluateBuiltinType(builtinTypeNode);
} else
throwEvaluateError("pointer size is not a builtin type");
size_t pointerSize = sizeType->getSize();
u128 pointedAtOffset = 0;
this->m_provider->read(pointerOffset, &pointedAtOffset, pointerSize);
this->m_currOffset = hex::changeEndianess(pointedAtOffset, pointerSize, underlyingType->getEndian().value_or(this->m_defaultDataEndian));
delete sizeType;
if (this->m_currOffset > this->m_provider->getActualSize())
throwEvaluateError("pointer points past the end of the data");
PatternData *pointedAt;
if (auto typeDecl = dynamic_cast<ASTNodeTypeDecl*>(node->getType()); typeDecl != nullptr)
pointedAt = this->evaluateType(typeDecl);
else if (auto builtinTypeDecl = dynamic_cast<ASTNodeBuiltinType*>(node->getType()); builtinTypeDecl != nullptr)
pointedAt = this->evaluateBuiltinType(builtinTypeDecl);
else
throwEvaluateError("ASTNodeVariableDecl had an invalid type. This is a bug!");
this->m_currOffset = pointerOffset + pointerSize;
auto pattern = new PatternDataPointer(pointerOffset, pointerSize, pointedAt);
pattern->setVariableName(node->getName().data());
pattern->setEndian(this->getCurrentEndian());
return pattern;
}
std::optional<std::vector<PatternData*>> Evaluator::evaluate(const std::vector<ASTNode *> &ast) {
try {
for (const auto& node : ast) {
this->m_endianStack.push_back(this->m_defaultDataEndian);
if (auto variableDeclNode = dynamic_cast<ASTNodeVariableDecl*>(node); variableDeclNode != nullptr) {
this->m_globalMembers.push_back(this->evaluateVariable(variableDeclNode));
} else if (auto arrayDeclNode = dynamic_cast<ASTNodeArrayVariableDecl*>(node); arrayDeclNode != nullptr) {
this->m_globalMembers.push_back(this->evaluateArray(arrayDeclNode));
} else if (auto pointerDeclNode = dynamic_cast<ASTNodePointerVariableDecl*>(node); pointerDeclNode != nullptr) {
this->m_globalMembers.push_back(this->evaluatePointer(pointerDeclNode));
} else if (auto typeDeclNode = dynamic_cast<ASTNodeTypeDecl*>(node); typeDeclNode != nullptr) {
this->m_types[typeDeclNode->getName().data()] = typeDeclNode->getType();
} else if (auto functionCallNode = dynamic_cast<ASTNodeFunctionCall*>(node); functionCallNode != nullptr) {
auto result = this->evaluateFunctionCall(functionCallNode);
delete result;
}
this->m_endianStack.clear();
}
} catch (EvaluateError &e) {
this->m_consoleLog.emplace_back(ConsoleLogLevel::Error, e);
this->m_endianStack.clear();
return { };
}
this->m_endianStack.clear();
return this->m_globalMembers;
}
}

View File

@@ -0,0 +1,471 @@
#include "lang/lexer.hpp"
#include <algorithm>
#include <functional>
#include <optional>
#include <vector>
namespace hex::lang {
#define TOKEN(type, value) Token::Type::type, Token::type::value, lineNumber
#define VALUE_TOKEN(type, value) Token::Type::type, value, lineNumber
Lexer::Lexer() { }
std::string matchTillInvalid(const char* characters, std::function<bool(char)> predicate) {
std::string ret;
while (*characters != 0x00) {
ret += *characters;
characters++;
if (!predicate(*characters))
break;
}
return ret;
}
size_t getIntegerLiteralLength(std::string_view string) {
return string.find_first_not_of("0123456789ABCDEFabcdef.xUL");
}
std::optional<Token::IntegerLiteral> parseIntegerLiteral(std::string_view string) {
Token::ValueType type = Token::ValueType::Any;
Token::IntegerLiteral result;
u8 base;
auto endPos = getIntegerLiteralLength(string);
std::string_view numberData = string.substr(0, endPos);
if (numberData.ends_with('U')) {
type = Token::ValueType::Unsigned32Bit;
numberData.remove_suffix(1);
} else if (numberData.ends_with("UL")) {
type = Token::ValueType::Unsigned64Bit;
numberData.remove_suffix(2);
} else if (numberData.ends_with("ULL")) {
type = Token::ValueType::Unsigned128Bit;
numberData.remove_suffix(3);
} else if (numberData.ends_with("L")) {
type = Token::ValueType::Signed64Bit;
numberData.remove_suffix(1);
} else if (numberData.ends_with("LL")) {
type = Token::ValueType::Signed128Bit;
numberData.remove_suffix(2);
} else if (!numberData.starts_with("0x") && !numberData.starts_with("0b")) {
if (numberData.ends_with('F')) {
type = Token::ValueType::Float;
numberData.remove_suffix(1);
} else if (numberData.ends_with('D')) {
type = Token::ValueType::Double;
numberData.remove_suffix(1);
}
}
if (numberData.starts_with("0x")) {
numberData = numberData.substr(2);
base = 16;
if (Token::isFloatingPoint(type))
return { };
if (numberData.find_first_not_of("0123456789ABCDEFabcdef") != std::string_view::npos)
return { };
} else if (numberData.starts_with("0b")) {
numberData = numberData.substr(2);
base = 2;
if (Token::isFloatingPoint(type))
return { };
if (numberData.find_first_not_of("01") != std::string_view::npos)
return { };
} else if (numberData.find('.') != std::string_view::npos || Token::isFloatingPoint(type)) {
base = 10;
if (type == Token::ValueType::Any)
type = Token::ValueType::Double;
if (std::count(numberData.begin(), numberData.end(), '.') > 1 || numberData.find_first_not_of("0123456789.") != std::string_view::npos)
return { };
if (numberData.ends_with('.'))
return { };
} else if (isdigit(numberData[0])) {
base = 10;
if (numberData.find_first_not_of("0123456789") != std::string_view::npos)
return { };
} else return { };
if (type == Token::ValueType::Any)
type = Token::ValueType::Signed32Bit;
if (numberData.length() == 0)
return { };
if (Token::isUnsigned(type) || Token::isSigned(type)) {
u128 integer = 0;
for (const char& c : numberData) {
integer *= base;
if (isdigit(c))
integer += (c - '0');
else if (c >= 'A' && c <= 'F')
integer += 10 + (c - 'A');
else if (c >= 'a' && c <= 'f')
integer += 10 + (c - 'a');
else return { };
}
switch (type) {
case Token::ValueType::Unsigned32Bit: return {{ type, u32(integer) }};
case Token::ValueType::Signed32Bit: return {{ type, s32(integer) }};
case Token::ValueType::Unsigned64Bit: return {{ type, u64(integer) }};
case Token::ValueType::Signed64Bit: return {{ type, s64(integer) }};
case Token::ValueType::Unsigned128Bit: return {{ type, u128(integer) }};
case Token::ValueType::Signed128Bit: return {{ type, s128(integer) }};
default: return { };
}
} else if (Token::isFloatingPoint(type)) {
double floatingPoint = strtod(numberData.data(), nullptr);
switch (type) {
case Token::ValueType::Float: return {{ type, float(floatingPoint) }};
case Token::ValueType::Double: return {{ type, double(floatingPoint) }};
default: return { };
}
}
return { };
}
std::optional<std::pair<char, size_t>> getCharacter(std::string_view string) {
if (string.length() < 1)
return { };
// Escape sequences
if (string[0] == '\\') {
if (string.length() < 2)
return { };
// Handle simple escape sequences
switch (string[1]) {
case 'a': return {{ '\a', 2 }};
case 'b': return {{ '\b', 2 }};
case 'f': return {{ '\f', 2 }};
case 'n': return {{ '\n', 2 }};
case 'r': return {{ '\r', 2 }};
case 't': return {{ '\t', 2 }};
case 'v': return {{ '\v', 2 }};
case '\\': return {{ '\\', 2 }};
case '\'': return {{ '\'', 2 }};
case '\"': return {{ '\"', 2 }};
}
// Hexadecimal number
if (string[1] == 'x') {
if (string.length() != 4)
return { };
if (!isxdigit(string[2]) || !isxdigit(string[3]))
return { };
return {{ std::strtoul(&string[2], nullptr, 16), 4 }};
}
// Octal number
if (string[1] == 'o') {
if (string.length() != 5)
return { };
if (string[2] < '0' || string[2] > '7' || string[3] < '0' || string[3] > '7' || string[4] < '0' || string[4] > '7')
return { };
return {{ std::strtoul(&string[2], nullptr, 8), 5 }};
}
return { };
} else return {{ string[0], 1 }};
}
std::optional<std::pair<std::string, size_t>> getStringLiteral(std::string_view string) {
if (!string.starts_with('\"'))
return { };
size_t size = 1;
std::string result;
while (string[size] != '\"') {
auto character = getCharacter(string.substr(size));
if (!character.has_value())
return { };
auto &[c, charSize] = character.value();
result += c;
size += charSize;
if (size >= string.length())
return { };
}
return {{ result, size + 1 }};
}
std::optional<std::pair<char, size_t>> getCharacterLiteral(std::string_view string) {
if (string.empty())
return { };
if (!string[0] != '\'')
return { };
auto character = getCharacter(string.substr(1));
if (!character.has_value())
return { };
auto &[c, charSize] = character.value();
if (string.length() >= charSize || string[charSize] != '\'')
return { };
return {{ c, charSize + 2 }};
}
std::optional<std::vector<Token>> Lexer::lex(const std::string& code) {
std::vector<Token> tokens;
u32 offset = 0;
u32 lineNumber = 1;
try {
while (offset < code.length()) {
const char& c = code[offset];
if (c == 0x00)
break;
if (std::isblank(c) || std::isspace(c)) {
if (code[offset] == '\n') lineNumber++;
offset += 1;
} else if (c == ';') {
tokens.emplace_back(TOKEN(Separator, EndOfExpression));
offset += 1;
} else if (c == '(') {
tokens.emplace_back(TOKEN(Separator, RoundBracketOpen));
offset += 1;
} else if (c == ')') {
tokens.emplace_back(TOKEN(Separator, RoundBracketClose));
offset += 1;
} else if (c == '{') {
tokens.emplace_back(TOKEN(Separator, CurlyBracketOpen));
offset += 1;
} else if (c == '}') {
tokens.emplace_back(TOKEN(Separator, CurlyBracketClose));
offset += 1;
} else if (c == '[') {
tokens.emplace_back(TOKEN(Separator, SquareBracketOpen));
offset += 1;
} else if (c == ']') {
tokens.emplace_back(TOKEN(Separator, SquareBracketClose));
offset += 1;
} else if (c == ',') {
tokens.emplace_back(TOKEN(Separator, Comma));
offset += 1;
} else if (c == '.') {
tokens.emplace_back(TOKEN(Separator, Dot));
offset += 1;
} else if (c == '@') {
tokens.emplace_back(TOKEN(Operator, AtDeclaration));
offset += 1;
} else if (code.substr(offset, 2) == "==") {
tokens.emplace_back(TOKEN(Operator, BoolEquals));
offset += 2;
} else if (code.substr(offset, 2) == "!=") {
tokens.emplace_back(TOKEN(Operator, BoolNotEquals));
offset += 2;
} else if (code.substr(offset, 2) == ">=") {
tokens.emplace_back(TOKEN(Operator, BoolGreaterThanOrEquals));
offset += 2;
} else if (code.substr(offset, 2) == "<=") {
tokens.emplace_back(TOKEN(Operator, BoolLessThanOrEquals));
offset += 2;
} else if (code.substr(offset, 2) == "&&") {
tokens.emplace_back(TOKEN(Operator, BoolAnd));
offset += 2;
} else if (code.substr(offset, 2) == "||") {
tokens.emplace_back(TOKEN(Operator, BoolOr));
offset += 2;
} else if (code.substr(offset, 2) == "^^") {
tokens.emplace_back(TOKEN(Operator, BoolXor));
offset += 2;
} else if (c == '=') {
tokens.emplace_back(TOKEN(Operator, Assignment));
offset += 1;
} else if (code.substr(offset, 2) == "::") {
tokens.emplace_back(TOKEN(Separator, ScopeResolution));
offset += 2;
} else if (c == ':') {
tokens.emplace_back(TOKEN(Operator, Inherit));
offset += 1;
} else if (c == '+') {
tokens.emplace_back(TOKEN(Operator, Plus));
offset += 1;
} else if (c == '-') {
tokens.emplace_back(TOKEN(Operator, Minus));
offset += 1;
} else if (c == '*') {
tokens.emplace_back(TOKEN(Operator, Star));
offset += 1;
} else if (c == '/') {
tokens.emplace_back(TOKEN(Operator, Slash));
offset += 1;
} else if (code.substr(offset, 2) == "<<") {
tokens.emplace_back(TOKEN(Operator, ShiftLeft));
offset += 2;
} else if (code.substr(offset, 2) == ">>") {
tokens.emplace_back(TOKEN(Operator, ShiftRight));
offset += 2;
} else if (c == '>') {
tokens.emplace_back(TOKEN(Operator, BoolGreaterThan));
offset += 1;
} else if (c == '<') {
tokens.emplace_back(TOKEN(Operator, BoolLessThan));
offset += 1;
} else if (c == '!') {
tokens.emplace_back(TOKEN(Operator, BoolNot));
offset += 1;
} else if (c == '|') {
tokens.emplace_back(TOKEN(Operator, BitOr));
offset += 1;
} else if (c == '&') {
tokens.emplace_back(TOKEN(Operator, BitAnd));
offset += 1;
} else if (c == '^') {
tokens.emplace_back(TOKEN(Operator, BitXor));
offset += 1;
} else if (c == '~') {
tokens.emplace_back(TOKEN(Operator, BitNot));
offset += 1;
} else if (c == '?') {
tokens.emplace_back(TOKEN(Operator, TernaryConditional));
offset += 1;
} else if (c == '\'') {
auto character = getCharacterLiteral(code.substr(offset));
if (!character.has_value())
throwLexerError("invalid character literal", lineNumber);
auto [c, charSize] = character.value();
tokens.emplace_back(VALUE_TOKEN(Integer, Token::IntegerLiteral(Token::ValueType::Character, c) ));
offset += charSize;
} else if (c == '\"') {
auto string = getStringLiteral(code.substr(offset));
if (!string.has_value())
throwLexerError("invalid string literal", lineNumber);
auto [s, stringSize] = string.value();
tokens.emplace_back(VALUE_TOKEN(String, s));
offset += stringSize;
} else if (std::isalpha(c)) {
std::string identifier = matchTillInvalid(&code[offset], [](char c) -> bool { return std::isalnum(c) || c == '_'; });
// Check for reserved keywords
if (identifier == "struct")
tokens.emplace_back(TOKEN(Keyword, Struct));
else if (identifier == "union")
tokens.emplace_back(TOKEN(Keyword, Union));
else if (identifier == "using")
tokens.emplace_back(TOKEN(Keyword, Using));
else if (identifier == "enum")
tokens.emplace_back(TOKEN(Keyword, Enum));
else if (identifier == "bitfield")
tokens.emplace_back(TOKEN(Keyword, Bitfield));
else if (identifier == "be")
tokens.emplace_back(TOKEN(Keyword, BigEndian));
else if (identifier == "le")
tokens.emplace_back(TOKEN(Keyword, LittleEndian));
else if (identifier == "if")
tokens.emplace_back(TOKEN(Keyword, If));
else if (identifier == "else")
tokens.emplace_back(TOKEN(Keyword, Else));
else if (identifier == "false")
tokens.emplace_back(VALUE_TOKEN(Integer, Token::IntegerLiteral(Token::ValueType::Boolean, s32(0))));
else if (identifier == "true")
tokens.emplace_back(VALUE_TOKEN(Integer, Token::IntegerLiteral(Token::ValueType::Boolean, s32(1))));
// Check for built-in types
else if (identifier == "u8")
tokens.emplace_back(TOKEN(ValueType, Unsigned8Bit));
else if (identifier == "s8")
tokens.emplace_back(TOKEN(ValueType, Signed8Bit));
else if (identifier == "u16")
tokens.emplace_back(TOKEN(ValueType, Unsigned16Bit));
else if (identifier == "s16")
tokens.emplace_back(TOKEN(ValueType, Signed16Bit));
else if (identifier == "u32")
tokens.emplace_back(TOKEN(ValueType, Unsigned32Bit));
else if (identifier == "s32")
tokens.emplace_back(TOKEN(ValueType, Signed32Bit));
else if (identifier == "u64")
tokens.emplace_back(TOKEN(ValueType, Unsigned64Bit));
else if (identifier == "s64")
tokens.emplace_back(TOKEN(ValueType, Signed64Bit));
else if (identifier == "u128")
tokens.emplace_back(TOKEN(ValueType, Unsigned128Bit));
else if (identifier == "s128")
tokens.emplace_back(TOKEN(ValueType, Signed128Bit));
else if (identifier == "float")
tokens.emplace_back(TOKEN(ValueType, Float));
else if (identifier == "double")
tokens.emplace_back(TOKEN(ValueType, Double));
else if (identifier == "char")
tokens.emplace_back(TOKEN(ValueType, Character));
else if (identifier == "bool")
tokens.emplace_back(TOKEN(ValueType, Boolean));
else if (identifier == "padding")
tokens.emplace_back(TOKEN(ValueType, Padding));
// If it's not a keyword and a builtin type, it has to be an identifier
else
tokens.emplace_back(VALUE_TOKEN(Identifier, identifier));
offset += identifier.length();
} else if (std::isdigit(c)) {
auto integer = parseIntegerLiteral(&code[offset]);
if (!integer.has_value())
throwLexerError("invalid integer literal", lineNumber);
tokens.emplace_back(VALUE_TOKEN(Integer, integer.value()));
offset += getIntegerLiteralLength(&code[offset]);
} else
throwLexerError("unknown token", lineNumber);
}
tokens.emplace_back(TOKEN(Separator, EndOfProgram));
} catch (LexerError &e) {
this->m_error = e;
return { };
}
return tokens;
}
}

View File

@@ -0,0 +1,635 @@
#include "lang/parser.hpp"
#include <optional>
#include <variant>
#define MATCHES(x) (begin() && x)
#define TO_NUMERIC_EXPRESSION(node) new ASTNodeNumericExpression((node), new ASTNodeIntegerLiteral({ Token::ValueType::Any, s32(0) }), Token::Operator::Plus)
// Definition syntax:
// [A] : Either A or no token
// [A|B] : Either A, B or no token
// <A|B> : Either A or B
// <A...> : One or more of A
// A B C : Sequence of tokens A then B then C
// (parseXXXX) : Parsing handled by other function
namespace hex::lang {
/* Mathematical expressions */
// Identifier([(parseMathematicalExpression)|<(parseMathematicalExpression),...>(parseMathematicalExpression)]
ASTNode* Parser::parseFunctionCall() {
auto functionName = getValue<std::string>(-2);
std::vector<ASTNode*> params;
ScopeExit paramCleanup([&]{
for (auto &param : params)
delete param;
});
while (!MATCHES(sequence(SEPARATOR_ROUNDBRACKETCLOSE))) {
if (MATCHES(sequence(STRING)))
params.push_back(parseStringLiteral());
else
params.push_back(parseMathematicalExpression());
if (MATCHES(sequence(SEPARATOR_COMMA, SEPARATOR_ROUNDBRACKETCLOSE)))
throwParseError("unexpected ',' at end of function parameter list", -1);
else if (MATCHES(sequence(SEPARATOR_ROUNDBRACKETCLOSE)))
break;
else if (!MATCHES(sequence(SEPARATOR_COMMA)))
throwParseError("missing ',' between parameters", -1);
}
paramCleanup.release();
return new ASTNodeFunctionCall(functionName, params);
}
ASTNode* Parser::parseStringLiteral() {
return new ASTNodeStringLiteral(getValue<std::string>(-1));
}
// Identifier::<Identifier[::]...>
ASTNode* Parser::parseScopeResolution(std::vector<std::string> &path) {
if (peek(IDENTIFIER, -1))
path.push_back(getValue<std::string>(-1));
if (MATCHES(sequence(SEPARATOR_SCOPE_RESOLUTION))) {
if (MATCHES(sequence(IDENTIFIER)))
return this->parseScopeResolution(path);
else
throwParseError("expected member name", -1);
} else
return TO_NUMERIC_EXPRESSION(new ASTNodeScopeResolution(path));
}
// <Identifier[.]...>
ASTNode* Parser::parseRValue(std::vector<std::string> &path) {
if (peek(IDENTIFIER, -1))
path.push_back(getValue<std::string>(-1));
if (MATCHES(sequence(SEPARATOR_DOT))) {
if (MATCHES(sequence(IDENTIFIER)))
return this->parseRValue(path);
else
throwParseError("expected member name", -1);
} else
return TO_NUMERIC_EXPRESSION(new ASTNodeRValue(path));
}
// <Integer|((parseMathematicalExpression))>
ASTNode* Parser::parseFactor() {
if (MATCHES(sequence(INTEGER)))
return TO_NUMERIC_EXPRESSION(new ASTNodeIntegerLiteral(getValue<Token::IntegerLiteral>(-1)));
else if (MATCHES(sequence(SEPARATOR_ROUNDBRACKETOPEN))) {
auto node = this->parseMathematicalExpression();
if (!MATCHES(sequence(SEPARATOR_ROUNDBRACKETCLOSE)))
throwParseError("expected closing parenthesis");
return node;
} else if (MATCHES(sequence(IDENTIFIER, SEPARATOR_SCOPE_RESOLUTION))) {
std::vector<std::string> path;
this->m_curr--;
return this->parseScopeResolution(path);
} else if (MATCHES(sequence(IDENTIFIER, SEPARATOR_ROUNDBRACKETOPEN))) {
return TO_NUMERIC_EXPRESSION(this->parseFunctionCall());
} else if (MATCHES(sequence(IDENTIFIER))) {
std::vector<std::string> path;
return this->parseRValue(path);
} else
throwParseError("expected integer or parenthesis");
}
// <+|-|!|~> (parseFactor)
ASTNode* Parser::parseUnaryExpression() {
if (MATCHES(sequence(OPERATOR_PLUS) || sequence(OPERATOR_MINUS) || sequence(OPERATOR_BOOLNOT) || sequence(OPERATOR_BITNOT))) {
auto op = getValue<Token::Operator>(-1);
return new ASTNodeNumericExpression(new ASTNodeIntegerLiteral({ Token::ValueType::Any, 0 }), this->parseFactor(), op);
}
return this->parseFactor();
}
// (parseUnaryExpression) <*|/> (parseUnaryExpression)
ASTNode* Parser::parseMultiplicativeExpression() {
auto node = this->parseUnaryExpression();
while (MATCHES(variant(OPERATOR_STAR, OPERATOR_SLASH))) {
auto op = getValue<Token::Operator>(-1);
node = new ASTNodeNumericExpression(node, this->parseUnaryExpression(), op);
}
return node;
}
// (parseMultiplicativeExpression) <+|-> (parseMultiplicativeExpression)
ASTNode* Parser::parseAdditiveExpression() {
auto node = this->parseMultiplicativeExpression();
while (MATCHES(variant(OPERATOR_PLUS, OPERATOR_MINUS))) {
auto op = getValue<Token::Operator>(-1);
node = new ASTNodeNumericExpression(node, this->parseMultiplicativeExpression(), op);
}
return node;
}
// (parseAdditiveExpression) < >>|<< > (parseAdditiveExpression)
ASTNode* Parser::parseShiftExpression() {
auto node = this->parseAdditiveExpression();
while (MATCHES(variant(OPERATOR_SHIFTLEFT, OPERATOR_SHIFTRIGHT))) {
auto op = getValue<Token::Operator>(-1);
node = new ASTNodeNumericExpression(node, this->parseAdditiveExpression(), op);
}
return node;
}
// (parseAdditiveExpression) < >=|<=|>|< > (parseAdditiveExpression)
ASTNode* Parser::parseRelationExpression() {
auto node = this->parseShiftExpression();
while (MATCHES(sequence(OPERATOR_BOOLGREATERTHAN) || sequence(OPERATOR_BOOLLESSTHAN) || sequence(OPERATOR_BOOLGREATERTHANOREQUALS) || sequence(OPERATOR_BOOLLESSTHANOREQUALS))) {
auto op = getValue<Token::Operator>(-1);
node = new ASTNodeNumericExpression(node, this->parseShiftExpression(), op);
}
return node;
}
// (parseRelationExpression) <==|!=> (parseRelationExpression)
ASTNode* Parser::parseEqualityExpression() {
auto node = this->parseRelationExpression();
while (MATCHES(sequence(OPERATOR_BOOLEQUALS) || sequence(OPERATOR_BOOLNOTEQUALS))) {
auto op = getValue<Token::Operator>(-1);
node = new ASTNodeNumericExpression(node, this->parseRelationExpression(), op);
}
return node;
}
// (parseEqualityExpression) & (parseEqualityExpression)
ASTNode* Parser::parseBinaryAndExpression() {
auto node = this->parseEqualityExpression();
while (MATCHES(sequence(OPERATOR_BITAND))) {
node = new ASTNodeNumericExpression(node, this->parseEqualityExpression(), Token::Operator::BitAnd);
}
return node;
}
// (parseBinaryAndExpression) ^ (parseBinaryAndExpression)
ASTNode* Parser::parseBinaryXorExpression() {
auto node = this->parseBinaryAndExpression();
while (MATCHES(sequence(OPERATOR_BITXOR))) {
node = new ASTNodeNumericExpression(node, this->parseBinaryAndExpression(), Token::Operator::BitXor);
}
return node;
}
// (parseBinaryXorExpression) | (parseBinaryXorExpression)
ASTNode* Parser::parseBinaryOrExpression() {
auto node = this->parseBinaryXorExpression();
while (MATCHES(sequence(OPERATOR_BITOR))) {
node = new ASTNodeNumericExpression(node, this->parseBinaryXorExpression(), Token::Operator::BitOr);
}
return node;
}
// (parseBinaryOrExpression) && (parseBinaryOrExpression)
ASTNode* Parser::parseBooleanAnd() {
auto node = this->parseBinaryOrExpression();
while (MATCHES(sequence(OPERATOR_BOOLAND))) {
node = new ASTNodeNumericExpression(node, this->parseBinaryOrExpression(), Token::Operator::BitOr);
}
return node;
}
// (parseBooleanAnd) ^^ (parseBooleanAnd)
ASTNode* Parser::parseBooleanXor() {
auto node = this->parseBooleanAnd();
while (MATCHES(sequence(OPERATOR_BOOLXOR))) {
node = new ASTNodeNumericExpression(node, this->parseBooleanAnd(), Token::Operator::BitOr);
}
return node;
}
// (parseBooleanXor) || (parseBooleanXor)
ASTNode* Parser::parseBooleanOr() {
auto node = this->parseBooleanXor();
while (MATCHES(sequence(OPERATOR_BOOLOR))) {
node = new ASTNodeNumericExpression(node, this->parseBooleanXor(), Token::Operator::BitOr);
}
return node;
}
// (parseBooleanOr) ? (parseBooleanOr) : (parseBooleanOr)
ASTNode* Parser::parseTernaryConditional() {
auto node = this->parseBooleanOr();
while (MATCHES(sequence(OPERATOR_TERNARYCONDITIONAL))) {
auto second = this->parseBooleanOr();
if (!MATCHES(sequence(OPERATOR_INHERIT)))
throwParseError("expected ':' in ternary expression");
auto third = this->parseBooleanOr();
node = new ASTNodeTernaryExpression(node, second, third, Token::Operator::TernaryConditional);
}
return node;
}
// (parseTernaryConditional)
ASTNode* Parser::parseMathematicalExpression() {
return this->parseTernaryConditional();
}
/* Control flow */
// if ((parseMathematicalExpression)) { (parseMember) }
ASTNode* Parser::parseConditional() {
auto condition = parseMathematicalExpression();
std::vector<ASTNode*> trueBody, falseBody;
ScopeExit cleanup([&]{
delete condition;
for (auto &statement : trueBody)
delete statement;
for (auto &statement : falseBody)
delete statement;
});
if (MATCHES(sequence(SEPARATOR_ROUNDBRACKETCLOSE, SEPARATOR_CURLYBRACKETOPEN))) {
while (!MATCHES(sequence(SEPARATOR_CURLYBRACKETCLOSE))) {
trueBody.push_back(parseMember());
}
} else if (MATCHES(sequence(SEPARATOR_ROUNDBRACKETCLOSE))) {
trueBody.push_back(parseMember());
} else
throwParseError("expected body of conditional statement");
if (MATCHES(sequence(KEYWORD_ELSE, SEPARATOR_CURLYBRACKETOPEN))) {
while (!MATCHES(sequence(SEPARATOR_CURLYBRACKETCLOSE))) {
falseBody.push_back(parseMember());
}
} else if (MATCHES(sequence(KEYWORD_ELSE))) {
falseBody.push_back(parseMember());
}
cleanup.release();
return new ASTNodeConditionalStatement(condition, trueBody, falseBody);
}
/* Type declarations */
// [be|le] <Identifier|u8|u16|u32|u64|u128|s8|s16|s32|s64|s128|float|double>
ASTNode* Parser::parseType(s32 startIndex) {
std::optional<std::endian> endian;
if (peekOptional(KEYWORD_LE, 0))
endian = std::endian::little;
else if (peekOptional(KEYWORD_BE, 0))
endian = std::endian::big;
if (getType(startIndex) == Token::Type::Identifier) { // Custom type
if (!this->m_types.contains(getValue<std::string>(startIndex)))
throwParseError("failed to parse type");
return new ASTNodeTypeDecl({ }, this->m_types[getValue<std::string>(startIndex)]->clone(), endian);
}
else { // Builtin type
return new ASTNodeTypeDecl({ }, new ASTNodeBuiltinType(getValue<Token::ValueType>(startIndex)), endian);
}
}
// using Identifier = (parseType)
ASTNode* Parser::parseUsingDeclaration() {
auto *type = dynamic_cast<ASTNodeTypeDecl *>(parseType(-1));
if (type == nullptr) throwParseError("invalid type used in variable declaration", -1);
if (peekOptional(KEYWORD_BE) || peekOptional(KEYWORD_LE))
return new ASTNodeTypeDecl(getValue<std::string>(-4), type, type->getEndian());
else
return new ASTNodeTypeDecl(getValue<std::string>(-3), type, type->getEndian());
}
// padding[(parseMathematicalExpression)]
ASTNode* Parser::parsePadding() {
auto size = parseMathematicalExpression();
if (!MATCHES(sequence(SEPARATOR_SQUAREBRACKETCLOSE))) {
delete size;
throwParseError("expected closing ']' at end of array declaration", -1);
}
return new ASTNodeArrayVariableDecl({ }, new ASTNodeTypeDecl({ }, new ASTNodeBuiltinType(Token::ValueType::Padding)), size);;
}
// (parseType) Identifier
ASTNode* Parser::parseMemberVariable() {
auto type = dynamic_cast<ASTNodeTypeDecl *>(parseType(-2));
if (type == nullptr) throwParseError("invalid type used in variable declaration", -1);
return new ASTNodeVariableDecl(getValue<std::string>(-1), type);
}
// (parseType) Identifier[(parseMathematicalExpression)]
ASTNode* Parser::parseMemberArrayVariable() {
auto type = dynamic_cast<ASTNodeTypeDecl *>(parseType(-3));
if (type == nullptr) throwParseError("invalid type used in variable declaration", -1);
auto name = getValue<std::string>(-2);
ASTNode *size = nullptr;
ScopeExit sizeCleanup([&]{ delete size; });
if (!MATCHES(sequence(SEPARATOR_SQUAREBRACKETCLOSE))) {
size = parseMathematicalExpression();
if (!MATCHES(sequence(SEPARATOR_SQUAREBRACKETCLOSE)))
throwParseError("expected closing ']' at end of array declaration", -1);
}
sizeCleanup.release();
return new ASTNodeArrayVariableDecl(name, type, size);
}
// (parseType) *Identifier : (parseType)
ASTNode* Parser::parseMemberPointerVariable() {
auto name = getValue<std::string>(-2);
auto pointerType = dynamic_cast<ASTNodeTypeDecl *>(parseType(-4));
if (pointerType == nullptr) throwParseError("invalid type used in variable declaration", -1);
if (!MATCHES((optional(KEYWORD_BE), optional(KEYWORD_LE)) && sequence(VALUETYPE_UNSIGNED)))
throwParseError("expected unsigned builtin type as size", -1);
auto sizeType = dynamic_cast<ASTNodeTypeDecl *>(parseType(-1));
if (sizeType == nullptr) throwParseError("invalid type used for pointer size", -1);
return new ASTNodePointerVariableDecl(name, pointerType, sizeType);
}
// [(parsePadding)|(parseMemberVariable)|(parseMemberArrayVariable)|(parseMemberPointerVariable)]
ASTNode* Parser::parseMember() {
ASTNode *member;
if (MATCHES(sequence(VALUETYPE_PADDING, SEPARATOR_SQUAREBRACKETOPEN)))
member = parsePadding();
else if (MATCHES((optional(KEYWORD_BE), optional(KEYWORD_LE)) && variant(IDENTIFIER, VALUETYPE_ANY) && sequence(IDENTIFIER, SEPARATOR_SQUAREBRACKETOPEN)))
member = parseMemberArrayVariable();
else if (MATCHES((optional(KEYWORD_BE), optional(KEYWORD_LE)) && variant(IDENTIFIER, VALUETYPE_ANY) && sequence(IDENTIFIER)))
member = parseMemberVariable();
else if (MATCHES((optional(KEYWORD_BE), optional(KEYWORD_LE)) && variant(IDENTIFIER, VALUETYPE_ANY) && sequence(OPERATOR_STAR, IDENTIFIER, OPERATOR_INHERIT)))
member = parseMemberPointerVariable();
else if (MATCHES(sequence(KEYWORD_IF, SEPARATOR_ROUNDBRACKETOPEN)))
return parseConditional();
else if (MATCHES(sequence(SEPARATOR_ENDOFPROGRAM)))
throwParseError("unexpected end of program", -2);
else
throwParseError("invalid struct member", 0);
if (!MATCHES(sequence(SEPARATOR_ENDOFEXPRESSION)))
throwParseError("missing ';' at end of expression", -1);
return member;
}
// struct Identifier { <(parseMember)...> }
ASTNode* Parser::parseStruct() {
const auto structNode = new ASTNodeStruct();
const auto &typeName = getValue<std::string>(-2);
ScopeExit structGuard([&]{ delete structNode; });
while (!MATCHES(sequence(SEPARATOR_CURLYBRACKETCLOSE))) {
structNode->addMember(parseMember());
}
structGuard.release();
return new ASTNodeTypeDecl(typeName, structNode);
}
// union Identifier { <(parseMember)...> }
ASTNode* Parser::parseUnion() {
const auto unionNode = new ASTNodeUnion();
const auto &typeName = getValue<std::string>(-2);
ScopeExit unionGuard([&]{ delete unionNode; });
while (!MATCHES(sequence(SEPARATOR_CURLYBRACKETCLOSE))) {
unionNode->addMember(parseMember());
}
unionGuard.release();
return new ASTNodeTypeDecl(typeName, unionNode);
}
// enum Identifier : (parseType) { <<Identifier|Identifier = (parseMathematicalExpression)[,]>...> }
ASTNode* Parser::parseEnum() {
std::string typeName;
if (peekOptional(KEYWORD_BE) || peekOptional(KEYWORD_LE))
typeName = getValue<std::string>(-5);
else
typeName = getValue<std::string>(-4);
auto underlyingType = dynamic_cast<ASTNodeTypeDecl*>(parseType(-2));
if (underlyingType == nullptr) throwParseError("failed to parse type", -2);
if (underlyingType->getEndian().has_value()) throwParseError("underlying type may not have an endian specification", -2);
const auto enumNode = new ASTNodeEnum(underlyingType);
ScopeExit enumGuard([&]{ delete enumNode; });
ASTNode *lastEntry = nullptr;
while (!MATCHES(sequence(SEPARATOR_CURLYBRACKETCLOSE))) {
if (MATCHES(sequence(IDENTIFIER, OPERATOR_ASSIGNMENT))) {
auto name = getValue<std::string>(-2);
auto value = parseMathematicalExpression();
enumNode->addEntry(name, value);
lastEntry = value;
}
else if (MATCHES(sequence(IDENTIFIER))) {
ASTNode *valueExpr;
auto name = getValue<std::string>(-1);
if (enumNode->getEntries().empty())
valueExpr = lastEntry = TO_NUMERIC_EXPRESSION(new ASTNodeIntegerLiteral({ Token::ValueType::Unsigned8Bit, u8(0) }));
else
valueExpr = new ASTNodeNumericExpression(lastEntry->clone(), new ASTNodeIntegerLiteral({ Token::ValueType::Any, s32(1) }), Token::Operator::Plus);
enumNode->addEntry(name, valueExpr);
}
else if (MATCHES(sequence(SEPARATOR_ENDOFPROGRAM)))
throwParseError("unexpected end of program", -2);
else
throwParseError("invalid enum entry", -1);
if (!MATCHES(sequence(SEPARATOR_COMMA))) {
if (MATCHES(sequence(SEPARATOR_CURLYBRACKETCLOSE)))
break;
else
throwParseError("missing ',' between enum entries", -1);
}
}
enumGuard.release();
return new ASTNodeTypeDecl(typeName, enumNode);
}
// bitfield Identifier { <Identifier : (parseMathematicalExpression)[;]...> }
ASTNode* Parser::parseBitfield() {
std::string typeName = getValue<std::string>(-2);
const auto bitfieldNode = new ASTNodeBitfield();
ScopeExit enumGuard([&]{ delete bitfieldNode; });
while (!MATCHES(sequence(SEPARATOR_CURLYBRACKETCLOSE))) {
if (MATCHES(sequence(IDENTIFIER, OPERATOR_INHERIT))) {
auto name = getValue<std::string>(-2);
bitfieldNode->addEntry(name, parseMathematicalExpression());
}
else if (MATCHES(sequence(SEPARATOR_ENDOFPROGRAM)))
throwParseError("unexpected end of program", -2);
else
throwParseError("invalid bitfield member", 0);
if (!MATCHES(sequence(SEPARATOR_ENDOFEXPRESSION))) {
throwParseError("missing ';' at end of expression", -1);
}
}
enumGuard.release();
return new ASTNodeTypeDecl(typeName, bitfieldNode);
}
// (parseType) Identifier @ Integer
ASTNode* Parser::parseVariablePlacement() {
auto type = dynamic_cast<ASTNodeTypeDecl *>(parseType(-3));
if (type == nullptr) throwParseError("invalid type used in variable declaration", -1);
return new ASTNodeVariableDecl(getValue<std::string>(-2), type, parseMathematicalExpression());
}
// (parseType) Identifier[[(parseMathematicalExpression)]] @ Integer
ASTNode* Parser::parseArrayVariablePlacement() {
auto type = dynamic_cast<ASTNodeTypeDecl *>(parseType(-3));
if (type == nullptr) throwParseError("invalid type used in variable declaration", -1);
auto name = getValue<std::string>(-2);
ASTNode *size = nullptr;
ScopeExit sizeCleanup([&]{ delete size; });
if (!MATCHES(sequence(SEPARATOR_SQUAREBRACKETCLOSE))) {
size = parseMathematicalExpression();
if (!MATCHES(sequence(SEPARATOR_SQUAREBRACKETCLOSE)))
throwParseError("expected closing ']' at end of array declaration", -1);
}
if (!MATCHES(sequence(OPERATOR_AT)))
throwParseError("expected placement instruction", -1);
sizeCleanup.release();
return new ASTNodeArrayVariableDecl(name, type, size, parseMathematicalExpression());
}
// (parseType) *Identifier : (parseType) @ Integer
ASTNode* Parser::parsePointerVariablePlacement() {
auto name = getValue<std::string>(-2);
auto temporaryPointerType = dynamic_cast<ASTNodeTypeDecl *>(parseType(-4));
if (temporaryPointerType == nullptr) throwParseError("invalid type used in variable declaration", -1);
if (!MATCHES((optional(KEYWORD_BE), optional(KEYWORD_LE)) && sequence(VALUETYPE_UNSIGNED)))
throwParseError("expected unsigned builtin type as size", -1);
auto temporaryPointerSizeType = dynamic_cast<ASTNodeTypeDecl *>(parseType(-1));
if (temporaryPointerSizeType == nullptr) throwParseError("invalid size type used in pointer declaration", -1);
if (!MATCHES(sequence(OPERATOR_AT)))
throwParseError("expected placement instruction", -1);
return new ASTNodePointerVariableDecl(name, temporaryPointerType, temporaryPointerSizeType, parseMathematicalExpression());
}
/* Program */
// <(parseUsingDeclaration)|(parseVariablePlacement)|(parseStruct)>
ASTNode* Parser::parseStatement() {
ASTNode *statement;
if (MATCHES(sequence(KEYWORD_USING, IDENTIFIER, OPERATOR_ASSIGNMENT) && (optional(KEYWORD_BE), optional(KEYWORD_LE)) && variant(IDENTIFIER, VALUETYPE_ANY)))
statement = dynamic_cast<ASTNodeTypeDecl*>(parseUsingDeclaration());
else if (MATCHES((optional(KEYWORD_BE), optional(KEYWORD_LE)) && variant(IDENTIFIER, VALUETYPE_ANY) && sequence(IDENTIFIER, SEPARATOR_SQUAREBRACKETOPEN)))
statement = parseArrayVariablePlacement();
else if (MATCHES((optional(KEYWORD_BE), optional(KEYWORD_LE)) && variant(IDENTIFIER, VALUETYPE_ANY) && sequence(IDENTIFIER, OPERATOR_AT)))
statement = parseVariablePlacement();
else if (MATCHES((optional(KEYWORD_BE), optional(KEYWORD_LE)) && variant(IDENTIFIER, VALUETYPE_ANY) && sequence(OPERATOR_STAR, IDENTIFIER, OPERATOR_INHERIT)))
statement = parsePointerVariablePlacement();
else if (MATCHES(sequence(KEYWORD_STRUCT, IDENTIFIER, SEPARATOR_CURLYBRACKETOPEN)))
statement = parseStruct();
else if (MATCHES(sequence(KEYWORD_UNION, IDENTIFIER, SEPARATOR_CURLYBRACKETOPEN)))
statement = parseUnion();
else if (MATCHES(sequence(KEYWORD_ENUM, IDENTIFIER, OPERATOR_INHERIT) && (optional(KEYWORD_BE), optional(KEYWORD_LE)) && sequence(VALUETYPE_UNSIGNED, SEPARATOR_CURLYBRACKETOPEN)))
statement = parseEnum();
else if (MATCHES(sequence(KEYWORD_BITFIELD, IDENTIFIER, SEPARATOR_CURLYBRACKETOPEN)))
statement = parseBitfield();
else if (MATCHES(sequence(IDENTIFIER, SEPARATOR_ROUNDBRACKETOPEN)))
statement = parseFunctionCall();
else throwParseError("invalid sequence", 0);
if (!MATCHES(sequence(SEPARATOR_ENDOFEXPRESSION)))
throwParseError("missing ';' at end of expression", -1);
if (auto typeDecl = dynamic_cast<ASTNodeTypeDecl*>(statement); typeDecl != nullptr)
this->m_types.insert({ typeDecl->getName().data(), typeDecl });
return statement;
}
// <(parseStatement)...> EndOfProgram
std::optional<std::vector<ASTNode*>> Parser::parse(const std::vector<Token> &tokens) {
this->m_curr = tokens.begin();
this->m_types.clear();
try {
auto program = parseTillToken(SEPARATOR_ENDOFPROGRAM);
if (program.empty() || this->m_curr != tokens.end())
throwParseError("program is empty!", -1);
return program;
} catch (ParseError &e) {
this->m_error = e;
}
return { };
}
}

View File

@@ -0,0 +1,220 @@
#include "lang/preprocessor.hpp"
namespace hex::lang {
Preprocessor::Preprocessor() {
}
std::optional<std::string> Preprocessor::preprocess(const std::string& code, bool initialRun) {
u32 offset = 0;
u32 lineNumber = 1;
if (initialRun) {
this->m_defines.clear();
this->m_pragmas.clear();
}
std::string output;
output.reserve(code.length());
try {
while (offset < code.length()) {
if (code[offset] == '#') {
offset += 1;
if (code.substr(offset, 7) == "include") {
offset += 7;
while (std::isblank(code[offset]) || std::isspace(code[offset]))
offset += 1;
if (code[offset] != '<' && code[offset] != '"')
throwPreprocessorError("expected '<' or '\"' before file name", lineNumber);
char endChar = code[offset];
if (endChar == '<') endChar = '>';
offset += 1;
std::string includeFile;
while (code[offset] != endChar) {
includeFile += code[offset];
offset += 1;
if (offset >= code.length())
throwPreprocessorError(hex::format("missing terminating '%c' character", endChar), lineNumber);
}
offset += 1;
if (includeFile[0] != '/')
includeFile = "include/" + includeFile;
FILE *file = fopen(includeFile.c_str(), "r");
if (file == nullptr)
throwPreprocessorError(hex::format("%s: No such file or directory", includeFile.c_str()), lineNumber);
fseek(file, 0, SEEK_END);
size_t size = ftell(file);
char *buffer = new char[size + 1];
rewind(file);
fread(buffer, size, 1, file);
buffer[size] = 0x00;
auto preprocessedInclude = this->preprocess(buffer, false);
if (!preprocessedInclude.has_value())
throw this->m_error;
auto content = preprocessedInclude.value();
std::replace(content.begin(), content.end(), '\n', ' ');
std::replace(content.begin(), content.end(), '\r', ' ');
output += content;
delete[] buffer;
fclose(file);
} else if (code.substr(offset, 6) == "define") {
offset += 6;
while (std::isblank(code[offset])) {
offset += 1;
}
std::string defineName;
while (!std::isblank(code[offset])) {
defineName += code[offset];
if (offset >= code.length() || code[offset] == '\n' || code[offset] == '\r')
throwPreprocessorError("no value given in #define directive", lineNumber);
offset += 1;
}
while (std::isblank(code[offset])) {
offset += 1;
if (offset >= code.length())
throwPreprocessorError("no value given in #define directive", lineNumber);
}
std::string replaceValue;
while (code[offset] != '\n' && code[offset] != '\r') {
if (offset >= code.length())
throwPreprocessorError("missing new line after #define directive", lineNumber);
replaceValue += code[offset];
offset += 1;
}
if (replaceValue.empty())
throwPreprocessorError("no value given in #define directive", lineNumber);
this->m_defines.emplace(defineName, replaceValue);
} else if (code.substr(offset, 6) == "pragma") {
offset += 6;
while (std::isblank(code[offset]))
offset += 1;
std::string pragmaKey;
while (!std::isblank(code[offset])) {
pragmaKey += code[offset];
if (offset >= code.length() || code[offset] == '\n' || code[offset] == '\r')
throwPreprocessorError("no instruction given in #pragma directive", lineNumber);
offset += 1;
}
while (std::isblank(code[offset]))
offset += 1;
std::string pragmaValue;
while (code[offset] != '\n' && code[offset] != '\r') {
if (offset >= code.length())
throwPreprocessorError("missing new line after #pragma directive", lineNumber);
pragmaValue += code[offset];
offset += 1;
}
if (pragmaValue.empty())
throwPreprocessorError("missing value in #pragma directive", lineNumber);
this->m_pragmas.emplace(pragmaKey, pragmaValue);
} else
throwPreprocessorError("unknown preprocessor directive", lineNumber);
} else if (code.substr(offset, 2) == "//") {
while (code[offset] != '\n' && offset < code.length())
offset += 1;
} else if (code.substr(offset, 2) == "/*") {
while (code.substr(offset, 2) != "*/" && offset < code.length()) {
if (code[offset] == '\n') {
output += '\n';
lineNumber++;
}
offset += 1;
}
offset += 2;
if (offset >= code.length())
throwPreprocessorError("unterminated comment", lineNumber - 1);
}
if (code[offset] == '\n')
lineNumber++;
output += code[offset];
offset += 1;
}
if (initialRun) {
// Apply defines
std::vector<std::pair<std::string, std::string>> sortedDefines;
std::copy(this->m_defines.begin(), this->m_defines.end(), std::back_inserter(sortedDefines));
std::sort(sortedDefines.begin(), sortedDefines.end(), [](const auto &left, const auto &right) {
return left.first.size() > right.first.size();
});
for (const auto &[define, value] : sortedDefines) {
s32 index = 0;
while((index = output.find(define, index)) != std::string::npos) {
output.replace(index, define.length(), value);
index += value.length();
}
}
// Handle pragmas
for (const auto &[type, value] : this->m_pragmas) {
if (this->m_pragmaHandlers.contains(type)) {
if (!this->m_pragmaHandlers[type](value))
throwPreprocessorError(hex::format("invalid value provided to '%s' #pragma directive", type.c_str()), lineNumber);
} else
throwPreprocessorError(hex::format("no #pragma handler registered for type %s", type.c_str()), lineNumber);
}
}
} catch (PreprocessorError &e) {
this->m_error = e;
return { };
}
return output;
}
void Preprocessor::addPragmaHandler(const std::string &pragmaType, const std::function<bool(const std::string&)> &function) {
if (!this->m_pragmaHandlers.contains(pragmaType))
this->m_pragmaHandlers.emplace(pragmaType, function);
}
void Preprocessor::addDefaultPragmaHandlers() {
this->addPragmaHandler("MIME", [](const std::string &value) {
return !std::all_of(value.begin(), value.end(), isspace) && !value.ends_with('\n') && !value.ends_with('\r');
});
this->addPragmaHandler("endian", [](const std::string &value) {
return value == "big" || value == "little" || value == "native";
});
}
}

View File

@@ -0,0 +1,163 @@
#include "lang/validator.hpp"
#include <unordered_set>
#include <string>
#include "helpers/utils.hpp"
namespace hex::lang {
Validator::Validator() {
}
bool Validator::validate(const std::vector<ASTNode*>& ast) {
std::unordered_set<std::string> identifiers;
try {
for (const auto &node : ast) {
if (node == nullptr)
throwValidateError("nullptr in AST. This is a bug!", 1);
if (auto variableDeclNode = dynamic_cast<ASTNodeVariableDecl*>(node); variableDeclNode != nullptr) {
if (!identifiers.insert(variableDeclNode->getName().data()).second)
throwValidateError(hex::format("redefinition of identifier '%s'", variableDeclNode->getName().data()), variableDeclNode->getLineNumber());
this->validate({ variableDeclNode->getType() });
} else if (auto typeDeclNode = dynamic_cast<ASTNodeTypeDecl*>(node); typeDeclNode != nullptr) {
if (!identifiers.insert(typeDeclNode->getName().data()).second)
throwValidateError(hex::format("redefinition of identifier '%s'", typeDeclNode->getName().data()), typeDeclNode->getLineNumber());
this->validate({ typeDeclNode->getType() });
} else if (auto structNode = dynamic_cast<ASTNodeStruct*>(node); structNode != nullptr) {
this->validate(structNode->getMembers());
} else if (auto unionNode = dynamic_cast<ASTNodeUnion*>(node); unionNode != nullptr) {
this->validate(unionNode->getMembers());
} else if (auto enumNode = dynamic_cast<ASTNodeEnum*>(node); enumNode != nullptr) {
std::unordered_set<std::string> enumIdentifiers;
for (auto &[name, value] : enumNode->getEntries()) {
if (!enumIdentifiers.insert(name).second)
throwValidateError(hex::format("redefinition of enum constant '%s'", name.c_str()), value->getLineNumber());
}
}
}
} catch (ValidatorError &e) {
this->m_error = e;
return false;
}
return true;
}
void Validator::printAST(const std::vector<ASTNode*>& ast){
#if DEBUG
#define INDENT_VALUE indent, ' '
static s32 indent = -2;
indent += 2;
for (const auto &node : ast) {
if (auto variableDeclNode = dynamic_cast<ASTNodeVariableDecl*>(node); variableDeclNode != nullptr) {
if (auto offset = dynamic_cast<ASTNodeNumericExpression*>(variableDeclNode->getPlacementOffset()); offset != nullptr) {
printf("%*c ASTNodeVariableDecl (%s) @\n", INDENT_VALUE, variableDeclNode->getName().data());
printAST({ offset });
}
else
printf("%*c ASTNodeVariableDecl (%s)\n", INDENT_VALUE, variableDeclNode->getName().data());
printAST({ variableDeclNode->getType() });
} else if (auto pointerDeclNode = dynamic_cast<ASTNodePointerVariableDecl*>(node); pointerDeclNode != nullptr) {
if (auto offset = dynamic_cast<ASTNodeNumericExpression*>(pointerDeclNode->getPlacementOffset()); offset != nullptr) {
printf("%*c ASTNodePointerVariableDecl (*%s) @\n", INDENT_VALUE, pointerDeclNode->getName().data());
printAST({ offset });
}
else
printf("%*c ASTNodePointerVariableDecl (*%s)\n", INDENT_VALUE, pointerDeclNode->getName().data());
printAST({ pointerDeclNode->getType() });
printAST({ pointerDeclNode->getSizeType() });
} else if (auto arrayDeclNode = dynamic_cast<ASTNodeArrayVariableDecl*>(node); arrayDeclNode != nullptr) {
auto sizeExpr = dynamic_cast<ASTNodeNumericExpression*>(arrayDeclNode->getSize());
if (sizeExpr == nullptr) {
printf("%*c Invalid size!\n", INDENT_VALUE);
continue;
}
if (auto offset = dynamic_cast<ASTNodeNumericExpression*>(arrayDeclNode->getPlacementOffset()); offset != nullptr) {
printf("%*c ASTNodeArrayVariableDecl (%s[]) @\n", INDENT_VALUE, arrayDeclNode->getName().data());
printAST({ sizeExpr });
printAST({ offset });
}
else {
printf("%*c ASTNodeArrayVariableDecl (%s[])\n", INDENT_VALUE, arrayDeclNode->getName().data());
printAST({ sizeExpr });
}
printAST({ arrayDeclNode->getType() });
printAST({ arrayDeclNode->getSize() });
} else if (auto typeDeclNode = dynamic_cast<ASTNodeTypeDecl*>(node); typeDeclNode != nullptr) {
printf("%*c ASTNodeTypeDecl (%s %s)\n", INDENT_VALUE, typeDeclNode->getEndian().value_or(std::endian::native) == std::endian::little ? "le" : "be", typeDeclNode->getName().empty() ? "<unnamed>" : typeDeclNode->getName().data());
printAST({ typeDeclNode->getType() });
} else if (auto builtinTypeNode = dynamic_cast<ASTNodeBuiltinType*>(node); builtinTypeNode != nullptr) {
std::string typeName = Token::getTypeName(builtinTypeNode->getType());
printf("%*c ASTNodeTypeDecl (%s)\n", INDENT_VALUE, typeName.c_str());
} else if (auto integerLiteralNode = dynamic_cast<ASTNodeIntegerLiteral*>(node); integerLiteralNode != nullptr) {
printf("%*c ASTNodeIntegerLiteral %lld\n", INDENT_VALUE, (s64)std::get<s128>(integerLiteralNode->getValue()));
} else if (auto numericExpressionNode = dynamic_cast<ASTNodeNumericExpression*>(node); numericExpressionNode != nullptr) {
std::string op;
switch (numericExpressionNode->getOperator()) {
case Token::Operator::Plus: op = "+"; break;
case Token::Operator::Minus: op = "-"; break;
case Token::Operator::Star: op = "*"; break;
case Token::Operator::Slash: op = "/"; break;
case Token::Operator::ShiftLeft: op = ">>"; break;
case Token::Operator::ShiftRight: op = "<<"; break;
case Token::Operator::BitAnd: op = "&"; break;
case Token::Operator::BitOr: op = "|"; break;
case Token::Operator::BitXor: op = "^"; break;
default: op = "???";
}
printf("%*c ASTNodeNumericExpression %s\n", INDENT_VALUE, op.c_str());
printf("%*c Left:\n", INDENT_VALUE);
printAST({ numericExpressionNode->getLeftOperand() });
printf("%*c Right:\n", INDENT_VALUE);
printAST({ numericExpressionNode->getRightOperand() });
} else if (auto structNode = dynamic_cast<ASTNodeStruct*>(node); structNode != nullptr) {
printf("%*c ASTNodeStruct\n", INDENT_VALUE);
printAST(structNode->getMembers());
} else if (auto unionNode = dynamic_cast<ASTNodeUnion*>(node); unionNode != nullptr) {
printf("%*c ASTNodeUnion\n", INDENT_VALUE);
printAST(unionNode->getMembers());
} else if (auto enumNode = dynamic_cast<ASTNodeEnum*>(node); enumNode != nullptr) {
printf("%*c ASTNodeEnum\n", INDENT_VALUE);
for (const auto &[name, entry] : enumNode->getEntries()) {
printf("%*c ::%s\n", INDENT_VALUE, name.c_str());
printAST({ entry });
}
} else if (auto bitfieldNode = dynamic_cast<ASTNodeBitfield*>(node); bitfieldNode != nullptr) {
printf("%*c ASTNodeBitfield\n", INDENT_VALUE);
for (const auto &[name, entry] : bitfieldNode->getEntries()) {
printf("%*c %s : \n", INDENT_VALUE, name.c_str());
printAST({ entry });
}
} else if (auto rvalueNode = dynamic_cast<ASTNodeRValue*>(node); rvalueNode != nullptr) {
printf("%*c ASTNodeRValue\n", INDENT_VALUE);
printf("%*c ", INDENT_VALUE);
for (const auto &path : rvalueNode->getPath())
printf("%s.", path.c_str());
printf("\n");
} else {
printf("%*c Invalid AST node!\n", INDENT_VALUE);
}
}
indent -= 2;
#undef INDENT_VALUE
#endif
}
}