/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix;

import de.jstacs.algorithms.optimization.EvaluationException;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.LogGenDisMixFunction;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.LogPrior;
import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore;
import de.jstacs.utils.Normalisation;
import java.util.Arrays;

public class OneDataSetLogGenDisMixFunction
extends LogGenDisMixFunction {
    public OneDataSetLogGenDisMixFunction(int threads, DifferentiableSequenceScore[] score, DataSet data, double[][] weights, LogPrior prior, double[] beta, boolean norm, boolean freeParams) throws IllegalArgumentException {
        super(threads, score, new DataSet[]{data}, weights, prior, beta, norm, freeParams);
    }

    @Override
    public void setDataAndWeights(DataSet[] data, double[][] weights) throws IllegalArgumentException {
        if (data.length != 1 || weights == null || weights.length != this.cl) {
            throw new IllegalArgumentException("The dimension of the data set or weights (array) is not correct.");
        }
        this.data = data;
        this.weights = weights;
        this.sum[this.cl] = 0.0;
        for (int i = 0; i < this.cl; ++i) {
            this.sum[i] = 0.0;
            if (data[0].getNumberOfElements() != weights[i].length) {
                throw new IllegalArgumentException("The dimension of the " + i + "-th weights (array) is not correct.");
            }
            for (int j = 0; j < weights[i].length; ++j) {
                int n = i;
                this.sum[n] = this.sum[n] + weights[i][j];
            }
            int n = this.cl;
            this.sum[n] = this.sum[n] + this.sum[i];
        }
        if (this.worker != null) {
            this.prepareThreads();
        }
    }

    @Override
    public DataSet[] getData() {
        Object[] d = new DataSet[this.weights.length];
        Arrays.fill(d, this.data[0]);
        return d;
    }

    @Override
    protected void evaluateGradientOfFunction(int index, int startClass, int startSeq, int endClass, int endSeq) {
        Arrays.fill(this.llGrad[index], 0.0);
        Arrays.fill(this.cllGrad[index], 0.0);
        int counter4 = 0;
        for (int counter2 = startSeq; counter2 < endSeq; ++counter2) {
            int counter1;
            Sequence s = this.data[0].getElementAt(counter2);
            for (counter1 = 0; counter1 < this.cl; ++counter1) {
                this.iList[index][counter1].clear();
                this.dList[index][counter1].clear();
                this.helpArray[index][counter1] = this.logClazz[counter1] + this.score[index][counter1].getLogScoreAndPartialDerivation(s, 0, this.iList[index][counter1], this.dList[index][counter1]);
            }
            Normalisation.logSumNormalisation(this.helpArray[index], 0, this.helpArray[index].length, this.helpArray[index], 0);
            for (int counter3 = 0; counter3 < this.cl; ++counter3) {
                double weight = this.weights[counter3][counter2];
                if (this.beta[1] != 0.0) {
                    if (counter3 < this.shortcut[0]) {
                        double[] dArray = this.llGrad[index];
                        int n = counter3;
                        dArray[n] = dArray[n] + weight;
                    }
                    for (counter4 = 0; counter4 < this.iList[index][counter3].length(); ++counter4) {
                        double[] dArray = this.llGrad[index];
                        int n = this.shortcut[counter3] + this.iList[index][counter3].get(counter4);
                        dArray[n] = dArray[n] + weight * this.dList[index][counter3].get(counter4);
                    }
                }
                if (this.beta[0] == 0.0) continue;
                for (counter1 = 0; counter1 < this.shortcut[0]; ++counter1) {
                    if (counter1 != counter3) {
                        double[] dArray = this.cllGrad[index];
                        int n = counter1;
                        dArray[n] = dArray[n] - weight * this.helpArray[index][counter1];
                        continue;
                    }
                    double[] dArray = this.cllGrad[index];
                    int n = counter1;
                    dArray[n] = dArray[n] + weight * (1.0 - this.helpArray[index][counter1]);
                }
                for (counter1 = 0; counter1 < this.cl; ++counter1) {
                    if (counter1 != counter3) {
                        for (counter4 = 0; counter4 < this.iList[index][counter1].length(); ++counter4) {
                            double[] dArray = this.cllGrad[index];
                            int n = this.shortcut[counter1] + this.iList[index][counter1].get(counter4);
                            dArray[n] = dArray[n] - weight * this.dList[index][counter1].get(counter4) * this.helpArray[index][counter1];
                        }
                        continue;
                    }
                    for (counter4 = 0; counter4 < this.iList[index][counter1].length(); ++counter4) {
                        double[] dArray = this.cllGrad[index];
                        int n = this.shortcut[counter1] + this.iList[index][counter1].get(counter4);
                        dArray[n] = dArray[n] + weight * this.dList[index][counter1].get(counter4) * (1.0 - this.helpArray[index][counter1]);
                    }
                }
            }
        }
    }

    @Override
    protected void evaluateFunction(int index, int startClass, int startSeq, int endClass, int endSeq) throws EvaluationException {
        double cll = 0.0;
        double ll = 0.0;
        double offset = 0.0;
        for (int counter2 = startSeq; counter2 < endSeq; ++counter2) {
            Sequence s = this.data[0].getElementAt(counter2);
            for (int counter1 = 0; counter1 < this.cl; ++counter1) {
                this.helpArray[index][counter1] = this.logClazz[counter1] + this.score[index][counter1].getLogScoreFor(s, 0);
            }
            if (this.beta[0] != 0.0) {
                offset = Normalisation.getLogSum(this.helpArray[index]);
            }
            for (int counter3 = 0; counter3 < this.cl; ++counter3) {
                cll += this.weights[counter3][counter2] * (this.helpArray[index][counter3] - offset);
                ll += this.weights[counter3][counter2] * this.helpArray[index][counter3];
            }
        }
        this.helpArray[index][0] = ll;
        this.helpArray[index][1] = cll;
    }
}

