/*
 * Decompiled with CFR 0.152.
 */
package seqTools;

import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.data.sequences.annotation.ReferenceSequenceAnnotation;
import de.jstacs.data.sequences.annotation.SequenceAnnotation;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.motifDiscovery.MotifDiscoverer;
import de.jstacs.motifDiscovery.Mutable;
import de.jstacs.motifDiscovery.MutableMotifDiscoverer;
import de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.ToolBox;
import de.jtem.numericalMethods.calculus.specialFunctions.Gamma;
import java.text.NumberFormat;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;

public class CompleteViterbiChipper
extends AbstractDifferentiableStatisticalModel
implements MutableMotifDiscoverer {
    private static double MINUS_LOG_2 = -Math.log(2.0);
    private double logP;
    private double[] hiddenParameters;
    private double[] hiddenPotential;
    private double[] logPartNorm;
    private double logHiddenNorm;
    private double logGammaSum;
    private double logNorm;
    private DifferentiableStatisticalModel motif;
    private int motifLength;
    private int numParams;
    private int starts;
    private int[] currentHiddenVariables;
    private HashMap<Sequence, int[]> sequenceSpecificHiddenVariables;
    protected HashMap<Sequence, float[]> positionHash;

    public CompleteViterbiChipper(int starts, DifferentiableStatisticalModel motif) throws IllegalArgumentException, CloneNotSupportedException {
        super(motif.getAlphabetContainer(), 0);
        if (starts <= 0) {
            throw new IllegalArgumentException();
        }
        this.starts = starts;
        this.motif = (DifferentiableStatisticalModel)motif.clone();
        this.init();
    }

    public CompleteViterbiChipper(StringBuffer xml) throws NonParsableException {
        super(xml);
    }

    protected String getXMLTag() {
        return this.getClass().getSimpleName();
    }

    @Override
    protected void fromXML(StringBuffer xml) throws NonParsableException {
        xml = XMLParser.extractForTag(xml, this.getXMLTag());
        this.starts = (Integer)XMLParser.extractObjectForTags(xml, "starts");
        this.motif = (DifferentiableStatisticalModel)XMLParser.extractObjectForTags(xml, "motif");
        this.alphabets = this.motif.getAlphabetContainer();
        this.length = 0;
        this.init();
        this.hiddenParameters = (double[])XMLParser.extractObjectForTags(xml, "hiddenParameters");
        this.logHiddenNorm = Normalisation.logSumNormalisation(this.hiddenParameters, 0, this.hiddenParameters.length, this.hiddenPotential);
    }

    @Override
    public StringBuffer toXML() {
        StringBuffer xml = new StringBuffer();
        XMLParser.appendObjectWithTags(xml, this.starts, "starts");
        XMLParser.appendObjectWithTags(xml, this.motif, "motif");
        XMLParser.appendObjectWithTags(xml, this.hiddenParameters, "hiddenParameters");
        XMLParser.addTags(xml, this.getXMLTag());
        return xml;
    }

    private void init() {
        this.currentHiddenVariables = new int[3];
        this.hiddenParameters = new double[2];
        this.logHiddenNorm = -Math.log(this.hiddenParameters.length);
        this.hiddenPotential = new double[this.hiddenParameters.length];
        Arrays.fill(this.hiddenPotential, 1.0 / (double)this.hiddenPotential.length);
        this.logPartNorm = new double[this.hiddenParameters.length];
        this.logNorm = Double.NaN;
        this.logP = -Math.log(this.alphabets.getAlphabetLengthAt(0));
        this.sequenceSpecificHiddenVariables = new HashMap();
        this.positionHash = new HashMap();
        this.set();
        this.computeLogGammaSum();
    }

    @Override
    public CompleteViterbiChipper clone() throws CloneNotSupportedException {
        Sequence s;
        CompleteViterbiChipper clone = (CompleteViterbiChipper)super.clone();
        clone.motif = (DifferentiableStatisticalModel)this.motif.clone();
        clone.hiddenParameters = (double[])this.hiddenParameters.clone();
        clone.hiddenPotential = (double[])this.hiddenPotential.clone();
        Iterator<Sequence> it = this.positionHash.keySet().iterator();
        clone.positionHash = new HashMap();
        while (it.hasNext()) {
            s = it.next();
            clone.positionHash.put(s, (float[])this.positionHash.get(s).clone());
        }
        it = this.sequenceSpecificHiddenVariables.keySet().iterator();
        clone.sequenceSpecificHiddenVariables = new HashMap();
        while (it.hasNext()) {
            s = it.next();
            clone.sequenceSpecificHiddenVariables.put(s, (int[])this.sequenceSpecificHiddenVariables.get(s).clone());
        }
        clone.currentHiddenVariables = (int[])this.currentHiddenVariables.clone();
        clone.logPartNorm = (double[])this.logPartNorm.clone();
        return clone;
    }

    public void reset() {
        this.sequenceSpecificHiddenVariables.clear();
    }

    protected void set() {
        if (this.motif.isInitialized()) {
            this.numParams = this.motif.getNumberOfParameters();
            this.motifLength = this.motif.getLength();
        } else {
            this.numParams = Integer.MIN_VALUE;
            this.motifLength = Integer.MIN_VALUE;
        }
    }

    @Override
    public double getESS() {
        return 2.0 * this.motif.getESS();
    }

    private void precomputeNorm() {
        this.logPartNorm[0] = this.hiddenParameters[0] - this.logHiddenNorm + this.motif.getLogNormalizationConstant();
        this.logPartNorm[1] = this.hiddenParameters[1] - this.logHiddenNorm;
        this.logNorm = Normalisation.getLogSum(this.logPartNorm);
    }

    @Override
    public double getLogNormalizationConstant() {
        if (Double.isNaN(this.logNorm)) {
            this.precomputeNorm();
        }
        return this.logNorm;
    }

    @Override
    public double getLogPartialNormalizationConstant(int parameterIndex) throws Exception {
        if (Double.isNaN(this.logNorm)) {
            this.precomputeNorm();
        }
        if (parameterIndex >= this.numParams) {
            return this.logPartNorm[parameterIndex - this.numParams];
        }
        return this.hiddenPotential[0] - this.logHiddenNorm + this.motif.getLogPartialNormalizationConstant(parameterIndex);
    }

    @Override
    public void addGradientOfLogPriorTerm(double[] grad, int start) throws Exception {
        this.motif.addGradientOfLogPriorTerm(grad, start);
        double sum = this.getESS();
        double e = this.motif.getESS();
        int i = 0;
        while (i < this.hiddenParameters.length) {
            int n = start + this.numParams + i;
            grad[n] = grad[n] + (e - sum * this.hiddenPotential[i]);
            ++i;
        }
    }

    @Override
    public double getLogPriorTerm() {
        double val = 0.0;
        double sum = 0.0;
        double e = this.motif.getESS();
        int i = 0;
        while (i < this.hiddenParameters.length) {
            sum += e;
            val += this.hiddenParameters[i] * e;
            ++i;
        }
        if (this.isNormalized()) {
            val -= sum * this.logHiddenNorm;
        }
        return (val += this.motif.getLogPriorTerm()) + this.logGammaSum;
    }

    private void computeLogGammaSum() {
        this.logGammaSum = 0.0;
        int i = 0;
        int n = this.getNumberOfComponents();
        if (n > 1 && this.getESS() > 0.0) {
            double sum = 0.0;
            while (i < n) {
                double h = this.motif.getESS();
                sum += h;
                this.logGammaSum -= Gamma.logOfGamma(h);
                ++i;
            }
            this.logGammaSum += Gamma.logOfGamma(sum);
        }
    }

    @Override
    public int getSizeOfEventSpaceForRandomVariablesOfParameter(int index) {
        if (this.numParams < 0) {
            throw new IllegalArgumentException();
        }
        if (index < this.numParams) {
            return this.motif.getSizeOfEventSpaceForRandomVariablesOfParameter(index);
        }
        return this.hiddenParameters.length;
    }

    @Override
    public double[] getCurrentParameterValues() throws Exception {
        if (this.numParams != 0) {
            double[] params = this.motif.getCurrentParameterValues();
            double[] res = new double[this.numParams + this.hiddenParameters.length];
            System.arraycopy(params, 0, res, 0, this.numParams);
            System.arraycopy(this.hiddenParameters, 0, res, this.numParams, this.hiddenParameters.length);
            return res;
        }
        return null;
    }

    @Override
    public double getLogScoreFor(Sequence seq, int start) {
        double res;
        int h = seq.hashCode();
        int[] hv = this.sequenceSpecificHiddenVariables.get(seq);
        if (hv == null) {
            hv = this.currentHiddenVariables;
            this.setHiddenVariables(hv, seq, start);
        }
        if (hv[0] == 1) {
            res = this.hiddenParameters[1] - this.logHiddenNorm + this.logP * (double)(seq.getLength() - start);
        } else {
            try {
                float[] position = this.getPosition(seq, false);
                int e = seq.getLength() - start - this.motifLength;
                double pos = position == null ? -Math.log(e + 1) : (double)position[hv[1]];
                res = this.hiddenParameters[0] - this.logHiddenNorm + pos + this.logP * (double)e + MINUS_LOG_2 + this.getMotifScore(hv[2] == 1, seq, hv[1]);
            }
            catch (Exception e) {
                RuntimeException r = new RuntimeException(e.getMessage());
                throw r;
            }
        }
        return res;
    }

    @Override
    public double getLogScoreAndPartialDerivation(Sequence seq, int start, IntList indices, DoubleList partialDer) {
        boolean add;
        int[] hv = this.sequenceSpecificHiddenVariables.get(seq);
        boolean bl = add = hv == null && start == 0;
        if (hv == null) {
            hv = this.currentHiddenVariables;
        }
        this.setHiddenVariables(hv, seq, start);
        if (add) {
            this.sequenceSpecificHiddenVariables.put(seq, (int[])hv.clone());
        }
        if (hv[0] == 1) {
            indices.add(this.numParams);
            partialDer.add(-this.hiddenPotential[0]);
            indices.add(this.numParams + 1);
            partialDer.add(1.0 - this.hiddenPotential[1]);
            return this.hiddenParameters[1] - this.logHiddenNorm + this.logP * (double)(seq.getLength() - start);
        }
        try {
            float[] position = this.getPosition(seq, start == 0);
            int e = seq.getLength() - start - this.motifLength;
            double pos = position == null ? -Math.log(e + 1) : (double)position[hv[1]];
            double res = this.hiddenParameters[0] - this.logHiddenNorm + pos + this.logP * (double)e + MINUS_LOG_2 + (hv[2] == 0 ? this.motif.getLogScoreAndPartialDerivation(seq, hv[1], indices, partialDer) : this.motif.getLogScoreAndPartialDerivation(seq.reverseComplement(), seq.getLength() - hv[1] - this.motifLength, indices, partialDer));
            indices.add(this.numParams);
            partialDer.add(1.0 - this.hiddenPotential[0]);
            indices.add(this.numParams + 1);
            partialDer.add(-this.hiddenPotential[1]);
            return res;
        }
        catch (Exception e) {
            RuntimeException r = new RuntimeException(e.getMessage());
            throw r;
        }
    }

    @Override
    public int getNumberOfParameters() {
        if (this.numParams >= 0) {
            return this.numParams + 2;
        }
        return -1;
    }

    @Override
    public void setParameters(double[] params, int start) {
        this.motif.setParameters(params, start);
        System.arraycopy(params, start + this.numParams, this.hiddenParameters, 0, this.hiddenParameters.length);
        this.logHiddenNorm = Normalisation.logSumNormalisation(this.hiddenParameters, 0, this.hiddenParameters.length, this.hiddenPotential);
        this.logNorm = Double.NaN;
    }

    @Override
    public String getInstanceName() {
        return "complete viterbi chipper (" + this.motif.getInstanceName() + ")";
    }

    @Override
    public boolean isInitialized() {
        return this.motif.isInitialized();
    }

    @Override
    public void adjustHiddenParameters(int index, DataSet[] data, double[][] weights) throws Exception {
    }

    public static int draw(DataSet d, double[] weight) {
        if (weight == null) {
            return r.nextInt(d.getNumberOfElements());
        }
        return CompleteViterbiChipper.draw(weight);
    }

    public static int draw(double[] weight) {
        double s = r.nextDouble() * ToolBox.sum(weight);
        int i = 0;
        while (weight[i] < s) {
            s -= weight[i];
            ++i;
        }
        return i;
    }

    @Override
    public void initializeFunction(int index, boolean freeParams, DataSet[] data, double[][] weights) throws Exception {
        int num = this.getNumberOfMotifs();
        int a = (int)this.alphabets.getAlphabetLengthAt(0);
        double d = 0.1 / (double)(a - 1);
        d = (1.0 - (double)a * d) / ((double)a * d);
        int motif = 0;
        while (motif < num) {
            int p;
            int s = CompleteViterbiChipper.draw(data[index], weights == null ? null : weights[index]);
            Sequence seq = data[index].getElementAt(s);
            Sequence ref = this.getReference(seq);
            if (ref == null) {
                p = r.nextInt(seq.getLength() - this.motifLength + 1);
            } else {
                double[] prof = new double[seq.getLength() - this.motifLength];
                int i = 0;
                while (i < prof.length) {
                    prof[i] = ref.continuousVal(i);
                    ++i;
                }
                p = CompleteViterbiChipper.draw(prof);
            }
            seq = seq.getSubSequence(p, this.motifLength);
            double h = d * this.motif.getESS();
            this.motif.initializeFunction(0, freeParams, new DataSet[]{new DataSet("", seq)}, new double[][]{{h}});
            ++motif;
        }
        Arrays.fill(this.hiddenParameters, 0.0);
        this.logHiddenNorm = Normalisation.logSumNormalisation(this.hiddenParameters, 0, this.hiddenParameters.length, this.hiddenPotential, 0);
        this.reset();
        this.set();
        this.logNorm = Double.NaN;
    }

    @Override
    public void initializeFunctionRandomly(boolean freeParams) throws Exception {
        this.initializeMotifRandomly(0);
    }

    @Override
    public void initializeMotif(int motifIndex, DataSet data, double[] weights) throws Exception {
        if (motifIndex != 0) {
            throw new IndexOutOfBoundsException();
        }
        this.motif.initializeFunction(0, false, new DataSet[]{data}, new double[][]{weights});
        this.reset();
        this.set();
        this.logNorm = Double.NaN;
    }

    @Override
    public void initializeMotifRandomly(int motif) throws Exception {
        if (motif != 0) {
            throw new IndexOutOfBoundsException();
        }
        this.motif.initializeFunctionRandomly(false);
        this.reset();
        this.set();
        this.logNorm = Double.NaN;
    }

    @Override
    public boolean modifyMotif(int motifIndex, int offsetLeft, int offsetRight) throws Exception {
        if (motifIndex == 0 && this.motif instanceof Mutable) {
            double norm_old = this.motif.getLogNormalizationConstant();
            boolean res = ((Mutable)((Object)this.motif)).modify(offsetLeft, offsetRight);
            if (res) {
                double norm_new = this.motif.getLogNormalizationConstant();
                int n = motifIndex;
                this.hiddenParameters[n] = this.hiddenParameters[n] + (norm_old - norm_new);
                this.logHiddenNorm = Normalisation.logSumNormalisation(this.hiddenParameters, 0, this.hiddenParameters.length, this.hiddenPotential, 0);
                this.positionHash.clear();
                this.reset();
                this.set();
                this.logNorm = Double.NaN;
            }
            return res;
        }
        return false;
    }

    @Override
    public int getGlobalIndexOfMotifInComponent(int component, int motif) {
        return component;
    }

    @Override
    public int getIndexOfMaximalComponentFor(Sequence sequence) throws Exception {
        int[] help = new int[3];
        this.setHiddenVariables(help, sequence, 0);
        return help[0];
    }

    private void setHiddenVariables(int[] hiddenVariables, Sequence sequence, int startpos) {
        double best = this.hiddenParameters[1] - this.logHiddenNorm;
        hiddenVariables[0] = 1;
        int end = sequence.getLength() - startpos - this.motifLength + 1;
        double pos = Double.NaN;
        float[] position = this.getPosition(sequence, false);
        if (position == null) {
            pos = -Math.log(end);
        }
        try {
            int l = 0;
            while (l < end) {
                double c;
                if (position != null) {
                    pos = position[startpos];
                }
                if ((c = this.hiddenParameters[0] - this.logHiddenNorm + pos + MINUS_LOG_2 + this.getMotifScore(false, sequence, startpos) - (double)this.motifLength * this.logP) > best) {
                    best = c;
                    hiddenVariables[0] = 0;
                    hiddenVariables[1] = startpos;
                    hiddenVariables[2] = 0;
                }
                if ((c = this.hiddenParameters[0] - this.logHiddenNorm + pos + MINUS_LOG_2 + this.getMotifScore(true, sequence, startpos) - (double)this.motifLength * this.logP) > best) {
                    best = c;
                    hiddenVariables[0] = 0;
                    hiddenVariables[1] = startpos;
                    hiddenVariables[2] = 1;
                }
                ++l;
                ++startpos;
            }
        }
        catch (Exception e) {
            RuntimeException r = new RuntimeException(e.getMessage());
            throw r;
        }
    }

    protected Sequence getReference(Sequence seq) {
        SequenceAnnotation seqAn = seq.getSequenceAnnotationByType("reference", 0);
        return seqAn == null ? null : ((ReferenceSequenceAnnotation)seqAn).getReferenceSequence();
    }

    protected float[] getPosition(Sequence seq, boolean add) {
        Sequence ref;
        float[] res = this.positionHash.get(seq);
        if (res == null && (ref = this.getReference(seq)) != null) {
            res = new float[seq.getLength() - this.motifLength + 1];
            float sum = 0.0f;
            int i = 0;
            while (i < res.length) {
                res[i] = (float)ref.continuousVal(i);
                sum += res[i];
                ++i;
            }
            i = 0;
            while (i < res.length) {
                res[i] = (float)Math.log(res[i] / sum);
                ++i;
            }
            if (add) {
                this.positionHash.put(seq, res);
            }
        }
        return res;
    }

    @Override
    public int getMotifLength(int motif) {
        if (motif == 0) {
            return this.motif.getLength();
        }
        throw new IndexOutOfBoundsException();
    }

    @Override
    public int getNumberOfComponents() {
        return 2;
    }

    @Override
    public int getNumberOfMotifs() {
        return 1;
    }

    @Override
    public int getNumberOfMotifsInComponent(int component) {
        return component == 0 ? 1 : 0;
    }

    @Override
    public double[] getProfileOfScoresFor(int component, int motif, Sequence sequence, int startpos, MotifDiscoverer.KindOfProfile kind) throws Exception {
        if (motif == 0 && component == 0) {
            double[] res;
            double d = 0.0;
            int l = sequence.getLength() - startpos - this.motifLength;
            if (l >= 0) {
                double[] h = new double[2];
                res = new double[l + 1];
                int i = 0;
                while (i < res.length) {
                    this.fill(h, sequence, startpos + i);
                    res[i] = Normalisation.getLogSum(h);
                    ++i;
                }
                switch (kind) {
                    case UNNORMALIZED_JOINT: {
                        d = this.hiddenParameters[component] - this.logHiddenNorm;
                    }
                    case UNNORMALIZED_CONDITIONAL: {
                        break;
                    }
                    case NORMALIZED_CONDITIONAL: {
                        d = -Normalisation.getLogSum(res);
                        break;
                    }
                    default: {
                        throw new IndexOutOfBoundsException();
                    }
                }
                i = 0;
                while (i < res.length) {
                    int n = i++;
                    res[n] = res[n] + d;
                }
            } else {
                res = new double[]{};
            }
            return res;
        }
        throw new IndexOutOfBoundsException();
    }

    @Override
    public double[] getStrandProbabilitiesFor(int component, int motif, Sequence sequence, int startpos) throws Exception {
        if (motif > 0 || component > 0) {
            throw new IndexOutOfBoundsException();
        }
        double[] res = new double[2];
        this.fill(res, sequence, startpos);
        Normalisation.logSumNormalisation(res);
        return res;
    }

    private void fill(double[] res, Sequence sequence, int startpos) throws Exception {
        res[0] = this.getMotifScore(false, sequence, startpos);
        res[1] = this.getMotifScore(true, sequence, startpos);
    }

    private double getMotifScore(boolean rc, Sequence sequence, int startPos) throws Exception {
        if (rc) {
            return this.motif.getLogScoreFor(sequence.reverseComplement(), sequence.getLength() - startPos - this.motifLength);
        }
        return this.motif.getLogScoreFor(sequence, startPos);
    }

    @Override
    public String toString() {
        throw new Error("Unresolved compilation problem: \n\tCannot override the final method from AbstractDifferentiableStatisticalModel\n");
    }

    public DifferentiableStatisticalModel getMotif() throws CloneNotSupportedException {
        return (DifferentiableStatisticalModel)this.motif.clone();
    }

    @Override
    public /* synthetic */ String toString(NumberFormat numberFormat) {
        throw new Error("Unresolved compilation problem: \n\tThe type CompleteViterbiChipper must implement the inherited abstract method SequenceScore.toString(NumberFormat)\n");
    }
}

