#!/usr/bin/env python3
# Clean extra SQ records from a sam file
# Porter L

# Example:
# 

import argparse, subprocess, os, sys, multiprocessing
from functools import partial


def parse_args():
    parser = argparse.ArgumentParser(description="Split BAM file by reference number")
    parser.add_argument("-i", "--input_dir", required=True, help="Path to input files")
    parser.add_argument("-o", "--output_dir", required=True, help="Path to output directory")
    parser.add_argument("-t", "--threads", required=True, help="Threads to use")

    return parser.parse_args()

def collect_used_refs(sam_path):
    used = 0
    unused = 0
    used_refs = set()
    with open(sam_path, 'r') as f:
        for line in f:
            if line.startswith('@'):
                continue
            fields = line.split('\t')
            if len(fields) > 2 and fields[2] != '*':
                if fields[2] in used_refs:
                    unused += 1
                else:
                    used_refs.add(fields[2])
                    used += 1
    print(f"Retained SQ entries: {used}, Discarded SQ entries: {unused}")
    return used_refs

def clean_sam_header(sam_path, used_refs, output_path):
    used = 0
    unused = 0
    with open(sam_path, 'r') as f_in, open(output_path, 'w') as f_out:
        for line in f_in:
            if line.startswith('@SQ'):
                refname = None
                for field in line.strip().split('\t'):
                    if field.startswith('SN:'):
                        refname = field[3:]
                        break
                if refname and refname in used_refs:
                    f_out.write(line)
                    used += 1
                else:
                    unused += 1
            elif line.startswith('@'):
                f_out.write(line)
                used += 1
            else:
                f_out.write(line)
                used += 1
    print(f"Written lines: {used}, Discarded lines: {unused}")


def clean(output_dir, sam_path):
    output_path = os.path.join(output_dir, os.path.basename(sam_path))
    used_refs = collect_used_refs(sam_path)
    clean_sam_header(sam_path, used_refs, output_path)

def main():
    args = parse_args()
    input_dir = args.input_dir
    output_dir = args.output_dir
    threads = int(args.threads)
    import glob

    sam_files = []
    for file in glob.glob(os.path.join(input_dir,"*.sam")):
        sam_files.append(file)

    # Create worker pool 
    num_workers = min(multiprocessing.cpu_count(), threads)
    with multiprocessing.Pool(processes=num_workers) as pool:
        pool.map(partial(clean, output_dir), sam_files)



    print("** cleanSQ done! **")

if __name__ == '__main__':
    main()
