#!/usr/bin/env python3
# Filter files based on control files
# Porter L
import argparse
import os
import pysam

def parse_args():
    """Parses command line arguments."""
    parser = argparse.ArgumentParser(description="Split BAM file by reference number.")
    parser.add_argument("-i", "--input_path", required=True, help="Path to input file directory.")
    parser.add_argument("-d", "--input_id", required=True, help="Input file ID (e.g., 1_3_S3).")
    parser.add_argument("-o", "--output_file", required=True, help="Path to output file.")
    parser.add_argument("-c", "--control_id", action='append', required=True, help="Manifest data for control file input.")
    return parser.parse_args()

def main():
    """
    Main function to filter a BAM file.
    It removes all reads found in the specified control files.
    """
    args = parse_args()
    
    # args.input_path='/mounts/lovelace/temporary/porter/control_test_2/03_sam_convert'
    # args.input_id='1_3_S3'
    # args.output_file='/mounts/lovelace/temporary/porter/control_test_2/04_control_simple/1_3_S3.bam'
    # args.control_id=['NC_15_S19a', 'NC_C6_S23', 'NC_EB_S24']

    # Use a set for efficient storage and lookup of read IDs
    control_read_ids = set()
    
    print("Collecting read IDs from control files...")
    # Loop through each control file ID to build the set of IDs to remove
    for control_id in args.control_id:
        # Construct the full path to the control BAM file
        control_file_path = os.path.join(args.input_path, f"{control_id}.bam")
        print(f"  -> Reading from {control_file_path}...")
        
        try:
            # Open the control BAM file for reading
            with pysam.AlignmentFile(control_file_path, "rb") as control_bam_file:
                # Iterate through all reads in the control file and add their names to the set
                for read in control_bam_file:
                    control_read_ids.add(read.query_name)
        except FileNotFoundError:
            print(f"Warning: Control file not found at {control_file_path}. Skipping.")
        except Exception as e:
            print(f"An error occurred while reading {control_file_path}: {e}")

    print(f"Found {len(control_read_ids)} unique read IDs to remove.")

    # Construct the full path to the input BAM file
    input_file_path = os.path.join(args.input_path, f"{args.input_id}.bam")
    
    print(f"Filtering reads from {input_file_path}...")
    try:
        # Open the input BAM file for reading
        with pysam.AlignmentFile(input_file_path, "rb") as input_bam_file:
            # Open the output BAM file for writing, using the header from the input file
            with pysam.AlignmentFile(args.output_file, "wb", header=input_bam_file.header) as output_bam_file:
                # Iterate through all reads in the input file
                for read in input_bam_file:
                    # Check if the read's ID is NOT in our set of control IDs
                    if read.query_name not in control_read_ids:
                        # If not, write the read to the output file
                        output_bam_file.write(read)
        
        print(f"Successfully created filtered BAM file at {args.output_file}")
    
    except FileNotFoundError:
        print(f"Error: Input file not found at {input_file_path}. Exiting.")
    except Exception as e:
        print(f"An error occurred during filtering: {e}")


if __name__ == "__main__":
    main()
