"""FastAPI app: upload → stabilize → trim → generate → view/export.

Heavy CV runs in background threads keyed off the in-memory JobStore; the
frontend polls /status for progress.
"""
from __future__ import annotations

import threading
from typing import Optional

import cv2
import numpy as np
from fastapi import FastAPI, File, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, Response
from pydantic import BaseModel

from .jobs import JobStore
from .pipeline import focus, geometry, stabilize, video

app = FastAPI(title="Focus-Stacking → 3D")

# Vite dev server runs on :5173 and proxies /api, but allow direct CORS too.
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

store = JobStore()


# ---- helpers ---------------------------------------------------------------

def _require(job_id: str):
    job = store.get(job_id)
    if job is None:
        raise HTTPException(status_code=404, detail="job not found")
    return job


def _jpeg_response(bgr: np.ndarray) -> Response:
    ok, buf = cv2.imencode(".jpg", bgr, [cv2.IMWRITE_JPEG_QUALITY, 85])
    if not ok:
        raise HTTPException(status_code=500, detail="encode failed")
    return Response(content=buf.tobytes(), media_type="image/jpeg")


# ---- upload ----------------------------------------------------------------

@app.post("/api/jobs")
async def create_job(video_file: UploadFile = File(..., alias="video")):
    job = store.create()
    data = await video_file.read()
    job.input_path.write_bytes(data)
    try:
        meta = video.probe(str(job.input_path))
    except Exception as exc:  # noqa: BLE001
        store.set_error(job, f"probe failed: {exc}")
        raise HTTPException(status_code=400, detail=str(exc))
    job.meta = meta.to_dict()
    job.frame_count = meta.frame_count
    return {"job_id": job.id, "meta": job.meta}


# ---- stabilize -------------------------------------------------------------

def _run_stabilize(job_id: str) -> None:
    job = store.require(job_id)
    try:
        store.set_status(job, "stabilizing", progress=0.0)
        n = stabilize.stabilize(
            str(job.input_path),
            job.aligned_dir,
            progress_cb=lambda p: store.set_progress(job, p),
        )
        job.frame_count = n
        store.set_status(job, "stabilized", progress=1.0)
    except Exception as exc:  # noqa: BLE001
        store.set_error(job, f"stabilize failed: {exc}")


@app.post("/api/jobs/{job_id}/stabilize")
def start_stabilize(job_id: str):
    job = _require(job_id)
    if job.status == "stabilizing":
        return job.to_status()
    threading.Thread(target=_run_stabilize, args=(job.id,), daemon=True).start()
    return {"job_id": job.id, "status": "stabilizing"}


@app.get("/api/jobs/{job_id}/status")
def status(job_id: str):
    return _require(job_id).to_status()


@app.get("/api/jobs/{job_id}/frames/{idx}")
def frame(job_id: str, idx: int):
    job = _require(job_id)
    path = job.aligned_dir / f"{idx:05d}.png"
    if not path.exists():
        raise HTTPException(status_code=404, detail="frame not found")
    img = cv2.imread(str(path))
    if img is None:
        raise HTTPException(status_code=500, detail="frame read failed")
    return _jpeg_response(img)


@app.get("/api/jobs/{job_id}/focus/{idx}")
def focus_preview(job_id: str, idx: int, threshold: Optional[float] = None):
    job = _require(job_id)
    if threshold is None:
        threshold = focus.threshold_suggestion(job.aligned_dir, 0, job.frame_count - 1)
    overlay = focus.mask_overlay(job.aligned_dir, idx, threshold)
    return _jpeg_response(overlay)


# ---- generate --------------------------------------------------------------

class GenerateBody(BaseModel):
    front: int
    back: int
    z_step: float = 1.0
    focus_threshold: Optional[float] = None
    downsample: int = 1
    xy_scale: float = 1.0


def _run_generate(job_id: str, body: GenerateBody) -> None:
    job = store.require(job_id)
    try:
        store.set_status(job, "generating", progress=0.0)
        threshold = body.focus_threshold
        if threshold is None:
            threshold = focus.threshold_suggestion(job.aligned_dir, body.front, body.back)
        depth_res = focus.compute_depth(
            job.aligned_dir,
            body.front,
            body.back,
            threshold,
            progress_cb=lambda p: store.set_progress(job, 0.9 * p),
        )
        geo = geometry.build(
            depth_res.depth,
            depth_res.in_focus_mask,
            job.ply_path,
            job.stl_path,
            z_step=body.z_step,
            xy_scale=body.xy_scale,
            downsample=body.downsample,
        )
        job.result = {
            "vertex_count": geo.vertex_count,
            "triangle_count": geo.triangle_count,
            "bbox": geo.bbox,
            "front": depth_res.front,
            "back": depth_res.back,
            "focus_threshold": float(threshold),
        }
        store.set_status(job, "ready", progress=1.0)
    except Exception as exc:  # noqa: BLE001
        store.set_error(job, f"generate failed: {exc}")


@app.post("/api/jobs/{job_id}/generate")
def start_generate(job_id: str, body: GenerateBody):
    job = _require(job_id)
    threading.Thread(target=_run_generate, args=(job.id, body), daemon=True).start()
    return {"job_id": job.id, "status": "generating"}


@app.get("/api/jobs/{job_id}/points")
def points(job_id: str, max_points: int = 200_000):
    """Decimated vertices as binary Float32 (xyz interleaved) for three.js."""
    job = _require(job_id)
    if not job.ply_path.exists():
        raise HTTPException(status_code=404, detail="no geometry yet")
    verts = _read_ply_points(job.ply_path)
    n = verts.shape[0]
    if n > max_points and n > 0:
        stride = int(np.ceil(n / max_points))
        verts = verts[::stride]
    payload = np.ascontiguousarray(verts, dtype="<f4").tobytes()
    return Response(
        content=payload,
        media_type="application/octet-stream",
        headers={"X-Vertex-Count": str(verts.shape[0])},
    )


@app.get("/api/jobs/{job_id}/download/{kind}")
def download(job_id: str, kind: str):
    job = _require(job_id)
    if kind == "ply":
        path = job.ply_path
    elif kind == "stl":
        path = job.stl_path
    else:
        raise HTTPException(status_code=400, detail="kind must be ply or stl")
    if not path.exists():
        raise HTTPException(status_code=404, detail="file not generated")
    return FileResponse(
        str(path), media_type="application/octet-stream", filename=f"model.{kind}"
    )


def _read_ply_points(path) -> np.ndarray:
    """Read xyz floats from our binary-little-endian PLY (header ends at
    'end_header\\n')."""
    with open(path, "rb") as f:
        raw = f.read()
    marker = b"end_header\n"
    i = raw.find(marker)
    if i < 0:
        return np.zeros((0, 3), dtype=np.float32)
    body = raw[i + len(marker):]
    arr = np.frombuffer(body, dtype="<f4")
    return arr.reshape(-1, 3)


@app.get("/api/health")
def health():
    return {"ok": True}
