#!/usr/bin/env python3
"""
claude-interactive: PTY wrapper that runs claude in interactive mode for subscription billing.

NEEDLE captures the wrapper's stdout as a pipe (Stdio::piped), so claude would normally
detect a non-TTY stdout and enter programmatic-credit billing mode.  This wrapper creates
an internal PTY so the claude subprocess sees a real TTY on both stdin and stdout, keeping
it in interactive/subscription billing mode.  The wrapper itself then synthesises JSONL
on its own piped stdout, compatible with needle-transform-claude.
"""

import os
import sys
import pty
import select
import time
import signal
import json
import fcntl
import termios
import struct
import argparse

import re
import pyte


# ── Screen dimensions ─────────────────────────────────────────────────────────

_SCREEN_COLS = 220
_SCREEN_ROWS = 2000

# ── TTY observer output ────────────────────────────────────────────────────────

_tty_fd: int | None = None


def _tty_open() -> None:
    global _tty_fd
    try:
        _tty_fd = os.open('/dev/tty', os.O_WRONLY | os.O_NOCTTY)
    except OSError:
        _tty_fd = None


def _tty_write(data: bytes) -> None:
    if _tty_fd is not None:
        try:
            os.write(_tty_fd, data)
        except OSError:
            pass


def _tty_print(text: str) -> None:
    _tty_write((text + '\r\n').encode())


def _set_title(title: str) -> None:
    _tty_write(f'\033]2;{title}\007'.encode())


def _get_tty_size() -> tuple[int, int]:
    """Return (rows, cols) of the controlling terminal, or defaults."""
    if _tty_fd is not None:
        try:
            result = fcntl.ioctl(_tty_fd, termios.TIOCGWINSZ, b'\x00' * 8)
            rows, cols = struct.unpack('HHHH', result)[:2]
            if rows > 0 and cols > 0:
                return rows, cols
        except OSError:
            pass
    return _SCREEN_ROWS, _SCREEN_COLS


def _extract_bead_id(prompt: str) -> str | None:
    m = re.search(r'Bead[- ]?ID[:\s]+([a-zA-Z0-9_-]+)', prompt, re.IGNORECASE)
    return m.group(1) if m else None


# ── Screen-based response extraction ──────────────────────────────────────────

# Chrome patterns: lines that look like separator bars, status indicators, etc.
_BOX_CHARS = frozenset('─│╭╮╰╯╔╗╚╝═║┌┐└┘├┤┬┴┼')
_SKIP_PREFIXES = ('✻', '⏵', '⊕ ─', '⊘ ─', '⊙ ─')


def _is_chrome(line: str) -> bool:
    """Return True if `line` is TUI decoration rather than response content."""
    if not line:
        return True
    # Separator bar: all box-drawing chars and spaces
    if all(c in _BOX_CHARS or c == ' ' for c in line):
        return True
    for prefix in _SKIP_PREFIXES:
        if line.startswith(prefix):
            return True
    # Status bar: contains unicode progress blocks (█▓░)
    if '█' in line or '▓' in line or '░' in line:
        return True
    return False


def extract_response(all_pty_bytes: bytes) -> str:
    """
    Feed the complete PTY byte stream to a virtual terminal, then find the
    region between the user-input line and the final prompt line.
    Returns the cleaned response text.
    """
    screen = pyte.Screen(_SCREEN_COLS, _SCREEN_ROWS)
    pyte.ByteStream(screen).feed(all_pty_bytes)
    rows = screen.display

    # Find separator rows (all box-drawing chars) to locate the "final prompt" region.
    # The new prompt sits between two separator bars:
    #   ─────── separator ───────
    #   ❯ [optional ghost text]   ← new_prompt_row
    #   ─────── separator ───────
    # Without this, ghost text (e.g. "❯ commit this") is mistaken for an input row.
    sep_rows: set[int] = set()
    for i, row in enumerate(rows):
        stripped = row.rstrip()
        if stripped and all(c in _BOX_CHARS or c == ' ' for c in stripped):
            sep_rows.add(i)

    new_prompt_row = -1
    input_row = -1

    # Find the final prompt: the last ❯ row that has a separator within 3 rows
    for i in range(len(rows) - 1, -1, -1):
        stripped = rows[i].rstrip()
        if stripped.startswith('❯') or stripped.startswith('❯'):
            near_sep = any(abs(i - s) <= 3 for s in sep_rows)
            if near_sep:
                new_prompt_row = i
                break

    # If no separator-flanked prompt found, fall back to the last bare ❯
    if new_prompt_row == -1:
        for i in range(len(rows) - 1, -1, -1):
            stripped = rows[i].rstrip()
            if stripped in ('❯', '❯'):
                new_prompt_row = i
                break

    # Find the input row: last ❯ + text row that is not the new_prompt_row
    for i in range(len(rows) - 1, -1, -1):
        if i == new_prompt_row:
            continue
        stripped = rows[i].rstrip()
        if stripped.startswith('❯ ') or stripped.startswith('❯ '):
            input_row = i
            break

    if input_row == -1 or new_prompt_row == -1 or input_row >= new_prompt_row:
        return ''

    # Skip the input continuation display: lines immediately after input_row that
    # are indented with 2 spaces (these are the multi-line input echoed in the TUI,
    # not response content). Response starts at the first ● (tool call) line or the
    # first ⎿ (tool result) line or the first non-indented content line.
    response_start = input_row + 1
    for i in range(input_row + 1, new_prompt_row):
        line = rows[i].rstrip()
        if not line:
            continue
        if line.startswith('●') or line.startswith('●') or '⎿' in line[:4]:
            response_start = i
            break
        if not line.startswith('  '):
            response_start = i
            break

    response_lines = []
    for i in range(response_start, new_prompt_row):
        line = rows[i].rstrip()
        if _is_chrome(line):
            continue
        # Strip response-bullet prefix (● or ● with leading spaces)
        if line.startswith('●'):
            line = line[1:].lstrip()
        elif '●' in line[:4]:
            line = line[line.index('●') + 1:].lstrip()
        if line:
            response_lines.append(line)

    return '\n'.join(response_lines).strip()


# ── JSONL emission ─────────────────────────────────────────────────────────────

def emit(obj: dict) -> None:
    sys.stdout.write(json.dumps(obj) + '\n')
    sys.stdout.flush()


def emit_success(text: str, model: str, elapsed_ms: int) -> None:
    if text:
        emit({
            "type": "assistant",
            "message": {
                "id": "msg_interactive",
                "type": "message",
                "role": "assistant",
                "content": [{"type": "text", "text": text}],
                "model": model,
                "stop_reason": "end_turn",
                "usage": {
                    "input_tokens": 0,
                    "output_tokens": 0,
                    "cache_creation_input_tokens": 0,
                    "cache_read_input_tokens": 0,
                },
            },
        })
    emit({
        "type": "result",
        "subtype": "success",
        "is_error": False,
        "duration_ms": elapsed_ms,
        "cost_usd": 0,
        "session_id": "interactive",
        "num_turns": 1,
    })


def emit_error(message: str, code: str = "error_during_execution") -> None:
    emit({
        "type": "result",
        "subtype": code,
        "is_error": True,
        "error_message": message,
        "cost_usd": 0,
        "session_id": "interactive",
        "num_turns": 0,
    })


# ── PTY session ───────────────────────────────────────────────────────────────

# Claude startup sequence:
#  1. Initial burst: trust-folder dialog (new workspace) + terminal queries
#  2. Wrapper responds to terminal queries; sends CR to dismiss any trust dialog
#  3. Second burst: actual REPL renders
#  4. Wrapper sends the real prompt after a quiet gap
#
# For already-trusted workspaces there is no trust dialog; the extra CR at the
# `>` prompt is harmless (claude just shows another prompt).
_GAP_ENTER   = 0.8   # gap after initial burst → send CR
_GAP_PROMPT  = 2.0   # gap after CR response → send actual prompt
_MIN_STARTUP = 200   # minimum startup bytes before gap detection triggers

STARTUP_WAIT      = 45.0    # hard fallback if burst never settles
IDLE_TIMEOUT      = 30.0    # idle after last PTY chunk = task done (once response started)
PRE_RESP_TIMEOUT  = 120.0   # max wait from prompt-sent to first response byte
MAX_WALL_TIME     = 3600.0  # hard cap per invocation

# Response threshold: PTY bytes after prompt needed to consider a response "started".
# The input display (echoed prompt) is typically 5-20 KB.  A real response adds tool
# call output / text on top of that.  We use byte count growth as a proxy.
_RESPONSE_BYTES_THRESHOLD = 30_000


def run_session(prompt: str, model: str, max_turns: int) -> int:
    _tty_open()
    bead_id = _extract_bead_id(prompt) or 'unknown'
    _set_title(f'NEEDLE  {bead_id}  starting…')
    _tty_print(f'\033[1;36m── NEEDLE worker: {bead_id}  model={model} ──\033[0m')

    start = time.monotonic()
    master, slave = pty.openpty()

    # Use the real terminal size for the PTY so forwarded output renders correctly.
    # Fall back to defaults if /dev/tty is unavailable.
    pty_rows, pty_cols = _get_tty_size()
    # Keep minimum dimensions for pyte extraction (pyte screen is always _SCREEN_ROWS x _SCREEN_COLS)
    pty_rows = max(pty_rows, 24)
    pty_cols = max(pty_cols, _SCREEN_COLS)
    fcntl.ioctl(slave, termios.TIOCSWINSZ,
                struct.pack('HHHH', pty_rows, pty_cols, 0, 0))

    pid = os.fork()
    if pid == 0:
        # ── Child: exec claude attached to the PTY slave ──
        os.setsid()
        try:
            fcntl.ioctl(slave, termios.TIOCSCTTY, 0)
        except OSError:
            pass
        for fd in (0, 1, 2):
            os.dup2(slave, fd)
        os.close(master)
        os.close(slave)
        os.execvp('claude', [
            'claude',
            '--dangerously-skip-permissions',
            '--model', model,
            '--max-turns', str(max_turns),
        ])
        sys.exit(127)

    os.close(slave)

    all_buf          = b''   # complete PTY stream for pyte rendering
    startup_buf      = b''   # bytes received before prompt is sent
    enter_sent       = False
    prompt_sent      = False
    prompt_sent_time = 0.0
    # Response detection: watch for ● (U+25CF, UTF-8: \xe2\x97\x8f) which claude
    # uses for tool calls and assistant bullets.  Its first appearance marks the
    # start of the actual response (distinct from the echoed input display).
    response_started = False
    last_chunk      = time.monotonic()
    last_data       = time.monotonic()
    term_ack        : set[str] = set()

    def respond_terminal_queries(chunk: bytes) -> None:
        # The master side IS the terminal emulator; we must answer DA queries or
        # claude will block waiting for responses.
        if b'\x1b[c' in chunk and 'da1' not in term_ack:
            os.write(master, b'\x1b[?6c')
            term_ack.add('da1')
        if b'\x1b[>0q' in chunk and 'xtver' not in term_ack:
            os.write(master, b'\x1bP>|wrapper-1.0\x1b\\')
            term_ack.add('xtver')

    try:
        while True:
            now = time.monotonic()

            if now - start > MAX_WALL_TIME:
                emit_error(f"wall-clock timeout after {MAX_WALL_TIME:.0f}s", "error_max_turns")
                return 1

            r, _, _ = select.select([master], [], [], 0.25)

            if r:
                try:
                    chunk = os.read(master, 8192)
                except OSError:
                    break
                if not chunk:
                    break

                all_buf += chunk
                _tty_write(chunk)  # forward live PTY output to observer
                last_chunk = time.monotonic()
                if not prompt_sent:
                    startup_buf += chunk
                    respond_terminal_queries(chunk)
                else:
                    last_data = time.monotonic()
                    if not response_started and b'\xe2\x97\x8f' in chunk:
                        response_started = True
                        elapsed = int(time.monotonic() - start)
                        _set_title(f'NEEDLE  {bead_id}  responding  {elapsed}s')
            else:
                gap = now - last_chunk
                if not enter_sent:
                    # Phase 1: after startup burst settles, send CR to dismiss
                    # any trust dialog and to get to the `>` prompt.
                    enough_data   = len(startup_buf) >= _MIN_STARTUP
                    gap_triggered = gap >= _GAP_ENTER
                    timed_out     = (now - start) > STARTUP_WAIT
                    if (enough_data and gap_triggered) or timed_out:
                        if not startup_buf:
                            emit_error("claude produced no output within startup window")
                            return 1
                        os.write(master, b'\r')
                        enter_sent = True
                        last_chunk = time.monotonic()
                        _set_title(f'NEEDLE  {bead_id}  prompting…')
                elif not prompt_sent:
                    # Phase 2: REPL re-renders after CR; send real prompt once settled.
                    if gap >= _GAP_PROMPT:
                        time.sleep(0.15)
                        # Use bracketed paste mode so the REPL treats embedded \n as
                        # literal newlines (not Enter presses that would split the prompt
                        # into multiple turns).
                        payload = prompt.encode('utf-8')
                        os.write(master, b'\x1b[200~' + payload + b'\x1b[201~')
                        time.sleep(0.5)  # Let the REPL display the paste before submitting
                        os.write(master, b'\r')
                        prompt_sent = True
                        prompt_sent_time = time.monotonic()
                        last_data = time.monotonic()
                        _set_title(f'NEEDLE  {bead_id}  thinking…')
                elif response_started:
                    # Response has begun (saw ●); end after short idle.
                    if (now - last_data) >= IDLE_TIMEOUT:
                        break
                else:
                    # No ● yet; give claude up to PRE_RESP_TIMEOUT to start responding.
                    if (now - prompt_sent_time) >= PRE_RESP_TIMEOUT:
                        break

    finally:
        try:
            os.write(master, b'/exit\r')
            time.sleep(0.3)
        except OSError:
            pass
        try:
            os.kill(pid, signal.SIGTERM)
        except ProcessLookupError:
            pass
        try:
            os.waitpid(pid, 0)
        except ChildProcessError:
            pass
        try:
            os.close(master)
        except OSError:
            pass

    elapsed_ms = int((time.monotonic() - start) * 1000)
    elapsed_s = elapsed_ms // 1000

    if not prompt_sent:
        _set_title(f'NEEDLE  {bead_id}  ERROR: claude did not start')
        emit_error("claude did not start")
        return 1

    response_text = extract_response(all_buf)
    if response_text:
        _set_title(f'NEEDLE  {bead_id}  done  {elapsed_s}s')
        _tty_print(f'\033[1;32m── done: {bead_id}  {elapsed_s}s ──\033[0m')
    else:
        _set_title(f'NEEDLE  {bead_id}  WARN: empty response  {elapsed_s}s')
        _tty_print(f'\033[1;33m── warning: empty response for {bead_id} ──\033[0m')
    emit_success(response_text, model, elapsed_ms)
    return 0


# ── CLI ───────────────────────────────────────────────────────────────────────

def main() -> None:
    parser = argparse.ArgumentParser(add_help=False)
    parser.add_argument('-p', '--print', dest='pipe_compat', action='store_true')
    parser.add_argument('--dangerously-skip-permissions', dest='skip_perms', action='store_true')
    parser.add_argument('--model', default='claude-sonnet-4-6')
    parser.add_argument('--max-turns', dest='max_turns', type=int, default=30)
    parser.add_argument('prompt', nargs='?')
    args, _ = parser.parse_known_args()

    if args.prompt:
        prompt = args.prompt
    elif not sys.stdin.isatty():
        prompt = sys.stdin.read().strip()
    else:
        emit_error("no prompt provided")
        sys.exit(1)

    if not prompt:
        emit_error("empty prompt")
        sys.exit(1)

    sys.exit(run_session(prompt, args.model, args.max_turns))


if __name__ == '__main__':
    main()
