-- stdlib/list.rail — List utilities for Rail native compiler
-- Import: import "stdlib/list.rail"

-- Helpers for fold
add_fn a b = a + b
mul_fn a b = a * b

-- Sum a list of integers
sum xs = fold add_fn 0 xs

-- Product of a list of integers
product xs = fold mul_fn 1 xs

-- Take first n elements
take n xs =
  if n <= 0 then []
  else if xs == [] then []
  else cons (head xs) (take (n - 1) (tail xs))

-- Drop first n elements
drop n xs =
  if n <= 0 then xs
  else if xs == [] then []
  else drop (n - 1) (tail xs)

-- Zip two lists into pairs
zip xs ys =
  if xs == [] then []
  else if ys == [] then []
  else cons (head xs, head ys) (zip (tail xs) (tail ys))

-- Nth element (0-indexed)
nth n xs =
  if n == 0 then head xs
  else nth (n - 1) (tail xs)

-- Any element satisfies predicate
any pred xs =
  if xs == [] then false
  else if pred (head xs) then true
  else any pred (tail xs)

-- All elements satisfy predicate
all pred xs =
  if xs == [] then true
  else if not (pred (head xs)) then false
  else all pred (tail xs)

-- Flatten list of lists
flatten xss =
  if xss == [] then []
  else append (head xss) (flatten (tail xss))

-- Map with index
mapi f xs = mapi_acc f xs 0

mapi_acc f xs i =
  if xs == [] then []
  else cons (f i (head xs)) (mapi_acc f (tail xs) (i + 1))

-- Sort (insertion sort — good enough for small lists)
sort xs =
  if xs == [] then []
  else insert (head xs) (sort (tail xs))

insert x sorted =
  if sorted == [] then [x]
  else if x <= head sorted then cons x sorted
  else cons (head sorted) (insert x (tail sorted))
-- stdlib/string.rail — String utilities for Rail native compiler
-- Import: import "stdlib/string.rail"

-- Check if string starts with prefix
starts_with prefix s =
  let plen = length prefix
  let slen = length s
  if plen > slen then false
  else starts_with_acc (chars prefix) (chars s)

starts_with_acc pcs scs =
  if length pcs == 0 then true
  else if head pcs == head scs then starts_with_acc (tail pcs) (tail scs)
  else false

-- Check if string contains substring (single char only for now)
contains_char c s = any (\ch -> ch == c) (chars s)

-- Repeat string n times
repeat_str s n =
  if n <= 0 then ""
  else append s (repeat_str s (n - 1))

-- Pad string to width with spaces
pad_right s width =
  let diff = width - length s
  if diff <= 0 then s
  else append s (repeat_str " " diff)

pad_left s width =
  let diff = width - length s
  if diff <= 0 then s
  else append (repeat_str " " diff) s

-- Concatenate list of strings (used everywhere, defined here so all programs get it)
cat parts = join "" parts

-- Join list with separator (alias for builtin join)
intercalate sep xs = join sep xs

-- Convert list of chars back to string
from_chars cs =
  if length cs == 0 then ""
  else append (head cs) (from_chars (tail cs))
-- stdlib/map.rail — Pure Rail tree map (BST)
-- O(log n) average insert/lookup, O(n) worst case (unbalanced)
-- Supports string and integer keys via Rail's polymorphic ==, <, >

type Map = | MapEmpty | MapNode key value left right

map_new _ = MapEmpty

map_put m k v = match m
  | MapEmpty -> MapNode k v MapEmpty MapEmpty
  | MapNode mk mv l r ->
    if k == mk then MapNode mk v l r
    else if k < mk then MapNode mk mv (map_put l k v) r
    else MapNode mk mv l (map_put r k v)

map_get m k = match m
  | MapEmpty -> 0
  | MapNode mk mv l r ->
    if k == mk then mv
    else if k < mk then map_get l k
    else map_get r k

map_has m k = match m
  | MapEmpty -> false
  | MapNode mk mv l r ->
    if k == mk then true
    else if k < mk then map_has l k
    else map_has r k

map_del m k = match m
  | MapEmpty -> MapEmpty
  | MapNode mk mv l r ->
    if k == mk then map_merge l r
    else if k < mk then MapNode mk mv (map_del l k) r
    else MapNode mk mv l (map_del r k)

map_merge l r = match l
  | MapEmpty -> r
  | _ -> match r
    | MapEmpty -> l
    | _ ->
      let (mk, mv, r2) = map_pop_min r
      MapNode mk mv l r2

map_pop_min m = match m
  | MapEmpty -> (0, 0, MapEmpty)
  | MapNode k v l r -> match l
    | MapEmpty -> (k, v, r)
    | _ ->
      let (mk, mv, l2) = map_pop_min l
      (mk, mv, MapNode k v l2 r)

map_keys m = match m
  | MapEmpty -> []
  | MapNode k v l r -> append (map_keys l) (cons k (map_keys r))

map_size m = match m
  | MapEmpty -> 0
  | MapNode k v l r -> 1 + map_size l + map_size r
-- stdlib/bytes.rail — Byte-level helpers for binary protocols.
--
-- Designed to support pure-Rail crypto (SHA-256, ChaCha20-Poly1305),
-- TLS record framing, and X.509 certificate parsing. Byte buffers are
-- represented as mutable int arrays (arr_new / arr_get / arr_set), each
-- slot holding a byte value in 0..255.
--
-- API:
--   hex_to_bytes "deadbeef"         -> int array (4 bytes)
--   bytes_to_hex arr len            -> hex string (lowercase)
--   be32_read arr offset            -> 32-bit unsigned int
--   be32_write arr offset value     -> unit (arr mutated in place)
--   add32 a b                       -> (a + b) mod 2^32
--   and32 a b, or32 a b, xor32 a b  -> masked to 32 bits
--   rotr32 x n                      -> 32-bit right rotate
--   shr32 x n                       -> 32-bit logical shift right

-- ── 32-bit arithmetic (Rail ints are 63-bit, so we mask to 32 ─────────────
-- bits after each op that could overflow or sign-extend unexpectedly).

mask32 = 4294967295
two32 = 4294967296

and32 a b = bit_and (bit_and a b) mask32
or32 a b = bit_and (bit_or a b) mask32
xor32 a b = bit_and (bit_xor a b) mask32
add32 a b = bit_and (a + b) mask32

-- 32-bit right rotation: ((x >> n) | (x << (32 - n))) & 0xFFFFFFFF
-- For SHA-256 Σ, σ, Ch, Maj functions.
rotr32 x n =
  let r = shr x n
  let l = shl x (32 - n)
  bit_and (bit_or r l) mask32

-- 32-bit logical shift right (just shr + mask for safety).
shr32 x n = bit_and (shr x n) mask32

-- ── Hex helpers ───────────────────────────────────────────────────────────
-- Rail has no 0xFF literal parsing (parses silently as 0), so every crypto
-- constant comes through hex_to_bytes or is written in decimal. This helper
-- converts a lowercase-or-mixed-case hex string to an int array.

hex_digit c =
  -- c is a tagged int (ASCII code). Returns 0..15 or -1 on invalid.
  if c >= 48 && c <= 57 then c - 48           -- '0'..'9'
  else if c >= 97 && c <= 102 then c - 87     -- 'a'..'f'
  else if c >= 65 && c <= 70 then c - 55      -- 'A'..'F'
  else 0 - 1

hex_to_bytes_at cs arr idx =
  if length cs < 2 then idx
  else
    let hi = hex_digit (char_to_int (head cs))
    let lo = hex_digit (char_to_int (head (tail cs)))
    let b = bit_or (shl hi 4) lo
    let _ = arr_set arr idx b
    hex_to_bytes_at (tail (tail cs)) arr (idx + 1)

hex_to_bytes s =
  let cs = chars s
  let n = length cs / 2
  let arr = arr_new n 0
  let _ = hex_to_bytes_at cs arr 0
  arr

-- ── bytes_to_hex: int array → lowercase hex string ────────────────────────

hex_char_of n =
  if n < 10 then char_from_int (48 + n)
  else char_from_int (87 + n)

bytes_to_hex_at arr n i acc =
  if i >= n then acc
  else
    let b = arr_get arr i
    let hi = shr b 4
    let lo = bit_and b 15
    let s = join "" [acc, hex_char_of hi, hex_char_of lo]
    bytes_to_hex_at arr n (i + 1) s

bytes_to_hex arr n = bytes_to_hex_at arr n 0 ""

-- ── Big-endian 32-bit word I/O ────────────────────────────────────────────
-- SHA-256 processes 16 big-endian 32-bit words per 64-byte block.

be32_read arr offset =
  let b0 = arr_get arr offset
  let b1 = arr_get arr (offset + 1)
  let b2 = arr_get arr (offset + 2)
  let b3 = arr_get arr (offset + 3)
  or32 (or32 (shl b0 24) (shl b1 16)) (or32 (shl b2 8) b3)

be32_write arr offset value =
  let _ = arr_set arr offset       (bit_and (shr value 24) 255)
  let _ = arr_set arr (offset + 1) (bit_and (shr value 16) 255)
  let _ = arr_set arr (offset + 2) (bit_and (shr value 8) 255)
  let _ = arr_set arr (offset + 3) (bit_and value 255)
  0

-- ── Little-endian 32-bit word I/O ─────────────────────────────────────────
-- ChaCha20 uses little-endian 32-bit words (RFC 8439 §2.3).

le32_read arr offset =
  let b0 = arr_get arr offset
  let b1 = arr_get arr (offset + 1)
  let b2 = arr_get arr (offset + 2)
  let b3 = arr_get arr (offset + 3)
  or32 (or32 b0 (shl b1 8)) (or32 (shl b2 16) (shl b3 24))

le32_write arr offset value =
  let _ = arr_set arr offset       (bit_and value 255)
  let _ = arr_set arr (offset + 1) (bit_and (shr value 8) 255)
  let _ = arr_set arr (offset + 2) (bit_and (shr value 16) 255)
  let _ = arr_set arr (offset + 3) (bit_and (shr value 24) 255)
  0

-- 32-bit left rotation. The naive `bit_and (rotl x n) mask32` fails because
-- Layer 0's rotl is a 64-bit rotate — bits rotated from position 31 and
-- below land above bit 32, where the mask clears them. Use the same
-- shift-OR pattern as rotr32.
rotl32 x n =
  let l = shl x n
  let r = shr x (32 - n)
  bit_and (bit_or l r) mask32

-- ── String-to-bytes adapter ───────────────────────────────────────────────
-- For hashing arbitrary Rail strings (prompts, keys, etc.) without going
-- through the hex encoder.

-- WARNING: do not check `length cs == 0` as the loop guard here.
-- `length` on a Rail cons-list is O(N), so the recursive walk would
-- be O(N^2) and a 360 KB string took *hours* before this was fixed.
-- Pattern-match on the empty list shape instead, which is O(1).
string_to_bytes_at cs arr i =
  if cs == [] then i
  else
    let b = char_to_int (head cs)
    let _ = arr_set arr i b
    string_to_bytes_at (tail cs) arr (i + 1)

string_to_bytes s =
  let cs = chars s
  let n = length cs
  let arr = arr_new n 0
  let _ = string_to_bytes_at cs arr 0
  arr

string_length_bytes s = length (chars s)
-- stdlib/json.rail — Pure Rail JSON parser
-- Import: import "stdlib/json.rail"
-- Parses JSON strings into Rail values
-- Representation: ["obj", pairs] | ["arr", items] | ["str", s] | ["num", n] | ["bool", b] | ["null"]

-- Skip whitespace characters
json_skip_ws cs =
  if cs == [] then []
  else
    let c = head cs
    if c == " " || c == "\t" || c == "\n" then json_skip_ws (tail cs)
    else cs

-- Parse a JSON string value (after opening quote)
json_parse_str cs acc =
  if cs == [] then (acc, [])
  else if head cs == "\"" then (acc, tail cs)
  else if head cs == "\\" then
    if length (tail cs) == 0 then (acc, [])
    else
      let nx = head (tail cs)
      let esc = if nx == "n" then "\n"
                else if nx == "t" then "\t"
                else if nx == "\\" then "\\"
                else if nx == "\"" then "\""
                else nx
      json_parse_str (tail (tail cs)) (append acc esc)
  else json_parse_str (tail cs) (append acc (head cs))

-- Parse a JSON number (integer only for now)
json_parse_num cs acc =
  if cs == [] then (acc, cs)
  else
    let c = head cs
    if c == "0" || c == "1" || c == "2" || c == "3" || c == "4" || c == "5" || c == "6" || c == "7" || c == "8" || c == "9" || c == "-" then
      json_parse_num (tail cs) (append acc c)
    else (acc, cs)

-- Parse a JSON value, returns (value, remaining_chars)
json_parse cs =
  let ws = json_skip_ws cs
  if ws == [] then (["null"], [])
  else
    let c = head ws
    if c == "\"" then
      let (s, rest) = json_parse_str (tail ws) ""
      (["str", s], rest)
    else if c == "[" then json_parse_array (tail ws) []
    else if c == "{" then json_parse_object (tail ws) []
    else if c == "t" then (["bool", true], drop 4 ws)
    else if c == "f" then (["bool", false], drop 5 ws)
    else if c == "n" then (["null"], drop 4 ws)
    else
      let (num_s, rest) = json_parse_num ws ""
      (["num", json_str_to_int num_s], rest)

-- Pure-Rail string->int.  `to_int` is float-only and silently returns 0
-- for plain integer strings.
json_str_to_int s =
  json_str_to_int_acc (chars s) 0 1

json_str_to_int_acc cs acc sign =
  if cs == [] then acc * sign
  else
    let c = head cs
    if c == "-" then json_str_to_int_acc (tail cs) acc (0 - 1)
    else
      let d = char_to_int c - 48
      if d < 0 then acc * sign
      else if d > 9 then acc * sign
      else json_str_to_int_acc (tail cs) (acc * 10 + d) sign

-- Parse JSON array elements
json_parse_array cs acc =
  let ws = json_skip_ws cs
  if ws == [] then (["arr", acc], [])
  else if head ws == "]" then (["arr", reverse acc], tail ws)
  else
    let (val, rest) = json_parse ws
    let ws2 = json_skip_ws rest
    if length ws2 > 0 then
      if head ws2 == "," then json_parse_array (tail ws2) (cons val acc)
      else json_parse_array ws2 (cons val acc)
    else (["arr", reverse (cons val acc)], [])

-- Parse JSON object key-value pairs
json_parse_object cs acc =
  let ws = json_skip_ws cs
  if ws == [] then (["obj", acc], [])
  else if head ws == "}" then (["obj", reverse acc], tail ws)
  else
    -- Parse key (must be string)
    let ws2 = if head ws == "\"" then tail ws else ws
    let (key, rest) = json_parse_str ws2 ""
    let ws3 = json_skip_ws rest
    -- Skip colon
    let ws4 = if length ws3 > 0 then
      if head ws3 == ":" then json_skip_ws (tail ws3) else ws3
      else ws3
    let (val, rest2) = json_parse ws4
    let pair = (key, val)
    let ws5 = json_skip_ws rest2
    if length ws5 > 0 then
      if head ws5 == "," then json_parse_object (tail ws5) (cons pair acc)
      else json_parse_object ws5 (cons pair acc)
    else (["obj", reverse (cons pair acc)], [])

-- Parse a JSON string, returns a JSON value
parse_json str =
  let (val, _) = json_parse (chars str)
  val

-- Get value from JSON object by key
json_get obj key =
  let pairs = head (tail obj)
  json_lookup pairs key

json_lookup pairs key =
  if pairs == [] then ["null"]
  else
    let (k, v) = head pairs
    if k == key then v
    else json_lookup (tail pairs) key

-- Get string value
json_str val =
  if head val == "str" then head (tail val)
  else ""

-- Get int value
json_int val =
  if head val == "num" then head (tail val)
  else 0

-- Get array items
json_items val =
  if head val == "arr" then head (tail val)
  else []

-- Serialize JSON value to string.  Split into per-tag helpers because
-- a single `if/then/else` chain with let-bindings under several arms,
-- plus a multi-line lambda inside `map`, defeated the parser; flat
-- helpers parse cleanly and read better anyway.
json_encode val =
  let tag = head val
  if tag == "str" then json_encode_str val
  else if tag == "num" then json_encode_num val
  else if tag == "bool" then json_encode_bool val
  else if tag == "null" then "null"
  else if tag == "arr" then json_encode_arr val
  else if tag == "obj" then json_encode_obj val
  else "null"

json_encode_str val = cat ["\"", head (tail val), "\""]
json_encode_num val = show (head (tail val))

json_encode_bool val =
  if head (tail val) then "true" else "false"

json_encode_arr val =
  let items = map json_encode (head (tail val))
  cat ["[", join ", " items, "]"]

json_encode_pair p =
  let (k, v) = p
  cat ["\"", k, "\": ", json_encode v]

json_encode_obj val =
  let pairs = map json_encode_pair (head (tail val))
  cat ["{", join ", " pairs, "}"]

-- Helper: drop first n elements from a list
drop n xs =
  if n <= 0 then xs
  else if xs == [] then []
  else drop (n - 1) (tail xs)
-- stdlib/optim.rail — Optimizers for Rail tensor training
-- Import: import "stdlib/optim.rail"
--
-- Currently exports: Adam (fused GPU kernel, one dispatch per tensor per step).
-- Queued for v2.3: cosine_decay LR schedule, clip_norm, mini-batch iterator.
--
-- All scalars travel as raw doubles through Rail's foreign ABI; Adam's six
-- hyperparameters are packed into a float_arr of length 6 to keep the FFI
-- signature inside the 8-register window.
--
--   hyp[0] = lr      learning rate
--   hyp[1] = β1      first-moment decay
--   hyp[2] = β2      second-moment decay
--   hyp[3] = ε       numerical stability floor
--   hyp[4] = bc1     1 - β1^t  (bias correction for m, computed host-side)
--   hyp[5] = bc2     1 - β2^t  (bias correction for v)
--
-- The kernel mutates w, m, v in place; g is read-only.
--
-- Prereq: import "stdlib/tensor.rail" (or autograd.rail, which pulls tensor)
-- BEFORE importing this module. Rail imports do not dedupe, so optim does
-- not re-import tensor. The tgl_adam_update_f64 foreign decl lives in
-- tensor.rail alongside the other GPU op foreigns.

foreign pow x y -> float
foreign cos x -> float
foreign sqrt x -> float

-- ─────────────────────────────────────────────────────────────
-- AdamState: persistent first/second moment + step counter.
-- m and v are float_arrs, same size as the parameter tensor data.
-- step is a 1-element int array so we can mutate it across calls
-- without allocating a new state record each update.
-- ─────────────────────────────────────────────────────────────

type AdamState = | AdamState m v step_arr

-- Build a fresh AdamState matching an existing float_arr (weights) in size.
-- m and v start at zeros; step counter starts at 0.
adam_state n =
  let m = float_arr_new n 0.0
  let v = float_arr_new n 0.0
  let step_arr = arr_new 1 0
  AdamState m v step_arr

-- Pack the six hyperparameters for this step into a float_arr the dylib
-- can read. Bias correction is recomputed every step because pow is cheap.
adam_hyp lr b1 b2 eps t =
  let t_f = int_to_float t
  let bc1 = 1.0 - (pow b1 t_f)
  let bc2 = 1.0 - (pow b2 t_f)
  let h = float_arr_new 6 0.0
  let _ = float_arr_set h 0 lr
  let _ = float_arr_set h 1 b1
  let _ = float_arr_set h 2 b2
  let _ = float_arr_set h 3 eps
  let _ = float_arr_set h 4 bc1
  let _ = float_arr_set h 5 bc2
  h

-- CPU Adam inner loop. Hyperparameter layout matches adam_hyp:
-- h[0]=lr, h[1]=b1, h[2]=b2, h[3]=eps, h[4]=bc1=1-b1^t, h[5]=bc2=1-b2^t.
cpu_adam_loop w g m v h n i =
  if i >= n then 0
  else
    let lr = float_arr_get h 0
    let b1 = float_arr_get h 1
    let b2 = float_arr_get h 2
    let eps = float_arr_get h 3
    let bc1 = float_arr_get h 4
    let bc2 = float_arr_get h 5
    let gi = float_arr_get g i
    let mi = b1 * (float_arr_get m i) + (1.0 - b1) * gi
    let vi = b2 * (float_arr_get v i) + (1.0 - b2) * gi * gi
    let _ = float_arr_set m i mi
    let _ = float_arr_set v i vi
    let m_hat = mi / bc1
    let v_hat = vi / bc2
    let wi = float_arr_get w i
    let _ = float_arr_set w i (wi - lr * m_hat / (sqrt v_hat + eps))
    cpu_adam_loop w g m v h n (i + 1)

-- One fused Adam update on a raw float_arr (weights data). Mutates w, m, v.
-- Returns the new step counter. Dylib path when gpu_available; pure-Rail
-- CPU fallback otherwise (machines where the Metal dylib can't be loaded).
adam_update_raw w g state lr b1 b2 eps = match state
  | AdamState m v step_arr ->
    let t = arr_get step_arr 0 + 1
    let _ = arr_set step_arr 0 t
    let n = float_arr_len w
    let h = adam_hyp lr b1 b2 eps t
    let _ = if gpu_available 0 then tgl_adam_update_f64 w g m v h n
            else cpu_adam_loop w g m v h n 0
    t

-- Tensor-level wrapper: unwraps the Tensor ADT and calls the raw path.
-- param and grad must share shape; state must have been built with
-- adam_state matching that shape's element count.
adam_update param grad state lr b1 b2 eps = match param
  | Tensor pd _ _ -> match grad
    | Tensor gd _ _ ->
      let _ = adam_update_raw pd gd state lr b1 b2 eps
      param

-- ─────────────────────────────────────────────────────────────
-- AdamW: Adam + decoupled weight decay (Loshchilov & Hutter 2019).
-- Instead of L2 regularization embedded in the gradient, we shrink
-- weights toward zero by a small factor (1 - lr*wd) each step, then
-- run the normal Adam step on the unregularized gradient.
--
-- Used by every modern transformer (GPT / Llama / Qwen) with wd
-- typically 0.01-0.1.
-- ─────────────────────────────────────────────────────────────
apply_decay_loop w lr_wd n i =
  if i >= n then 0
  else
    let wi = float_arr_get w i
    let _ = float_arr_set w i (wi * (1.0 - lr_wd))
    apply_decay_loop w lr_wd n (i + 1)

adamw_update_raw w g state lr b1 b2 eps wd =
  let n = float_arr_len w
  let lr_wd = lr * wd
  let _ = if wd > 0.0 then apply_decay_loop w lr_wd n 0 else 0
  adam_update_raw w g state lr b1 b2 eps

adamw_update param grad state lr b1 b2 eps wd = match param
  | Tensor pd _ _ -> match grad
    | Tensor gd _ _ ->
      let _ = adamw_update_raw pd gd state lr b1 b2 eps wd
      param

-- ─────────────────────────────────────────────────────────────
-- Per-parameter LR multipliers. Motivated by the γ bounce at
-- step 349→399 (1.59 → 1.78 loss) in run 924e0c5 with a single
-- learnable RMSNorm γ and uniform lr=0.02. Damping γ's effective
-- LR by 0.3× keeps the norm scale from oscillating; tied embedding
-- (W_E reused as output projection) benefits from 0.5× because it
-- receives gradients through two paths.
--
-- Rail's cross-function float-param inference doesn't propagate a
-- float-typed result through a thin wrapper fn: `base_lr * lr_mult`
-- inside a helper receives untyped params and the multiply misses
-- the __float_ marker path. So no wrapper here — pre-multiply at
-- the call site, which DOES get the marker via `let scaled_lr = ...`:
--
--   let scaled_lr_gamma = base_lr * adam_lr_mult_gamma
--   let _ = adam_update_raw w_gamma g_gamma st_gamma scaled_lr_gamma b1 b2 eps
--   let _ = adam_update_raw w_embed g_embed st_embed
--             (base_lr * adam_lr_mult_tied_embed) b1 b2 eps
--   let _ = adam_update_raw w_other g_other st_other base_lr b1 b2 eps
--
-- Verified working in /tmp/optim_smoke2.rail; wrapper variant returns
-- 0.0 delta because the inline multiply goes through integer codegen.
-- ─────────────────────────────────────────────────────────────

adam_lr_mult_gamma = 0.3
adam_lr_mult_tied_embed = 0.5
adam_lr_mult_default = 1.0

adam_default_wd = 0.01

-- Default hyperparameters (Kingma & Ba 2015). Use as:
--   let lr = adam_default_lr
--   let b1 = adam_default_b1
-- etc. Not a record because Rail lacks record-of-floats destructure.
adam_default_lr = 0.001
adam_default_b1 = 0.9
adam_default_b2 = 0.999
adam_default_eps = 1e-8

-- ─────────────────────────────────────────────────────────────
-- Learning rate schedule — linear warmup → half-cosine decay to 0.
--   step 0..warmup-1       : lr = base_lr * (step/warmup)
--   step warmup..max_steps : lr = base_lr * 0.5 * (1 + cos(π * p))
--     where p = (step - warmup) / (max_steps - warmup)
-- Clamps at 0.0 after max_steps to avoid negative lrs from overshoots.
-- ─────────────────────────────────────────────────────────────

cosine_decay step warmup max_steps base_lr =
  if step >= max_steps then 0.0
  else if step < warmup then
    let p = int_to_float step / int_to_float warmup
    base_lr * p
  else
    let decay_steps = max_steps - warmup
    let p = int_to_float (step - warmup) / int_to_float decay_steps
    let pi_val = 3.14159265358979323846
    let c = cos (pi_val * p)
    base_lr * 0.5 * (1.0 + c)

-- ─────────────────────────────────────────────────────────────
-- Gradient global-norm clipping.
--   Given a list of gradient Tensors and a max_norm threshold,
--   compute the global L2 norm (sqrt of sum of all squared entries)
--   and if it exceeds max_norm, return a new list of tensors scaled
--   by max_norm/norm. Returns the PRE-clip norm so callers can log
--   gradient magnitudes regardless of whether clipping fired.
--
-- Per-tensor sum of squares is implemented as tensor_mul(g,g) then
-- tensor_sum; the mul dispatches to GPU, sum runs on CPU (the reduce
-- is cheap compared to the element-wise square).
-- ─────────────────────────────────────────────────────────────

sum_sq_tensor g =
  let g2 = tensor_mul g g
  tensor_sum g2

-- Accumulate sum-of-squares across a list of tensors into a mutable
-- float_arr slot. Rail's cross-function float-param inference can
-- misclassify a float accumulator passed recursively, so we route
-- the running total through a float_arr instead of an `acc` arg.
total_sq_into grads acc_arr =
  if length grads == 0 then 0
  else
    let cur = float_arr_get acc_arr 0
    let contrib = sum_sq_tensor (head grads)
    let _ = float_arr_set acc_arr 0 (cur + contrib)
    total_sq_into (tail grads) acc_arr

-- Fold helper: scale every tensor in a list by a constant factor.
-- `factor` is a float but only participates in arithmetic inside
-- tensor_scale, which already handles float scalars correctly via
-- its own float_arr dispatch.
scale_each grads factor acc =
  if length grads == 0 then reverse acc
  else scale_each (tail grads) factor (cons (tensor_scale (head grads) factor) acc)

-- Returns (clipped_grads, total_norm).
clip_grad_norm grads max_norm =
  let acc_arr = float_arr_new 1 0.0
  let _ = total_sq_into grads acc_arr
  let total_sq = float_arr_get acc_arr 0
  let norm = sqrt total_sq
  if norm <= max_norm then (grads, norm)
  else
    let factor = max_norm / norm
    let clipped = scale_each grads factor []
    (clipped, norm)
-- stdlib/sha256.rail — Pure-Rail SHA-256 per RFC 6234 (FIPS 180-4).
--
-- Depends on stdlib/bytes.rail for bit-op helpers + BE word I/O.
-- Input: Rail string. Output: 32-byte int-array (hash digest).
-- Also provides sha256_hex for convenience testing.
--
-- Usage:
--   let digest = sha256 "hello"
--   let hex = bytes_to_hex digest 32
--
-- Not constant-time; not side-channel-hardened. Adequate for TLS
-- handshake hashing (where timing reveals little) and message
-- authentication (HMAC-SHA256). Do NOT use for password hashing
-- unless wrapped in PBKDF2/Argon2.

import "stdlib/bytes.rail"

-- ── Constants ─────────────────────────────────────────────────────────────
-- K[t] round constants and H[0..7] initial hash values. FIPS 180-4 §4.2.2.
-- Encoded as hex strings to avoid Rail's missing hex-literal parser;
-- decoded once at first-use into 32-bit word arrays.

sha256_k_hex = "428a2f9871374491b5c0fbcfe9b5dba53956c25b59f111f1923f82a4ab1c5ed5d807aa9812835b01243185be550c7dc372be5d7480deb1fe9bdc06a7c19bf174e49b69c1efbe47860fc19dc6240ca1cc2de92c6f4a7484aa5cb0a9dc76f988da983e5152a831c66db00327c8bf597fc7c6e00bf3d5a7914706ca63511429296727b70a852e1b21384d2c6dfc53380d13650a7354766a0abb81c2c92e92722c85a2bfe8a1a81a664bc24b8b70c76c51a3d192e819d6990624f40e3585106aa07019a4c1161e376c082748774c34b0bcb5391c0cb34ed8aa4a5b9cca4f682e6ff3748f82ee78a5636f84c878148cc7020890befffaa4506cebbef9a3f7c67178f2"

sha256_h_hex = "6a09e667bb67ae853c6ef372a54ff53a510e527f9b05688c1f83d9ab5be0cd19"

sha256_k_table _ =
  -- 64 × 32-bit constants = 256 hex chars = 128 bytes.
  let bytes = hex_to_bytes sha256_k_hex
  let tbl = arr_new 64 0
  load_words bytes tbl 0 64

sha256_h_init _ =
  let bytes = hex_to_bytes sha256_h_hex
  let h = arr_new 8 0
  load_words bytes h 0 8

load_words bytes dst i n =
  if i >= n then dst
  else
    let w = be32_read bytes (i * 4)
    let _ = arr_set dst i w
    load_words bytes dst (i + 1) n

-- ── Compression function helpers ──────────────────────────────────────────
-- Ch, Maj, Σ0, Σ1, σ0, σ1 per FIPS 180-4 §4.1.2.
-- All args are 32-bit unsigned (caller's responsibility — ops mask internally).

sha_ch x y z =
  let a = bit_and x y
  let nx = bit_and (bit_xor x mask32) z
  or32 a nx

sha_maj x y z =
  let a = bit_and x y
  let b = bit_and x z
  let c = bit_and y z
  or32 (or32 a b) c

sha_big_sigma0 x =
  let a = rotr32 x 2
  let b = rotr32 x 13
  let c = rotr32 x 22
  xor32 (xor32 a b) c

sha_big_sigma1 x =
  let a = rotr32 x 6
  let b = rotr32 x 11
  let c = rotr32 x 25
  xor32 (xor32 a b) c

sha_small_sigma0 x =
  let a = rotr32 x 7
  let b = rotr32 x 18
  let c = shr32 x 3
  xor32 (xor32 a b) c

sha_small_sigma1 x =
  let a = rotr32 x 17
  let b = rotr32 x 19
  let c = shr32 x 10
  xor32 (xor32 a b) c

-- ── Message schedule: expand W[0..15] → W[16..63] ─────────────────────────

expand_schedule w t =
  if t >= 64 then 0
  else
    let s1 = sha_small_sigma1 (arr_get w (t - 2))
    let s0 = sha_small_sigma0 (arr_get w (t - 15))
    let w2 = arr_get w (t - 7)
    let w16 = arr_get w (t - 16)
    let v = add32 (add32 s1 w2) (add32 s0 w16)
    let _ = arr_set w t v
    expand_schedule w (t + 1)

-- ── Compression: 64 rounds updating (a..h) ────────────────────────────────
-- Store state as an 8-int array indexed 0..7 = (a,b,c,d,e,f,g,h) to avoid
-- 8-tuple returns (Rail's tuple support is limited).

sha_round state w k_tbl t =
  if t >= 64 then 0
  else
    let a = arr_get state 0
    let b = arr_get state 1
    let c = arr_get state 2
    let d = arr_get state 3
    let e = arr_get state 4
    let f = arr_get state 5
    let g = arr_get state 6
    let h = arr_get state 7
    let k = arr_get k_tbl t
    let wt = arr_get w t
    let t1 = add32 (add32 (add32 h (sha_big_sigma1 e)) (sha_ch e f g)) (add32 k wt)
    let t2 = add32 (sha_big_sigma0 a) (sha_maj a b c)
    let _ = arr_set state 7 g
    let _ = arr_set state 6 f
    let _ = arr_set state 5 e
    let _ = arr_set state 4 (add32 d t1)
    let _ = arr_set state 3 c
    let _ = arr_set state 2 b
    let _ = arr_set state 1 a
    let _ = arr_set state 0 (add32 t1 t2)
    sha_round state w k_tbl (t + 1)

-- ── Block processing ──────────────────────────────────────────────────────
-- One 64-byte block updates the 8 × 32-bit hash state.

sha_process_block h state block k_tbl =
  let w = arr_new 64 0
  let _ = load_words block w 0 16
  let _ = expand_schedule w 16
  -- copy h into state
  let _ = arr_set state 0 (arr_get h 0)
  let _ = arr_set state 1 (arr_get h 1)
  let _ = arr_set state 2 (arr_get h 2)
  let _ = arr_set state 3 (arr_get h 3)
  let _ = arr_set state 4 (arr_get h 4)
  let _ = arr_set state 5 (arr_get h 5)
  let _ = arr_set state 6 (arr_get h 6)
  let _ = arr_set state 7 (arr_get h 7)
  let _ = sha_round state w k_tbl 0
  -- H[i] += state[i]
  let _ = arr_set h 0 (add32 (arr_get h 0) (arr_get state 0))
  let _ = arr_set h 1 (add32 (arr_get h 1) (arr_get state 1))
  let _ = arr_set h 2 (add32 (arr_get h 2) (arr_get state 2))
  let _ = arr_set h 3 (add32 (arr_get h 3) (arr_get state 3))
  let _ = arr_set h 4 (add32 (arr_get h 4) (arr_get state 4))
  let _ = arr_set h 5 (add32 (arr_get h 5) (arr_get state 5))
  let _ = arr_set h 6 (add32 (arr_get h 6) (arr_get state 6))
  let _ = arr_set h 7 (add32 (arr_get h 7) (arr_get state 7))
  0

-- ── Padding + block iteration ─────────────────────────────────────────────
-- Pad message to 64-byte-multiple with 0x80 + zeros + 8-byte big-endian
-- bit-length. Output digest: H[0..7] as 32 big-endian bytes.

sha_copy_bytes src dst src_off dst_off n =
  if n <= 0 then 0
  else
    let _ = arr_set dst dst_off (arr_get src src_off)
    sha_copy_bytes src dst (src_off + 1) (dst_off + 1) (n - 1)

sha_process_from msg_bytes msg_len h state k_tbl off =
  -- Process full 64-byte blocks until < 64 bytes remain.
  if msg_len - off >= 64 then
    let block = arr_new 64 0
    let _ = sha_copy_bytes msg_bytes block off 0 64
    let _ = sha_process_block h state block k_tbl
    sha_process_from msg_bytes msg_len h state k_tbl (off + 64)
  else off

sha_final_block msg_bytes msg_len h state k_tbl remain_off =
  -- Build the final 1 or 2 blocks with 0x80 marker + length padding.
  let rem = msg_len - remain_off
  let bit_len = msg_len * 8
  -- If rem + 1 (for 0x80) + 8 (for length) <= 64, one more block suffices.
  -- Otherwise two blocks.
  let need_two = rem + 9 > 64
  let nblocks = if need_two then 2 else 1
  let buf_size = nblocks * 64
  let buf = arr_new buf_size 0
  let _ = sha_copy_bytes msg_bytes buf remain_off 0 rem
  let _ = arr_set buf rem 128
  -- Zero-fill is automatic (arr_new default 0). Now write length in last
  -- 8 bytes as big-endian. Bit-length fits in 32 bits for messages ≤ 2^29
  -- bytes (536 MB), which covers all realistic TLS transcript hashes.
  -- Upper 4 bytes of the 8-byte length field stay 0.
  let _ = be32_write buf (buf_size - 4) bit_len
  let _ = sha_process_block h state buf k_tbl
  if need_two then sha_process_block h state (arr_slice buf 64 64) k_tbl
  else 0

arr_slice src start n =
  -- Helper: shallow copy of n bytes from src[start..].
  let dst = arr_new n 0
  let _ = sha_copy_bytes src dst start 0 n
  dst

-- ── Public entry points ───────────────────────────────────────────────────

sha256 msg =
  let msg_bytes = string_to_bytes msg
  let msg_len = string_length_bytes msg
  let h = sha256_h_init 0
  let k_tbl = sha256_k_table 0
  let state = arr_new 8 0
  let end = sha_process_from msg_bytes msg_len h state k_tbl 0
  let _ = sha_final_block msg_bytes msg_len h state k_tbl end
  -- Serialize H[0..7] to 32 big-endian bytes.
  let out = arr_new 32 0
  let _ = be32_write out 0 (arr_get h 0)
  let _ = be32_write out 4 (arr_get h 1)
  let _ = be32_write out 8 (arr_get h 2)
  let _ = be32_write out 12 (arr_get h 3)
  let _ = be32_write out 16 (arr_get h 4)
  let _ = be32_write out 20 (arr_get h 5)
  let _ = be32_write out 24 (arr_get h 6)
  let _ = be32_write out 28 (arr_get h 7)
  out

sha256_hex msg =
  bytes_to_hex (sha256 msg) 32

-- ── Streaming API ─────────────────────────────────────────────────────────
-- Init / update / finalize for hashes whose input doesn't fit in memory at
-- once.  Used by the streaming HTTPS download path so we can hash a multi-GB
-- response on the way to disk without ever holding it as a Rail string.
--
-- State layout (single Rail array of 5 slots):
--   st[0]  H               8-int running hash (mutated in place)
--   st[1]  K_TBL            64-int round constants (built once, shared)
--   st[2]  pending_buf      64-byte working buffer for partial blocks
--   st[3]  pending_len      bytes currently in pending_buf  (0..63)
--   st[4]  total_bytes      total bytes consumed (used for length padding)
--
-- arena_mark/reset is the caller's responsibility.  The state itself is
-- allocated on entry to sha256_init; subsequent updates allocate scratch
-- (the message schedule W and the round-state) per processed block, which
-- the GC reclaims on the caller's reset.

sha256_init _ =
  let st = arr_new 5 0
  let _ = arr_set st 0 (sha256_h_init 0)
  let _ = arr_set st 1 (sha256_k_table 0)
  let _ = arr_set st 2 (arr_new 64 0)
  let _ = arr_set st 3 0
  let _ = arr_set st 4 0
  st

-- Feed `n` bytes from `arr` starting at `off` into the running hash.
-- Returns the same state array (mutated in place).
sha256_update_arr st arr off n =
  if n <= 0 then st
  else
    let h = arr_get st 0
    let k_tbl = arr_get st 1
    let pbuf = arr_get st 2
    let plen = arr_get st 3
    let total = arr_get st 4
    let _ = arr_set st 4 (total + n)
    let space = 64 - plen
    if n < space then
      let _ = sha_copy_bytes arr pbuf off plen n
      let _ = arr_set st 3 (plen + n)
      st
    else
      let _ = sha_copy_bytes arr pbuf off plen space
      let scratch = arr_new 8 0
      let _ = sha_process_block h scratch pbuf k_tbl
      let _ = arr_set st 3 0
      sha256_stream_full st arr (off + space) (n - space) k_tbl h

sha256_stream_full st arr off remaining k_tbl h =
  if remaining < 64 then
    let pbuf = arr_get st 2
    let _ = sha_copy_bytes arr pbuf off 0 remaining
    let _ = arr_set st 3 remaining
    st
  else
    let block = arr_new 64 0
    let _ = sha_copy_bytes arr block off 0 64
    let scratch = arr_new 8 0
    let _ = sha_process_block h scratch block k_tbl
    sha256_stream_full st arr (off + 64) (remaining - 64) k_tbl h

-- Convenience: feed a Rail string.  Allocates a temporary byte array;
-- not appropriate for huge strings — use sha256_update_arr from a streaming
-- producer instead.
sha256_update_str st s =
  let arr = string_to_bytes s
  let n = string_length_bytes s
  sha256_update_arr st arr 0 n

-- Pad and produce the 32-byte digest.  After finalize the state is spent.
sha256_finalize st =
  let h = arr_get st 0
  let k_tbl = arr_get st 1
  let pbuf = arr_get st 2
  let plen = arr_get st 3
  let total = arr_get st 4
  let bit_len = total * 8
  let need_two = plen + 9 > 64
  let nblocks = if need_two then 2 else 1
  let buf_size = nblocks * 64
  let buf = arr_new buf_size 0
  let _ = sha_copy_bytes pbuf buf 0 0 plen
  let _ = arr_set buf plen 128
  let _ = be32_write buf (buf_size - 4) bit_len
  let scratch = arr_new 8 0
  let _ = sha_process_block h scratch buf k_tbl
  let _ = if need_two then sha_process_block h scratch (arr_slice buf 64 64) k_tbl else 0
  let out = arr_new 32 0
  let _ = be32_write out 0 (arr_get h 0)
  let _ = be32_write out 4 (arr_get h 1)
  let _ = be32_write out 8 (arr_get h 2)
  let _ = be32_write out 12 (arr_get h 3)
  let _ = be32_write out 16 (arr_get h 4)
  let _ = be32_write out 20 (arr_get h 5)
  let _ = be32_write out 24 (arr_get h 6)
  let _ = be32_write out 28 (arr_get h 7)
  out

sha256_finalize_hex st =
  bytes_to_hex (sha256_finalize st) 32
-- BX fixed-point integer library (no main; importable).
-- Shared exact-integer primitives for the bit-exact transformer rungs (BX4+).
--
-- Representation: fixed-point with F=24 fractional bits, S=2^24=16777216. A real x is the
-- integer X = round(x*S). Every operation is a FIXED INTEGER algorithm: integer arithmetic is
-- exact and associative, so any device/runtime running the same algorithm produces BIT-IDENTICAL
-- output. All internal divides truncate toward zero (ARM sdiv); a foreign witness must do the same.
--
-- Codegen traps honored here (see memory):
--   * value-position integer literals >=32768 are miscompiled -> keep big consts as operands,
--     route constant returns through computation (rail-bug-value-literal-tag-mov).
--   * const-multiply in a self-loop tail-arg is miscompiled -> mutual recursion
--     (rail-bug-self-loop-const-mul-arg).

-- ================= integer sqrt (digit-by-digit) =================
bx4_topbit_a n d = if d * 4 > n then d else bx4_topbit_b n (d * 4)
bx4_topbit_b n d = if d * 4 > n then d else bx4_topbit_a n (d * 4)
bx4_topbit n d = bx4_topbit_a n d
bx4_isqrt_a n c d =
  if d == 0 then c
  else
    let cd = c + d in
    if n >= cd then bx4_isqrt_b (n - cd) ((c / 2) + d) (d / 4)
    else bx4_isqrt_b n (c / 2) (d / 4)
bx4_isqrt_b n c d =
  if d == 0 then c
  else
    let cd = c + d in
    if n >= cd then bx4_isqrt_a (n - cd) ((c / 2) + d) (d / 4)
    else bx4_isqrt_a n (c / 2) (d / 4)
bx4_isqrt n = if n < 1 then 0 else bx4_isqrt_a n 0 (bx4_topbit n 1)

bx4_fxsqrt xq = bx4_isqrt (xq * 16777216)
bx4_fxrsqrt xq = let q = bx4_fxsqrt xq in if q == 0 then 0 else 281474976710656 / q

-- ================= fixed-point exp (arg-reduced, unrolled poly) =================
bx4_pow2_go acc k = if k == 0 then acc else bx4_pow2_go (acc + acc) (k - 1)
bx4_pow2 k = if k < 1 then 1 else bx4_pow2_go 1 k

-- e^f * S for f = fq/S in [0, ln2). Degree-8 Horner, fully unrolled. Coeffs c_i = round(S / i!).
bx4_exp_poly fq =
  let p8 = 416 in
  let p7 = 3329 + (p8 * fq) / 16777216 in
  let p6 = 23302 + (p7 * fq) / 16777216 in
  let p5 = 139810 + (p6 * fq) / 16777216 in
  let p4 = 699051 + (p5 * fq) / 16777216 in
  let p3 = 2796203 + (p4 * fq) / 16777216 in
  let p2 = 8388608 + (p3 * fq) / 16777216 in
  let p1 = 16777216 + (p2 * fq) / 16777216 in
  let p0 = 16777216 + (p1 * fq) / 16777216 in
  p0

-- e^x * S for x <= 0 (ax = |x|q > 0). LN2_q = round(ln2 * S) = 11629080.
bx4_exp_neg ax =
  let k = ax / 11629080 in
  let fq = ax - k * 11629080 in
  let ef = bx4_exp_poly fq in
  let emf = 281474976710656 / ef in
  if k >= 50 then 0
  else let p2 = bx4_pow2 k in (emf + p2 / 2) / p2

-- e^x * S for x > 0. The k>=36 branch is dead for the bounded domain; its sentinel is a runtime
-- product (NOT a bare literal >=32768) per rail-bug-value-literal-tag-mov.
bx4_exp_pos xq =
  let k = xq / 11629080 in
  let fq = xq - k * 11629080 in
  let ef = bx4_exp_poly fq in
  if k >= 36 then ef * (bx4_pow2 35)
  else (ef * (bx4_pow2 k))

bx4_fxexp xq =
  if xq < 0 then bx4_exp_neg (0 - xq)
  else bx4_exp_pos xq

-- ================= tanh / gelu (built on exp<=0) =================
bx4_tanh_pos axq =
  let em = bx4_fxexp (0 - (axq + axq)) in
  ((16777216 - em) * 16777216) / (16777216 + em)
bx4_tanh xq = if xq < 0 then 0 - (bx4_tanh_pos (0 - xq)) else bx4_tanh_pos xq

-- gelu(x) = 0.5*x*(1 + tanh(g*(x + a*x^3))), g_q=13386610, a_q=750194
bx4_gelu xq =
  let x2 = (xq * xq) / 16777216 in
  let x3 = (x2 * xq) / 16777216 in
  let ax3 = (750194 * x3) / 16777216 in
  let inner0 = xq + ax3 in
  let inner = (13386610 * inner0) / 16777216 in
  let th = bx4_tanh inner in
  let onep = 16777216 + th in
  (xq * onep) / 33554432

-- gelu'(x) = 0.5(1+t) + 0.5 x sech^2(u) du/dx, where u = g(x + a x^3), t = tanh(u),
-- sech^2(u) = 1 - t^2, du/dx = g(1 + 3a x^2). Pure fixed-point; built on bx4_tanh.
-- This is the transcendental DERIVATIVE that the bit-exact backward needs: like exp/tanh
-- it is a fixed integer algorithm, so every gradient reproduces bit-for-bit on any device.
bx4_gelu_grad xq =
  let x2 = (xq * xq) / 16777216 in
  let x3 = (x2 * xq) / 16777216 in
  let ax3 = (750194 * x3) / 16777216 in
  let inner0 = xq + ax3 in
  let u = (13386610 * inner0) / 16777216 in
  let t = bx4_tanh u in
  let term1 = (16777216 + t) / 2 in
  let sech2 = 16777216 - (t * t) / 16777216 in
  let a3x2 = (3 * (750194 * x2)) / 16777216 in
  let inner1 = 16777216 + a3x2 in
  let dudx = (13386610 * inner1) / 16777216 in
  let p1 = (xq * sech2) / 16777216 in
  let p2 = (p1 * dudx) / 16777216 in
  let term2 = p2 / 2 in
  term1 + term2

-- ================= float<->fixed helpers =================
bx4_quant x = let s = x *. 16777216.0 in
  let r = if s <. 0.0 then s -. 0.5 else s +. 0.5 in float_to_int r
bx4_back rq = (int_to_float rq) /. 16777216.0
bx4_fabs x = if x <. 0.0 then 0.0 -. x else x
-- stdlib/tensor.rail — Tensor type for RailML
-- Phase 2: CPU tensors + Metal GPU auto-dispatch
-- Import: import "stdlib/tensor.rail"
--
-- GPU dispatch: if ~/projects/rail/tools/metal/tensor_gpu exists,
-- matmul and activations auto-dispatch to Metal GPU (M4 Pro: 269 GFLOPS).
-- Falls back to CPU loops transparently.
--
-- NOTE: Due to Rail's float TCO limitation, all loops return int (0).
-- Float results are stored in mutable float arrays and read back after loops.
-- Use show_float for printing tensor_get/tensor_sum/tensor_mean results
-- (cross-function float return inference does not propagate through match).

foreign tanh x -> float
foreign sqrt x -> float
foreign exp x -> float

-- ADT: flat float array + shape + strides (row-major)
type Tensor = | Tensor data shape strides

-- HalfTensor: the same shape/strides machinery, but `half_data` is a
-- float_arr whose 8-byte slots are reinterpreted as packed uint16_t[4].
-- Storage size = ceil(n_elements / 4) slots, where n_elements is
-- product(shape). See tools/metal/tensor_gpu_lib.m: tgl_f64_to_half /
-- tgl_half_to_f64 / tgl_matmul_half_host cross the boundary once at
-- cast time so hot-path matmuls don't pay the f64↔f16 cast.
type HalfTensor = | HalfTensor half_data shape strides

-- ============================================================
-- Shape utilities
-- ============================================================

-- Product of a list of ints
mul_acc a b = a * b
list_product xs = fold mul_acc 1 xs

-- Compute row-major strides from shape
-- [2,3,4] -> [12, 4, 1]
compute_strides_rev rev_shape =
  if length rev_shape == 0 then []
  else if length rev_shape == 1 then [1]
  else
    let rest = compute_strides_rev (tail rev_shape)
    let stride = (head rev_shape) * (head rest)
    cons stride rest

compute_strides shape = compute_strides_rev (reverse shape)

-- Compute flat offset from multi-dim indices and strides
compute_offset indices strides =
  if length indices == 0 then 0
  else (head indices) * (head strides) + compute_offset (tail indices) (tail strides)

-- Nth element of a list (0-indexed)
list_nth n xs =
  if n == 0 then head xs
  else list_nth (n - 1) (tail xs)

-- ============================================================
-- Constructors
-- ============================================================

-- Create tensor with given shape, all elements set to init_val
tensor_new shape init_val =
  let size = list_product shape
  let data = float_arr_new size init_val
  let strides = compute_strides shape
  Tensor data shape strides

-- Convenience constructors
tensor_zeros shape = tensor_new shape 0.0
tensor_ones shape = tensor_new shape 1.0

-- ============================================================
-- Accessors
-- ============================================================

-- Get tensor shape
tensor_shape t = match t
  | Tensor d s st -> s

-- Get total number of elements
tensor_size t = match t
  | Tensor d s st -> list_product s

-- Get size of dimension at axis
dim t axis = match t
  | Tensor d s st -> list_nth axis s

-- Get element at multi-dimensional indices (returns float)
tensor_get t indices = match t
  | Tensor d s st ->
    let offset = compute_offset indices st
    float_arr_get d offset

-- Set element at multi-dimensional indices
tensor_set t indices val = match t
  | Tensor d s st ->
    let offset = compute_offset indices st
    float_arr_set d offset val

-- Get/set by flat index
tensor_get_flat t i = match t
  | Tensor d s st -> float_arr_get d i

tensor_set_flat t i val = match t
  | Tensor d s st -> float_arr_set d i val

-- ============================================================
-- GPU dispatch helpers
-- ============================================================

gpu_binary_path = "~/projects/rail/tools/metal/tensor_gpu"

-- Cached GPU availability flag (0 = unchecked, 1 = available, -1 = unavailable)
gpu_flag_arr = float_arr_new 1 0.0

-- `tools/metal/.no_gpu` sentinel forces the CPU path for every tensor op.
-- Motivation: on machines where the Metal dylib loads but silently returns
-- zeros (Studio M1 Ultra w/ stale Xcode CLT is the current case), touch
-- tools/metal/.no_gpu once and CPU fallbacks stay engaged until it's
-- removed — no per-run binary shuffling. File-based over env-based because
-- `foreign getenv` doesn't propagate reliably through rail_native run's
-- /bin/sh -c child (see docs/ffi-str-return-bug.md).
gpu_available _ =
  let cached = float_arr_get gpu_flag_arr 0
  if cached > 0.5 then true
  else if cached < (0.0 - 0.5) then false
  else
    let result = shell "test -x ~/projects/rail/tools/metal/tensor_gpu && test ! -f ~/projects/rail/tools/metal/.no_gpu && echo 1 || echo 0"
    let avail = str_contains "1" result
    let _ = float_arr_set gpu_flag_arr 0 (if avail then 1.0 else (0.0 - 1.0))
    avail

-- Write float_arr to text file for GPU consumption
gpu_write_floats path arr n =
  write_file path (gpu_write_loop arr n 0 "")

gpu_write_loop arr n i acc =
  if i >= n then acc
  else gpu_write_loop arr n (i + 1) (cat (cons acc (cons (show_float (float_arr_get arr i)) (cons "\n" []))))

-- Read float_arr from GPU output text file
gpu_read_floats path n =
  let content = read_file path
  let lines = str_split "\n" content
  let arr = float_arr_new n 0.0
  let _ = gpu_parse_lines arr lines 0 n
  arr

gpu_parse_lines arr lines i n =
  if i >= n then 0
  else if length lines == 0 then 0
  else
    let line = head lines
    if length (chars line) == 0 then gpu_parse_lines arr (tail lines) i n
    else
      let _ = float_arr_set arr i (parse_float line)
      gpu_parse_lines arr (tail lines) (i + 1) n

-- Cleanup tmp files after GPU ops
gpu_cleanup _ = shell "rm -f /tmp/rail_tg_a.txt /tmp/rail_tg_b.txt /tmp/rail_tg_c.txt 2>/dev/null"

-- Check if tensor daemon is running on :9300
gpu_daemon_flag = float_arr_new 1 0.0

gpu_daemon_available _ =
  let cached = float_arr_get gpu_daemon_flag 0
  if cached > 0.5 then true
  else if cached < (0.0 - 0.5) then false
  else
    let result = shell "nc -z localhost 9300 2>/dev/null && echo 1 || echo 0"
    let avail = str_contains "1" result
    let _ = float_arr_set gpu_daemon_flag 0 (if avail then 1.0 else (0.0 - 1.0))
    avail

-- Format float_arr as space-separated string
gpu_format_floats arr n = gpu_format_loop arr n 0 ""

gpu_format_loop arr n i acc =
  if i >= n then acc
  else
    let sep = if i > 0 then " " else ""
    gpu_format_loop arr n (i + 1) (join "" (cons acc (cons sep (cons (show_float (float_arr_get arr i)) []))))

-- Parse space-separated floats into float_arr
gpu_parse_inline str n =
  let parts = str_split " " str
  let arr = float_arr_new n 0.0
  let _ = gpu_parse_inline_loop arr parts 0 n
  arr

gpu_parse_inline_loop arr parts i n =
  if i >= n then 0
  else if length parts == 0 then 0
  else
    let p = head parts
    if length (chars p) == 0 then gpu_parse_inline_loop arr (tail parts) i n
    else
      let _ = float_arr_set arr i (parse_float p)
      gpu_parse_inline_loop arr (tail parts) (i + 1) n

-- GPU via dylib FFI (fastest, zero-copy).
-- dylib path: ~/projects/rail/tools/metal/libtensor_gpu.dylib
-- absolute install_name makes it discoverable without DYLD env vars.
-- Every foreign float_arr argument is passed as a pointer; the dylib
-- skips the 8-byte Rail count header internally.
foreign tgl_init dummy -> int
foreign tgl_matmul_f64 a b c m k n -> int
foreign tgl_matmul_f32x_halfw_host x w y m k n -> int
foreign tgl_matmul_relu_f64 a b bias c m k n -> int
foreign tgl_add_f64 a b c n -> int
foreign tgl_mul_f64 a b c n -> int
foreign tgl_scale_f64 a c scalar_arr n -> int
foreign tgl_relu_f64 x y n -> int
foreign tgl_relu_backward_f64 x g o n -> int
foreign tgl_sigmoid_f64 x y n -> int
foreign tgl_exp_f64 x y n -> int
foreign tgl_tanh_f64 x y n -> int
foreign tgl_softmax_rows_f64 x y rows cols -> int
foreign tgl_transpose_f64 a b m n -> int
foreign tgl_sgd_update_f64 w g lr_arr n -> int
foreign tgl_adam_update_f64 w g m v hyp n -> int
foreign tgl_cross_entropy_f64 p t l batch vocab -> int
foreign tgl_matmul_gelu_f64 a b bias c m k n -> int
foreign tgl_matmul_batched_f64 a b c bdim m k n -> int
foreign tgl_softmax_backward_f64 y dy dx rows cols -> int
foreign tgl_ce_softmax_backward_f64 probs targets grad batch vocab -> int
foreign tgl_layernorm_backward_f64 x mean rstd gamma dy dx rows dim -> int

-- Transformer fused ops added 2026-05-14 — Metal kernels for the elementwise
-- bottleneck of the training step (RMSNorm fwd, RoPE in-place, SiLU fwd).
-- CPU equivalents stay in stdlib/transformer.rail; these are GPU drop-ins.
foreign tgl_rmsnorm_save_f64 x y g rstd dim n_rows -> int  -- eps hardcoded 1e-5
foreign tgl_rope_apply_f64   x seq d sign -> int
foreign tgl_silu_fwd_f64     x y sig n -> int

-- Phase 0 of the trace+JIT plan (rail-jit-fused-kernels-plan.md, 2026-05-14).
-- tgl_jit_compile_from_tmp_file reads MSL source from /tmp/rail_jit_kernel.metal
-- and returns an integer pipeline ID. tgl_jit_dispatch_1in1out runs a 1-in/1-out
-- kernel with that ID at a 1D grid of size n. Both prove the JIT loop end-to-end.
foreign tgl_jit_compile_from_tmp_file dummy -> int
foreign tgl_jit_dispatch_1in1out kid x y n -> int

-- bf16 forward matmul (added 2026-05-14, replaces fp16 for the long-tail
-- stability question — bf16 has f32's exponent range, no 65504 overflow).
foreign tgl_matmul_bf16 a b c m k n -> int

-- BX3 (2026-06-03): exact integer matmul. aq,bq are float_arr of EXACT-INTEGER
-- doubles (fixed-point quantized inputs, |q| < 2^31). hi,lo are float_arr that
-- receive the per-output 2-limb accumulator (hi, lo in [0,2^31)) as exact-integer
-- doubles. One GPU thread per output does its OWN sequential 2-limb accumulation
-- (no cross-thread reduction, no atomics) so the result is schedule-independent
-- and bit-for-bit identical to the CPU reference in tools/bitexact/bx2.
foreign tgl_exact_matmul aq bq hi lo m k n -> int

-- Phase 1.1 (2026-05-15): JIT dispatcher for the fused rmsnorm+QKV kernel.
-- 5 inputs (X, G, WQ, WK, WV) + 3 outputs (Q, K, V) + SEQ + D.  The kid
-- comes from tgl_jit_compile_from_tmp_file after writing the fused MSL.
foreign tgl_jit_dispatch_rmsnorm_qkv kid x g wq wk wv qkv_packed seq_d -> int
foreign tgl_jit_dispatch_2in1out      kid a b c n -> int
-- silu+hadamard variant: output buffer is 2*N doubles (h_act | sigmoid).
foreign tgl_jit_dispatch_silu_hadamard kid g u out_sig n -> int

-- fp16 variants (Phase 4a Option A, labrat-produced kernels).
-- Same Rail-side signature as the f64 siblings: pass float_arr (double*);
-- the dylib stages down to half on the way in and back to double on the
-- way out. Bias buffer stays fp32 on the GPU per the "fp32-bias" hint
-- that landed the bias-fused kernels (see docs/plans/LABRAT_FIRST_WIN.md).
foreign tgl_matmul_f16           a b c m k n -> int
foreign tgl_matmul_blocked_f16   a b c m k n -> int
foreign tgl_matmul_bias_relu_f16 a b bias c m k n -> int
foreign tgl_matmul_bias_gelu_f16 a b bias c m k n -> int

-- HalfTensor primitives (Phase 4b follow-on). Pack/unpack cross the
-- host-side f64↔f16 boundary once; matmul_half_host accepts already-
-- packed half buffers so hot-path calls skip the cast.
foreign tgl_f64_to_half      src dst n -> int
foreign tgl_half_to_f64      src dst n -> int
foreign tgl_matmul_half_host a b c m k n -> int
foreign tgl_add_half_host    a b c n -> int
foreign tgl_scale_half_host  a c scalar_arr n -> int
foreign tgl_transpose_half_host a b m n -> int
foreign tgl_softmax_half_host   a c rows cols -> int

-- Has dylib been initialized this process?
-- Note: the previous top-level `float_arr` cache was broken by Rail's
-- re-evaluate-per-reference rule for nullary bindings (float_arr_new
-- runs fresh on every reference), so ensure_dylib paid a shell + tgl_init
-- on EVERY matmul call. Fix: tgl_init is idempotent (the dylib has its
-- own g_initialized flag); call it unconditionally and trust it. Shell
-- existence check is compile-time anyway — if the dylib weren't linked,
-- ld would have failed, so this symbol exists by construction.
ensure_dylib _ =
  let _ = tgl_init 0
  true

-- GPU matmul: in-process dylib. Zero-copy path, ~1ms per 128x128 matmul.
-- The historical binary-file and text-file fallbacks were dropped from the
-- hot path — if the dylib isn't linked, link would fail at build time.
-- Callers needing a non-dylib path should call gpu_matmul_file (below)
-- explicitly. This keeps the common case branch-free.
gpu_matmul_dispatch a_data b_data m k n =
  let result = float_arr_new (m * n) 0.0
  let _ = tgl_matmul_f64 a_data b_data result m k n
  result

-- Legacy binary-file dispatch, kept for debugging. ~50ms/call.
gpu_matmul_file a_data b_data m k n =
  let _ = float_arr_to_f32_file "/tmp/rail_tg_a.bin" a_data (m * k)
  let _ = float_arr_to_f32_file "/tmp/rail_tg_b.bin" b_data (k * n)
  let cmd = join "" (cons gpu_binary_path (cons " matmul_bin " (cons (show m) (cons " " (cons (show k) (cons " " (cons (show n) (cons " /tmp/rail_tg_a.bin /tmp/rail_tg_b.bin /tmp/rail_tg_c.bin" []))))))))
  let _ = shell cmd
  let result = float_arr_new (m * n) 0.0
  let _ = float_arr_from_f32_file "/tmp/rail_tg_c.bin" result (m * n)
  let _ = shell "rm -f /tmp/rail_tg_a.bin /tmp/rail_tg_b.bin /tmp/rail_tg_c.bin 2>/dev/null"
  result

-- GPU relu: daemon (fast) or file-mode (fallback)
gpu_relu_dispatch src_data n =
  if gpu_daemon_available 0 then
    let a_str = gpu_format_floats src_data n
    let cmd = join "" (cons "printf 'relu " (cons (show n) (cons "\\n" (cons a_str (cons "\\nEND\\n' | nc -w2 localhost 9300" [])))))
    let result_str = shell cmd
    gpu_parse_inline result_str n
  else
    let _ = gpu_write_floats "/tmp/rail_tg_a.txt" src_data n
    let cmd = join "" (cons gpu_binary_path (cons " relu " (cons (show n) (cons " /tmp/rail_tg_a.txt /tmp/rail_tg_c.txt" []))))
    let _ = shell cmd
    let result = gpu_read_floats "/tmp/rail_tg_c.txt" n
    let _ = gpu_cleanup 0
    result

-- ============================================================
-- dylib FFI path for every op — zero-copy, amortized allocation.
-- Each helper: if dylib is up, call it; else fall back to file/daemon.
-- Callers that hit these do NOT have to do their own ensure_dylib
-- gate — the helper handles it.
-- ============================================================

-- Unary activation (relu, sigmoid, exp, tanh). Op string maps to FFI.
gpu_unop_ffi op src_data n =
  if ensure_dylib 0 then
    let result = float_arr_new n 0.0
    let rc = if op == "relu"     then tgl_relu_f64    src_data result n
             else if op == "sigmoid" then tgl_sigmoid_f64 src_data result n
             else if op == "exp"      then tgl_exp_f64     src_data result n
             else if op == "tanh_fwd" then tgl_tanh_f64    src_data result n
             else 0 - 1
    if rc >= 0 then result
    else gpu_unop_dispatch op src_data n
  else
    gpu_unop_dispatch op src_data n

-- Binary elementwise (add, mul).
gpu_binop_ffi op a_data b_data n =
  if ensure_dylib 0 then
    let result = float_arr_new n 0.0
    let rc = if op == "add" then tgl_add_f64 a_data b_data result n
             else if op == "mul" then tgl_mul_f64 a_data b_data result n
             else 0 - 1
    if rc >= 0 then result
    else gpu_binop_dispatch op a_data b_data n
  else
    gpu_binop_dispatch op a_data b_data n

-- Scalar multiply. Scalar wrapped into a tiny float_arr so we can pass by pointer.
gpu_scale_ffi src_data s n =
  if ensure_dylib 0 then
    let sarr = float_arr_new 1 s
    let result = float_arr_new n 0.0
    let rc = tgl_scale_f64 src_data result sarr n
    if rc >= 0 then result
    else src_data   -- fallback callers always have CPU loop nearby
  else
    src_data

-- Softmax (row-wise). 2D input [rows,cols]. Shape-preserving.
gpu_softmax_ffi src_data rows cols =
  if ensure_dylib 0 then
    let result = float_arr_new (rows * cols) 0.0
    let rc = tgl_softmax_rows_f64 src_data result rows cols
    if rc >= 0 then result
    else gpu_unop_dispatch "softmax" src_data (rows * cols)
  else
    gpu_unop_dispatch "softmax" src_data (rows * cols)

-- Transpose: A[M,N]^T returns a new flat buffer of shape [N,M].
gpu_transpose_ffi src_data m n =
  if ensure_dylib 0 then
    let result = float_arr_new (m * n) 0.0
    let rc = tgl_transpose_f64 src_data result m n
    if rc >= 0 then result
    else gpu_unop_dispatch "transpose" src_data (m * n)
  else
    gpu_unop_dispatch "transpose" src_data (m * n)

-- ReLU backward: ∂L/∂x = (x > 0) ? ∂L/∂y : 0
gpu_relu_backward_ffi x_data grad_data n =
  if ensure_dylib 0 then
    let result = float_arr_new n 0.0
    let rc = tgl_relu_backward_f64 x_data grad_data result n
    if rc >= 0 then result
    else x_data
  else x_data

-- In-place SGD step: w -= lr * grad
gpu_sgd_ffi w_data g_data lr n =
  if ensure_dylib 0 then
    let larr = float_arr_new 1 lr
    let rc = tgl_sgd_update_f64 w_data g_data larr n
    rc
  else 0 - 1

-- Fused matmul+bias+relu
gpu_matmul_relu_ffi a_data b_data bias_data m k n =
  if ensure_dylib 0 then
    let result = float_arr_new (m * n) 0.0
    let rc = tgl_matmul_relu_f64 a_data b_data bias_data result m k n
    if rc >= 0 then result
    else result
  else
    gpu_matmul_dispatch a_data b_data m k n

-- ============================================================
-- Matmul: (M,K) @ (K,N) -> (M,N)
-- Auto-dispatches to Metal GPU when available, falls back to CPU
-- ============================================================

-- Innermost loop: accumulate a[i,k]*b[k,j] for k
matmul_k a_data b_data acc_arr k_dim n_dim i j kk =
  if kk >= k_dim then 0
  else
    let av = float_arr_get a_data (i * k_dim + kk)
    let bv = float_arr_get b_data (kk * n_dim + j)
    let cur = float_arr_get acc_arr 0
    let _ = float_arr_set acc_arr 0 (cur + av * bv)
    matmul_k a_data b_data acc_arr k_dim n_dim i j (kk + 1)

-- Middle loop: iterate over columns j
matmul_j a_data b_data c_data acc_arr m_dim n_dim k_dim i j =
  if j >= n_dim then 0
  else
    let _ = float_arr_set acc_arr 0 0.0
    let _ = matmul_k a_data b_data acc_arr k_dim n_dim i j 0
    let _ = float_arr_set c_data (i * n_dim + j) (float_arr_get acc_arr 0)
    matmul_j a_data b_data c_data acc_arr m_dim n_dim k_dim i (j + 1)

-- Outer loop: iterate over rows i
matmul_i a_data b_data c_data acc_arr m_dim n_dim k_dim i =
  if i >= m_dim then 0
  else
    let _ = matmul_j a_data b_data c_data acc_arr m_dim n_dim k_dim i 0
    matmul_i a_data b_data c_data acc_arr m_dim n_dim k_dim (i + 1)

-- Matrix multiply: a(M,K) @ b(K,N) -> c(M,N)
-- Tries Metal GPU first (269 GFLOPS), falls back to CPU loops
matmul a b = match a
  | Tensor a_data a_shape a_strides -> match b
    | Tensor b_data b_shape b_strides ->
      let m_dim = head a_shape
      let k_dim = head (tail a_shape)
      let n_dim = head (tail b_shape)
      if gpu_available 0 then
        let c_data = gpu_matmul_dispatch a_data b_data m_dim k_dim n_dim
        Tensor c_data (cons m_dim (cons n_dim [])) (compute_strides (cons m_dim (cons n_dim [])))
      else
        let c = tensor_zeros (cons m_dim (cons n_dim []))
        let acc = float_arr_new 1 0.0
        match c
          | Tensor c_data c_shape c_strides ->
            let _ = matmul_i a_data b_data c_data acc m_dim n_dim k_dim 0
            c

-- fp16 matmul: Tensor wrapper around tgl_matmul_f16. Dylib stages f64→f16
-- on the way in and f16→f64 on the way out; accumulator stays fp32 inside
-- the kernel. GPU-only (no CPU fallback — callers should gate on
-- gpu_available if they need graceful degradation). ~1.7-1.8× speedup vs
-- tgl_matmul_f64 on M1 Ultra per fp16_drafts/RESULTS.md; Phase 4b training
-- uses this in lm_v3_chunked_fp16.rail.
matmul_f16 a b = match a
  | Tensor a_data a_shape a_strides -> match b
    | Tensor b_data b_shape b_strides ->
      let m_dim = head a_shape
      let k_dim = head (tail a_shape)
      let n_dim = head (tail b_shape)
      let c_data = float_arr_new (m_dim * n_dim) 0.0
      let _ = tgl_matmul_f16 a_data b_data c_data m_dim k_dim n_dim
      Tensor c_data (cons m_dim (cons n_dim [])) (compute_strides (cons m_dim (cons n_dim [])))

-- bf16 matmul (added 2026-05-14): same Rail signature as matmul_f16.
-- Dylib stages f64→bf16 on input and bf16→f64 on output; accumulator
-- stays fp32 inside the kernel. bf16's exponent range matches f32
-- (~3.4e38) — fp16's step-2759 overflow simply cannot happen here.
matmul_bf16 a b = match a
  | Tensor a_data a_shape a_strides -> match b
    | Tensor b_data b_shape b_strides ->
      let m_dim = head a_shape
      let k_dim = head (tail a_shape)
      let n_dim = head (tail b_shape)
      let c_data = float_arr_new (m_dim * n_dim) 0.0
      let _ = tgl_matmul_bf16 a_data b_data c_data m_dim k_dim n_dim
      Tensor c_data (cons m_dim (cons n_dim [])) (compute_strides (cons m_dim (cons n_dim [])))

-- fp16 fused matmul+bias+relu. `bias` is a float_arr (length n) — stays
-- fp32 on the GPU per the fp32-bias hint that landed the fused kernels.
-- Semantics: out[i,j] = max(0, sum_k a[i,k]*b[k,j] + bias[j]).
matmul_bias_relu_f16 a b bias = match a
  | Tensor a_data a_shape a_strides -> match b
    | Tensor b_data b_shape b_strides ->
      let m_dim = head a_shape
      let k_dim = head (tail a_shape)
      let n_dim = head (tail b_shape)
      let c_data = float_arr_new (m_dim * n_dim) 0.0
      let _ = tgl_matmul_bias_relu_f16 a_data b_data bias c_data m_dim k_dim n_dim
      Tensor c_data (cons m_dim (cons n_dim [])) (compute_strides (cons m_dim (cons n_dim [])))

-- fp16 fused matmul+bias+gelu (same pattern, GELU activation).
matmul_bias_gelu_f16 a b bias = match a
  | Tensor a_data a_shape a_strides -> match b
    | Tensor b_data b_shape b_strides ->
      let m_dim = head a_shape
      let k_dim = head (tail a_shape)
      let n_dim = head (tail b_shape)
      let c_data = float_arr_new (m_dim * n_dim) 0.0
      let _ = tgl_matmul_bias_gelu_f16 a_data b_data bias c_data m_dim k_dim n_dim
      Tensor c_data (cons m_dim (cons n_dim [])) (compute_strides (cons m_dim (cons n_dim [])))

-- ============================================================
-- HalfTensor — f16-native representation
-- ============================================================

-- Packed-half storage size: one 8-byte float_arr slot holds 4 halfs.
-- For n half elements we need ceil(n/4) slots. Slot 0 of the underlying
-- float_arr still holds the slot-count header (set by float_arr_new).
half_storage_slots n = (n + 3) / 4

-- Allocate an uninitialized HalfTensor with the given shape. Content
-- is zero-bit-pattern, which in fp16 is 0.0 — so this is effectively
-- half_zeros, same as tensor_zeros for Tensor.
half_tensor_new shape =
  let n = list_product shape
  let data = float_arr_new (half_storage_slots n) 0.0
  HalfTensor data shape (compute_strides shape)

-- f64 Tensor → packed HalfTensor (host-side, one cast). Used at init,
-- checkpoint load, and eval boundaries — NOT inside the training loop.
half_of_tensor t = match t
  | Tensor t_data shape strides ->
    let n = list_product shape
    let h_data = float_arr_new (half_storage_slots n) 0.0
    let _ = tgl_f64_to_half t_data h_data n
    HalfTensor h_data shape strides

-- Packed HalfTensor → f64 Tensor. Inverse of half_of_tensor; used at
-- measurement / checkpoint save time.
tensor_of_half h = match h
  | HalfTensor h_data shape strides ->
    let n = list_product shape
    let t_data = float_arr_new n 0.0
    let _ = tgl_half_to_f64 h_data t_data n
    Tensor t_data shape strides

-- fp16 matmul on already-packed HalfTensors. Zero cast at the host
-- boundary — memcpy in, GPU kernel runs on fp16, memcpy out. This is
-- the Phase 4b follow-on: the language stops paying the cast on every
-- call; it pays once at init/checkpoint.
matmul_half a b = match a
  | HalfTensor a_data a_shape a_strides -> match b
    | HalfTensor b_data b_shape b_strides ->
      let m_dim = head a_shape
      let k_dim = head (tail a_shape)
      let n_dim = head (tail b_shape)
      let out_shape = cons m_dim (cons n_dim [])
      let c_data = float_arr_new (half_storage_slots (m_dim * n_dim)) 0.0
      let _ = tgl_matmul_half_host a_data b_data c_data m_dim k_dim n_dim
      HalfTensor c_data out_shape (compute_strides out_shape)

-- Rail-native mixed precision matmul: f64 activations × fp16 weights → f64.
-- Acts arrive as a Tensor (f64) and leave as a Tensor (f64) — Rail-side code
-- never sees fp32 directly. GPU keeps acts in fp32 and weights in fp16; the
-- dot product runs in fp32 (Apple Silicon GPU's native compute precision).
-- This is the precision-preserving sibling of matmul_half: a HalfTensor of
-- weights composes with f64 activations through every block boundary
-- without the half-cast precision loss that flipped argmax in v3_half.
matmul_mixed x w_h = match x
  | Tensor x_data x_shape _ -> match w_h
    | HalfTensor w_data w_shape _ ->
      let m_dim = head x_shape
      let k_dim = head (tail x_shape)
      let n_dim = head (tail w_shape)
      let out_shape = cons m_dim (cons n_dim [])
      let y_data = float_arr_new (m_dim * n_dim) 0.0
      let _ = tgl_matmul_f32x_halfw_host x_data w_data y_data m_dim k_dim n_dim
      Tensor y_data out_shape (compute_strides out_shape)

-- fp16 element-wise add on already-packed HalfTensors. Same zero-cast
-- contract as matmul_half: storage stays fp16, GPU keeps an fp32
-- accumulator per element.
add_half a b = match a
  | HalfTensor a_data a_shape a_strides -> match b
    | HalfTensor b_data _ _ ->
      let n = list_product a_shape
      let c_data = float_arr_new (half_storage_slots n) 0.0
      let _ = tgl_add_half_host a_data b_data c_data n
      HalfTensor c_data a_shape a_strides

-- Scalar × HalfTensor. Scalar passed as a 1-element float_arr (ABI
-- match with tgl_scale_f64). GPU narrows the scalar to fp32.
scale_half t s = match t
  | HalfTensor a_data a_shape a_strides ->
    let n = list_product a_shape
    let scalar_arr = float_arr_new 1 s
    let c_data = float_arr_new (half_storage_slots n) 0.0
    let _ = tgl_scale_half_host a_data c_data scalar_arr n
    HalfTensor c_data a_shape a_strides

-- Transpose a 2-D HalfTensor. Shape [M,N] → [N,M]; storage stays fp16
-- so the op is bit-exact vs its f64 cast-path counterpart (shape
-- permutation only, no arithmetic).
transpose_half t = match t
  | HalfTensor a_data a_shape a_strides ->
    let m_dim = head a_shape
    let n_dim = head (tail a_shape)
    let out_shape = cons n_dim (cons m_dim [])
    let b_data = float_arr_new (half_storage_slots (m_dim * n_dim)) 0.0
    let _ = tgl_transpose_half_host a_data b_data m_dim n_dim
    HalfTensor b_data out_shape (compute_strides out_shape)

-- Row-wise softmax on a 2-D HalfTensor. GPU keeps the fp32 max-subtract
-- and exp-sum in fp32 so raw logits up to ~88 (exp overflow in fp32)
-- are safe — far above fp16's ~65504 ceiling. Output stored in fp16,
-- in [0, 1] by construction.
softmax_half t = match t
  | HalfTensor a_data a_shape a_strides ->
    let rows = head a_shape
    let cols = head (tail a_shape)
    let c_data = float_arr_new (half_storage_slots (rows * cols)) 0.0
    let _ = tgl_softmax_half_host a_data c_data rows cols
    HalfTensor c_data a_shape a_strides

-- ============================================================
-- Element-wise ops (explicit loops, no function passing)
-- ============================================================

-- Add loop
add_loop a_data b_data c_data n i =
  if i >= n then 0
  else
    let av = float_arr_get a_data i
    let bv = float_arr_get b_data i
    let _ = float_arr_set c_data i (av + bv)
    add_loop a_data b_data c_data n (i + 1)

-- Sub loop
sub_loop a_data b_data c_data n i =
  if i >= n then 0
  else
    let av = float_arr_get a_data i
    let bv = float_arr_get b_data i
    let _ = float_arr_set c_data i (av - bv)
    sub_loop a_data b_data c_data n (i + 1)

-- Hadamard (element-wise multiply) loop
hmul_loop a_data b_data c_data n i =
  if i >= n then 0
  else
    let av = float_arr_get a_data i
    let bv = float_arr_get b_data i
    let _ = float_arr_set c_data i (av * bv)
    hmul_loop a_data b_data c_data n (i + 1)

-- Scalar multiply loop
scale_loop src dst n i s =
  if i >= n then 0
  else
    let v = float_arr_get src i
    let _ = float_arr_set dst i (v * s)
    scale_loop src dst n (i + 1) s

-- GPU binary elementwise dispatch (add, mul) via daemon or file
gpu_binop_dispatch op a_data b_data n =
  if gpu_daemon_available 0 then
    let a_str = gpu_format_floats a_data n
    let b_str = gpu_format_floats b_data n
    let cmd = join "" (cons "printf '" (cons op (cons " " (cons (show n) (cons "\\n" (cons a_str (cons "\\n" (cons b_str (cons "\\nEND\\n' | nc -w2 localhost 9300" [])))))))))
    gpu_parse_inline (shell cmd) n
  else
    let _ = gpu_write_floats "/tmp/rail_tg_a.txt" a_data n
    let _ = gpu_write_floats "/tmp/rail_tg_b.txt" b_data n
    let cmd = join "" (cons gpu_binary_path (cons " " (cons op (cons " " (cons (show n) (cons " /tmp/rail_tg_a.txt /tmp/rail_tg_b.txt /tmp/rail_tg_c.txt" []))))))
    let _ = shell cmd
    let result = gpu_read_floats "/tmp/rail_tg_c.txt" n
    let _ = gpu_cleanup 0
    result

-- GPU unary elementwise dispatch (relu, exp, tanh, sigmoid)
gpu_unop_dispatch op src_data n =
  if gpu_daemon_available 0 then
    let a_str = gpu_format_floats src_data n
    let cmd = join "" (cons "printf '" (cons op (cons " " (cons (show n) (cons "\\n" (cons a_str (cons "\\nEND\\n' | nc -w2 localhost 9300" [])))))))
    gpu_parse_inline (shell cmd) n
  else
    let _ = gpu_write_floats "/tmp/rail_tg_a.txt" src_data n
    let cmd = join "" (cons gpu_binary_path (cons " " (cons op (cons " " (cons (show n) (cons " /tmp/rail_tg_a.txt /tmp/rail_tg_c.txt" []))))))
    let _ = shell cmd
    let result = gpu_read_floats "/tmp/rail_tg_c.txt" n
    let _ = gpu_cleanup 0
    result

-- tensor_add: element-wise addition (GPU via dylib, CPU fallback)
tensor_add a b = match a
  | Tensor a_data a_shape a_st -> match b
    | Tensor b_data b_shape b_st ->
      let n = list_product a_shape
      if gpu_available 0 then
        Tensor (gpu_binop_ffi "add" a_data b_data n) a_shape a_st
      else
        let c = tensor_zeros a_shape
        match c
          | Tensor c_data c_shape c_st ->
            let _ = add_loop a_data b_data c_data n 0
            c

-- tensor_sub: element-wise subtraction
tensor_sub a b = match a
  | Tensor a_data a_shape a_st -> match b
    | Tensor b_data b_shape b_st ->
      let n = list_product a_shape
      let c = tensor_zeros a_shape
      match c
        | Tensor c_data c_shape c_st ->
          let _ = sub_loop a_data b_data c_data n 0
          c

-- tensor_mul: element-wise (Hadamard) multiplication (GPU via dylib, CPU fallback)
tensor_mul a b = match a
  | Tensor a_data a_shape a_st -> match b
    | Tensor b_data b_shape b_st ->
      let n = list_product a_shape
      if gpu_available 0 then
        Tensor (gpu_binop_ffi "mul" a_data b_data n) a_shape a_st
      else
        let c = tensor_zeros a_shape
        match c
          | Tensor c_data c_shape c_st ->
            let _ = hmul_loop a_data b_data c_data n 0
            c

-- tensor_scale: scalar multiply (GPU via dylib, CPU fallback)
tensor_scale t s = match t
  | Tensor t_data t_shape t_st ->
    let n = list_product t_shape
    if gpu_available 0 then
      Tensor (gpu_scale_ffi t_data s n) t_shape t_st
    else
      let c = tensor_zeros t_shape
      match c
        | Tensor c_data c_shape c_st ->
          let _ = scale_loop t_data c_data n 0 s
          c

-- ============================================================
-- Activation functions (element-wise, explicit loops)
-- ============================================================

-- ReLU loop: max(0, x)
relu_loop src dst n i =
  if i >= n then 0
  else
    let v = float_arr_get src i
    let r = if v > 0.0 then v else 0.0
    let _ = float_arr_set dst i r
    relu_loop src dst n (i + 1)

tensor_relu t = match t
  | Tensor t_data t_shape t_st ->
    let n = list_product t_shape
    if gpu_available 0 then
      let c_data = gpu_unop_ffi "relu" t_data n
      Tensor c_data t_shape t_st
    else
      let c = tensor_zeros t_shape
      match c
        | Tensor c_data c_shape c_st ->
          let _ = relu_loop t_data c_data n 0
          c

-- GELU loop: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))
gelu_loop src dst n i =
  if i >= n then 0
  else
    let x = float_arr_get src i
    let pi_val = 3.14159265358979
    let coeff = sqrt (2.0 / pi_val)
    let inner = coeff * (x + 0.044715 * x * x * x)
    let result = 0.5 * x * (1.0 + tanh inner)
    let _ = float_arr_set dst i result
    gelu_loop src dst n (i + 1)

tensor_gelu t = match t
  | Tensor t_data t_shape t_st ->
    let n = list_product t_shape
    let c = tensor_zeros t_shape
    match c
      | Tensor c_data c_shape c_st ->
        let _ = gelu_loop t_data c_data n 0
        c

-- GPU-accelerated exp, tanh, sigmoid
tensor_exp_gpu t = match t
  | Tensor t_data t_shape t_st ->
    let n = list_product t_shape
    if gpu_available 0 then
      Tensor (gpu_unop_ffi "exp" t_data n) t_shape t_st
    else
      let c = tensor_zeros t_shape
      match c | Tensor c_data _ _ ->
        let _ = exp_loop t_data c_data n 0
        c

exp_loop src dst n i =
  if i >= n then 0
  else
    let _ = float_arr_set dst i (exp (float_arr_get src i))
    exp_loop src dst n (i + 1)

tensor_tanh_gpu t = match t
  | Tensor t_data t_shape t_st ->
    let n = list_product t_shape
    if gpu_available 0 then
      Tensor (gpu_unop_ffi "tanh_fwd" t_data n) t_shape t_st
    else
      let c = tensor_zeros t_shape
      match c | Tensor c_data _ _ ->
        let _ = tanh_loop t_data c_data n 0
        c

tanh_loop src dst n i =
  if i >= n then 0
  else
    let _ = float_arr_set dst i (tanh (float_arr_get src i))
    tanh_loop src dst n (i + 1)

tensor_sigmoid t = match t
  | Tensor t_data t_shape t_st ->
    let n = list_product t_shape
    if gpu_available 0 then
      Tensor (gpu_unop_ffi "sigmoid" t_data n) t_shape t_st
    else
      let c = tensor_zeros t_shape
      match c | Tensor c_data _ _ ->
        let _ = sigmoid_loop t_data c_data n 0
        c

sigmoid_loop src dst n i =
  if i >= n then 0
  else
    let x = float_arr_get src i
    let _ = float_arr_set dst i (1.0 / (1.0 + exp (0.0 - x)))
    sigmoid_loop src dst n (i + 1)

-- ============================================================
-- Reductions
-- ============================================================

-- Sum all elements (uses mutable accumulator array)
sum_loop data n i acc_arr =
  if i >= n then 0
  else
    let v = float_arr_get data i
    let cur = float_arr_get acc_arr 0
    let _ = float_arr_set acc_arr 0 (cur + v)
    sum_loop data n (i + 1) acc_arr

tensor_sum t = match t
  | Tensor t_data t_shape t_st ->
    let n = list_product t_shape
    let acc_arr = float_arr_new 1 0.0
    let _ = sum_loop t_data n 0 acc_arr
    float_arr_get acc_arr 0

-- Mean of all elements
tensor_mean t = match t
  | Tensor t_data t_shape t_st ->
    let n = list_product t_shape
    let acc_arr = float_arr_new 1 0.0
    let _ = sum_loop t_data n 0 acc_arr
    let s = float_arr_get acc_arr 0
    s / to_float n

-- ============================================================
-- Autograd tensor primitives (13 ops required by autograd.rail)
-- ============================================================

-- tensor_relu_mask: 1.0 where x > 0, 0.0 otherwise
relu_mask_loop src dst n i =
  if i >= n then 0
  else
    let v = float_arr_get src i
    let _ = float_arr_set dst i (if v > 0.0 then 1.0 else 0.0)
    relu_mask_loop src dst n (i + 1)

tensor_relu_mask t = match t
  | Tensor t_data t_shape t_st ->
    let n = list_product t_shape
    let c = tensor_zeros t_shape
    match c | Tensor c_data _ _ ->
      let _ = relu_mask_loop t_data c_data n 0
      c

-- tensor_gelu_backward: dGELU/dx * upstream gradient
gelu_back_loop x_data grad_data dst n i =
  if i >= n then 0
  else
    let x = float_arr_get x_data i
    let g = float_arr_get grad_data i
    let pi_val = 3.14159265358979
    let coeff = sqrt (2.0 / pi_val)
    let inner = coeff * (x + 0.044715 * x * x * x)
    let t_val = tanh inner
    let sech2 = 1.0 - t_val * t_val
    let d_inner = coeff * (1.0 + 3.0 * 0.044715 * x * x)
    let dg = 0.5 * (1.0 + t_val) + 0.5 * x * sech2 * d_inner
    let _ = float_arr_set dst i (g * dg)
    gelu_back_loop x_data grad_data dst n (i + 1)

tensor_gelu_backward x grad = match x
  | Tensor x_data x_shape _ -> match grad
    | Tensor g_data _ _ ->
      let n = list_product x_shape
      let c = tensor_zeros x_shape
      match c | Tensor c_data _ _ ->
        let _ = gelu_back_loop x_data g_data c_data n 0
        c

-- tensor_softmax: numerically stable softmax along last axis
-- For 2D [batch, classes]: softmax each row
softmax_row src dst cols offset =
  let acc = float_arr_new 1 (0.0 - 999999.0)
  let _ = softmax_find_max src cols offset acc 0
  let mx = float_arr_get acc 0
  let _ = float_arr_set acc 0 0.0
  let _ = softmax_exp_sum src dst cols offset mx acc 0
  let sm = float_arr_get acc 0
  softmax_normalize dst cols offset sm 0

softmax_find_max src cols offset acc j =
  if j >= cols then 0
  else
    let v = float_arr_get src (offset + j)
    let cur = float_arr_get acc 0
    let _ = if v > cur then float_arr_set acc 0 v else 0
    softmax_find_max src cols offset acc (j + 1)

softmax_exp_sum src dst cols offset mx acc j =
  if j >= cols then 0
  else
    let v = exp (float_arr_get src (offset + j) - mx)
    let _ = float_arr_set dst (offset + j) v
    let cur = float_arr_get acc 0
    let _ = float_arr_set acc 0 (cur + v)
    softmax_exp_sum src dst cols offset mx acc (j + 1)

softmax_normalize dst cols offset sm j =
  if j >= cols then 0
  else
    let v = float_arr_get dst (offset + j)
    let _ = float_arr_set dst (offset + j) (v / sm)
    softmax_normalize dst cols offset sm (j + 1)

softmax_rows src dst rows cols r =
  if r >= rows then 0
  else
    let _ = softmax_row src dst cols (r * cols)
    softmax_rows src dst rows cols (r + 1)

-- GPU softmax dispatch: daemon or file
gpu_softmax_dispatch src_data rows cols =
  let n = rows * cols
  if gpu_daemon_available 0 then
    let a_str = gpu_format_floats src_data n
    let cmd = join "" (cons "printf 'softmax " (cons (show rows) (cons " " (cons (show cols) (cons "\\n" (cons a_str (cons "\\nEND\\n' | nc -w2 localhost 9300" [])))))))
    gpu_parse_inline (shell cmd) n
  else
    let _ = gpu_write_floats "/tmp/rail_tg_a.txt" src_data n
    let cmd = join "" (cons gpu_binary_path (cons " softmax " (cons (show rows) (cons " " (cons (show cols) (cons " /tmp/rail_tg_a.txt /tmp/rail_tg_c.txt" []))))))
    let _ = shell cmd
    let r = gpu_read_floats "/tmp/rail_tg_c.txt" n
    let _ = gpu_cleanup 0
    r

tensor_softmax t = match t
  | Tensor t_data t_shape t_st ->
    let n = list_product t_shape
    let cols = if length t_shape >= 1 then head (reverse t_shape) else n
    let rows = n / cols
    if gpu_available 0 then
      Tensor (gpu_softmax_ffi t_data rows cols) t_shape t_st
    else
      let c = tensor_zeros t_shape
      match c | Tensor c_data _ _ ->
        let _ = softmax_rows t_data c_data rows cols 0
        c

-- tensor_sum_last: sum along last axis. [batch, N] → [batch, 1]
sum_last_loop src dst rows cols r =
  if r >= rows then 0
  else
    let acc = float_arr_new 1 0.0
    let _ = sum_last_row src cols (r * cols) acc 0
    let _ = float_arr_set dst r (float_arr_get acc 0)
    sum_last_loop src dst rows cols (r + 1)

sum_last_row src cols offset acc j =
  if j >= cols then 0
  else
    let cur = float_arr_get acc 0
    let _ = float_arr_set acc 0 (cur + float_arr_get src (offset + j))
    sum_last_row src cols offset acc (j + 1)

tensor_sum_last t = match t
  | Tensor t_data t_shape _ ->
    let cols = head (reverse t_shape)
    let rows = list_product t_shape / cols
    let out = tensor_zeros (cons rows (cons 1 []))
    match out | Tensor o_data _ o_st ->
      let _ = sum_last_loop t_data o_data rows cols 0
      out

-- tensor_sum_batch: sum along first axis. [batch, N] → [1, N]
sum_batch_loop src dst rows cols r =
  if r >= rows then 0
  else
    let _ = sum_batch_add_row src dst cols (r * cols) 0
    sum_batch_loop src dst rows cols (r + 1)

sum_batch_add_row src dst cols offset j =
  if j >= cols then 0
  else
    let cur = float_arr_get dst j
    let _ = float_arr_set dst j (cur + float_arr_get src (offset + j))
    sum_batch_add_row src dst cols offset (j + 1)

tensor_sum_batch t = match t
  | Tensor t_data t_shape _ ->
    let cols = head (reverse t_shape)
    let rows = list_product t_shape / cols
    let out = tensor_zeros (cons 1 (cons cols []))
    match out | Tensor o_data _ _ ->
      let _ = sum_batch_loop t_data o_data rows cols 0
      out

-- tensor_mean_last: mean along last axis. [batch, N] → [batch, 1]
tensor_mean_last t = match t
  | Tensor t_data t_shape _ ->
    let cols = head (reverse t_shape)
    let s = tensor_sum_last t
    tensor_scale s (1.0 / to_float cols)

-- tensor_broadcast_last: [batch, 1] → [batch, N] by repeating
broadcast_last_loop src dst rows cols r =
  if r >= rows then 0
  else
    let v = float_arr_get src r
    let _ = broadcast_last_fill dst (r * cols) cols v 0
    broadcast_last_loop src dst rows cols (r + 1)

broadcast_last_fill dst offset cols v j =
  if j >= cols then 0
  else
    let _ = float_arr_set dst (offset + j) v
    broadcast_last_fill dst offset cols v (j + 1)

tensor_broadcast_last t cols = match t
  | Tensor t_data t_shape _ ->
    let rows = list_product t_shape
    let out = tensor_zeros (cons rows (cons cols []))
    match out | Tensor o_data _ _ ->
      let _ = broadcast_last_loop t_data o_data rows cols 0
      out

-- tensor_mul_broadcast: [batch, N] * [1, N] element-wise (broadcast first dim)
mul_bcast_loop a_data b_data dst rows cols r =
  if r >= rows then 0
  else
    let _ = mul_bcast_row a_data b_data dst cols (r * cols) 0
    mul_bcast_loop a_data b_data dst rows cols (r + 1)

mul_bcast_row a_data b_data dst cols offset j =
  if j >= cols then 0
  else
    let _ = float_arr_set dst (offset + j) (float_arr_get a_data (offset + j) * float_arr_get b_data j)
    mul_bcast_row a_data b_data dst cols offset (j + 1)

tensor_mul_broadcast a b = match a
  | Tensor a_data a_shape _ -> match b
    | Tensor b_data b_shape _ ->
      let cols = head (reverse a_shape)
      let rows = list_product a_shape / cols
      let out = tensor_zeros a_shape
      match out | Tensor o_data _ _ ->
        let _ = mul_bcast_loop a_data b_data o_data rows cols 0
        out

-- tensor_one_hot: indices tensor [batch] → [batch, vocab_size]
one_hot_loop indices dst batch vocab i =
  if i >= batch then 0
  else
    let idx = float_to_int (float_arr_get indices i)
    let _ = if idx >= 0 then if idx < vocab then float_arr_set dst (i * vocab + idx) 1.0 else 0 else 0
    one_hot_loop indices dst batch vocab (i + 1)

tensor_one_hot t vocab_size = match t
  | Tensor t_data t_shape _ ->
    let batch = list_product t_shape
    let out = tensor_zeros (cons batch (cons vocab_size []))
    match out | Tensor o_data _ _ ->
      let _ = one_hot_loop t_data o_data batch vocab_size 0
      out

-- tensor_embedding_lookup: weights [vocab, dim], indices [batch] → [batch, dim]
embed_lookup_loop weights indices dst batch dim i =
  if i >= batch then 0
  else
    let idx = float_to_int (float_arr_get indices i)
    let _ = embed_copy_row weights dst dim (idx * dim) (i * dim) 0
    embed_lookup_loop weights indices dst batch dim (i + 1)

embed_copy_row src dst dim src_off dst_off j =
  if j >= dim then 0
  else
    let _ = float_arr_set dst (dst_off + j) (float_arr_get src (src_off + j))
    embed_copy_row src dst dim src_off dst_off (j + 1)

tensor_embedding_lookup weights indices = match weights
  | Tensor w_data w_shape _ -> match indices
    | Tensor i_data i_shape _ ->
      let vocab = head w_shape
      let dim = head (tail w_shape)
      let batch = list_product i_shape
      let out = tensor_zeros (cons batch (cons dim []))
      match out | Tensor o_data _ _ ->
        let _ = embed_lookup_loop w_data i_data o_data batch dim 0
        out

-- tensor_slice_row: extract row i from 2D tensor [rows, cols] → [1, cols]
tensor_slice_row t row_idx = match t
  | Tensor t_data t_shape _ ->
    let cols = head (tail t_shape)
    let out = tensor_zeros (cons 1 (cons cols []))
    match out | Tensor o_data _ _ ->
      let _ = embed_copy_row t_data o_data cols (row_idx * cols) 0 0
      out

-- tensor_accumulate_row: scatter-add grad into weight rows (for embedding backward)
accum_row_loop grad weights indices batch dim i =
  if i >= batch then 0
  else
    let idx = float_to_int (float_arr_get indices i)
    let _ = accum_row_add weights grad dim (idx * dim) (i * dim) 0
    accum_row_loop grad weights indices batch dim (i + 1)

accum_row_add dst src dim dst_off src_off j =
  if j >= dim then 0
  else
    let cur = float_arr_get dst (dst_off + j)
    let _ = float_arr_set dst (dst_off + j) (cur + float_arr_get src (src_off + j))
    accum_row_add dst src dim dst_off src_off (j + 1)

tensor_accumulate_row weights grad indices = match weights
  | Tensor w_data w_shape _ -> match grad
    | Tensor g_data _ _ -> match indices
      | Tensor i_data i_shape _ ->
        let dim = head (tail w_shape)
        let batch = list_product i_shape
        let out = tensor_clone weights
        match out | Tensor o_data _ _ ->
          let _ = accum_row_loop g_data o_data i_data batch dim 0
          out

-- tensor_scale_by_loss: scale all elements by scalar loss gradient
tensor_scale_by_loss t loss_grad = tensor_scale t loss_grad

-- ============================================================
-- Additional autograd utilities
-- ============================================================

-- tensor_copy: alias for tensor_clone
tensor_copy t = tensor_clone t

-- tensor_matmul: alias for matmul (autograd uses this name)
tensor_matmul a b = matmul a b

-- tensor_sum_all: alias for tensor_sum
tensor_sum_all t = tensor_sum t

-- tensor_ones_like_shape: ones tensor with given shape
tensor_ones_like_shape shape = tensor_ones shape

-- tensor_scalar_val: extract scalar from a 1-element tensor
tensor_scalar_val t = tensor_get_flat t 0

-- tensor_from_int: wrap an int as a 1-element tensor
tensor_from_int n =
  let t = tensor_new (cons 1 []) 0.0
  let _ = tensor_set_flat t 0 (to_float n)
  t

-- tensor_to_int: extract int from a 1-element tensor
tensor_to_int t = float_to_int (tensor_get_flat t 0)

-- tensor_get_int: get element as int (for indexing)
tensor_get_int t i = float_to_int (tensor_get_flat t i)

-- tensor_map_scalar: apply a scalar function to all elements
-- Since Rail lambdas can segfault, this takes a function tag string
-- and dispatches. Use named functions instead.
map_scalar_loop src dst n i fn_tag =
  if i >= n then 0
  else
    let v = float_arr_get src i
    let r = if fn_tag == "neg" then 0.0 - v
            else if fn_tag == "abs" then (if v < 0.0 then 0.0 - v else v)
            else if fn_tag == "sq" then v * v
            else if fn_tag == "sqrt" then sqrt v
            else if fn_tag == "exp" then exp v
            else if fn_tag == "log" then log v
            else if fn_tag == "inv" then 1.0 / v
            else v
    let _ = float_arr_set dst i r
    map_scalar_loop src dst n (i + 1) fn_tag

tensor_map_scalar fn_tag t = match t
  | Tensor t_data t_shape t_st ->
    let n = list_product t_shape
    let c = tensor_zeros t_shape
    match c | Tensor c_data _ _ ->
      let _ = map_scalar_loop t_data c_data n 0 fn_tag
      c

-- tensor_cross_entropy_loss: -sum(targets * log(probs)) / batch
-- probs: [batch, vocab] (after softmax), targets: [batch] (indices)
ce_loss_loop probs targets batch vocab i acc =
  if i >= batch then acc
  else
    let idx = float_to_int (float_arr_get targets i)
    let p = float_arr_get probs (i * vocab + idx)
    let lp = if p > 0.000001 then log p else log 0.000001
    ce_loss_loop probs targets batch vocab (i + 1) (acc - lp)

tensor_cross_entropy_loss probs targets = match probs
  | Tensor p_data p_shape _ -> match targets
    | Tensor t_data t_shape _ ->
      let batch = head p_shape
      let vocab = head (tail p_shape)
      let loss = ce_loss_loop p_data t_data batch vocab 0 0.0
      let result = tensor_new (cons 1 []) 0.0
      let _ = tensor_set_flat result 0 (loss / to_float batch)
      result

-- ============================================================
-- Shape operations
-- ============================================================

-- Transpose: for 2D tensors, copies data into new row-major layout
-- This is a REAL transpose (data copy), not just stride swap,
-- because matmul assumes contiguous row-major data.
transpose_copy_loop src dst rows cols i =
  if i >= rows then 0
  else
    let _ = transpose_copy_row src dst rows cols i 0
    transpose_copy_loop src dst rows cols (i + 1)

transpose_copy_row src dst rows cols i j =
  if j >= cols then 0
  else
    let _ = float_arr_set dst (j * rows + i) (float_arr_get src (i * cols + j))
    transpose_copy_row src dst rows cols i (j + 1)

-- GPU transpose dispatch: daemon or file
gpu_transpose_dispatch src_data rows cols =
  let n = rows * cols
  if gpu_daemon_available 0 then
    -- (daemon protocol for transpose not wired — fall through to file mode)
    let _ = gpu_write_floats "/tmp/rail_tg_a.txt" src_data n
    let cmd = join "" (cons gpu_binary_path (cons " transpose " (cons (show rows) (cons " " (cons (show cols) (cons " /tmp/rail_tg_a.txt /tmp/rail_tg_c.txt" []))))))
    let _ = shell cmd
    let r = gpu_read_floats "/tmp/rail_tg_c.txt" n
    let _ = gpu_cleanup 0
    r
  else
    let _ = gpu_write_floats "/tmp/rail_tg_a.txt" src_data n
    let cmd = join "" (cons gpu_binary_path (cons " transpose " (cons (show rows) (cons " " (cons (show cols) (cons " /tmp/rail_tg_a.txt /tmp/rail_tg_c.txt" []))))))
    let _ = shell cmd
    let r = gpu_read_floats "/tmp/rail_tg_c.txt" n
    let _ = gpu_cleanup 0
    r

tensor_transpose t = match t
  | Tensor t_data t_shape t_st ->
    let rows = head t_shape
    let cols = head (tail t_shape)
    let n = rows * cols
    if gpu_available 0 then
      let new_data = gpu_transpose_ffi t_data rows cols
      Tensor new_data (cons cols (cons rows [])) (compute_strides (cons cols (cons rows [])))
    else
      let new_data = float_arr_new n 0.0
      let _ = transpose_copy_loop t_data new_data rows cols 0
      Tensor new_data (cons cols (cons rows [])) (compute_strides (cons cols (cons rows [])))

-- Reshape: new shape, same data, recompute strides
-- Caller must ensure same total element count
tensor_reshape t new_shape = match t
  | Tensor t_data t_shape t_st ->
    let new_strides = compute_strides new_shape
    Tensor t_data new_shape new_strides

-- ============================================================
-- Phase-1 gap fills (2026-04-16)
-- rank, general slice, layer_norm forward
-- ============================================================

-- tensor_rank: number of dimensions
tensor_rank t = match t
  | Tensor _ s _ -> length s

-- Helper: copy [start*block..end*block) from src to dst starting at dst offset 0.
copy_range src dst src_off n i =
  if i >= n then 0
  else
    let _ = float_arr_set dst i (float_arr_get src (src_off + i))
    copy_range src dst src_off n (i + 1)

-- tensor_slice: extract a contiguous slice along axis 0 (outermost).
-- For a tensor with shape [D0, D1, D2, ...], returns shape [end-start, D1, D2, ...].
-- start is inclusive, end is exclusive. Caller must ensure 0 <= start < end <= D0.
-- This is the common case for attention KV-cache rolling, batch splits, etc.
-- For slicing along inner axes, compose with transpose first.
tensor_slice t start end = match t
  | Tensor t_data t_shape _ ->
    let d0 = head t_shape
    let inner = tail t_shape
    let inner_size = list_product inner
    let new_d0 = end - start
    let new_shape = cons new_d0 inner
    let total = new_d0 * inner_size
    let out_data = float_arr_new total 0.0
    let src_off = start * inner_size
    let _ = copy_range t_data out_data src_off total 0
    Tensor out_data new_shape (compute_strides new_shape)

-- ============================================================
-- LayerNorm (forward)
--   y = gamma * (x - mean) / sqrt(var + eps) + beta
-- Normalizes along the LAST axis. For input shape [..., D]:
-- - rows = product of all but the last dim
-- - dim  = last dim
-- gamma and beta are 1D tensors of shape [D] (or scalars if null-shaped).
-- Matches the tgl_layernorm_backward_f64 foreign contract so the backward
-- pass (reserved for Phase 4 autodiff) lines up: mean and rstd outputs
-- are returned as a triple alongside y.
-- ============================================================

-- Inner: per-row normalization. Writes y[row*dim..(row+1)*dim] in place.
-- Returns 0 (float-result-via-mutation convention — see module header).
ln_row x y gamma beta row dim eps =
  let base = row * dim
  let sum_arr = float_arr_new 1 0.0
  let _ = ln_row_sum x base dim 0 sum_arr
  let mean = (float_arr_get sum_arr 0) / (0.0 + dim)
  let sq_arr = float_arr_new 1 0.0
  let _ = ln_row_sq x base dim 0 mean sq_arr
  let var = (float_arr_get sq_arr 0) / (0.0 + dim)
  let rstd = 1.0 / (sqrt (var + eps))
  let _ = ln_row_apply x y gamma beta base dim 0 mean rstd
  0

ln_row_sum x base dim i acc =
  if i >= dim then 0
  else
    let v = float_arr_get x (base + i)
    let cur = float_arr_get acc 0
    let _ = float_arr_set acc 0 (cur + v)
    ln_row_sum x base dim (i + 1) acc

ln_row_sq x base dim i mean acc =
  if i >= dim then 0
  else
    let v = float_arr_get x (base + i)
    let d = v - mean
    let cur = float_arr_get acc 0
    let _ = float_arr_set acc 0 (cur + d * d)
    ln_row_sq x base dim (i + 1) mean acc

ln_row_apply x y gamma beta base dim i mean rstd =
  if i >= dim then 0
  else
    let v = float_arr_get x (base + i)
    let g = float_arr_get gamma i
    let b = float_arr_get beta i
    let norm = (v - mean) * rstd
    let _ = float_arr_set y (base + i) (g * norm + b)
    ln_row_apply x y gamma beta base dim (i + 1) mean rstd

ln_rows x y gamma beta rows dim eps r =
  if r >= rows then 0
  else
    let _ = ln_row x y gamma beta r dim eps
    ln_rows x y gamma beta rows dim eps (r + 1)

tensor_layer_norm t gamma beta eps = match t
  | Tensor t_data t_shape t_st ->
    let rank = length t_shape
    let dim = list_nth (rank - 1) t_shape
    let total = list_product t_shape
    let rows = total / dim
    let y = float_arr_new total 0.0
    match gamma | Tensor g_data _ _ ->
      match beta | Tensor b_data _ _ ->
        let _ = ln_rows t_data y g_data b_data rows dim eps 0
        Tensor y t_shape t_st

-- ============================================================
-- Print utility for debugging
-- ============================================================

print_row t n_cols i j =
  if j >= n_cols then
    let _ = print ""
    0
  else
    let v = tensor_get t [i, j]
    let _ = print (cat ["  ", show_float v])
    print_row t n_cols i (j + 1)

print_rows t n_rows n_cols i =
  if i >= n_rows then 0
  else
    let _ = print_row t n_cols i 0
    print_rows t n_rows n_cols (i + 1)

tensor_print_2d t =
  let s = tensor_shape t
  let rows = head s
  let cols = head (tail s)
  print_rows t rows cols 0

-- Clone: deep copy a tensor (new buffer, same shape/strides)
tensor_clone t =
  match t | Tensor data shape strides ->
    let n = float_arr_len data
    let new_data = float_arr_new n 0.0
    let _ = tensor_clone_copy data new_data n 0
    Tensor new_data shape strides

tensor_clone_copy src dst n i =
  if i >= n then 0
  else
    let _ = float_arr_set dst i (float_arr_get src i)
    tensor_clone_copy src dst n (i + 1)

-- Read binary file as float_arr of bytes
read_binary path =
  let content = shell (cat ["xxd -p ", path, " | tr -d '\\n'"])
  let n = length content / 2
  let arr = float_arr_new n 0.0
  let _ = read_binary_fill arr content n 0
  arr

read_binary_fill arr hex n i =
  if i >= n then 0
  else
    let hi = hex_val (head (chars (str_sub hex (i * 2) 1)))
    let lo = hex_val (head (chars (str_sub hex (i * 2 + 1) 1)))
    let _ = float_arr_set arr i (to_float (hi * 16 + lo))
    read_binary_fill arr hex n (i + 1)

hex_val c =
  if c == "0" then 0 else if c == "1" then 1 else if c == "2" then 2
  else if c == "3" then 3 else if c == "4" then 4 else if c == "5" then 5
  else if c == "6" then 6 else if c == "7" then 7 else if c == "8" then 8
  else if c == "9" then 9 else if c == "a" then 10 else if c == "b" then 11
  else if c == "c" then 12 else if c == "d" then 13 else if c == "e" then 14
  else if c == "f" then 15 else 0

-- ============================================================
-- Test / main
-- ============================================================

tensor_test_main _ =
  -- Create A = [[1,2,3],[4,5,6]] (2x3)
  let a = tensor_new [2, 3] 0.0
  let _ = tensor_set a [0, 0] 1.0
  let _ = tensor_set a [0, 1] 2.0
  let _ = tensor_set a [0, 2] 3.0
  let _ = tensor_set a [1, 0] 4.0
  let _ = tensor_set a [1, 1] 5.0
  let _ = tensor_set a [1, 2] 6.0

  -- Create B = [[7,8],[9,10],[11,12]] (3x2)
  let b = tensor_new [3, 2] 0.0
  let _ = tensor_set b [0, 0] 7.0
  let _ = tensor_set b [0, 1] 8.0
  let _ = tensor_set b [1, 0] 9.0
  let _ = tensor_set b [1, 1] 10.0
  let _ = tensor_set b [2, 0] 11.0
  let _ = tensor_set b [2, 1] 12.0

  -- Matmul: A(2,3) @ B(3,2) = C(2,2)
  -- Expected: [[58, 64], [139, 154]]
  let c = matmul a b
  let _ = print "=== Matmul A(2x3) @ B(3x2) ==="
  let _ = print (cat ["C[0,0] = ", show_float (tensor_get c [0, 0]), " (expect 58)"])
  let _ = print (cat ["C[0,1] = ", show_float (tensor_get c [0, 1]), " (expect 64)"])
  let _ = print (cat ["C[1,0] = ", show_float (tensor_get c [1, 0]), " (expect 139)"])
  let _ = print (cat ["C[1,1] = ", show_float (tensor_get c [1, 1]), " (expect 154)"])

  -- Test element-wise ops
  let _ = print "=== Element-wise ops ==="
  let x = tensor_new [2, 2] 3.0
  let y = tensor_new [2, 2] 2.0
  let added = tensor_add x y
  let _ = print (cat ["3+2 = ", show_float (tensor_get added [0, 0]), " (expect 5)"])
  let subbed = tensor_sub x y
  let _ = print (cat ["3-2 = ", show_float (tensor_get subbed [0, 0]), " (expect 1)"])
  let mulled = tensor_mul x y
  let _ = print (cat ["3*2 = ", show_float (tensor_get mulled [0, 0]), " (expect 6)"])
  let scaled = tensor_scale x 10.0
  let _ = print (cat ["3*10 = ", show_float (tensor_get scaled [0, 0]), " (expect 30)"])

  -- Test activations
  let _ = print "=== Activations ==="
  let act = tensor_new [4, 1] 0.0
  let _ = tensor_set act [0, 0] (0.0 - 2.0)
  let _ = tensor_set act [1, 0] (0.0 - 1.0)
  let _ = tensor_set act [2, 0] 0.0
  let _ = tensor_set act [3, 0] 1.0
  let r = tensor_relu act
  let _ = print (cat ["relu(-2) = ", show_float (tensor_get r [0, 0]), " (expect 0)"])
  let _ = print (cat ["relu(1)  = ", show_float (tensor_get r [3, 0]), " (expect 1)"])
  let g = tensor_gelu act
  let _ = print (cat ["gelu(1)  = ", show_float (tensor_get g [3, 0]), " (expect ~0.8412)"])

  -- Test reductions
  let _ = print "=== Reductions ==="
  let vals = tensor_new [2, 2] 0.0
  let _ = tensor_set vals [0, 0] 1.0
  let _ = tensor_set vals [0, 1] 2.0
  let _ = tensor_set vals [1, 0] 3.0
  let _ = tensor_set vals [1, 1] 4.0
  let s = tensor_sum vals
  let _ = print (cat ["sum([1,2,3,4]) = ", show_float s, " (expect 10)"])
  let m = tensor_mean vals
  let _ = print (cat ["mean([1,2,3,4]) = ", show_float m, " (expect 2.5)"])

  -- Test transpose
  let _ = print "=== Transpose ==="
  let _ = print (cat ["A shape: ", show (tensor_shape a)])
  let at = tensor_transpose a
  let _ = print (cat ["A^T shape: ", show (tensor_shape at)])
  let _ = print (cat ["A^T[0,0] = ", show_float (tensor_get at [0, 0]), " (expect 1)"])
  let _ = print (cat ["A^T[0,1] = ", show_float (tensor_get at [0, 1]), " (expect 4)"])
  let _ = print (cat ["A^T[1,0] = ", show_float (tensor_get at [1, 0]), " (expect 2)"])
  let _ = print (cat ["A^T[2,0] = ", show_float (tensor_get at [2, 0]), " (expect 3)"])

  -- Test reshape
  let _ = print "=== Reshape ==="
  let flat = tensor_reshape vals [4]
  let _ = print (cat ["reshape shape: ", show (tensor_shape flat)])
  let _ = print (cat ["flat[2] = ", show_float (tensor_get flat [2]), " (expect 3)"])

  -- Test tensor_zeros / tensor_ones
  let _ = print "=== Zeros/Ones ==="
  let z = tensor_zeros [2, 3]
  let _ = print (cat ["zeros[0,0] = ", show_float (tensor_get z [0, 0]), " (expect 0)"])
  let o = tensor_ones [2, 3]
  let _ = print (cat ["ones[1,2] = ", show_float (tensor_get o [1, 2]), " (expect 1)"])

  let _ = print "=== All tests passed ==="
  0

-- ── Gradient checking (inline — cross-file float return broken) ──

grad_eps = 0.0001

sum_sq_grad t =
  let n = tensor_size t
  sum_sq_grad_acc t n 0 0.0
sum_sq_grad_acc t n i acc =
  if i >= n then acc
  else
    let v = tensor_get_flat t i
    sum_sq_grad_acc t n (i + 1) (acc + v * v)

sum_relu_grad t =
  let n = tensor_size t
  sum_relu_grad_acc t n 0 0.0
sum_relu_grad_acc t n i acc =
  if i >= n then acc
  else
    let v = tensor_get_flat t i
    let r = if v > 0.0 then v else 0.0
    sum_relu_grad_acc t n (i + 1) (acc + r)

sum_exp_grad t =
  let n = tensor_size t
  sum_exp_grad_acc t n 0 0.0
sum_exp_grad_acc t n i acc =
  if i >= n then acc
  else
    sum_exp_grad_acc t n (i + 1) (acc + exp (tensor_get_flat t i))

numgrad_sq t i =
  let orig = tensor_get_flat t i
  let _ = tensor_set_flat t i (orig + grad_eps)
  let fp = sum_sq_grad t
  let _ = tensor_set_flat t i (orig - grad_eps)
  let fm = sum_sq_grad t
  let _ = tensor_set_flat t i orig
  (fp - fm) / (2.0 * grad_eps)

numgrad_relu_fn t i =
  let orig = tensor_get_flat t i
  let _ = tensor_set_flat t i (orig + grad_eps)
  let fp = sum_relu_grad t
  let _ = tensor_set_flat t i (orig - grad_eps)
  let fm = sum_relu_grad t
  let _ = tensor_set_flat t i orig
  (fp - fm) / (2.0 * grad_eps)

numgrad_exp_fn t i =
  let orig = tensor_get_flat t i
  let _ = tensor_set_flat t i (orig + grad_eps)
  let fp = sum_exp_grad t
  let _ = tensor_set_flat t i (orig - grad_eps)
  let fm = sum_exp_grad t
  let _ = tensor_set_flat t i orig
  (fp - fm) / (2.0 * grad_eps)

abs_f x = if x > 0.0 then x else 0.0 - x

check_g name i numerical analytical =
  let diff = abs_f (numerical - analytical)
  let denom = abs_f analytical + 0.000001
  let rel = diff / denom
  let tag = if rel < 0.01 then "PASS" else "FAIL"
  let _ = print (cat ["  ", name, "[", show i, "]: num=", show_float numerical, " ana=", show_float analytical, " err=", show_float rel, " ", tag])
  tag

gradient_test_main _ =
  let _ = print "=== Gradient Checks ==="

  let _ = print "--- d/dx sum(x^2) = 2x ---"
  let x1 = tensor_new [4] 0.0
  let _ = tensor_set_flat x1 0 1.0
  let _ = tensor_set_flat x1 1 (0.0 - 2.0)
  let _ = tensor_set_flat x1 2 3.0
  let _ = tensor_set_flat x1 3 0.5
  let _ = check_g "x^2" 0 (numgrad_sq x1 0) 2.0
  let _ = check_g "x^2" 1 (numgrad_sq x1 1) (0.0 - 4.0)
  let _ = check_g "x^2" 2 (numgrad_sq x1 2) 6.0
  let _ = check_g "x^2" 3 (numgrad_sq x1 3) 1.0

  let _ = print "--- d/dx sum(relu(x)) = 1{x>0} ---"
  let x2 = tensor_new [4] 0.0
  let _ = tensor_set_flat x2 0 2.0
  let _ = tensor_set_flat x2 1 (0.0 - 1.0)
  let _ = tensor_set_flat x2 2 0.5
  let _ = tensor_set_flat x2 3 (0.0 - 3.0)
  let _ = check_g "relu" 0 (numgrad_relu_fn x2 0) 1.0
  let _ = check_g "relu" 1 (numgrad_relu_fn x2 1) 0.0
  let _ = check_g "relu" 2 (numgrad_relu_fn x2 2) 1.0
  let _ = check_g "relu" 3 (numgrad_relu_fn x2 3) 0.0

  let _ = print "--- d/dx sum(exp(x)) = exp(x) ---"
  let x3 = tensor_new [3] 0.0
  let _ = tensor_set_flat x3 0 0.0
  let _ = tensor_set_flat x3 1 1.0
  let _ = tensor_set_flat x3 2 (0.0 - 1.0)
  let _ = check_g "exp" 0 (numgrad_exp_fn x3 0) (exp 0.0)
  let _ = check_g "exp" 1 (numgrad_exp_fn x3 1) (exp 1.0)
  let _ = check_g "exp" 2 (numgrad_exp_fn x3 2) (exp (0.0 - 1.0))

  let _ = print "=== Done ==="
  0

-- Standalone test: ./rail_native run stdlib/tensor.rail
-- main =
--   let _ = tensor_test_main 0
--   let _ = gradient_test_main 0
--   0
