/*
 * Decompiled with CFR 0.152.
 */
package projects.talGA;

import de.jstacs.data.DataSet;
import de.jstacs.data.WrongAlphabetException;
import de.jstacs.data.WrongLengthException;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.data.sequences.WrongSequenceTypeException;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import java.util.HashMap;
import projects.tals.TALgetterDiffSM;

public class PartialStringTree2 {
    private static final double logP = -Math.log(4.0);
    private TALgetterDiffSM model;
    private int[] innerNodes;
    private int[] leaves;
    private IntList freedLeaves;
    private DataSet data;
    private int length;
    private int currOff;
    private int currOffLeaves;
    private int msl;
    private HashMap<Sequence, DoubleList> scoreHash;

    public PartialStringTree2(TALgetterDiffSM model, DataSet data, int length) throws WrongAlphabetException, WrongLengthException, WrongSequenceTypeException {
        this.model = model;
        this.model.fix();
        this.data = data;
        this.length = length;
        this.construct();
        this.scoreHash = new HashMap();
    }

    private void construct() throws WrongAlphabetException, WrongLengthException {
        this.currOff = 4;
        this.currOffLeaves = 1;
        int n = 0;
        this.msl = 0;
        int i = 0;
        while (i < this.data.getNumberOfElements()) {
            int temp = this.data.getElementAt(i).getLength() - this.length + 1;
            if (temp > 0) {
                if (this.msl < temp) {
                    this.msl = temp;
                }
                n += temp;
            }
            ++i;
        }
        System.out.println(n);
        this.leaves = new int[Math.min((int)Math.pow(4.0, this.length), n) + 1];
        int lev = (int)(Math.log(n) / Math.log(4.0));
        int m = lev < this.length ? (int)Math.pow(4.0, lev) + (this.length - lev) * n : (int)Math.pow(4.0, this.length);
        System.out.println(m);
        this.innerNodes = new int[m];
        this.freedLeaves = new IntList();
        int i2 = 0;
        while (i2 < this.data.getNumberOfElements()) {
            Sequence seq = this.data.getElementAt(i2);
            if (i2 % 1000 == 0) {
                System.out.println(i2);
            }
            int j = 0;
            while (j < seq.getLength() - this.length + 1) {
                this.insert2(seq, i2, j, 0, 0);
                ++j;
            }
            ++i2;
        }
        System.out.println("leaves " + this.currOffLeaves);
        System.out.println("inner nodes " + this.currOff);
    }

    private void insert2(Sequence seq, int idx, int start, int depth, int currNodeStartIdx) {
        int sym = seq.discreteVal(start + depth);
        int currNodeIdx = currNodeStartIdx + sym;
        if (this.innerNodes[currNodeIdx] == 0) {
            if (this.freedLeaves.length() == 0) {
                this.innerNodes[currNodeIdx] = -this.currOffLeaves;
                this.leaves[this.currOffLeaves] = -(idx * this.msl + start);
                ++this.currOffLeaves;
            } else {
                int temp = this.freedLeaves.pop();
                this.innerNodes[currNodeIdx] = -temp;
                this.leaves[temp] = -(idx * this.msl + start);
            }
        } else if (this.innerNodes[currNodeIdx] > 0) {
            this.insert2(seq, idx, start, depth + 1, this.innerNodes[currNodeIdx]);
        } else if (depth + 1 < this.length) {
            int encIdx = -this.leaves[-this.innerNodes[currNodeIdx]];
            this.freedLeaves.add(-this.innerNodes[currNodeIdx]);
            this.innerNodes[currNodeIdx] = this.currOff;
            this.currOff += 4;
            this.insert2(seq, idx, start, depth + 1, this.innerNodes[currNodeIdx]);
            int remIdx = encIdx / this.msl;
            int remPos = encIdx % this.msl;
            Sequence temp = this.data.getElementAt(remIdx);
            this.insert2(temp, remIdx, remPos, depth + 1, this.innerNodes[currNodeIdx]);
        } else if (this.leaves[-this.innerNodes[currNodeIdx]] > 0) {
            int n = -this.innerNodes[currNodeIdx];
            this.leaves[n] = this.leaves[n] + 1;
        } else {
            this.leaves[-this.innerNodes[currNodeIdx]] = 2;
        }
    }

    public double[] getScoresAbove(Sequence tal, double t, int cap) {
        DoubleList l = this.scoreHash.get(tal);
        if (l == null || l.get(l.length() - 1) != t || (int)l.get(l.length() - 2) != cap) {
            l = new DoubleList();
            int[] currSeq = new int[tal.getLength() + 1];
            double[] scs = new double[tal.getLength() + 1];
            double sc = this.model.getBestPossibleScore(tal, scs);
            int i = scs.length - 2;
            while (i >= 0) {
                int n = i;
                scs[n] = scs[n] + scs[i + 1];
                --i;
            }
            this.getScoresAbove(0, tal, t, currSeq, scs, 0, this.model, 0.0, l, cap);
            l.add(cap);
            l.add(t);
            this.scoreHash.put(tal, l);
        }
        return l.toArray(0, l.length() - 2);
    }

    private void getScoresAbove(int off, Sequence tal, double t, int[] currSeq, double[] bestSc, int depth, TALgetterDiffSM model, double currScore, DoubleList list, int cap) {
        if (list.length() > cap) {
            return;
        }
        int i = 0;
        while (i < 4) {
            if (this.innerNodes[off + i] > 0) {
                currSeq[depth] = i;
                double tempScore = currScore + model.getPartialLogScoreFor(tal, currSeq, 0, depth, 1);
                if (tempScore + bestSc[depth + 1] >= t) {
                    this.getScoresAbove(this.innerNodes[off + i], tal, t, currSeq, bestSc, depth + 1, model, tempScore, list, cap);
                }
            } else if (this.innerNodes[off + i] < 0) {
                int leafIdx = -this.innerNodes[off + i];
                int leafVal = this.leaves[leafIdx];
                if (leafVal < 0) {
                    int remIdx = -leafVal / this.msl;
                    int remPos = -leafVal % this.msl;
                    double tempScore = currScore + model.getPartialLogScoreFor(tal, this.data.getElementAt(remIdx), remPos, depth, tal.getLength() + 1 - depth);
                    if (tempScore > t) {
                        list.add(tempScore);
                    }
                } else {
                    currSeq[depth] = i;
                    double tempScore = currScore + model.getPartialLogScoreFor(tal, currSeq, 0, depth, 1);
                    if (tempScore > t) {
                        list.add(tempScore, list.length(), list.length() + leafVal);
                    }
                }
            }
            ++i;
        }
    }

    public void print(int off) {
        int i = 0;
        while (i < 4) {
            int newIdx = off + i;
            if (this.innerNodes[newIdx] < 0) {
                System.out.println("leaf for " + i + " with " + this.leaves[-this.innerNodes[newIdx]]);
            } else if (this.innerNodes[newIdx] > 0) {
                System.out.println("inner node for " + i);
                this.print(this.innerNodes[newIdx]);
                System.out.println("inner node for " + i + " end");
            } else {
                System.out.println("no child for " + i);
            }
            ++i;
        }
    }
}

