#!/usr/bin/env python3
import os, sys, subprocess, argparse
import multiprocessing as mp

# python runMapDamage.py -t 16 -i /mounts/lovelace/temporary/porter/full-MT/14_filter-presence/ -o /mounts/lovelace/temporary/porter/full-MT/15_mapdamage/ -r /mounts/lovelace2/databases/ncbi/mitochondrion/mitochondrion.1.1.genomic.fna
def parse_args():
    parser = argparse.ArgumentParser(description="Run mapDamage on BAM files in parallel.")
    parser.add_argument("-t", "--threads", type=int, default=4,help="Number of threads to use for parallel processing (default: 4)")
    parser.add_argument("-r", "--ref", "--reference", dest="reference_genome", required=True,help="Path to the reference genome FASTA file")
    parser.add_argument("-i", "--input", "--input_dir", dest="input_dir", required=True,help="Path to the input directory")
    parser.add_argument("-o", "--output", "--output_dir", dest="output_dir", required=True, help="Path to the output directory")
    return parser.parse_args()

def run_cmd(cmd):

    # Run a single command line as a subprocess
    result = subprocess.run(
        cmd,
        shell=True,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        text=True
    )
    if result.returncode != 0:
        print(f"Command failed: {cmd}")
        print(result.stderr)
        exit(result.returncode)

    return result.stdout

def run_mapdamage(bam_file, reference_genome, output_dir, input_dir):
    input_path = os.path.join(input_dir, bam_file)
    output_path = os.path.join(output_dir, os.path.splitext(bam_file)[0])

    print(f"Making output dir: {output_path}")
    run_cmd(f"mkdir {output_path}")

    command = f"mapDamage --merge-libraries -i {input_path} -r {reference_genome} -d {output_path}"
    print(f"Running command: {command}")
    run_cmd(command)

def main():
    args = parse_args()
    threads = args.threads
    reference_genome = args.reference_genome
    input_dir = args.input_dir
    output_dir = args.output_dir

    bam_files = [f for f in os.listdir(input_dir) if f.endswith(".bam")]

    if not bam_files:
        print("No .bam files found in the current working directory.")
        exit(1)

    print(f"Found {len(bam_files)} .bam files. Processing with {threads} threads using reference: {reference_genome}")

    with mp.Pool(processes=threads) as pool:
        results = pool.starmap(run_mapdamage, [(bam_file, reference_genome, output_dir, input_dir) for bam_file in bam_files])

    print("runMapDamage done.")

if __name__ == "__main__":
    main()
