#!/usr/bin/env python3
"""Focused aft_search hybrid-fusion quality benchmark.

Runs the existing in-tree golden fixtures plus an exact-identifier stress set,
then applies bench-only offline rerankers to the same candidate lists. The
rerankers are intentionally outside production code: they make the suspected
fusion weakness measurable without changing `crates/aft`.
"""

from __future__ import annotations

import argparse
import json
import math
import os
import sys
import time
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple

from run import (  # type: ignore
    AftClient,
    AftProtocolError,
    binary_sha256,
    binary_version,
    git_rev,
    normalize_result_path,
    percentile,
)

JsonObject = Dict[str, Any]

DEFAULT_TOP_K = 5
DEFAULT_WIDE_TOP_K = 100
BENCH_READY_TIMEOUT_SECS = 600.0
RRF_K = 60.0
IDENTIFIER_LIKE_SHAPES = {"identifier", "path", "error-code"}

# Exclude benchmark/report artifacts from the bench-only exact-match oracle so
# the query strings in fixture/result/report files do not become self-fulfilling
# exact hits. The production lexical index is broader; this oracle is only for
# comparing rerank policies against the intended corpus files.
EXACT_SCAN_EXCLUDED_DIRS = {
    ".bench",
    ".git",
    ".hg",
    ".mypy_cache",
    ".pytest_cache",
    ".ruff_cache",
    ".venv",
    "coverage",
    "dist",
    "node_modules",
    "target",
}
EXACT_SCAN_EXCLUDED_PREFIXES = (
    ".alfonso/",
    "benchmarks/aft-search/results/",
)
EXACT_SCAN_EXCLUDED_FILES = {
    "benchmarks/aft-search/baseline.json",
    "benchmarks/aft-search/external-fixtures.json",
    "benchmarks/aft-search/fixtures.json",
    "benchmarks/aft-search/identifier-fusion-fixtures.json",
}
MAX_EXACT_SCAN_BYTES = 2_000_000

CANDIDATES = (
    "current",
    "rrf_exact_lane",
    "exact_identifier_first",
    "identifier_lexical_uncapped",
)


def parse_args(argv: Sequence[str]) -> argparse.Namespace:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--binary", default="../../target/release/aft", help="Path to the aft binary to measure.")
    parser.add_argument("--project-root", default="../..", help="Project root to configure aft against.")
    parser.add_argument("--existing-fixtures", default="fixtures.json", help="Existing golden fixture JSON.")
    parser.add_argument(
        "--identifier-fixtures",
        default="identifier-fusion-fixtures.json",
        help="Focused exact-identifier fixture JSON.",
    )
    parser.add_argument("--out", default="results/search-fusion-quality.json", help="Detailed JSON output path.")
    parser.add_argument(
        "--summary",
        default="results/search-fusion-quality-summary.tsv",
        help="TSV summary output path.",
    )
    parser.add_argument("--top-k", type=int, default=DEFAULT_TOP_K, help="Top-k threshold for pass metrics.")
    parser.add_argument(
        "--wide-top-k",
        type=int,
        default=DEFAULT_WIDE_TOP_K,
        help="Wide aft_search result count used as the offline rerank candidate pool.",
    )
    parser.add_argument(
        "--ready-timeout",
        type=float,
        default=BENCH_READY_TIMEOUT_SECS,
        help="Seconds to wait for search and semantic indexes.",
    )
    return parser.parse_args(list(argv))


def script_dir() -> Path:
    return Path(__file__).resolve().parent


def resolve_script_path(value: str) -> Path:
    path = Path(value)
    if path.is_absolute():
        return path
    return (script_dir() / path).resolve()


def display_path(path: str | Path) -> str:
    value = Path(path)
    try:
        return "~/" + value.expanduser().resolve().relative_to(Path.home()).as_posix()
    except (OSError, ValueError):
        return str(path)


def ensure_ort_env() -> JsonObject:
    if os.environ.get("ORT_DYLIB_PATH"):
        return {"source": "env", "ort_dylib_path": display_path(os.environ["ORT_DYLIB_PATH"])}

    lib_name = "libonnxruntime.dylib" if sys.platform == "darwin" else "libonnxruntime.so"
    candidates = [
        Path.home() / ".local/share/cortexkit/aft/onnxruntime/1.24.4" / lib_name,
        Path.home() / "Library/Application Support/cortexkit/aft/onnxruntime/1.24.4" / lib_name,
        Path("/opt/homebrew/lib") / lib_name,
        Path("/usr/local/lib") / lib_name,
    ]
    for candidate in candidates:
        if candidate.exists():
            os.environ["ORT_DYLIB_PATH"] = str(candidate)
            return {"source": "auto", "ort_dylib_path": display_path(candidate)}
    return {"source": "unresolved", "ort_dylib_path": None}


def load_fixtures(path: Path, suite: str) -> List[JsonObject]:
    data = json.loads(path.read_text())
    if not isinstance(data, list):
        raise ValueError(f"{path} must contain a JSON array")
    fixtures: List[JsonObject] = []
    seen = set()
    for index, raw in enumerate(data, start=1):
        if not isinstance(raw, dict):
            raise ValueError(f"{path} item {index} must be an object")
        for key in ("query", "shape", "expected_top_files"):
            if key not in raw:
                raise ValueError(f"{path} item {index} missing {key}")
        query = str(raw["query"])
        fixture_id = str(raw.get("id") or f"{suite}:{query}")
        if (suite, fixture_id) in seen:
            raise ValueError(f"duplicate fixture id in {path}: {fixture_id}")
        seen.add((suite, fixture_id))
        expected = [str(item) for item in raw["expected_top_files"]]
        if not expected:
            raise ValueError(f"fixture {fixture_id} expected_top_files must be non-empty")
        fixtures.append(
            {
                "suite": suite,
                "id": fixture_id,
                "query": query,
                "shape": str(raw["shape"]),
                "expected_top_files": expected,
                "notes": str(raw.get("notes", "")),
            }
        )
    return fixtures


def missing_expected_files(fixtures: Sequence[JsonObject], project_root: Path) -> List[JsonObject]:
    missing: List[JsonObject] = []
    for fixture in fixtures:
        for rel_path in fixture["expected_top_files"]:
            if not (project_root / rel_path).exists():
                missing.append({"suite": fixture["suite"], "id": fixture["id"], "file": rel_path})
    return missing


def should_scan_file(rel_path: str, path: Path) -> bool:
    if rel_path in EXACT_SCAN_EXCLUDED_FILES:
        return False
    if any(rel_path.startswith(prefix) for prefix in EXACT_SCAN_EXCLUDED_PREFIXES):
        return False
    try:
        if path.stat().st_size > MAX_EXACT_SCAN_BYTES:
            return False
    except OSError:
        return False
    return True


def load_exact_scan_corpus(project_root: Path) -> List[Tuple[str, str]]:
    corpus: List[Tuple[str, str]] = []
    for root, dirs, files in os.walk(project_root):
        dirs[:] = [d for d in dirs if d not in EXACT_SCAN_EXCLUDED_DIRS]
        root_path = Path(root)
        for name in files:
            path = root_path / name
            try:
                rel_path = path.relative_to(project_root).as_posix()
            except ValueError:
                continue
            if not should_scan_file(rel_path, path):
                continue
            try:
                data = path.read_bytes()
            except OSError:
                continue
            if b"\0" in data[:4096]:
                continue
            try:
                text = data.decode("utf-8")
            except UnicodeDecodeError:
                text = data.decode("utf-8", errors="ignore")
            corpus.append((rel_path, text))
    return corpus


def exact_counts(query: str, corpus: Sequence[Tuple[str, str]]) -> Dict[str, int]:
    if not query:
        return {}
    counts: Dict[str, int] = {}
    for rel_path, text in corpus:
        count = text.count(query)
        if count:
            counts[rel_path] = count
    return counts


def round_float(value: Any, digits: int = 6) -> Optional[float]:
    if isinstance(value, (int, float)):
        return round(float(value), digits)
    return None


def response_candidates(response: JsonObject, project_root: Path) -> List[JsonObject]:
    results = response.get("results") or []
    if not isinstance(results, list):
        raise AftProtocolError(f"semantic_search returned non-list results: {response}")
    candidates: List[JsonObject] = []
    for rank, result in enumerate(results, start=1):
        if not isinstance(result, dict):
            continue
        rel_file = normalize_result_path(str(result.get("file", "")), project_root)
        candidates.append(
            {
                "rank": rank,
                "file": rel_file,
                "name": result.get("name"),
                "kind": result.get("kind"),
                "score": round_float(result.get("score")),
                "source": result.get("source", "semantic"),
                "semantic_score": round_float(result.get("semantic_score")),
                "lexical_score": round_float(result.get("lexical_score")),
                "hybrid_boosted": bool(result.get("hybrid_boosted")),
            }
        )
    return candidates


def candidate_pool(candidates: Sequence[JsonObject], counts: Dict[str, int]) -> List[JsonObject]:
    pool = [dict(item, exact_count=counts.get(str(item["file"]), 0)) for item in candidates]
    present = {str(item["file"]) for item in pool}
    for rel_file, count in sorted(counts.items()):
        if rel_file in present:
            continue
        pool.append(
            {
                "rank": None,
                "file": rel_file,
                "name": "",
                "kind": "FileSummary",
                "score": 0.0,
                "source": "offline_exact",
                "semantic_score": None,
                "lexical_score": None,
                "hybrid_boosted": False,
                "exact_count": count,
            }
        )
    return pool


def current_order(pool: Sequence[JsonObject], original_len: int) -> List[JsonObject]:
    return [dict(item) for item in pool[:original_len]]


def semantic_rank_map(pool: Sequence[JsonObject]) -> Dict[str, int]:
    semantic_items = [item for item in pool if item.get("semantic_score") is not None]
    semantic_items.sort(
        key=lambda item: (
            -float(item.get("semantic_score") or 0.0),
            int(item.get("rank") or 1_000_000),
            str(item.get("file") or ""),
            str(item.get("name") or ""),
        )
    )
    ranks: Dict[str, int] = {}
    for rank, item in enumerate(semantic_items, start=1):
        key = result_key(item)
        ranks.setdefault(key, rank)
    return ranks


def lexical_strength(item: JsonObject) -> float:
    score = float(item.get("lexical_score") or 0.0)
    count = int(item.get("exact_count") or 0)
    if count:
        # Exact files that were not present in the returned fused list still get
        # a lane rank. log1p keeps repeated examples from swamping exactness.
        score = max(score, 1.0 + math.log1p(count))
    return score


def lexical_rank_map(pool: Sequence[JsonObject]) -> Dict[str, int]:
    lexical_items = [item for item in pool if lexical_strength(item) > 0.0]
    lexical_items.sort(
        key=lambda item: (
            -lexical_strength(item),
            int(item.get("rank") or 1_000_000),
            str(item.get("file") or ""),
            str(item.get("name") or ""),
        )
    )
    ranks: Dict[str, int] = {}
    for rank, item in enumerate(lexical_items, start=1):
        key = result_key(item)
        ranks.setdefault(key, rank)
    return ranks


def result_key(item: JsonObject) -> str:
    return "\0".join([str(item.get("file") or ""), str(item.get("name") or ""), str(item.get("kind") or "")])


def rerank_rrf(pool: Sequence[JsonObject], shape: str) -> List[JsonObject]:
    sem_ranks = semantic_rank_map(pool)
    lex_ranks = lexical_rank_map(pool)
    scored: List[JsonObject] = []
    for item in pool:
        key = result_key(item)
        sem_rank = sem_ranks.get(key)
        lex_rank = lex_ranks.get(key)
        score = 0.0
        if sem_rank is not None:
            score += 1.0 / (RRF_K + sem_rank)
        if lex_rank is not None:
            score += 1.0 / (RRF_K + lex_rank)
        scored.append(dict(item, rerank_score=round(score, 9), semantic_rank=sem_rank, lexical_rank=lex_rank))
    scored.sort(
        key=lambda item: (
            -float(item["rerank_score"]),
            -int(item.get("exact_count") or 0),
            int(item.get("rank") or 1_000_000),
            str(item.get("file") or ""),
        )
    )
    return scored


def rerank_exact_identifier_first(pool: Sequence[JsonObject], shape: str) -> List[JsonObject]:
    if shape != "identifier":
        return sorted(pool, key=lambda item: int(item.get("rank") or 1_000_000))
    ranked = [dict(item) for item in pool]
    ranked.sort(
        key=lambda item: (
            0 if int(item.get("exact_count") or 0) > 0 else 1,
            int(item.get("rank") or 1_000_000),
            str(item.get("file") or ""),
            str(item.get("name") or ""),
        )
    )
    return ranked


def rerank_identifier_lexical_uncapped(pool: Sequence[JsonObject], shape: str) -> List[JsonObject]:
    if shape not in IDENTIFIER_LIKE_SHAPES:
        return sorted(pool, key=lambda item: int(item.get("rank") or 1_000_000))
    ranked: List[JsonObject] = []
    for item in pool:
        current_score = float(item.get("score") or 0.0)
        uncapped_lexical = lexical_strength(item) * 0.8
        ranked.append(dict(item, rerank_score=round(max(current_score, uncapped_lexical), 9)))
    ranked.sort(
        key=lambda item: (
            -float(item.get("rerank_score") or 0.0),
            int(item.get("rank") or 1_000_000),
            str(item.get("file") or ""),
            str(item.get("name") or ""),
        )
    )
    return ranked


def rerank(pool: Sequence[JsonObject], original_len: int, shape: str, candidate: str) -> List[JsonObject]:
    if candidate == "current":
        return current_order(pool, original_len)
    if candidate == "rrf_exact_lane":
        return rerank_rrf(pool, shape)
    if candidate == "exact_identifier_first":
        return rerank_exact_identifier_first(pool, shape)
    if candidate == "identifier_lexical_uncapped":
        return rerank_identifier_lexical_uncapped(pool, shape)
    raise ValueError(f"unknown candidate: {candidate}")


def first_match_rank(order: Sequence[JsonObject], expected_files: Iterable[str]) -> Optional[int]:
    expected = set(expected_files)
    for rank, item in enumerate(order, start=1):
        if str(item.get("file")) in expected:
            return rank
    return None


def slim_top(order: Sequence[JsonObject], top_k: int) -> List[JsonObject]:
    out: List[JsonObject] = []
    for rank, item in enumerate(order[:top_k], start=1):
        out.append(
            {
                "rank": rank,
                "file": item.get("file"),
                "name": item.get("name"),
                "kind": item.get("kind"),
                "score": item.get("score"),
                "source": item.get("source"),
                "semantic_score": item.get("semantic_score"),
                "lexical_score": item.get("lexical_score"),
                "exact_count": item.get("exact_count", 0),
                "rerank_score": item.get("rerank_score"),
            }
        )
    return out


def evaluate_one(
    client: AftClient,
    fixture: JsonObject,
    project_root: Path,
    exact_corpus: Sequence[Tuple[str, str]],
    top_k: int,
    wide_top_k: int,
) -> JsonObject:
    response, latency_ms = client.semantic_search(fixture["query"], top_k=wide_top_k)
    if response.get("success") is False:
        raise AftProtocolError(f"semantic_search failed for {fixture['query']!r}: {response}")
    if response.get("status") != "ready":
        raise AftProtocolError(f"semantic_search not ready for {fixture['query']!r}: {response}")

    candidates = response_candidates(response, project_root)
    counts = exact_counts(fixture["query"], exact_corpus)
    pool = candidate_pool(candidates, counts)

    candidate_results: JsonObject = {}
    rerank_latencies: JsonObject = {}
    for candidate in CANDIDATES:
        start = time.perf_counter()
        order = rerank(pool, len(candidates), fixture["shape"], candidate)
        rerank_latencies[candidate] = round((time.perf_counter() - start) * 1000.0, 6)
        rank = first_match_rank(order, fixture["expected_top_files"])
        candidate_results[candidate] = {
            "first_match_rank": rank,
            "rank1": 1.0 if rank == 1 else 0.0,
            f"top{top_k}": 1.0 if rank is not None and rank <= top_k else 0.0,
            "mrr": round(1.0 / rank, 6) if rank else 0.0,
            "top_results": slim_top(order, top_k),
        }

    exact_expected_files = sorted(set(counts).intersection(fixture["expected_top_files"]))
    return {
        "suite": fixture["suite"],
        "id": fixture["id"],
        "query": fixture["query"],
        "shape": fixture["shape"],
        "expected_top_files": fixture["expected_top_files"],
        "notes": fixture.get("notes", ""),
        "latency_ms": round(latency_ms, 3),
        "rerank_latency_ms": rerank_latencies,
        "exact_match_file_count": len(counts),
        "exact_expected_files": exact_expected_files,
        "candidates_returned": len(candidates),
        "candidate_results": candidate_results,
    }


def aggregate(rows: Sequence[JsonObject], suite: str, candidate: str, top_k: int) -> JsonObject:
    items = [row for row in rows if suite == "all" or row["suite"] == suite]
    if not items:
        return {
            "suite": suite,
            "candidate": candidate,
            "fixtures": 0,
            "rank1_rate": 0.0,
            f"top{top_k}_rate": 0.0,
            "mrr": 0.0,
            "query_latency_ms_p50": 0.0,
            "query_latency_ms_p95": 0.0,
            "rerank_latency_ms_p50": 0.0,
            "rerank_latency_ms_p95": 0.0,
        }
    key = f"top{top_k}"
    latencies = [float(row["latency_ms"]) for row in items]
    rerank_latencies = [float(row["rerank_latency_ms"].get(candidate, 0.0)) for row in items]
    return {
        "suite": suite,
        "candidate": candidate,
        "fixtures": len(items),
        "rank1_rate": round(sum(float(row["candidate_results"][candidate]["rank1"]) for row in items) / len(items), 6),
        f"top{top_k}_rate": round(sum(float(row["candidate_results"][candidate][key]) for row in items) / len(items), 6),
        "mrr": round(sum(float(row["candidate_results"][candidate]["mrr"]) for row in items) / len(items), 6),
        "query_latency_ms_p50": percentile(latencies, 50),
        "query_latency_ms_p95": percentile(latencies, 95),
        "rerank_latency_ms_p50": percentile(rerank_latencies, 50),
        "rerank_latency_ms_p95": percentile(rerank_latencies, 95),
    }


def aggregate_by_suite(rows: Sequence[JsonObject], top_k: int) -> List[JsonObject]:
    suites = ["identifier", "existing", "all"]
    return [aggregate(rows, suite, candidate, top_k) for suite in suites for candidate in CANDIDATES]


def write_json(path: Path, value: JsonObject) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    path.write_text(json.dumps(value, indent=2, sort_keys=True) + "\n")


def write_summary(path: Path, aggregates: Sequence[JsonObject], top_k: int) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    key = f"top{top_k}_rate"
    columns = [
        "suite",
        "candidate",
        "fixtures",
        "rank1_rate",
        key,
        "mrr",
        "query_latency_ms_p50",
        "query_latency_ms_p95",
        "rerank_latency_ms_p50",
        "rerank_latency_ms_p95",
    ]
    lines = ["\t".join(columns)]
    for row in aggregates:
        lines.append("\t".join(str(row.get(column, "")) for column in columns))
    path.write_text("\n".join(lines) + "\n")


def print_summary(aggregates: Sequence[JsonObject], top_k: int) -> None:
    key = f"top{top_k}_rate"
    print("\nFusion quality summary")
    print("  suite        candidate                       n  R@1    R@%d   MRR    rerank p95 ms" % top_k)
    for row in aggregates:
        print(
            f"  {row['suite']:<12} {row['candidate']:<30} {row['fixtures']:>2}  "
            f"{row['rank1_rate']:.3f}  {row[key]:.3f}  {row['mrr']:.3f}  "
            f"{row['rerank_latency_ms_p95']:.4f}"
        )


def main(argv: Sequence[str]) -> int:
    args = parse_args(argv)
    ort_info = ensure_ort_env()
    project_root = Path(args.project_root).resolve()
    binary = Path(args.binary).resolve()
    if not binary.exists():
        raise FileNotFoundError(f"aft binary not found: {binary}")

    existing_fixtures = load_fixtures(resolve_script_path(args.existing_fixtures), "existing")
    identifier_fixtures = load_fixtures(resolve_script_path(args.identifier_fixtures), "identifier")
    fixtures = identifier_fixtures + existing_fixtures
    missing = missing_expected_files(fixtures, project_root)
    exact_corpus = load_exact_scan_corpus(project_root)

    client = AftClient(binary, project_root, args.ready_timeout)
    try:
        client.configure()
        status = client.wait_for_indexes(require_search=True)
        version_response = client.call("version", timeout_secs=10.0)
        protocol_version = version_response.get("version") if version_response.get("success") else None
        rows = [
            evaluate_one(client, fixture, project_root, exact_corpus, args.top_k, args.wide_top_k)
            for fixture in fixtures
        ]
    finally:
        client.close()

    aggregates = aggregate_by_suite(rows, args.top_k)
    output: JsonObject = {
        "schema_version": 1,
        "benchmark": "aft-search-fusion-quality",
        "generated_at_unix": int(time.time()),
        "top_k": args.top_k,
        "wide_top_k": args.wide_top_k,
        "rrf_k": RRF_K,
        "binary": {
            "path": args.binary,
            "version": protocol_version or binary_version(binary),
            "sha256": binary_sha256(binary),
        },
        "project": {"root": args.project_root, "git_rev": git_rev(project_root)},
        "ort": ort_info,
        "semantic_index": status.get("semantic_index"),
        "search_index": status.get("search_index"),
        "fixtures": {
            "existing": args.existing_fixtures,
            "identifier": args.identifier_fixtures,
            "missing_expected_files": missing,
        },
        "exact_scan": {
            "files_scanned": len(exact_corpus),
            "excluded_dirs": sorted(EXACT_SCAN_EXCLUDED_DIRS),
            "excluded_files": sorted(EXACT_SCAN_EXCLUDED_FILES),
            "excluded_prefixes": list(EXACT_SCAN_EXCLUDED_PREFIXES),
            "max_file_bytes": MAX_EXACT_SCAN_BYTES,
        },
        "candidate_descriptions": {
            "current": "Current aft_search fused order from semantic_search(top_k=wide_top_k).",
            "rrf_exact_lane": "Bench-only reciprocal-rank fusion over reconstructed semantic ranks and an exact-match lexical lane.",
            "exact_identifier_first": "Bench-only identifier guardrail: exact-match files sort before non-exact files for identifier-shaped queries.",
            "identifier_lexical_uncapped": "Bench-only simulation of removing the 0.25 lexical-only ceiling for identifier/path/error-code queries.",
        },
        "aggregates": aggregates,
        "per_fixture": rows,
    }

    out_path = resolve_script_path(args.out)
    summary_path = resolve_script_path(args.summary)
    write_json(out_path, output)
    write_summary(summary_path, aggregates, args.top_k)
    print_summary(aggregates, args.top_k)
    print(f"\nwrote {out_path}")
    print(f"wrote {summary_path}")
    return 0


if __name__ == "__main__":
    raise SystemExit(main(sys.argv[1:]))
