#!/usr/bin/env python
import optparse
import sys

#from get_LCA_functions import *
from Bio import SeqIO

import glob

def main():   
    p = optparse.OptionParser()
    p.add_option('--outfile', '-o')
    options, arguments = p.parse_args()
    fastafile=arguments[0]

    if options.outfile is None:
        outfile=fastafile.replace('.fasta','.taxid.fasta')
        print 'No output file specified, writing to: '+outfile
    else:
        outfile=options.outfile
###################

    count=0
    av_list=[]
    for record in SeqIO.parse(fastafile, format='fasta'):

        av_nr=record.description.split(' ')[0]   # was '|' [1]
        #print "debug1: " + av_nr 
        if av_nr not in av_list:
            av_list.append(av_nr)
            
###################
    av_dict={}
    for av_nr in av_list:
        av_dict[av_nr]='TAXID_NOT_FOUND'
####################
    count=0
    filelist=glob.glob('/mounts/lovelace/reference-genomes/accession/nucl_*.a?')

    for av_name in filelist:
        count+=1
        #if count>1:
        #    break
        print '\nprocessing file '+str(count)+' of '+str(len(filelist))+'\n'
        gi=open(av_name,'r')
        gi2taxid={}
        
        lines_processed = 0
        for line in gi.readlines():
            lines_processed = lines_processed + 1
            if (lines_processed % 5000000 == 0):
                sys.stdout.write('-')
            
            text=line.split()
            # print "debug3: " + text[0] + "  " + text[1] + " " + text[2] + " " + text[3] 
            gi2taxid[text[1]] = text[2]
        
        gi.close()
        
        for av_nr in av_list:
            if av_dict[av_nr]=='TAXID_NOT_FOUND':
                try:
                    av_dict[av_nr]=gi2taxid[av_nr]
                
                except:
                    av_dict[av_nr]='TAXID_NOT_FOUND'
        
    print '\n"accession "-files loaded into memory \n'    

    count=0
    outfile=open(outfile,'w')
    for record in SeqIO.parse(fastafile, format='fasta'):
        #count+=1
        #if count>100:
        #    break
        
        #print record.description.split('|')[1]
        try:
            gi = record.description.split(' ')[0] # was '|' [1]
            #print "debug2: " + gi
            #print av_dict[gi]
            taxid=av_dict[gi]
        except:
            taxid='TAXID_NOT_FOUND_in_dict'
            #print gi2taxid[gi]
            
        record.description=taxid+' '+record.description # was '|'
        #print record.description+'\t'+str(len(record))
        
        
        outfile.write('>'+str(record.description)+'\n'+str(record.seq)+'\n')
    outfile.close()

#####################
if __name__ == '__main__':
    main()
