package gridss.kraken;

import au.edu.wehi.idsv.debruijn.ContigKmerCounter;
import au.edu.wehi.idsv.kraken.KrakenReportLine;
import au.edu.wehi.idsv.kraken.SeqIdToTaxIdMap;
import au.edu.wehi.idsv.ncbi.MinimalTaxonomyNode;
import au.edu.wehi.idsv.ncbi.TaxonomyHelper;
import au.edu.wehi.idsv.ncbi.TaxonomyNode;
import com.google.common.collect.Streams;
import gridss.cmdline.ReferenceCommandLineProgram;
import gridss.cmdline.programgroups.DataConversion;
import htsjdk.samtools.SAMSequenceRecord;
import htsjdk.samtools.fastq.FastqReader;
import htsjdk.samtools.reference.FastaReferenceWriter;
import htsjdk.samtools.reference.FastaReferenceWriterBuilder;
import htsjdk.samtools.reference.IndexedFastaSequenceFile;
import htsjdk.samtools.reference.ReferenceSequence;
import htsjdk.samtools.util.IOUtil;
import htsjdk.samtools.util.Log;
import htsjdk.samtools.util.RuntimeIOException;
import htsjdk.samtools.util.SequenceUtil;
import htsjdk.variant.vcf.VCFHeader;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.configuration.tree.DefaultExpressionEngine;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang3.tuple.Pair;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.CommandLineProgramProperties;
import picard.cmdline.CommandLineProgram;
import picard.cmdline.StandardOptionDefinitions;

@CommandLineProgramProperties(summary = "Processes a Kraken2 report and extracts the sequences with the most hits", oneLineSummary = "Processes a Kraken2 report and extracts the sequences with the most hits.", programGroup = DataConversion.class)
/* loaded from: input_file:gridss/kraken/ExtractBestViralReference.class */
public class ExtractBestViralReference extends CommandLineProgram {
    private static final Log log = Log.getInstance(ExtractBestViralReference.class);
    private static final Comparator<KrakenReportLine> SORT_ORDER = KrakenReportLine.ByCountAssignedDirectly.reversed().thenComparing(KrakenReportLine.ByCountAssignedToTree.reversed());

    @Argument(shortName = StandardOptionDefinitions.INPUT_SHORT_NAME, doc = "TSV from gridss.IdentifyViralTaxa")
    public File INPUT_SUMMARY;

    @Argument(doc = "Viral reads. Used to determine which genome for the chosen taxa to return.")
    public File INPUT_VIRAL_READS;

    @Argument(shortName = StandardOptionDefinitions.OUTPUT_SHORT_NAME, doc = "Output fasta file")
    public File OUTPUT;

    @Argument(doc = "TSV annotated with extracted references")
    public File OUTPUT_SUMMARY;

    @Argument(doc = "TSV containing number of matched kmers for each candidate reference sequence.")
    public File OUTPUT_MATCHING_KMERS;

    @Argument(doc = "NCBI taxonomy nodes.dmp. Download and extract from https://ftp.ncbi.nlm.nih.gov/pub/taxonomy/taxdmp.zip")
    public File NCBI_NODES_DMP;

    @Argument(doc = "Kraken2 seqid2taxid.map mapping file")
    public File SEQID2TAXID_MAP;

    @Argument(doc = "Kraken2 library.fna files. Downloaded by kraken2-build. Must be indexed. Do not run kraken2-build --clean as these files will be removed. Files are checked in order and all the contigs for the given taxid from the first matching file are extracted.", optional = true)
    public List<File> KRAKEN_REFERENCES;

    @Argument(doc = "Kmer used determining best viral genome match for viral reads.")
    public int KMER = 16;

    @Argument(doc = "Distance between kmers in reference lookup. Longer stride reduces memory usage. Should not be more than kmer length")
    public int STRIDE = 16;

    @Argument(doc = "Maximum number of contigs to extract per NCBI taxonomic identifiers.", optional = true)
    public int CONTIGS_PER_TAXID = 1;

    @Argument(doc = "Use the viral references that occur in earlier KRAKEN_REFERENCES files whenever possible.", optional = true)
    public boolean FAVOUR_EARLY_KRAKEN_REFERENCES = true;

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // picard.cmdline.CommandLineProgram
    public String[] customCommandLineValidation() {
        return (this.KRAKEN_REFERENCES == null || this.KRAKEN_REFERENCES.size() == 0) ? new String[]{"KRAKEN_REFERENCES required. At minimum, this is the library/viral/library.fna file in the kraken2 database directory."} : super.customCommandLineValidation();
    }

    @Override // picard.cmdline.CommandLineProgram
    protected int doWork() {
        IOUtil.assertFileIsReadable(this.SEQID2TAXID_MAP);
        IOUtil.assertFileIsReadable(this.NCBI_NODES_DMP);
        IOUtil.assertFileIsReadable(this.INPUT_SUMMARY);
        IOUtil.assertFileIsWritable(this.OUTPUT);
        if (this.OUTPUT_SUMMARY != null) {
            IOUtil.assertFileIsWritable(this.OUTPUT_SUMMARY);
        }
        try {
            ArrayList arrayList = new ArrayList(this.KRAKEN_REFERENCES.size());
            for (File file : this.KRAKEN_REFERENCES) {
                IOUtil.assertFileIsReadable(file);
                ReferenceCommandLineProgram.ensureSequenceDictionary(file);
                arrayList.add(new IndexedFastaSequenceFile(file));
            }
            log.info("Loading seqid2taxid.map from ", this.SEQID2TAXID_MAP);
            Map<String, Integer> createLookup = SeqIdToTaxIdMap.createLookup(this.SEQID2TAXID_MAP);
            log.info("Loading NCBI taxonomy from ", this.NCBI_NODES_DMP);
            Map<Integer, TaxonomyNode> parseFull = TaxonomyHelper.parseFull(this.NCBI_NODES_DMP);
            log.info("Parsing ", this.INPUT_SUMMARY);
            List list = (List) Files.readAllLines(this.INPUT_SUMMARY.toPath()).stream().map(str -> {
                return Arrays.asList(str.split("\t"));
            }).collect(Collectors.toList());
            Set set = (Set) list.stream().skip(1L).map(list2 -> {
                return Integer.valueOf(Integer.parseInt((String) list2.get(6)));
            }).collect(Collectors.toSet());
            boolean[] createInclusionLookup = TaxonomyHelper.createInclusionLookup(set, parseFull);
            ContigKmerCounter contigKmerCounter = new ContigKmerCounter((Stream<ReferenceSequence>) arrayList.stream().flatMap(indexedFastaSequenceFile -> {
                return indexedFastaSequenceFile.getSequenceDictionary().getSequences().stream().filter(sAMSequenceRecord -> {
                    return createInclusionLookup[((Integer) createLookup.get(sAMSequenceRecord.getSequenceName())).intValue()];
                }).map(sAMSequenceRecord2 -> {
                    return indexedFastaSequenceFile.getSequence(sAMSequenceRecord2.getSequenceName());
                });
            }), this.KMER, this.STRIDE);
            if (this.INPUT_VIRAL_READS != null) {
                log.info("Identifying best viral reference genomes from ", this.INPUT_VIRAL_READS);
                FastqReader fastqReader = new FastqReader(this.INPUT_VIRAL_READS);
                while (fastqReader.hasNext()) {
                    contigKmerCounter.count(fastqReader.next().getReadBases());
                }
            }
            Map map = (Map) Streams.zip(contigKmerCounter.getContigs().stream(), contigKmerCounter.getKmerCounts().stream(), (v0, v1) -> {
                return Pair.of(v0, v1);
            }).collect(Collectors.groupingBy(pair -> {
                return Integer.valueOf(getParentTaxaOfInterest(set, parseFull, ((Integer) createLookup.get(pair.getKey())).intValue()));
            }));
            if (this.OUTPUT_MATCHING_KMERS != null) {
                log.info("Writing matching kmer counts to ", this.OUTPUT_MATCHING_KMERS);
                Files.write(this.OUTPUT_MATCHING_KMERS.toPath(), (Iterable<? extends CharSequence>) map.keySet().stream().flatMap(num -> {
                    return ((List) map.get(num)).stream().map(pair2 -> {
                        return String.format("%d\t%s\t%d", num, pair2.getKey(), pair2.getValue());
                    });
                }).collect(Collectors.toList()), new OpenOption[0]);
            }
            ArrayList arrayList2 = new ArrayList();
            ArrayList<String> arrayList3 = new ArrayList();
            for (int i = 0; i < list.size(); i++) {
                List list3 = (List) list.get(i);
                if (i == 0) {
                    ArrayList arrayList4 = new ArrayList(list3);
                    arrayList4.add(VCFHeader.REFERENCE_KEY);
                    arrayList4.add("reference_taxid");
                    arrayList4.add("reference_kmer_count");
                    arrayList4.add("alternate_kmer_count");
                    arrayList2.add(arrayList4);
                } else {
                    int parseInt = Integer.parseInt((String) list3.get(6));
                    new HashMap();
                    for (int i2 = 0; i2 < arrayList.size(); i2++) {
                        for (SAMSequenceRecord sAMSequenceRecord : ((IndexedFastaSequenceFile) arrayList.get(i2)).getSequenceDictionary().getSequences()) {
                        }
                    }
                    List<Pair> list4 = (List) ((List) map.get(Integer.valueOf(parseInt))).stream().sorted(Comparator.comparingInt(pair2 -> {
                        if (this.FAVOUR_EARLY_KRAKEN_REFERENCES) {
                            return offsetInList(arrayList, (String) pair2.getKey());
                        }
                        return 0;
                    }).thenComparing(Comparator.comparingLong(pair3 -> {
                        return ((Long) pair3.getValue()).longValue();
                    }).reversed())).collect(Collectors.toList());
                    List<Pair> list5 = (List) list4.stream().limit(this.CONTIGS_PER_TAXID).collect(Collectors.toList());
                    List list6 = (List) list4.stream().skip(this.CONTIGS_PER_TAXID).collect(Collectors.toList());
                    log.debug("Processing taxid " + parseInt);
                    for (Pair pair4 : list4) {
                        log.debug(((String) pair4.getKey()) + "\t" + pair4.getValue());
                    }
                    long orElse = list6.stream().mapToLong(pair5 -> {
                        return ((Long) pair5.getValue()).longValue();
                    }).findFirst().orElse(0L);
                    for (Pair pair6 : list5) {
                        String str2 = (String) pair6.getKey();
                        Long l = (Long) pair6.getValue();
                        int intValue = createLookup.get(str2).intValue();
                        log.info("Using " + ((String) pair6.getKey()) + " as viral reference for " + parseInt + " (" + intValue + DefaultExpressionEngine.DEFAULT_INDEX_END);
                        ArrayList arrayList5 = new ArrayList(list3);
                        arrayList5.add(str2);
                        arrayList5.add(Integer.toString(intValue));
                        arrayList5.add(l.toString());
                        arrayList5.add(Long.toString(orElse));
                        arrayList2.add(arrayList5);
                        arrayList3.add(str2);
                    }
                }
            }
            log.info("Writing " + arrayList3.size() + " viral contigs to ", this.OUTPUT);
            if (arrayList3.size() > 0) {
                FastaReferenceWriter build = new FastaReferenceWriterBuilder().setMakeDictOutput(true).setMakeFaiOutput(true).setFastaFile(this.OUTPUT.toPath()).build();
                try {
                    for (String str3 : arrayList3) {
                        Iterator it2 = arrayList.iterator();
                        while (true) {
                            if (it2.hasNext()) {
                                IndexedFastaSequenceFile indexedFastaSequenceFile2 = (IndexedFastaSequenceFile) it2.next();
                                if (indexedFastaSequenceFile2.getSequenceDictionary().getSequence(str3) != null) {
                                    build.addSequence(cleanSequence(indexedFastaSequenceFile2.getSequence(str3)));
                                    break;
                                }
                            }
                        }
                    }
                    if (build != null) {
                        build.close();
                    }
                } finally {
                }
            } else {
                Files.write(this.OUTPUT.toPath(), new byte[0], new OpenOption[0]);
                log.warn("No sequences written to ", this.OUTPUT);
            }
            Files.write(this.OUTPUT_SUMMARY.toPath(), (List) arrayList2.stream().map(list7 -> {
                return StringUtils.join(list7, "\t");
            }).collect(Collectors.toList()), new OpenOption[0]);
            return 0;
        } catch (IOException e) {
            log.error(e, new Object[0]);
            throw new RuntimeIOException(e);
        }
    }

    private int getParentTaxaOfInterest(Set<Integer> set, Map<Integer, ? extends MinimalTaxonomyNode> map, int i) {
        int i2;
        while (!set.contains(Integer.valueOf(i)) && (i2 = map.get(Integer.valueOf(i)).parentTaxId) > 1) {
            i = i2;
        }
        return i;
    }

    private static int offsetInList(List<IndexedFastaSequenceFile> list, String str) {
        for (int i = 0; i < list.size(); i++) {
            if (list.get(i).getSequenceDictionary().getSequence(str) != null) {
                return i;
            }
        }
        return list.size();
    }

    private static ReferenceSequence cleanSequence(ReferenceSequence referenceSequence) {
        byte[] bases = referenceSequence.getBases();
        for (int i = 0; i < bases.length; i++) {
            if (!SequenceUtil.isIUPAC(bases[i])) {
                bases[i] = 78;
            }
        }
        return new ReferenceSequence(referenceSequence.getName(), referenceSequence.getContigIndex(), bases);
    }

    public static void main(String[] strArr) {
        System.exit(new ExtractBestViralReference().instanceMain(strArr));
    }
}
