#!/usr/bin/env bash
# beagle-specfix: oracle-guided speculative fix tool.
#
# Given a build directory of compiled .clj files and an oracle (verify script
# with assert-eq assertions), automatically generates and validates candidate
# fixes for logic bugs that the type system cannot catch.
#
# Usage:
#   beagle-specfix <build-dir> <verify-script>
#
# Strategy:
#   1. Run the oracle, capture failing assertions with expected/actual values
#   2. For each numeric failure, compute the expected/actual ratio
#   3. Based on the ratio, generate candidate fixes (operator swaps, etc.)
#   4. Apply each candidate to a temp copy (metadata-stripped for matching),
#      rerun the full oracle
#   5. Report verified fixes that don't introduce regressions
#
# Output:
#   SPECFIX: <assertion-label>
#     file: <module>.clj
#     function: <fn-name>
#     fix: replace `<old>` with `<new>` at line N
#     confidence: <0-1>
#     verified: oracle passes with this fix applied

set -euo pipefail

if [[ $# -lt 2 ]]; then
    echo "Usage: beagle-specfix <build-dir> <verify-script>" >&2
    exit 1
fi

BUILD_DIR="$(cd "$1" && pwd)"
VERIFY_SCRIPT="$(realpath "$2")"
WORK_DIR="$(mktemp -d /tmp/specfix.XXXXXX)"

trap 'rm -rf "$WORK_DIR"' EXIT

# ---------------------------------------------------------------------------
# Phase 1: Run oracle, capture failures
# ---------------------------------------------------------------------------

echo "=== Phase 1: Running oracle ===" >&2

FAILURES_FILE="$WORK_DIR/failures.txt"

bb -cp "$BUILD_DIR" -e "(load-file \"$VERIFY_SCRIPT\")" 2>&1 \
    | grep -A2 "^FAIL:" > "$FAILURES_FILE" || true

FAIL_COUNT="$(grep -c "^FAIL:" "$FAILURES_FILE" 2>/dev/null)" || FAIL_COUNT=0
echo "  Found $FAIL_COUNT failing assertions" >&2

if [[ "$FAIL_COUNT" -eq 0 ]]; then
    echo "No failures found. Nothing to fix." >&2
    exit 0
fi

# ---------------------------------------------------------------------------
# Phase 2+3: Analyze, generate candidates, verify
# ---------------------------------------------------------------------------

export SPECFIX_FAILURES="$FAILURES_FILE"
export SPECFIX_BUILD_DIR="$BUILD_DIR"
export SPECFIX_WORK_DIR="$WORK_DIR"
export SPECFIX_VERIFY="$VERIFY_SCRIPT"

python3 - << 'PYEOF'
import sys, os, re, subprocess, shutil, tempfile

failures_file = os.environ['SPECFIX_FAILURES']
build_dir = os.environ['SPECFIX_BUILD_DIR']
work_dir = os.environ['SPECFIX_WORK_DIR']
verify_script = os.environ['SPECFIX_VERIFY']

META_RE = re.compile(r'\^\{[^}]*\}\s*')

def strip_meta(text):
    """Remove ^{:line N :file "..."} metadata annotations."""
    return META_RE.sub('', text)

# ---------------------------------------------------------------------------
# Parse failures
# ---------------------------------------------------------------------------

failures = []
with open(failures_file) as f:
    lines = f.readlines()

i = 0
while i < len(lines):
    line = lines[i].rstrip()
    m = re.match(r"^FAIL: (.+)", line)
    if m:
        label = m.group(1)
        expected = None
        actual = None
        if i+1 < len(lines):
            em = re.match(r"\s+expected:\s+(.+)", lines[i+1].rstrip())
            if em:
                expected = em.group(1)
        if i+2 < len(lines):
            am = re.match(r"\s+actual:\s+(.+)", lines[i+2].rstrip())
            if am:
                actual = am.group(1)
        failures.append({
            'label': label,
            'expected': expected,
            'actual': actual,
        })
        i += 3
    else:
        i += 1

print(f"  Parsed {len(failures)} failures", file=sys.stderr)

# ---------------------------------------------------------------------------
# Module / function mapping
# ---------------------------------------------------------------------------

PREFIX_MAP = {
    'cat/': 'catalog',
    'cust/': 'customers',
    'inv/': 'inventory',
    'ord/': 'orders',
    'rep/': 'reports',
    'ship/': 'shipping',
    'bill/': 'billing',
    'proc/': 'procurement',
    'promo/': 'promotions',
    'emp/': 'employees',
    'analytics/': 'analytics',
    'notif/': 'notifications',
    'audit/': 'audit',
}

def label_to_module(label):
    for prefix, module in PREFIX_MAP.items():
        if label.startswith(prefix):
            return module
    ll = label.lower()
    if 'order' in ll: return 'orders'
    if 'ship' in ll: return 'shipping'
    if 'invoice' in ll or 'bill' in ll: return 'billing'
    if 'rep/' in ll or 'report' in ll: return 'reports'
    return None

def label_to_function(label):
    m = re.match(r"[a-z]+/([a-zA-Z0-9_-]+)", label)
    if m:
        return m.group(1)
    return None

# ---------------------------------------------------------------------------
# File reading with metadata handling
# ---------------------------------------------------------------------------

def read_module(module):
    """Read a module file, return (raw_lines, clean_lines)."""
    path = os.path.join(build_dir, f"{module}.clj")
    if not os.path.exists(path):
        return None, None, None
    with open(path) as f:
        raw_content = f.read()
    raw_lines = raw_content.split('\n')
    clean_lines = [strip_meta(l) for l in raw_lines]
    return path, raw_lines, clean_lines

def find_function_range(clean_lines, fn_name):
    """Find start/end line indices for a function."""
    if fn_name is None:
        return 0, len(clean_lines)

    start = None
    for i, line in enumerate(clean_lines):
        if re.search(r'\(defn\s+' + re.escape(fn_name) + r'\b', line):
            start = i
            break

    if start is None:
        return 0, len(clean_lines)

    end = len(clean_lines)
    for i in range(start + 1, len(clean_lines)):
        cl = clean_lines[i].lstrip()
        if cl.startswith('(defn ') or cl.startswith('(defrecord '):
            end = i
            break

    return start, end

# ---------------------------------------------------------------------------
# Record accessor registry (parsed from compiled .clj files)
# ---------------------------------------------------------------------------

_record_accessor_cache = {}  # module -> {record_name: [accessors]}

def build_accessor_registry(module):
    """Parse a compiled .clj module, extract record->accessor mappings."""
    if module in _record_accessor_cache:
        return _record_accessor_cache[module]
    path, raw_lines, clean_lines = read_module(module)
    if path is None:
        _record_accessor_cache[module] = {}
        return {}
    registry = {}
    record_names = []
    for cl in clean_lines:
        m = re.match(r'\(defrecord\s+(\w+)\s+\[([^\]]*)\]\)', cl)
        if m:
            rec_name = m.group(1)
            lower = rec_name[0].lower() + rec_name[1:]
            lower_full = rec_name.lower()
            record_names.append((rec_name, lower, lower_full))
    for cl in clean_lines:
        m = re.match(r'\(defn\s+([\w?!<>*+\-/]+)\s+\[r\]\s+\(:(\w[\w-]*)\s+r\)\)', cl)
        if m:
            accessor_name = m.group(1)
            for rec_name, lower, lower_full in record_names:
                if accessor_name.startswith(lower + '-') or accessor_name.startswith(lower_full + '-'):
                    registry.setdefault(rec_name, []).append(accessor_name)
                    break
    _record_accessor_cache[module] = registry
    return registry

def build_accessor_registry_all():
    """Build accessor registry for all modules in build_dir."""
    for fname in os.listdir(build_dir):
        if fname.endswith('.clj'):
            build_accessor_registry(fname[:-4])

def find_record_for_accessor(accessor_name):
    """Given an accessor name, find its record and all sibling accessors."""
    for module, registry in _record_accessor_cache.items():
        for rec_name, accessors in registry.items():
            if accessor_name in accessors:
                return rec_name, accessors, module
    return None, [], None

# ---------------------------------------------------------------------------
# Function parameter cache (for arg-swap detection)
# ---------------------------------------------------------------------------

_fn_params_cache = {}

def build_fn_params(module, clean_lines):
    """Extract parameter names for each defn in a module's compiled output."""
    if module in _fn_params_cache:
        return _fn_params_cache[module]
    params_map = {}
    for cl in clean_lines:
        m = re.match(r'\(defn\s+([\w?!<>*+\-/]+)\s+\[([^\]]*)\]', cl)
        if m:
            fn_name = m.group(1)
            params_str = m.group(2).strip()
            params = params_str.split() if params_str else []
            if len(params) > 1:
                params_map[fn_name] = params
    _fn_params_cache[module] = params_map
    return params_map

# ---------------------------------------------------------------------------
# Balanced-paren argument extraction
# ---------------------------------------------------------------------------

def extract_call_args(line, call_start):
    """Extract arguments from a function call starting at call_start."""
    pos = call_start + 1
    while pos < len(line) and not line[pos].isspace():
        pos += 1
    while pos < len(line) and line[pos].isspace():
        pos += 1
    args = []
    while pos < len(line) and line[pos] != ')':
        while pos < len(line) and line[pos].isspace():
            pos += 1
        if pos >= len(line) or line[pos] == ')':
            break
        if line[pos] == '(':
            depth = 0
            arg_start = pos
            while pos < len(line):
                if line[pos] == '(':
                    depth += 1
                elif line[pos] == ')':
                    depth -= 1
                    if depth == 0:
                        pos += 1
                        break
                pos += 1
            args.append(line[arg_start:pos])
        elif line[pos] == '"':
            arg_start = pos
            pos += 1
            while pos < len(line) and line[pos] != '"':
                if line[pos] == '\\':
                    pos += 1
                pos += 1
            pos += 1
            args.append(line[arg_start:pos])
        else:
            arg_start = pos
            while pos < len(line) and not line[pos].isspace() and line[pos] != ')':
                pos += 1
            args.append(line[arg_start:pos])
    return args if args else None

# ---------------------------------------------------------------------------
# Candidate generation
# ---------------------------------------------------------------------------

def compute_ratio(expected, actual):
    try:
        e = float(expected)
        a = float(actual)
    except (ValueError, TypeError):
        return None
    if e == 0:
        return None
    return a / e

def generate_candidates(failure):
    """Generate candidate fixes for a failure."""
    label = failure['label']
    expected = failure['expected']
    actual = failure['actual']
    module = label_to_module(label)
    fn_name = label_to_function(label)

    if module is None:
        return []

    path, raw_lines, clean_lines = read_module(module)
    if path is None:
        return []

    ratio = compute_ratio(expected, actual)
    candidates = []
    start, end = find_function_range(clean_lines, fn_name)

    # Strategy 1: ratio == -1 -> swap operands in subtraction
    if ratio is not None and abs(ratio - (-1.0)) < 0.01:
        for i in range(start, end):
            cl = clean_lines[i]
            # Match (- (expr1) (expr2))
            for m in re.finditer(r'\(-\s+(\([^)]+\))\s+(\([^)]+\))\)', cl):
                arg1 = m.group(1)
                arg2 = m.group(2)
                candidates.append({
                    'label': label,
                    'module': module,
                    'fn_name': fn_name,
                    'line': i + 1,
                    'line_idx': i,
                    'search': f'(- {arg1} {arg2})',
                    'replace': f'(- {arg2} {arg1})',
                    'description': f'swap operands: `(- {arg1} {arg2})` -> `(- {arg2} {arg1})`',
                    'confidence': 0.85,
                    'type': 'operand-swap',
                })

    # Strategy 2: ratio > 9 -> try + to * (value way too large)
    if ratio is not None and ratio > 9:
        for i in range(start, end):
            cl = clean_lines[i]
            for m in re.finditer(r'\(\+\s+(\S+)\s+(\S+|\([^)]+\))\)', cl):
                arg1 = m.group(1)
                arg2 = m.group(2)
                candidates.append({
                    'label': label,
                    'module': module,
                    'fn_name': fn_name,
                    'line': i + 1,
                    'line_idx': i,
                    'search': f'(+ {arg1} {arg2})',
                    'replace': f'(* {arg1} {arg2})',
                    'description': f'operator: `(+ {arg1} {arg2})` -> `(* {arg1} {arg2})`',
                    'confidence': 0.75,
                    'type': 'plus-to-times',
                })

    # Strategy 3: actual == 0, expected != 0 -> wrong filter/accessor
    if actual == '0' and expected != '0':
        for i in range(start, end):
            cl = clean_lines[i]
            # Wrong string arg where variable expected
            m_str = re.search(r'(\w+-for-order)\s+(\w+)\s+"(\w+)"', cl)
            if m_str:
                func_called = m_str.group(1)
                coll_arg = m_str.group(2)
                string_arg = m_str.group(3)
                candidates.append({
                    'label': label,
                    'module': module,
                    'fn_name': fn_name,
                    'line': i + 1,
                    'line_idx': i,
                    'search': f'(shipments-for-order {coll_arg} "{string_arg}")',
                    'replace': f'(shipments-for-order {coll_arg} order-id)',
                    'description': f'wrong argument: `"{string_arg}"` should be `order-id`',
                    'confidence': 0.80,
                    'type': 'wrong-arg',
                })

    # Strategy 3b: wrong accessor in arithmetic (not gated on actual==0)
    # Detect (something-id x) used in (+ ... ) context where a cost/rate accessor exists
    if ratio is not None and 0 < ratio < 1:
        raw_content_full = '\n'.join(raw_lines)
        for i in range(start, end):
            cl = clean_lines[i]
            # Look for (+ (something-id x) ...) — ID used where cost/rate expected
            for m in re.finditer(r'\((\w+-id)\s+(\w+)\)', cl):
                accessor = m.group(1)
                arg = m.group(2)
                # Check this is inside an arithmetic expression
                if not re.search(r'\(\+\s.*' + re.escape(m.group(0)), cl):
                    continue
                prefix = accessor.rsplit('-id', 1)[0]
                for alt_suffix in ['base-rate', 'base-fee', 'cost', 'rate', 'amount', 'fee']:
                    alt_accessor = f'{prefix}-{alt_suffix}'
                    # Exact match: (defn name followed by space or [
                    if re.search(r'\(defn ' + re.escape(alt_accessor) + r'[\s\[]', raw_content_full):
                        candidates.append({
                            'label': label,
                            'module': module,
                            'fn_name': fn_name,
                            'line': i + 1,
                            'line_idx': i,
                            'search': f'({accessor} {arg})',
                            'replace': f'({alt_accessor} {arg})',
                            'description': f'wrong accessor: `({accessor} {arg})` -> `({alt_accessor} {arg})`',
                            'confidence': 0.80,
                            'type': 'wrong-accessor',
                        })

    # Strategy 4: percentage/factor bugs
    if ratio is not None and abs(ratio - 0.1) < 0.02:
        for i in range(start, end):
            cl = clean_lines[i]
            if ' 10 ' in cl or ' 10)' in cl:
                candidates.append({
                    'label': label,
                    'module': module,
                    'fn_name': fn_name,
                    'line': i + 1,
                    'line_idx': i,
                    'search': '10',
                    'replace': '100',
                    'description': 'factor: `10` -> `100`',
                    'confidence': 0.70,
                    'type': 'factor-fix',
                })

    # Strategy 5: + where * needed (surcharge calculations: ratio < 1 and > 0)
    # e.g., (+ cost pct) / 100 should be (* cost pct) / 100
    if ratio is not None and 0 < ratio < 0.7 and ratio > 0:
        for i in range(start, end):
            cl = clean_lines[i]
            for m in re.finditer(r'\(\+\s+(\S+)\s+(\([^)]+\)|\S+)\)', cl):
                arg1 = m.group(1)
                arg2 = m.group(2)
                candidates.append({
                    'label': label,
                    'module': module,
                    'fn_name': fn_name,
                    'line': i + 1,
                    'line_idx': i,
                    'search': f'(+ {arg1} {arg2})',
                    'replace': f'(* {arg1} {arg2})',
                    'description': f'operator: `(+ {arg1} {arg2})` -> `(* {arg1} {arg2})`',
                    'confidence': 0.75,
                    'type': 'plus-to-times',
                })

    # Strategy 6: wrong comparator (> vs <) — actual==0 or boolean inversion
    # For filterv predicates where the comparison is inverted
    if actual == '0' and expected != '0':
        # Look for (> ...) or (< ...) in filter predicates within this function
        # or in functions this function calls (within the same module)
        # First: check functions called by fn_name
        called_fns = set()
        for i in range(start, end):
            cl = clean_lines[i]
            for m in re.finditer(r'\(([a-z][\w-]*)\s', cl):
                called_fns.add(m.group(1))

        # For each called function in this module, look for comparator flips
        for called_fn in called_fns:
            c_start, c_end = find_function_range(clean_lines, called_fn)
            if c_start == 0 and c_end == len(clean_lines) and called_fn != fn_name:
                continue  # not found
            for i in range(c_start, c_end):
                cl = clean_lines[i]
                # Look for (> X Y) that could be (< X Y)
                for m in re.finditer(r'\(>\s+(\([^)]+\)|\S+)\s+(\([^)]+\)|\S+)\)', cl):
                    arg1 = m.group(1)
                    arg2 = m.group(2)
                    candidates.append({
                        'label': label,
                        'module': module,
                        'fn_name': called_fn,
                        'line': i + 1,
                        'line_idx': i,
                        'search': f'(> {arg1} {arg2})',
                        'replace': f'(< {arg1} {arg2})',
                        'description': f'comparator flip in {called_fn}: `(> {arg1} {arg2})` -> `(< {arg1} {arg2})`',
                        'confidence': 0.75,
                        'type': 'comparator-flip',
                    })

    # Strategy 7: swapped cond values
    # Detect: two failures in same fn with reciprocal expected/actual
    # e.g., gold expects 15, gets 5; bronze expects 5, gets 15
    # This is hard to do per-failure, so we handle it via a cross-failure check below

    # Strategy 8: accessor swap detection
    # When a wrong value is produced, look for accessor calls (record-field var)
    # in the function body and try swapping for other accessors on the same record.
    # Uses the compiled-file accessor registry built from defrecord/defn patterns.
    if ratio is not None or (actual is not None and expected is not None and actual != expected):
        registry = build_accessor_registry(module)
        for other_mod in PREFIX_MAP.values():
            build_accessor_registry(other_mod)

        NUMERIC_SUFFIXES = {'id', 'cost', 'rate', 'price', 'total', 'amount',
                           'fee', 'quantity', 'count', 'pct', 'days', 'value',
                           'min-quantity', 'subtotal', 'tax', 'discount',
                           'weight', 'weight-kg', 'line-total', 'surcharge-pct',
                           'per-kg-rate', 'base-rate', 'base-fee', 'lead-time-days',
                           'hire-date', 'commission-pct', 'balance', 'revenue',
                           'shipping-cost', 'created-at', 'delivered-at',
                           'issued-at', 'due-at', 'processed-at', 'salary'}
        STRING_SUFFIXES = {'name', 'sku', 'status', 'email', 'phone', 'city',
                          'address', 'tracking-number', 'description', 'segment',
                          'tier', 'type', 'zone-name', 'title'}
        BOOL_SUFFIXES = {'active', 'active?'}

        def accessor_suffix(acc):
            parts = acc.split('-')
            return '-'.join(parts[1:]) if len(parts) >= 2 else acc

        def same_type_group(s1, s2):
            s1n = s1 in NUMERIC_SUFFIXES or any(s1.endswith(s) for s in NUMERIC_SUFFIXES)
            s2n = s2 in NUMERIC_SUFFIXES or any(s2.endswith(s) for s in NUMERIC_SUFFIXES)
            s1s = s1 in STRING_SUFFIXES or any(s1.endswith(s) for s in STRING_SUFFIXES)
            s2s = s2 in STRING_SUFFIXES or any(s2.endswith(s) for s in STRING_SUFFIXES)
            if s1n and s2n: return True
            if s1s and s2s: return True
            if s1 in BOOL_SUFFIXES and s2 in BOOL_SUFFIXES: return True
            return False

        for i in range(start, end):
            cl = clean_lines[i]
            in_arith = bool(re.search(r'\([+\-*/><]=?\s', cl))
            in_filter = bool(re.search(r'(filterv|filter|=|not=)', cl))
            in_computation = in_arith or in_filter

            for m in re.finditer(r'\(([\w-]+)\s+(\w+)\)', cl):
                accessor = m.group(1)
                arg = m.group(2)
                rec_name, siblings, acc_module = find_record_for_accessor(accessor)
                if rec_name is None or len(siblings) < 2:
                    continue
                old_suffix = accessor_suffix(accessor)
                for alt in siblings:
                    if alt == accessor:
                        continue
                    if f'({alt} ' in cl:
                        continue
                    alt_suffix = accessor_suffix(alt)
                    if not same_type_group(old_suffix, alt_suffix):
                        continue
                    conf = 0.75 if in_computation else 0.65
                    candidates.append({
                        'label': label,
                        'module': module,
                        'fn_name': fn_name,
                        'line': i + 1,
                        'line_idx': i,
                        'search': f'({accessor} {arg})',
                        'replace': f'({alt} {arg})',
                        'description': f'accessor swap ({rec_name}): `({accessor} {arg})` -> `({alt} {arg})`',
                        'confidence': conf,
                        'type': 'accessor-swap',
                    })

    # Strategy 9: wrong-argument detection (argument permutation)
    # When a function call has correct arity but wrong output, try swapping
    # pairs of arguments.  Only for functions with 2-5 params.
    if ratio is not None or (actual is not None and expected is not None and actual != expected):
        fn_params = build_fn_params(module, clean_lines)
        for i in range(start, end):
            cl = clean_lines[i]
            for called_fn_name, params in fn_params.items():
                if called_fn_name == fn_name:
                    continue
                nparams = len(params)
                if nparams < 2 or nparams > 5:
                    continue
                call_pattern = r'\(' + re.escape(called_fn_name) + r'\s+'
                m_call = re.search(call_pattern, cl)
                if not m_call:
                    continue
                call_args = extract_call_args(cl, m_call.start())
                if call_args is None or len(call_args) != nparams:
                    continue
                for a_idx in range(len(call_args)):
                    for b_idx in range(a_idx + 1, len(call_args)):
                        if call_args[a_idx] == call_args[b_idx]:
                            continue
                        swapped = list(call_args)
                        swapped[a_idx], swapped[b_idx] = swapped[b_idx], swapped[a_idx]
                        orig_call = f'({called_fn_name} {" ".join(call_args)})'
                        new_call = f'({called_fn_name} {" ".join(swapped)})'
                        if orig_call in cl:
                            candidates.append({
                                'label': label,
                                'module': module,
                                'fn_name': fn_name,
                                'line': i + 1,
                                'line_idx': i,
                                'search': orig_call,
                                'replace': new_call,
                                'description': f'arg swap in {called_fn_name}: swap arg {a_idx+1} <-> {b_idx+1}',
                                'confidence': 0.65,
                                'type': 'arg-swap',
                            })

    return candidates

# ---------------------------------------------------------------------------
# Apply fix to a file (handles metadata-annotated code)
# ---------------------------------------------------------------------------

def apply_fix_to_file(src_path, dst_path, candidate):
    """Apply a candidate fix, handling metadata annotations in the source.

    Uses line_idx to target the correct line, avoiding accidental matches
    elsewhere in the file.
    """
    with open(src_path) as f:
        raw_lines = f.readlines()

    # Special case: value-swap (swap two values on different lines)
    if candidate.get('type') == 'value-swap':
        line1_idx, line2_idx = candidate['swap_lines']
        val1, val2 = candidate['swap_values']
        if 0 <= line1_idx < len(raw_lines) and 0 <= line2_idx < len(raw_lines):
            # Use a placeholder to avoid double-swap
            placeholder = '__SPECFIX_PLACEHOLDER__'
            raw_lines[line1_idx] = re.sub(
                r'(?<![0-9])' + re.escape(val1) + r'(?![0-9])',
                placeholder, raw_lines[line1_idx], count=1)
            raw_lines[line2_idx] = re.sub(
                r'(?<![0-9])' + re.escape(val2) + r'(?![0-9])',
                val1, raw_lines[line2_idx], count=1)
            raw_lines[line1_idx] = raw_lines[line1_idx].replace(placeholder, val2)
            with open(dst_path, 'w') as f:
                f.writelines(raw_lines)
            return True
        return False

    target_line_idx = candidate['line_idx']
    search_text = candidate['search']
    replace_text = candidate['replace']

    # Strategy A: try exact match on the target line first
    if 0 <= target_line_idx < len(raw_lines):
        raw_line = raw_lines[target_line_idx]
        if search_text in raw_line:
            raw_lines[target_line_idx] = raw_line.replace(search_text, replace_text, 1)
            with open(dst_path, 'w') as f:
                f.writelines(raw_lines)
            return True

    # Strategy B: metadata-aware replacement on the target line
    # The search_text is what the CLEAN line looks like; the raw line has metadata
    if 0 <= target_line_idx < len(raw_lines):
        raw_line = raw_lines[target_line_idx]
        clean_line = strip_meta(raw_line)

        if search_text in clean_line:
            # Build a flexible regex that allows metadata between tokens
            meta_opt = r'(?:\^\{[^}]*\}\s*)?'

            # Tokenize the search text and build regex
            escaped = re.escape(search_text)
            # Allow metadata between spaces
            flexible_pattern = escaped.replace(r'\ ', r'\s+' + meta_opt)
            # Allow metadata after opening parens
            flexible_pattern = flexible_pattern.replace(r'\(', r'\(' + meta_opt)

            try:
                m = re.search(flexible_pattern, raw_line)
                if m:
                    new_line = raw_line[:m.start()] + replace_text + raw_line[m.end():]
                    raw_lines[target_line_idx] = new_line
                    with open(dst_path, 'w') as f:
                        f.writelines(raw_lines)
                    return True
            except re.error:
                pass

    # Strategy C: strip all metadata from the entire file, do replacement
    # (last resort — loses metadata but preserves logic)
    raw_content = ''.join(raw_lines)
    stripped_content = strip_meta(raw_content)
    if search_text in stripped_content:
        # To avoid wrong-line hits, use the clean content split by lines
        clean_lines_all = stripped_content.split('\n')
        if 0 <= target_line_idx < len(clean_lines_all):
            cl = clean_lines_all[target_line_idx]
            if search_text in cl:
                clean_lines_all[target_line_idx] = cl.replace(search_text, replace_text, 1)
                with open(dst_path, 'w') as f:
                    f.write('\n'.join(clean_lines_all))
                return True
        # If not on exact line, try nearby lines (function might span multiple)
        for offset in range(-3, 4):
            idx = target_line_idx + offset
            if 0 <= idx < len(clean_lines_all) and search_text in clean_lines_all[idx]:
                clean_lines_all[idx] = clean_lines_all[idx].replace(search_text, replace_text, 1)
                with open(dst_path, 'w') as f:
                    f.write('\n'.join(clean_lines_all))
                return True

    return False

# ---------------------------------------------------------------------------
# Oracle runner
# ---------------------------------------------------------------------------

def run_oracle(build_dir_path, verify_path):
    """Run oracle, return set of failing labels (or None on error)."""
    try:
        result = subprocess.run(
            ['bb', '-cp', build_dir_path,
             '-e', f'(load-file "{verify_path}")'],
            capture_output=True, text=True, timeout=120
        )
        output = result.stdout + result.stderr
        failing = set()
        passed = False
        for line in output.split('\n'):
            m_f = re.match(r'^FAIL: (.+)', line)
            if m_f:
                failing.add(m_f.group(1))
            if 'passed' in line:
                passed = True
        if result.returncode != 0 and not passed:
            print(f"    Oracle crashed (exit {result.returncode})", file=sys.stderr)
            return None
        return failing
    except (subprocess.TimeoutExpired, Exception) as e:
        print(f"    Error running oracle: {e}", file=sys.stderr)
        return None

# ---------------------------------------------------------------------------
# Phase 2: Generate candidates
# ---------------------------------------------------------------------------

print("=== Phase 2: Generating candidates ===", file=sys.stderr)

build_accessor_registry_all()
_n_recs = sum(len(reg) for reg in _record_accessor_cache.values())
_n_accs = sum(len(a) for reg in _record_accessor_cache.values() for a in reg.values())
print(f"  Accessor registry: {_n_recs} records, {_n_accs} accessors", file=sys.stderr)

all_candidates = []
for failure in failures:
    all_candidates.extend(generate_candidates(failure))

# Cross-failure strategy: swapped cond values
# Detect pairs where fn_name is the same, and expected/actual are swapped
# e.g., "cust/tier-discount-pct gold" expects 15, gets 5
#        "cust/tier-discount-pct bronze" expects 5, gets 15
from itertools import combinations
by_function = {}
for f in failures:
    fn = label_to_function(f['label'])
    mod = label_to_module(f['label'])
    if fn and mod:
        key = (mod, fn)
        by_function.setdefault(key, []).append(f)

for (mod, fn), group in by_function.items():
    if len(group) < 2:
        continue
    for f1, f2 in combinations(group, 2):
        # Check if expected/actual are swapped
        if (f1['expected'] == f2['actual'] and f1['actual'] == f2['expected']
                and f1['expected'] is not None and f2['expected'] is not None):
            # Values are swapped! Generate a candidate that swaps the literal values
            # in the cond/case within the function
            path, raw_lines, clean_lines = read_module(mod)
            if path is None:
                continue
            start, end = find_function_range(clean_lines, fn)
            val1 = f1['expected']
            val2 = f2['expected']
            # Find lines with these values and swap them
            line_with_val1 = None
            line_with_val2 = None
            for i in range(start, end):
                cl = clean_lines[i]
                # In a cond, return values appear after `) value` at end of line/before next clause
                # Pattern: `...) VAL` where VAL is our target
                # Look for the value as a trailing token (after a closing paren)
                if re.search(r'\)\s+' + re.escape(val1) + r'\s*$', cl) and line_with_val1 is None:
                    line_with_val1 = i
                if re.search(r'\)\s+' + re.escape(val2) + r'\s*$', cl) and line_with_val2 is None:
                    line_with_val2 = i
            if line_with_val1 is not None and line_with_val2 is not None and line_with_val1 != line_with_val2:
                # Generate a two-line swap candidate
                all_candidates.append({
                    'label': f1['label'],
                    'module': mod,
                    'fn_name': fn,
                    'line': line_with_val1 + 1,
                    'line_idx': line_with_val1,
                    'search': f'__SWAP__{val1}__{val2}__',  # special marker
                    'replace': f'__SWAP__{val2}__{val1}__',
                    'description': f'swap values in cond: `{val1}` <-> `{val2}`',
                    'confidence': 0.80,
                    'type': 'value-swap',
                    'swap_lines': (line_with_val1, line_with_val2),
                    'swap_values': (val1, val2),
                })

# Deduplicate by (module, search, replace)
seen = set()
unique_candidates = []
for c in all_candidates:
    key = (c['module'], c['search'], c['replace'])
    if key not in seen:
        seen.add(key)
        unique_candidates.append(c)

print(f"  Generated {len(all_candidates)} total, {len(unique_candidates)} unique candidates", file=sys.stderr)

# ---------------------------------------------------------------------------
# Phase 3: Verify candidates
# ---------------------------------------------------------------------------

print("=== Phase 3: Verifying candidates ===", file=sys.stderr)

baseline_failures = set(f['label'] for f in failures)
print(f"  Baseline: {len(baseline_failures)} failures", file=sys.stderr)

verified_fixes = []

for idx, candidate in enumerate(unique_candidates):
    desc = candidate['description']
    print(f"  [{idx+1}/{len(unique_candidates)}] {candidate['module']}/{candidate.get('fn_name','?')}: {desc}", file=sys.stderr)

    # Create temp dir with fix applied
    temp_dir = tempfile.mkdtemp(dir=work_dir)
    for f in os.listdir(build_dir):
        if f.endswith('.clj'):
            shutil.copy2(os.path.join(build_dir, f), temp_dir)

    module_file = f"{candidate['module']}.clj"
    src_path = os.path.join(build_dir, module_file)
    dst_path = os.path.join(temp_dir, module_file)

    applied = apply_fix_to_file(src_path, dst_path, candidate)
    if not applied:
        print(f"    SKIP: pattern not found in file", file=sys.stderr)
        continue

    # Run oracle
    new_failures = run_oracle(temp_dir, verify_script)
    if new_failures is None:
        print(f"    ERROR: oracle crashed", file=sys.stderr)
        continue

    target_label = candidate['label']
    target_fixed = target_label not in new_failures
    regressions = new_failures - baseline_failures
    fixes_count = len(baseline_failures - new_failures)

    if target_fixed and len(regressions) == 0:
        candidate['verified'] = True
        candidate['fixes_count'] = fixes_count
        candidate['remaining_failures'] = len(new_failures)
        verified_fixes.append(candidate)
        print(f"    VERIFIED: fixes {fixes_count} assertion(s), 0 regressions, {len(new_failures)} remaining", file=sys.stderr)
    elif target_fixed and len(regressions) > 0:
        print(f"    REJECTED: target fixed but {len(regressions)} regression(s): {list(regressions)[:3]}", file=sys.stderr)
    else:
        print(f"    FAILED: target still failing ({len(new_failures)} total failures)", file=sys.stderr)

# ---------------------------------------------------------------------------
# Phase 4: Output
# ---------------------------------------------------------------------------

print("", file=sys.stderr)
print(f"=== Results: {len(verified_fixes)} verified fix(es) ===", file=sys.stderr)

if not verified_fixes:
    print("No verified fixes found.")
    sys.exit(0)

# Sort by fixes_count descending, then confidence
verified_fixes.sort(key=lambda x: (-x['fixes_count'], -x['confidence']))

for fix in verified_fixes:
    print(f"SPECFIX: {fix['label']}")
    print(f"  file: {fix['module']}.clj")
    print(f"  function: {fix.get('fn_name', '(unknown)')}")
    print(f"  fix: {fix['description']}")
    print(f"  line: {fix['line']}")
    print(f"  confidence: {fix['confidence']}")
    print(f"  verified: oracle passes with this fix applied")
    print(f"  assertions-fixed: {fix['fixes_count']}")
    print(f"  remaining-failures: {fix['remaining_failures']}")
    if fix.get('type') == 'value-swap' and fix.get('swap_lines') and fix.get('swap_values'):
        print(f"  swap-lines: {fix['swap_lines'][0]},{fix['swap_lines'][1]}")
        print(f"  swap-values: {fix['swap_values'][0]}|||{fix['swap_values'][1]}")
    elif fix.get('search') and fix.get('replace'):
        print(f"  search: {fix['search']}")
        print(f"  replace: {fix['replace']}")
        print(f"  line-idx: {fix.get('line_idx', '?')}")
    print()

# Summary
modules_fixed = len(set(f['module'] for f in verified_fixes))
print(f"---")
print(f"Summary: {len(verified_fixes)} verified fixes across {modules_fixed} modules")
print(f"  Bug types found: {', '.join(sorted(set(f['type'] for f in verified_fixes)))}")

PYEOF
