#!/usr/bin/env python3
"""
albert-train — one command to start GPU training + local dashboard

  albert-train                        # start Modal GPU training, open dashboard
  albert-train --no-browser           # same, skip opening Firefox
  albert-train --detach               # Modal detached mode (survives terminal close)
  albert-train --local                # CPU-only local training (no Modal, no CUDA)
  albert-train --contributor NAME     # CPU local training: 30 batches/epoch,
                                      # auto-push spore to albert-spores after each checkpoint
  albert-train pull                   # pull latest checkpoint from Modal volume
"""

import os, sys, subprocess, threading, signal, time, shutil, argparse, re

HERE          = os.path.dirname(os.path.abspath(__file__))
LOG           = os.path.join(HERE, "dashboard", "training.log")
BATCH_HISTORY = os.path.join(HERE, "dashboard", "batch_history.csv")
DASH_SRV      = os.path.join(HERE, "dashboard", "run_server.py")
MODAL_PY      = os.path.join(HERE, "train_modal.py")
CARGO         = os.path.expanduser("~/.cargo/bin/cargo")
PRODUCE_SPORE = os.path.join(HERE, "scripts", "produce_spore.py")

BATCHES_PER_EPOCH = 300
_BATCH_RE    = re.compile(r'Epoch\s+\d+\s+\(Global\s+(\d+)\),\s+Batch\s+(\d+):\s+loss\s*=\s*([\d.]+)')
_EPOCH_SM_RE = re.compile(r'EPOCH_SUMMARY epoch=(\d+) loss_avg=([\d.]+).*loss_best=([\d.]+)')

# ---------------------------------------------------------------------------
# pull subcommand — just delegate to train_modal.py
# ---------------------------------------------------------------------------
if len(sys.argv) > 1 and sys.argv[1] == "pull":
    os.chdir(HERE)
    sys.exit(subprocess.run([sys.executable, MODAL_PY, "pull"]).returncode)

# ---------------------------------------------------------------------------
# parse flags
# ---------------------------------------------------------------------------
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("--no-browser",  action="store_true")
parser.add_argument("--detach",      action="store_true")
parser.add_argument("--local",       action="store_true",
                    help="CPU-only local training: build without cuda, run against local data/models")
parser.add_argument("--contributor", default=None, metavar="NAME",
                    help="CPU local training with auto-spore push: 30 batches/epoch, "
                         "checkpoint pushed to albert-spores after each epoch")
args, _ = parser.parse_known_args()

# --contributor implies --local
if args.contributor:
    args.local = True

# ---------------------------------------------------------------------------
# --local / --contributor: CPU-only training without Modal
# ---------------------------------------------------------------------------
if args.local:
    contributor    = args.contributor        # None if plain --local
    batches        = 30 if contributor else BATCHES_PER_EPOCH
    mode_label     = f"CONTRIBUTOR ({contributor})" if contributor else "LOCAL"

    print(f"[albert-train] {mode_label} mode — building train_bible (no cuda feature) ...")
    build = subprocess.run(
        [CARGO, "build", "--release", "--bin", "train_bible"],
        cwd=os.path.join(HERE, "moe-llm-core"),
        env={**os.environ, "CARGO_TERM_COLOR": "never"},
    )
    if build.returncode != 0:
        print("[albert-train] build failed", file=sys.stderr)
        sys.exit(1)

    binary = os.path.join(HERE, "moe-llm-core", "target", "release", "train_bible")
    cmd = [binary, f"--root={HERE}", "--gate-diversity=0.5", "--lb-weight=0.03",
           "--div-weight=0.001", f"--batches-per-epoch={batches}"]
    print(f"[albert-train] {' '.join(cmd)}")
    if contributor:
        print(f"[albert-train] auto-spore: every checkpoint → albert-spores (as '{contributor}')")

    # Start dashboard — pass --cpu in contributor mode for relaxed stale thresholds
    dash_args = [sys.executable, DASH_SRV]
    if contributor:
        dash_args.append("--cpu")
    dash_proc = subprocess.Popen(
        dash_args,
        cwd=os.path.join(HERE, "dashboard"),
        stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
    )
    time.sleep(0.5)
    if not args.no_browser:
        for browser in ("firefox", "xdg-open", "open"):
            if shutil.which(browser):
                subprocess.Popen([browser, "http://localhost:8888"],
                                 stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
                break

    log_f        = open(LOG,           "w")
    batch_hist_f = open(BATCH_HISTORY, "a")

    _seen_x_local: set[float] = set()
    if os.path.exists(BATCH_HISTORY):
        with open(BATCH_HISTORY) as _bh:
            for _row in _bh:
                _comma = _row.find(',')
                if _comma > 0:
                    try: _seen_x_local.add(float(_row[:_comma]))
                    except ValueError: pass

    proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
                             text=True, bufsize=1)

    # Auto-spore: push checkpoint to albert-spores after each epoch.
    # Lock prevents concurrent LFS pushes when epochs are faster than the git operation.
    _spore_lock = threading.Lock()

    def _push_spore(name: str, epoch: int, loss: float):
        if not _spore_lock.acquire(blocking=False):
            print(f"[albert-train] auto-spore: ep{epoch} — push already in progress, skipping")
            return
        try:
            print(f"[albert-train] auto-spore: ep{epoch} loss={loss:.4f} → pushing ...")
            r = subprocess.run(
                [sys.executable, PRODUCE_SPORE,
                 "--name", name, "--epoch", str(epoch), "--loss", str(loss)],
                cwd=HERE,
                stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True,
            )
            for ln in r.stdout.splitlines():
                print(f"  {ln}")
            if r.returncode == 0:
                print(f"[albert-train] auto-spore: ep{epoch} done")
            else:
                print(f"[albert-train] auto-spore: ep{epoch} FAILED (rc={r.returncode})")
        finally:
            _spore_lock.release()

    def _stream_local():
        for line in proc.stdout:
            sys.stdout.write(line); sys.stdout.flush()
            log_f.write(line); log_f.flush()
            m = _BATCH_RE.match(line.strip())
            if m:
                x = int(m.group(1)) + int(m.group(2)) / batches
                if x not in _seen_x_local:
                    _seen_x_local.add(x)
                    batch_hist_f.write(f"{x:.6f},{m.group(3)}\n"); batch_hist_f.flush()
            if contributor:
                sm = _EPOCH_SM_RE.search(line)
                if sm:
                    ep   = int(sm.group(1))
                    loss = float(sm.group(3))  # loss_best, not loss_avg
                    threading.Thread(target=_push_spore,
                                     args=(contributor, ep, loss), daemon=True).start()

    def _sigint(sig, frame):
        print("\n[albert-train] stopping local training ...")
        proc.send_signal(signal.SIGINT)

    signal.signal(signal.SIGINT, _sigint)
    threading.Thread(target=_stream_local, daemon=True).start()
    proc.wait()
    log_f.close(); batch_hist_f.close()
    print("[albert-train] local run ended. Dashboard still at http://localhost:8888")
    try:
        dash_proc.wait()
    except KeyboardInterrupt:
        dash_proc.terminate()
    sys.exit(0)

# ---------------------------------------------------------------------------
# start dashboard server
# ---------------------------------------------------------------------------
print("[albert-train] starting dashboard server on http://localhost:8888 ...")
dash_proc = subprocess.Popen(
    [sys.executable, DASH_SRV],
    cwd=os.path.join(HERE, "dashboard"),
    stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
)

time.sleep(0.5)  # let server bind before opening browser

if not args.no_browser:
    for browser in ("firefox", "xdg-open", "open"):
        if shutil.which(browser):
            subprocess.Popen([browser, "http://localhost:8888"],
                             stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
            break

# ---------------------------------------------------------------------------
# start Modal training, pipe its stdout to dashboard/training.log + terminal
# ---------------------------------------------------------------------------
modal_cmd = ["modal", "run"]
if args.detach:
    modal_cmd.append("--detach")
modal_cmd.append(MODAL_PY)

print(f"[albert-train] launching:  {' '.join(modal_cmd)}")
print(f"[albert-train] log → {LOG}")
print()

log_f        = open(LOG,           "w")  # overwrite — each run starts with a clean local log
batch_hist_f = open(BATCH_HISTORY, "a")  # append — preserves history across runs

# Load x-values already in batch_history.csv so restarts never write duplicates.
_seen_x: set[float] = set()
if os.path.exists(BATCH_HISTORY):
    with open(BATCH_HISTORY) as _bh:
        for _row in _bh:
            _comma = _row.find(',')
            if _comma > 0:
                try: _seen_x.add(float(_row[:_comma]))
                except ValueError: pass

modal_proc = subprocess.Popen(
    modal_cmd,
    cwd=HERE,
    stdout=subprocess.PIPE,
    stderr=subprocess.STDOUT,
    text=True,
    bufsize=1,
)

def stream():
    for line in modal_proc.stdout:
        sys.stdout.write(line)
        sys.stdout.flush()
        log_f.write(line)
        log_f.flush()
        m = _BATCH_RE.match(line.strip())
        if m:
            epoch = int(m.group(1))
            batch = int(m.group(2))
            loss  = m.group(3)
            x = epoch + batch / BATCHES_PER_EPOCH
            if x not in _seen_x:
                _seen_x.add(x)
                batch_hist_f.write(f"{x:.6f},{loss}\n")
                batch_hist_f.flush()

stream_thread = threading.Thread(target=stream, daemon=True)
stream_thread.start()

# ---------------------------------------------------------------------------
# handle Ctrl-C: let checkpoint finish, then clean up
# ---------------------------------------------------------------------------
def on_sigint(sig, frame):
    print("\n[albert-train] Ctrl-C — modal run will finish current epoch then stop.")
    print("[albert-train] Run  albert-train pull  to sync checkpoint back.")
    modal_proc.send_signal(signal.SIGINT)

signal.signal(signal.SIGINT, on_sigint)

modal_proc.wait()
stream_thread.join(timeout=3)
log_f.close()
batch_hist_f.close()

print("\n[albert-train] Modal run ended.")
print("[albert-train] Dashboard still running at http://localhost:8888")
print("[albert-train] Run  albert-train pull  to sync checkpoint.")

# Keep dashboard alive until user kills the process
try:
    dash_proc.wait()
except KeyboardInterrupt:
    dash_proc.terminate()
