#!/usr/bin/env python3

# Downsample a BAM using samtools with a reproducible seed, mirroring flowcell.R logic.

# Usage:
# python downsample_bam.py --p 0.35 --in input.bam --out output.bam \
#   [--seed 175246] [--threads 4] [--samtools /path/to/samtools]

import argparse, os, shutil, subprocess, sys
from typing import Tuple

def parse_args():
    parser = argparse.ArgumentParser(description="Downsample a BAM using samtools -s seed.p")
    parser.add_argument("--p", type=float, required=True, help="Downsampling probability (0<p<=1)")
    parser.add_argument("--in", dest="in_bam", required=True, help="Input BAM file")
    parser.add_argument("--out", dest="out_bam", required=True, help="Output BAM file")
    parser.add_argument("--seed", type=int, default=175246, help="Random seed for samtools -s (default: 175246)")
    parser.add_argument("--threads", type=int, default=1, help="Threads for samtools view -@ (default: 1)")
    parser.add_argument("--samtools", default="", help="Path to samtools (default: use PATH)")
    return parser.parse_args()


def format_sampling_arg(seed: int, p: float) -> str:
    if not (0 < p <= 1):
        raise ValueError("p must be in (0, 1].")
    micros = int(round(p * 1_000_000))
    # Clamp to valid range [1, 1_000_000] for safety (samtools expects 1..1e6)
    if micros < 1:
        micros = 1
    if micros > 1_000_000:
        micros = 1_000_000
    return f"{seed}.{micros:06d}"


def run_cmd(cmd):
    # Run a single command line as a subprocess
    print(f"Running: {cmd}")
    result = subprocess.run(
        cmd,
        shell=True,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        text=True
    )
    if result.returncode != 0:
        print(f"Command failed: {cmd}")
        print(result.stderr)
        exit(result.returncode)

    return result.stdout


def main():


    args = parser.parse_args()

    # Validate inputs
    if not (0 < args.p <= 1):
        raise SystemExit("--p must be in (0,1].")
    if not os.path.exists(args.in_bam):
        raise SystemExit(f"Input BAM not found: {args.in_bam}")

    samp_arg = format_sampling_arg(args.seed, args.p)

    # Ensure output directory exists
    out_dir = os.path.dirname(os.path.abspath(args.out_bam))
    if out_dir and not os.path.exists(out_dir):
        os.makedirs(out_dir, exist_ok=True)

    # Build commands
    view_cmd = [
        "samtools view",
        "-@", str(args.threads),
        "-s", samp_arg,
        "-b",
        args.in_bam,
        "-o", args.out_bam,
    ]

    index_cmd = ["samtools index", args.out_bam]

    # Run
    run_cmd(' '.join(view_cmd))
    run_cmd(' '.join(index_cmd))
    print(f"Downsampled BAM written: {args.out_bam}")


if __name__ == "__main__":
    main()
