/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix;

import de.jstacs.classifiers.differentiableSequenceScoreBased.ScoreClassifier;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.GenDisMixClassifierParameterSet;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.LearningPrinciple;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.LogGenDisMixFunction;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.OneDataSetLogGenDisMixFunction;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.DoesNothingLogPrior;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.LogPrior;
import de.jstacs.data.DataSet;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.parameters.SimpleParameter;
import de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore;
import de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel;
import java.util.Arrays;

public class GenDisMixClassifier
extends ScoreClassifier {
    protected LogPrior prior;
    protected LogGenDisMixFunction function;
    protected double[] beta;
    private static final String XML_TAG = "gendismix-classifier";

    protected GenDisMixClassifier(GenDisMixClassifierParameterSet params, LogPrior prior, double lastScore, double[] beta, DifferentiableSequenceScore ... score) throws CloneNotSupportedException {
        super(params, lastScore, score);
        this.setWeights(beta);
        this.setPrior(prior);
    }

    public GenDisMixClassifier(GenDisMixClassifierParameterSet params, LogPrior prior, double lastScore, double[] beta, DifferentiableStatisticalModel ... score) throws CloneNotSupportedException {
        this(params, prior, lastScore, beta, (DifferentiableSequenceScore[])score);
    }

    public GenDisMixClassifier(GenDisMixClassifierParameterSet params, LogPrior prior, double[] beta, DifferentiableStatisticalModel ... score) throws CloneNotSupportedException {
        this(params, prior, Double.NaN, beta, score);
    }

    public GenDisMixClassifier(GenDisMixClassifierParameterSet params, LogPrior prior, double genBeta, double disBeta, double priorBeta, DifferentiableStatisticalModel ... score) throws CloneNotSupportedException {
        this(params, prior, Double.NaN, new double[]{disBeta, genBeta, priorBeta}, score);
    }

    public GenDisMixClassifier(GenDisMixClassifierParameterSet params, LogPrior prior, LearningPrinciple key, DifferentiableStatisticalModel ... score) throws CloneNotSupportedException {
        this(params, prior, Double.NaN, LearningPrinciple.getBeta(key), score);
    }

    public GenDisMixClassifier(StringBuffer xml) throws NonParsableException {
        super(xml);
    }

    @Override
    public GenDisMixClassifier clone() throws CloneNotSupportedException {
        GenDisMixClassifier clone = (GenDisMixClassifier)super.clone();
        clone.prior = this.prior.getNewInstance();
        clone.beta = (double[])this.beta.clone();
        return clone;
    }

    @Override
    protected LogGenDisMixFunction getFunction(DataSet[] data, double[][] weights) throws Exception {
        GenDisMixClassifierParameterSet p = (GenDisMixClassifierParameterSet)this.params;
        if (data.length > 1) {
            return new LogGenDisMixFunction(p.getNumberOfThreads(), this.score, data, weights, this.prior, this.beta, p.shouldBeNormalized(), p.useOnlyFreeParameter());
        }
        return new OneDataSetLogGenDisMixFunction(p.getNumberOfThreads(), this.score, data[0], weights, this.prior, this.beta, p.shouldBeNormalized(), p.useOnlyFreeParameter());
    }

    public void setPrior(LogPrior prior) {
        this.prior = prior != null ? prior : DoesNothingLogPrior.defaultInstance;
        this.hasBeenOptimized = false;
    }

    public void setWeights(double ... beta) throws IllegalArgumentException {
        this.beta = LearningPrinciple.checkWeights(beta);
        this.hasBeenOptimized = false;
    }

    @Override
    protected String getXMLTag() {
        return XML_TAG;
    }

    @Override
    protected StringBuffer getFurtherClassifierInfos() {
        StringBuffer xml = super.getFurtherClassifierInfos();
        XMLParser.appendObjectWithTags(xml, this.beta, "beta");
        if (!(this.prior instanceof DoesNothingLogPrior)) {
            StringBuffer pr = new StringBuffer(1000);
            pr.append("<prior>\n");
            XMLParser.appendObjectWithTags(pr, this.prior.getClass(), "class");
            pr.append(this.prior.toXML());
            pr.append("\t</prior>\n");
            xml.append(pr);
        }
        return xml;
    }

    @Override
    protected void extractFurtherClassifierInfosFromXML(StringBuffer xml) throws NonParsableException {
        super.extractFurtherClassifierInfosFromXML(xml);
        this.beta = LearningPrinciple.checkWeights(XMLParser.extractObjectForTags(xml, "beta", double[].class));
        StringBuffer pr = XMLParser.extractForTag(xml, "prior");
        if (pr != null) {
            Class clazz = XMLParser.extractObjectForTags(pr, "class", Class.class);
            try {
                this.prior = (LogPrior)clazz.getConstructor(StringBuffer.class).newInstance(pr);
            }
            catch (NoSuchMethodException e) {
                NonParsableException n = new NonParsableException("You must provide a constructor " + clazz.getSimpleName() + "(StringBuffer).");
                n.setStackTrace(e.getStackTrace());
                throw n;
            }
            catch (Exception e) {
                NonParsableException n = new NonParsableException("problem at " + clazz.getSimpleName() + ": " + e.getMessage());
                n.setStackTrace(e.getStackTrace());
                throw n;
            }
        } else {
            this.prior = DoesNothingLogPrior.defaultInstance;
        }
        if (this.beta[2] > 0.0) {
            try {
                this.prior.set(((GenDisMixClassifierParameterSet)this.params).useOnlyFreeParameter(), this.score);
            }
            catch (Exception e) {
                NonParsableException n = new NonParsableException("problem when setting the kind of parameter: " + e.getMessage());
                n.setStackTrace(e.getStackTrace());
                throw n;
            }
        }
    }

    public static GenDisMixClassifier[] create(GenDisMixClassifierParameterSet params, LogPrior prior, double[] weights, DifferentiableStatisticalModel[] ... functions) throws CloneNotSupportedException {
        int anz = 1;
        int[] current = new int[functions.length];
        int[] max = new int[functions.length];
        DifferentiableStatisticalModel[] sf = new DifferentiableStatisticalModel[functions.length];
        for (int counter1 = 0; counter1 < functions.length; ++counter1) {
            anz *= functions[counter1].length;
            max[counter1] = functions[counter1].length - 1;
        }
        GenDisMixClassifier[] erg = new GenDisMixClassifier[anz];
        anz = sf.length - 1;
        for (int counter1 = 0; counter1 < erg.length; ++counter1) {
            int counter2;
            for (counter2 = 0; counter2 < sf.length; ++counter2) {
                sf[counter2] = functions[counter2][current[counter2]];
            }
            erg[counter1] = new GenDisMixClassifier(params, prior, weights, sf);
            counter2 = 0;
            while (counter2 < anz && current[counter2] == max[counter2]) {
                current[counter2++] = 0;
            }
            int n = counter2;
            current[n] = current[n] + 1;
        }
        return erg;
    }

    @Override
    public String getInstanceName() {
        return super.getInstanceName() + " with weights=" + Arrays.toString(this.beta) + (this.prior == null || this.prior == DoesNothingLogPrior.defaultInstance ? "" : " and with " + this.prior.getInstanceName());
    }

    public int getNumberOfThreads() {
        return ((GenDisMixClassifierParameterSet)this.params).getNumberOfThreads();
    }

    public String toString() {
        int i;
        StringBuffer sb = new StringBuffer(this.score.length * 5000);
        String heading = "function ";
        for (i = 0; i < this.score.length; ++i) {
            sb.append(heading + i);
            sb.append("\n" + this.score[i].toString() + "\n");
        }
        sb.append("class weights: ");
        for (i = 0; i < this.getNumberOfClasses(); ++i) {
            sb.append(this.getClassWeight(i) + " ");
        }
        sb.append("\n");
        return sb.toString();
    }

    public void setNumberOfThreads(int threads) throws SimpleParameter.IllegalValueException {
        ((GenDisMixClassifierParameterSet)this.params).setNumberOfThreads(threads);
    }
}

