#!/tmp/sdk-test/bin/python3
"""Warm repair agent pool — Unix socket dispatch + filesystem fallback.

The architect agent calls `dispatch-fix` which sends requests over a Unix
socket to this daemon. A warm Claude worker applies edits instantly and
handles cascading errors.

Also watches .beagle/fixes/ for JSON fix plans (legacy/fallback path)
and monitors file mtimes for moved-on detection.

Uses the Claude Code SDK (ClaudeSDKClient) for persistent sessions.

Usage:
    beagle-pool <dir> [--pool-size N] [--model MODEL] [--max-cost USD]
"""

import argparse
import asyncio
import glob
import json
import os
import signal
import subprocess
import sys
import time

try:
    import claude_code_sdk as sdk
    from claude_code_sdk._internal import message_parser as mp

    _orig_parse = mp.parse_message
    def _safe_parse(data):
        try:
            return _orig_parse(data)
        except Exception:
            return None
    mp.parse_message = _safe_parse
except ImportError:
    print("error: claude-code-sdk not installed. Run: uv pip install claude-code-sdk",
          file=sys.stderr)
    sys.exit(1)

PROJECT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
DAEMON_BIN = os.path.join(PROJECT_DIR, "bin", "beagle-daemon")
POOL_LOG = os.path.join(PROJECT_DIR, ".beagle", "pool-log.jsonl")
PROMPT_FILE = os.path.join(PROJECT_DIR, ".beagle", "repair-agent-prompt.md")
FIXES_DIR = os.path.join(PROJECT_DIR, ".beagle", "fixes")
SOCK_PATH = os.path.join(PROJECT_DIR, ".beagle", "pool.sock")


def log_event(event, **kwargs):
    entry = {"ts": time.time(), "event": event, **kwargs}
    try:
        os.makedirs(os.path.dirname(POOL_LOG), exist_ok=True)
        with open(POOL_LOG, "a") as f:
            f.write(json.dumps(entry) + "\n")
    except Exception:
        pass


def check_file(file_path):
    """Query daemon for enriched errors. Returns (error_count, error_text)."""
    try:
        result = subprocess.run(
            [DAEMON_BIN, "query", "check-enriched", file_path],
            capture_output=True, text=True, timeout=10
        )
        if result.returncode != 0:
            return 0, ""
        parsed = json.loads(result.stdout)
        r = parsed.get("result")
        if not r or r == "null":
            return 0, ""
        errors = r.get("error_count", 0)
        if errors == 0:
            return 0, ""
        auto = r.get("auto_fixable", 0)
        non_auto = errors - auto
        if non_auto <= 0:
            return 0, ""
        lines = [f"{os.path.basename(file_path)}: {errors} errors ({auto} auto-fixable)"]
        for e in r.get("errors", []):
            lines.append(f"  L{e.get('line', 0)} [{e.get('kind', '?')}]: {e.get('message', '')}")
            fp = e.get("fix_plan")
            if fp and isinstance(fp, dict):
                hint = fp.get("fix-hint", "")
                if hint:
                    lines.append(f"    -> {hint}")
        return non_auto, "\n".join(lines)
    except Exception:
        return 0, ""


class Worker:
    """A warm Claude session that accepts repair tasks."""

    def __init__(self, worker_id, model, system_prompt, watch_dir):
        self.id = worker_id
        self.model = model
        self.system_prompt = system_prompt
        self.watch_dir = watch_dir
        self.client = None
        self.busy = False
        self.current_file = None
        self.tasks_completed = 0

    async def connect(self):
        opts = sdk.ClaudeCodeOptions(
            allowed_tools=["Read", "Edit", "Bash"],
            model=self.model,
            permission_mode="bypassPermissions",
            append_system_prompt=self.system_prompt,
            cwd=self.watch_dir,
        )
        self.client = sdk.ClaudeSDKClient(opts)
        await self.client.connect()

    async def dispatch_fix(self, file_path, fixes):
        """Apply fixes from a dispatch-fix request."""
        self.busy = True
        self.current_file = file_path
        t0 = time.time()

        fix_lines = []
        for i, f in enumerate(fixes, 1):
            fix_lines.append(f'{i}. Line {f["line"]}: change `{f["old"]}` to `{f["new"]}`')

        prompt = (
            f"Apply these fixes to {file_path}:\n\n"
            + "\n".join(fix_lines)
            + "\n\nUse the Edit tool for each fix. After all edits, run:\n"
            f"  {DAEMON_BIN} query check-enriched {file_path}\n"
            "If new errors appear from your edits, fix them too."
        )

        return await self._run(prompt, t0)

    async def dispatch_error_repair(self, file_path, error_text):
        """Repair errors found by the moved-on watcher."""
        self.busy = True
        self.current_file = file_path
        t0 = time.time()

        prompt = (
            f"Fix the type errors in {file_path}.\n\n"
            f"Error diagnostics:\n{error_text}\n\n"
            f"Read the file, fix each error, and report what you changed."
        )

        return await self._run(prompt, t0)

    async def _run(self, prompt, t0):
        result_lines = []
        cost = 0.0
        try:
            await self.client.query(prompt)
            async for msg in self.client.receive_response():
                if msg is None:
                    continue
                if isinstance(msg, sdk.AssistantMessage):
                    for block in msg.content:
                        if isinstance(block, sdk.TextBlock):
                            result_lines.append(block.text)
                elif isinstance(msg, sdk.ResultMessage):
                    cost = getattr(msg, "total_cost_usd", 0.0) or 0.0
                    break
            self.tasks_completed += 1
            duration = round(time.time() - t0, 1)
            return True, "\n".join(result_lines), duration, cost
        except Exception as e:
            duration = round(time.time() - t0, 1)
            return False, str(e), duration, cost
        finally:
            self.busy = False
            self.current_file = None

    async def disconnect(self):
        if self.client:
            try:
                await self.client.disconnect()
            except Exception:
                pass


class Pool:
    """Manages warm workers and dispatches tasks."""

    def __init__(self, watch_dir, pool_size, model, max_cost):
        self.watch_dir = watch_dir
        self.pool_size = pool_size
        self.model = model
        self.max_cost = max_cost
        self.total_cost = 0.0
        self.total_dispatches = 0
        self.workers = []
        self.active_files = set()
        self.system_prompt = ""

    async def start(self):
        if os.path.exists(PROMPT_FILE):
            with open(PROMPT_FILE) as f:
                self.system_prompt = f.read()

        subprocess.run([DAEMON_BIN, "query", "watch", self.watch_dir],
                       capture_output=True, timeout=5)

        os.makedirs(FIXES_DIR, exist_ok=True)

        print(f"beagle-pool: connecting {self.pool_size} warm worker(s)...", flush=True)
        for i in range(self.pool_size):
            w = Worker(f"w{i}", self.model, self.system_prompt, self.watch_dir)
            try:
                await w.connect()
                self.workers.append(w)
                print(f"  worker {w.id} connected", flush=True)
            except Exception as e:
                print(f"  worker w{i} failed to connect: {e}", file=sys.stderr, flush=True)

        if not self.workers:
            print("error: no workers connected", file=sys.stderr)
            sys.exit(1)

        log_event("pool_start", directory=self.watch_dir,
                  workers=len(self.workers))

    def get_idle_worker(self):
        for w in self.workers:
            if not w.busy:
                return w
        return None

    def _check_budget(self):
        if self.total_cost >= self.max_cost:
            print(f"  SKIP (cost cap ${self.max_cost:.2f} reached, "
                  f"spent ${self.total_cost:.2f})", flush=True)
            return False
        return True

    async def dispatch(self, file_path, fixes, source="socket"):
        """Dispatch a fix request (from socket or fix plan)."""
        if not self._check_budget():
            return False, "cost cap reached"

        worker = self.get_idle_worker()
        if not worker:
            return False, "no idle worker"

        basename = os.path.basename(file_path)
        n_fixes = len(fixes)
        self.active_files.add(file_path)
        self.total_dispatches += 1

        log_event("dispatch", worker=worker.id, file=basename,
                  fixes=n_fixes, source=source,
                  dispatch_num=self.total_dispatches)
        print(f"  DISPATCH {worker.id} → {basename} ({n_fixes} fix(es), "
              f"#{self.total_dispatches}, via {source})", flush=True)

        async def run_and_log():
            ok, result, duration, cost = await worker.dispatch_fix(file_path, fixes)
            self.active_files.discard(file_path)
            self.total_cost += cost
            status = "done" if ok else "failed"
            log_event(status, worker=worker.id, file=basename,
                      duration_s=duration, cost_usd=cost,
                      total_cost_usd=self.total_cost, source=source)
            print(f"  {status.upper()} {worker.id} ({basename}, {duration}s, "
                  f"${cost:.4f}, total ${self.total_cost:.2f}/${self.max_cost:.2f})",
                  flush=True)

        asyncio.create_task(run_and_log())
        return True, "dispatched"

    async def dispatch_error(self, file_path, error_text):
        """Dispatch an error repair (moved-on fallback)."""
        if not self._check_budget():
            return

        worker = self.get_idle_worker()
        if not worker:
            return

        basename = os.path.basename(file_path)
        self.active_files.add(file_path)
        self.total_dispatches += 1

        log_event("dispatch_error", worker=worker.id, file=basename,
                  dispatch_num=self.total_dispatches)
        print(f"  DISPATCH {worker.id} → {basename} (errors, "
              f"#{self.total_dispatches})", flush=True)

        async def run_and_log():
            ok, result, duration, cost = await worker.dispatch_error_repair(
                file_path, error_text)
            self.active_files.discard(file_path)
            self.total_cost += cost
            status = "done" if ok else "failed"
            log_event(status, worker=worker.id, file=basename,
                      duration_s=duration, cost_usd=cost,
                      total_cost_usd=self.total_cost, source="moved_on")
            print(f"  {status.upper()} {worker.id} ({basename}, {duration}s, "
                  f"${cost:.4f}, total ${self.total_cost:.2f}/${self.max_cost:.2f})",
                  flush=True)

        asyncio.create_task(run_and_log())

    def is_handled(self, file_path):
        return file_path in self.active_files

    def status(self):
        return {
            "workers": len(self.workers),
            "busy": sum(1 for w in self.workers if w.busy),
            "dispatches": self.total_dispatches,
            "cost_usd": round(self.total_cost, 4),
            "max_cost_usd": self.max_cost,
            "active_files": list(self.active_files),
        }

    async def stop(self):
        for w in self.workers:
            await w.disconnect()
        log_event("pool_stop", total_dispatches=self.total_dispatches,
                  total_cost_usd=self.total_cost)


async def handle_socket_client(reader, writer, pool):
    """Handle one dispatch-fix connection."""
    try:
        data = await asyncio.wait_for(reader.read(65536), timeout=5.0)
        msg = json.loads(data.decode())

        cmd = msg.get("cmd", "")

        if cmd == "status":
            resp = pool.status()
        elif cmd == "fix":
            file_path = msg.get("file", "")
            fixes = msg.get("fixes", [])
            if not file_path or not fixes:
                resp = {"ok": False, "msg": "missing file or fixes"}
            else:
                ok, reason = await pool.dispatch(file_path, fixes, source="socket")
                resp = {"ok": ok, "msg": reason}
        else:
            resp = {"ok": False, "msg": f"unknown cmd: {cmd}"}

        writer.write(json.dumps(resp).encode())
        await writer.drain()
    except Exception as e:
        try:
            writer.write(json.dumps({"ok": False, "msg": str(e)}).encode())
            await writer.drain()
        except Exception:
            pass
    finally:
        writer.close()


def scan_mtimes(directory):
    mtimes = {}
    for ext in ('.bclj', '.bjs', '.bnix', '.bsql', '.bpy', '.bgl', '.rkt'):
        for f in glob.glob(os.path.join(directory, "*" + ext)):
            try:
                mtimes[f] = os.path.getmtime(f)
            except OSError:
                pass
    return mtimes


def scan_fix_plans():
    plans = []
    if not os.path.exists(FIXES_DIR):
        return plans
    for f in glob.glob(os.path.join(FIXES_DIR, "*.json")):
        try:
            with open(f) as fh:
                plan = json.load(fh)
            if "file" in plan and "fixes" in plan:
                plans.append((plan, f))
        except (json.JSONDecodeError, OSError):
            pass
    return plans


async def main():
    parser = argparse.ArgumentParser(description="Warm repair agent pool")
    parser.add_argument("directory", help="Directory to watch")
    parser.add_argument("--pool-size", type=int, default=1, choices=[1, 2, 3])
    parser.add_argument("--model", default="sonnet")
    parser.add_argument("--max-cost", type=float, default=2.00,
                        help="Stop dispatching after this total USD spend")
    parser.add_argument("--poll-interval", type=float, default=1.0)
    args = parser.parse_args()

    watch_dir = os.path.abspath(args.directory)
    if not os.path.isdir(watch_dir):
        print(f"error: {watch_dir} is not a directory", file=sys.stderr)
        sys.exit(1)

    pool = Pool(watch_dir, args.pool_size, args.model, args.max_cost)
    await pool.start()

    # Clean up stale socket
    if os.path.exists(SOCK_PATH):
        os.unlink(SOCK_PATH)

    # Start Unix socket server
    sock_server = await asyncio.start_unix_server(
        lambda r, w: handle_socket_client(r, w, pool),
        path=SOCK_PATH
    )
    print(f"beagle-pool: socket listening at {SOCK_PATH}", flush=True)

    prev_mtimes = scan_mtimes(watch_dir)
    current_file = None
    running = True

    loop = asyncio.get_event_loop()
    stop_event = asyncio.Event()

    def on_signal():
        nonlocal running
        running = False
        stop_event.set()

    for sig in (signal.SIGINT, signal.SIGTERM):
        loop.add_signal_handler(sig, on_signal)

    rkt_count = len(prev_mtimes)
    print(f"beagle-pool: watching {watch_dir} ({rkt_count} beagle files, "
          f"{len(pool.workers)} warm worker(s), ${args.max_cost:.2f} cap)",
          flush=True)

    while running:
        try:
            await asyncio.wait_for(stop_event.wait(), timeout=args.poll_interval)
            break
        except asyncio.TimeoutError:
            pass

        # Check for fix plans from .beagle/fixes/ (legacy/fallback)
        if pool.get_idle_worker():
            for plan, plan_file in scan_fix_plans():
                file_path = plan["file"]
                if not pool.is_handled(file_path):
                    ok, _ = await pool.dispatch(
                        file_path, plan["fixes"], source="fixplan")
                    if ok:
                        try:
                            os.unlink(plan_file)
                        except OSError:
                            pass
                        break

        # Moved-on watcher (fallback)
        curr_mtimes = scan_mtimes(watch_dir)
        changed = [f for f in curr_mtimes
                   if f not in prev_mtimes or curr_mtimes[f] != prev_mtimes[f]]

        if not changed:
            prev_mtimes = curr_mtimes
            continue

        await asyncio.sleep(0.5)
        curr_mtimes = scan_mtimes(watch_dir)
        changed = [f for f in curr_mtimes
                   if f not in prev_mtimes or curr_mtimes[f] != prev_mtimes[f]]

        if not changed:
            prev_mtimes = curr_mtimes
            continue

        most_recent = max(changed, key=lambda f: curr_mtimes[f])
        print(f"[{time.strftime('%H:%M:%S')}] edit: {os.path.basename(most_recent)}",
              flush=True)

        prev_file = current_file
        current_file = os.path.abspath(most_recent)

        if prev_file and prev_file != current_file:
            if not pool.is_handled(prev_file) and pool.get_idle_worker():
                error_count, error_text = check_file(prev_file)
                if error_count > 0:
                    log_event("moved_on",
                              source=os.path.basename(current_file),
                              abandoned=os.path.basename(prev_file),
                              errors=error_count)
                    print(f"  Agent moved to {os.path.basename(current_file)}",
                          flush=True)
                    await pool.dispatch_error(prev_file, error_text)

        prev_mtimes = curr_mtimes

    # Cleanup
    sock_server.close()
    await sock_server.wait_closed()
    if os.path.exists(SOCK_PATH):
        os.unlink(SOCK_PATH)
    await pool.stop()
    print(f"beagle-pool: stopped ({pool.total_dispatches} dispatches, "
          f"${pool.total_cost:.2f} spent)", flush=True)


if __name__ == "__main__":
    asyncio.run(main())
