diff --git a/CMakeLists.txt b/CMakeLists.txt index 3fa17c484..a428ed79c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -51,6 +51,9 @@ if (USE_SYSTEM_YARA) include_directories(include ${YARA_INCLUDE_DIRS}) endif() +enable_testing() +add_subdirectory(tests) + addVersionDefines() configurePackageCreation() diff --git a/plugins/libimhex/include/hex/helpers/file.hpp b/plugins/libimhex/include/hex/helpers/file.hpp index 6f5d984c6..fad35f296 100644 --- a/plugins/libimhex/include/hex/helpers/file.hpp +++ b/plugins/libimhex/include/hex/helpers/file.hpp @@ -24,15 +24,18 @@ namespace hex { }; explicit File(const std::string &path, Mode mode); + File(); ~File(); bool isValid() { return this->m_file != nullptr; } void seek(u64 offset); + size_t readBuffer(u8 *buffer, size_t size); std::vector readBytes(size_t numBytes = 0); std::string readString(size_t numBytes = 0); + void write(const u8 *buffer, size_t size); void write(const std::vector &bytes); void write(const std::string &string); diff --git a/plugins/libimhex/include/hex/pattern_language/pattern_data.hpp b/plugins/libimhex/include/hex/pattern_language/pattern_data.hpp index 94e5540bf..a0d3814e5 100644 --- a/plugins/libimhex/include/hex/pattern_language/pattern_data.hpp +++ b/plugins/libimhex/include/hex/pattern_language/pattern_data.hpp @@ -16,6 +16,7 @@ #include #include #include +#include namespace hex::pl { @@ -189,6 +190,23 @@ namespace hex::pl { return this->m_local; } + [[nodiscard]] virtual bool operator!=(const PatternData &other) const final { return !operator==(other); } + [[nodiscard]] virtual bool operator==(const PatternData &other) const = 0; + + template + [[nodiscard]] bool areCommonPropertiesEqual(const PatternData &other) const { + return + typeid(other) == typeid(std::remove_cvref_t) && + this->m_offset == other.m_offset && + this->m_size == other.m_size && + this->m_hidden == other.m_hidden && + this->m_endian == other.m_endian && + this->m_variableName == other.m_variableName && + this->m_typeName == other.m_typeName && + this->m_comment == other.m_comment && + this->m_local == other.m_local; + } + protected: void createDefaultEntry(const std::string &value) const { ImGui::TableNextRow(); @@ -252,6 +270,8 @@ namespace hex::pl { [[nodiscard]] std::string getFormattedName() const override { return ""; } + + [[nodiscard]] bool operator==(const PatternData &other) const override { return areCommonPropertiesEqual(other); } }; class PatternDataPointer : public PatternData { @@ -341,6 +361,11 @@ namespace hex::pl { return this->m_pointedAt; } + [[nodiscard]] bool operator==(const PatternData &other) const override { + return areCommonPropertiesEqual(other) && + *static_cast(&other)->m_pointedAt == *this->m_pointedAt; + } + private: PatternData *m_pointedAt; }; @@ -372,6 +397,8 @@ namespace hex::pl { default: return "Unsigned data"; } } + + [[nodiscard]] bool operator==(const PatternData &other) const override { return areCommonPropertiesEqual(other); } }; class PatternDataSigned : public PatternData { @@ -432,6 +459,8 @@ namespace hex::pl { default: return "Signed data"; } } + + [[nodiscard]] bool operator==(const PatternData &other) const override { return areCommonPropertiesEqual(other); } }; class PatternDataFloat : public PatternData { @@ -466,6 +495,8 @@ namespace hex::pl { default: return "Floating point data"; } } + + [[nodiscard]] bool operator==(const PatternData &other) const override { return areCommonPropertiesEqual(other); } }; class PatternDataBoolean : public PatternData { @@ -492,6 +523,8 @@ namespace hex::pl { [[nodiscard]] std::string getFormattedName() const override { return "bool"; } + + [[nodiscard]] bool operator==(const PatternData &other) const override { return areCommonPropertiesEqual(other); } }; class PatternDataCharacter : public PatternData { @@ -513,6 +546,8 @@ namespace hex::pl { [[nodiscard]] std::string getFormattedName() const override { return "char"; } + + [[nodiscard]] bool operator==(const PatternData &other) const override { return areCommonPropertiesEqual(other); } }; class PatternDataCharacter16 : public PatternData { @@ -535,6 +570,8 @@ namespace hex::pl { [[nodiscard]] std::string getFormattedName() const override { return "char16"; } + + [[nodiscard]] bool operator==(const PatternData &other) const override { return areCommonPropertiesEqual(other); } }; class PatternDataString : public PatternData { @@ -556,6 +593,8 @@ namespace hex::pl { [[nodiscard]] std::string getFormattedName() const override { return "String"; } + + [[nodiscard]] bool operator==(const PatternData &other) const override { return areCommonPropertiesEqual(other); } }; class PatternDataString16 : public PatternData { @@ -582,6 +621,8 @@ namespace hex::pl { [[nodiscard]] std::string getFormattedName() const override { return "String16"; } + + [[nodiscard]] bool operator==(const PatternData &other) const override { return areCommonPropertiesEqual(other); } }; class PatternDataDynamicArray : public PatternData { @@ -692,6 +733,22 @@ namespace hex::pl { } } + [[nodiscard]] bool operator==(const PatternData &other) const override { + if (!areCommonPropertiesEqual(other)) + return false; + + auto &otherArray = *static_cast(&other); + if (this->m_entries.size() != otherArray.m_entries.size()) + return false; + + for (u64 i = 0; i < this->m_entries.size(); i++) { + if (*this->m_entries[i] != *otherArray.m_entries[i]) + return false; + } + + return true; + } + private: std::vector m_entries; }; @@ -813,6 +870,14 @@ namespace hex::pl { this->m_template->setParent(this); } + [[nodiscard]] bool operator==(const PatternData &other) const override { + if (!areCommonPropertiesEqual(other)) + return false; + + auto &otherArray = *static_cast(&other); + return *this->m_template == *otherArray.m_template && this->m_entryCount == otherArray.m_entryCount; + } + private: PatternData *m_template; size_t m_entryCount; @@ -927,6 +992,22 @@ namespace hex::pl { this->m_sortedMembers = this->m_members; } + [[nodiscard]] bool operator==(const PatternData &other) const override { + if (!areCommonPropertiesEqual(other)) + return false; + + auto &otherStruct = *static_cast(&other); + if (this->m_members.size() != otherStruct.m_members.size()) + return false; + + for (u64 i = 0; i < this->m_members.size(); i++) { + if (*this->m_members[i] != *otherStruct.m_members[i]) + return false; + } + + return true; + } + private: std::vector m_members; std::vector m_sortedMembers; @@ -1042,6 +1123,22 @@ namespace hex::pl { this->m_sortedMembers = this->m_members; } + [[nodiscard]] bool operator==(const PatternData &other) const override { + if (!areCommonPropertiesEqual(other)) + return false; + + auto &otherUnion = *static_cast(&other); + if (this->m_members.size() != otherUnion.m_members.size()) + return false; + + for (u64 i = 0; i < this->m_members.size(); i++) { + if (*this->m_members[i] != *otherUnion.m_members[i]) + return false; + } + + return true; + } + private: std::vector m_members; std::vector m_sortedMembers; @@ -1115,6 +1212,22 @@ namespace hex::pl { this->m_enumValues = enumValues; } + [[nodiscard]] bool operator==(const PatternData &other) const override { + if (!areCommonPropertiesEqual(other)) + return false; + + auto &otherEnum = *static_cast(&other); + if (this->m_enumValues.size() != otherEnum.m_enumValues.size()) + return false; + + for (u64 i = 0; i < this->m_enumValues.size(); i++) { + if (this->m_enumValues[i] != otherEnum.m_enumValues[i]) + return false; + } + + return true; + } + private: std::vector> m_enumValues; }; @@ -1175,6 +1288,14 @@ namespace hex::pl { return this->m_bitSize; } + [[nodiscard]] bool operator==(const PatternData &other) const override { + if (!areCommonPropertiesEqual(other)) + return false; + + auto &otherBitfieldField = *static_cast(&other); + return this->m_bitOffset == otherBitfieldField.m_bitOffset && this->m_bitSize == otherBitfieldField.m_bitSize; + } + private: u8 m_bitOffset, m_bitSize; }; @@ -1246,6 +1367,22 @@ namespace hex::pl { field->setSize(this->getSize()); } + [[nodiscard]] bool operator==(const PatternData &other) const override { + if (!areCommonPropertiesEqual(other)) + return false; + + auto &otherBitfield = *static_cast(&other); + if (this->m_fields.size() != otherBitfield.m_fields.size()) + return false; + + for (u64 i = 0; i < this->m_fields.size(); i++) { + if (this->m_fields[i] != otherBitfield.m_fields[i]) + return false; + } + + return true; + } + private: std::vector m_fields; }; diff --git a/plugins/libimhex/source/helpers/file.cpp b/plugins/libimhex/source/helpers/file.cpp index 94ac0aa01..b3cb89979 100644 --- a/plugins/libimhex/source/helpers/file.cpp +++ b/plugins/libimhex/source/helpers/file.cpp @@ -13,6 +13,10 @@ namespace hex { this->m_file = fopen64(path.c_str(), "w+b"); } + File::File() { + this->m_file = nullptr; + } + File::~File() { if (isValid()) fclose(this->m_file); @@ -22,7 +26,15 @@ namespace hex { fseeko64(this->m_file, offset, SEEK_SET); } + size_t File::readBuffer(u8 *buffer, size_t size) { + if (!isValid()) return 0; + + return fread(buffer, size, 1, this->m_file); + } + std::vector File::readBytes(size_t numBytes) { + if (!isValid()) return { }; + std::vector bytes(numBytes ?: getSize()); auto bytesRead = fread(bytes.data(), bytes.size(), 1, this->m_file); @@ -32,18 +44,32 @@ namespace hex { } std::string File::readString(size_t numBytes) { + if (!isValid()) return { }; + return reinterpret_cast(readBytes(numBytes).data()); } + void File::write(const u8 *buffer, size_t size) { + if (!isValid()) return; + + fwrite(buffer, size, 1, this->m_file); + } + void File::write(const std::vector &bytes) { + if (!isValid()) return; + fwrite(bytes.data(), bytes.size(), 1, this->m_file); } void File::write(const std::string &string) { + if (!isValid()) return; + fwrite(string.data(), string.size(), 1, this->m_file); } size_t File::getSize() { + if (!isValid()) return 0; + auto startPos = ftello64(this->m_file); fseeko64(this->m_file, 0, SEEK_END); size_t size = ftello64(this->m_file); @@ -53,6 +79,8 @@ namespace hex { } void File::setSize(u64 size) { + if (!isValid()) return; + ftruncate64(fileno(this->m_file), size); } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index e69de29bb..fcbae7851 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -0,0 +1,23 @@ +cmake_minimum_required(VERSION 3.16) + +project(tests) + + +# Add new tests here # +set(AVAILABLE_TESTS + Placement +) + + + +add_executable(tests source/main.cpp ) +target_include_directories(tests PRIVATE include) +target_link_libraries(tests libimhex) +set_target_properties(tests PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) +add_custom_command(TARGET tests + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy "${CMAKE_CURRENT_SOURCE_DIR}/test_data" ${CMAKE_BINARY_DIR}) + +foreach (test IN LISTS AVAILABLE_TESTS) + add_test(NAME "${test}" COMMAND tests "${test}" WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) +endforeach () \ No newline at end of file diff --git a/tests/include/test_patterns/test_pattern.hpp b/tests/include/test_patterns/test_pattern.hpp index 99bae8e86..aa423b2ea 100644 --- a/tests/include/test_patterns/test_pattern.hpp +++ b/tests/include/test_patterns/test_pattern.hpp @@ -1,8 +1,39 @@ -// -// Created by werwo on 11/09/2021. -// +#pragma once -#ifndef IMHEX_TEST_PATTERN_HPP -#define IMHEX_TEST_PATTERN_HPP +#include +#include -#endif //IMHEX_TEST_PATTERN_HPP +#include + +namespace hex::test { + + class TestPattern { + public: + TestPattern() = default; + virtual ~TestPattern() { + for (auto &pattern : this->m_patterns) + delete pattern; + } + + template + static T* createVariablePattern(u64 offset, size_t size, const std::string &typeName, const std::string &varName) { + auto pattern = new T(offset, size); + pattern->setTypeName(typeName); + pattern->setVariableName(varName); + + return pattern; + } + + virtual std::string getSourceCode() const = 0; + + [[nodiscard]] + virtual const std::vector& getPatterns() const final { return this->m_patterns; } + virtual void addPattern(pl::PatternData *pattern) final { + this->m_patterns.push_back(pattern); + } + + private: + std::vector m_patterns; + }; + +} \ No newline at end of file diff --git a/tests/include/test_patterns/test_pattern_example.hpp b/tests/include/test_patterns/test_pattern_example.hpp index d8ef065e2..40d66695a 100644 --- a/tests/include/test_patterns/test_pattern_example.hpp +++ b/tests/include/test_patterns/test_pattern_example.hpp @@ -1,30 +1,23 @@ #pragma once -#include -#include - -#include +#include "test_pattern.hpp" namespace hex::test { - class TestPattern { + class TestPatternExample : public TestPattern { public: - TestPattern() = default; - virtual TestPattern() { - for (auto &pattern : this->m_patterns) - delete pattern; - } + TestPatternExample() { - virtual std::string getSourceCode() = 0; + } + ~TestPatternExample() override = default; [[nodiscard]] - virtual const std::vector& getPatterns() const final { return this->m_patterns; } - virtual void addPattern(pl::PatternData *pattern) final { - this->m_patterns.push_back(pattern); + std::string getSourceCode() const override { + return R"( + + )"; } - private: - std::vector m_patterns; }; } \ No newline at end of file diff --git a/tests/include/test_patterns/test_pattern_placement.hpp b/tests/include/test_patterns/test_pattern_placement.hpp index 4a413e109..a78c36202 100644 --- a/tests/include/test_patterns/test_pattern_placement.hpp +++ b/tests/include/test_patterns/test_pattern_placement.hpp @@ -4,20 +4,29 @@ namespace hex::test { - class TestPatternExample : public TestPattern { + class TestPatternPlacement : public TestPattern { public: - TestPatternExample() { - auto placementTest = new pl::PatternDataSigned(0x00, sizeof(u32)); - placementTest->setTypeName("u32"); - placementTest->setVariableName("placementTest"); - addPattern(placementTest); + TestPatternPlacement() { + // placementVar + { + addPattern(createVariablePattern(0x00, sizeof(u32), "u32", "placementVar")); + } + + // placementArray + { + auto placementArray = createVariablePattern(0x10, sizeof(u8) * 10, "u8", "placementArray"); + placementArray->setEntries(createVariablePattern(0x10, sizeof(u8), "u8", ""), 10); + addPattern(placementArray); + } + } - ~TestPatternExample() override = default; + ~TestPatternPlacement() override = default; [[nodiscard]] std::string getSourceCode() const override { return R"( - u32 placementTest @ 0x00; + u32 placementVar @ 0x00; + u8 placementArray[10] @ 0x10; )"; } diff --git a/tests/include/test_provider.hpp b/tests/include/test_provider.hpp index ac255fc28..a88c22710 100644 --- a/tests/include/test_provider.hpp +++ b/tests/include/test_provider.hpp @@ -1,8 +1,49 @@ -// -// Created by werwo on 11/09/2021. -// +#include -#ifndef IMHEX_TEST_PROVIDER_HPP -#define IMHEX_TEST_PROVIDER_HPP +#include +#include +#include -#endif //IMHEX_TEST_PROVIDER_HPP +namespace hex::test { + using namespace hex::prv; + + class TestProvider : public prv::Provider { + public: + TestProvider() : Provider() { + this->m_testFile = File("test_data", File::Mode::Read); + if (!this->m_testFile.isValid()) { + hex::log::fatal("Failed to open test data!"); + throw std::runtime_error(""); + } + } + ~TestProvider() override = default; + + bool isAvailable() override { return true; } + bool isReadable() override { return true; } + bool isWritable() override { return false; } + bool isResizable() override { return false; } + bool isSavable() override { return false; } + + std::vector> getDataInformation() override { + return { }; + } + + void readRaw(u64 offset, void *buffer, size_t size) override { + this->m_testFile.seek(offset); + this->m_testFile.readBuffer(static_cast(buffer), size); + } + + void writeRaw(u64 offset, const void *buffer, size_t size) override { + this->m_testFile.seek(offset); + this->m_testFile.write(static_cast(buffer), size); + } + + size_t getActualSize() override { + return m_testFile.getSize(); + } + + private: + File m_testFile; + }; + +} \ No newline at end of file diff --git a/tests/source/main.cpp b/tests/source/main.cpp index 137fcc3ef..62f8578ab 100644 --- a/tests/source/main.cpp +++ b/tests/source/main.cpp @@ -1 +1,83 @@ -#include \ No newline at end of file +#include +#include +#include + +#include +#include +#include + +#include "test_provider.hpp" +#include "test_patterns/test_pattern_placement.hpp" + +using namespace hex::test; + +static std::map testPatterns { + { "Placement", new TestPatternPlacement() } +}; + +int main(int argc, char **argv) { + ON_SCOPE_EXIT { + for (auto &[key, value] : testPatterns) + delete value; + }; + + // Check if a test to run has been provided + if (argc != 2) { + hex::log::fatal("Invalid number of arguments specified! {}", argc); + return EXIT_FAILURE; + } + + // Check if that test exists + std::string testName = argv[1]; + if (!testPatterns.contains(testName)) { + hex::log::fatal("No test with name {} found!", testName); + return EXIT_FAILURE; + } + + const auto &currTest = testPatterns[testName]; + + auto provider = new TestProvider(); + ON_SCOPE_EXIT { delete provider; }; + if (provider->getActualSize() == 0) { + hex::log::fatal("Failed to load Testing Data"); + return EXIT_FAILURE; + } + + hex::pl::PatternLanguage language; + + // Check if compilation succeeded + auto patterns = language.executeString(provider, testPatterns[testName]->getSourceCode()); + if (!patterns.has_value()) { + hex::log::fatal("Error during compilation!"); + for (auto &[level, line] : language.getConsoleLog()) + hex::log::info("PL: {}", line); + + return EXIT_FAILURE; + } + + ON_SCOPE_EXIT { + for (auto &pattern : *patterns) + delete pattern; + }; + + // Check if the right number of patterns have been produced + if (patterns->size() != currTest->getPatterns().size()) { + hex::log::fatal("Source didn't produce expected number of patterns"); + return EXIT_FAILURE; + } + + // Check if the produced patterns are the ones expected + for (u32 i = 0; i < patterns->size(); i++) { + auto &left = *patterns->at(i); + auto &right = *currTest->getPatterns().at(i); + + if (left != right) { + hex::log::fatal("Pattern with name {}:{} didn't match template", patterns->at(i)->getTypeName(), patterns->at(i)->getVariableName()); + return EXIT_FAILURE; + } + } + + hex::log::info("Success!"); + + return EXIT_SUCCESS; +} \ No newline at end of file diff --git a/tests/test_data b/tests/test_data new file mode 100644 index 000000000..15ef05f10 Binary files /dev/null and b/tests/test_data differ