# addTaxonToSQ.py - decorates the @SQ records in a SAM file that was aligned against a 
# database/index that uses accession numbers (e.g. the mitochondrial) with the taxon ID from 
# the NCBI. 
#
# 2025, May - Original coding, charliep and Gemini.
# 

import sys
from Bio import Entrez
import re
import time

def get_taxon_id(accession_number):
    """
    Looks up the taxon ID for a given accession number using NCBI's Entrez API.
    Looks for the Taxon ID in the 'db_xref' field within GBFeature_quals.

    Args:
        accession_number (str): The accession number to search for.

    Returns:
        str or None: The taxon ID if found, otherwise None.
    """
    print(f"Attempting to fetch record for accession: '{accession_number}'")
    time.sleep(0.1)

    try:
        handle = Entrez.efetch(db="nucleotide", id=accession_number, rettype="gb", retmode="xml")
        record = Entrez.read(handle)
        handle.close()

        if not record:
            print(f"Warning: Empty record returned for accession '{accession_number}'.")
            return None

        if isinstance(record, list) and record:
            gb_record = record[0]
            if gb_record.get('GBSeq_feature-table'):
                for feature in gb_record['GBSeq_feature-table']:
                    if feature.get('GBFeature_key') == 'source' and feature.get('GBFeature_quals'):
                        for qualifier in feature['GBFeature_quals']:
                            if qualifier.get('GBQualifier_name') == 'db_xref':
                                db_xref_value = qualifier.get('GBQualifier_value', '')
                                if db_xref_value.startswith("taxon:"):
                                    return db_xref_value.split(":")[1]
        else:
            print(f"Warning: Unexpected record structure for '{accession_number}'.")
            print(record) # Print the structure for debugging
            return None

        return None

    except Exception as e:
        print(f"Error fetching record for '{accession_number}': {e}")
        return None

def process_sam_file(input_sam_file, output_sam_file):
    """
    Reads a SAM file, looks up taxon IDs for @SQ records (removing version),
    adds SP field, and writes to a new file.

    Args:
        input_sam_file (str): Path to the input SAM file.
        output_sam_file (str): Path to the output SAM file.
    """
    Entrez.email = "charliep@earlham.edu"  # Replace with /your/ actual email address
    Entrez.api_key = "c21e043794c5c83956bfacdb9cc3603d4208"

    try:
        with open(input_sam_file, 'r') as infile, open(output_sam_file, 'w') as outfile:
            for line in infile:
                if line.startswith('@SQ'):
                    parts = line.strip().split('\t')
                    sq_dict = {}
                    accession_with_version = None
                    ln_value = ''
                    for part in parts:
                        if ':' in part:
                            key, value = part.split(':', 1)
                            sq_dict[key] = value
                            if key == 'SN':
                                accession_with_version = value
                            elif key == 'LN':
                                ln_value = value

                    if accession_with_version:
                        # Remove trailing version number (period and everything after)
                        accession = re.sub(r'\.\d+$', '', accession_with_version)
                        taxon_id = get_taxon_id(accession_with_version)
                        if taxon_id:
                            outfile.write(f"@SQ\tSN:{accession_with_version}\tLN:{ln_value}\tSP:{taxon_id}\n")
                            print(f"Processed @SQ record for {accession_with_version} (base: {accession}): Taxon ID - {taxon_id}")
                        else:
                            outfile.write(line) # Write the original line if taxon ID not found
                            print(f"Taxon ID not found for accession: {accession_with_version} (base: {accession})")
                    else:
                        outfile.write(line) # Write the original @SQ line if SN field is missing
                else:
                    outfile.write(line) # Write non-@SQ lines as they are

        print(f"\nProcessed SAM file. New @SQ records written to {output_sam_file}")

    except FileNotFoundError:
        print(f"Error: Input SAM file '{input_sam_file}' not found.")
        exit(1)
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
        exit(1)

if __name__ == "__main__":
    if len(sys.argv) != 3:
        print("Usage: python addTaxonToSQ.py <input_sam_file> <output_sam_file>")
        exit(1)

    input_sam = sys.argv[1]
    output_sam = sys.argv[2]

    process_sam_file(input_sam, output_sam)