#!/usr/bin/env bash
# beagle-cascade: cross-module impact propagation and predictive blame — GRAPH-NATIVE.
#
# The call graph is derived from the Fram claim graph (via beagle-callgraph):
# every call binds the defn in its OWN module (scope-correct), so same-named
# functions across modules never collide. This replaces the old regex call-graph
# scraping, which was structurally incapable of telling two `helper`s apart — it
# matched the bare text and merged them, corrupting blast radius across modules.
#
# Given a set of modified functions, predicts which assertions will fail by walking
# the call graph transitively (closure computed by Fram Datalog). Also identifies
# cascade patterns where fixing one root bug eliminates multiple downstream failures.
#
# Usage:
#   beagle-cascade <source-dir> <verify-script> [--modified fn1,fn2,...]
#   beagle-cascade <source-dir> <verify-script> --from-failures
#
# fn may be bare ('helper') or module-qualified ('mod_a/helper'). Qualify to target
# one of several same-named defns; a bare name that collides is analyzed per-module.

set -euo pipefail

if [[ $# -lt 2 ]]; then
    echo "Usage: beagle-cascade <source-dir> <verify-script> [--modified fn1,...] [--from-failures]" >&2
    echo "  fn may be bare ('helper') or module-qualified ('mod_a/helper')." >&2
    exit 1
fi

SOURCE_DIR="$(cd "$1" && pwd)"
VERIFY_SCRIPT="$(realpath "$2")"
shift 2

MODIFIED_FNS=""
FROM_FAILURES=false

while [[ $# -gt 0 ]]; do
    case "$1" in
        --modified) MODIFIED_FNS="$2"; shift 2 ;;
        --from-failures) FROM_FAILURES=true; shift ;;
        *) echo "Unknown option: $1" >&2; exit 1 ;;
    esac
done

BEAGLE_BIN="$(cd "$(dirname "$0")" && pwd)"
WORK_DIR="$(mktemp -d /tmp/beagle-cascade.XXXXXX)"
trap 'rm -rf "$WORK_DIR"' EXIT

# ---------------------------------------------------------------------------
# Phase 1: Build the call graph — from the Fram claim graph, not text.
# ---------------------------------------------------------------------------
echo "─── Phase 1: Building call graph (graph-native, scope-correct) ───" >&2
GRAPH_JSON="$WORK_DIR/callgraph.json"
"$BEAGLE_BIN/beagle-callgraph" "$SOURCE_DIR" > "$GRAPH_JSON"

python3 - "$SOURCE_DIR" "$BEAGLE_BIN" "$WORK_DIR" "$VERIFY_SCRIPT" "$MODIFIED_FNS" "$FROM_FAILURES" "$GRAPH_JSON" << 'PYEOF'
import sys, os, re, subprocess, json
from collections import defaultdict

source_dir       = sys.argv[1]
beagle_bin       = sys.argv[2]
work_dir         = sys.argv[3]
verify_script    = sys.argv[4]
modified_fns_str = sys.argv[5]
from_failures    = sys.argv[6] == 'true'
graph_path       = sys.argv[7]

# ---------------------------------------------------------------------------
# The call graph, scope-correct, straight off the Fram claim graph.
#   defns: [{key, file, module, name}]          edges: [[caller_key, callee_key]]
#   blast: {callee_key: [transitive caller keys]}   (closure by Fram Datalog)
# ---------------------------------------------------------------------------
with open(graph_path) as f:
    G = json.load(f)

defns = G['defns']
edges = G['edges']
blast = G['blast']

meta    = {d['key']: d for d in defns}            # key -> {module,name,file,key}
by_name = defaultdict(list)                       # name -> [keys]
by_qual = {}                                      # "module/name" -> key
for d in defns:
    by_name[d['name']].append(d['key'])
    by_qual[f"{d['module']}/{d['name']}"] = d['key']

fwd = defaultdict(set)                             # caller -> {direct callees}
rev_direct = defaultdict(set)                      # callee -> {direct callers}
for a, b in edges:
    fwd[a].add(b)
    rev_direct[b].add(a)

def label_of(key):
    m = meta.get(key)
    return f"{m['module']}/{m['name']}" if m else key

def impact(key):
    """Transitive callers (blast radius) — already closed by Fram Datalog."""
    return set(blast.get(key, []))

modules = sorted({d['module'] for d in defns})
print(f"  {len(defns)} defns across {len(modules)} modules", file=sys.stderr)
print(f"  {len(edges)} scope-correct call edges", file=sys.stderr)

# ---------------------------------------------------------------------------
# Best-effort signatures (metadata only — the call graph above is 100% graph-
# derived; sigs just enrich display). beagle-provides runs the checker per file.
# ---------------------------------------------------------------------------
sig_by_qual = {}
for fpath, module in sorted({d['file']: d['module'] for d in defns}.items()):
    try:
        r = subprocess.run([os.path.join(beagle_bin, 'beagle-provides'), fpath],
                           capture_output=True, text=True, timeout=30)
        for m in re.finditer(r'^\s+([\w?!<>*+\-/]+)\s*:\s*\[(.+)\]$', r.stdout, re.MULTILINE):
            sig_by_qual[f"{module}/{m.group(1)}"] = m.group(2)
    except Exception:
        pass

def sig_of(key):
    return sig_by_qual.get(label_of(key), '?')

# ---------------------------------------------------------------------------
# Resolve a user/assertion token to scope-correct defn key(s).
#   'module/fn' -> the precise defn ; bare 'fn' -> every defn of that name.
# ---------------------------------------------------------------------------
def resolve_entry(token):
    token = token.strip()
    if not token:
        return []
    if '/' in token:
        key = by_qual.get(token)
        if key:
            return [key]
        return list(by_name.get(token.split('/')[-1], []))   # fall back to bare
    return list(by_name.get(token, []))

# ---------------------------------------------------------------------------
# Map verify-script assertions to the function they exercise (head of the
# asserted expression), resolved scope-correctly.  label -> [keys]
# ---------------------------------------------------------------------------
assertion_map = {}
with open(verify_script) as f:
    verify_content = f.read()
for m in re.finditer(r'assert-eq\s+"([^"]+)"\s+\S+\s+\(([\w?!<>*+\-/]+)', verify_content):
    assertion_map[m.group(1)] = resolve_entry(m.group(2))
print(f"  {len(assertion_map)} assertions mapped to functions", file=sys.stderr)

# ===========================================================================
# Mode: --from-failures
# ===========================================================================
if from_failures:
    print("\n─── Running oracle to find current failures ───", file=sys.stderr)
    build_dir = os.path.join(work_dir, 'build')
    os.makedirs(build_dir, exist_ok=True)
    subprocess.run([os.path.join(beagle_bin, 'beagle-build-all'), '--warn', source_dir, '--out', build_dir],
                   capture_output=True, timeout=60)
    result = subprocess.run(['bb', '-cp', build_dir, '-e', f'(load-file "{verify_script}")'],
                            capture_output=True, text=True, timeout=120)
    output = result.stdout + result.stderr

    failing_labels = [m.group(1) for line in output.split('\n')
                      if (m := re.match(r'^FAIL: (.+)', line))]
    print(f"  {len(failing_labels)} failing assertions", file=sys.stderr)

    failing_keys = set()
    for label in failing_labels:
        failing_keys |= set(assertion_map.get(label, []))

    # cascade score: how many failing fns are explained if this one is fixed
    cascade_scores = {k: len((impact(k) | {k}) & failing_keys) for k in failing_keys}

    print()
    print("CASCADE ANALYSIS")
    print("=" * 60)
    print(f"\n{len(failing_labels)} failing assertions trace to {len(failing_keys)} functions\n")

    ranked = sorted(cascade_scores.items(), key=lambda x: -x[1])
    print("Root cause candidates (fixing one may fix many):")
    print("-" * 60)
    for key, score in ranked[:15]:
        if score > 1:
            downstream = (impact(key) | {key}) & failing_keys
            explained_labels = [l for l, keys in assertion_map.items()
                                if l in failing_labels and any(k in downstream for k in keys)]
            print(f"\n  {label_of(key)} — cascade score: {score}")
            print(f"    sig: {sig_of(key)}")
            print(f"    would fix: {', '.join(explained_labels[:5])}")
            if len(explained_labels) > 5:
                print(f"    ... and {len(explained_labels) - 5} more")

    independent = [(k, s) for k, s in ranked if s == 1]
    if independent:
        print(f"\n\nIndependent failures (no cascade, fix individually):")
        print("-" * 60)
        for key, _ in independent[:10]:
            labels = [l for l, keys in assertion_map.items()
                      if l in failing_labels and key in keys]
            if labels:
                print(f"  {label_of(key)}: {', '.join(labels[:3])}")

# ===========================================================================
# Mode: --modified
# ===========================================================================
elif modified_fns_str:
    tokens = [t for t in modified_fns_str.split(',') if t.strip()]
    print()
    print("IMPACT PREDICTION")
    print("=" * 60)
    print(f"\nModified: {', '.join(tokens)}\n")

    changed_keys = set()
    union_affected = set()
    for tok in tokens:
        targets = resolve_entry(tok)
        if not targets:
            print(f"\n  {tok}: (no such defn in the call graph)")
            continue
        if len(targets) > 1:
            others = ", ".join(sorted(label_of(k) for k in targets))
            bare = tok.split('/')[-1]
            print(f"\n  ⚠ '{tok}' is defined in {len(targets)} modules: {others}")
            print(f"    (qualify as <module>/{bare} to target one — analyzing each separately)")
        for key in targets:
            changed_keys.add(key)
            aff = impact(key)
            union_affected |= aff
            direct = sorted(label_of(c) for c in rev_direct.get(key, set()))
            print(f"\n  {label_of(key)}:")
            print(f"    direct callers: {', '.join(direct) if direct else '(none)'}")
            print(f"    transitive impact: {len(aff)} function(s)")

    at_risk = [l for l, keys in assertion_map.items()
               if any(k in union_affected or k in changed_keys for k in keys)]

    print(f"\n\nPREDICTED ASSERTION FAILURES ({len(at_risk)}):")
    print("-" * 60)
    for label in sorted(at_risk):
        via = ", ".join(sorted(label_of(k) for k in assertion_map.get(label, []))) or '?'
        print(f"  {label}  (via {via})")

    print(f"\n\nSummary: changing {', '.join(tokens)} affects "
          f"{len(union_affected)} function(s) and risks {len(at_risk)} assertion(s)")

# ===========================================================================
# Mode: default — call graph summary
# ===========================================================================
else:
    print()
    print("CALL GRAPH SUMMARY")
    print("=" * 60)
    print(f"\n  Modules: {len(modules)}")
    print(f"  Functions: {len(defns)}")
    print(f"  Call edges: {len(edges)}")
    print(f"  Assertions mapped: {len(assertion_map)}")

    print(f"\n\nMost-impactful functions (highest blast radius if changed):")
    print("-" * 60)
    ranked = sorted((d['key'] for d in defns), key=lambda k: -len(impact(k)))
    for key in ranked[:15]:
        aff = impact(key)
        if not aff and not rev_direct.get(key):
            continue
        at_risk = [l for l, keys in assertion_map.items() if any(k in aff for k in keys)]
        print(f"  {label_of(key)}: {len(rev_direct.get(key, set()))} direct callers, "
              f"{len(aff)} transitive, {len(at_risk)} assertions at risk")
PYEOF
