#!/usr/bin/env python3
import os
import sys
import subprocess
import argparse
import multiprocessing as mp

def parse_args():
    parser = argparse.ArgumentParser(description="Run metaDMG-cpp pipeline (lca, dfit, aggregate) on BAM files.")
    parser.add_argument("-i", "--input_dir", required=True, help="Directory containing input BAM files")
    parser.add_argument("-o", "--output_dir", required=True, help="Directory for output files")
    parser.add_argument("-t", "--threads", type=int, default=10, help="Threads per metaDMG process (default: 10)")
    parser.add_argument("-p", "--parallel", type=int, default=1, help="Number of BAM files to process in parallel (default: 1)")
    
    # metaDMG specific data files
    parser.add_argument("--names", required=True, help="Path to names.dmp")
    parser.add_argument("--nodes", required=True, help="Path to nodes.dmp")
    parser.add_argument("--acc2tax", required=True, help="Path to nucl_gb.accession2taxid")
    
    return parser.parse_args()

def run_pipeline(bam_file, args):
    sample_name = os.path.splitext(bam_file)[0]
    input_path = os.path.join(args.input_dir, bam_file)
    output_prefix = os.path.join(args.output_dir, f"h{sample_name}")
    
    # 1. LCA Command
    lca_cmd = [
        "metaDMG-cpp", "lca",
        "--names", args.names,
        "--nodes", args.nodes,
        "--acc2tax", args.acc2tax,
        "--sim_score_low", "0.95",
        "--sim_score_high", "1.0",
        "--how_many", "30",
        "--weight_type", "1",
        "--fix_ncbi", "0",
        "--lca_rank", "genus",
        "--threads", str(args.threads),
        "--bam", input_path,
        "--out_prefix", output_prefix
    ]

    # 2. DFIT Command
    dfit_cmd = [
        "metaDMG-cpp", "dfit",
        f"{output_prefix}.bdamage.gz",
        "--threads", str(args.threads),
        "--names", args.names,
        "--nodes", args.nodes,
        "--showfits", "2",
        "--nopt", "10",
        "--nbootstrap", "20",
        "--doboot", "1",
        "--seed", "1234",
        "--lib", "ds",
        "--out_prefix", output_prefix
    ]

    # 3. AGGREGATE Command
    agg_cmd = [
        "metaDMG-cpp", "aggregate",
        f"{output_prefix}.bdamage.gz",
        "--names", args.names,
        "--nodes", args.nodes,
        "--lcastat", f"{output_prefix}.stat.gz",
        "--dfit", f"{output_prefix}.dfit.gz",
        "--out_prefix", f"{output_prefix}.agg"
    ]

    try:
        # Execute the sequence of commands
        for cmd in [lca_cmd, dfit_cmd, agg_cmd]:
            subprocess.run(cmd, check=True, capture_output=True, text=True)
        return (True, f"Success: {sample_name}")
    except subprocess.CalledProcessError as e:
        err_msg = e.stderr.strip() if e.stderr else "Unknown Error"
        return (False, f"Failed: {sample_name} at step {' '.join(e.cmd[:2])}\n  Error: {err_msg}")

def main():
    args = parse_args()
    
    if not os.path.isdir(args.input_dir):
        print(f"Error: Input directory '{args.input_dir}' does not exist.")
        sys.exit(1)

    os.makedirs(args.output_dir, exist_ok=True)

    bam_files = [f for f in os.listdir(args.input_dir) if f.endswith(".bam")]

    if not bam_files:
        print(f"No .bam files found in {args.input_dir}")
        sys.exit(1)

    print(f"--- Starting metaDMG Pipeline ---")
    print(f"Input:    {args.input_dir}")
    print(f"Output:   {args.output_dir}")
    print(f"Files:    {len(bam_files)}")
    print(f"Threads:  {args.threads} (per file)")
    print(f"Parallel: {args.parallel} (simultaneous files)")
    print("---------------------------------")

    # Prepare arguments for parallel execution
    # Note: We pass the whole 'args' object to access paths easily
    tasks = [(f, args) for f in bam_files]

    # Execute in parallel
    with mp.Pool(processes=args.parallel) as pool:
        results = pool.starmap(run_pipeline, tasks)

    # Process results
    successes = 0
    failures = 0

    print("\n--- Processing Report ---")
    for success, message in results:
        if success:
            successes += 1
        else:
            failures += 1
            print(message)

    print("-" * 30)
    print(f"Total Files: {len(bam_files)}")
    print(f"Successful:  {successes}")
    print(f"Failed:      {failures}")

    if failures > 0:
        sys.exit(1)

if __name__ == "__main__":
    main()