#!/usr/bin/env python3
# Remove all reads from a BAM file by TaxID
# Porter L

# python3 removeTaxIds.py --input_dir /input/dir --output_dir /output/dir --taxids 12345 --taxid 67890 
# python3 removeTaxIds.py --input_dir /input/dir --output_dir /output/dir --taxids 12345 --taxid 67890 --flip

import argparse, subprocess, os, sys, multiprocessing, pysam, psycopg2
from functools import partial
from typing import Optional, List

ACCESSION_TAXID_CACHE = {} 
CLASSIFICATION_CACHE = {}

conn = psycopg2.connect(dbname="ncbi", user="fieldsci", password="skalanes")

def parse_args():
    parser = argparse.ArgumentParser(description="Remove all reads from a BAM file by TaxID")
    parser.add_argument("-i","--input_dir", required=True, help="Path to input files")
    parser.add_argument("-o","--output_dir", required=True, help="Path to output directory")
    parser.add_argument("-t","--taxid", action='append', required=True, help="TaxIDs to remove, multiple allowed")
    parser.add_argument("-f",'--flip', action='store_true', help='keep ONLY reads whose lineage includes any of the provided -t TaxIDs')
    return parser.parse_args()

def get_taxid_from_accession(accession, conn):
    try:
        with conn.cursor() as cur:
            cur.execute("""
                SELECT tax_id
                FROM accession_taxid
                WHERE accession_version = %s
                LIMIT 1;
            """, (accession,))
            row = cur.fetchone()
            if row:
                return row[0]
            else:
                return "Unknown"
    except Exception as e:
        print(f"[Error] {e}")
        return "Unknown"


def get_taxid_from_accession_cached(ref_accession: str, conn):
    if ref_accession in ACCESSION_TAXID_CACHE:
        return ACCESSION_TAXID_CACHE[ref_accession]
        
    taxid = get_taxid_from_accession(ref_accession, conn)
    
    ACCESSION_TAXID_CACHE[ref_accession] = taxid
    return taxid


def get_classification_cached(tax_id: Optional[int], conn):
    if tax_id is None or tax_id == "Unknown" or tax_id == 1: 
        return []
        
    if tax_id in CLASSIFICATION_CACHE:
        return CLASSIFICATION_CACHE[tax_id]
        
    visited = set()
    current_tax_id = tax_id
    
    try:
        with conn.cursor() as cur:
            while current_tax_id and current_tax_id not in visited and current_tax_id != 1:
                visited.add(current_tax_id)
                cur.execute("""
                    SELECT parent_tax_id
                    FROM nodes
                    WHERE tax_id = %s
                    LIMIT 1;
                """, (current_tax_id,))
                
                row = cur.fetchone()
                if not row:
                    break
                
                parent_tid = row[0]
                current_tax_id = parent_tid
        
        result = list(visited)
        CLASSIFICATION_CACHE[tax_id] = result
        return result

    except Exception as e:
        print(f"[Database Error in get_classification] {e}") 
        return [] 


def filter_bam_by_taxid(input_bam: str, output_bam: str, target_taxids_to_remove: List[int], flip_filter: bool = False) -> None:
    target_taxids_set = set(target_taxids_to_remove) 

    def should_keep(matched: bool) -> bool:
        decision = not matched
        if flip_filter:
            decision = not decision
        return decision
    
    print(f"\nStarting to process {input_bam} with caching...")
    if not flip_filter:
        print(f"    -> Mode: REMOVE reads whose lineage includes (and descendants): {target_taxids_to_remove}")
    else:
        print(f"    -> Mode: KEEP ONLY reads whose lineage includes (and descendants): {target_taxids_to_remove}")
    print("    -> RUNNING...")
    try:
        infile = pysam.AlignmentFile(input_bam, "rb")
        outfile = pysam.AlignmentFile(output_bam, "wb", header=infile.header)
        reads_kept = 0
        reads_removed = 0
        for read in infile.fetch(until_eof=True):
            ref_accession = read.reference_name
            
            if ref_accession is None or read.is_unmapped:
                matched = False  # Treat unmapped/unknown accession as non-matched
                if should_keep(matched):
                    outfile.write(read)
                    reads_kept += 1
                else:
                    reads_removed += 1
                continue

            read_taxid = get_taxid_from_accession_cached(ref_accession, conn)
            
            if read_taxid is None:
                matched = False
                if should_keep(matched):
                    outfile.write(read)
                    reads_kept += 1
                else:
                    reads_removed += 1
                continue

            read_all_taxids = get_classification_cached(read_taxid, conn)
            matched = any((target_taxid in read_all_taxids) for target_taxid in target_taxids_set)

            if should_keep(matched):
                outfile.write(read)
                reads_kept += 1
            else:
                reads_removed += 1
                

        infile.close()
        outfile.close()
        #if reads_kept > 0:
        #    pysam.index(output_bam)
        
        print(f"    -> Hierarchical Filtering Complete!")
        print(f"\n    Total Unique Accessions Looked Up: {len(ACCESSION_TAXID_CACHE)}")
        print(f"    Total Unique TaxIDs Lineages Resolved: {len(CLASSIFICATION_CACHE)}")
        print(f"    Reads Kept: {reads_kept}")
        print(f"    Reads Removed: {reads_removed}")
        print(f"    Output file created and indexed: {output_bam}")
        
    except Exception as e:
        print(f"??? An error occurred: {e}")
        if os.path.exists(output_bam):
            os.remove(output_bam)
            print(f"Cleaned up partial output file: {output_bam}")
            

if __name__ == '__main__':
    args = parse_args()
    for file in os.listdir(args.input_dir):
        if file.endswith(".bam"):
            input_bam = os.path.join(args.input_dir, file)
            output_bam = os.path.join(args.output_dir, file)
            filter_bam_by_taxid(input_bam, output_bam, args.taxid,flip_filter=getattr(args, 'flip', False))
        else:
            print(f"-? Skipping non-BAM file: {file}")

    conn.close()
