import os
import sys
import multiprocessing as mp
import argparse

def run_DamageProfiler(bam_file, reference_genome):
    command = f"java -jar /eccs/home/charliep/aDNA-packages/DamageProfiler-1.1-java11.jar -i {bam_file} -o {bam_file}.DamageProfiler -r {reference_genome}"
    print(f"Running: {command}")
    return_code = os.system(command)

    if return_code != 0:
        print(f"Error running command for {bam_file}. Return code: {return_code}")
    return return_code

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run DamageProfiler on BAM files in parallel.")
    parser.add_argument("-t", "--threads", type=int, default=8,
                        help="Number of threads to use for parallel processing (default: 8)")
    parser.add_argument("-r", "--ref", "--reference", dest="reference_genome", required=True,
                        help="Path to the reference genome FASTA file")
    args = parser.parse_args()

    threads = args.threads
    reference_genome = args.reference_genome

    bam_files = [f for f in os.listdir('.') if f.endswith(".bam")]

    if not bam_files:
        print("No .bam files found in the current working directory.")
        sys.exit(0)

    print(f"Found {len(bam_files)} .bam files. Processing with {threads} threads using reference: {reference_genome}")

    with mp.Pool(processes=threads) as pool:
        results = pool.starmap(run_DamageProfiler, [(bam_file, reference_genome) for bam_file in bam_files])

    print("run_DamageProfiler done.")

