#!/usr/bin/env python3
# Filter files by a threshold of reads
# Porter L

# Example:
# filterPresence.py  -i /input/dir -o /output/dir -p 10 -t 8

import argparse, subprocess, os, sys, multiprocessing
from functools import partial
from pathlib import Path
import pysam

def parse_args():
    parser = argparse.ArgumentParser(description="Split BAM file by reference number.")
    parser.add_argument("-i", "--input_dir", required=True, help="Path to input files.")
    parser.add_argument("-o", "--output_dir", required=True, help="Path to output directory.")
    parser.add_argument("-p", "--presence_threshold", required=True, help="Number of reads required to establish presence.")
    parser.add_argument("-t", "--threads", required=True, help="CPU Threads to use.")
    return parser.parse_args()

def run_cmd(cmd):
    print(f"Running command: {cmd}")
    # Run a single command line as a subprocess
    result = subprocess.run(
        cmd, 
        shell=True, 
        stdout=subprocess.PIPE, 
        stderr=subprocess.PIPE, 
        text=True
    )
    if result.returncode != 0:
        print(f"Command failed: {cmd}")
        print(result.stderr)
        exit(result.returncode)

    return result.stdout

def find_files_from_dir(dir):
    out_list = []
    for f in os.listdir(dir):
            if os.path.splitext(f)[-1] == ".bam":
                 out_list.append(os.path.join(dir, f))
    return out_list

# def process_single_file(theta, output_dir, file):
#     file_read_count = int(run_cmd(f"samtools view {file} | wc -l"))
#     if file_read_count > int(theta):
#         run_cmd(f"cp {file} {output_dir}")

def process_single_file(theta, output_dir, file):
    # Output filename
    basename = os.path.basename(file)
    out_bam = os.path.join(output_dir, basename)
    
    # Open BAM and count reads per reference
    bamfile = pysam.AlignmentFile(file, "rb")
    read_counts = {ref: 0 for ref in bamfile.references}
    for read in bamfile.fetch(until_eof=True):
        if not read.is_unmapped:
            ref_name = bamfile.get_reference_name(read.reference_id)
            read_counts[ref_name] += 1
    bamfile.close()
    
    # Keep only references above threshold
    keep_refs = {ref for ref, count in read_counts.items() if int(count) >= int(theta)}
    
    # Reopen BAM and write only alignments from kept refs
    bamfile = pysam.AlignmentFile(file, "rb")
    out = pysam.AlignmentFile(out_bam, "wb", header=bamfile.header)
    
    for read in bamfile.fetch(until_eof=True):
        if not read.is_unmapped:
            ref_name = bamfile.get_reference_name(read.reference_id)
            if ref_name in keep_refs:
                out.write(read)
    
    bamfile.close()
    out.close()
    
    print(f"Filtered BAM written to {out_bam}")

                 
def main():
    # Parse all input from command line
    # (or from wrapper in workflow)
    args = parse_args()
    input_dir = args.input_dir
    output_dir = args.output_dir
    threads = int(args.threads)
    presence_threshold = args.presence_threshold

    file_list = find_files_from_dir(input_dir)

    
    # Create worker pool 
    num_workers = min(multiprocessing.cpu_count(), threads)
    with multiprocessing.Pool(processes=num_workers) as pool:
        pool.map(partial(process_single_file, presence_threshold, output_dir), file_list)
    print("** filterPresence done! **")

if __name__ == "__main__":
    main()