package au.edu.wehi.idsv.debruijn;

import au.edu.wehi.idsv.BreakendDirection;
import au.edu.wehi.idsv.DirectedEvidence;
import au.edu.wehi.idsv.NonReferenceReadPair;
import au.edu.wehi.idsv.configuration.ErrorCorrectionConfiguration;
import htsjdk.samtools.SAMRecord;
import htsjdk.samtools.util.Log;
import htsjdk.samtools.util.SequenceUtil;
import it.unimi.dsi.fastutil.longs.Long2IntMap;
import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap;
import it.unimi.dsi.fastutil.longs.LongOpenHashSet;
import java.util.HashSet;

/* loaded from: input_file:au/edu/wehi/idsv/debruijn/ReadErrorCorrector.class */
public class ReadErrorCorrector {
    private static final Log log = Log.getInstance(ReadErrorCorrector.class);
    private final Long2IntMap kmerCounts;
    private final int k;
    private final float kmerErrorCorrectionMultiple;
    private final boolean deduplicateReadKmers;
    private final int maxCorrectionsInKmer;
    private int maxCount;
    private int maxCollapseCount;

    public ReadErrorCorrector(ErrorCorrectionConfiguration errorCorrectionConfiguration) {
        this(errorCorrectionConfiguration.k, errorCorrectionConfiguration.kmerErrorCorrectionMultiple, errorCorrectionConfiguration.maxCorrectionsInKmer, errorCorrectionConfiguration.deduplicateReadKmers);
    }

    public ReadErrorCorrector(int i, float f, int i2, boolean z) {
        this.kmerCounts = new Long2IntOpenHashMap();
        this.maxCount = 0;
        this.maxCollapseCount = 0;
        if (i > 31) {
            throw new IllegalArgumentException("k cannot exceed 31");
        }
        this.k = i;
        this.kmerErrorCorrectionMultiple = f;
        this.maxCorrectionsInKmer = i2;
        this.deduplicateReadKmers = z;
    }

    public static void errorCorrect(int i, float f, int i2, boolean z, Iterable<? extends DirectedEvidence> iterable) {
        ReadErrorCorrector readErrorCorrector = new ReadErrorCorrector(i, f, i2, z);
        HashSet hashSet = new HashSet();
        HashSet hashSet2 = new HashSet();
        for (DirectedEvidence directedEvidence : iterable) {
            hashSet.add(directedEvidence.getUnderlyingSAMRecord());
            if (directedEvidence instanceof NonReferenceReadPair) {
                SAMRecord nonReferenceRead = ((NonReferenceReadPair) directedEvidence).getNonReferenceRead();
                if ((directedEvidence.getBreakendSummary().direction == BreakendDirection.Forward) ^ nonReferenceRead.getReadNegativeStrandFlag()) {
                    hashSet2.add(nonReferenceRead);
                } else {
                    hashSet.add(nonReferenceRead);
                }
            }
        }
        hashSet.stream().forEach(sAMRecord -> {
            readErrorCorrector.countKmers(sAMRecord, false);
        });
        hashSet2.stream().forEach(sAMRecord2 -> {
            readErrorCorrector.countKmers(sAMRecord2, true);
        });
        hashSet.stream().forEach(sAMRecord3 -> {
            readErrorCorrector.errorCorrect(sAMRecord3, false);
        });
        hashSet2.stream().forEach(sAMRecord4 -> {
            readErrorCorrector.errorCorrect(sAMRecord4, true);
        });
    }

    public void countKmers(SAMRecord sAMRecord, boolean z) {
        PackedSequence packedSequence = new PackedSequence(sAMRecord.getReadBases(), z, z);
        LongOpenHashSet longOpenHashSet = new LongOpenHashSet(Math.max(1, (packedSequence.length() - this.k) + 1), 0.5f);
        for (int i = 0; i < (packedSequence.length() - this.k) + 1; i++) {
            long kmer = packedSequence.getKmer(i, this.k);
            if (!this.deduplicateReadKmers || !longOpenHashSet.contains(kmer)) {
                int i2 = this.kmerCounts.get(kmer) + 1;
                this.kmerCounts.put(kmer, i2);
                if (i2 > this.maxCount) {
                    this.maxCount = i2;
                    refreshMaxCollapseCount();
                }
                if (this.deduplicateReadKmers) {
                    longOpenHashSet.add(kmer);
                }
            }
        }
    }

    public int errorCorrect(SAMRecord sAMRecord, boolean z) {
        if (sAMRecord.getReadLength() < this.k) {
            return 0;
        }
        PackedSequence packedSequence = new PackedSequence(sAMRecord.getReadBases(), z, z);
        PackedSequence packedSequence2 = new PackedSequence(packedSequence);
        int musket_two_sided = 0 + musket_two_sided(packedSequence) + musket_one_sided_greedy_without_voting(packedSequence);
        if (musket_two_sided > 0) {
            if (musket_two_sided <= this.maxCorrectionsInKmer || !tooManyDifferencesInWindow(packedSequence, packedSequence2)) {
                byte[] bytes = packedSequence.getBytes(0, sAMRecord.getReadBases().length);
                if (z) {
                    SequenceUtil.reverseComplement(bytes);
                }
                sAMRecord.setReadBases(bytes);
            } else {
                musket_two_sided = 0;
            }
        }
        return musket_two_sided;
    }

    private boolean tooManyDifferencesInWindow(PackedSequence packedSequence, PackedSequence packedSequence2) {
        for (int i = 0; i < packedSequence.length() - (this.k - 1); i++) {
            if (KmerEncodingHelper.basesDifference(this.k, packedSequence.getKmer(i, this.k), packedSequence2.getKmer(i, this.k)) > this.maxCorrectionsInKmer) {
                return true;
            }
        }
        return false;
    }

    private void debug_dump_changes(SAMRecord sAMRecord, PackedSequence packedSequence, PackedSequence packedSequence2) {
        String str = new String(packedSequence2.getBytes(0, sAMRecord.getReadBases().length));
        String str2 = new String(packedSequence.getBytes(0, sAMRecord.getReadBases().length));
        System.err.printf("\n%s\n", str);
        for (int i = 0; i < str.length() - (this.k - 1); i++) {
            if (isSafeBase(packedSequence2, i)) {
                System.err.printf("*", new Object[0]);
            } else {
                System.err.printf(" ", new Object[0]);
            }
        }
        System.err.printf("\n", new Object[0]);
        for (int i2 = 0; i2 < str.length(); i2++) {
            if (str.charAt(i2) == str2.charAt(i2)) {
                System.err.printf(" ", new Object[0]);
            } else {
                System.err.printf("|", new Object[0]);
            }
        }
        System.err.printf("\n", new Object[0]);
        for (int i3 = 0; i3 < str.length(); i3++) {
            if (str.charAt(i3) == str2.charAt(i3)) {
                System.err.printf(" ", new Object[0]);
            } else {
                System.err.printf("%s", Character.valueOf(str2.charAt(i3)));
            }
        }
        System.err.printf("\n", new Object[0]);
    }

    private void refreshMaxCollapseCount() {
        this.maxCollapseCount = (int) Math.floor(this.maxCount / this.kmerErrorCorrectionMultiple);
    }

    private int musket_two_sided(PackedSequence packedSequence) {
        int i = this.k - 1;
        int i2 = 0;
        while (i + this.k < packedSequence.length()) {
            long kmer = packedSequence.getKmer(i, this.k);
            int i3 = this.kmerCounts.get(kmer);
            if (i3 > this.maxCollapseCount) {
                i += this.k;
            } else {
                int i4 = i - (this.k - 1);
                long kmer2 = packedSequence.getKmer(i4, this.k);
                int i5 = this.kmerCounts.get(kmer2);
                if (i5 > this.maxCollapseCount) {
                    i++;
                } else {
                    long neighbourToCollapseInto = neighbourToCollapseInto(kmer2, this.k - 1, i5);
                    if (neighbourToCollapseInto == kmer2) {
                        i++;
                    } else {
                        long neighbourToCollapseInto2 = neighbourToCollapseInto(kmer, 0, i3);
                        if (neighbourToCollapseInto2 == kmer) {
                            i++;
                        } else {
                            if (this.kmerCounts.get(neighbourToCollapseInto) >= this.kmerCounts.get(neighbourToCollapseInto2)) {
                                packedSequence.setKmer(neighbourToCollapseInto, i4, this.k);
                            } else {
                                packedSequence.setKmer(neighbourToCollapseInto2, i, this.k);
                            }
                            i += this.k;
                            i2++;
                        }
                    }
                }
            }
        }
        return i2;
    }

    private int musket_one_sided_greedy_without_voting(PackedSequence packedSequence) {
        int advanceToSafePosition = advanceToSafePosition(packedSequence, 0);
        int i = 0;
        if (advanceToSafePosition > packedSequence.length() - this.k) {
            return 0;
        }
        if (advanceToSafePosition > 0) {
            i = 0 + fixBaseAdjacentTo(packedSequence, advanceToSafePosition, -1);
        }
        int advanceToEndOfSafeRegion = advanceToEndOfSafeRegion(packedSequence, advanceToSafePosition);
        while (true) {
            int i2 = advanceToEndOfSafeRegion;
            if (i2 > (packedSequence.length() - this.k) - 1) {
                return i;
            }
            i += fixBaseAdjacentTo(packedSequence, i2, 1);
            advanceToEndOfSafeRegion = advanceToEndOfSafeRegion(packedSequence, advanceToSafePosition(packedSequence, i2 + 1));
        }
    }

    private int advanceToSafePosition(PackedSequence packedSequence, int i) {
        while (i <= packedSequence.length() - this.k && !isSafeKmerStartPosition(packedSequence, i)) {
            i++;
        }
        return i;
    }

    private boolean isSafeKmerStartPosition(PackedSequence packedSequence, int i) {
        return this.kmerCounts.get(packedSequence.getKmer(i, this.k)) > this.maxCollapseCount;
    }

    private boolean isSafeBase(PackedSequence packedSequence, int i) {
        for (int max = Math.max(0, i - (this.k - 1)); max <= Math.min(i, packedSequence.length() - this.k); max++) {
            if (isSafeKmerStartPosition(packedSequence, max)) {
                return true;
            }
        }
        return false;
    }

    private int advanceToEndOfSafeRegion(PackedSequence packedSequence, int i) {
        while (i <= packedSequence.length() - this.k) {
            if (!isSafeKmerStartPosition(packedSequence, i)) {
                return i - 1;
            }
            int i2 = i + this.k;
            i = (i2 > packedSequence.length() - this.k || !isSafeKmerStartPosition(packedSequence, i2)) ? i + 1 : i2;
        }
        return i;
    }

    private int fixBaseAdjacentTo(PackedSequence packedSequence, int i, int i2) {
        int i3 = i + i2;
        if (i3 < 0 || i3 > packedSequence.length() - this.k) {
            return 0;
        }
        int i4 = i3 + i2;
        long kmer = packedSequence.getKmer(i3, this.k);
        int i5 = i2 == -1 ? 0 : this.k - 1;
        long neighbourToCollapseInto = neighbourToCollapseInto(kmer, i5, this.kmerCounts.get(kmer));
        if (neighbourToCollapseInto == kmer) {
            return 0;
        }
        if (i4 < 0 || i4 > packedSequence.length() - this.k) {
            packedSequence.setKmer(neighbourToCollapseInto, i3, this.k);
            return 1;
        }
        long kmer2 = packedSequence.getKmer(i4, this.k);
        int i6 = i2 == -1 ? 1 : this.k - 2;
        if (KmerEncodingHelper.getBase(this.k, neighbourToCollapseInto, i5) != KmerEncodingHelper.getBase(this.k, neighbourToCollapseInto(kmer2, i6, this.kmerCounts.get(kmer2)), i6)) {
            return 0;
        }
        packedSequence.setKmer(neighbourToCollapseInto, i3, this.k);
        return 1;
    }

    private long neighbourToCollapseInto(long j, int i, int i2) {
        long j2 = j;
        int ceil = ((int) Math.ceil(i2 * this.kmerErrorCorrectionMultiple)) - 1;
        long j3 = 1;
        while (true) {
            long j4 = j3;
            if (j4 >= 4) {
                return j2;
            }
            long j5 = j ^ (j4 << (((this.k - 1) - i) * 2));
            int i3 = this.kmerCounts.get(j5);
            if (i3 > ceil) {
                j2 = j5;
                ceil = i3;
            }
            j3 = j4 + 1;
        }
    }
}
