import os
import sys
import multiprocessing as mp
import argparse

def run_mapdamage(bam_file, reference_genome):
    command = f"mapDamage --log-level ERROR --merge-libraries -i {bam_file} -r {reference_genome}"
    print(f"Running: {command}")
    return_code = os.system(command)

    if return_code != 0:
        print(f"Error running command for {bam_file}. Return code: {return_code}")
    return return_code

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run mapDamage on BAM files in parallel.")
    parser.add_argument("-t", "--threads", type=int, default=4,
                        help="Number of threads to use for parallel processing (default: 4)")
    parser.add_argument("-r", "--ref", "--reference", dest="reference_genome", required=True,
                        help="Path to the reference genome FASTA file")
    args = parser.parse_args()

    threads = args.threads
    reference_genome = args.reference_genome

    bam_files = [f for f in os.listdir('.') if f.endswith(".bam")]

    if not bam_files:
        print("No .bam files found in the current working directory.")
        sys.exit(0)

    print(f"Found {len(bam_files)} .bam files. Processing with {threads} threads using reference: {reference_genome}")

    with mp.Pool(processes=threads) as pool:
        results = pool.starmap(run_mapdamage, [(bam_file, reference_genome) for bam_file in bam_files])

    print("runMapDamage done.")

