#!/usr/bin/env bash
# beagle-trace: instrumented tracing for compiled beagle modules.
#
# Wraps arithmetic/comparison operations with value capture, runs the
# oracle, and on failure shows the sub-expression trace: exactly which
# operation produced the divergent value and where in the source it lives.
#
# Usage:
#   beagle-trace <build-dir> <verify-script>
#   beagle-trace <build-dir> <verify-script> --focus <fn-name>
#
# Output:
#   For each failing assertion, shows the computation trace through
#   the relevant function(s), highlighting the first divergence point.

set -euo pipefail

if [[ $# -lt 2 ]]; then
    echo "Usage: beagle-trace <build-dir> <verify-script> [--focus <fn-name>]" >&2
    exit 1
fi

BUILD_DIR="$(cd "$1" && pwd)"
VERIFY_SCRIPT="$(realpath "$2")"
FOCUS_FN="${4:-}"

WORK_DIR="$(mktemp -d /tmp/beagle-trace.XXXXXX)"
trap 'rm -rf "$WORK_DIR"' EXIT

# ---------------------------------------------------------------------------
# Phase 1: Generate instrumented versions of all .clj modules
# ---------------------------------------------------------------------------

echo "=== beagle-trace: instrumenting $BUILD_DIR ===" >&2

python3 - "$BUILD_DIR" "$WORK_DIR" "$FOCUS_FN" << 'PYEOF'
import sys, os, re, shutil
from pathlib import Path

build_dir = sys.argv[1]
work_dir = sys.argv[2]
focus_fn = sys.argv[3] if len(sys.argv) > 3 and sys.argv[3] else None

TRACED_OPS = {'+', '-', '*', '/', 'mod', 'rem', 'quot',
              '>', '<', '>=', '<=', '=', 'not=', 'compare'}

def extract_meta(expr_before):
    """Extract :line and :file from ^{:line N :file "path"} metadata."""
    m = re.match(r'\^\{:line\s+(\d+)\s+:file\s+"([^"]+)"\}', expr_before)
    if m:
        return int(m.group(1)), m.group(2)
    return None, None

def collect_user_fns(build_dir):
    """Collect all user-defined function names across modules."""
    fns = set()
    for fname in os.listdir(build_dir):
        if not fname.endswith('.clj'):
            continue
        module = fname[:-4]
        with open(os.path.join(build_dir, fname)) as f:
            content = f.read()
        for m in re.finditer(r'\(defn\s+([\w?!<>*+\-/=]+)', content):
            fn_name = m.group(1)
            fns.add(fn_name)
            fns.add(f'{module}/{fn_name}')
    return fns

user_fns = collect_user_fns(build_dir)

def instrument_file(src_path, dst_path, focus_fn=None):
    """Instrument a .clj file by wrapping arithmetic ops and function calls with trace capture."""
    with open(src_path) as f:
        content = f.read()

    module = Path(src_path).stem

    lines = content.split('\n')
    instrumented_lines = []
    in_target_fn = focus_fn is None
    current_fn = None

    for i, line in enumerate(lines):
        fn_match = re.search(r'\(defn\s+([\w?!<>*+\-/=]+)', line)
        if fn_match:
            current_fn = fn_match.group(1)
            if focus_fn:
                in_target_fn = (current_fn == focus_fn)
            else:
                in_target_fn = True

        if not in_target_fn:
            instrumented_lines.append(line)
            continue

        new_line = line

        # Instrument arithmetic ops
        for op in TRACED_OPS:
            pattern = r'(\^\{:line\s+(\d+)\s+:file\s+"([^"]+)"\}\s*)\(' + re.escape(op) + r'\s+'
            def replace_op(m):
                meta = m.group(1)
                ln = m.group(2)
                fl = m.group(3)
                return f'{meta}(beagle.trace/t "{op}" {ln} "{fl}" '
            new_line = re.sub(pattern, replace_op, new_line)

        # Instrument user-defined function calls (call-graph walk)
        # Match: ^{:line N :file "path"} (fn-name ...)
        # Replace with: ^{:line N :file "path"} (beagle.trace/tc "fn-name" N "file" fn-name ...)
        fn_call_pattern = r'(\^\{:line\s+(\d+)\s+:file\s+"([^"]+)"\}\s*)\(([\w?!<>*+\-/=]+)\s+'
        def replace_call(m):
            meta = m.group(1)
            ln = m.group(2)
            fl = m.group(3)
            called = m.group(4)
            if called in user_fns and called not in TRACED_OPS:
                return f'{meta}(beagle.trace/tc "{called}" {ln} "{fl}" {called} '
            return m.group(0)
        new_line = re.sub(fn_call_pattern, replace_call, new_line)

        instrumented_lines.append(new_line)

    with open(dst_path, 'w') as f:
        f.write('\n'.join(instrumented_lines))

    return current_fn is not None

# Copy all .clj files, instrumenting each
for fname in os.listdir(build_dir):
    if fname.endswith('.clj'):
        src = os.path.join(build_dir, fname)
        dst = os.path.join(work_dir, fname)
        instrument_file(src, dst, focus_fn)
        print(f"  instrumented: {fname}", file=sys.stderr)

# Write the tracing runtime
trace_runtime = os.path.join(work_dir, "beagle", "trace.clj")
os.makedirs(os.path.dirname(trace_runtime), exist_ok=True)

with open(trace_runtime, 'w') as f:
    f.write('''(ns beagle.trace)

(def ^:dynamic *trace-enabled* (atom true))
(def ^:dynamic *trace-log* (atom []))
(def ^:dynamic *trace-depth* (atom 0))
(def max-trace-depth 200)

(defn t
  "Traced operation wrapper. Captures op, args, result, source location."
  [op-name line file & args]
  (let [depth @*trace-depth*]
    (if (or (not @*trace-enabled*) (> depth max-trace-depth))
      (apply (resolve (symbol op-name)) args)
      (do
        (swap! *trace-depth* inc)
        (let [result (apply (resolve (symbol op-name)) args)]
          (swap! *trace-depth* dec)
          (swap! *trace-log* conj
            {:op op-name
             :args (vec args)
             :result result
             :line line
             :file file
             :depth depth})
          result)))))

(defn tc
  "Traced function call wrapper. Records call entry/exit with args and result."
  [fn-name line file target-fn & args]
  (let [depth @*trace-depth*]
    (if (or (not @*trace-enabled*) (> depth max-trace-depth))
      (apply target-fn args)
      (do
        (swap! *trace-log* conj
          {:op (str "call:" fn-name)
           :args (vec args)
           :result :pending
           :line line
           :file file
           :depth depth})
        (swap! *trace-depth* inc)
        (let [result (apply target-fn args)]
          (swap! *trace-depth* dec)
          (swap! *trace-log* conj
            {:op (str "ret:" fn-name)
             :args []
             :result result
             :line line
             :file file
             :depth depth})
          result)))))

(defn reset-trace! [] (reset! *trace-log* []))

(defn get-trace [] @*trace-log*)

(defn format-trace-entry [entry]
  (let [{:keys [op args result line file depth]} entry
        indent (apply str (repeat depth "  "))
        short-file (last (clojure.string/split file #"/"))]
    (cond
      (.startsWith ^String op "call:")
        (format "%s→ %s(%s)  ; %s:%s"
          indent (subs op 5)
          (clojure.string/join ", " (map #(if (> (count (str %)) 30) (str (subs (str %) 0 27) "...") (str %)) args))
          short-file line)
      (.startsWith ^String op "ret:")
        (format "%s← %s = %s  ; %s:%s"
          indent (subs op 4) (str result)
          short-file line)
      :else
        (format "%s(%s %s) = %s  ; %s:%s"
          indent op
          (clojure.string/join " " (map str args))
          (str result)
          short-file line))))

(defn dump-trace
  "Print the trace log, optionally filtered to a file or function."
  ([] (dump-trace nil))
  ([filter-file]
    (doseq [entry @*trace-log*]
      (when (or (nil? filter-file)
                (and (:file entry)
                     (.contains ^String (:file entry) ^String filter-file)))
        (println (format-trace-entry entry))))))

(defn trace-around-assertion
  "Run a thunk with tracing, return {:result value :trace entries}."
  [thunk]
  (reset-trace!)
  (let [result (thunk)]
    {:result result :trace @*trace-log*}))

;; Semantic rules: function name patterns → expected/suspicious ops
(def semantic-rules
  [{:pattern #"total|sum-of|aggregate" :expected #{"+" "*"} :suspicious #{"-"} :reason "aggregation"}
   {:pattern #"discount|rebate|deduct" :expected #{"-"} :suspicious #{"+"} :reason "reduction"}
   {:pattern #"margin|profit|markup" :expected #{"-"} :suspicious #{} :reason "price-minus-cost"}
   {:pattern #"commission|surcharge|fee-amount|tax-amount" :expected #{"*"} :suspicious #{"+"} :reason "rate*base"}
   {:pattern #"line-total|line-cost|poline-total" :expected #{"*"} :suspicious #{"+" "-"} :reason "unit*quantity"}
   {:pattern #"count-|num-|-count$" :expected #{} :suspicious #{"-"} :reason "count (non-negative)"}])

(defn check-op-semantic [fn-name op]
  (when fn-name
    (first
      (for [rule semantic-rules
            :when (re-find (:pattern rule) fn-name)
            :when (contains? (:suspicious rule) op)]
        (:reason rule)))))

(defn annotate-trace [entries fn-context]
  (mapv (fn [entry]
    (let [warning (check-op-semantic fn-context (:op entry))]
      (if warning (assoc entry :warning warning) entry)))
    entries))
''')

print(f"  wrote trace runtime: beagle/trace.clj", file=sys.stderr)
PYEOF

# ---------------------------------------------------------------------------
# Phase 2: Pre-process verify script with traced assert-eq
# ---------------------------------------------------------------------------

echo "=== Phase 2: Patching verify script ===" >&2

python3 - "$VERIFY_SCRIPT" "$WORK_DIR/traced-verify.clj" << 'PATCHEOF'
import sys, re

verify_path = sys.argv[1]
output_path = sys.argv[2]

with open(verify_path) as f:
    content = f.read()

# Replace the assert-eq definition with a traced version that resets
# the trace log before evaluating and dumps it on failure.
# The key insight: the actual value is already computed by the time
# assert-eq is called, but the trace log captured all ops during that computation.
# We need to reset BEFORE the call — so we wrap each assertion call site.

# Strategy: replace the assert-eq defn with one that captures traces
old_assert = '''(defn assert-eq [label expected actual]
  (if (= expected actual)
    (swap! *pass-count* inc)
    (do (swap! *fail-count* inc)
        (println (str "FAIL: " label "\\n  expected: " expected "\\n  actual:   " actual)))))'''

traced_assert = '''(defn assert-eq [label expected actual]
  (if (= expected actual)
    (do (swap! *pass-count* inc)
        (beagle.trace/reset-trace!))
    (do (swap! *fail-count* inc)
        (println (str "FAIL: " label "\\n  expected: " expected "\\n  actual:   " actual))
        (let [trace (beagle.trace/get-trace)
              fn-name (second (re-find #"^[^/]*/(.+)" label))
              relevant (filter #(or (#{"+" "-" "*" "/" ">" "<" ">=" "<="} (:op %))
                                    (.startsWith ^String (:op %) "call:")
                                    (.startsWith ^String (:op %) "ret:")) trace)
              annotated (beagle.trace/annotate-trace (vec relevant) fn-name)
              trimmed (if (> (count annotated) 20) (take-last 20 annotated) annotated)]
          (when (seq trimmed)
            (println "  trace:")
            (doseq [entry trimmed]
              (print (str "    " (beagle.trace/format-trace-entry entry)))
              (when (:warning entry)
                (print (str "  ⚠ SUSPECT: " (:warning entry))))
              (println)))
          (println))
        (beagle.trace/reset-trace!))))'''

if old_assert in content:
    content = content.replace(old_assert, traced_assert)
    print("  patched assert-eq with trace capture", file=sys.stderr)
else:
    # Try a more forgiving match
    content = re.sub(
        r'\(defn assert-eq \[label expected actual\].*?\)\)\)',
        traced_assert,
        content,
        count=1,
        flags=re.DOTALL
    )
    print("  patched assert-eq (regex fallback)", file=sys.stderr)

# Add require for beagle.trace at top of ns form
content = content.replace(
    '(:require ',
    '(:require [beagle.trace] '
)

with open(output_path, 'w') as f:
    f.write(content)

print(f"  wrote {output_path}", file=sys.stderr)
PATCHEOF

# ---------------------------------------------------------------------------
# Phase 3: Run traced oracle
# ---------------------------------------------------------------------------

echo "=== Phase 3: Running traced oracle ===" >&2

bb -cp "$WORK_DIR" -e "(load-file \"$WORK_DIR/traced-verify.clj\")" 2>&1

echo "" >&2
echo "=== beagle-trace complete ===" >&2
