#!/usr/bin/env bash
# beagle-oracle: behavioral oracle synthesis from golden reference code.
#
# Compiles golden code, runs all exported functions with generated test data,
# captures outputs as expected values, and emits a verify script. No
# handwritten assertions — the golden code IS the oracle.
#
# Usage:
#   beagle-oracle <golden-source-dir> [--out oracle.verify.clj]
#   beagle-oracle <golden-source-dir> --diff <modified-source-dir>
#
# Modes:
#   (default)   Generate full oracle from golden code
#   --diff      Generate assertions only for functions whose output differs

set -euo pipefail

if [[ $# -lt 1 ]]; then
    echo "Usage: beagle-oracle <golden-source-dir> [--out FILE] [--diff <modified-dir>]" >&2
    exit 1
fi

SOURCE_DIR="$(cd "$1" && pwd)"
shift

OUT_FILE=""
DIFF_DIR=""

while [[ $# -gt 0 ]]; do
    case "$1" in
        --out) OUT_FILE="$2"; shift 2 ;;
        --diff) DIFF_DIR="$(cd "$2" && pwd)"; shift 2 ;;
        *) echo "Unknown option: $1" >&2; exit 1 ;;
    esac
done

BEAGLE_BIN="$(cd "$(dirname "$0")" && pwd)"
WORK_DIR="$(mktemp -d /tmp/beagle-oracle.XXXXXX)"
trap 'rm -rf "$WORK_DIR"' EXIT

# ---------------------------------------------------------------------------
# Phase 1: Compile golden code
# ---------------------------------------------------------------------------

echo "─── Phase 1: Compile golden code ───" >&2

BUILD_DIR="$WORK_DIR/golden-build"
mkdir -p "$BUILD_DIR"

"$BEAGLE_BIN/beagle-build-all" --warn "$SOURCE_DIR" --out "$BUILD_DIR" 2>/dev/null

MODULE_COUNT=$(ls "$BUILD_DIR"/*.clj 2>/dev/null | wc -l)
echo "  $MODULE_COUNT modules compiled" >&2

# ---------------------------------------------------------------------------
# Phase 2: Extract type info and generate test data + oracle
# ---------------------------------------------------------------------------

echo "─── Phase 2: Generating oracle ───" >&2

python3 - "$SOURCE_DIR" "$BUILD_DIR" "$BEAGLE_BIN" "$WORK_DIR" "$OUT_FILE" "$DIFF_DIR" << 'PYEOF'
import sys, os, re, subprocess, json

source_dir = sys.argv[1]
build_dir = sys.argv[2]
beagle_bin = sys.argv[3]
work_dir = sys.argv[4]
out_file = sys.argv[5]
diff_dir = sys.argv[6]

BEAGLE_EXTENSIONS = ('.bclj', '.bjs', '.bnix', '.bsql', '.bpy', '.bgl', '.rkt')
def is_beagle_source(fname):
    return any(fname.endswith(ext) for ext in BEAGLE_EXTENSIONS)
def strip_beagle_ext(fname):
    for ext in BEAGLE_EXTENSIONS:
        if fname.endswith(ext):
            return fname[:-len(ext)]
    return fname

# ---------------------------------------------------------------------------
# Extract type information from all modules
# ---------------------------------------------------------------------------

modules = {}
all_records = {}  # RecordName -> {module, fields}

for fname in sorted(os.listdir(source_dir)):
    if not is_beagle_source(fname):
        continue
    fpath = os.path.join(source_dir, fname)
    module_name = strip_beagle_ext(fname)

    try:
        result = subprocess.run(
            [os.path.join(beagle_bin, 'beagle-provides'), fpath],
            capture_output=True, text=True, timeout=30
        )
        output = result.stdout
    except (subprocess.TimeoutExpired, Exception):
        continue

    records = []
    functions = []

    # Parse records
    for m in re.finditer(r'^\s+(\w+)\s+\[(.+)\]$', output, re.MULTILINE):
        rec_name = m.group(1)
        fields_str = m.group(2)
        fields = []
        for fm in re.finditer(r'([\w-]+):(\w+\??)', fields_str):
            fields.append({'name': fm.group(1), 'type': fm.group(2)})
        if fields:
            records.append({'name': rec_name, 'fields': fields})
            all_records[rec_name] = {'module': module_name, 'fields': fields}

    # Parse functions
    for m in re.finditer(r'^\s+([\w?!<>*+\-/]+)\s+:\s+\[(.+)\]$', output, re.MULTILINE):
        fn_name = m.group(1)
        sig = m.group(2)
        parts = sig.rsplit(' -> ', 1)
        if len(parts) == 2:
            args_str = parts[0].strip()
            ret_type = parts[1].strip()
            functions.append({
                'name': fn_name,
                'args_str': args_str,
                'return_type': ret_type,
            })

    modules[module_name] = {
        'records': records,
        'functions': functions,
    }

print(f"  {len(modules)} modules, {len(all_records)} records, "
      f"{sum(len(m['functions']) for m in modules.values())} functions",
      file=sys.stderr)

# ---------------------------------------------------------------------------
# Generate test data: create sample instances of each record type
# ---------------------------------------------------------------------------

def sample_value(type_name, field_name='', index=1):
    """Generate a representative sample value for a type."""
    if type_name in ('Long', 'Int', 'Integer'):
        if 'id' in field_name.lower():
            return str(1000 + index)
        if 'qty' in field_name or 'quantity' in field_name or 'count' in field_name:
            return str(5 * index)
        if 'rate' in field_name or 'pct' in field_name:
            return str(10 + index)
        if 'cost' in field_name or 'price' in field_name or 'amount' in field_name:
            return str(100 * index)
        if 'days' in field_name or 'time' in field_name or 'date' in field_name or 'at' in field_name:
            return str(1700000000 + index)
        return str(index * 10)
    elif type_name == 'Double':
        return f'{0.1 * index}'
    elif type_name == 'String':
        if 'name' in field_name:
            return f'"test-{field_name}-{index}"'
        if 'status' in field_name:
            return '"active"'
        if 'email' in field_name:
            return f'"test{index}@example.com"'
        return f'"test-{index}"'
    elif type_name == 'Boolean':
        return 'true'
    elif type_name == 'Keyword':
        return f':{field_name}'
    elif type_name.endswith('?'):
        return sample_value(type_name[:-1], field_name, index)
    elif type_name in all_records:
        return None  # Complex type — skip for now
    else:
        # Scalar type (ProductId, Amount, etc.) — use Long
        if 'id' in field_name.lower() or 'Id' in type_name:
            return str(1000 + index)
        return str(100 * index)

# Generate Clojure code for test data
test_data_lines = []
test_data_lines.append(';; --- Auto-generated test data ---')
test_data_lines.append('')

for module_name, info in modules.items():
    for i, rec in enumerate(info['records']):
        rec_name = rec['name']
        fields = rec['fields']
        ctor = f'{module_name}/->{rec_name}'

        # Generate 2 instances per record
        for inst_idx in [1, 2]:
            vals = []
            skip = False
            for j, f in enumerate(fields):
                v = sample_value(f['type'], f['name'], inst_idx + j)
                if v is None:
                    skip = True
                    break
                vals.append(v)

            if skip:
                continue

            var_name = f'test-{rec_name.lower()}-{inst_idx}'
            test_data_lines.append(
                f'(def {var_name} ({ctor} {" ".join(vals)}))'
            )

    test_data_lines.append('')

# ---------------------------------------------------------------------------
# Generate oracle assertions by calling functions on test data
# ---------------------------------------------------------------------------

# For the generated oracle, we create a Clojure script that:
# 1. Loads all modules
# 2. Creates test data
# 3. Calls each single-arg function that takes a record type
# 4. Prints the results as assertion data

capture_script_lines = []
capture_script_lines.append('(ns oracle-capture)')
capture_script_lines.append('')

# Require all modules
for mod in modules:
    capture_script_lines.append(f'(require \'[{mod}])')
capture_script_lines.append('')

# Test data
capture_script_lines.extend(test_data_lines)
capture_script_lines.append('')
capture_script_lines.append(';; --- Capture function outputs ---')
capture_script_lines.append('(def results (atom []))')
capture_script_lines.append('')

# Call single-record-arg functions
calls_generated = 0
for module_name, info in modules.items():
    for fn in info['functions']:
        fn_name = fn['name']
        args_str = fn['args_str']
        ret_type = fn['return_type']

        # Simple heuristic: only test single-arg functions where arg is a record
        arg_types = [a.strip() for a in args_str.split() if a.strip()]
        if len(arg_types) == 1 and arg_types[0] in all_records:
            rec_type = arg_types[0]
            var_name = f'test-{rec_type.lower()}-1'
            label = f'{module_name}/{fn_name}'
            capture_script_lines.append(
                f'(try'
                f'  (let [r ({module_name}/{fn_name} {var_name})]'
                f'    (swap! results conj {{:label "{label}" :expected (pr-str r)}})'
                f'    (println (format "CAPTURE: %s = %s" "{label}" (pr-str r))))'
                f'  (catch Exception e'
                f'    (println (format "SKIP: {label} — %s" (.getMessage e)))))'
            )
            calls_generated += 1

capture_script_lines.append('')
capture_script_lines.append(f'(println (format "\\nCaptured %d assertions" (count @results)))')

# Write capture script
capture_path = os.path.join(work_dir, 'capture.clj')
with open(capture_path, 'w') as f:
    f.write('\n'.join(capture_script_lines))

print(f"  Generated {calls_generated} function calls to capture", file=sys.stderr)

# ---------------------------------------------------------------------------
# Phase 3: Run capture and generate oracle
# ---------------------------------------------------------------------------

print("─── Phase 3: Running capture ───", file=sys.stderr)

result = subprocess.run(
    ['bb', '-cp', build_dir,
     '-e', f'(load-file "{capture_path}")'],
    capture_output=True, text=True, timeout=120
)

output = result.stdout + result.stderr
captures = []
for line in output.split('\n'):
    m = re.match(r'CAPTURE:\s+(\S+)\s+=\s+(.+)', line)
    if m:
        captures.append({'label': m.group(1), 'value': m.group(2)})

print(f"  Captured {len(captures)} function outputs", file=sys.stderr)

# ---------------------------------------------------------------------------
# Phase 4: Generate oracle verify script
# ---------------------------------------------------------------------------

oracle_lines = []
oracle_lines.append(';; beagle-oracle: auto-generated behavioral oracle')
oracle_lines.append(f';; Source: {source_dir}')
oracle_lines.append(f';; Functions captured: {len(captures)}')
oracle_lines.append('')

# Namespace with requires
ns_requires = ' '.join(f'[{mod}]' for mod in modules)
oracle_lines.append(f'(ns beagle.oracle\n  (:require {ns_requires}))')
oracle_lines.append('')

# Assert helper
oracle_lines.append('(def ^:dynamic *pass-count* (atom 0))')
oracle_lines.append('(def ^:dynamic *fail-count* (atom 0))')
oracle_lines.append('')
oracle_lines.append('(defn assert-eq [label expected actual]')
oracle_lines.append('  (if (= expected actual)')
oracle_lines.append('    (swap! *pass-count* inc)')
oracle_lines.append('    (do (swap! *fail-count* inc)')
oracle_lines.append('        (println (str "FAIL: " label "\\n  expected: " expected "\\n  actual:   " actual)))))')
oracle_lines.append('')

# Test data
oracle_lines.extend(test_data_lines)
oracle_lines.append('')
oracle_lines.append(';; --- Oracle assertions ---')
oracle_lines.append('')

for cap in captures:
    label = cap['label']
    value = cap['value']
    # Reconstruct the function call from the label
    parts = label.split('/')
    if len(parts) == 2:
        module_name, fn_name = parts
        # Find the arg type for this function
        fn_info = None
        for fn in modules.get(module_name, {}).get('functions', []):
            if fn['name'] == fn_name:
                fn_info = fn
                break
        if fn_info:
            arg_types = [a.strip() for a in fn_info['args_str'].split() if a.strip()]
            if len(arg_types) == 1 and arg_types[0] in all_records:
                var_name = f'test-{arg_types[0].lower()}-1'
                oracle_lines.append(
                    f'(assert-eq "{label}" {value} ({module_name}/{fn_name} {var_name}))'
                )

oracle_lines.append('')
oracle_lines.append(';; --- Summary ---')
oracle_lines.append('(println (format "\\n%d assertions, %d failures"')
oracle_lines.append('  (+ @*pass-count* @*fail-count*) @*fail-count*))')

# Write oracle
oracle_content = '\n'.join(oracle_lines)

if out_file:
    with open(out_file, 'w') as f:
        f.write(oracle_content)
    print(f"\n  Oracle written to: {out_file}", file=sys.stderr)
else:
    # Print to stdout
    print(oracle_content)

print(f"  {len(captures)} assertions generated", file=sys.stderr)

# ---------------------------------------------------------------------------
# Diff mode: if --diff specified, compile modified code and find differences
# ---------------------------------------------------------------------------

if diff_dir:
    print("\n─── Diff mode: comparing golden vs modified ───", file=sys.stderr)

    diff_build_dir = os.path.join(work_dir, 'diff-build')
    os.makedirs(diff_build_dir, exist_ok=True)

    subprocess.run(
        [os.path.join(beagle_bin, 'beagle-build-all'), '--warn', diff_dir, '--out', diff_build_dir],
        capture_output=True, timeout=60
    )

    # Run same capture on modified code
    result = subprocess.run(
        ['bb', '-cp', diff_build_dir,
         '-e', f'(load-file "{capture_path}")'],
        capture_output=True, text=True, timeout=120
    )

    diff_output = result.stdout + result.stderr
    diff_captures = {}
    for line in diff_output.split('\n'):
        m = re.match(r'CAPTURE:\s+(\S+)\s+=\s+(.+)', line)
        if m:
            diff_captures[m.group(1)] = m.group(2)

    # Compare
    golden_dict = {c['label']: c['value'] for c in captures}
    differences = []
    for label, golden_val in golden_dict.items():
        diff_val = diff_captures.get(label)
        if diff_val and diff_val != golden_val:
            differences.append({
                'label': label,
                'golden': golden_val,
                'modified': diff_val,
            })

    print(f"\n  DIFFERENCES: {len(differences)} function(s) produce different output", file=sys.stderr)
    for d in differences:
        print(f"\n  {d['label']}:")
        print(f"    golden:   {d['golden']}")
        print(f"    modified: {d['modified']}")

PYEOF

echo "" >&2
echo "─── beagle-oracle complete ───" >&2
