#!/usr/bin/env python3
import os
import sys
import subprocess
import argparse
import multiprocessing as mp

def parse_args():
    parser = argparse.ArgumentParser(description="Run PMD on BAM files to generate DS:Z tags.")
    parser.add_argument("-t", "--threads", type=int, default=4, help="Number of threads to use (default: 4)")
    parser.add_argument("-i", "--input_dir", dest="input_dir", required=True, help="Path to the input directory")
    parser.add_argument("-o", "--output_dir", dest="output_dir", required=True, help="Path to the output directory")
    parser.add_argument("-r", "--threshold", type=int, required=True, help="PMD Threshold (use a low number like -100 to keep ALL reads)")
    return parser.parse_args()

def run_pipeline(bam_file, threshold, output_dir, input_dir):
    input_path = os.path.join(input_dir, bam_file)
    output_path = os.path.join(output_dir, os.path.splitext(bam_file)[0] + ".bam")
    cmd = (
        f'samtools view -h "{input_path}" | '
        f'pmdtools --threshold {threshold} --writesamfield --header | '
        f'samtools view -Sb - > "{output_path}"'
    )

    try:
        # Run command and capture errors if they occur
        subprocess.run(cmd, shell=True, check=True, stderr=subprocess.PIPE)
        return (True, f"Success: {bam_file}")
    except subprocess.CalledProcessError as e:
        # If failed, decode error message
        err_msg = e.stderr.decode('utf-8').strip() if e.stderr else "Unknown Error"
        return (False, f"Failed: {bam_file}\n  Error: {err_msg}")

def main():
    args = parse_args()
    
    # Check inputs
    if not os.path.isdir(args.input_dir):
        print(f"Error: Input directory '{args.input_dir}' does not exist.")
        sys.exit(1)

    # check if output exists
    if not os.path.isdir(args.output_dir):
        os.makedirs(args.output_dir, exist_ok=True)

    # Gather BAM files
    bam_files = [f for f in os.listdir(args.input_dir) if f.endswith(".bam")]

    if not bam_files:
        print(f"No .bam files found in {args.input_dir}")
        sys.exit(1)

    print(f"--- Starting PMDtools Pipeline ---")
    print(f"Input:     {args.input_dir}")
    print(f"Output:    {args.output_dir}")
    print(f"Files:     {len(bam_files)}")
    print(f"Threads:   {args.threads}")
    print(f"Threshold: {args.threshold}")
    print("----------------------------------")

    # Prepare arguments for parallel execution
    tasks = [(f, args.threshold, args.output_dir, args.input_dir) for f in bam_files]

    # Execute in parallel
    with mp.Pool(processes=args.threads) as pool:
        results = pool.starmap(run_pipeline, tasks)

    # Process results
    successes = 0
    failures = 0

    print("\n--- Processing Report ---")
    for success, message in results:
        if success:
            successes += 1
        else:
            failures += 1
            print(message) # Always print failures

    print("-" * 30)
    print(f"Total Files: {len(bam_files)}")
    print(f"Successful:  {successes}")
    print(f"Failed:      {failures}")

    if failures > 0:
        sys.exit(1)

if __name__ == "__main__":
    main()