#!/usr/bin/env python3
import os, sys, subprocess, argparse, shutil, glob
import multiprocessing as mp

def parse_args():
    parser = argparse.ArgumentParser(description="Run bowtie2 in parallel with a split database")
    # some of these are just here in order to redirect them from bowtie
    parser.add_argument('--input_pair', action='append', help='Input pairs (multiple allowed)', required=True)
    parser.add_argument("-x", "--database", required=True, help="Database")
    parser.add_argument("--ramdisk", required=True)
    parser.add_argument("--threads", required=True)
    parser.add_argument("-S", "--output", required=True)
    return parser.parse_known_args()

def run_cmd(cmd):
    # Run a single command line as a subprocess
    print(f"Running: {cmd}")
    result = subprocess.run(
        cmd,
        shell=True,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        text=True
    )
    if result.returncode != 0:
        print(f"Command failed: {cmd}")
        print(result.stderr)
        exit(result.returncode)

    return result.stdout

def run_bowtie(database, samplepair, bowtie_args, output_parent):
    
    #print(f"=== Running sample {samplepair} with database {database}.")

    # Get info for tagging
    splitpair = samplepair.split(",")
    sample_name = splitpair[0].split("/")[-1].split(".")[0][0:-3]
    horizon = sample_name.split("_")[0]
    chunk = database.split("_")[-1]

    # Fix bowtie args for sample
    bowtie_args = " ".join(bowtie_args)
    bowtie_args = bowtie_args.replace("SAMPLE", sample_name)
    bowtie_args = bowtie_args.replace("HORIZON", horizon)

    # Create output dir (If none exists)
    sample_output_dir = os.path.join(output_parent, sample_name)
    sample_output_file = os.path.join(output_parent, sample_name, f"{sample_name}_{chunk}.sam")
    if not os.path.exists(sample_output_dir):
        os.makedirs(sample_output_dir)

    print(f"=== RUNNING BOWTIE ({sample_output_file})")
    command = f"bowtie2 -x {database} -S {sample_output_file} --threads 8 {bowtie_args} -1 {splitpair[0]} -2 {splitpair[1]}"
    run_cmd(command)


if __name__ == "__main__":
    args, bowtie_args = parse_args()

    # Collect all database chunks from parent directory
    database_chunks = []
    for d in os.listdir(args.database):
        database_chunks.append(os.path.join(args.database, d))
    
    # Do actual alignments
    for chunk in database_chunks:
        # Load Chunk into RAMdisk
        print(f"= COPYING {chunk} into {args.ramdisk}...")
        files = glob.glob(os.path.join(chunk,"*.bt2*")) # either bt2 or bt2l
        for f in files:
            shutil.copy2(f, args.ramdisk)

        bowtie_database_path = os.path.join(args.ramdisk, chunk.split("/")[-1])

        # Run each sample against current chunk
        with mp.Pool(processes=4) as pool:
            results = pool.starmap(run_bowtie, [(bowtie_database_path, samplepair, bowtie_args, args.output) for samplepair in args.input_pair])

        # Clean up RAMdisk
        print(f"= CLEANING up {chunk} into {args.ramdisk}...")
        files = glob.glob(os.path.join(args.ramdisk, "*.bt2*")) # either bt2 or bt2l
        for f in files:
            os.remove(f)
        print(f"= DONE with {chunk}!\n")
    
    # for each sample (still split at this point)
    for sample_dir in os.listdir(args.output):
        print(f"= MERGING sample {sample_dir}")

        # convert all SAM to BAM
        for sample_section in glob.glob(os.path.join(args.output,sample_dir,"*.sam")):
            sample_section_out = sample_section.replace(".sam", ".bam")

            command = f"samtools view -bS {sample_section} > {sample_section_out}"
            command2 = f"rm {sample_section}"
            run_cmd(command)
            run_cmd(command2)

        # Sort all BAM by coord (change filetype to .sort.bam)
        for sample_section in glob.glob(os.path.join(args.output,sample_dir,"*.bam")):
            sample_section_out = sample_section.replace(".bam", ".sort.bam")

            command = f"samtools sort -o {sample_section_out} {sample_section}"
            command2 = f"rm {sample_section}"
            run_cmd(command)
            run_cmd(command2)


        # Merge all files for chunk N 
        files_to_merge = glob.glob(os.path.join(args.output,sample_dir,"*.sort.bam"))
        output_name = f"{sample_dir.split('/')[-1]}.bam"
        output_path = os.path.join(args.output, output_name)

        command = f"samtools merge {output_path} {' '.join(files_to_merge)}"
        run_cmd(command)
        run_cmd(f"rm -r {os.path.join(args.output,sample_dir)}")
    

