import os
import sys
from collections import defaultdict
import re

def group_and_sort_bam_files(directory, output_file):
    bam_files = [f for f in os.listdir(directory) if f.endswith(".bam")]

    def get_suffix(filename):
        match = re.search(r"(NW|NC|NT).*\.bam$", filename)
        if match:
            return match.group(0)[:-4]  # Exclude ".bam"
        return "XXX-Error"

    def sort_key(filename):
        base_name = filename[:-4]
        prefix5 = base_name[:5]
        suffix = get_suffix(filename)
        return (suffix, prefix5)

    sorted_bam_files = sorted(bam_files, key=sort_key)

    grouped_files = defaultdict(list)

    for filename in sorted_bam_files:
        base_name = filename[:-4]
        first_char = base_name[0]
        suffix = get_suffix(filename)
        if suffix:
            key = (first_char, suffix)
            grouped_files[key].append(filename)

    list_of_lists = list(grouped_files.values())

    try:
        with open(output_file, 'w') as outfile:

            for group in list_of_lists:
                if group:
                    first_filename = group[0]
                    base_name = first_filename[:-4]
                    first_char = base_name[0]
                    suffix = get_suffix(first_filename)

                    if suffix:
                        output_bam_name_suffix = suffix.replace(".", "_") + ".bam"
                        output_bam_name = f"{first_char}_{output_bam_name_suffix}"
                        merge_command = f"samtools merge -o tmp/{output_bam_name} {' '.join(group)}"
                        outfile.write(merge_command + '\n')
        print(f"samtools merge commands written to: {output_file}")

    except IOError:
        print(f"Error: Could not write to the output file: {output_file}")

if __name__ == "__main__":
    if len(sys.argv) != 3:
        print("Usage: python makeBAMmergeScript.py <directory> <output_file>")
        sys.exit(1)

    directory_path = sys.argv[1]
    output_filename = sys.argv[2]

    if not os.path.isdir(directory_path):
        print(f"Error: Directory not found: {directory_path}")
        sys.exit(1)

    group_and_sort_bam_files(directory_path, output_filename)