"""End-to-end pipeline test on a synthetic focus-sweep clip.

Run directly (no pytest needed):  python -m tests.test_pipeline
Asserts: stabilization produces N aligned frames, depth-from-focus yields
in-focus vertices, recovered depth increases monotonically with image row, and
the PLY/STL files are structurally valid.
"""
from __future__ import annotations

import struct
import sys
import tempfile
from pathlib import Path

import numpy as np

# Allow `python -m tests.test_pipeline` from backend/.
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

from app.pipeline import focus, geometry, stabilize, video  # noqa: E402
from tests.make_synthetic import make_clip  # noqa: E402


def _check_ply(path: Path) -> int:
    raw = path.read_bytes()
    assert raw.startswith(b"ply\n"), "PLY magic missing"
    marker = b"end_header\n"
    i = raw.find(marker)
    assert i > 0, "PLY header end missing"
    header = raw[:i].decode("ascii")
    count = int([l for l in header.splitlines() if l.startswith("element vertex")][0].split()[-1])
    body = raw[i + len(marker):]
    assert len(body) == count * 12, "PLY body size mismatch"
    return count


def _check_stl(path: Path) -> int:
    raw = path.read_bytes()
    assert len(raw) >= 84, "STL too short"
    (tri_count,) = struct.unpack("<I", raw[80:84])
    assert len(raw) == 84 + tri_count * 50, "STL triangle-count/body mismatch"
    return tri_count


def run() -> None:
    with tempfile.TemporaryDirectory() as td:
        tmp = Path(td)
        clip = make_clip(tmp / "synthetic.mp4")

        meta = video.probe(str(clip))
        assert meta.frame_count > 0, "probe found no frames"
        print(f"probe: {meta.frame_count} frames {meta.width}x{meta.height} @ {meta.fps:.1f}fps")

        aligned = tmp / "aligned"
        n = stabilize.stabilize(str(clip), aligned)
        assert n == meta.frame_count, f"aligned {n} != probed {meta.frame_count}"
        assert len(list(aligned.glob("*.png"))) == n
        print(f"stabilized: {n} aligned frames")

        res = focus.compute_depth(aligned, front=0, back=n - 1, focus_threshold=0.0)
        assert res.in_focus_mask.any(), "no in-focus pixels"

        # Depth-vs-row monotonicity: mean recovered depth per row should rise.
        depth = res.depth.astype(np.float32)
        mask = res.in_focus_mask
        rows, row_depth = [], []
        h = depth.shape[0]
        for r in range(0, h, 10):
            band = mask[r:r + 10]
            if band.any():
                rows.append(r)
                row_depth.append(float(depth[r:r + 10][band].mean()))
        rows = np.array(rows)
        row_depth = np.array(row_depth)
        corr = float(np.corrcoef(rows, row_depth)[0, 1])
        print(f"depth-vs-row correlation: {corr:.3f}")
        assert corr > 0.8, f"depth not monotonic with row (corr={corr:.3f})"

        geo = geometry.build(
            res.depth, res.in_focus_mask,
            tmp / "out.ply", tmp / "out.stl",
            z_step=1.0, downsample=1,
        )
        assert geo.vertex_count > 0, "no vertices written"
        print(f"geometry: {geo.vertex_count} verts, {geo.triangle_count} tris, bbox={geo.bbox}")

        ply_n = _check_ply(tmp / "out.ply")
        stl_n = _check_stl(tmp / "out.stl")
        assert ply_n == geo.vertex_count
        assert stl_n == geo.triangle_count
        print(f"file validity OK: PLY {ply_n} verts, STL {stl_n} tris")

    print("\nALL PIPELINE ASSERTIONS PASSED")


if __name__ == "__main__":
    run()
