# Who's in that file?
# Porter L

# EXAMPLE:
# python whos_in_that_file.py \
#   -i /tmp/tmpreftest/new \
#   -f /mounts/lovelace2/databases/ncbi/mitochondrion/mitochondrion.1.1.genomic.fna \
#   -o species.csv

import argparse, subprocess, os, sys, multiprocessing, re
from Bio import Entrez
from functools import partial
from pathlib import Path
from tqdm import tqdm

Entrez.email = "charliep@earlham.edu"  # Use your real email

def parse_args():
    parser = argparse.ArgumentParser(description="")
    parser.add_argument("-i", "--input_dir", required=True, help="Path to input files/file")
    parser.add_argument("-o", "--output", required=True, help="Path to output file")
    parser.add_argument("-f", "--fna", required=True, help="")
    return parser.parse_args()

def run_cmd(cmd):
    # Run a single command line as a subprocess
    result = subprocess.run(
        cmd, 
        shell=True, 
        stdout=subprocess.PIPE, 
        stderr=subprocess.PIPE, 
        text=True
    )
    if result.returncode != 0:
        print(f"Command failed: {cmd}")
        print(result.stderr)
        sys.exit(result.returncode)

    return result.stdout

def get_common_name(scientific_name):
    # Search for taxonomy ID
    try:
        handle = Entrez.esearch(db="taxonomy", term=scientific_name, retmode="xml")
        record = Entrez.read(handle)
        handle.close()
    except Exception as e:
        print(f"Error during esearch: {e}", file=sys.stderr)
        return "Unknown"

    if not record["IdList"]:
        print("No taxonomy ID found.", file=sys.stderr)
        return "Unknown"

    tax_id = record["IdList"][0]

    # Fetch taxonomy details
    try:
        handle = Entrez.efetch(db="taxonomy", id=tax_id, retmode="xml")
        records = Entrez.read(handle)
        handle.close()
    except Exception as e:
        print(f"Error during efetch: {e}", file=sys.stderr)
        return "Unknown"

    if not records:
        print("No records returned.", file=sys.stderr)
        return "Unknown"

    rec = records[0]

    # Try multiple common name fields
    if "CommonName" in rec and rec["CommonName"]:
        return rec["CommonName"]
    elif "OtherNames" in rec:
        other = rec["OtherNames"]
        if "GenbankCommonName" in other and other["GenbankCommonName"]:
            return other["GenbankCommonName"]
        elif "CommonName" in other and other["CommonName"]:
            return other["CommonName"]

    return "Unknown"

def get_latin_name(ref, fna_file):
    with open(fna_file, 'r') as f:
        for line in f:
            if ref in line:
                split_line = line.rstrip().split()
                return " ".join(split_line[1:3])
    return None

def get_ref_list(p):
    file_list = []
    for f in os.listdir(p):
        file_list.append(os.path.join(p,f))
    return file_list

def file_to_ref(f):
    pattern = re.compile(r'(NC|NW|NT)_\d+\.\d+')
    match = pattern.search(f)  # use .search instead of .match
    if match:
        return match.group(0)
    else:
        raise ValueError(f"Filename does not match expected format: {f}")

if __name__ == "__main__":
    args = parse_args()
    input = args.input_dir
    fna_file = args.fna
    output = args.output

    file_list = []
    if os.path.isfile(input):
        # Single File
        file_list = [input]

    elif os.path.isdir(input):
        # Directory of files
        file_list = get_ref_list(input)

    else:
        print("Input is incorrect...")
        exit

    # name = sys.argv[1]
    control_refs = []
    for file in tqdm(file_list):
        ref = file_to_ref(file)
        horizon_id = file.split("/")[-1][0]
        if horizon_id == "N":
            control_refs.append(ref)



    with open(output, 'w') as f:
        f.write(f"FILE NAME, REFERENCE NUMBER, HORIZON ID, READ COUNT, LATIN NAME, COMMON NAME\n")
        for file in tqdm(file_list):
            horizon_id = file.split("/")[-1][0]
            ref = file_to_ref(file)
            print("ref:", ref)

            if horizon_id != "N" and ref not in control_refs:
                print("fetching info")
                
                latin_name = get_latin_name(ref, fna_file)
                common_name = get_common_name(latin_name)
                read_count = int(run_cmd(f"samtools view {file} | wc -l"))
                f.write(f'{file}, {ref}, {horizon_id}, {read_count}, {latin_name}, "{common_name}"\n')
