#!/usr/bin/env python3
"""v11.0 Phase 8 — CI-style performance gate.

Runs `bin/memory-bench` in a subprocess with `MEMORY_MODE=fast`, parses the
markdown table emitted on stdout, and asserts every metric stays under its
threshold AND that `llm_calls == 0` AND `network_calls == 0`. On regression
the script exits non-zero and prints which metrics failed; on success it
exits 0 and prints a summary table.

Usage:
    bin/memory-perf-gate                # runs the live bench
    bin/memory-perf-gate --help         # CLI help
    python bin/memory-perf-gate         # equivalent to the line above

Designed to be importable too — tests/test_v11_perf_gate.py drives the parse
and compare logic with synthetic input via `runpy.run_path`.
"""

from __future__ import annotations

import argparse
import os
import re
import subprocess
import sys
from pathlib import Path


# ─── Thresholds (P95 in milliseconds) ─────────────────────────────────


THRESHOLDS_P95_MS: dict[str, float] = {
    "save_fast":          50.0,
    "save_fast (cached)":  5.0,
    "search_fast":        200.0,
    "cached_search":       20.0,
}

REQUIRED_ZERO_COUNTERS: tuple[str, ...] = ("llm_calls", "network_calls")


# ─── Pure functions (testable without the bench) ──────────────────────


# A bench row looks like:
#   | save_fast           | 12.3 | 18.5 | 22.1 |
# A counter row looks like:
#   | llm_calls           | 0                |
_ROW_PATTERN = re.compile(r"^\|\s*(.+?)\s*\|\s*(.+?)\s*\|")


def parse_bench_table(text: str) -> dict[str, dict[str, float | int]]:
    """Parse the markdown table emitted by bin/memory-bench.

    Returns a dict shaped like:
        {
          "save_fast":          {"p50": 12.3, "p95": 18.5, "p99": 22.1},
          "save_fast (cached)": {...},
          "search_fast":        {...},
          "cached_search":      {...},
          "llm_calls":          {"value": 0},
          "network_calls":      {"value": 0},
        }

    Skips header / separator rows. Tolerates extra rows (returns them too
    so the gate can ignore unknown metrics gracefully).
    """
    out: dict[str, dict[str, float | int]] = {}
    for raw_line in text.splitlines():
        line = raw_line.strip()
        if not line.startswith("|"):
            continue
        # Header row (`| metric ...`) and separator (`|-----`) get filtered.
        if line.lower().startswith("| metric") or line.startswith("|-"):
            continue

        # Split into cells and trim — splitting on `|` then dropping leading
        # / trailing empties matches both 4-column metric rows AND 2-column
        # counter rows.
        cells = [c.strip() for c in line.split("|")]
        cells = [c for c in cells if c != ""]
        if len(cells) == 0:
            continue

        name = cells[0]
        rest = cells[1:]

        # Counter row: `| llm_calls | 0 |` → 1 numeric cell.
        if len(rest) == 1:
            try:
                out[name] = {"value": int(float(rest[0]))}
            except ValueError:
                continue
            continue

        # Metric row: at least 3 numeric cells (p50/p95/p99). Be lenient if a
        # bench variant adds a column.
        try:
            p50 = float(rest[0])
            p95 = float(rest[1])
            p99 = float(rest[2]) if len(rest) >= 3 else float("nan")
        except ValueError:
            continue
        out[name] = {"p50": p50, "p95": p95, "p99": p99}

    return out


def evaluate_thresholds(
    parsed: dict[str, dict[str, float | int]],
    *,
    thresholds: dict[str, float] | None = None,
    zero_counters: tuple[str, ...] | None = None,
) -> tuple[bool, list[str], list[str]]:
    """Return `(passed, failures, summary_lines)`.

    A failure entry is human-readable: which metric, what we measured, what
    we expected. `summary_lines` always returns one row per checked metric
    so callers can print a results table whether the gate passed or failed.
    """
    thr = thresholds if thresholds is not None else THRESHOLDS_P95_MS
    counters = zero_counters if zero_counters is not None else REQUIRED_ZERO_COUNTERS

    failures: list[str] = []
    summary: list[str] = []

    for metric, ceiling in thr.items():
        row = parsed.get(metric)
        if row is None:
            failures.append(
                f"missing metric {metric!r} in bench output"
            )
            summary.append(f"{metric:<22s}: MISSING (need p95 < {ceiling})")
            continue
        p95 = row.get("p95")
        if p95 is None:
            failures.append(
                f"metric {metric!r} has no p95 value in bench output"
            )
            summary.append(f"{metric:<22s}: NO_P95 (need p95 < {ceiling})")
            continue
        ok = float(p95) < float(ceiling)
        if not ok:
            failures.append(
                f"{metric}: p95={p95:.2f}ms exceeds threshold {ceiling:.2f}ms"
            )
        marker = "OK" if ok else "FAIL"
        summary.append(
            f"{metric:<22s}: p95={float(p95):>6.2f}ms < {ceiling:>5.1f}ms  [{marker}]"
        )

    for c in counters:
        row = parsed.get(c)
        if row is None:
            failures.append(f"missing counter {c!r} in bench output")
            summary.append(f"{c:<22s}: MISSING (need 0)")
            continue
        value = row.get("value")
        if value is None:
            failures.append(f"counter {c!r} has no value")
            summary.append(f"{c:<22s}: NO_VALUE (need 0)")
            continue
        ok = int(value) == 0
        if not ok:
            failures.append(f"{c}: got {value}, must be 0")
        marker = "OK" if ok else "FAIL"
        summary.append(
            f"{c:<22s}: value={int(value)} (need 0)  [{marker}]"
        )

    return (len(failures) == 0), failures, summary


# ─── Bench runner (subprocess) ────────────────────────────────────────


def run_bench(
    bench_path: Path | None = None,
    *,
    timeout_s: int = 300,
    max_attempts: int = 3,
) -> tuple[int, str, str]:
    """Run `bin/memory-bench` and return `(exit_code, stdout, stderr)`.

    The bench has a known flakiness window when the async enrichment worker
    is still ticking at process shutdown (sqlite3 commit-without-tx race
    against the worker's connection). We retry up to `max_attempts` times
    on non-zero exit before reporting failure. This mirrors what a CI
    runner would do anyway and keeps the gate signal honest.
    """
    here = Path(__file__).resolve()
    repo_root = here.parent.parent
    bench = bench_path or (repo_root / "bin" / "memory-bench")
    if not bench.is_file():
        return 127, "", f"bench script not found: {bench}"

    env = os.environ.copy()
    env["MEMORY_MODE"] = "fast"
    env["MEMORY_USE_LLM_IN_HOT_PATH"] = "false"
    env["MEMORY_ALLOW_OLLAMA_IN_HOT_PATH"] = "false"
    env["MEMORY_RERANK_ENABLED"] = "false"
    env["MEMORY_ENRICHMENT_ENABLED"] = "false"
    env["MEMORY_LLM_ENABLED"] = "false"
    # Silence the async enrichment worker for the bench duration — it races
    # against the bench process at shutdown.
    env["MEMORY_ASYNC_ENRICHMENT"] = "false"

    last_code = 1
    last_stdout = ""
    last_stderr = ""
    for attempt in range(1, max(1, max_attempts) + 1):
        proc = subprocess.run(
            [sys.executable, str(bench)],
            capture_output=True, text=True, timeout=timeout_s,
            env=env, cwd=str(repo_root),
        )
        last_code, last_stdout, last_stderr = proc.returncode, proc.stdout, proc.stderr
        if proc.returncode == 0:
            return last_code, last_stdout, last_stderr
        if attempt < max_attempts:
            print(
                f"  bench attempt {attempt}/{max_attempts} failed "
                f"(exit={proc.returncode}); retrying ...",
                file=sys.stderr,
            )
    return last_code, last_stdout, last_stderr


def gate_from_text(text: str) -> tuple[bool, list[str], list[str]]:
    """Helper: parse + evaluate in one shot. Used by tests."""
    parsed = parse_bench_table(text)
    return evaluate_thresholds(parsed)


# ─── Entry point ──────────────────────────────────────────────────────


def main(argv: list[str] | None = None) -> int:
    parser = argparse.ArgumentParser(
        prog="memory-perf-gate",
        description="v11.0 perf gate — runs bin/memory-bench and validates "
                    "P95 thresholds + zero LLM/network call invariants.",
    )
    parser.add_argument(
        "--bench", type=Path, default=None,
        help="path to memory-bench (defaults to ./bin/memory-bench)",
    )
    parser.add_argument(
        "--timeout", type=int, default=300,
        help="subprocess timeout in seconds (default 300)",
    )
    args = parser.parse_args(argv)

    print("Running bin/memory-bench (MEMORY_MODE=fast) ...", flush=True)
    code, stdout, stderr = run_bench(args.bench, timeout_s=args.timeout)
    if code != 0:
        print(stdout, end="")
        print(stderr, end="", file=sys.stderr)
        print(
            f"\nFAIL: bench exited {code} — gate cannot evaluate.",
            file=sys.stderr,
        )
        return 1

    parsed = parse_bench_table(stdout)
    passed, failures, summary = evaluate_thresholds(parsed)

    print()
    print("v11 perf gate — summary")
    print("-" * 64)
    for line in summary:
        print(line)
    print("-" * 64)

    if not passed:
        print()
        print("FAIL: regression(s) detected:", file=sys.stderr)
        for f in failures:
            print(f"  - {f}", file=sys.stderr)
        return 1

    print("OK: every metric within threshold and counters at zero.")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
