/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.sequenceScores.statisticalModels.trainable;

import de.jstacs.NotTrainedException;
import de.jstacs.algorithms.optimization.LimitedMedianStartDistance;
import de.jstacs.algorithms.optimization.NegativeDifferentiableFunction;
import de.jstacs.algorithms.optimization.Optimizer;
import de.jstacs.algorithms.optimization.termination.AbstractTerminationCondition;
import de.jstacs.algorithms.optimization.termination.SmallDifferenceOfFunctionEvaluationsCondition;
import de.jstacs.classifiers.differentiableSequenceScoreBased.OptimizableFunction;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.LearningPrinciple;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.LogGenDisMixFunction;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.CompositeLogPrior;
import de.jstacs.data.DataSet;
import de.jstacs.data.WrongAlphabetException;
import de.jstacs.data.WrongLengthException;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.results.NumericalResultSet;
import de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore;
import de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.IndependentProductDiffSM;
import de.jstacs.sequenceScores.statisticalModels.differentiable.UniformDiffSM;
import de.jstacs.sequenceScores.statisticalModels.differentiable.homogeneous.UniformHomogeneousDiffSM;
import de.jstacs.sequenceScores.statisticalModels.trainable.AbstractTrainableStatisticalModel;
import de.jstacs.utils.SafeOutputStream;
import java.io.OutputStream;

public class DifferentiableStatisticalModelWrapperTrainSM
extends AbstractTrainableStatisticalModel {
    private SafeOutputStream out;
    protected DifferentiableStatisticalModel nsf;
    private double logNorm;
    private double lineps;
    private double startD;
    private AbstractTerminationCondition tc;
    private byte algo;
    private int threads;
    private static final String XML_TAG = "DifferentiableStatisticalModelWrapperTrainSM";

    public DifferentiableStatisticalModelWrapperTrainSM(DifferentiableStatisticalModel nsf, int threads, byte algo, AbstractTerminationCondition tc, double lineps, double startD) throws CloneNotSupportedException {
        super(nsf.getAlphabetContainer(), nsf.getLength());
        if (threads < 1) {
            throw new IllegalArgumentException("The number of threads has to be positive.");
        }
        this.threads = threads;
        this.tc = tc.clone();
        if (lineps < 0.0) {
            throw new IllegalArgumentException("The value of lineps has to be non-negative.");
        }
        this.lineps = lineps;
        if (startD <= 0.0) {
            throw new IllegalArgumentException("The value of startD has to be positive.");
        }
        this.startD = startD;
        this.algo = algo;
        this.nsf = (DifferentiableStatisticalModel)nsf.clone();
        this.logNorm = this.isInitialized() ? nsf.getLogNormalizationConstant() : Double.NEGATIVE_INFINITY;
        this.setOutputStream(SafeOutputStream.DEFAULT_STREAM);
    }

    public DifferentiableStatisticalModelWrapperTrainSM(StringBuffer stringBuff) throws NonParsableException {
        super(stringBuff);
    }

    @Override
    public DifferentiableStatisticalModelWrapperTrainSM clone() throws CloneNotSupportedException {
        DifferentiableStatisticalModelWrapperTrainSM clone = (DifferentiableStatisticalModelWrapperTrainSM)super.clone();
        clone.nsf = (DifferentiableStatisticalModel)this.nsf.clone();
        clone.tc = this.tc.clone();
        clone.setOutputStream(this.out.doesNothing() ? null : SafeOutputStream.DEFAULT_STREAM);
        return clone;
    }

    @Override
    public void train(DataSet data, double[] weights) throws Exception {
        if (!data.getAlphabetContainer().checkConsistency(this.alphabets)) {
            throw new WrongAlphabetException("The AlphabetConatainer of the sample and the model do not match.");
        }
        if (this.length != 0 && this.length != data.getElementLength()) {
            throw new WrongLengthException("The length of the elements of the sample is not suitable for the model.");
        }
        if (this.nsf instanceof IndependentProductDiffSM) {
            IndependentProductDiffSM ipsf = (IndependentProductDiffSM)this.nsf;
            DifferentiableStatisticalModel[] nsfs = ArrayHandler.cast(DifferentiableStatisticalModel.class, ipsf.getFunctions());
            DataSet[] part = new DataSet[1];
            DataSet[] packedData = new DataSet[]{data};
            double[][] packedWeights = new double[][]{weights};
            int i = 0;
            while (i < nsfs.length) {
                int a = ipsf.extractSequenceParts(i, packedData, part);
                double[][] partWeights = ipsf.extractWeights(a, packedWeights);
                nsfs[i] = this.train(part[0], partWeights[0], nsfs[i]);
                ++i;
            }
            this.nsf = new IndependentProductDiffSM(ipsf.getESS(), true, nsfs, ipsf.getIndices(), ipsf.getPartialLengths(), ipsf.getReverseSwitches());
        } else {
            this.nsf = this.train(data, weights, this.nsf);
        }
    }

    private DifferentiableStatisticalModel train(DataSet data, double[] weights, DifferentiableStatisticalModel nsf) throws Exception {
        if (!(nsf instanceof UniformDiffSM) && !(nsf instanceof UniformHomogeneousDiffSM)) {
            DataSet.WeightedDataSetFactory wsf = new DataSet.WeightedDataSetFactory(DataSet.WeightedDataSetFactory.SortOperation.NO_SORT, data, weights);
            DataSet small = wsf.getDataSet();
            double[] smallWeights = wsf.getWeights();
            DifferentiableSequenceScore best = null;
            double max = Double.NEGATIVE_INFINITY;
            double fac = data.getNumberOfElements();
            double ess = nsf.getESS();
            fac = fac / (ess + fac) * (ess == 0.0 ? 1.0 : 2.0);
            DifferentiableSequenceScore[] score = new DifferentiableStatisticalModel[]{(DifferentiableStatisticalModel)nsf.clone()};
            CompositeLogPrior prior = new CompositeLogPrior();
            double[] beta = LearningPrinciple.getBeta(ess == 0.0 ? LearningPrinciple.ML : LearningPrinciple.MAP);
            LogGenDisMixFunction f = new LogGenDisMixFunction(this.threads, score, new DataSet[]{small}, new double[][]{smallWeights}, prior, beta, true, false);
            NegativeDifferentiableFunction minusF = new NegativeDifferentiableFunction(f);
            LimitedMedianStartDistance sd = new LimitedMedianStartDistance(5, this.startD * fac);
            int i = 0;
            while (i < nsf.getNumberOfRecommendedStarts()) {
                this.out.writeln("start: " + i);
                score[0].initializeFunction(0, false, new DataSet[]{small}, new double[][]{smallWeights});
                f.reset(score);
                double[] params = f.getParameters(OptimizableFunction.KindOfParameter.PLUGIN);
                sd.reset();
                Optimizer.optimize(this.algo, minusF, params, this.tc, this.lineps * fac, sd, this.out);
                double current = f.evaluateFunction(params);
                if (current > max) {
                    best = score[0];
                    max = current;
                }
                score[0] = (DifferentiableStatisticalModel)nsf.clone();
                ++i;
            }
            this.out.writeln("best: " + max);
            nsf = best;
            this.logNorm = nsf.getLogNormalizationConstant();
            f.stopThreads();
            System.gc();
        }
        return nsf;
    }

    @Override
    public double getLogProbFor(Sequence sequence, int startpos, int endpos) throws NotTrainedException, Exception {
        if (!this.isInitialized()) {
            throw new NotTrainedException();
        }
        if (!sequence.getAlphabetContainer().checkConsistency(this.alphabets)) {
            throw new WrongAlphabetException("The AlphabetContainer of the sequence and the model do not match.");
        }
        if (startpos < 0) {
            throw new IllegalArgumentException("Check start position.");
        }
        if (endpos + 1 < startpos || endpos >= sequence.getLength()) {
            throw new IllegalArgumentException("Check end position.");
        }
        if (this.length != 0 && this.length != endpos - startpos + 1) {
            throw new WrongLengthException("Check length of the sequence.");
        }
        return this.nsf.getLogScoreFor(sequence, startpos) - this.logNorm;
    }

    @Override
    public double getLogPriorTerm() throws Exception {
        return this.nsf.getLogPriorTerm() - this.nsf.getESS() * this.logNorm;
    }

    @Override
    public String getInstanceName() {
        return "model using " + this.nsf.getInstanceName();
    }

    @Override
    public boolean isInitialized() {
        return this.nsf.isInitialized();
    }

    @Override
    public NumericalResultSet getNumericalCharacteristics() throws Exception {
        return null;
    }

    @Override
    public String toString() {
        return this.nsf.toString();
    }

    @Override
    protected void fromXML(StringBuffer xml) throws NonParsableException {
        StringBuffer rep = XMLParser.extractForTag(xml, XML_TAG);
        this.nsf = XMLParser.extractObjectForTags(rep, "DifferentiableStatisticalModel", DifferentiableStatisticalModel.class);
        this.threads = XMLParser.extractObjectForTags(rep, "threads", Integer.TYPE);
        this.algo = XMLParser.extractObjectForTags(rep, "algorithm", Byte.TYPE);
        if (XMLParser.hasTag(rep, "tc", null, null)) {
            this.tc = (AbstractTerminationCondition)XMLParser.extractObjectForTags(rep, "tc");
        } else {
            try {
                this.tc = new SmallDifferenceOfFunctionEvaluationsCondition(XMLParser.extractObjectForTags(rep, "eps", Double.TYPE));
            }
            catch (Exception e) {
                NonParsableException n = new NonParsableException(e.getMessage());
                throw n;
            }
        }
        this.lineps = XMLParser.extractObjectForTags(rep, "lineps", Double.TYPE);
        this.startD = XMLParser.extractObjectForTags(rep, "startDistance", Double.TYPE);
        this.logNorm = this.isInitialized() ? this.nsf.getLogNormalizationConstant() : Double.NEGATIVE_INFINITY;
        this.alphabets = this.nsf.getAlphabetContainer();
        this.length = this.nsf.getLength();
        this.setOutputStream(SafeOutputStream.DEFAULT_STREAM);
    }

    @Override
    public StringBuffer toXML() {
        StringBuffer xml = new StringBuffer(100000);
        XMLParser.appendObjectWithTags(xml, this.nsf, "DifferentiableStatisticalModel");
        XMLParser.appendObjectWithTags(xml, this.threads, "threads");
        XMLParser.appendObjectWithTags(xml, this.algo, "algorithm");
        XMLParser.appendObjectWithTags(xml, this.tc, "tc");
        XMLParser.appendObjectWithTags(xml, this.lineps, "lineps");
        XMLParser.appendObjectWithTags(xml, this.startD, "startDistance");
        XMLParser.addTags(xml, XML_TAG);
        return xml;
    }

    public final void setOutputStream(OutputStream o) {
        this.out = SafeOutputStream.getSafeOutputStream(o);
    }

    public DifferentiableStatisticalModel getFunction() throws CloneNotSupportedException {
        return (DifferentiableStatisticalModel)this.nsf.clone();
    }
}

