#!/usr/bin/env python3
# Merge BAM file by reference number
# Charlie P, Porter L

# Example:
# mergeByHorizon.py -i /input/dir -o /output/dir -s 1_1_S1,1_2_S2,1_3_S3 -t 8 -z 1

import argparse, subprocess, os, sys, multiprocessing
from functools import partial
from pathlib import Path

def parse_args():
    parser = argparse.ArgumentParser(description="Split BAM file by reference number")
    parser.add_argument("-i", "--input_dir", required=True, help="Path to input files")
    parser.add_argument("-o", "--output_dir", required=True, help="Path to output directory")
    parser.add_argument("-s", "--sample_ids", required=True, help="Sample IDs")
    parser.add_argument("-z", "--horizon_id", required=True, help="Horizon ID for output")
    parser.add_argument("-t", "--threads", required=True, help="CPU Threads to use")
    return parser.parse_args()

def run_cmd(cmd):
    print(f"Running command: {cmd}")
    # Run a single command line as a subprocess
    result = subprocess.run(
        cmd, 
        shell=True, 
        stdout=subprocess.PIPE, 
        stderr=subprocess.PIPE, 
        text=True
    )
    if result.returncode != 0:
        print(f"Command failed: {cmd}")
        print(result.stderr)
        exit(result.returncode)

    return result.stdout

def create_merge_commands(input_dir, output_dir, sample_ids, horizon_id):
    files = []
    command_queue = []

    output_string = os.path.join(output_dir, f"{horizon_id}.bam") 

    for i in sample_ids:
        files.append(os.path.join(input_dir,f"{i}.bam"))
    input_string = " ".join(files)

    cmd = f"samtools merge -o {output_string} {input_string}"
    command_queue.append(cmd)

    return command_queue

def create_copy_commands(input_dir, output_dir, sample_ids):
    command_queue = []
    for sample_id in sample_ids:
        for f in os.listdir(input_dir):
            if os.path.splitext(f)[-1] == ".bam":
                if os.path.isfile(os.path.join(input_dir, f)) and sample_id in f:
                    command_queue.append(f"cp {os.path.join(input_dir, f)} {os.path.join(output_dir, f)}")

    return command_queue

def main():
    # Parse all input from command line
    # (or from wrapper in workflow)
    args = parse_args()
    input_dir = args.input_dir
    output_dir = args.output_dir
    sample_ids = args.sample_ids.split(",")
    horizon_id = args.horizon_id
    threads = int(args.threads)

    # Build Commands
    command_queue = []

    if int(horizon_id) != 0:
        command_queue = create_merge_commands(input_dir, output_dir, sample_ids,horizon_id)
    else:
        command_queue = create_copy_commands(input_dir, output_dir, sample_ids)
    
    # Create worker pool 
    num_workers = min(multiprocessing.cpu_count(), threads)
    with multiprocessing.Pool(processes=num_workers) as pool:
        pool.map(run_cmd, command_queue)
    print("** mergeByHorizon done! **")

if __name__ == "__main__":
    main()
