/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.scoringFunctions.mix;

import de.jstacs.NonParsableException;
import de.jstacs.data.Sample;
import de.jstacs.data.Sequence;
import de.jstacs.scoringFunctions.NormalizableScoringFunction;
import de.jstacs.scoringFunctions.mix.AbstractMixtureScoringFunction;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.random.DirichletMRG;
import de.jstacs.utils.random.DirichletMRGParams;
import java.util.Arrays;

public class MixtureScoringFunction
extends AbstractMixtureScoringFunction {
    public MixtureScoringFunction(int starts, boolean plugIn, NormalizableScoringFunction ... component) throws CloneNotSupportedException {
        super(component[0].getLength(), starts, component.length, true, plugIn, component);
        for (int i = 0; i < component.length; ++i) {
            if (this.length != component[i].getLength()) {
                throw new IllegalArgumentException("The length of component " + i + " is not " + this.length + ".");
            }
            if (this.alphabets.checkConsistency(component[i].getAlphabetContainer())) continue;
            throw new IllegalArgumentException("The AlphabetContainer of component " + i + " is not suitable.");
        }
        this.computeLogGammaSum();
    }

    public MixtureScoringFunction(StringBuffer xml) throws NonParsableException {
        super(xml);
    }

    protected double getNormalizationConstantForComponent(int i) {
        return this.function[i].getNormalizationConstant();
    }

    public double getPartialNormalizationConstant(int parameterIndex) throws Exception {
        int[] ind;
        if (this.isNormalized) {
            return 0.0;
        }
        if (this.norm < 0.0) {
            this.precomputeNorm();
        }
        if ((ind = this.getIndices(parameterIndex))[0] == this.function.length) {
            return this.partNorm[ind[1]];
        }
        return this.hiddenPotential[ind[0]] * this.function[ind[0]].getPartialNormalizationConstant(ind[1]);
    }

    public double getHyperparameterForHiddenParameter(int index) {
        return this.function[index].getEss();
    }

    public double getEss() {
        double ess = 0.0;
        for (int i = 0; i < this.function.length; ++i) {
            ess += this.function[i].getEss();
        }
        return ess;
    }

    protected void initializeUsingPlugIn(int index, boolean freeParams, Sample[] data, double[][] weights) throws Exception {
        int j;
        Arrays.fill(this.hiddenParameter, 0.0);
        if (weights == null) {
            weights = new double[data.length][];
        }
        double[][] newWeights = new double[this.function.length][data[index].getNumberOfElements()];
        int i = 0;
        double[] h = new double[this.getNumberOfComponents()];
        if (this.getEss() == 0.0) {
            Arrays.fill(h, 1.0);
        } else {
            for (j = 0; j < h.length; ++j) {
                h[j] = this.getHyperparameterForHiddenParameter(j);
            }
        }
        DirichletMRGParams param = new DirichletMRGParams(h);
        double[] p = new double[h.length];
        double w = 1.0;
        while (i < newWeights[0].length) {
            DirichletMRG.DEFAULT_INSTANCE.generate(p, 0, p.length, param);
            if (weights[index] != null) {
                w = weights[index][i];
            }
            for (j = 0; j < p.length; ++j) {
                newWeights[j][i] = w * p[j];
                int n = j;
                this.hiddenParameter[n] = this.hiddenParameter[n] + newWeights[j][i];
            }
            ++i;
        }
        h = weights[index];
        for (i = 0; i < this.function.length; ++i) {
            weights[index] = newWeights[i];
            this.function[i].initializeFunction(index, freeParams, data, (double[][])weights);
        }
        weights[index] = h;
        this.computeHiddenParameter(this.hiddenParameter);
    }

    public String getInstanceName() {
        String erg = "mixture(" + this.function[0].getInstanceName();
        for (int i = 1; i < this.function.length; ++i) {
            erg = erg + ", " + this.function[i].getInstanceName();
        }
        return erg + ")";
    }

    protected void fillComponentScores(Sequence seq, int start) {
        for (int i = 0; i < this.function.length; ++i) {
            this.componentScore[i] = this.logHiddenPotential[i] + this.function[i].getLogScore(seq, start);
        }
    }

    public double getLogScoreAndPartialDerivation(Sequence seq, int start, IntList indices, DoubleList partialDer) {
        int i;
        int j = 0;
        int k = this.paramRef.length - 1;
        k = this.paramRef[k] - this.paramRef[k - 1];
        for (i = 0; i < this.function.length; ++i) {
            this.iList[i].clear();
            this.dList[i].clear();
            this.componentScore[i] = this.logHiddenPotential[i] + this.function[i].getLogScoreAndPartialDerivation(seq, start, this.iList[i], this.dList[i]);
        }
        double logScore = Normalisation.logSumNormalisation(this.componentScore, 0, this.function.length, this.componentScore, 0);
        for (i = 0; i < this.function.length; ++i) {
            for (j = 0; j < this.iList[i].length(); ++j) {
                indices.add(this.paramRef[i] + this.iList[i].get(j));
                partialDer.add(this.componentScore[i] * this.dList[i].get(j));
            }
        }
        for (j = 0; j < k; ++j) {
            indices.add(this.paramRef[i] + j);
            partialDer.add(this.componentScore[j] - (this.isNormalized ? this.hiddenPotential[j] : 0.0));
        }
        return logScore;
    }

    public String toString() {
        if (this.norm < 0.0) {
            this.precomputeNorm();
        }
        StringBuffer erg = new StringBuffer(this.function.length * 1000);
        for (int i = 0; i < this.function.length; ++i) {
            erg.append("p(" + i + ") = " + (this.isNormalized ? this.hiddenPotential[i] : this.partNorm[i] / this.norm) + "\n" + this.function[i].toString() + "\n");
        }
        return erg.toString();
    }
}

