#include <fstream>
#include <iomanip>
#include <iostream>
#include <memory>
#include <openssl/md5.h>
#include <sstream>
#include <vector>

#ifdef _WIN32
#include <filesystem>
#include <windows.h>
namespace fs = std::filesystem;
std::string GetExecutablePath()
{
    char buffer[MAX_PATH];
    DWORD length = GetModuleFileName(nullptr, buffer, MAX_PATH);
    if (length == 0) {
        return "";
    }
    return std::string(buffer);
}
#endif

std::string to_hex_string(const unsigned char* hash, size_t length)
{
    std::ostringstream oss;
    for (size_t i = 0; i < length; ++i) {
        oss << std::hex << std::setw(2) << std::setfill('0')
            << static_cast<int>(hash[i]);
    }
    return oss.str();
}

std::string md5_string(const std::string& input)
{
    unsigned char hash[MD5_DIGEST_LENGTH];
    MD5(reinterpret_cast<const unsigned char*>(input.c_str()), input.size(),
        hash);
    return to_hex_string(hash, MD5_DIGEST_LENGTH);
}

std::string md5_file(const std::string& filename)
{
    std::ifstream file(filename, std::ios::binary);
    if (!file) {
        throw std::runtime_error("Cannot open file: " + filename);
    }

    MD5_CTX ctx;
    MD5_Init(&ctx);

    std::vector<char> buffer(8192);
    while (file.good()) {
        file.read(buffer.data(), buffer.size());
        MD5_Update(&ctx, buffer.data(), file.gcount());
    }

    unsigned char hash[MD5_DIGEST_LENGTH];
    MD5_Final(hash, &ctx);

    return to_hex_string(hash, MD5_DIGEST_LENGTH);
}

int main(int argc, char* argv[])
{

#ifdef _WIN32
#include <windows.h>
    std::string load_dll_dir = fs::path(GetExecutablePath()).parent_path().append("mdsha_dll").string();
    SetDllDirectory(load_dll_dir.c_str());
#endif

    if (argc != 3) {
        std::cerr << "Usage: " << argv[0] << " <mode> <input>" << std::endl;
        std::cerr << "Support DelayLoad SubDirectory:mdsha_dll" << std::endl;
        std::cerr << "  -s : Calculate MD5 for a string" << std::endl;
        std::cerr << "  -f : Calculate MD5 for a file" << std::endl;
        return 1;
    }

    std::string mode = argv[1];
    std::string input = argv[2];

    try {
        if (mode == "-s") {
            std::cout << "MD5(string): " << md5_string(input) << std::endl;
        } else if (mode == "-f") {
            std::cout << "MD5(file): " << md5_file(input) << std::endl;
        } else {
            std::cerr << "Invalid mode: " << mode << std::endl;
            return 1;
        }
    } catch (const std::exception& e) {
        std::cerr << "Error: " << e.what() << std::endl;
        return 1;
    }

    return 0;
}