#!/usr/bin/env bash
# beagle-proptest: schema-driven property test generation.
#
# Reads type information from beagle modules (defrecord, defscalar, function
# signatures) and generates property-based test assertions. No handwritten
# test code — the type system IS the test spec.
#
# Usage:
#   beagle-proptest <source-dir>              Generate properties to stdout
#   beagle-proptest <source-dir> --run        Generate and run against compiled build
#   beagle-proptest <source-dir> --run --build-dir <dir>
#
# Generated properties:
#   - Record constructor round-trip (construct → access → equals)
#   - Record generators (random valid instances from field types)
#   - Property inference (return type → non-negative, deterministic, vec length)
#   - Scalar non-negativity (Amount, Price, Count → >= 0)
#   - Function return type constraints (returns Amount → non-negative)
#   - Monotonicity (count functions → >= 0)
#   - Filter subset (filter result ⊆ input)

set -euo pipefail

if [[ $# -lt 1 ]]; then
    echo "Usage: beagle-proptest <source-dir> [--run] [--build-dir <dir>]" >&2
    exit 1
fi

SOURCE_DIR="$(cd "$1" && pwd)"
shift

RUN_MODE=false
BUILD_DIR=""
DIFF_DIR=""

while [[ $# -gt 0 ]]; do
    case "$1" in
        --run) RUN_MODE=true; shift ;;
        --build-dir) BUILD_DIR="$2"; shift 2 ;;
        --diff) DIFF_DIR="$2"; shift 2 ;;
        *) echo "Unknown option: $1" >&2; exit 1 ;;
    esac
done

BEAGLE_BIN="$(cd "$(dirname "$0")" && pwd)"
WORK_DIR="$(mktemp -d /tmp/beagle-proptest.XXXXXX)"
trap 'rm -rf "$WORK_DIR"' EXIT

# ---------------------------------------------------------------------------
# Phase 1: Extract type information from all modules
# ---------------------------------------------------------------------------

echo ";; beagle-proptest: auto-generated properties" >&2
echo ";; source: $SOURCE_DIR" >&2

python3 - "$SOURCE_DIR" "$BEAGLE_BIN" "$WORK_DIR" << 'PYEOF'
import sys, os, re, subprocess

source_dir = sys.argv[1]
beagle_bin = sys.argv[2]
work_dir = sys.argv[3]

BEAGLE_EXTENSIONS = ('.bclj', '.bjs', '.bnix', '.bsql', '.bpy', '.bgl', '.rkt')
def is_beagle_source(fname):
    return any(fname.endswith(ext) for ext in BEAGLE_EXTENSIONS)

def strip_beagle_ext(fname):
    for ext in BEAGLE_EXTENSIONS:
        if fname.endswith(ext):
            return fname[:-len(ext)]
    return fname

# =========================================================================
# Collect type information from all modules
# =========================================================================

modules = {}
all_scalars = {}   # name -> base type, across all modules
all_records = {}   # name -> {module, fields}, across all modules

for fname in sorted(os.listdir(source_dir)):
    if not is_beagle_source(fname):
        continue
    fpath = os.path.join(source_dir, fname)
    module_name = strip_beagle_ext(fname)

    # Run beagle-provides to get type info
    try:
        result = subprocess.run(
            [os.path.join(beagle_bin, 'beagle-provides'), fpath],
            capture_output=True, text=True, timeout=30
        )
        output = result.stdout
    except (subprocess.TimeoutExpired, Exception):
        continue

    records = []
    functions = []
    scalars = []

    # Parse records
    for m in re.finditer(r'^\s+(\w+)\s+\[(.+)\]$', output, re.MULTILINE):
        rec_name = m.group(1)
        fields_str = m.group(2)
        fields = []
        for fm in re.finditer(r'([\w-]+):([\w/?]+)', fields_str):
            fields.append({'name': fm.group(1), 'type': fm.group(2)})
        if fields:
            records.append({'name': rec_name, 'fields': fields})
            all_records[rec_name] = {'module': module_name, 'fields': fields}

    # Parse functions
    for m in re.finditer(r'^\s+([\w?!<>*+\-/]+)\s+:\s+\[(.+)\]$', output, re.MULTILINE):
        fn_name = m.group(1)
        sig = m.group(2)
        parts = sig.rsplit(' -> ', 1)
        if len(parts) == 2:
            args_str = parts[0].strip()
            ret_type = parts[1].strip()
            arg_types = [a for a in args_str.split() if a != '->']
            functions.append({
                'name': fn_name,
                'args': args_str,
                'arg_types': arg_types,
                'return_type': ret_type,
            })

    # Parse scalars from source
    with open(fpath) as f:
        src = f.read()
    for sm in re.finditer(r'\(defscalar\s+(\w+)\s+(\w+)\)', src):
        scalar_name = sm.group(1)
        scalar_base = sm.group(2)
        scalars.append({'name': scalar_name, 'base': scalar_base})
        all_scalars[scalar_name] = scalar_base

    modules[module_name] = {
        'records': records,
        'functions': functions,
        'scalars': scalars,
    }

# =========================================================================
# Type resolution helpers
# =========================================================================

def resolve_base_type(t):
    """Unwrap scalar types to their base. e.g. Amount -> Long, Price -> Long."""
    bare = t.split('/')[-1] if '/' in t else t
    if bare in all_scalars:
        return all_scalars[bare]
    return bare

def is_scalar(t):
    bare = t.split('/')[-1] if '/' in t else t
    return bare in all_scalars

def is_record(t):
    bare = t.split('/')[-1] if '/' in t else t
    return bare in all_records

# Known semantic categories
NONNEG_SCALARS = {'Amount', 'Price', 'Count', 'Quantity', 'Weight', 'Cost'}
ID_SCALARS = set()
for s in all_scalars:
    if s.endswith('Id'):
        ID_SCALARS.add(s)

# Field names that should generate non-negative values
NONNEG_FIELDS = frozenset([
    'quantity', 'count', 'capacity', 'min-quantity', 'min-qty',
    'max-uses', 'current-uses', 'discount-value', 'discount-pct',
    'lead-time-days', 'uses', 'surcharge-pct', 'commission-pct',
    'weight-kg', 'target', 'total-spend',
])

# =========================================================================
# Record generators
# =========================================================================

def gen_expr_for_type(t, field_name='', depth=0):
    """Return a Clojure expression that generates a random value of type t.

    Key insight: beagle scalars (defscalar Amount Long) erase at compile time.
    ->Amount and amount-value are identity at runtime.  So scalar-typed fields
    just get a raw base-type value -- no wrapping constructor call.
    """
    if depth > 3:
        return 'nil'
    bare = t.split('/')[-1] if '/' in t else t
    base = resolve_base_type(t)

    # Scalar wrapping a base type -- emit raw value (scalar ctors are erased)
    if bare in all_scalars:
        return gen_expr_for_base(all_scalars[bare], bare, field_name)

    # Record type -- constructor IS real at runtime (Clojure defrecord)
    if bare in all_records:
        rec_info = all_records[bare]
        mod = rec_info['module']
        field_exprs = []
        for fld in rec_info['fields']:
            field_exprs.append(gen_expr_for_type(fld['type'], fld['name'], depth + 1))
        return f'({mod}/->{bare} {" ".join(field_exprs)})'

    # Base types
    return gen_expr_for_base(base, bare, field_name)


def gen_expr_for_base(base, hint='', field_name=''):
    """Return a Clojure expression for a random base type value."""
    if base == 'Long':
        # Use scalar name hint for range
        if any(hint.endswith(s) for s in ['Id']):
            return '(+ 1 (rand-int 10000))'
        if hint in NONNEG_SCALARS or any(hint.startswith(s) for s in NONNEG_SCALARS):
            return '(rand-int 10000)'
        # Field name heuristics for realistic data
        if field_name in NONNEG_FIELDS:
            return '(rand-int 1000)'
        return '(- (rand-int 20000) 10000)'
    elif base == 'Double':
        return '(* (rand) 100.0)'
    elif base == 'String':
        return '(rand-nth ["alpha" "bravo" "charlie" "delta" "echo" "foxtrot" "golf" "hotel"])'
    elif base == 'Boolean':
        return '(rand-nth [true false])'
    elif base == 'Keyword':
        return '(rand-nth [:a :b :c :d :e])'
    elif base == 'Any':
        return 'nil'
    elif base.startswith('Vec') or base.startswith('(Vec'):
        return '[]'
    elif base.startswith('Map') or base.startswith('(Map'):
        return '{}'
    else:
        return '(rand-int 10000)'


def gen_fn_name(rec_name):
    return f'gen-{rec_name}'


# =========================================================================
# Property inference from function signatures
# =========================================================================

def infer_properties(ns, fn_info):
    """Infer testable properties from a function's signature and name."""
    props = []
    fn_name = fn_info['name']
    ret = fn_info['return_type']
    arg_types = fn_info.get('arg_types', [])
    ret_bare = ret.split('/')[-1] if '/' in ret else ret

    # --- Return type is non-negative scalar (Amount, Price, Count, etc.) ---
    if ret_bare in NONNEG_SCALARS or (ret_bare in all_scalars
            and resolve_base_type(ret_bare) == 'Long'):
        for nns in NONNEG_SCALARS:
            if ret_bare == nns or ret_bare.startswith(nns):
                props.append({
                    'kind': 'nonneg-return',
                    'label': f'prop/{ns}/{fn_name}-returns-nonneg',
                    'fn': fn_name,
                    'ret_type': ret,
                })
                break

    # --- Name-based: *-total, *-sum -> non-negative ---
    if (fn_name.endswith('-total') or fn_name.endswith('-sum')
            or fn_name.startswith('total-') or fn_name.startswith('sum-')):
        if ret_bare == 'Long' or ret_bare in all_scalars:
            props.append({
                'kind': 'nonneg-return',
                'label': f'prop/{ns}/{fn_name}-nonneg-total',
                'fn': fn_name,
                'ret_type': ret,
            })

    # --- Name-based: *-count -> non-negative integer ---
    if fn_name.endswith('-count') or fn_name.startswith('count-'):
        props.append({
            'kind': 'nonneg-return',
            'label': f'prop/{ns}/{fn_name}-nonneg-count',
            'fn': fn_name,
            'ret_type': ret,
        })

    # --- Return type Boolean -> deterministic (same input -> same output) ---
    if ret_bare == 'Boolean':
        props.append({
            'kind': 'deterministic',
            'label': f'prop/{ns}/{fn_name}-deterministic',
            'fn': fn_name,
        })

    # --- Return type Vec -> length >= 0 ---
    if ret_bare.startswith('Vec') or ret_bare == 'Any':
        if (fn_name.startswith('filter') or fn_name.startswith('sort-')
                or fn_name.endswith('-items') or fn_name.endswith('-list')):
            props.append({
                'kind': 'vec-nonneg-length',
                'label': f'prop/{ns}/{fn_name}-vec-length',
                'fn': fn_name,
            })

    return props


# =========================================================================
# Determine which functions can be tested with generators
# =========================================================================

def can_generate_args(arg_types):
    """Check if we can generate all argument types."""
    for t in arg_types:
        bare = t.split('/')[-1] if '/' in t else t
        base = resolve_base_type(t)
        if bare in all_scalars or bare in all_records:
            continue
        if base in ('Long', 'Double', 'String', 'Boolean', 'Keyword'):
            continue
        # Any is tricky -- we can pass nil or [] but it may not be useful
        return False
    return True


# =========================================================================
# Generate properties
# =========================================================================

properties = []       # static round-trip assertions (existing behavior)
generators = {}       # rec_name -> gen function code
gen_properties = []   # randomized property tests

for module_name, info in modules.items():
    ns = module_name

    # --- Record generators ---
    for rec in info['records']:
        rec_name = rec['name']
        fields = rec['fields']
        ctor = f'->{rec_name}'
        gen_name = gen_fn_name(rec_name)

        field_gen_exprs = []
        for f in fields:
            field_gen_exprs.append(gen_expr_for_type(f['type'], f['name']))

        gen_code = (f'(defn {gen_name} []\n'
                    f'  ({ns}/{ctor} {" ".join(field_gen_exprs)}))')
        generators[rec_name] = {
            'code': gen_code,
            'name': gen_name,
            'module': ns,
        }

    # --- Record constructor round-trip (existing, preserved) ---
    for rec in info['records']:
        rec_name = rec['name']
        fields = rec['fields']
        ctor = f'->{rec_name}'
        accessor_prefix = rec_name.lower()

        sample_vals = []
        for f in fields:
            t = f['type']
            bare = t.split('/')[-1] if '/' in t else t
            base = resolve_base_type(t)
            if base in ('Long', 'Double') or bare in all_scalars:
                sample_vals.append('1')
            elif base == 'String':
                sample_vals.append('"test"')
            elif base == 'Boolean':
                sample_vals.append('true')
            elif base == 'Keyword':
                sample_vals.append(':test')
            elif base.startswith('Vec') or base.startswith('(Vec'):
                sample_vals.append('[]')
            else:
                sample_vals.append('1')

        if len(sample_vals) == len(fields):
            ctor_call = f'({ns}/{ctor} {" ".join(sample_vals)})'
            for i, f in enumerate(fields):
                accessor = f'{ns}/{accessor_prefix}-{f["name"]}'
                expected = sample_vals[i]
                properties.append(
                    f'(assert-eq "prop/{ns}/{rec_name}/{f["name"]}-roundtrip" '
                    f'{expected} ({accessor} {ctor_call}))'
                )

    # --- Randomized round-trip: generate N instances, verify construction ---
    for rec in info['records']:
        rec_name = rec['name']
        gen_name = gen_fn_name(rec_name)
        gen_properties.append({
            'kind': 'gen-roundtrip',
            'label': f'prop/{ns}/{rec_name}/gen-roundtrip',
            'gen_name': gen_name,
            'rec_name': rec_name,
            'module': ns,
        })

    # --- Property inference on functions ---
    for fn in info['functions']:
        fn_name = fn['name']
        arg_types = fn.get('arg_types', [])

        inferred = infer_properties(ns, fn)

        for prop in inferred:
            if len(arg_types) == 1 and can_generate_args(arg_types):
                prop['arg_types'] = arg_types
                prop['module'] = ns
                gen_properties.append(prop)
            elif len(arg_types) == 1:
                prop['executable'] = False
                gen_properties.append(prop)

# =========================================================================
# Output as runnable Clojure
# =========================================================================

N_ITERATIONS = 20   # number of random instances per generative property

output_lines = []
output_lines.append(';; beagle-proptest: auto-generated property assertions')
output_lines.append(f';; Generated from {len(modules)} module(s)')

n_static = len(properties)
n_gen = len([p for p in gen_properties if p.get('executable', True)])
output_lines.append(
    f';; {n_static} static properties, {n_gen} generative properties '
    f'({N_ITERATIONS} iterations each)')
output_lines.append('')

# Namespace header
ns_requires = []
for mod in modules:
    ns_requires.append(f'[{mod}]')
output_lines.append(f'(ns beagle.proptest')
output_lines.append(f'  (:require {" ".join(ns_requires)}))')
output_lines.append('')

# Test helpers
output_lines.append('(def ^:dynamic *pass-count* (atom 0))')
output_lines.append('(def ^:dynamic *fail-count* (atom 0))')
output_lines.append('')
output_lines.append('(defn assert-eq [label expected actual]')
output_lines.append('  (if (= expected actual)')
output_lines.append('    (swap! *pass-count* inc)')
output_lines.append('    (do (swap! *fail-count* inc)')
output_lines.append('        (println (str "FAIL: " label')
output_lines.append('                      "\\n  expected: " expected')
output_lines.append('                      "\\n  actual:   " actual)))))')
output_lines.append('')
output_lines.append('(defn assert-true [label actual]')
output_lines.append('  (if actual')
output_lines.append('    (swap! *pass-count* inc)')
output_lines.append('    (do (swap! *fail-count* inc)')
output_lines.append('        (println (str "FAIL: " label "\\n  got: " actual)))))')
output_lines.append('')
output_lines.append('(defn assert-nonneg [label v]')
output_lines.append('  (if (>= v 0)')
output_lines.append('    (swap! *pass-count* inc)')
output_lines.append('    (do (swap! *fail-count* inc)')
output_lines.append('        (println (str "FAIL: " label')
output_lines.append('                      "\\n  expected non-negative, got: " v)))))')
output_lines.append('')
output_lines.append(';; === Shrinking ===')
output_lines.append('(def ^:dynamic *shrink-count* (atom 0))')
output_lines.append('')
output_lines.append('(defn shrink-long [n]')
output_lines.append('  (cond (= n 0) []')
output_lines.append('        (= n 1) [0]')
output_lines.append('        (= n -1) [0]')
output_lines.append('        (pos? n) [0 (quot n 2)]')
output_lines.append('        :else [0 (quot n 2)]))')
output_lines.append('')
output_lines.append('(defn shrink-string [s]')
output_lines.append('  (if (= s "") []')
output_lines.append('    ["" (subs s 0 (quot (count s) 2))]))')
output_lines.append('')
output_lines.append('(defn shrink-val [v]')
output_lines.append('  (cond (integer? v) (shrink-long v)')
output_lines.append('        (float? v) (if (zero? v) [] [0.0 (/ v 2.0)])')
output_lines.append('        (string? v) (shrink-string v)')
output_lines.append('        (boolean? v) []')
output_lines.append('        (keyword? v) []')
output_lines.append('        (map? v) (let [ks (keys v)]')
output_lines.append('                   (mapv (fn [k]')
output_lines.append('                     (let [sv (first (shrink-val (get v k)))]')
output_lines.append('                       (if sv (assoc v k sv) v))) ks))')
output_lines.append('        :else []))')
output_lines.append('')
output_lines.append('(defn shrink-record [rec]')
output_lines.append('  (if-not (record? rec) []')
output_lines.append('    (let [m (into {} rec)')
output_lines.append('          ks (keys m)]')
output_lines.append('      (filterv #(not= % rec)')
output_lines.append('        (mapv (fn [k]')
output_lines.append('          (let [v (get m k)')
output_lines.append('                shrunk (first (shrink-val v))]')
output_lines.append('            (if (some? shrunk)')
output_lines.append('              (merge rec {k shrunk})')
output_lines.append('              rec))) ks)))))')
output_lines.append('')
output_lines.append('(defn shrink-input [input]')
output_lines.append('  (if (record? input) (shrink-record input) (shrink-val input)))')
output_lines.append('')
output_lines.append('(defn find-minimal [pred input max-shrinks]')
output_lines.append('  (loop [current input n 0]')
output_lines.append('    (if (>= n max-shrinks) current')
output_lines.append('      (let [candidates (shrink-input current)')
output_lines.append('            smaller (first (filter pred candidates))]')
output_lines.append('        (if smaller')
output_lines.append('          (do (swap! *shrink-count* inc)')
output_lines.append('              (recur smaller (inc n)))')
output_lines.append('          current)))))')
output_lines.append('')
output_lines.append(f'(def N {N_ITERATIONS})')
output_lines.append('')

# --- Record generators ---
output_lines.append(';; === Record Generators ===')
output_lines.append('')
for rec_name, gen_info in sorted(generators.items()):
    output_lines.append(gen_info['code'])
    output_lines.append('')

# --- Static properties (round-trip) ---
output_lines.append(';; === Static Round-Trip Properties ===')
output_lines.append('')
for prop in properties:
    output_lines.append(prop)
output_lines.append('')

# --- Generative properties ---
output_lines.append(';; === Generative Properties ===')
output_lines.append('')

for prop in gen_properties:
    kind = prop['kind']

    if kind == 'gen-roundtrip':
        gen_name = prop['gen_name']
        label = prop['label']
        output_lines.append(f';; {label}: construct random instances')
        output_lines.append(f'(dotimes [_ N]')
        output_lines.append(f'  (let [inst ({gen_name})]')
        output_lines.append(f'    (assert-true "{label}/not-nil" (some? inst))))')
        output_lines.append('')

    elif kind == 'nonneg-return' and prop.get('executable', True):
        fn_name = prop['fn']
        label = prop['label']
        module = prop.get('module', '')
        arg_types = prop.get('arg_types', [])
        if len(arg_types) == 1:
            arg_t = arg_types[0]
            bare = arg_t.split('/')[-1] if '/' in arg_t else arg_t
            if bare in all_records:
                gen_name = gen_fn_name(bare)
                output_lines.append(f';; {label}')
                output_lines.append(f'(dotimes [_ N]')
                output_lines.append(f'  (let [input ({gen_name})')
                output_lines.append(f'        result ({module}/{fn_name} input)]')
                output_lines.append(f'    (when (and (some? result) (neg? result))')
                output_lines.append(f'      (let [minimal (find-minimal')
                output_lines.append(f'                     (fn [i] (neg? ({module}/{fn_name} i)))')
                output_lines.append(f'                     input 20)]')
                output_lines.append(f'        (swap! *fail-count* inc)')
                output_lines.append(f'        (println (str "FAIL: {label}\\n  input: " input')
                output_lines.append(f'                      "\\n  shrunk: " minimal')
                output_lines.append(f'                      "\\n  result: " ({module}/{fn_name} minimal)))))))')
                output_lines.append('')
            elif bare in all_scalars:
                gen_input = gen_expr_for_type(arg_t)
                output_lines.append(f';; {label}')
                output_lines.append(f'(dotimes [_ N]')
                output_lines.append(f'  (let [input {gen_input}')
                output_lines.append(f'        result ({module}/{fn_name} input)]')
                output_lines.append(f'    (when (and (some? result) (neg? result))')
                output_lines.append(f'      (let [minimal (find-minimal')
                output_lines.append(f'                     (fn [i] (neg? ({module}/{fn_name} i)))')
                output_lines.append(f'                     input 20)]')
                output_lines.append(f'        (swap! *fail-count* inc)')
                output_lines.append(f'        (println (str "FAIL: {label}\\n  input: " input')
                output_lines.append(f'                      "\\n  shrunk: " minimal')
                output_lines.append(f'                      "\\n  result: " ({module}/{fn_name} minimal)))))))')
                output_lines.append('')

    elif kind == 'nonneg-return' and not prop.get('executable', True):
        output_lines.append(f';; [skip] {prop["label"]}: cannot generate args')

    elif kind == 'deterministic' and prop.get('executable', True):
        fn_name = prop['fn']
        label = prop['label']
        module = prop.get('module', '')
        arg_types = prop.get('arg_types', [])
        if len(arg_types) == 1:
            arg_t = arg_types[0]
            bare = arg_t.split('/')[-1] if '/' in arg_t else arg_t
            if bare in all_records:
                gen_name = gen_fn_name(bare)
                output_lines.append(f';; {label}')
                output_lines.append(f'(dotimes [_ N]')
                output_lines.append(f'  (let [input ({gen_name})')
                output_lines.append(f'        r1 ({module}/{fn_name} input)')
                output_lines.append(f'        r2 ({module}/{fn_name} input)]')
                output_lines.append(f'    (assert-eq "{label}" r1 r2)))')
                output_lines.append('')
            elif bare in all_scalars:
                gen_input = gen_expr_for_type(arg_t)
                output_lines.append(f';; {label}')
                output_lines.append(f'(dotimes [_ N]')
                output_lines.append(f'  (let [input {gen_input}')
                output_lines.append(f'        r1 ({module}/{fn_name} input)')
                output_lines.append(f'        r2 ({module}/{fn_name} input)]')
                output_lines.append(f'    (assert-eq "{label}" r1 r2)))')
                output_lines.append('')

    elif kind == 'vec-nonneg-length' and prop.get('executable', True):
        fn_name = prop['fn']
        label = prop['label']
        module = prop.get('module', '')
        arg_types = prop.get('arg_types', [])
        if len(arg_types) == 1:
            arg_t = arg_types[0]
            bare = arg_t.split('/')[-1] if '/' in arg_t else arg_t
            if bare in all_records:
                gen_name = gen_fn_name(bare)
                output_lines.append(f';; {label}')
                output_lines.append(f'(dotimes [_ N]')
                output_lines.append(f'  (let [input ({gen_name})')
                output_lines.append(f'        result ({module}/{fn_name} input)]')
                output_lines.append(f'    (when (some? result)')
                output_lines.append(f'      (assert-nonneg "{label}" (count result)))))')
                output_lines.append('')

# --- Summary ---
output_lines.append('')
output_lines.append(';; === Summary ===')
output_lines.append('(println (format "\\n%d properties checked, %d failures, %d shrink steps"')
output_lines.append('                 (+ @*pass-count* @*fail-count*) @*fail-count* @*shrink-count*))')
output_lines.append('(when (pos? @*fail-count*) (System/exit 1))')

# Write output
output_path = os.path.join(work_dir, 'proptest.clj')
with open(output_path, 'w') as f:
    f.write('\n'.join(output_lines))

# Also print to stdout
print('\n'.join(output_lines))

print(f"\n;; Wrote {n_static} static + {n_gen} generative properties to {output_path}",
      file=sys.stderr)
PYEOF

# ---------------------------------------------------------------------------
# Phase 2 (optional): Run properties against compiled build
# ---------------------------------------------------------------------------

if [[ "$RUN_MODE" == "true" ]]; then
    if [[ -z "$BUILD_DIR" ]]; then
        BUILD_DIR="$WORK_DIR/build"
        mkdir -p "$BUILD_DIR"
        echo "" >&2
        echo "=== Building modules ===" >&2
        "$BEAGLE_BIN/beagle-build-all" --warn "$SOURCE_DIR" --out "$BUILD_DIR" 2>&1 | tail -3 >&2
    fi

    echo "" >&2
    echo "=== Running properties ===" >&2
    bb -cp "$BUILD_DIR" -e "(load-file \"$WORK_DIR/proptest.clj\")" 2>&1
fi

# ---------------------------------------------------------------------------
# Phase 3 (optional): Differential testing — compare two builds
# ---------------------------------------------------------------------------

if [[ -n "$DIFF_DIR" ]]; then
    if [[ -z "$BUILD_DIR" ]]; then
        BUILD_DIR="$WORK_DIR/build"
        mkdir -p "$BUILD_DIR"
        echo "" >&2
        echo "=== Building golden modules ===" >&2
        "$BEAGLE_BIN/beagle-build-all" --warn "$SOURCE_DIR" --out "$BUILD_DIR" 2>&1 | tail -3 >&2
    fi

    echo "" >&2
    echo "=== Differential testing: $BUILD_DIR vs $DIFF_DIR ===" >&2

    BEAGLE_BIN="$BEAGLE_BIN" SOURCE_DIR="$SOURCE_DIR" python3 - "$BUILD_DIR" "$DIFF_DIR" "$WORK_DIR" << 'DIFF_PYEOF'
import sys, os, re, subprocess, json

golden_dir = sys.argv[1]
modified_dir = sys.argv[2]
work_dir = sys.argv[3]

BEAGLE_EXTENSIONS = ('.bclj', '.bjs', '.bnix', '.bsql', '.bpy', '.bgl', '.rkt')
def is_beagle_source(fname):
    return any(fname.endswith(ext) for ext in BEAGLE_EXTENSIONS)
def strip_beagle_ext(fname):
    for ext in BEAGLE_EXTENSIONS:
        if fname.endswith(ext):
            return fname[:-len(ext)]
    return fname

# Find functions shared between both builds (with type signatures)
def extract_functions(build_dir):
    """Parse defn signatures from compiled .clj files."""
    fns = {}
    for fname in sorted(os.listdir(build_dir)):
        if not fname.endswith('.clj'):
            continue
        module = fname[:-4]
        path = os.path.join(build_dir, fname)
        with open(path) as f:
            content = f.read()
        for m in re.finditer(r'\(defn\s+([\w?!<>*+\-/]+)\s+\[([^\]]*)\]', content):
            fn_name = m.group(1)
            params_str = m.group(2).strip()
            params = [p.strip() for p in params_str.split() if p.strip()] if params_str else []
            if params == ['r'] or fn_name.startswith('gen-'):
                continue
            fns.setdefault(module, []).append({'name': fn_name, 'arity': len(params)})
    return fns

def extract_typed_functions(source_dir, beagle_bin):
    """Use beagle-provides to get typed function signatures."""
    typed = {}
    if not source_dir or not os.path.isdir(source_dir):
        return typed
    for fname in sorted(os.listdir(source_dir)):
        if not is_beagle_source(fname):
            continue
        fpath = os.path.join(source_dir, fname)
        module = strip_beagle_ext(fname)
        try:
            result = subprocess.run(
                [os.path.join(beagle_bin, 'beagle-provides'), fpath],
                capture_output=True, text=True, timeout=30
            )
            for fm in re.finditer(r'^\s+([\w?!<>*+\-/]+)\s+:\s+\[(.+)\]$', result.stdout, re.MULTILINE):
                fn_name = fm.group(1)
                sig = fm.group(2)
                parts = sig.rsplit(' -> ', 1)
                if len(parts) == 2:
                    arg_types = [a for a in parts[0].strip().split() if a != '->']
                    typed.setdefault(module, {})[fn_name] = arg_types
        except:
            pass
    return typed

golden_fns = extract_functions(golden_dir)
modified_fns = extract_functions(modified_dir)

# Try to get typed signatures from source if available
beagle_bin = os.environ.get('BEAGLE_BIN', '')
source_dir = os.environ.get('SOURCE_DIR', '')
typed_fns = extract_typed_functions(source_dir, beagle_bin) if beagle_bin else {}

shared_modules = set(golden_fns.keys()) & set(modified_fns.keys())
if not shared_modules:
    print("No shared modules found between builds", file=sys.stderr)
    sys.exit(0)

# Find records for test data generation
def extract_records(build_dir):
    records = {}
    for fname in sorted(os.listdir(build_dir)):
        if not fname.endswith('.clj'):
            continue
        module = fname[:-4]
        path = os.path.join(build_dir, fname)
        with open(path) as f:
            content = f.read()
        for m in re.finditer(r'\(defrecord\s+(\w+)\s+\[([^\]]*)\]\)', content):
            rec_name = m.group(1)
            fields = [f.strip() for f in m.group(2).split() if f.strip()]
            records[rec_name] = {'module': module, 'fields': fields}
    return records

records = extract_records(golden_dir)

# Generate differential test script
diff_script = os.path.join(work_dir, 'diff-test.clj')
with open(diff_script, 'w') as out:
    out.write(';; Differential test: compare golden vs modified outputs\n')
    # Require all shared modules
    module_list = ' '.join(f'[{m}]' for m in sorted(shared_modules))
    out.write(f'(ns beagle.diff-test (:require {module_list}))\n\n')

    # Test data
    out.write(';; Test data\n')
    out.write('(def test-longs [0 1 -1 42 100 999 -50])\n')
    out.write('(def test-strings ["" "hello" "test" "foo bar"])\n')
    out.write('(def test-doubles [0.0 1.5 -3.14 100.0])\n')
    out.write('(def test-booleans [true false])\n\n')

    # Fixed record instances for deterministic diff testing
    for rec_name, rec_info in sorted(records.items()):
        mod = rec_info['module']
        n_fields = len(rec_info['fields'])
        # Generate 5 deterministic instances with varied field values
        for i in range(5):
            vals = ' '.join([str((i * 37 + j * 13 + 100) % 10000) for j in range(n_fields)])
            out.write(f'(def test-{rec_name}-{i} ({mod}/->{rec_name} {vals}))\n')
        out.write(f'(def test-{rec_name}s [' + ' '.join(f'test-{rec_name}-{i}' for i in range(5)) + '])\n')
    out.write('\n')

    out.write('(def results (atom []))\n\n')
    out.write('(defn record-result! [module fn-name args result]\n')
    out.write('  (swap! results conj {:module module :fn fn-name\n')
    out.write('                        :args (pr-str args) :result (pr-str result)}))\n\n')

    test_calls = 0
    for module in sorted(shared_modules):
        golden_fn_set = {f['name']: f for f in golden_fns[module]}
        modified_fn_set = {f['name']: f for f in modified_fns[module]}
        shared_fn_names = set(golden_fn_set.keys()) & set(modified_fn_set.keys())

        module_types = typed_fns.get(module, {})

        for fn_name in sorted(shared_fn_names):
            fn = golden_fn_set[fn_name]
            arity = fn['arity']
            qual_name = f"{module}/{fn_name}"
            arg_types = module_types.get(fn_name, [])

            if arity == 0:
                out.write(f'(try (record-result! "{module}" "{fn_name}" [] ({qual_name}))\n')
                out.write(f'  (catch Exception e (record-result! "{module}" "{fn_name}" [] (str "ERROR:" (.getMessage e)))))\n')
                test_calls += 1
            elif arity == 1:
                if arg_types and arg_types[0] in records:
                    rec = arg_types[0]
                    out.write(f'(doseq [v test-{rec}s]\n')
                    out.write(f'  (try (record-result! "{module}" "{fn_name}" [v] ({qual_name} v))\n')
                    out.write(f'    (catch Exception e (record-result! "{module}" "{fn_name}" [v] (str "ERROR:" (.getMessage e))))))\n')
                    test_calls += 5
                else:
                    out.write(f'(doseq [v test-longs]\n')
                    out.write(f'  (try (record-result! "{module}" "{fn_name}" [v] ({qual_name} v))\n')
                    out.write(f'    (catch Exception e (record-result! "{module}" "{fn_name}" [v] (str "ERROR:" (.getMessage e))))))\n')
                    test_calls += 7
            elif arity == 2 and arg_types and len(arg_types) >= 2:
                t0, t1 = arg_types[0], arg_types[1]
                if t0 in records and t1 in records:
                    out.write(f'(doseq [a test-{t0}s b (take 3 test-{t1}s)]\n')
                    out.write(f'  (try (record-result! "{module}" "{fn_name}" [a b] ({qual_name} a b))\n')
                    out.write(f'    (catch Exception e (record-result! "{module}" "{fn_name}" [a b] (str "ERROR:" (.getMessage e))))))\n')
                    test_calls += 15
                elif t0 in records:
                    out.write(f'(doseq [a test-{t0}s b (take 3 test-longs)]\n')
                    out.write(f'  (try (record-result! "{module}" "{fn_name}" [a b] ({qual_name} a b))\n')
                    out.write(f'    (catch Exception e (record-result! "{module}" "{fn_name}" [a b] (str "ERROR:" (.getMessage e))))))\n')
                    test_calls += 15
                elif t1 in records:
                    out.write(f'(doseq [a (take 3 test-longs) b test-{t1}s]\n')
                    out.write(f'  (try (record-result! "{module}" "{fn_name}" [a b] ({qual_name} a b))\n')
                    out.write(f'    (catch Exception e (record-result! "{module}" "{fn_name}" [a b] (str "ERROR:" (.getMessage e))))))\n')
                    test_calls += 15
                else:
                    out.write(f'(doseq [a (take 3 test-longs) b (take 3 test-longs)]\n')
                    out.write(f'  (try (record-result! "{module}" "{fn_name}" [a b] ({qual_name} a b))\n')
                    out.write(f'    (catch Exception e (record-result! "{module}" "{fn_name}" [a b] (str "ERROR:" (.getMessage e))))))\n')
                    test_calls += 9
            elif arity == 2:
                out.write(f'(doseq [a (take 3 test-longs) b (take 3 test-longs)]\n')
                out.write(f'  (try (record-result! "{module}" "{fn_name}" [a b] ({qual_name} a b))\n')
                out.write(f'    (catch Exception e (record-result! "{module}" "{fn_name}" [a b] (str "ERROR:" (.getMessage e))))))\n')
                test_calls += 9
            elif arity == 3 and arg_types and len(arg_types) >= 3:
                all_scalar = all(t not in records for t in arg_types[:3])
                if all_scalar:
                    out.write(f'(doseq [a (take 2 test-longs) b (take 2 test-longs) c (take 2 test-longs)]\n')
                    out.write(f'  (try (record-result! "{module}" "{fn_name}" [a b c] ({qual_name} a b c))\n')
                    out.write(f'    (catch Exception e (record-result! "{module}" "{fn_name}" [a b c] (str "ERROR:" (.getMessage e))))))\n')
                    test_calls += 8

    out.write(f'\n;; {test_calls} test calls generated\n')
    out.write('(println (str "  " (count @results) " function calls executed"))\n')
    out.write(';; Output results as EDN\n')
    out.write(f'(spit "{work_dir}/results.edn" (pr-str @results))\n')

print(f"  Generated {test_calls} differential test calls across {len(shared_modules)} modules", file=sys.stderr)

# Run against golden
golden_results_file = os.path.join(work_dir, 'golden-results.edn')
golden_result = subprocess.run(
    ['bb', '-cp', golden_dir, '-e', f'(load-file "{diff_script}")'],
    capture_output=True, text=True, timeout=60
)
if golden_result.returncode != 0:
    print(f"  Golden run failed: {golden_result.stderr[:200]}", file=sys.stderr)
    sys.exit(1)
print(f"  Golden:{golden_result.stdout.strip()}", file=sys.stderr)
os.rename(os.path.join(work_dir, 'results.edn'), golden_results_file)

# Run against modified
modified_results_file = os.path.join(work_dir, 'modified-results.edn')
modified_result = subprocess.run(
    ['bb', '-cp', modified_dir, '-e', f'(load-file "{diff_script}")'],
    capture_output=True, text=True, timeout=60
)
if modified_result.returncode != 0:
    print(f"  Modified run failed: {modified_result.stderr[:200]}", file=sys.stderr)
    sys.exit(1)
print(f"  Modified:{modified_result.stdout.strip()}", file=sys.stderr)
os.rename(os.path.join(work_dir, 'results.edn'), modified_results_file)

# Compare results
compare_script = os.path.join(work_dir, 'compare.clj')
with open(compare_script, 'w') as out:
    out.write(f'(def golden (read-string (slurp "{golden_results_file}")))\n')
    out.write(f'(def modified (read-string (slurp "{modified_results_file}")))\n')
    out.write('(def diffs (atom []))\n')
    out.write('(doseq [[g m] (map vector golden modified)]\n')
    out.write('  (when (not= (:result g) (:result m))\n')
    out.write('    (swap! diffs conj {:module (:module g) :fn (:fn g)\n')
    out.write('                       :args (:args g)\n')
    out.write('                       :golden (:result g) :modified (:result m)})))\n')
    out.write('(println (str (count @diffs) " behavioral difference(s) found"))\n')
    out.write('(doseq [d (take 20 @diffs)]\n')
    out.write('  (println (str "  DIFF " (:module d) "/" (:fn d) " args=" (:args d)))\n')
    out.write('  (println (str "    golden:   " (:golden d)))\n')
    out.write('  (println (str "    modified: " (:modified d))))\n')
    out.write('(when (> (count @diffs) 20) (println (str "  ... and " (- (count @diffs) 20) " more")))\n')
    out.write('(when (pos? (count @diffs)) (System/exit 1))\n')

compare_result = subprocess.run(
    ['bb', '-e', f'(load-file "{compare_script}")'],
    capture_output=True, text=True, timeout=30
)
print(compare_result.stdout, end='')
if compare_result.stderr:
    print(compare_result.stderr, file=sys.stderr, end='')
sys.exit(compare_result.returncode)
DIFF_PYEOF
fi
