#!/usr/bin/env python3
"""Fail if public provider/model count claims drift from providers.toml."""

from __future__ import annotations

import argparse
import json
import re
import sys
from pathlib import Path

SCRIPT_DIR = Path(__file__).resolve().parent
if str(SCRIPT_DIR) not in sys.path:
    sys.path.insert(0, str(SCRIPT_DIR))

from catalog_counts import catalog_counts  # noqa: E402


REQUIRED_SURFACES = (
    ("README.md", "{providers} LLM providers"),
    ("README.md", "{live_bucket} live-validated"),
    ("README.md", "{cataloged_bucket} cataloged"),
    ("pyproject.toml", "{providers} LLM providers"),
    ("pyproject.toml", "{cataloged_bucket} cataloged"),
    ("server.json", "{providers} LLM providers"),
    ("docs/index.html", "{providers} LLM providers"),
    ("docs/index.html", "{live_bucket} live-validated"),
    ("docs/index.html", "{cataloged_bucket} cataloged"),
    ("docs/free-llm-api-providers-list.html", "These {providers} providers"),
    ("docs/run-coding-agents-on-free-models.html", "pools {providers} free providers"),
    ("docs/run-opencode-on-free-models.html", "pools {providers} free providers"),
)

EXTERNAL_CONTEXT = (
    "freellmapi",
    "litellm",
    "openrouter free models",
    "openrouter's",
    "keyless",
    "no API key",
    "need no API",
    "providers ok",
)


def _format(template: str, counts) -> str:
    return template.format(
        providers=counts.providers,
        live_bucket=counts.live_bucket,
        cataloged_bucket=counts.cataloged_bucket,
    )


def _read(root: Path, rel: str) -> str:
    return (root / rel).read_text(encoding="utf-8")


def _check_required_surfaces(root: Path, counts) -> list[str]:
    errors: list[str] = []
    for rel, template in REQUIRED_SURFACES:
        expected = _format(template, counts)
        if expected not in _read(root, rel):
            errors.append(f"{rel}: missing expected count phrase {expected!r}")
    return errors


def _check_reference_table(root: Path, counts) -> list[str]:
    rel = "docs/free-llm-api-providers-list.html"
    html = _read(root, rel)
    rows = {
        match.group("id"): int(match.group("count"))
        for match in re.finditer(
            r'<tr data-provider="(?P<id>[^"]+)">.*?<td class=num>(?P<count>\d+)</td>',
            html,
        )
    }
    expected = {provider.id: provider.enabled_models for provider in counts.by_provider}
    errors = []
    if set(rows) != set(expected):
        missing = ", ".join(sorted(set(expected) - set(rows))) or "none"
        extra = ", ".join(sorted(set(rows) - set(expected))) or "none"
        errors.append(f"{rel}: provider table ids drifted (missing: {missing}; extra: {extra})")
        return errors
    for provider_id, expected_count in expected.items():
        if rows[provider_id] != expected_count:
            errors.append(
                f"{rel}: {provider_id} count is {rows[provider_id]}, expected {expected_count}"
            )
    return errors


def _check_public_drift(root: Path, counts) -> list[str]:
    errors: list[str] = []
    docs = [root / "README.md", root / "FAQ.md", *sorted((root / "docs").glob("**/*"))]
    provider_claim = re.compile(r"\b(?P<n>\d+)\s+(?:LLM\s+)?(?:free\s+)?providers\b")
    live_claim = re.compile(r"\b(?P<n>\d+\+)\s+live-validated\b")
    cataloged_claim = re.compile(r"\b(?P<n>\d+\+)\s+cataloged\b")
    for path in docs:
        if path.name == "POLISH_PLAN.md" or not path.is_file() or path.suffix not in {".md", ".html"}:
            continue
        rel = path.relative_to(root)
        for lineno, line in enumerate(path.read_text(encoding="utf-8").splitlines(), start=1):
            low = line.lower()
            if any(token in low for token in EXTERNAL_CONTEXT):
                continue
            for match in provider_claim.finditer(line):
                if int(match.group("n")) != counts.providers:
                    errors.append(f"{rel}:{lineno}: provider count drift: {line.strip()}")
            for match in live_claim.finditer(line):
                if match.group("n") != counts.live_bucket:
                    errors.append(f"{rel}:{lineno}: live model bucket drift: {line.strip()}")
            for match in cataloged_claim.finditer(line):
                if match.group("n") != counts.cataloged_bucket:
                    errors.append(f"{rel}:{lineno}: cataloged model bucket drift: {line.strip()}")
    return errors


def main(argv: list[str] | None = None) -> int:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--root", type=Path, default=Path(__file__).resolve().parent.parent)
    parser.add_argument("--json", action="store_true", help="print derived counts as JSON")
    args = parser.parse_args(argv)

    root = args.root.resolve()
    counts = catalog_counts(root)
    if args.json:
        print(
            json.dumps(
                {
                    "providers": counts.providers,
                    "enabled_chat_models": counts.enabled_chat_models,
                    "cataloged_chat_models": counts.cataloged_chat_models,
                    "live_bucket": counts.live_bucket,
                    "cataloged_bucket": counts.cataloged_bucket,
                    "enabled_all_models": counts.enabled_all_models,
                    "cataloged_all_models": counts.cataloged_all_models,
                    "by_provider": {
                        provider.id: {
                            "enabled_models": provider.enabled_models,
                            "cataloged_models": provider.cataloged_models,
                        }
                        for provider in counts.by_provider
                    },
                },
                sort_keys=True,
            )
        )
        return 0

    errors = [
        *_check_required_surfaces(root, counts),
        *_check_reference_table(root, counts),
        *_check_public_drift(root, counts),
    ]
    if errors:
        print("Count drift detected:", file=sys.stderr)
        for error in errors:
            print(f"  - {error}", file=sys.stderr)
        return 1
    print(
        "Count check passed: "
        f"{counts.providers} providers, "
        f"{counts.enabled_chat_models} enabled chat models ({counts.live_bucket}), "
        f"{counts.cataloged_chat_models} cataloged chat models ({counts.cataloged_bucket})."
    )
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
