#pragma once
/// Lock-Free Memory Allocator (C++20)
///
/// Architecture:
///   - Per-thread caches provide a contention-free fast path.
///   - Global free lists (one per size class) use lock-free Treiber stacks
///     with tagged pointers for ABA prevention.
///   - Superblocks (64 KB mmap regions) are carved into fixed-size blocks.
///   - Large allocations (> 32 KB) go directly through mmap/munmap.
///
/// Size classes: 16, 32, 64, 128, ..., 32768 bytes (powers of two).

#include <array>
#include <atomic>
#include <bit>
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <cstdlib>
#include <cstring>
#include <new>

// Detect sanitizer builds.  Under TSan or ASan the shadow memory layout
// conflicts with raw mmap, so we fall back to aligned_alloc/free which the
// sanitizer runtime can intercept and instrument properly.
#if defined(__SANITIZE_THREAD__) || defined(__SANITIZE_ADDRESS__)
#   define LFA_USE_MALLOC_BACKEND 1
#elif defined(__has_feature)
#   if __has_feature(thread_sanitizer) || __has_feature(address_sanitizer)
#       define LFA_USE_MALLOC_BACKEND 1
#   endif
#endif

#ifndef LFA_USE_MALLOC_BACKEND
#   define LFA_USE_MALLOC_BACKEND 0
#   include <sys/mman.h>
#endif

namespace lfa {

// ============================================================================
// Configuration
// ============================================================================

inline constexpr std::size_t kMinBlockSize       = 16;
inline constexpr std::size_t kMaxSmallSize       = 32768;   // 32 KB
inline constexpr std::size_t kSuperblockSize     = 65536;   // 64 KB
inline constexpr std::size_t kNumSizeClasses     = 12;      // log2(32768/16) + 1
inline constexpr std::size_t kBatchTransferSize  = 32;
inline constexpr std::size_t kThreadCacheMaxSize = 256;     // Per size class
inline constexpr std::size_t kPageSize           = 4096;

// ============================================================================
// Platform memory backend
// ============================================================================

/// Allocate `size` bytes of page-aligned memory from the OS (or libc under
/// sanitizers).  Returns nullptr on failure.
inline void* os_alloc(std::size_t size) noexcept {
#if LFA_USE_MALLOC_BACKEND
    // Under sanitizers, use aligned_alloc so TSan/ASan can track the memory.
    void* p = std::aligned_alloc(kPageSize, size);
    if (p) std::memset(p, 0, size);  // Match mmap's zero-fill guarantee.
    return p;
#else
    void* p = ::mmap(nullptr, size,
                     PROT_READ | PROT_WRITE,
                     MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
    return (p == MAP_FAILED) ? nullptr : p;
#endif
}

/// Release memory obtained from os_alloc.
inline void os_free(void* ptr, [[maybe_unused]] std::size_t size) noexcept {
#if LFA_USE_MALLOC_BACKEND
    std::free(ptr);
#else
    ::munmap(ptr, size);
#endif
}

// ============================================================================
// Size class utilities
// ============================================================================

/// Map a byte size to its size class index (0..kNumSizeClasses-1).
inline constexpr std::size_t size_class(std::size_t size) noexcept {
    if (size <= kMinBlockSize) return 0;
    std::size_t rounded = std::bit_ceil(size);
    auto cls = static_cast<std::size_t>(
        std::countr_zero(rounded) - std::countr_zero(kMinBlockSize));
    assert(cls < kNumSizeClasses);
    return cls;
}

/// Map a size class index back to its block size in bytes.
inline constexpr std::size_t class_size(std::size_t cls) noexcept {
    return kMinBlockSize << cls;
}

// ============================================================================
// FreeBlock — intrusive list node embedded inside free memory
// ============================================================================

struct FreeBlock {
    FreeBlock* next;
};

// ============================================================================
// TaggedPtr — 48-bit pointer + 16-bit ABA counter in a single 64-bit word
//
// On x86-64, user-space virtual addresses occupy at most 48 bits (canonical
// form).  We pack a monotonic tag in the upper 16 bits so that every CAS
// observes a unique value even if a pointer is recycled.  The 16-bit tag can
// wrap, but 65 536 concurrent ABA collisions on the exact same cache line
// between a load and its CAS is not a realistic scenario.
// ============================================================================

struct TaggedPtr {
    std::uintptr_t bits = 0;

    TaggedPtr() noexcept = default;

    TaggedPtr(FreeBlock* p, std::uint16_t tag) noexcept
        : bits{reinterpret_cast<std::uintptr_t>(p) |
               (static_cast<std::uintptr_t>(tag) << 48)} {}

    FreeBlock* ptr() const noexcept {
        return reinterpret_cast<FreeBlock*>(bits & ((1ULL << 48) - 1));
    }

    std::uint16_t tag() const noexcept {
        return static_cast<std::uint16_t>(bits >> 48);
    }

    bool operator==(const TaggedPtr&) const noexcept = default;
};

static_assert(sizeof(TaggedPtr) == sizeof(std::uintptr_t));
static_assert(std::atomic<TaggedPtr>::is_always_lock_free,
              "TaggedPtr must be lock-free for the allocator to be lock-free");

// ============================================================================
// LockFreeStack — Treiber stack with tagged-pointer ABA prevention
// ============================================================================

class LockFreeStack {
public:
    /// Push a single block onto the stack.
    void push(FreeBlock* node) noexcept {
        TaggedPtr old_head = head_.load(std::memory_order_relaxed);
        TaggedPtr new_head;
        do {
            node->next = old_head.ptr();
            new_head = TaggedPtr{node, static_cast<std::uint16_t>(old_head.tag() + 1)};
        } while (!head_.compare_exchange_weak(
            old_head, new_head,
            std::memory_order_release, std::memory_order_relaxed));
    }

    /// Pop a single block.  Returns nullptr if the stack is empty.
    FreeBlock* pop() noexcept {
        TaggedPtr old_head = head_.load(std::memory_order_acquire);
        TaggedPtr new_head;
        do {
            if (!old_head.ptr()) return nullptr;
            // Reading old_head.ptr()->next is safe: the memory backing every
            // block is always mapped (superblocks are never returned to the OS
            // during normal operation).  The tagged pointer prevents the ABA
            // problem from corrupting the list even if the node is concurrently
            // popped and reinserted by another thread.
            new_head = TaggedPtr{old_head.ptr()->next,
                                 static_cast<std::uint16_t>(old_head.tag() + 1)};
        } while (!head_.compare_exchange_weak(
            old_head, new_head,
            std::memory_order_acq_rel, std::memory_order_acquire));
        return old_head.ptr();
    }

    /// Push a linked batch (first -> ... -> last -> *).
    /// The caller must ensure last->next is unused; it will be overwritten.
    void push_batch(FreeBlock* first, FreeBlock* last) noexcept {
        TaggedPtr old_head = head_.load(std::memory_order_relaxed);
        TaggedPtr new_head;
        do {
            last->next = old_head.ptr();
            new_head = TaggedPtr{first, static_cast<std::uint16_t>(old_head.tag() + 1)};
        } while (!head_.compare_exchange_weak(
            old_head, new_head,
            std::memory_order_release, std::memory_order_relaxed));
    }

private:
    std::atomic<TaggedPtr> head_{};
};

// ============================================================================
// SuperblockHeader — bookkeeping at the start of each mmap'd region
// ============================================================================

struct alignas(64) SuperblockHeader {
    std::size_t       block_size;
    std::size_t       num_blocks;
    SuperblockHeader* next;         // Global tracking list
};

// ============================================================================
// GlobalHeap — singleton owning the per-class free lists and superblocks
// ============================================================================

class GlobalHeap {
public:
    static GlobalHeap& instance() noexcept {
        static GlobalHeap heap;
        return heap;
    }

    /// Pop one block from the global free list for the given size class.
    FreeBlock* pop(std::size_t cls) noexcept {
        return free_lists_[cls].pop();
    }

    /// Push one block back to the global free list.
    void push(std::size_t cls, FreeBlock* block) noexcept {
        free_lists_[cls].push(block);
    }

    /// Push a linked batch back to the global free list.
    void push_batch(std::size_t cls, FreeBlock* first, FreeBlock* last) noexcept {
        free_lists_[cls].push_batch(first, last);
    }

    /// Allocate a new superblock, carve it into blocks of the requested class,
    /// and return the linked list.  Sets `count` to the number of blocks.
    FreeBlock* allocate_superblock(std::size_t cls, std::size_t& count) {
        std::size_t bsz = class_size(cls);

        void* mem = os_alloc(kSuperblockSize);
        if (!mem) { count = 0; return nullptr; }

        // Place the header at the beginning of the superblock.
        auto* hdr = static_cast<SuperblockHeader*>(mem);
        hdr->block_size = bsz;

        // First block starts at the next block-aligned offset after the header.
        std::size_t offset = ((sizeof(SuperblockHeader) + bsz - 1) / bsz) * bsz;
        char* base = static_cast<char*>(mem) + offset;
        std::size_t n = (kSuperblockSize - offset) / bsz;
        hdr->num_blocks = n;

        // Track this superblock in a lock-free singly-linked list.
        SuperblockHeader* old = superblocks_.load(std::memory_order_relaxed);
        do {
            hdr->next = old;
        } while (!superblocks_.compare_exchange_weak(
            old, hdr,
            std::memory_order_release, std::memory_order_relaxed));

        if (n == 0) { count = 0; return nullptr; }

        // Build the free list through the carved blocks.
        auto* first = reinterpret_cast<FreeBlock*>(base);
        FreeBlock* prev = first;
        for (std::size_t i = 1; i < n; ++i) {
            auto* curr = reinterpret_cast<FreeBlock*>(base + i * bsz);
            prev->next = curr;
            prev = curr;
        }
        prev->next = nullptr;
        count = n;
        return first;
    }

    GlobalHeap(const GlobalHeap&)            = delete;
    GlobalHeap& operator=(const GlobalHeap&) = delete;

    // Intentionally leak superblocks.  Unmapping memory that thread-local
    // caches might still reference during static destruction is unsafe.
    // The OS reclaims everything when the process exits.
    ~GlobalHeap() = default;

private:
    GlobalHeap() = default;

    std::array<LockFreeStack, kNumSizeClasses> free_lists_;
    std::atomic<SuperblockHeader*> superblocks_{nullptr};
};

// ============================================================================
// ThreadCache — per-thread fast path (no synchronization required)
// ============================================================================

class ThreadCache {
public:
    /// Allocate `size` bytes.
    void* allocate(std::size_t size) {
        if (size > kMaxSmallSize) return allocate_large(size);

        std::size_t cls = size_class(size);
        auto& bin = bins_[cls];

        // Fast path: pop from the thread-local list.
        if (bin.head) {
            FreeBlock* b = bin.head;
            bin.head = b->next;
            --bin.count;
            return b;
        }

        // Slow path: refill from the global heap.
        return refill(cls);
    }

    /// Return `size` bytes starting at `ptr`.
    void deallocate(void* ptr, std::size_t size) noexcept {
        if (size > kMaxSmallSize) { deallocate_large(ptr); return; }

        std::size_t cls = size_class(size);
        auto& bin = bins_[cls];

        auto* b = static_cast<FreeBlock*>(ptr);
        b->next = bin.head;
        bin.head = b;
        ++bin.count;

        // If the local cache is too full, flush half back to the global heap.
        if (bin.count > kThreadCacheMaxSize) flush(cls);
    }

    ~ThreadCache() {
        // Return every cached block to the global heap on thread exit.
        for (std::size_t cls = 0; cls < kNumSizeClasses; ++cls)
            flush_all(cls);
    }

private:
    struct Bin {
        FreeBlock*  head  = nullptr;
        std::size_t count = 0;
    };

    std::array<Bin, kNumSizeClasses> bins_{};

    // ------------------------------------------------------------------
    // Slow-path helpers
    // ------------------------------------------------------------------

    /// Refill the thread-local bin from the global heap.
    void* refill(std::size_t cls) {
        GlobalHeap& heap = GlobalHeap::instance();
        auto& bin = bins_[cls];

        // Try to scavenge blocks already in the global free list.
        FreeBlock* result = nullptr;
        for (std::size_t i = 0; i < kBatchTransferSize; ++i) {
            FreeBlock* b = heap.pop(cls);
            if (!b) break;
            if (!result) {
                result = b;            // First block goes to the caller.
            } else {
                b->next  = bin.head;   // Remainder fills the local cache.
                bin.head = b;
                ++bin.count;
            }
        }
        if (result) return result;

        // Global list is empty — carve a fresh superblock.
        std::size_t n = 0;
        FreeBlock* list = heap.allocate_superblock(cls, n);
        if (!list) return nullptr;

        // First block for the caller; rest into the local cache.
        result = list;
        FreeBlock* rest = list->next;
        while (rest) {
            FreeBlock* nxt = rest->next;
            rest->next = bin.head;
            bin.head   = rest;
            ++bin.count;
            rest = nxt;
        }
        return result;
    }

    /// Flush half the bin back to the global heap.
    void flush(std::size_t cls) {
        auto& bin = bins_[cls];
        std::size_t to_flush = bin.count / 2;
        if (to_flush == 0) return;

        FreeBlock* first = bin.head;
        FreeBlock* last  = first;
        for (std::size_t i = 1; i < to_flush; ++i) last = last->next;

        bin.head   = last->next;
        bin.count -= to_flush;
        last->next = nullptr;

        GlobalHeap::instance().push_batch(cls, first, last);
    }

    /// Flush the entire bin back to the global heap.
    void flush_all(std::size_t cls) {
        auto& bin = bins_[cls];
        if (!bin.head) return;

        FreeBlock* first = bin.head;
        FreeBlock* last  = first;
        while (last->next) last = last->next;

        bin.head  = nullptr;
        bin.count = 0;

        GlobalHeap::instance().push_batch(cls, first, last);
    }

    // ------------------------------------------------------------------
    // Large allocations — direct mmap with a prepended size header
    // ------------------------------------------------------------------

    struct LargeHeader {
        std::size_t total_size;
    };

    static std::size_t round_to_page(std::size_t n) noexcept {
        return (n + kPageSize - 1) & ~(kPageSize - 1);
    }

    static void* allocate_large(std::size_t size) {
        std::size_t total = round_to_page(sizeof(LargeHeader) + size);
        void* mem = os_alloc(total);
        if (!mem) return nullptr;
        auto* hdr = static_cast<LargeHeader*>(mem);
        hdr->total_size = total;
        return hdr + 1;
    }

    static void deallocate_large(void* ptr) noexcept {
        auto* hdr = static_cast<LargeHeader*>(ptr) - 1;
        os_free(hdr, hdr->total_size);
    }
};

// ============================================================================
// Public API
// ============================================================================

/// Thread-local cache instance (one per thread, automatically created).
inline thread_local ThreadCache tl_cache;

/// Allocate at least `size` bytes of memory.  Returns nullptr on failure.
[[nodiscard]] inline void* allocate(std::size_t size) {
    if (size == 0) size = 1;
    return tl_cache.allocate(size);
}

/// Deallocate memory previously returned by allocate().
/// The caller must pass the same `size` used at allocation time.
inline void deallocate(void* ptr, std::size_t size) noexcept {
    if (!ptr) return;
    if (size == 0) size = 1;
    tl_cache.deallocate(ptr, size);
}

// ============================================================================
// STL-compatible allocator adapter
// ============================================================================

template <typename T>
class Allocator {
public:
    using value_type = T;

    Allocator() noexcept = default;
    template <typename U> Allocator(const Allocator<U>&) noexcept {}

    [[nodiscard]] T* allocate(std::size_t n) {
        if (auto* p = static_cast<T*>(lfa::allocate(n * sizeof(T))))
            return p;
        throw std::bad_alloc();
    }

    void deallocate(T* p, std::size_t n) noexcept {
        lfa::deallocate(p, n * sizeof(T));
    }

    template <typename U>
    bool operator==(const Allocator<U>&) const noexcept { return true; }
};

} // namespace lfa
