/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.sequenceScores.statisticalModels.trainable.mixture;

import de.jstacs.NotTrainedException;
import de.jstacs.algorithms.optimization.termination.SmallDifferenceOfFunctionEvaluationsCondition;
import de.jstacs.algorithms.optimization.termination.TerminationCondition;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.WrongAlphabetException;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.FileManager;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.results.NumericalResult;
import de.jstacs.results.NumericalResultSet;
import de.jstacs.results.Result;
import de.jstacs.results.ResultSet;
import de.jstacs.results.StorableResult;
import de.jstacs.sampling.BurnInTest;
import de.jstacs.sampling.GibbsSamplingModel;
import de.jstacs.sampling.SamplingComponent;
import de.jstacs.sequenceScores.statisticalModels.trainable.AbstractTrainableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.trainable.TrainableStatisticalModel;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.SafeOutputStream;
import de.jstacs.utils.Time;
import de.jstacs.utils.random.DirichletMRG;
import de.jstacs.utils.random.DirichletMRGParams;
import de.jstacs.utils.random.FastDirichletMRGParams;
import de.jstacs.utils.random.MRGParams;
import de.jstacs.utils.random.MultivariateRandomGenerator;
import de.jstacs.utils.random.SoftOneOfN;
import de.jtem.numericalMethods.calculus.specialFunctions.Gamma;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.OutputStream;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.Random;
import java.util.TreeMap;
import javax.naming.OperationNotSupportedException;

public abstract class AbstractMixtureTrainSM
extends AbstractTrainableStatisticalModel {
    protected double[] weights;
    protected double[] logWeights;
    protected double[] componentHyperParams;
    protected TrainableStatisticalModel[] model;
    protected TrainableStatisticalModel[] alternativeModel;
    protected int starts;
    protected int dimension;
    protected double best;
    protected SafeOutputStream sostream;
    protected DataSet[] sample;
    protected boolean estimateComponentProbs;
    protected boolean[] optimizeModel;
    protected Algorithm algorithm;
    protected boolean algorithmHasBeenRun;
    private Parameterization parametrization;
    private double alpha = 1.0;
    private TerminationCondition tc;
    protected int initialIteration;
    protected int stationaryIteration;
    protected BurnInTest burnInTest;
    protected BufferedWriter filewriter;
    protected BufferedReader filereader;
    protected File[] file;
    protected int[] counter;
    protected int samplingIndex;
    protected double[] compProb;
    private double[][][] usedWeights;
    protected double[][] seqWeights;
    private static final Random r = new Random();

    protected AbstractMixtureTrainSM(int length, TrainableStatisticalModel[] models, boolean[] optimizeModel, int dimension, int starts, boolean estimateComponentProbs, double[] componentHyperParams, double[] weights, Algorithm algorithm, double alpha, TerminationCondition tc, Parameterization parametrization, int initialIteration, int stationaryIteration, BurnInTest burnInTest) throws CloneNotSupportedException, IllegalArgumentException, WrongAlphabetException {
        super(models[0].getAlphabetContainer(), length);
        if (dimension < 1) {
            throw new IllegalArgumentException("The dimension has to be at least 1.");
        }
        this.dimension = dimension;
        this.set((TrainableStatisticalModel[])ArrayHandler.clone((Cloneable[])models), optimizeModel, starts, weights, estimateComponentProbs, componentHyperParams, algorithm, alpha, tc, parametrization, initialIteration, stationaryIteration, burnInTest != null ? burnInTest.clone() : null);
        this.setOutputStream(SafeOutputStream.DEFAULT_STREAM);
        this.algorithmHasBeenRun = false;
    }

    protected AbstractMixtureTrainSM(StringBuffer xml) throws NonParsableException {
        super(xml);
    }

    @Override
    public AbstractMixtureTrainSM clone() throws CloneNotSupportedException {
        try {
            AbstractMixtureTrainSM clone = (AbstractMixtureTrainSM)super.clone();
            clone.weights = null;
            clone.set((TrainableStatisticalModel[])ArrayHandler.clone((Cloneable[])this.model), this.optimizeModel, this.starts, this.weights, this.estimateComponentProbs, this.componentHyperParams, this.algorithm, this.alpha, this.tc, this.parametrization, this.initialIteration, this.stationaryIteration, this.burnInTest != null ? this.burnInTest.clone() : null);
            if (this.file != null) {
                clone.counter = (int[])this.counter.clone();
                clone.file = new File[this.file.length];
                try {
                    for (int i = 0; i < this.file.length; ++i) {
                        if (this.file[i] == null) continue;
                        clone.file[i] = File.createTempFile("pi-", ".dat", null);
                        FileManager.copy(this.file[i].getAbsolutePath(), clone.file[i].getAbsolutePath());
                    }
                }
                catch (IOException e) {
                    CloneNotSupportedException c = new CloneNotSupportedException(e.getMessage());
                    c.setStackTrace(e.getStackTrace());
                    throw c;
                }
            }
            clone.filereader = null;
            clone.filewriter = null;
            clone.setOutputStream(this.sostream.doesNothing() ? null : SafeOutputStream.DEFAULT_STREAM);
            clone.best = this.best;
            return clone;
        }
        catch (IllegalArgumentException e) {
            throw AbstractMixtureTrainSM.getCloneNotSupportedException(e);
        }
        catch (WrongAlphabetException e) {
            throw AbstractMixtureTrainSM.getCloneNotSupportedException(e);
        }
    }

    private static CloneNotSupportedException getCloneNotSupportedException(Exception e) {
        CloneNotSupportedException ex = new CloneNotSupportedException("impossible Exception in method clone in class AbstractMixtureTrainSM: " + e.getMessage());
        ex.setStackTrace(e.getStackTrace());
        return ex;
    }

    protected MultivariateRandomGenerator getMRG() {
        switch (this.algorithm) {
            case EM: {
                return DirichletMRG.DEFAULT_INSTANCE;
            }
            case GIBBS_SAMPLING: {
                return new SoftOneOfN();
            }
        }
        throw new IllegalArgumentException("The type of algorithm is unknown.");
    }

    protected MRGParams getMRGParams() {
        switch (this.algorithm) {
            case EM: {
                return new FastDirichletMRGParams(this.alpha);
            }
            case GIBBS_SAMPLING: {
                return null;
            }
        }
        throw new IllegalArgumentException("The type of algorithm is unknown.");
    }

    @Override
    public void train(DataSet data, double[] dataWeights) throws Exception {
        this.sample = null;
        System.gc();
        this.setTrainData(data);
        MultivariateRandomGenerator rg = this.getMRG();
        Object[] params = new MRGParams[data.getNumberOfElements()];
        Arrays.fill(params, this.getMRGParams());
        switch (this.algorithm) {
            case EM: {
                double max = Double.NEGATIVE_INFINITY;
                double[] p = (double[])this.weights.clone();
                if (this.alternativeModel == null) {
                    this.alternativeModel = (TrainableStatisticalModel[])ArrayHandler.clone((Cloneable[])this.model);
                }
                for (int i = 0; i < this.starts; ++i) {
                    double current = this.iterate(i, dataWeights, rg, (MRGParams[])params);
                    if (!(max < current)) continue;
                    this.swap();
                    p = (double[])this.weights.clone();
                    max = current;
                }
                this.swap();
                this.setWeights(p);
                p = null;
                this.best = max;
                this.sostream.writeln("best = " + max);
                break;
            }
            case GIBBS_SAMPLING: {
                boolean finished;
                int i;
                this.burnInTest.resetAllValues();
                this.initModelForSampling(this.starts);
                for (i = 0; i < this.starts; ++i) {
                    double current = this.iterate(i, dataWeights, rg, (MRGParams[])params);
                }
                do {
                    int burnIn = this.burnInTest.getLengthOfBurnIn();
                    int m = 0;
                    int anz = 0;
                    finished = true;
                    for (i = 0; i < this.starts; ++i) {
                        m = this.counter[i] - burnIn;
                        if (m > 0) {
                            anz += m;
                            continue;
                        }
                        finished = false;
                    }
                    if ((anz = (int)Math.ceil((double)(this.stationaryIteration - anz) / (double)this.starts)) <= 0) continue;
                    for (i = 0; i < this.starts; ++i) {
                        this.sostream.writeln("=== extend start: " + i + " ==========");
                        this.continueIterations(dataWeights, this.seqWeights, anz, i);
                    }
                } while (!finished);
                break;
            }
            default: {
                throw new IllegalArgumentException("The type of algorithm is unknown.");
            }
        }
        rg = null;
        params = null;
        System.gc();
    }

    protected void swap() {
        TrainableStatisticalModel[] helpM = this.alternativeModel;
        this.alternativeModel = this.model;
        this.model = helpM;
    }

    protected abstract void setTrainData(DataSet var1) throws Exception;

    protected double[][] createSeqWeightsArray() {
        return new double[this.model.length][this.sample[0].getNumberOfElements()];
    }

    public double iterate(DataSet data, double[] dataWeights, MultivariateRandomGenerator m, MRGParams[] params) throws Exception {
        this.sample = null;
        System.gc();
        this.setTrainData(data);
        return this.iterate(0, dataWeights, m, params);
    }

    protected double iterate(int start, double[] dataWeights, MultivariateRandomGenerator m, MRGParams[] params) throws Exception {
        this.sostream.writeln("========== start: " + start + " ==========");
        switch (this.algorithm) {
            case EM: {
                this.best = this.continueIterations(dataWeights, this.doFirstIteration(dataWeights, m, params));
                break;
            }
            case GIBBS_SAMPLING: {
                this.extendSampling(start);
                this.burnInTest.setCurrentSamplingIndex(start);
                this.seqWeights = this.doFirstIteration(dataWeights, m, params);
                this.samplingStopped();
                this.continueIterations(dataWeights, this.seqWeights, this.initialIteration, start);
                break;
            }
            default: {
                throw new IllegalArgumentException("The type of algorithm is unknown.");
            }
        }
        this.algorithmHasBeenRun = true;
        return this.best;
    }

    protected double[][] doFirstIteration(DataSet data, double[] dataWeights) throws Exception {
        Object[] params = new FastDirichletMRGParams[data.getNumberOfElements()];
        Arrays.fill(params, new FastDirichletMRGParams(this.alpha));
        return this.doFirstIteration(data, dataWeights, DirichletMRG.DEFAULT_INSTANCE, (MRGParams[])params);
    }

    protected double[][] doFirstIteration(DataSet data, double[] dataWeights, MultivariateRandomGenerator m, MRGParams[] params) throws Exception {
        this.sample = null;
        System.gc();
        this.setTrainData(data);
        return this.doFirstIteration(dataWeights, m, params);
    }

    protected abstract double[][] doFirstIteration(double[] var1, MultivariateRandomGenerator var2, MRGParams[] var3) throws Exception;

    protected double continueIterations(double[] dataWeights, double[][] seqweights) throws Exception {
        if (this.sample == null) {
            throw new OperationNotSupportedException("There is no reference to an internal data set, so you can not go on with training.");
        }
        int i = 0;
        double[] w = new double[this.dimension];
        if (seqweights == null) {
            seqweights = this.createSeqWeightsArray();
        }
        double pr = this.getLogPriorTerm();
        double L_old = Double.NEGATIVE_INFINITY;
        double L_new = this.getNewWeights(dataWeights, w, seqweights);
        this.sostream.write(i + "\t" + 0 + "\t" + L_new + "\t " + pr + "\t");
        this.sostream.writeln((L_new += pr) + "\t" + (L_new - L_old));
        Time t = Time.getTimeInstance(this.sostream);
        while (this.tc.doNextIteration(i, L_old, L_new, null, null, Double.NaN, t)) {
            this.getNewParameters(++i, seqweights, w);
            L_old = L_new;
            pr = this.getLogPriorTerm();
            L_new = this.getNewWeights(dataWeights, w, seqweights);
            this.sostream.write(i + "\t" + t.getElapsedTime() + "\t" + L_new + "\t " + pr + "\t");
            this.sostream.writeln((L_new += pr) + "\t" + (L_new - L_old));
        }
        return L_new;
    }

    protected double continueIterations(double[] dataWeights, double[][] seqweights, int iterations, int start) throws Exception {
        int j;
        if (this.burnInTest != null) {
            this.extendSampling(start);
            this.burnInTest.setCurrentSamplingIndex(start);
        }
        if (this.sample == null) {
            throw new OperationNotSupportedException("There is no reference to an internal data set, so you can not go on with training.");
        }
        double[] w = new double[this.dimension];
        if (seqweights == null) {
            seqweights = this.createSeqWeightsArray();
        }
        double pr = this.getLogPriorTerm();
        double L_old = Double.NEGATIVE_INFINITY;
        double L_new = this.getNewWeights(dataWeights, w, seqweights);
        int i = 0;
        int n = j = this.burnInTest == null ? 0 : this.counter[this.samplingIndex];
        while (i < iterations) {
            this.sostream.write(j + "\t" + L_new + "\t " + pr + "\t");
            this.sostream.writeln((L_new += pr) + "\t" + (L_new - L_old));
            if (this.burnInTest != null) {
                this.burnInTest.setValue(L_new);
            }
            this.getNewParameters(i, seqweights, w);
            L_old = L_new;
            pr = this.getLogPriorTerm();
            L_new = this.getNewWeights(dataWeights, w, seqweights);
            ++i;
            ++j;
        }
        if (this.burnInTest != null) {
            this.samplingStopped();
        }
        return L_new + pr;
    }

    protected void getNewParameters(int iteration, double[][] seqWeights, double[] w) throws Exception {
        for (int i = 0; i < seqWeights.length; ++i) {
            this.getNewParametersForModel(i, iteration, 0, seqWeights[i]);
        }
        this.getNewComponentProbs(w);
    }

    protected void getNewParametersForModel(int modelIndex, int iteration, int sampleIndex, double[] seqWeights) throws Exception {
        if (this.optimizeModel[modelIndex]) {
            switch (this.algorithm) {
                case EM: {
                    if (this.model[modelIndex] instanceof AbstractMixtureTrainSM) {
                        if (iteration == 0) {
                            this.usedWeights[modelIndex] = ((AbstractMixtureTrainSM)this.model[modelIndex]).doFirstIteration(this.sample[sampleIndex], seqWeights);
                            break;
                        }
                        ((AbstractMixtureTrainSM)this.model[modelIndex]).continueIterations(seqWeights, this.usedWeights[modelIndex], 1, 0);
                        break;
                    }
                    this.model[modelIndex].train(this.sample[sampleIndex], seqWeights);
                    break;
                }
                case GIBBS_SAMPLING: {
                    ((GibbsSamplingModel)((Object)this.model[modelIndex])).drawParameters(this.sample[sampleIndex], seqWeights);
                    ((GibbsSamplingModel)((Object)this.model[modelIndex])).acceptParameters();
                    break;
                }
                default: {
                    throw new IllegalArgumentException("The type of algorithm is unknown.");
                }
            }
        }
    }

    protected abstract double getNewWeights(double[] var1, double[] var2, double[][] var3) throws Exception;

    protected double modifyWeights(double[] w) {
        switch (this.algorithm) {
            case EM: {
                return Normalisation.logSumNormalisation(w, 0, w.length, w, 0);
            }
            case GIBBS_SAMPLING: {
                double l = Normalisation.logSumNormalisation(w, 0, w.length, w, 0);
                int index = AbstractMixtureTrainSM.draw(w, 0);
                Arrays.fill(w, 0.0);
                w[index] = 1.0;
                return l;
            }
        }
        throw new IllegalArgumentException("The type of algorithm is unknown.");
    }

    protected void initWithPrior(double[] w) {
        System.arraycopy(this.componentHyperParams, 0, w, 0, this.dimension);
    }

    public double getLogProbFor(int component, Sequence s) throws Exception {
        switch (this.algorithm) {
            case EM: {
                return this.getLogProbUsingCurrentParameterSetFor(component, s, 0, s.getLength() - 1);
            }
            case GIBBS_SAMPLING: {
                throw new OperationNotSupportedException();
            }
        }
        throw new IllegalArgumentException("The type of algorithm is unknown.");
    }

    protected abstract double getLogProbUsingCurrentParameterSetFor(int var1, Sequence var2, int var3, int var4) throws Exception;

    @Override
    public final double getLogProbFor(Sequence sequence, int startpos, int endpos) throws Exception {
        if (!this.isInitialized()) {
            throw new NotTrainedException();
        }
        switch (this.algorithm) {
            case EM: {
                for (int i = 0; i < this.dimension; ++i) {
                    this.compProb[i] = this.getLogProbUsingCurrentParameterSetFor(i, sequence, startpos, endpos);
                }
                return Normalisation.getLogSum(this.compProb);
            }
            case GIBBS_SAMPLING: {
                int anz = 0;
                int burnIn = this.burnInTest.getLengthOfBurnIn();
                double res = Double.NEGATIVE_INFINITY;
                for (int sampling = 0; sampling < this.starts; ++sampling) {
                    boolean b = this.parseParameterSet(sampling, burnIn);
                    while (b) {
                        for (int i = 0; i < this.dimension; ++i) {
                            this.compProb[i] = this.getLogProbUsingCurrentParameterSetFor(i, sequence, startpos, endpos);
                        }
                        res = Normalisation.getLogSum(res, Normalisation.getLogSum(this.compProb));
                        b = this.parseNextParameterSet();
                        ++anz;
                    }
                }
                return res - Math.log(anz);
            }
        }
        throw new IllegalArgumentException("The type of algorithm is unknown.");
    }

    @Override
    public final double[] getLogScoreFor(DataSet data) throws Exception {
        if (!this.isInitialized()) {
            throw new NotTrainedException();
        }
        switch (this.algorithm) {
            case EM: {
                return super.getLogScoreFor(data);
            }
            case GIBBS_SAMPLING: {
                int k;
                int anz = 0;
                int burnIn = this.burnInTest.getLengthOfBurnIn();
                Sequence[] sequence = data.getAllElements();
                double[] res = new double[sequence.length];
                Arrays.fill(res, Double.NEGATIVE_INFINITY);
                for (int sampling = 0; sampling < this.starts; ++sampling) {
                    boolean b = this.parseParameterSet(sampling, burnIn);
                    while (b) {
                        for (k = 0; k < sequence.length; ++k) {
                            for (int i = 0; i < this.dimension; ++i) {
                                this.compProb[i] = this.getLogProbUsingCurrentParameterSetFor(i, sequence[k], 0, sequence[k].getLength() - 1);
                            }
                            res[k] = Normalisation.getLogSum(res[k], Normalisation.getLogSum(this.compProb));
                        }
                        b = this.parseNextParameterSet();
                        ++anz;
                    }
                }
                double d = Math.log(anz);
                k = 0;
                while (k < sequence.length) {
                    int n = k++;
                    res[n] = res[n] - d;
                }
                return res;
            }
        }
        throw new IllegalArgumentException("The type of algorithm is unknown.");
    }

    @Override
    public double getLogPriorTerm() throws Exception {
        switch (this.algorithm) {
            case GIBBS_SAMPLING: {
                return 0.0;
            }
            case EM: {
                double erg = 0.0;
                for (int counter = 0; counter < this.model.length; ++counter) {
                    if (!this.optimizeModel[counter]) continue;
                    erg += this.model[counter].getLogPriorTerm();
                }
                return erg + this.getLogPriorTermForComponentProbs();
            }
        }
        throw new IllegalArgumentException("The type of algorithm is unknown.");
    }

    protected final double getLogPriorTermForComponentProbs() {
        double prior = 0.0;
        double sum = 0.0;
        if (this.estimateComponentProbs && this.componentHyperParams[0] > 0.0) {
            for (int counter = 0; counter < this.dimension; ++counter) {
                sum += this.componentHyperParams[counter];
                prior += (this.componentHyperParams[counter] + this.parametrization.getCount()) * this.logWeights[counter] - Gamma.logOfGamma((double)this.componentHyperParams[counter]);
            }
            prior += Gamma.logOfGamma((double)sum);
        }
        return prior;
    }

    public final double getScoreForBestRun() throws NotTrainedException, OperationNotSupportedException {
        if (this.algorithmHasBeenRun()) {
            if (this.algorithm == Algorithm.EM) {
                return this.best;
            }
            throw new OperationNotSupportedException();
        }
        throw new NotTrainedException();
    }

    @Override
    public String getInstanceName() {
        StringBuffer erg = new StringBuffer(this.getClass().getSimpleName() + "(");
        erg.append(this.model[0].getInstanceName());
        for (int i = 1; i < this.model.length; ++i) {
            erg.append(", ");
            erg.append(this.model[i].getInstanceName());
        }
        if (!this.estimateComponentProbs) {
            erg.append("; " + Arrays.toString(this.weights));
        }
        erg.append(") " + this.getNameOfAlgorithm());
        return erg.toString();
    }

    public int getIndexOfMaximalComponentFor(Sequence s) throws Exception {
        switch (this.algorithm) {
            case EM: {
                double best = this.getLogProbFor(0, s);
                int index = 0;
                for (int i = 1; i < this.dimension; ++i) {
                    double current = this.getLogProbFor(i, s);
                    if (!(current > best)) continue;
                    best = current;
                    index = i;
                }
                return index;
            }
            case GIBBS_SAMPLING: {
                throw new OperationNotSupportedException();
            }
        }
        throw new IllegalArgumentException("The type of algorithm is unknown.");
    }

    public final TrainableStatisticalModel[] getModels() throws CloneNotSupportedException {
        return (TrainableStatisticalModel[])ArrayHandler.clone((Cloneable[])this.model);
    }

    public final TrainableStatisticalModel getModel(int i) throws CloneNotSupportedException {
        return this.model[i].clone();
    }

    public String getNameOfAlgorithm() {
        switch (this.algorithm) {
            case EM: {
                return "EM";
            }
            case GIBBS_SAMPLING: {
                return "Gibbs Sampling";
            }
        }
        throw new IllegalArgumentException("The type of algorithm is unknown.");
    }

    public final int getNumberOfComponents() {
        return this.dimension;
    }

    @Override
    public ResultSet getCharacteristics() throws Exception {
        LinkedList<Result> infos = new LinkedList<Result>();
        for (int i = 0; i < this.model.length; ++i) {
            ResultSet part = this.model[i].getCharacteristics();
            if (part == null || part.getNumberOfResults() <= 0) continue;
            infos.add(new NumericalResult("model number", "type of model " + this.model[i].getClass().getSimpleName(), new Integer(i)));
            for (int j = 0; j < part.getNumberOfResults(); ++j) {
                infos.add(part.getResultAt(j));
            }
        }
        infos.add(new StorableResult("model", "the xml representation of the model", this));
        return new ResultSet(infos);
    }

    @Override
    public NumericalResultSet getNumericalCharacteristics() throws Exception {
        LinkedList<NumericalResult> infos = new LinkedList<NumericalResult>();
        for (int i = 0; i < this.model.length; ++i) {
            NumericalResultSet part = this.model[i].getNumericalCharacteristics();
            if (part == null || part.getNumberOfResults() <= 0) continue;
            infos.add(new NumericalResult("model number", "type of model " + this.model[i].getClass().getSimpleName(), new Integer(i)));
            for (int j = 0; j < part.getNumberOfResults(); ++j) {
                infos.add(part.getResultAt(j));
            }
        }
        return new NumericalResultSet(infos);
    }

    public final double[] getWeights() {
        return (double[])this.weights.clone();
    }

    public boolean algorithmHasBeenRun() {
        return this.algorithmHasBeenRun;
    }

    @Override
    public boolean isInitialized() {
        switch (this.algorithm) {
            case EM: {
                int i;
                for (i = 0; i < this.model.length && this.model[i].isInitialized(); ++i) {
                }
                return i == this.model.length;
            }
            case GIBBS_SAMPLING: {
                return this.algorithmHasBeenRun;
            }
        }
        throw new IllegalArgumentException("The type of algorithm is unknown.");
    }

    public final void setAlpha(double alpha) throws IllegalArgumentException {
        if (alpha <= 0.0) {
            throw new IllegalArgumentException("alpha has to be strict positive.");
        }
        this.alpha = alpha;
    }

    public final void setOutputStream(OutputStream o) {
        this.sostream = SafeOutputStream.getSafeOutputStream(o);
    }

    protected void getNewComponentProbs(double[] weights) throws Exception {
        if (this.estimateComponentProbs) {
            int i;
            double sum = 0.0;
            switch (this.algorithm) {
                case EM: {
                    boolean map;
                    boolean bl = map = this.componentHyperParams[0] != 0.0;
                    for (i = 0; i < this.dimension; ++i) {
                        if (map) {
                            int n = i;
                            weights[n] = weights[n] + this.parametrization.getCount();
                        }
                        sum += weights[i];
                        if (!(weights[i] < 0.0)) continue;
                        throw new IllegalArgumentException("Every weight has to be at least 0. Violate at position " + i + ".");
                    }
                    for (i = 0; i < this.dimension; ++i) {
                        this.weights[i] = weights[i] / sum;
                    }
                    break;
                }
                case GIBBS_SAMPLING: {
                    DirichletMRG.DEFAULT_INSTANCE.generate(this.weights, 0, this.dimension, new DirichletMRGParams(weights));
                    this.filewriter.write(this.counter[this.samplingIndex] + "\t");
                    while (i < this.dimension) {
                        this.filewriter.write(this.weights[i] + "\t");
                        ++i;
                    }
                    this.filewriter.write("\n");
                    this.filewriter.flush();
                    break;
                }
                default: {
                    throw new IllegalArgumentException("The type of algorithm is unknown.");
                }
            }
            for (i = 0; i < this.dimension; ++i) {
                this.logWeights[i] = Math.log(this.weights[i]);
            }
        }
        if (this.algorithm == Algorithm.GIBBS_SAMPLING) {
            int n = this.samplingIndex;
            this.counter[n] = this.counter[n] + 1;
        }
    }

    protected void setWeights(double ... weights) throws IllegalArgumentException {
        int i;
        if (weights.length != this.dimension) {
            throw new IllegalArgumentException("The number of weights is incorrect");
        }
        double sum = 0.0;
        for (i = 0; i < this.dimension; ++i) {
            sum += weights[i];
            if (!(weights[i] < 0.0)) continue;
            throw new IllegalArgumentException("Every weight has to be at least 0. Violate at position " + i + ".");
        }
        if (Math.abs(1.0 - sum) > 1.0E-9) {
            throw new IllegalArgumentException("The weights do not sum to 1.");
        }
        if (this.weights == null) {
            this.weights = new double[this.dimension];
            this.logWeights = new double[this.dimension];
        }
        for (i = 0; i < this.dimension; ++i) {
            this.weights[i] = weights[i];
            this.logWeights[i] = Math.log(this.weights[i]);
        }
    }

    @Override
    public StringBuffer toXML() {
        StringBuffer xml = new StringBuffer(100000);
        XMLParser.appendObjectWithTags(xml, this.length, "length");
        XMLParser.appendObjectWithTags(xml, this.dimension, "dimension");
        XMLParser.appendObjectWithTags(xml, this.starts, "starts");
        XMLParser.appendObjectWithTags(xml, this.estimateComponentProbs, "estimateComponentProbs");
        XMLParser.appendObjectWithTags(xml, this.componentHyperParams, "componentHyperParams");
        XMLParser.appendObjectWithTags(xml, this.model, "models");
        XMLParser.appendObjectWithTags(xml, this.optimizeModel, "optimizeModel");
        XMLParser.appendObjectWithTags(xml, this.algorithmHasBeenRun, "algorithmHasBeenRun");
        XMLParser.appendObjectWithTags(xml, this.weights, "weights");
        XMLParser.appendObjectWithTags(xml, (Object)this.algorithm, "algorithm");
        switch (this.algorithm) {
            case EM: {
                XMLParser.appendObjectWithTags(xml, this.alpha, "alpha");
                XMLParser.appendObjectWithTags(xml, this.tc, "terminationCondition");
                XMLParser.appendObjectWithTags(xml, (Object)this.parametrization, "parametrization");
                break;
            }
            case GIBBS_SAMPLING: {
                XMLParser.appendObjectWithTags(xml, this.initialIteration, "initialIteration");
                XMLParser.appendObjectWithTags(xml, this.stationaryIteration, "stationaryIteration");
                XMLParser.appendObjectWithTags(xml, this.burnInTest, "burnInTest");
                XMLParser.appendObjectWithTags(xml, this.file != null, "hasParameterFiles");
                if (this.file == null) break;
                XMLParser.appendObjectWithTags(xml, this.counter, "counter");
                try {
                    for (int i = 0; i < this.counter.length; ++i) {
                        String content = this.file[i] != null ? FileManager.readFile(this.file[i]).toString() : "";
                        XMLParser.appendObjectWithTagsAndAttributes(xml, content, "fileContent", "pos=\"" + i + "\"");
                    }
                    break;
                }
                catch (IOException e) {
                    RuntimeException r = new RuntimeException(e.getMessage());
                    r.setStackTrace(e.getStackTrace());
                    throw r;
                }
            }
        }
        XMLParser.appendObjectWithTags(xml, this.best, "best");
        xml.append(this.getFurtherInformation());
        XMLParser.addTags(xml, this.getClass().getSimpleName());
        return xml;
    }

    protected StringBuffer getFurtherInformation() {
        return new StringBuffer(1);
    }

    @Override
    protected void fromXML(StringBuffer representation) throws NonParsableException {
        StringBuffer xml = XMLParser.extractForTag(representation, this.getClass().getSimpleName());
        this.length = XMLParser.extractObjectForTags(xml, "length", Integer.TYPE);
        this.dimension = XMLParser.extractObjectForTags(xml, "dimension", Integer.TYPE);
        this.starts = XMLParser.extractObjectForTags(xml, "starts", Integer.TYPE);
        this.estimateComponentProbs = XMLParser.extractObjectForTags(xml, "estimateComponentProbs", Boolean.TYPE);
        this.componentHyperParams = XMLParser.extractObjectForTags(xml, "componentHyperParams", double[].class);
        this.model = XMLParser.extractObjectForTags(xml, "models", TrainableStatisticalModel[].class);
        this.optimizeModel = XMLParser.extractObjectForTags(xml, "optimizeModel", boolean[].class);
        this.algorithmHasBeenRun = XMLParser.extractObjectForTags(xml, "algorithmHasBeenRun", Boolean.TYPE);
        double[] w = XMLParser.extractObjectForTags(xml, "weights", double[].class);
        this.algorithm = XMLParser.extractObjectForTags(xml, "algorithm", Algorithm.class);
        try {
            switch (this.algorithm) {
                case EM: {
                    this.parametrization = XMLParser.extractObjectForTags(xml, "parametrization", Parameterization.class);
                    this.tc = XMLParser.hasTag(xml, "epsilon", null, null) ? new SmallDifferenceOfFunctionEvaluationsCondition(XMLParser.extractObjectForTags(xml, "epsilon", Double.TYPE)) : XMLParser.extractObjectForTags(xml, "terminationCondition", TerminationCondition.class);
                    this.set(this.model, this.optimizeModel, this.starts, w, this.estimateComponentProbs, this.componentHyperParams, this.algorithm, XMLParser.extractObjectForTags(xml, "alpha", Double.TYPE), this.tc, this.parametrization, 0, 0, null);
                    break;
                }
                case GIBBS_SAMPLING: {
                    this.set(this.model, this.optimizeModel, this.starts, w, this.estimateComponentProbs, this.componentHyperParams, this.algorithm, 0.0, null, Parameterization.LAMBDA, XMLParser.extractObjectForTags(xml, "initialIteration", Integer.TYPE), XMLParser.extractObjectForTags(xml, "stationaryIteration", Integer.TYPE), XMLParser.extractObjectForTags(xml, "burnInTest", BurnInTest.class));
                    if (XMLParser.extractObjectForTags(xml, "hasParameterFiles", Boolean.TYPE).booleanValue()) {
                        this.counter = XMLParser.extractObjectForTags(xml, "counter", int[].class);
                        this.file = new File[this.counter.length];
                        try {
                            TreeMap<String, String> filter = new TreeMap<String, String>();
                            for (int i = 0; i < this.counter.length; ++i) {
                                filter.clear();
                                filter.put("pos", "" + i);
                                String content = XMLParser.extractObjectAndAttributesForTags(xml, "fileContent", null, filter, String.class);
                                if (content.equalsIgnoreCase("")) continue;
                                this.file[i] = File.createTempFile("pi-", ".dat", null);
                                FileManager.writeFile(this.file[i], (CharSequence)new StringBuffer(content));
                            }
                            break;
                        }
                        catch (IOException e) {
                            NonParsableException r = new NonParsableException(e.getMessage());
                            r.setStackTrace(e.getStackTrace());
                            throw r;
                        }
                    }
                    this.file = null;
                    break;
                }
                default: {
                    throw new IllegalArgumentException("The type of algorithm is unknown.");
                }
            }
        }
        catch (Exception e) {
            NonParsableException n = new NonParsableException(e.getMessage());
            n.setStackTrace(e.getStackTrace());
            throw n;
        }
        this.best = XMLParser.extractObjectForTags(xml, "best", Double.TYPE);
        this.alphabets = this.model[0].getAlphabetContainer();
        this.setOutputStream(SafeOutputStream.DEFAULT_STREAM);
        this.extractFurtherInformation(xml);
    }

    protected void extractFurtherInformation(StringBuffer xml) throws NonParsableException {
    }

    private void set(TrainableStatisticalModel[] model, boolean[] optimizeModel, int starts, double[] weights, boolean estimateComponentProbs, double[] componentHyperParams, Algorithm algorithm, double alpha, TerminationCondition tc, Parameterization parametrization, int initialIteration, int stationaryIteration, BurnInTest burnInTest) throws IllegalArgumentException, WrongAlphabetException {
        boolean minValueOfUsedHyperParamIsZero;
        int i;
        if (starts < 1) {
            throw new IllegalArgumentException("The number of iterations has to be at least 1.");
        }
        this.starts = starts;
        AlphabetContainer abc = model[0].getAlphabetContainer();
        for (i = 0; i < model.length; ++i) {
            if (i != 0 && !model[i].getAlphabetContainer().checkConsistency(abc)) {
                throw new WrongAlphabetException("The models have to have the same alphabet like the AbstractMixtureTrainSM. Violated at position " + i + ".");
            }
            if (model[i] instanceof AbstractMixtureTrainSM) {
                ((AbstractMixtureTrainSM)model[i]).setOutputStream(null);
            }
            this.checkLength(i, model[i].getLength());
        }
        if (optimizeModel == null) {
            this.optimizeModel = new boolean[model.length];
            Arrays.fill(this.optimizeModel, true);
        } else {
            if (optimizeModel.length != model.length) {
                throw new IllegalArgumentException("The dimension of the switch whether the individual models should be optimized/adjusted has wrong dimension.");
            }
            this.optimizeModel = new boolean[model.length];
            System.arraycopy(optimizeModel, 0, this.optimizeModel, 0, optimizeModel.length);
        }
        if (weights == null) {
            weights = new double[this.dimension];
            Arrays.fill(weights, 1.0 / (double)this.dimension);
        }
        this.setWeights(weights);
        this.model = model;
        this.alternativeModel = null;
        this.estimateComponentProbs = estimateComponentProbs;
        if (!estimateComponentProbs || componentHyperParams == null) {
            this.componentHyperParams = new double[this.dimension];
            minValueOfUsedHyperParamIsZero = estimateComponentProbs;
        } else {
            if (componentHyperParams.length != this.dimension) {
                throw new IllegalArgumentException("The dimension of the component assignment hyperparameter is not correct.");
            }
            this.componentHyperParams = new double[this.dimension];
            minValueOfUsedHyperParamIsZero = componentHyperParams[0] == 0.0;
            for (i = 0; i < this.dimension; ++i) {
                if (componentHyperParams[i] < 0.0 || minValueOfUsedHyperParamIsZero && componentHyperParams[i] > 0.0 || !minValueOfUsedHyperParamIsZero && componentHyperParams[i] == 0.0) {
                    throw new IllegalArgumentException("The " + i + "-th component assignment hyperparameter is not correct.");
                }
                this.componentHyperParams[i] = componentHyperParams[i];
            }
        }
        this.best = Double.NEGATIVE_INFINITY;
        this.compProb = new double[this.dimension];
        this.usedWeights = new double[model.length][][];
        switch (algorithm) {
            case EM: {
                if (parametrization != Parameterization.THETA && parametrization != Parameterization.LAMBDA) {
                    throw new IllegalArgumentException("The type of parametrization is unknown.");
                }
                this.parametrization = parametrization;
                this.setAlpha(alpha);
                if (tc == null) {
                    throw new NullPointerException();
                }
                if (!tc.isSimple()) {
                    throw new IllegalArgumentException("The TerminationCondition has to be simple.");
                }
                this.tc = tc;
                break;
            }
            case GIBBS_SAMPLING: {
                if (minValueOfUsedHyperParamIsZero) {
                    throw new IllegalArgumentException("The component hyper parameters have to be set to positive values.");
                }
                if (initialIteration <= 0) {
                    throw new IllegalArgumentException("The given number of intial iterations has to be at least 1.");
                }
                if (initialIteration * starts > stationaryIteration) {
                    throw new IllegalArgumentException("The given number of intial iterations has to be most (stationaryIteration/starts).");
                }
                if (stationaryIteration <= 0) {
                    throw new IllegalArgumentException("The given number of iterations has to be at least 1.");
                }
                if (burnInTest == null) {
                    throw new IllegalArgumentException("You have to specify a burn in test.");
                }
                this.initialIteration = initialIteration;
                this.stationaryIteration = stationaryIteration;
                this.burnInTest = burnInTest;
                this.checkModelsForGibbsSampling();
                break;
            }
            default: {
                throw new IllegalArgumentException("The type of algorithm is unknown.");
            }
        }
        this.algorithm = algorithm;
    }

    protected void checkModelsForGibbsSampling() {
        for (int i = 0; i < this.model.length; ++i) {
            if (!this.optimizeModel[i] || this.model[i] instanceof GibbsSamplingModel) continue;
            throw new IllegalArgumentException("The model for component " + i + " doesn't implement the interface GibbsSamplingModel!");
        }
    }

    protected void checkLength(int index, int l) {
        if (l != 0 && this.length != l) {
            throw new IllegalArgumentException("The models have to use the same length like the AbstractMixtureTrainSM. Violated at position " + index + ".");
        }
    }

    @Override
    public DataSet emitDataSet(int n, int ... lengths) throws Exception {
        Sequence[] seqs;
        if (!this.isInitialized()) {
            throw new NotTrainedException();
        }
        switch (this.algorithm) {
            case EM: {
                seqs = this.emitDataSetUsingCurrentParameterSet(n, lengths);
                break;
            }
            case GIBBS_SAMPLING: {
                int i;
                int[] anz = new int[this.starts];
                int all = 0;
                int burnIn = this.burnInTest.getLengthOfBurnIn();
                for (i = 0; i < this.starts; ++i) {
                    all = anz[i] = all + Math.max(0, this.counter[i] - burnIn);
                }
                int[] no = new int[all];
                for (i = 0; i < n; ++i) {
                    int n2 = r.nextInt(all);
                    no[n2] = no[n2] + 1;
                }
                seqs = new Sequence[n];
                all = 0;
                int j = 0;
                for (i = 0; i < this.starts; ++i) {
                    this.parseParameterSet(i, burnIn);
                    while (all < anz[i]) {
                        if (no[all] > 0) {
                            int[] len;
                            if (lengths == null || lengths.length <= 1) {
                                len = lengths;
                            } else {
                                len = new int[no[all]];
                                System.arraycopy(lengths, j, len, 0, no[all]);
                            }
                            Sequence[] help = this.emitDataSetUsingCurrentParameterSet(no[all], len);
                            int k = 0;
                            while (k < no[all]) {
                                seqs[j] = help[k];
                                ++k;
                                ++j;
                            }
                        }
                        this.parseNextParameterSet();
                        ++all;
                    }
                }
                break;
            }
            default: {
                throw new IllegalArgumentException("The type of algorithm is unknown.");
            }
        }
        return new DataSet("sampled from " + this.getInstanceName(), seqs);
    }

    protected abstract Sequence[] emitDataSetUsingCurrentParameterSet(int var1, int ... var2) throws Exception;

    protected boolean parseParameterSet(int sampling, int burnInIteration) throws Exception {
        boolean parsed = true;
        for (int i = 0; i < this.model.length; ++i) {
            if (!this.optimizeModel[i]) continue;
            parsed &= ((SamplingComponent)((Object)this.model[i])).parseParameterSet(sampling, burnInIteration);
        }
        if (this.estimateComponentProbs) {
            parsed &= this.parseComponentParameterSet(sampling, burnInIteration);
        }
        return parsed;
    }

    private boolean parseComponentParameterSet(int sampling, int burnInIteration) throws IOException {
        String str;
        if (this.filereader != null) {
            this.filereader.close();
        }
        this.filereader = new BufferedReader(new FileReader(this.file[sampling]));
        while ((str = this.filereader.readLine()) != null) {
            if (Integer.parseInt(str.substring(0, str.indexOf("\t"))) != burnInIteration) continue;
            this.parse(str);
            return true;
        }
        return false;
    }

    private void parse(String str) {
        String[] strarray = str.split("\t");
        int l = 1;
        for (int i = 0; i < this.model.length; ++i) {
            this.weights[i] = Double.parseDouble(strarray[l++]);
            this.logWeights[i] = Math.log(this.weights[i]);
        }
    }

    protected boolean parseNextParameterSet() throws Exception {
        if (this.estimateComponentProbs) {
            String str = this.filereader.readLine();
            if (str == null) {
                return false;
            }
            this.parse(str);
        }
        boolean parsed = true;
        for (int i = 0; i < this.model.length && parsed; ++i) {
            if (!this.optimizeModel[i]) continue;
            parsed &= ((SamplingComponent)((Object)this.model[i])).parseNextParameterSet();
        }
        return parsed;
    }

    protected void initModelForSampling(int starts) throws IOException {
        if (this.file != null && this.file.length == starts) {
            for (int i = 0; i < starts; ++i) {
                if (this.file[i] != null) {
                    FileOutputStream o = new FileOutputStream(this.file[i]);
                    o.close();
                }
                this.counter[i] = 0;
            }
        } else {
            this.deleteParameterFiles();
            this.file = new File[starts];
            this.counter = new int[starts];
        }
        for (int i = 0; i < this.model.length; ++i) {
            if (!this.optimizeModel[i]) continue;
            ((SamplingComponent)((Object)this.model[i])).initForSampling(starts);
        }
    }

    protected void extendSampling(int sampling) throws Exception {
        if (this.file[sampling] == null) {
            this.file[sampling] = File.createTempFile("pi-", ".dat", null);
        } else {
            this.parseComponentParameterSet(sampling, this.counter[sampling] - 1);
            this.filereader.close();
            this.filereader = null;
        }
        this.filewriter = new BufferedWriter(new FileWriter(this.file[sampling], true));
        for (int i = 0; i < this.model.length; ++i) {
            if (!this.optimizeModel[i]) continue;
            ((SamplingComponent)((Object)this.model[i])).extendSampling(sampling, true);
        }
        this.samplingIndex = sampling;
    }

    protected void samplingStopped() throws IOException {
        for (int i = 0; i < this.model.length; ++i) {
            if (!this.optimizeModel[i]) continue;
            ((SamplingComponent)((Object)this.model[i])).samplingStopped();
        }
        this.filewriter.close();
        this.filewriter = null;
    }

    protected boolean isInSamplingMode() {
        int i;
        for (i = 0; i < this.model.length && (!this.optimizeModel[i] || ((SamplingComponent)((Object)this.model[i])).isInSamplingMode()); ++i) {
        }
        return i == this.model.length && this.filewriter != null;
    }

    protected void finalize() throws Throwable {
        this.alternativeModel = null;
        this.model = null;
        this.compProb = null;
        this.componentHyperParams = null;
        this.logWeights = null;
        this.weights = null;
        this.sample = null;
        this.counter = null;
        this.optimizeModel = null;
        this.usedWeights = null;
        if (this.filereader != null) {
            this.filereader.close();
        }
        if (this.filewriter != null) {
            this.filewriter.close();
        }
        this.deleteParameterFiles();
        super.finalize();
    }

    private void deleteParameterFiles() {
        if (this.file != null) {
            for (int i = 0; i < this.file.length; ++i) {
                if (this.file[i] == null) continue;
                this.file[i].delete();
            }
        }
    }

    public static final int draw(double[] w, int start) {
        int i = start;
        for (double p = r.nextDouble(); i < w.length && p > w[i]; p -= w[i++]) {
        }
        if (i == w.length) {
            --i;
        }
        return i;
    }

    public static final int max(double[] w, int start, int end) {
        int max = start;
        for (int i = start + 1; i < end; ++i) {
            if (!(w[i] > w[max])) continue;
            max = i;
        }
        return max;
    }

    public static enum Algorithm {
        EM,
        GIBBS_SAMPLING;

    }

    public static enum Parameterization {
        THETA(-1.0),
        LAMBDA(0.0);

        private double count;

        private Parameterization(double count) {
            this.count = count;
        }

        double getCount() {
            return this.count;
        }
    }
}

