#!/usr/bin/env python3
"""v11.0 Phase 5 — fast-mode hot-path benchmark.

Measures `Store.save_knowledge` and `Recall.search` latency in MEMORY_MODE=fast
with the Ollama fallback explicitly disabled, then asserts that the in-process
telemetry counters report **zero** LLM/network calls. If a counter rises,
the bench exits non-zero and prints which counter spiked — that surfaces a
silent regression in the hot path immediately.

Outputs a Markdown table to stdout AND to `docs/v11/benchmark.md`.

Usage:
    python bin/memory-bench
"""

from __future__ import annotations

import os
import shutil
import statistics
import sys
import tempfile
import time
import uuid
from pathlib import Path


# v11 contract — set BEFORE importing server so resolve_mode_defaults() picks
# the right derived knobs at import time.
os.environ["MEMORY_MODE"] = "fast"
os.environ["MEMORY_USE_LLM_IN_HOT_PATH"] = "false"
os.environ["MEMORY_ALLOW_OLLAMA_IN_HOT_PATH"] = "false"
os.environ["MEMORY_RERANK_ENABLED"] = "false"
os.environ["MEMORY_ENRICHMENT_ENABLED"] = "false"
os.environ["MEMORY_LLM_ENABLED"] = "false"

# Use a throw-away memory dir so the bench never pollutes a real install.
_TMPDIR = Path(tempfile.mkdtemp(prefix="memory-bench-"))
os.environ["CLAUDE_MEMORY_DIR"] = str(_TMPDIR)

# src/ on the path
ROOT = Path(__file__).resolve().parent.parent
SRC = ROOT / "src"
if str(SRC) not in sys.path:
    sys.path.insert(0, str(SRC))

import server  # noqa: E402  (after env setup)


# ─── helpers ──────────────────────────────────────────────────────────


def _percentile(values: list[float], pct: float) -> float:
    """Linear-interpolated percentile in milliseconds."""
    if not values:
        return 0.0
    if len(values) == 1:
        return values[0]
    s = sorted(values)
    rank = (pct / 100.0) * (len(s) - 1)
    lo = int(rank)
    hi = min(lo + 1, len(s) - 1)
    frac = rank - lo
    return s[lo] + (s[hi] - s[lo]) * frac


def _fmt_ms(value: float) -> str:
    return f"{value:.1f}"


def _section(label: str) -> None:
    print(f"\n=== {label} ===")


# ─── bench ────────────────────────────────────────────────────────────


def main() -> int:
    _section("Setup")
    print(f"memory dir = {_TMPDIR}")
    print(f"mode       = {os.environ['MEMORY_MODE']}")
    print(f"llm hot    = {os.environ['MEMORY_USE_LLM_IN_HOT_PATH']}")
    print(f"ollama hot = {os.environ['MEMORY_ALLOW_OLLAMA_IN_HOT_PATH']}")

    store = server.Store()
    recall = server.Recall(store)

    # Seed a session row that save_knowledge expects.
    store.db.execute(
        "INSERT INTO sessions (id, started_at, project, status) "
        "VALUES (?, ?, ?, 'open')",
        ("bench-session", "2026-04-27T00:00:00Z", "bench"),
    )
    store.db.commit()

    _section("Warmup (load FastEmbed)")
    for i in range(3):
        rid, *_ = store.save_knowledge(
            sid="bench-session",
            content=f"warmup record {i} — load model into memory",
            ktype="fact", project="bench",
        )
        print(f"  warmup {i+1}/3 -> id={rid}")

    # Reset telemetry AFTER warmup so the bench doesn't include model-load time.
    store._perf_reset()

    # ── Bench A: 50 small unique saves ────────────────────────────────
    _section("Bench A — save_fast (50 unique small records)")
    save_a: list[float] = []
    for i in range(50):
        content = f"benchmark unique record {uuid.uuid4().hex} — payload {i}"
        t0 = time.perf_counter()
        store.save_knowledge(
            sid="bench-session", content=content, ktype="fact", project="bench",
        )
        save_a.append((time.perf_counter() - t0) * 1000.0)

    # ── Bench B: 50 saves of identical content (cache test) ───────────
    _section("Bench B — save_fast cached embedding (50x identical content)")
    same_content = "cached benchmark record — identical payload across attempts"
    save_b: list[float] = []
    # Prime the L2 cache with one save.
    store.save_knowledge(
        sid="bench-session", content=same_content, ktype="fact", project="bench",
    )
    for _ in range(50):
        t0 = time.perf_counter()
        store.save_knowledge(
            sid="bench-session", content=same_content, ktype="fact", project="bench",
        )
        save_b.append((time.perf_counter() - t0) * 1000.0)

    # ── Bench C: 30 unique searches ──────────────────────────────────
    _section("Bench C — search_fast (30 unique queries)")
    search_c: list[float] = []
    queries = [
        "warmup record memory",
        "benchmark unique payload",
        "postgres autovacuum threshold",
        "redis cluster failover",
        "kubernetes operator pattern",
        "fastembed cache warm",
        "binary quantization hamming",
        "vector store hnsw cosine",
        "fts5 keyword bm25",
        "rrf fusion ranking",
    ]
    for i in range(30):
        q = queries[i % len(queries)] + f" iter{i}"
        t0 = time.perf_counter()
        recall.search(query=q, project="bench", limit=5)
        search_c.append((time.perf_counter() - t0) * 1000.0)

    # ── Bench D: 30 cached searches (same query) ─────────────────────
    _section("Bench D — cached_search (30x identical query)")
    same_query = "warmup record memory cached"
    search_d: list[float] = []
    # Prime the cache.
    recall.search(query=same_query, project="bench", limit=5)
    for _ in range(30):
        t0 = time.perf_counter()
        recall.search(query=same_query, project="bench", limit=5)
        search_d.append((time.perf_counter() - t0) * 1000.0)

    # ── Counter sanity check ─────────────────────────────────────────
    snap = store._perf_snapshot()
    llm_calls = int(snap.get("llm_calls", 0))
    net_calls = int(snap.get("network_calls", 0))

    _section("Telemetry snapshot")
    for k in sorted(snap):
        print(f"  {k:<20s} = {snap[k]:.2f}")

    # ── Build table ──────────────────────────────────────────────────
    rows = [
        ("save_fast",          save_a),
        ("save_fast (cached)", save_b),
        ("search_fast",        search_c),
        ("cached_search",      search_d),
    ]

    md_lines = [
        "| metric              | p50 | p95 | p99 |",
        "|---------------------|-----|-----|-----|",
    ]
    for name, samples in rows:
        p50 = _percentile(samples, 50)
        p95 = _percentile(samples, 95)
        p99 = _percentile(samples, 99)
        md_lines.append(
            f"| {name:<19s} | {_fmt_ms(p50):>3} | {_fmt_ms(p95):>3} | {_fmt_ms(p99):>3} |"
        )
    md_lines.append(f"| llm_calls           | {llm_calls}                |")
    md_lines.append(f"| network_calls       | {net_calls}                |")

    table = "\n".join(md_lines)

    _section("Results (markdown)")
    print(table)

    # Persist.
    out_path = ROOT / "docs" / "v11" / "benchmark.md"
    out_path.parent.mkdir(parents=True, exist_ok=True)
    header = "# v11.0 hot-path benchmark\n\n"
    header += f"_Generated: {time.strftime('%Y-%m-%d %H:%M:%S %Z')}_\n\n"
    header += "MEMORY_MODE=fast, MEMORY_ALLOW_OLLAMA_IN_HOT_PATH=false.\n\n"
    out_path.write_text(header + table + "\n")
    print(f"\nWrote {out_path}")

    # ── Cleanup tmpdir ────────────────────────────────────────────────
    try:
        store.db.close()
    except Exception:
        pass
    shutil.rmtree(_TMPDIR, ignore_errors=True)

    # ── Assertion ─────────────────────────────────────────────────────
    if llm_calls != 0 or net_calls != 0:
        print(
            f"\nFAIL: hot path produced llm_calls={llm_calls} network_calls={net_calls}",
            file=sys.stderr,
        )
        return 1
    print("\nOK: hot path made 0 LLM calls and 0 network calls.")
    return 0


if __name__ == "__main__":
    sys.exit(main())
