#!/usr/bin/env bash
# beagle-dtrace: distributed tracing for beagle microservices.
#
# Instruments compiled beagle modules for cross-service tracing, collects
# spans from multiple services, and provides cross-service blame analysis.
#
# Subcommands:
#   instrument <build-dir> [--services s1,s2,...] [--out <dir>]
#       Auto-instrument compiled modules with span creation at service boundaries.
#
#   collect [--port N] [--dir <trace-dir>]
#       Start a TCP collector daemon that receives spans from services.
#
#   view <trace-dir> [--trace-id <id>]
#       Show trace waterfall — full request flow across services.
#
#   blame <trace-dir> [--trace-id <id>] [--verify <script>]
#       Cross-service blame: find which service first produced a wrong value.
#
#   graph <trace-dir>
#       Service dependency graph from collected traces.
#
#   cascade <trace-dir> [--trace-id <id>]
#       Root cause analysis: walk backwards across service boundaries.

set -euo pipefail

BEAGLE_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
DTRACE_LIB="$BEAGLE_ROOT/lib"

if [[ $# -lt 1 ]]; then
    echo "Usage: beagle-dtrace <subcommand> [args...]" >&2
    echo "" >&2
    echo "Subcommands:" >&2
    echo "  instrument <build-dir> [--services s1,s2,...] [--out <dir>]" >&2
    echo "  collect    [--port N] [--dir <trace-dir>]" >&2
    echo "  view       <trace-dir> [--trace-id <id>]" >&2
    echo "  blame      <trace-dir> [--trace-id <id>] [--verify <script>]" >&2
    echo "  graph      <trace-dir>" >&2
    echo "  cascade    <trace-dir> [--trace-id <id>]" >&2
    exit 1
fi

SUBCMD="$1"
shift

# ═══════════════════════════════════════════════════════════════════════════
# INSTRUMENT — auto-instrument compiled modules for distributed tracing
# ═══════════════════════════════════════════════════════════════════════════

do_instrument() {
    if [[ $# -lt 1 ]]; then
        echo "Usage: beagle-dtrace instrument <build-dir> [--services s1,s2] [--out <dir>]" >&2
        exit 1
    fi

    BUILD_DIR="$(cd "$1" && pwd)"
    shift

    SERVICES=""
    OUT_DIR=""
    while [[ $# -gt 0 ]]; do
        case "$1" in
            --services) SERVICES="$2"; shift 2 ;;
            --out) OUT_DIR="$2"; shift 2 ;;
            *) echo "Unknown option: $1" >&2; exit 1 ;;
        esac
    done

    if [[ -z "$OUT_DIR" ]]; then
        OUT_DIR="$(mktemp -d /tmp/beagle-dtrace-inst.XXXXXX)"
    fi
    mkdir -p "$OUT_DIR"

    echo "=== beagle-dtrace instrument ===" >&2
    echo "  source: $BUILD_DIR" >&2
    echo "  output: $OUT_DIR" >&2

    python3 - "$BUILD_DIR" "$OUT_DIR" "$SERVICES" "$DTRACE_LIB" << 'PYEOF'
import sys, os, re, shutil
from pathlib import Path

build_dir = sys.argv[1]
out_dir = sys.argv[2]
services_str = sys.argv[3]
dtrace_lib = sys.argv[4]

# Determine service boundaries: each .clj file = one module = one potential service
modules = {}
for fname in sorted(os.listdir(build_dir)):
    if not fname.endswith('.clj'):
        continue
    module = fname[:-4]
    modules[module] = fname

# If --services specified, only those are service boundaries.
# Otherwise, every module is its own service.
if services_str:
    service_modules = set(s.strip() for s in services_str.split(','))
else:
    service_modules = set(modules.keys())

print(f"  {len(modules)} modules, {len(service_modules)} service(s)", file=sys.stderr)
print(f"  services: {', '.join(sorted(service_modules))}", file=sys.stderr)

# Build cross-module call map: which functions are defined in which module
fn_to_module = {}  # fn-name -> module
module_fns = {}    # module -> set of fn-names
for module, fname in modules.items():
    fpath = os.path.join(build_dir, fname)
    with open(fpath) as f:
        content = f.read()
    fns = set()
    for m in re.finditer(r'\(defn\s+([\w?!<>*+\-/=]+)', content):
        fn_name = m.group(1)
        fns.add(fn_name)
        fn_to_module[fn_name] = module
    module_fns[module] = fns

# For each file, find qualified cross-module calls (e.g., catalog/find-product)
# and wrap them with beagle.dtrace/traced-call
cross_service_pattern = re.compile(
    r'(\^\{:line\s+(\d+)\s+:file\s+"([^"]+)"\}\s*)'
    r'\(([\w?!<>*+\-]+)/([\w?!<>*+\-/=]+)\s+'
)

# Parse namespace aliases from each module's (:require ...) form
# e.g., [orders :as ord] → alias "ord" → module "orders"
module_aliases = {}  # module -> {alias -> real_module}
for module, fname in modules.items():
    fpath = os.path.join(build_dir, fname)
    with open(fpath) as f:
        content = f.read()
    aliases = {}
    # Match [real-module :as alias] patterns in :require
    for m in re.finditer(r'\[([\w.\-]+)\s+:as\s+([\w.\-]+)\]', content):
        real_mod = m.group(1)
        alias = m.group(2)
        aliases[alias] = real_mod
    # Also: bare requires (module name = alias)
    for m in re.finditer(r'\[([\w.\-]+)\](?!\s+:as)', content):
        mod = m.group(1)
        if mod != 'beagle.dtrace':
            aliases[mod] = mod
    module_aliases[module] = aliases

instrumented_count = 0
span_sites = [0]

for module, fname in modules.items():
    src_path = os.path.join(build_dir, fname)
    dst_path = os.path.join(out_dir, fname)

    with open(src_path) as f:
        content = f.read()

    # Track which module this file is in
    current_service = module if module in service_modules else None
    aliases = module_aliases.get(module, {})

    # Find and instrument cross-module calls that cross service boundaries
    def replace_cross_call(m):
        meta = m.group(1)
        line = m.group(2)
        file = m.group(3)
        ns_prefix = m.group(4)
        fn_name = m.group(5)
        qualified = f'{ns_prefix}/{fn_name}'

        from_svc = module
        # Resolve alias to actual module name
        to_svc = aliases.get(ns_prefix, ns_prefix)

        # Only instrument if both sides are recognized services
        if from_svc in service_modules and to_svc in service_modules and from_svc != to_svc:
            span_sites[0] += 1
            return (f'{meta}(beagle.dtrace/traced-call '
                    f'"{from_svc}" "{to_svc}" "{qualified}" {qualified} ')
        return m.group(0)

    instrumented = cross_service_pattern.sub(replace_cross_call, content)

    # Add require for beagle.dtrace if we instrumented anything
    if instrumented != content:
        instrumented_count += 1
        # Add to :require in ns form
        instrumented = instrumented.replace(
            '(:require ',
            '(:require [beagle.dtrace] '
        )
        # If no :require, add one after ns
        if '[beagle.dtrace]' not in instrumented:
            instrumented = re.sub(
                r'\(ns\s+([\w.\-]+)\)',
                r'(ns \1 (:require [beagle.dtrace]))',
                instrumented,
                count=1
            )

    with open(dst_path, 'w') as f:
        f.write(instrumented)

# Copy the dtrace runtime library into the output directory
dtrace_dst = os.path.join(out_dir, 'beagle', 'dtrace.clj')
os.makedirs(os.path.dirname(dtrace_dst), exist_ok=True)
shutil.copy2(os.path.join(dtrace_lib, 'beagle', 'dtrace.clj'), dtrace_dst)

# Generate an initialization script that sets up each service
init_script = os.path.join(out_dir, 'dtrace-init.clj')
with open(init_script, 'w') as f:
    f.write('(ns dtrace-init (:require [beagle.dtrace]))\n\n')
    f.write(';; Auto-generated initialization for distributed tracing.\n')
    f.write(';; Include this before running your service or oracle.\n\n')
    f.write('(defn init-dtrace!\n')
    f.write('  "Initialize distributed tracing. Call from each service entry point."\n')
    f.write('  [service-name trace-dir]\n')
    f.write('  (beagle.dtrace/init! {:service service-name :trace-dir trace-dir}))\n')

print(f"\n  Instrumented {instrumented_count} module(s), {span_sites[0]} cross-service call site(s)", file=sys.stderr)
print(f"  Output: {out_dir}", file=sys.stderr)
PYEOF

    echo "  instrumented build: $OUT_DIR" >&2
}

# ═══════════════════════════════════════════════════════════════════════════
# COLLECT — TCP collector daemon for receiving spans
# ═══════════════════════════════════════════════════════════════════════════

do_collect() {
    PORT=9876
    TRACE_DIR="/tmp/beagle-traces"

    while [[ $# -gt 0 ]]; do
        case "$1" in
            --port) PORT="$2"; shift 2 ;;
            --dir) TRACE_DIR="$2"; shift 2 ;;
            *) echo "Unknown option: $1" >&2; exit 1 ;;
        esac
    done

    mkdir -p "$TRACE_DIR"
    echo "=== beagle-dtrace collector ===" >&2
    echo "  port: $PORT" >&2
    echo "  trace-dir: $TRACE_DIR" >&2

    python3 - "$PORT" "$TRACE_DIR" << 'PYEOF'
import sys, os, json, socket, threading, signal

port = int(sys.argv[1])
trace_dir = sys.argv[2]

span_count = 0
lock = threading.Lock()

def handle_client(conn, addr):
    global span_count
    data = b''
    try:
        while True:
            chunk = conn.recv(4096)
            if not chunk:
                break
            data += chunk
    except:
        pass
    finally:
        conn.close()

    for line in data.decode('utf-8', errors='replace').strip().split('\n'):
        line = line.strip()
        if not line:
            continue
        try:
            span = json.loads(line)
            service = span.get('service', 'unknown')
            trace_id = span.get('trace_id', 'unknown')

            # Write to per-service file
            svc_path = os.path.join(trace_dir, f'{service}.jsonl')
            with lock:
                with open(svc_path, 'a') as f:
                    f.write(line + '\n')
                span_count += 1

            # Also write to per-trace index
            idx_path = os.path.join(trace_dir, 'trace-index.jsonl')
            with lock:
                with open(idx_path, 'a') as f:
                    f.write(json.dumps({
                        'trace_id': trace_id,
                        'span_id': span.get('span_id'),
                        'service': service,
                        'operation': span.get('operation'),
                        'start_ms': span.get('start_ms'),
                    }) + '\n')

            if span_count % 100 == 0:
                print(f"  [{span_count} spans collected]", file=sys.stderr)
        except json.JSONDecodeError:
            pass

def server_loop():
    srv = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    srv.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    srv.bind(('127.0.0.1', port))
    srv.listen(16)
    print(f"  Listening on 127.0.0.1:{port}", file=sys.stderr)
    print(f"  Trace data: {trace_dir}", file=sys.stderr)

    while True:
        conn, addr = srv.accept()
        t = threading.Thread(target=handle_client, args=(conn, addr), daemon=True)
        t.start()

def shutdown(signum, frame):
    print(f"\n  Collector shutting down. Total spans: {span_count}", file=sys.stderr)
    sys.exit(0)

signal.signal(signal.SIGINT, shutdown)
signal.signal(signal.SIGTERM, shutdown)

server_loop()
PYEOF
}

# ═══════════════════════════════════════════════════════════════════════════
# VIEW — trace waterfall visualization
# ═══════════════════════════════════════════════════════════════════════════

do_view() {
    if [[ $# -lt 1 ]]; then
        echo "Usage: beagle-dtrace view <trace-dir> [--trace-id <id>]" >&2
        exit 1
    fi

    TRACE_DIR="$1"; shift
    TRACE_ID=""

    while [[ $# -gt 0 ]]; do
        case "$1" in
            --trace-id) TRACE_ID="$2"; shift 2 ;;
            *) echo "Unknown option: $1" >&2; exit 1 ;;
        esac
    done

    python3 - "$TRACE_DIR" "$TRACE_ID" << 'PYEOF'
import sys, os, json
from collections import defaultdict

trace_dir = sys.argv[1]
target_trace = sys.argv[2] if len(sys.argv) > 2 and sys.argv[2] else None

# Load all spans
all_spans = []
for fname in sorted(os.listdir(trace_dir)):
    if not fname.endswith('.jsonl') or fname == 'trace-index.jsonl':
        continue
    fpath = os.path.join(trace_dir, fname)
    with open(fpath) as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                span = json.loads(line)
                all_spans.append(span)
            except json.JSONDecodeError:
                pass

if not all_spans:
    print("No spans found.", file=sys.stderr)
    sys.exit(1)

# Group by trace ID
traces = defaultdict(list)
for span in all_spans:
    traces[span.get('trace_id', 'unknown')].append(span)

# If no trace ID specified, show summary of all traces
if not target_trace:
    print(f"TRACES ({len(traces)} total)")
    print("=" * 72)
    for tid, spans in sorted(traces.items(), key=lambda x: min(s.get('start_ms', 0) for s in x[1])):
        services = sorted(set(s.get('service', '?') for s in spans))
        min_start = min(s.get('start_ms', 0) for s in spans)
        max_end = max(s.get('end_ms', 0) or s.get('start_ms', 0) for s in spans)
        duration = max_end - min_start if max_end > min_start else 0
        root_ops = [s.get('operation', '?') for s in spans if not s.get('parent_span_id')]
        errors = sum(1 for s in spans if s.get('status') == 'error')
        print(f"\n  trace: {tid[:16]}...")
        print(f"    spans: {len(spans)}  services: {', '.join(services)}")
        print(f"    duration: {duration}ms  errors: {errors}")
        if root_ops:
            print(f"    root: {root_ops[0]}")
    sys.exit(0)

# Show specific trace waterfall
spans = traces.get(target_trace) or traces.get(
    next((k for k in traces if k.startswith(target_trace)), ''))

if not spans:
    print(f"Trace not found: {target_trace}", file=sys.stderr)
    print(f"Available traces: {', '.join(t[:16] for t in traces)}", file=sys.stderr)
    sys.exit(1)

# Sort by start time
spans.sort(key=lambda s: s.get('start_ms', 0))
trace_start = spans[0].get('start_ms', 0)

# Build parent-child tree
by_id = {s['span_id']: s for s in spans if 'span_id' in s}
children = defaultdict(list)
roots = []
for s in spans:
    pid = s.get('parent_span_id')
    if pid and pid in by_id:
        children[pid].append(s)
    else:
        roots.append(s)

# Render waterfall
print(f"TRACE: {target_trace}")
print(f"{'=' * 72}")
print(f"  spans: {len(spans)}  services: {', '.join(sorted(set(s.get('service','?') for s in spans)))}")
total_dur = max(s.get('end_ms', 0) or s.get('start_ms', 0) for s in spans) - trace_start
print(f"  total duration: {total_dur}ms")
print()

# Tree-walk render
WIDTH = 40

def render_bar(start_ms, end_ms, total_start, total_dur, width=WIDTH):
    if total_dur == 0:
        return '█' * width
    s = max(0, int((start_ms - total_start) / total_dur * width))
    e = min(width, max(s + 1, int(((end_ms or start_ms) - total_start) / total_dur * width)))
    return '·' * s + '█' * (e - s) + '·' * (width - e)

def print_span(span, depth=0):
    indent = '  ' * depth
    svc = span.get('service', '?')
    op = span.get('operation', '?')
    start = span.get('start_ms', 0)
    end = span.get('end_ms', start)
    dur = (end or start) - start
    status = span.get('status', '?')
    err = ' ✗' if status == 'error' else ''

    bar = render_bar(start, end, trace_start, total_dur)
    kind = span.get('tags', {}).get('kind', '')
    kind_marker = '→' if kind == 'client' else '←' if kind == 'server' else '·'

    print(f"  {bar}  {indent}{kind_marker} [{svc}] {op} ({dur}ms){err}")

    if span.get('error'):
        print(f"  {'':>{WIDTH}}  {indent}  ERROR: {span['error'][:60]}")

    # Local trace entries (from beagle-trace integration)
    local = span.get('local_trace', [])
    if local and len(local) <= 10:
        for lt in local:
            lt_op = lt.get('op', '?')
            lt_result = lt.get('result', '?')
            lt_file = lt.get('file', '')
            lt_line = lt.get('line', '')
            if lt_file:
                short = os.path.basename(lt_file)
                loc = f"{short}:{lt_line}"
            else:
                loc = ''
            print(f"  {'':>{WIDTH}}  {indent}    ({lt_op}) = {lt_result}  ; {loc}")

    # Recurse to children
    for child in sorted(children.get(span.get('span_id', ''), []),
                        key=lambda c: c.get('start_ms', 0)):
        print_span(child, depth + 1)

print(f"  {'time →':^{WIDTH}}  operation")
print(f"  {'─' * WIDTH}  {'─' * 32}")

for root in roots:
    print_span(root)

# Error summary
errors = [s for s in spans if s.get('status') == 'error']
if errors:
    print(f"\nERRORS ({len(errors)}):")
    print("-" * 60)
    for s in errors:
        svc = s.get('service', '?')
        op = s.get('operation', '?')
        err = s.get('error', '?')
        print(f"  [{svc}] {op}: {err[:80]}")
PYEOF
}

# ═══════════════════════════════════════════════════════════════════════════
# BLAME — cross-service blame analysis
# ═══════════════════════════════════════════════════════════════════════════

do_blame() {
    if [[ $# -lt 1 ]]; then
        echo "Usage: beagle-dtrace blame <trace-dir> [--trace-id <id>] [--oracle-output <file>]" >&2
        exit 1
    fi

    TRACE_DIR="$1"; shift
    TRACE_ID=""
    ORACLE_OUTPUT=""

    while [[ $# -gt 0 ]]; do
        case "$1" in
            --trace-id) TRACE_ID="$2"; shift 2 ;;
            --oracle-output) ORACLE_OUTPUT="$2"; shift 2 ;;
            *) echo "Unknown option: $1" >&2; exit 1 ;;
        esac
    done

    python3 - "$TRACE_DIR" "$TRACE_ID" "$ORACLE_OUTPUT" << 'PYEOF'
import sys, os, json, re
from collections import defaultdict

trace_dir = sys.argv[1]
target_trace = sys.argv[2] if len(sys.argv) > 2 and sys.argv[2] else ""
oracle_output_path = sys.argv[3] if len(sys.argv) > 3 and sys.argv[3] else ""

# Load all spans
all_spans = []
for fname in sorted(os.listdir(trace_dir)):
    if not fname.endswith('.jsonl') or fname == 'trace-index.jsonl':
        continue
    with open(os.path.join(trace_dir, fname)) as f:
        for line in f:
            line = line.strip()
            if line:
                try:
                    all_spans.append(json.loads(line))
                except json.JSONDecodeError:
                    pass

# Group by trace
traces = defaultdict(list)
for span in all_spans:
    traces[span.get('trace_id', '')].append(span)

# Select trace(s)
if target_trace:
    target_spans = traces.get(target_trace, [])
    if not target_spans:
        for tid, spans in traces.items():
            if tid.startswith(target_trace):
                target_spans = spans
                target_trace = tid
                break
    if not target_spans:
        print(f"Trace not found: {target_trace}", file=sys.stderr)
        sys.exit(1)
    trace_sets = [(target_trace, target_spans)]
else:
    trace_sets = [(tid, spans) for tid, spans in traces.items()
                  if any(s.get('status') == 'error' for s in spans)]
    if not trace_sets:
        trace_sets = list(traces.items())

# --- Load oracle failures (if provided) ---
# Parses FAIL lines: "FAIL: cat/product-margin Widget A"
# Maps label prefix (cat/, ord/, etc.) to service name

PREFIX_MAP = {
    'cat/': 'catalog', 'cust/': 'customers', 'inv/': 'inventory',
    'ord/': 'orders', 'rep/': 'reports', 'ship/': 'shipping',
    'bill/': 'billing', 'proc/': 'procurement', 'promo/': 'promotions',
    'emp/': 'employees', 'analytics/': 'analytics',
    'notif/': 'notifications', 'audit/': 'audit',
}

oracle_failures = []  # [{label, service, function, expected, actual}]
if oracle_output_path and os.path.exists(oracle_output_path):
    with open(oracle_output_path) as f:
        lines = f.readlines()
    i = 0
    while i < len(lines):
        m = re.match(r'^FAIL:\s+(.+)', lines[i].rstrip())
        if m:
            label = m.group(1)
            expected = actual = None
            if i + 1 < len(lines):
                em = re.match(r'\s+expected:\s+(.+)', lines[i+1].rstrip())
                if em: expected = em.group(1)
            if i + 2 < len(lines):
                am = re.match(r'\s+actual:\s+(.+)', lines[i+2].rstrip())
                if am: actual = am.group(1)

            service = function = None
            for prefix, svc in PREFIX_MAP.items():
                if label.startswith(prefix):
                    service = svc
                    fn_match = re.match(r'[\w?!<>*+\-/]+', label[len(prefix):])
                    if fn_match:
                        function = fn_match.group(0)
                    break

            # Also try "module fn-name ..." pattern (e.g., "order1 discount")
            if not service:
                for prefix, svc in PREFIX_MAP.items():
                    short = prefix.rstrip('/')
                    if label.lower().startswith(short):
                        service = svc
                        break

            oracle_failures.append({
                'label': label,
                'service': service,
                'function': function,
                'expected': expected,
                'actual': actual,
            })
        i += 1
    print(f"  Oracle failures loaded: {len(oracle_failures)}", file=sys.stderr)
elif oracle_output_path:
    # Read from stdin-piped oracle output
    pass

print("CROSS-SERVICE BLAME ANALYSIS")
print("=" * 72)
print(f"  Traces: {len(trace_sets)}, spans: {len(all_spans)}, oracle failures: {len(oracle_failures)}\n")

# Build service call graph from traces
svc_calls = defaultdict(set)  # from_service -> set of to_service
svc_operations = defaultdict(set)  # service -> set of operations called
for span in all_spans:
    svc = span.get('service', '?')
    op = span.get('operation', '?')
    target = span.get('tags', {}).get('target-service')
    svc_operations[svc].add(op)
    if target:
        svc_calls[svc].add(target)

# --- Strategy 1: Error span propagation ---
all_findings = []

for trace_id, spans in trace_sets:
    by_id = {s['span_id']: s for s in spans if 'span_id' in s}

    def ancestors(span):
        chain = [span]
        current = span
        while current.get('parent_span_id') in by_id:
            current = by_id[current['parent_span_id']]
            chain.append(current)
        return list(reversed(chain))

    error_spans = [s for s in spans if s.get('status') == 'error']
    for es in error_spans:
        chain = ancestors(es)
        service_chain = [s.get('service', '?') for s in chain]
        all_findings.append({
            'strategy': 'error-propagation',
            'confidence': 0.85,
            'root_service': es.get('service', '?'),
            'root_operation': es.get('operation', '?'),
            'error': es.get('error', '?'),
            'chain': ' → '.join(f"[{s}]" for s in service_chain),
        })

    # Duration anomalies
    durations = [(s, s.get('duration_ms', 0)) for s in spans if s.get('duration_ms')]
    if durations:
        avg_dur = sum(d for _, d in durations) / len(durations)
        for span, dur in durations:
            if dur > avg_dur * 3 and dur > 10:
                all_findings.append({
                    'strategy': 'duration-anomaly',
                    'confidence': 0.55,
                    'root_service': span.get('service', '?'),
                    'root_operation': span.get('operation', '?'),
                    'detail': f'{dur}ms (avg {avg_dur:.0f}ms)',
                })

# --- Strategy 2: Oracle failure → service blame via call graph ---

if oracle_failures:
    # Map each failure to the service that owns the failing function
    service_failure_count = defaultdict(int)
    failure_by_service = defaultdict(list)

    for fail in oracle_failures:
        if fail['service']:
            service_failure_count[fail['service']] += 1
            failure_by_service[fail['service']].append(fail)

    # For each failing service, check if its failures could be caused by
    # an upstream service it depends on (via trace call graph)
    upstream_blame = defaultdict(lambda: {'count': 0, 'downstream': set(), 'failures': []})

    for failing_svc, failures in failure_by_service.items():
        # Which services does this service call?
        called_services = svc_calls.get(failing_svc, set())
        for upstream_svc in called_services:
            # Does the upstream service ALSO have failures?
            if upstream_svc in service_failure_count:
                upstream_blame[upstream_svc]['count'] += len(failures)
                upstream_blame[upstream_svc]['downstream'].add(failing_svc)
                upstream_blame[upstream_svc]['failures'].extend(failures)

    # Root cause candidates: services whose bugs cascade to others
    for svc, info in upstream_blame.items():
        if info['downstream']:
            all_findings.append({
                'strategy': 'oracle-cascade',
                'confidence': 0.80,
                'root_service': svc,
                'root_operation': f'{len(failure_by_service.get(svc, []))} own failure(s)',
                'detail': f'cascades to {", ".join(sorted(info["downstream"]))} '
                          f'({info["count"]} downstream failures)',
            })

    # Direct oracle failures per service
    for svc, failures in sorted(failure_by_service.items(), key=lambda x: -len(x[1])):
        fn_names = set(f['function'] for f in failures if f['function'])
        all_findings.append({
            'strategy': 'oracle-direct',
            'confidence': 0.75,
            'root_service': svc,
            'root_operation': ', '.join(sorted(fn_names)[:5]),
            'detail': f'{len(failures)} assertion(s) failed',
            'failures': [(f['label'], f['expected'], f['actual']) for f in failures[:5]],
        })

    # Cross-service mismatch: failure label references a function that
    # is CALLED from another service (visible in traces)
    for fail in oracle_failures:
        if not fail['function']:
            continue
        fn_qual = f"{fail['function']}"
        # Find which services call this function (from traces)
        callers = set()
        for span in all_spans:
            op = span.get('operation', '')
            if op.endswith(f'/{fn_qual}') or op == fn_qual:
                caller_svc = span.get('service', '?')
                if caller_svc != fail['service']:
                    callers.add(caller_svc)
        if callers:
            all_findings.append({
                'strategy': 'cross-service-call-failure',
                'confidence': 0.70,
                'root_service': fail['service'] or '?',
                'root_operation': fail['function'],
                'detail': f'{fail["label"]}: called by {", ".join(sorted(callers))}; '
                          f'expected={fail["expected"]}, actual={fail["actual"]}',
            })

# --- Output ---

if not all_findings:
    print("  No blame findings.")
    sys.exit(0)

# Sort by confidence, then strategy
all_findings.sort(key=lambda f: (-f['confidence'], f['strategy']))

# Deduplicate by (strategy, root_service, root_operation)
seen = set()
deduped = []
for f in all_findings:
    key = (f['strategy'], f['root_service'], f.get('root_operation', ''))
    if key not in seen:
        seen.add(key)
        deduped.append(f)
all_findings = deduped

for i, f in enumerate(all_findings):
    print(f"\n  [{i+1}] {f['strategy']}  confidence: {f['confidence']:.2f}")
    print(f"      service: {f['root_service']}")
    print(f"      operation: {f['root_operation']}")
    if 'error' in f:
        print(f"      error: {str(f['error'])[:80]}")
    if 'chain' in f:
        print(f"      chain: {f['chain']}")
    if 'detail' in f:
        print(f"      detail: {f['detail']}")
    if 'failures' in f:
        for label, exp, act in f['failures'][:3]:
            print(f"      FAIL: {label}")
            if exp: print(f"        expected: {exp}")
            if act: print(f"        actual:   {act}")

# --- Service blame ranking ---

print(f"\n\n{'─' * 72}")
print("SERVICE BLAME RANKING")
print(f"{'─' * 72}")

service_blame_count = defaultdict(float)
for f in all_findings:
    service_blame_count[f['root_service']] += f['confidence']

ranked = sorted(service_blame_count.items(), key=lambda x: -x[1])
max_score = max(s for _, s in ranked) if ranked else 1
for svc, score in ranked:
    bar_len = int(score / max_score * 20)
    bar = '█' * bar_len + '░' * (20 - bar_len)
    findings_for = sum(1 for f in all_findings if f['root_service'] == svc)
    print(f"  {bar}  {svc}: score {score:.1f} ({findings_for} finding{'s' if findings_for != 1 else ''})")

# Cross-service cascade summary
cascades = [f for f in all_findings if f['strategy'] == 'oracle-cascade']
if cascades:
    print(f"\n\nCROSS-SERVICE CASCADES")
    print(f"{'─' * 72}")
    for c in cascades:
        print(f"  [{c['root_service']}] → {c['detail']}")

print()
PYEOF
}

# ═══════════════════════════════════════════════════════════════════════════
# GRAPH — service dependency graph from traces
# ═══════════════════════════════════════════════════════════════════════════

do_graph() {
    if [[ $# -lt 1 ]]; then
        echo "Usage: beagle-dtrace graph <trace-dir>" >&2
        exit 1
    fi

    TRACE_DIR="$1"; shift

    python3 - "$TRACE_DIR" << 'PYEOF'
import sys, os, json
from collections import defaultdict

trace_dir = sys.argv[1]

# Load all spans
spans = []
for fname in sorted(os.listdir(trace_dir)):
    if not fname.endswith('.jsonl') or fname == 'trace-index.jsonl':
        continue
    with open(os.path.join(trace_dir, fname)) as f:
        for line in f:
            line = line.strip()
            if line:
                try:
                    spans.append(json.loads(line))
                except json.JSONDecodeError:
                    pass

if not spans:
    print("No spans found.", file=sys.stderr)
    sys.exit(1)

# Build service graph from client→server span relationships
edges = defaultdict(lambda: defaultdict(int))  # from_svc -> to_svc -> count
service_stats = defaultdict(lambda: {'spans': 0, 'errors': 0, 'total_ms': 0, 'operations': set()})
operation_edges = defaultdict(int)  # (from_svc, from_op, to_svc, to_op) -> count

by_id = {s['span_id']: s for s in spans if 'span_id' in s}

for span in spans:
    svc = span.get('service', 'unknown')
    service_stats[svc]['spans'] += 1
    service_stats[svc]['total_ms'] += span.get('duration_ms', 0) or 0
    service_stats[svc]['operations'].add(span.get('operation', '?'))
    if span.get('status') == 'error':
        service_stats[svc]['errors'] += 1

    # Client spans indicate outgoing calls
    kind = span.get('tags', {}).get('kind', '')
    if kind == 'client':
        target = span.get('tags', {}).get('target-service')
        if target:
            from_svc = svc
            edges[from_svc][target] += 1
            operation_edges[(from_svc, span.get('operation', '?'), target, '?')] += 1

    # Also detect edges from parent→child where services differ
    pid = span.get('parent_span_id')
    if pid and pid in by_id:
        parent = by_id[pid]
        parent_svc = parent.get('service', 'unknown')
        if parent_svc != svc:
            edges[parent_svc][svc] += 1

print("SERVICE DEPENDENCY GRAPH")
print("=" * 72)
print(f"\n  {len(service_stats)} services, {len(spans)} spans\n")

# Service summary
print("SERVICES")
print("-" * 72)
for svc in sorted(service_stats.keys()):
    stats = service_stats[svc]
    avg_ms = stats['total_ms'] / stats['spans'] if stats['spans'] > 0 else 0
    err_rate = stats['errors'] / stats['spans'] * 100 if stats['spans'] > 0 else 0
    ops = sorted(stats['operations'])
    print(f"\n  [{svc}]")
    print(f"    spans: {stats['spans']}  avg duration: {avg_ms:.0f}ms  error rate: {err_rate:.0f}%")
    print(f"    operations: {', '.join(ops[:8])}")
    if len(ops) > 8:
        print(f"                ... and {len(ops) - 8} more")

# Edge list
print(f"\n\nCALL EDGES")
print("-" * 72)
all_edges = []
for from_svc, targets in sorted(edges.items()):
    for to_svc, count in sorted(targets.items(), key=lambda x: -x[1]):
        all_edges.append((from_svc, to_svc, count))
        print(f"  {from_svc} → {to_svc}  ({count} call{'s' if count > 1 else ''})")

# ASCII dependency diagram
print(f"\n\nDEPENDENCY DIAGRAM")
print("-" * 72)

# Find leaf services (no outgoing calls) and root services (no incoming)
all_services = set(service_stats.keys())
has_outgoing = set(from_svc for from_svc, _, _ in all_edges)
has_incoming = set(to_svc for _, to_svc, _ in all_edges)
roots = all_services - has_incoming
leaves = all_services - has_outgoing
middle = all_services - roots - leaves

if roots:
    print(f"\n  Entry points: {', '.join(sorted(roots))}")
if middle:
    print(f"  Internal:     {', '.join(sorted(middle))}")
if leaves:
    print(f"  Leaf:         {', '.join(sorted(leaves))}")

# DAG-style rendering
print()
for from_svc, to_svc, count in sorted(all_edges, key=lambda x: -x[2]):
    arrow = f"{'─' * max(1, 20 - len(from_svc))}→"
    print(f"  [{from_svc}] {arrow} [{to_svc}]  ×{count}")

# Impact analysis: which service, if it goes down, affects the most others?
print(f"\n\nIMPACT ANALYSIS")
print("-" * 72)

def transitive_dependents(svc, visited=None):
    if visited is None:
        visited = set()
    if svc in visited:
        return visited
    visited.add(svc)
    # Who depends on this service? (reverse edges)
    for from_svc, to_svc, _ in all_edges:
        if to_svc == svc and from_svc not in visited:
            transitive_dependents(from_svc, visited)
    return visited

for svc in sorted(all_services):
    deps = transitive_dependents(svc) - {svc}
    if deps:
        print(f"  [{svc}] failure affects: {', '.join(sorted(deps))}")

print()
PYEOF
}

# ═══════════════════════════════════════════════════════════════════════════
# CASCADE — root cause analysis across service boundaries
# ═══════════════════════════════════════════════════════════════════════════

do_cascade() {
    if [[ $# -lt 1 ]]; then
        echo "Usage: beagle-dtrace cascade <trace-dir> [--trace-id <id>]" >&2
        exit 1
    fi

    TRACE_DIR="$1"; shift
    TRACE_ID=""

    while [[ $# -gt 0 ]]; do
        case "$1" in
            --trace-id) TRACE_ID="$2"; shift 2 ;;
            *) echo "Unknown option: $1" >&2; exit 1 ;;
        esac
    done

    python3 - "$TRACE_DIR" "$TRACE_ID" << 'PYEOF'
import sys, os, json, re
from collections import defaultdict

trace_dir = sys.argv[1]
target_trace = sys.argv[2] if len(sys.argv) > 2 and sys.argv[2] else ""

# Load all spans
all_spans = []
for fname in sorted(os.listdir(trace_dir)):
    if not fname.endswith('.jsonl') or fname == 'trace-index.jsonl':
        continue
    with open(os.path.join(trace_dir, fname)) as f:
        for line in f:
            line = line.strip()
            if line:
                try:
                    all_spans.append(json.loads(line))
                except json.JSONDecodeError:
                    pass

if not all_spans:
    print("No spans found.", file=sys.stderr)
    sys.exit(1)

# Group by trace
traces = defaultdict(list)
for span in all_spans:
    traces[span.get('trace_id', '')].append(span)

# Select trace(s)
if target_trace:
    # Prefix match
    matched = [(tid, spans) for tid, spans in traces.items()
               if tid.startswith(target_trace)]
    if not matched:
        print(f"No traces matching: {target_trace}", file=sys.stderr)
        sys.exit(1)
    trace_sets = matched
else:
    # All traces with errors
    trace_sets = [(tid, spans) for tid, spans in traces.items()
                  if any(s.get('status') == 'error' for s in spans)]
    if not trace_sets:
        print("No error traces found. Analyzing all traces.", file=sys.stderr)
        trace_sets = list(traces.items())[:10]

print("CROSS-SERVICE CASCADE ANALYSIS")
print("=" * 72)

for trace_id, spans in trace_sets:
    print(f"\n\nTRACE: {trace_id[:32]}...")
    print("-" * 72)

    # Build span tree
    by_id = {s['span_id']: s for s in spans if 'span_id' in s}
    children = defaultdict(list)
    roots = []

    for s in spans:
        pid = s.get('parent_span_id')
        if pid and pid in by_id:
            children[pid].append(s)
        else:
            roots.append(s)

    # Walk tree depth-first to find cascade chains
    # A cascade: an error or anomaly in a leaf service causes failures up the tree

    def find_cascade_chains(span, path=None):
        """Walk from root toward leaves, collecting service transition points."""
        if path is None:
            path = []
        path = path + [{
            'service': span.get('service', '?'),
            'operation': span.get('operation', '?'),
            'status': span.get('status', '?'),
            'duration_ms': span.get('duration_ms', 0),
            'span_id': span.get('span_id'),
            'error': span.get('error'),
            'local_trace': span.get('local_trace', []),
        }]

        child_spans = children.get(span.get('span_id', ''), [])
        if not child_spans:
            # Leaf — report this chain
            return [path]

        chains = []
        for child in child_spans:
            chains.extend(find_cascade_chains(child, path))
        return chains

    all_chains = []
    for root in roots:
        all_chains.extend(find_cascade_chains(root))

    # Identify cascade patterns:
    # 1. Error at leaf propagating upward
    # 2. Slow leaf causing slow parents
    # 3. Service boundary where behavior changes

    cascade_findings = []

    for chain in all_chains:
        # Find service transitions
        transitions = []
        for i in range(1, len(chain)):
            if chain[i]['service'] != chain[i-1]['service']:
                transitions.append({
                    'from': chain[i-1],
                    'to': chain[i],
                    'index': i,
                })

        if not transitions:
            continue

        # Check for error cascade: error at deepest, propagates up
        errors_in_chain = [(i, node) for i, node in enumerate(chain)
                          if node['status'] == 'error']

        if errors_in_chain:
            deepest_error_idx, deepest_error = errors_in_chain[-1]
            # Which service boundary did it cross?
            crossed = [t for t in transitions if t['index'] <= deepest_error_idx]
            if crossed:
                last_crossing = crossed[-1]
                cascade_findings.append({
                    'type': 'error-cascade',
                    'root_service': deepest_error['service'],
                    'root_operation': deepest_error['operation'],
                    'error': deepest_error.get('error', '?'),
                    'crossed_boundary': f"{last_crossing['from']['service']} → {last_crossing['to']['service']}",
                    'affected_services': list(dict.fromkeys(
                        node['service'] for node in chain[:deepest_error_idx]
                        if node['status'] == 'error')),
                    'depth': deepest_error_idx,
                })

        # Check for latency cascade
        for t in transitions:
            from_dur = t['from'].get('duration_ms', 0) or 0
            to_dur = t['to'].get('duration_ms', 0) or 0
            if to_dur > from_dur * 0.8 and to_dur > 10:
                cascade_findings.append({
                    'type': 'latency-cascade',
                    'root_service': t['to']['service'],
                    'root_operation': t['to']['operation'],
                    'detail': f"{t['to']['service']} took {to_dur}ms "
                              f"({to_dur/from_dur*100:.0f}% of caller's {from_dur}ms)"
                              if from_dur > 0 else f"{t['to']['service']} took {to_dur}ms",
                    'crossed_boundary': f"{t['from']['service']} → {t['to']['service']}",
                })

    # Deduplicate by root_service + root_operation
    seen = set()
    unique_findings = []
    for f in cascade_findings:
        key = (f.get('root_service'), f.get('root_operation'), f['type'])
        if key not in seen:
            seen.add(key)
            unique_findings.append(f)

    # Output cascade findings
    if not unique_findings:
        services = sorted(set(s.get('service', '?') for s in spans))
        print(f"  No cascades detected across {len(services)} services")
        continue

    print(f"  {len(unique_findings)} cascade pattern(s) found:\n")

    for i, f in enumerate(unique_findings):
        marker = '✗' if f['type'] == 'error-cascade' else '⏱'
        print(f"  {marker} [{i+1}] {f['type']}")
        print(f"      root: [{f['root_service']}] {f['root_operation']}")
        print(f"      boundary: {f['crossed_boundary']}")
        if 'error' in f:
            print(f"      error: {str(f['error'])[:60]}")
        if 'detail' in f:
            print(f"      detail: {f['detail']}")
        if 'affected_services' in f and f['affected_services']:
            print(f"      affected: {', '.join(f['affected_services'])}")
        print()

    # Root cause ranking
    root_counts = defaultdict(int)
    for f in unique_findings:
        root_counts[f['root_service']] += 1

    if len(root_counts) > 1:
        print(f"  ROOT CAUSE RANKING:")
        for svc, count in sorted(root_counts.items(), key=lambda x: -x[1]):
            print(f"    [{svc}]: {count} cascade(s)")

# --- Aggregate across all traces ---

if len(trace_sets) > 1:
    print(f"\n\n{'═' * 72}")
    print(f"AGGREGATE CASCADE ANALYSIS ({len(trace_sets)} traces)")
    print(f"{'═' * 72}")

    # Collect all root services across all traces
    all_roots = defaultdict(int)
    all_boundaries = defaultdict(int)

    for trace_id, spans in trace_sets:
        error_svcs = set(s.get('service', '?') for s in spans if s.get('status') == 'error')
        for svc in error_svcs:
            all_roots[svc] += 1

        # Count boundary crossings
        by_id = {s['span_id']: s for s in spans if 'span_id' in s}
        for s in spans:
            pid = s.get('parent_span_id')
            if pid and pid in by_id:
                p = by_id[pid]
                if p.get('service') != s.get('service'):
                    boundary = f"{p.get('service','?')} → {s.get('service','?')}"
                    all_boundaries[boundary] += 1

    print(f"\n  Services with errors:")
    for svc, count in sorted(all_roots.items(), key=lambda x: -x[1]):
        bar = '█' * min(40, count) + f" {count}"
        print(f"    [{svc}] {bar}")

    print(f"\n  Hot boundaries (most crossings):")
    for boundary, count in sorted(all_boundaries.items(), key=lambda x: -x[1])[:10]:
        print(f"    {boundary}: {count} crossing(s)")

print()
PYEOF
}

# ═══════════════════════════════════════════════════════════════════════════
# Dispatch
# ═══════════════════════════════════════════════════════════════════════════

case "$SUBCMD" in
    instrument) do_instrument "$@" ;;
    collect)    do_collect "$@" ;;
    view)       do_view "$@" ;;
    blame)      do_blame "$@" ;;
    graph)      do_graph "$@" ;;
    cascade)    do_cascade "$@" ;;
    *)
        echo "Unknown subcommand: $SUBCMD" >&2
        echo "Use: instrument | collect | view | blame | graph | cascade" >&2
        exit 1
        ;;
esac
