#!/usr/bin/env python3
# Split BAM file by reference number (Now in Python!)
# Porter L

# Example:
# splitByRef.py -i my_input_file.bam -o my_output_dir -s sample_id -t 8

import argparse, subprocess, os, sys, multiprocessing
from functools import partial
import pysam

def parse_args():
    parser = argparse.ArgumentParser(description="Split BAM file by reference number")
    parser.add_argument("-i", "--input_bam", required=True, help="Path to input BAM file")
    parser.add_argument("-o", "--output_dir", required=True, help="Path to output directory")
    parser.add_argument("-s", "--sample_id", required=True, help="Sample ID")
    parser.add_argument("-t", "--threads", required=True, help="CPU Threads to use")
    return parser.parse_args()

def run_cmd(cmd):
    print(f"Running command: {cmd}")

    # Run a single command line as a subprocess
    result = subprocess.run(
        cmd, 
        shell=True, 
        stdout=subprocess.PIPE, 
        stderr=subprocess.PIPE, 
        text=True
    )
    
    if result.returncode != 0:
        print(f"Command failed: {cmd}")
        print(result.stderr)
        exit(result.returncode)

    return result.stdout

def get_ref_list(input_bam):
    # Get all references directly from a single BAM file

    idxstats_output = run_cmd(f"samtools idxstats {input_bam}")
    refs = []
    for line in idxstats_output.strip().splitlines():
        split_line = line.split("\t")

        reads_ref = split_line[0]
        #read_length = int(split_line[1]) # could use this later?
        mapped_reads = int(split_line[2])
        unmapped_reads = int(split_line[3])


        if mapped_reads + unmapped_reads != 0:
            refs.append(reads_ref)
    
    return refs

def process_ref(input_bam, output_dir, sample_id, ref):
    # Process a single Reference ID (accession # or TaxID)
    # Create a new BAM file in the format <sampleID>-<referenceID>.bam
    out_bam = os.path.join(output_dir, f"{sample_id}-{ref}.bam")
    run_cmd(f"samtools view -b {input_bam} {ref} > {out_bam}")


def main():
    # Parse all input from command line
    # (or from wrapper in workflow)
    args = parse_args()
    input_bam = args.input_bam
    output_dir = args.output_dir
    sample_id = args.sample_id
    threads = int(args.threads)

    # Get refs
    refs = get_ref_list(input_bam)

    # Create worker pool 
    num_workers = min(multiprocessing.cpu_count(), threads)
    with multiprocessing.Pool(processes=num_workers) as pool:
        pool.map(partial(process_ref, input_bam, output_dir, sample_id), refs)
    print("** SplitByRef done! **")

if __name__ == "__main__":
    main()
