#!/usr/bin/env python3
# Split BAM file by reference number (Now in Python!)
# Porter L

# Example:
# splitByRef.py -i my_input_file.bam -o my_output_dir -s sample_id -t 8

import argparse
import os
import multiprocessing
from functools import partial
import pysam
import subprocess

def parse_args():
    parser = argparse.ArgumentParser(description="Split BAM file by reference number")
    parser.add_argument("-i", "--input_bam", required=True, help="Path to input BAM file")
    parser.add_argument("-o", "--output_dir", required=True, help="Path to output directory")
    parser.add_argument("-s", "--sample_id", required=True, help="Sample ID")
    parser.add_argument("-t", "--threads", required=True, help="CPU Threads to use")
    return parser.parse_args()

def get_ref_list(input_bam):
    # Get all references directly from a single BAM file
    try:
        bam_file = pysam.AlignmentFile(input_bam, "rb")
        refs = [
            ref
            for ref, length, mapped_reads, unmapped_reads in bam_file.get_index_statistics()
            if mapped_reads + unmapped_reads != 0
        ]
        bam_file.close()
        return refs
    except Exception as e:
        print(f"Error getting reference list: {e}")
        return []

def process_ref(input_bam, output_dir, sample_id, ref):
    # Process a single Reference ID (accession # or TaxID)
    try:
        # Step 1: Create the new minimal header in a temporary SAM file using grep
        temp_header_path = os.path.join(output_dir, f"{sample_id}-{ref}.header.sam")
        
        # Use subprocess to run samtools view and grep
        # Using a raw string (r"...") to handle the backslashes correctly
        header_cmd = rf"samtools view -H {input_bam} | grep -E '^(@HD|@PG|@RG|@CO|@SQ\s+SN:{ref}(\s+|$))' > {temp_header_path}"
        subprocess.run(header_cmd, shell=True, check=True)

        # Step 2: Use samtools view to create a temporary BAM with only the reads for this reference
        temp_reads_bam_path = os.path.join(output_dir, f"{sample_id}-{ref}.temp.bam")
        subprocess.run(
            ["samtools", "view", "-b", "-o", temp_reads_bam_path, input_bam, ref],
            check=True
        )

        # Step 3: Use samtools reheader to replace the header of the temporary BAM
        out_bam_path = os.path.join(output_dir, f"{sample_id}-{ref}.bam")
        subprocess.run(
            ["samtools", "reheader", temp_header_path, temp_reads_bam_path],
            stdout=open(out_bam_path, 'wb'),
            check=True
        )

        # Step 4: Clean up temporary files
        os.remove(temp_reads_bam_path)
        os.remove(temp_header_path)

        print(f"Successfully processed {ref}, saved to {out_bam_path}")
        
    except subprocess.CalledProcessError as e:
        print(f"Command failed during processing of {ref}: {e}")
        # Clean up temporary files if they exist to prevent clutter
        if os.path.exists(temp_reads_bam_path):
            os.remove(temp_reads_bam_path)
        if os.path.exists(temp_header_path):
            os.remove(temp_header_path)
    except Exception as e:
        print(f"Error processing reference {ref}: {e}")

import subprocess
import os

def fix_bam_header(input_bam, output_bam_fixed):
    """
    Fixes a BAM header by generating a new one from idxstats and reheadering the file.
    """
    try:
        # Step 1: Get all reference IDs and their lengths from the BAM index
        idxstats_output = subprocess.run(
            ["samtools", "idxstats", input_bam],
            capture_output=True,
            text=True,
            check=True
        ).stdout

        # Create a new header string
        new_header_str = "@HD\tVN:1.6\tSO:coordinate\n"
        for line in idxstats_output.strip().splitlines():
            parts = line.split('\t')
            ref_name = parts[0]
            ref_length = parts[1]
            mapped_reads = int(parts[2])
            unmapped_reads = int(parts[3])

            if mapped_reads + unmapped_reads > 0 and ref_name != '*':
                new_header_str += f"@SQ\tSN:{ref_name}\tLN:{ref_length}\n"

        # Step 2: Write the new header to a temporary file
        temp_fixed_header_path = f"{output_bam_fixed}.fixed_header.sam"
        with open(temp_fixed_header_path, "w") as f:
            f.write(new_header_str)

        # Step 3: Reheader the original BAM file
        print(f"Fixing header for {input_bam}...")
        subprocess.run(
            ["samtools", "reheader", temp_fixed_header_path, input_bam],
            stdout=open(output_bam_fixed, 'wb'),
            check=True
        )
        os.remove(temp_fixed_header_path)
        print(f"Header fixed. New file saved to {output_bam_fixed}")
        return True
    except subprocess.CalledProcessError as e:
        print(f"Error fixing BAM header: {e.stderr}")
        return False
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
        return False

# Example usage in your main() function
def main():
    # ... your existing setup ...
    args = parse_args()
    input_bam = args.input_bam
    output_dir = args.output_dir
    sample_id = args.sample_id
    threads = int(args.threads)

    # First, fix the header of the input BAM file
    fixed_bam_path = os.path.join(output_dir, f"{sample_id}.fixed_header.bam")
    if not fix_bam_header(input_bam, fixed_bam_path):
        print("Failed to fix BAM header. Exiting.")
        return

    # Now, get the refs from the fixed BAM and proceed
    refs = get_ref_list(fixed_bam_path)

    # ... The rest of your existing main function ...
    # Make sure to pass the new 'fixed_bam_path' to process_ref
    with multiprocessing.Pool(processes=threads) as pool:
        pool.map(partial(process_ref, fixed_bam_path, output_dir, sample_id), refs)
    print("** SplitByRef done! **")

if __name__ == "__main__":
    main()