/*
 * Decompiled with CFR 0.152.
 */
package seqTools;

import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.random.DirichletMRG;
import de.jstacs.utils.random.DirichletMRGParams;
import java.text.NumberFormat;
import java.util.Arrays;
import java.util.HashMap;

public class FastHMM0
extends AbstractDifferentiableStatisticalModel {
    private HashMap<Sequence, double[]> countHash;
    private double classEss;
    private double avgLen;
    private double[] pars;
    private double[] probs;
    private double[] partDers;
    private double norm;
    private boolean fixed;

    public FastHMM0(AlphabetContainer alphabet, double ess, double avgLen) {
        super(alphabet, 0);
        this.classEss = ess;
        this.avgLen = avgLen;
        this.pars = new double[(int)alphabet.getAlphabetLengthAt(0)];
        this.partDers = new double[this.pars.length];
        this.probs = new double[this.pars.length];
        this.norm = Normalisation.logSumNormalisation(this.pars, 0, this.pars.length, this.probs, 0);
        this.countHash = new HashMap();
        this.fixed = false;
    }

    @Override
    public FastHMM0 clone() throws CloneNotSupportedException {
        FastHMM0 clone = (FastHMM0)super.clone();
        clone.countHash = new HashMap();
        clone.pars = (double[])this.pars.clone();
        clone.probs = (double[])this.probs.clone();
        clone.partDers = (double[])this.partDers.clone();
        return clone;
    }

    @Override
    public double getLogPriorTerm() {
        if (this.fixed) {
            return 0.0;
        }
        double lp = 0.0;
        int i = 0;
        while (i < this.pars.length) {
            lp += this.classEss * this.avgLen / (double)this.pars.length * this.pars[i];
            ++i;
        }
        return lp -= this.classEss * this.avgLen * this.norm;
    }

    @Override
    public void addGradientOfLogPriorTerm(double[] grad, int off) {
        if (this.fixed) {
            return;
        }
        int i = 0;
        while (i < this.pars.length) {
            int n = off + i;
            grad[n] = grad[n] + (this.classEss * this.avgLen / (double)this.pars.length - this.classEss * this.avgLen * this.probs[i]);
            ++i;
        }
    }

    private double getFullScore(Sequence seq, int start) {
        double[] temp;
        if (start > 0) {
            seq = seq.getSubSequence(start);
        }
        if ((temp = this.countHash.get(seq)) == null) {
            temp = new double[this.pars.length];
            int i = 0;
            while (i < seq.getLength()) {
                int n = seq.discreteVal(i);
                temp[n] = temp[n] + 1.0;
                ++i;
            }
            this.countHash.put(seq, temp);
        }
        double val = 0.0;
        int i = 0;
        while (i < temp.length) {
            val += temp[i] * this.pars[i];
            ++i;
        }
        return val -= (double)seq.getLength() * this.norm;
    }

    public double getLogScoreForMotifPos(Sequence seq, int start, int motifPos, int motifLength) {
        double full = 0.0;
        if (motifPos >= 0) {
            int i = 0;
            while (i < motifLength) {
                full += this.pars[seq.discreteVal(start + motifPos + i)];
                ++i;
            }
            full -= (double)motifLength * this.norm;
        }
        return full;
    }

    public double getLogScoreFor(Sequence seq, int start, int motifPos, int motifLength) {
        double full = this.getFullScore(seq, start);
        return full -= this.getLogScoreForMotifPos(seq, start, motifPos, motifLength);
    }

    private double getFullScoreAndPartialDer(Sequence seq, int start, int off) {
        double[] temp;
        if (start > 0) {
            seq = seq.getSubSequence(start);
        }
        if ((temp = this.countHash.get(seq)) == null) {
            temp = new double[this.pars.length];
            int i = 0;
            while (i < seq.getLength()) {
                int n = seq.discreteVal(i);
                temp[n] = temp[n] + 1.0;
                ++i;
            }
            this.countHash.put(seq, temp);
        }
        double val = 0.0;
        int i = 0;
        while (i < temp.length) {
            val += temp[i] * this.pars[i];
            int n = i;
            this.partDers[n] = this.partDers[n] + temp[i];
            int n2 = i;
            this.partDers[n2] = this.partDers[n2] - (double)seq.getLength() * this.probs[i];
            ++i;
        }
        return val -= (double)seq.getLength() * this.norm;
    }

    public double getLogScoreAndPartialDerivationFor(Sequence seq, int start, int motifPos, int motifLength, IntList indices, DoubleList partial, int off) {
        int i;
        Arrays.fill(this.partDers, 0.0);
        double full = this.getFullScoreAndPartialDer(seq, start, off);
        if (motifPos >= 0) {
            i = 0;
            while (i < motifLength) {
                full -= this.pars[seq.discreteVal(start + motifPos + i)];
                int n = seq.discreteVal(start + motifPos + i);
                this.partDers[n] = this.partDers[n] - 1.0;
                ++i;
            }
            i = 0;
            while (i < this.partDers.length) {
                int n = i;
                this.partDers[n] = this.partDers[n] - (double)(-motifLength) * this.probs[i];
                ++i;
            }
            full -= (double)(-motifLength) * this.norm;
        }
        if (!this.fixed) {
            i = 0;
            while (i < this.partDers.length) {
                indices.add(i + off);
                partial.add(this.partDers[i]);
                ++i;
            }
        }
        return full;
    }

    @Override
    public int getNumberOfParameters() {
        if (this.fixed) {
            return 0;
        }
        return this.pars.length;
    }

    @Override
    public double[] getCurrentParameterValues() {
        if (this.fixed) {
            return new double[0];
        }
        return (double[])this.pars.clone();
    }

    @Override
    public void setParameters(double[] pars, int off) {
        if (this.fixed) {
            return;
        }
        int i = 0;
        while (i < this.pars.length) {
            this.pars[i] = pars[off + i];
            ++i;
        }
        this.norm = Normalisation.logSumNormalisation(this.pars, 0, this.pars.length, this.probs, 0);
    }

    @Override
    public void initializeFunctionRandomly(boolean free) {
        if (this.fixed) {
            return;
        }
        DirichletMRGParams pars = new DirichletMRGParams(this.classEss * this.avgLen / (double)this.pars.length, this.pars.length);
        DirichletMRG.DEFAULT_INSTANCE.generateLog(this.pars, 0, this.pars.length, pars);
        this.norm = Normalisation.logSumNormalisation(this.pars, 0, this.pars.length, this.probs, 0);
    }

    @Override
    public boolean isInitialized() {
        return true;
    }

    @Override
    public String toString() {
        throw new Error("Unresolved compilation problem: \n\tCannot override the final method from AbstractDifferentiableStatisticalModel\n");
    }

    @Override
    public int getSizeOfEventSpaceForRandomVariablesOfParameter(int index) {
        return this.pars.length;
    }

    @Override
    public double getLogNormalizationConstant() {
        return 0.0;
    }

    @Override
    public double getLogPartialNormalizationConstant(int parameterIndex) throws Exception {
        return Double.NEGATIVE_INFINITY;
    }

    @Override
    public double getESS() {
        return this.classEss;
    }

    @Override
    public void initializeFunction(int index, boolean freeParams, DataSet[] data, double[][] weights) throws Exception {
        this.initializeFunctionRandomly(false);
    }

    @Override
    public double getLogScoreAndPartialDerivation(Sequence seq, int start, IntList indices, DoubleList partialDer) {
        return this.getLogScoreAndPartialDerivationFor(seq, start, -1, 0, indices, partialDer, 0);
    }

    @Override
    public String getInstanceName() {
        return "FastHMM(0)";
    }

    @Override
    public double getLogScoreFor(Sequence seq, int start) {
        return this.getLogScoreFor(seq, start, -1, 0);
    }

    @Override
    public StringBuffer toXML() {
        StringBuffer xml = new StringBuffer();
        XMLParser.appendObjectWithTags(xml, this.avgLen, "avgLen");
        XMLParser.appendObjectWithTags(xml, this.classEss, "classEss");
        XMLParser.appendObjectWithTags(xml, this.fixed, "fixed");
        XMLParser.appendObjectWithTags(xml, this.pars, "pars");
        XMLParser.addTags(xml, this.getClass().getSimpleName());
        return xml;
    }

    @Override
    protected void fromXML(StringBuffer xml) throws NonParsableException {
        xml = XMLParser.extractForTag(xml, this.getClass().getSimpleName());
        this.avgLen = XMLParser.extractObjectForTags(xml, "avgLen", Double.TYPE);
        this.classEss = XMLParser.extractObjectForTags(xml, "classEss", Double.TYPE);
        this.fixed = XMLParser.extractObjectForTags(xml, "fixed", Boolean.TYPE);
        this.pars = XMLParser.extractObjectForTags(xml, "pars", double[].class);
        this.partDers = new double[this.pars.length];
        this.probs = new double[this.pars.length];
        this.norm = Normalisation.logSumNormalisation(this.pars, 0, this.pars.length, this.probs, 0);
        this.countHash = new HashMap();
    }

    public void fix() {
        this.fixed = true;
    }

    @Override
    public /* synthetic */ String toString(NumberFormat numberFormat) {
        throw new Error("Unresolved compilation problem: \n\tThe type FastHMM0 must implement the inherited abstract method SequenceScore.toString(NumberFormat)\n");
    }
}

