/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.sequenceScores.statisticalModels.trainable.mixture;

import de.jstacs.algorithms.optimization.termination.TerminationCondition;
import de.jstacs.data.DataSet;
import de.jstacs.data.WrongAlphabetException;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.NonParsableException;
import de.jstacs.sampling.BurnInTest;
import de.jstacs.sequenceScores.statisticalModels.trainable.TrainableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.trainable.mixture.AbstractMixtureTrainSM;
import de.jstacs.utils.random.MRGParams;
import de.jstacs.utils.random.MultivariateRandomGenerator;
import java.text.NumberFormat;
import java.util.Arrays;
import javax.naming.OperationNotSupportedException;

public class MixtureTrainSM
extends AbstractMixtureTrainSM {
    protected MixtureTrainSM(int length, TrainableStatisticalModel[] models, int starts, boolean estimateComponentProbs, double[] componentHyperParams, double[] weights, AbstractMixtureTrainSM.Algorithm algorithm, double alpha, TerminationCondition tc, AbstractMixtureTrainSM.Parameterization parametrization, int initialIteration, int stationaryIteration, BurnInTest burnInTest) throws IllegalArgumentException, WrongAlphabetException, CloneNotSupportedException {
        super(length, models, null, models.length, starts, estimateComponentProbs, componentHyperParams, weights, algorithm, alpha, tc, parametrization, initialIteration, stationaryIteration, burnInTest);
    }

    public MixtureTrainSM(int length, TrainableStatisticalModel[] models, int starts, double[] componentHyperParams, double alpha, TerminationCondition tc, AbstractMixtureTrainSM.Parameterization parametrization) throws IllegalArgumentException, WrongAlphabetException, CloneNotSupportedException {
        this(length, models, starts, true, componentHyperParams, null, AbstractMixtureTrainSM.Algorithm.EM, alpha, tc, parametrization, 0, 0, null);
    }

    public MixtureTrainSM(int length, TrainableStatisticalModel[] models, double[] weights, int starts, double alpha, TerminationCondition tc, AbstractMixtureTrainSM.Parameterization parametrization) throws IllegalArgumentException, WrongAlphabetException, CloneNotSupportedException {
        this(length, models, starts, false, null, weights, AbstractMixtureTrainSM.Algorithm.EM, alpha, tc, parametrization, 0, 0, null);
    }

    public MixtureTrainSM(int length, TrainableStatisticalModel[] models, int starts, double[] componentHyperParams, int initialIteration, int stationaryIteration, BurnInTest burnInTest) throws IllegalArgumentException, WrongAlphabetException, CloneNotSupportedException {
        this(length, models, starts, true, componentHyperParams, null, AbstractMixtureTrainSM.Algorithm.GIBBS_SAMPLING, 0.0, null, AbstractMixtureTrainSM.Parameterization.LAMBDA, initialIteration, stationaryIteration, burnInTest);
    }

    public MixtureTrainSM(int length, TrainableStatisticalModel[] models, double[] weights, int starts, int initialIteration, int stationaryIteration, BurnInTest burnInTest) throws IllegalArgumentException, WrongAlphabetException, CloneNotSupportedException {
        this(length, models, starts, false, null, weights, AbstractMixtureTrainSM.Algorithm.GIBBS_SAMPLING, 0.0, null, AbstractMixtureTrainSM.Parameterization.LAMBDA, initialIteration, stationaryIteration, burnInTest);
    }

    public MixtureTrainSM(StringBuffer xml) throws NonParsableException {
        super(xml);
    }

    @Override
    protected Sequence[] emitDataSetUsingCurrentParameterSet(int n, int ... lengths) throws Exception {
        int no;
        int[] numbers = new int[this.dimension];
        Arrays.fill(numbers, 0);
        int counter = 0;
        int k = 0;
        for (no = 0; no < n; ++no) {
            int n2 = AbstractMixtureTrainSM.draw(this.weights, 0);
            numbers[n2] = numbers[n2] + 1;
        }
        no = 0;
        Sequence[] seqs = new Sequence[n];
        if (this.length == 0) {
            while (counter < this.dimension) {
                if (numbers[counter] > 0) {
                    DataSet help;
                    if (lengths.length == 1) {
                        help = this.model[counter].emitDataSet(n, lengths);
                    } else {
                        int[] array = new int[numbers[counter]];
                        System.arraycopy(lengths, k, array, 0, numbers[counter]);
                        help = this.model[counter].emitDataSet(n, array);
                    }
                    for (k = 0; k < help.getNumberOfElements(); ++k) {
                        seqs[no] = help.getElementAt(k);
                    }
                }
                ++counter;
            }
        } else if (lengths == null || lengths.length == 0) {
            while (counter < this.dimension) {
                if (numbers[counter] > 0) {
                    DataSet help = this.model[counter].getLength() == 0 ? this.model[counter].emitDataSet(numbers[counter], this.length) : this.model[counter].emitDataSet(numbers[counter], lengths);
                    no = 0;
                    while (no < numbers[counter]) {
                        seqs[k] = help.getElementAt(no);
                        ++no;
                        ++k;
                    }
                }
                ++counter;
            }
        } else {
            throw new Exception("This is an inhomogeneous model. Please check parameter lengths.");
        }
        return seqs;
    }

    @Override
    protected double[][] doFirstIteration(double[] dataWeights, MultivariateRandomGenerator m, MRGParams[] params) throws Exception {
        int d = this.sample[0].getNumberOfElements();
        double[][] seqweights = this.createSeqWeightsArray();
        double[] w = new double[this.dimension];
        this.initWithPrior(w);
        double[] help = new double[this.dimension];
        if (dataWeights == null) {
            for (int counter1 = 0; counter1 < d; ++counter1) {
                help = m.generate(this.dimension, params[counter1]);
                for (int counter2 = 0; counter2 < this.dimension; ++counter2) {
                    seqweights[counter2][counter1] = help[counter2];
                    int n = counter2;
                    w[n] = w[n] + help[counter2];
                }
            }
        } else {
            for (int counter1 = 0; counter1 < d; ++counter1) {
                help = m.generate(this.dimension, params[counter1]);
                for (int counter2 = 0; counter2 < this.dimension; ++counter2) {
                    seqweights[counter2][counter1] = dataWeights[counter1] * help[counter2];
                    int n = counter2;
                    w[n] = w[n] + seqweights[counter2][counter1];
                }
            }
        }
        this.getNewParameters(0, seqweights, w);
        return seqweights;
    }

    public double[][] doFirstIteration(DataSet data, double[] dataWeights, double[][] partitioning) throws Exception {
        this.setTrainData(data);
        if (this.dimension > 1) {
            int d = data.getNumberOfElements();
            double[][] seqweights = this.createSeqWeightsArray();
            double[] w = new double[this.dimension];
            this.initWithPrior(w);
            for (int counter1 = 0; counter1 < d; ++counter1) {
                if (partitioning[counter1].length != this.dimension) {
                    throw new IllegalArgumentException("The partitioning for sequence " + counter1 + " was wrong. (number of parts)");
                }
                double sum = 0.0;
                for (int counter2 = 0; counter2 < this.dimension; ++counter2) {
                    if (partitioning[counter1][counter2] < 0.0 || partitioning[counter1][counter2] > 1.0) {
                        throw new IllegalArgumentException("The partitioning for sequence " + counter1 + " was wrong. (part " + counter2 + "was incorrect)");
                    }
                    seqweights[counter2][counter1] = (dataWeights == null ? 1.0 : dataWeights[counter1]) * partitioning[counter1][counter2];
                    sum += partitioning[counter1][counter2];
                    int n = counter2;
                    w[n] = w[n] + seqweights[counter2][counter1];
                }
                if (sum == 1.0) continue;
                throw new IllegalArgumentException("The partitioning for sequence " + counter1 + " was wrong. (sum of parts not 1)");
            }
            this.getNewParameters(0, seqweights, w);
            return seqweights;
        }
        throw new OperationNotSupportedException();
    }

    @Override
    protected double getLogProbUsingCurrentParameterSetFor(int component, Sequence s, int start, int end) throws Exception {
        return this.logWeights[component] + this.model[component].getLogProbFor(s, start, end);
    }

    @Override
    public String toString(NumberFormat nf) {
        StringBuffer sb = new StringBuffer(this.model.length * 100000);
        sb.append("Mixture model with parameter estimation by " + this.getNameOfAlgorithm() + ": \n");
        sb.append("number of starts:\t" + this.starts + "\n");
        switch (this.algorithm) {
            case EM: {
                for (int i = 0; i < this.dimension; ++i) {
                    sb.append(nf.format(this.weights[i]) + "\t" + this.model[i].getInstanceName() + "\n" + this.model[i].toString(nf) + "\n");
                }
                break;
            }
            case GIBBS_SAMPLING: {
                sb.append("burn in test              :\t" + this.burnInTest.getInstanceName() + "\n");
                sb.append("length of stationary phase:\t" + this.stationaryIteration + "\n");
                sb.append("Mixture model components:\n");
                for (int i = 0; i < this.dimension; ++i) {
                    sb.append(i + 1 + ". component: " + this.model[i].getInstanceName() + "\n");
                }
                break;
            }
            default: {
                throw new IllegalArgumentException("The type of algorithm is unknown.");
            }
        }
        return sb.toString();
    }

    @Override
    protected double getNewWeights(double[] dataWeights, double[] w, double[][] seqweights) throws Exception {
        double L = 0.0;
        double currentWeight = 1.0;
        int counter2 = 0;
        this.initWithPrior(w);
        double[] help = new double[this.dimension];
        for (int counter1 = 0; counter1 < seqweights[0].length; ++counter1) {
            Sequence seq = this.sample[0].getElementAt(counter1);
            if (dataWeights != null) {
                currentWeight = dataWeights[counter1];
            }
            for (counter2 = 0; counter2 < this.dimension; ++counter2) {
                help[counter2] = this.model[counter2].getLogProbFor(seq) + this.logWeights[counter2];
            }
            L += this.modifyWeights(help) * currentWeight;
            for (counter2 = 0; counter2 < this.dimension; ++counter2) {
                seqweights[counter2][counter1] = help[counter2] * currentWeight;
                int n = counter2;
                w[n] = w[n] + seqweights[counter2][counter1];
            }
        }
        return L;
    }

    @Override
    protected void setTrainData(DataSet data) {
        this.sample = new DataSet[]{data};
    }
}

