/// Lock-Free Allocator — test suite and benchmarks
///
/// Tests cover:
///   1. Size class mapping correctness
///   2. Single-threaded allocation and deallocation
///   3. Write-after-allocate (memory is usable)
///   4. Multi-threaded stress test
///   5. Cross-thread free (allocate on thread A, free on thread B)
///   6. Large allocations (> 32 KB, direct mmap path)
///   7. STL container integration via lfa::Allocator<T>
///   8. Throughput benchmark vs. malloc/free

#include "allocator.h"

#include <algorithm>
#include <atomic>
#include <cassert>
#include <chrono>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <numeric>
#include <thread>
#include <vector>

// ---------------------------------------------------------------------------
// Minimal test harness
// ---------------------------------------------------------------------------

static int g_tests_run    = 0;
static int g_tests_passed = 0;

#define TEST(name)                                                             \
    static void test_##name();                                                 \
    static struct Register_##name {                                            \
        Register_##name() { test_registry().push_back({#name, test_##name}); } \
    } reg_##name;                                                              \
    static void test_##name()

struct TestEntry { const char* name; void (*fn)(); };
static std::vector<TestEntry>& test_registry() {
    static std::vector<TestEntry> v;
    return v;
}

static void run_all_tests() {
    for (auto& [name, fn] : test_registry()) {
        ++g_tests_run;
        std::printf("  %-45s ", name);
        std::fflush(stdout);
        try {
            fn();
            ++g_tests_passed;
            std::printf("PASS\n");
        } catch (const std::exception& e) {
            std::printf("FAIL (%s)\n", e.what());
        } catch (...) {
            std::printf("FAIL (unknown exception)\n");
        }
    }
}

// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------

TEST(size_class_mapping) {
    // Every size in [1, 16] should map to class 0 (block size 16).
    for (std::size_t s = 1; s <= 16; ++s) {
        assert(lfa::size_class(s) == 0);
        assert(lfa::class_size(0) == 16);
    }

    // Exact power-of-two boundaries.
    assert(lfa::size_class(17) == 1);   // -> 32
    assert(lfa::size_class(32) == 1);
    assert(lfa::size_class(33) == 2);   // -> 64
    assert(lfa::size_class(64) == 2);
    assert(lfa::size_class(65) == 3);   // -> 128
    assert(lfa::size_class(32768) == 11);
    assert(lfa::class_size(11) == 32768);
}

TEST(single_thread_alloc_free) {
    constexpr int N = 10000;
    std::vector<void*> ptrs(N);

    for (int i = 0; i < N; ++i) {
        ptrs[i] = lfa::allocate(64);
        assert(ptrs[i] != nullptr);
    }

    for (int i = 0; i < N; ++i)
        lfa::deallocate(ptrs[i], 64);
}

TEST(write_after_allocate) {
    // Allocate blocks of various sizes and verify we can write/read them.
    for (std::size_t sz : {1, 16, 31, 64, 128, 255, 512, 4096, 32768}) {
        auto* p = static_cast<char*>(lfa::allocate(sz));
        assert(p != nullptr);
        std::memset(p, 0xAB, sz);
        for (std::size_t i = 0; i < sz; ++i)
            assert(static_cast<unsigned char>(p[i]) == 0xAB);
        lfa::deallocate(p, sz);
    }
}

TEST(zero_and_null_edge_cases) {
    // Allocating zero bytes should succeed (internally rounded to 1).
    void* p = lfa::allocate(0);
    assert(p != nullptr);
    lfa::deallocate(p, 0);

    // Deallocating nullptr should be a no-op.
    lfa::deallocate(nullptr, 64);
}

TEST(large_allocation) {
    // Sizes above kMaxSmallSize (32 KB) take the direct mmap path.
    constexpr std::size_t sz = 1 << 20; // 1 MB
    auto* p = static_cast<char*>(lfa::allocate(sz));
    assert(p != nullptr);
    std::memset(p, 0xCD, sz);
    assert(static_cast<unsigned char>(p[0]) == 0xCD);
    assert(static_cast<unsigned char>(p[sz - 1]) == 0xCD);
    lfa::deallocate(p, sz);
}

TEST(multithread_stress) {
    // Many threads each doing rapid alloc/free cycles concurrently.
    constexpr int kThreads     = 8;
    constexpr int kOpsPerThread = 50000;
    constexpr std::size_t kSize = 128;

    std::atomic<int> errors{0};
    std::vector<std::thread> threads;
    threads.reserve(kThreads);

    for (int t = 0; t < kThreads; ++t) {
        threads.emplace_back([&errors] {
            std::vector<void*> ptrs;
            ptrs.reserve(256);

            for (int i = 0; i < kOpsPerThread; ++i) {
                if (i % 3 != 0 || ptrs.empty()) {
                    // Allocate.
                    void* p = lfa::allocate(kSize);
                    if (!p) { errors.fetch_add(1, std::memory_order_relaxed); continue; }
                    // Write a sentinel to detect corruption.
                    std::memset(p, 0x42, kSize);
                    ptrs.push_back(p);
                } else {
                    // Free the most recent allocation.
                    void* p = ptrs.back();
                    ptrs.pop_back();
                    // Verify sentinel before freeing.
                    auto* cp = static_cast<unsigned char*>(p);
                    if (cp[0] != 0x42 || cp[kSize - 1] != 0x42)
                        errors.fetch_add(1, std::memory_order_relaxed);
                    lfa::deallocate(p, kSize);
                }
            }

            // Clean up remaining allocations.
            for (void* p : ptrs)
                lfa::deallocate(p, kSize);
        });
    }

    for (auto& th : threads) th.join();
    assert(errors.load() == 0);
}

TEST(cross_thread_free) {
    // Allocate on the main thread, free on worker threads.
    constexpr int N = 4000;
    constexpr int kWorkers = 4;
    constexpr std::size_t kSize = 256;

    std::vector<void*> ptrs(N);
    for (int i = 0; i < N; ++i) {
        ptrs[i] = lfa::allocate(kSize);
        assert(ptrs[i] != nullptr);
        std::memset(ptrs[i], 0xFF, kSize);
    }

    // Partition the pointers among workers.
    std::vector<std::thread> workers;
    workers.reserve(kWorkers);
    int chunk = N / kWorkers;

    for (int w = 0; w < kWorkers; ++w) {
        int lo = w * chunk;
        int hi = (w == kWorkers - 1) ? N : lo + chunk;
        workers.emplace_back([&ptrs, lo, hi, kSize] {
            for (int i = lo; i < hi; ++i)
                lfa::deallocate(ptrs[i], kSize);
        });
    }

    for (auto& th : workers) th.join();
}

TEST(mixed_size_classes) {
    // Interleave allocations across multiple size classes.
    constexpr int N = 2000;
    struct Entry { void* ptr; std::size_t size; };
    std::vector<Entry> entries;
    entries.reserve(N);

    std::size_t sizes[] = {8, 24, 64, 100, 512, 1024, 8192, 32768};

    for (int i = 0; i < N; ++i) {
        std::size_t sz = sizes[i % std::size(sizes)];
        void* p = lfa::allocate(sz);
        assert(p != nullptr);
        std::memset(p, static_cast<int>(i & 0xFF), sz);
        entries.push_back({p, sz});
    }

    // Free in reverse order.
    for (int i = N - 1; i >= 0; --i)
        lfa::deallocate(entries[i].ptr, entries[i].size);
}

TEST(stl_vector_with_custom_allocator) {
    std::vector<int, lfa::Allocator<int>> v;
    for (int i = 0; i < 10000; ++i)
        v.push_back(i);

    // Verify contents.
    for (int i = 0; i < 10000; ++i)
        assert(v[i] == i);
}

TEST(thread_cache_flush_cycle) {
    // Force the thread cache to fill up and flush back to the global heap
    // by allocating many blocks, freeing them all, then allocating again.
    constexpr int N = 1000;
    constexpr std::size_t kSize = 32;
    std::vector<void*> ptrs(N);

    // First round: fill up the thread cache.
    for (int i = 0; i < N; ++i) {
        ptrs[i] = lfa::allocate(kSize);
        assert(ptrs[i] != nullptr);
    }
    for (int i = 0; i < N; ++i)
        lfa::deallocate(ptrs[i], kSize);

    // Second round: blocks should come from the cache or global free list.
    for (int i = 0; i < N; ++i) {
        ptrs[i] = lfa::allocate(kSize);
        assert(ptrs[i] != nullptr);
        std::memset(ptrs[i], 0x55, kSize);
    }
    for (int i = 0; i < N; ++i)
        lfa::deallocate(ptrs[i], kSize);
}

// ---------------------------------------------------------------------------
// Benchmark: lfa vs. malloc throughput
// ---------------------------------------------------------------------------

static void benchmark() {
    constexpr int kThreads      = 4;
    constexpr int kOpsPerThread  = 500000;
    constexpr std::size_t kSize  = 64;

    auto bench = [](const char* label, auto alloc_fn, auto free_fn) {
        auto start = std::chrono::steady_clock::now();
        std::vector<std::thread> threads;
        threads.reserve(kThreads);

        for (int t = 0; t < kThreads; ++t) {
            threads.emplace_back([&] {
                for (int i = 0; i < kOpsPerThread; ++i) {
                    void* p = alloc_fn(kSize);
                    free_fn(p, kSize);
                }
            });
        }
        for (auto& th : threads) th.join();

        auto elapsed = std::chrono::steady_clock::now() - start;
        double ms = std::chrono::duration<double, std::milli>(elapsed).count();
        double mops = (static_cast<double>(kThreads) * kOpsPerThread) / (ms * 1000.0);
        std::printf("  %-12s %8.1f ms  (%6.2f M ops/sec)\n", label, ms, mops);
    };

    std::printf("\n  Benchmark: %d threads x %d alloc+free pairs (%zu bytes)\n\n",
                kThreads, kOpsPerThread, kSize);

    bench("lfa",
          [](std::size_t sz) { return lfa::allocate(sz); },
          [](void* p, std::size_t sz) { lfa::deallocate(p, sz); });

    bench("malloc",
          [](std::size_t sz) { return std::malloc(sz); },
          [](void* p, [[maybe_unused]] std::size_t sz) { std::free(p); });
}

// ---------------------------------------------------------------------------
// Entry point
// ---------------------------------------------------------------------------

int main() {
    std::printf("Lock-Free Allocator — Tests\n\n");
    run_all_tests();

    std::printf("\n  %d/%d tests passed.\n", g_tests_passed, g_tests_run);

    if (g_tests_passed == g_tests_run) {
        benchmark();
        std::printf("\n  All tests passed.\n\n");
        return 0;
    } else {
        std::printf("\n  SOME TESTS FAILED.\n\n");
        return 1;
    }
}
