diff --git a/lib/libimhex/include/hex/pattern_language/ast_node.hpp b/lib/libimhex/include/hex/pattern_language/ast_node.hpp index e2d6e5f11..01521fb91 100644 --- a/lib/libimhex/include/hex/pattern_language/ast_node.hpp +++ b/lib/libimhex/include/hex/pattern_language/ast_node.hpp @@ -97,6 +97,32 @@ namespace hex::pl { return this->m_attributes; } + bool hasAttribute(const std::string &key, bool needsParameter) const { + return std::any_of(this->m_attributes.begin(), this->m_attributes.end(), [&](ASTNodeAttribute *attribute) { + if (attribute->getAttribute() == key) { + if (needsParameter && !attribute->getValue().has_value()) + LogConsole::abortEvaluation(hex::format("attribute '{}' expected a parameter"), attribute); + else if (!needsParameter && attribute->getValue().has_value()) + LogConsole::abortEvaluation(hex::format("attribute '{}' did not expect a parameter "), attribute); + else + return true; + } + + return false; + }); + } + + [[nodiscard]] std::optional getAttributeValue(const std::string &key) const { + auto attribute = std::find_if(this->m_attributes.begin(), this->m_attributes.end(), [&](ASTNodeAttribute *attribute) { + return attribute->getAttribute() == key; + }); + + if (attribute != this->m_attributes.end()) + return (*attribute)->getValue(); + else + return std::nullopt; + } + private: std::vector m_attributes; }; @@ -692,91 +718,91 @@ namespace hex::pl { ASTNode *m_postExpression; }; - inline void applyVariableAttributes(Evaluator *evaluator, const Attributable *attributable, PatternData *pattern) { + inline void applyVariableAttributes(Evaluator *evaluator, const ASTNode *node, PatternData *pattern) { + auto attributable = dynamic_cast(node); + if (attributable == nullptr) + LogConsole::abortEvaluation("attribute cannot be applied here", node); + auto endOffset = evaluator->dataOffset(); evaluator->dataOffset() = pattern->getOffset(); ON_SCOPE_EXIT { evaluator->dataOffset() = endOffset; }; - for (ASTNodeAttribute *attribute : attributable->getAttributes()) { - auto &name = attribute->getAttribute(); - auto value = attribute->getValue(); + if (auto value = attributable->getAttributeValue("color"); value) { + u32 color = strtoul(value->c_str(), nullptr, 16); + pattern->setColor(hex::changeEndianess(color, std::endian::big) >> 8); + } - auto node = reinterpret_cast(attributable); + if (auto value = attributable->getAttributeValue("name"); value) { + pattern->setDisplayName(*value); + } - auto requiresValue = [&]() { - if (!value.has_value()) - LogConsole::abortEvaluation(hex::format("used attribute '{}' without providing a value", name), node); - return true; - }; + if (auto value = attributable->getAttributeValue("comment"); value) { + pattern->setComment(*value); + } - auto noValue = [&]() { - if (value.has_value()) - LogConsole::abortEvaluation(hex::format("provided a value to attribute '{}' which doesn't take one", name), node); - return true; - }; + if (auto value = attributable->getAttributeValue("format"); value) { + auto functions = evaluator->getCustomFunctions(); + if (!functions.contains(*value)) + LogConsole::abortEvaluation(hex::format("cannot find formatter function '{}'", *value), node); - if (name == "color" && requiresValue()) { - u32 color = strtoul(value->c_str(), nullptr, 16); - pattern->setColor(hex::changeEndianess(color, std::endian::big) >> 8); - } else if (name == "name" && requiresValue()) { - pattern->setDisplayName(*value); - } else if (name == "comment" && requiresValue()) { - pattern->setComment(*value); - } else if (name == "hidden" && noValue()) { - pattern->setHidden(true); - } else if (name == "no_unique_address" && noValue()) { - endOffset -= pattern->getSize(); - } else if (name == "inline" && noValue()) { - auto inlinable = dynamic_cast(pattern); + const auto &function = functions[*value]; + if (function.parameterCount != 1) + LogConsole::abortEvaluation("formatter function needs exactly one parameter", node); - if (inlinable == nullptr) - LogConsole::abortEvaluation("inline attribute can only be applied to nested types", node); - else - inlinable->setInlined(true); + pattern->setFormatterFunction(function); + } - } else if (name == "format" && requiresValue()) { - auto functions = evaluator->getCustomFunctions(); - if (!functions.contains(*value)) - LogConsole::abortEvaluation(hex::format("cannot find formatter function '{}'", *value), node); + if (auto value = attributable->getAttributeValue("transform"); value) { + auto functions = evaluator->getCustomFunctions(); + if (!functions.contains(*value)) + LogConsole::abortEvaluation(hex::format("cannot find transform function '{}'", *value), node); - const auto &function = functions[*value]; - if (function.parameterCount != 1) - LogConsole::abortEvaluation("formatter function needs exactly one parameter", node); + const auto &function = functions[*value]; + if (function.parameterCount != 1) + LogConsole::abortEvaluation("transform function needs exactly one parameter", node); - pattern->setFormatterFunction(function); - } else if (name == "transform" && requiresValue()) { - auto functions = evaluator->getCustomFunctions(); - if (!functions.contains(*value)) - LogConsole::abortEvaluation(hex::format("cannot find transform function '{}'", *value), node); + pattern->setTransformFunction(function); + } - const auto &function = functions[*value]; - if (function.parameterCount != 1) - LogConsole::abortEvaluation("transform function needs exactly one parameter", node); + if (auto value = attributable->getAttributeValue("pointer_base"); value) { + auto functions = evaluator->getCustomFunctions(); + if (!functions.contains(*value)) + LogConsole::abortEvaluation(hex::format("cannot find pointer base function '{}'", *value), node); - pattern->setTransformFunction(function); - } else if (name == "pointer_base" && requiresValue()) { - auto functions = evaluator->getCustomFunctions(); - if (!functions.contains(*value)) - LogConsole::abortEvaluation(hex::format("cannot find pointer base function '{}'", *value), node); + const auto &function = functions[*value]; + if (function.parameterCount != 1) + LogConsole::abortEvaluation("pointer base function needs exactly one parameter", node); - const auto &function = functions[*value]; - if (function.parameterCount != 1) - LogConsole::abortEvaluation("pointer base function needs exactly one parameter", node); + if (auto pointerPattern = dynamic_cast(pattern)) { + u128 pointerValue = pointerPattern->getPointedAtAddress(); - if (auto pointerPattern = dynamic_cast(pattern)) { - u128 pointerValue = pointerPattern->getPointedAtAddress(); + auto result = function.func(evaluator, { pointerValue }); - auto result = function.func(evaluator, { pointerValue }); + if (!result.has_value()) + LogConsole::abortEvaluation("pointer base function did not return a value", node); - if (!result.has_value()) - LogConsole::abortEvaluation("pointer base function did not return a value", node); - - pointerPattern->setPointedAtAddress(Token::literalToUnsigned(result.value()) + pointerValue); - } else { - LogConsole::abortEvaluation("pointer_base attribute may only be applied to a pointer"); - } + pointerPattern->setPointedAtAddress(Token::literalToUnsigned(result.value()) + pointerValue); + } else { + LogConsole::abortEvaluation("pointer_base attribute may only be applied to a pointer"); } } + + if (attributable->hasAttribute("hidden", false)) { + pattern->setHidden(true); + } + + if (attributable->hasAttribute("no_unique_address", false)) { + endOffset -= pattern->getSize(); + } + + if (attributable->hasAttribute("inline", false)) { + auto inlinable = dynamic_cast(pattern); + + if (inlinable == nullptr) + LogConsole::abortEvaluation("inline attribute can only be applied to nested types", node); + else + inlinable->setInlined(true); + } } class ASTNodeVariableDecl : public ASTNode, @@ -905,11 +931,7 @@ namespace hex::pl { if (dynamic_cast(type)) pattern = createStaticArray(evaluator); else if (auto attributable = dynamic_cast(type)) { - auto &attributes = attributable->getAttributes(); - - bool isStaticType = std::any_of(attributes.begin(), attributes.end(), [](ASTNodeAttribute *attribute) { - return attribute->getAttribute() == "static" && !attribute->getValue().has_value(); - }); + bool isStaticType = attributable->hasAttribute("static", false); if (isStaticType) pattern = createStaticArray(evaluator); @@ -1524,17 +1546,18 @@ namespace hex::pl { delete field; }; - auto &attributes = this->getAttributes(); - bool isLeftToRight = false; - - if (std::any_of(attributes.begin(), attributes.end(), [](ASTNodeAttribute *attribute) { return attribute->getAttribute() == "left_to_right" && !attribute->getValue().has_value(); })) + if (this->hasAttribute("left_to_right", false)) isLeftToRight = true; - else if (std::any_of(attributes.begin(), attributes.end(), [](ASTNodeAttribute *attribute) { return attribute->getAttribute() == "right_to_left" && !attribute->getValue().has_value(); })) + else if (this->hasAttribute("right_to_left", false)) isLeftToRight = false; + auto entries = this->m_entries; + if (isLeftToRight) + std::reverse(entries.begin(), entries.end()); + evaluator->pushScope(pattern, fields); - for (auto [name, bitSizeNode] : this->m_entries) { + for (auto [name, bitSizeNode] : entries) { auto literal = bitSizeNode->evaluate(evaluator); ON_SCOPE_EXIT { delete literal; }; @@ -1549,10 +1572,7 @@ namespace hex::pl { auto field = new PatternDataBitfieldField(evaluator, evaluator->dataOffset(), bitOffset, bitSize, pattern); field->setVariableName(name); - if (isLeftToRight) - fields.insert(fields.begin(), field); - else - fields.push_back(field); + fields.push_back(field); } bitOffset += bitSize;