/*
 * Decompiled with CFR 0.152.
 */
package projects.kmermotifs;

import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.NonParsableException;
import de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.random.DirichletMRG;
import de.jstacs.utils.random.DirichletMRGParams;
import java.text.NumberFormat;
import java.util.Arrays;
import projects.kmermotifs.DeconvolvableDiffSM;

public class DeconvolvablePWMDiffSM
extends AbstractDifferentiableStatisticalModel
implements DeconvolvableDiffSM {
    private double[][] pars;
    private double[] norm;
    private double[] posEss;
    private double classEss;
    private boolean isInitialized;
    private double globalNorm;

    public DeconvolvablePWMDiffSM(AlphabetContainer alphabet, int length, double[] positionDependentEss, double classEss) {
        super(alphabet, length);
        this.posEss = (double[])positionDependentEss.clone();
        this.classEss = classEss;
        this.pars = new double[length][(int)alphabet.getAlphabetLengthAt(0)];
        this.norm = new double[length];
    }

    @Override
    public DeconvolvablePWMDiffSM clone() throws CloneNotSupportedException {
        DeconvolvablePWMDiffSM clone = (DeconvolvablePWMDiffSM)super.clone();
        clone.norm = (double[])this.norm.clone();
        clone.pars = (double[][])ArrayHandler.clone((Cloneable[])this.pars);
        clone.posEss = (double[])this.posEss.clone();
        return clone;
    }

    @Override
    public int getSizeOfEventSpaceForRandomVariablesOfParameter(int index) {
        return (int)this.alphabets.getAlphabetLengthAt(0);
    }

    @Override
    public double getLogNormalizationConstant(int ignoreLeft, int ignoreRight) {
        double part = 0.0;
        int i = ignoreLeft;
        while (i < this.pars.length - ignoreRight) {
            part += this.norm[i];
            ++i;
        }
        return part;
    }

    @Override
    public double getLogNormalizationConstant() {
        return this.globalNorm;
    }

    @Override
    public double getLogPartialNormalizationConstant(int parameterIndex) throws Exception {
        int row = parameterIndex / this.pars[0].length;
        int col = parameterIndex % this.pars[0].length;
        return this.globalNorm - this.norm[row] + this.pars[row][col];
    }

    @Override
    public double getLogPartialNormalizationConstant(int ignoreLeft, int ignoreRight, int parameterIndex) {
        int row = parameterIndex / this.pars[0].length;
        int col = parameterIndex % this.pars[0].length;
        if (row < ignoreLeft || row >= this.pars.length - ignoreRight) {
            return Double.NEGATIVE_INFINITY;
        }
        return this.getLogNormalizationConstant(ignoreLeft, ignoreRight) - this.norm[row] + this.pars[row][col];
    }

    @Override
    public double getLogPriorTerm() {
        double lp = 0.0;
        int i = 0;
        while (i < this.pars.length) {
            int j = 0;
            while (j < this.pars[i].length) {
                lp += this.pars[i][j] * this.posEss[i] / (double)this.pars[i].length;
                ++j;
            }
            ++i;
        }
        return lp;
    }

    @Override
    public void addGradientOfLogPriorTerm(double[] grad, int start) throws Exception {
        int initStart = start;
        int i = 0;
        while (i < this.pars.length) {
            int j = 0;
            while (j < this.pars[i].length) {
                int n = start++;
                grad[n] = grad[n] + this.posEss[i] / (double)this.pars[i].length;
                ++j;
            }
            ++i;
        }
    }

    @Override
    public double getESS() {
        return this.classEss;
    }

    @Override
    public void initializeFunction(int index, boolean freeParams, DataSet[] data, double[][] weights) throws Exception {
        int i = 0;
        while (i < this.pars.length) {
            Arrays.fill(this.pars[i], this.posEss[i] / (double)this.pars[i].length);
            ++i;
        }
        i = 0;
        while (i < data[index].getNumberOfElements()) {
            Sequence seq = data[index].getElementAt(i);
            int j = 0;
            while (j < seq.getLength()) {
                double[] dArray = this.pars[j];
                int n = seq.discreteVal(j);
                dArray[n] = dArray[n] + weights[index][i];
                ++j;
            }
            ++i;
        }
        i = 0;
        while (i < this.pars.length) {
            Normalisation.sumNormalisation(this.pars[i]);
            int j = 0;
            while (j < this.pars[i].length) {
                this.pars[i][j] = Math.log(this.pars[i][j]);
                ++j;
            }
            ++i;
        }
        Arrays.fill(this.norm, 0.0);
        this.isInitialized = true;
    }

    @Override
    public void initializeFunctionRandomly(boolean freeParams) throws Exception {
        int i = 0;
        while (i < this.pars.length) {
            DirichletMRGParams params = new DirichletMRGParams(this.posEss[i] / (double)this.pars[i].length, this.pars[i].length);
            DirichletMRG.DEFAULT_INSTANCE.generateLog(this.pars[i], 0, this.pars[i].length, params);
            ++i;
        }
        Arrays.fill(this.norm, 0.0);
        this.isInitialized = true;
    }

    @Override
    public double getLogScoreAndPartialDerivation(Sequence seq, int start, IntList indices, DoubleList partialDer) {
        return this.getLogProbAndPartialDerivation(seq, start, 0, 0, indices, partialDer);
    }

    @Override
    public int getNumberOfParameters() {
        return this.pars.length * this.pars[0].length;
    }

    @Override
    public double[] getCurrentParameterValues() throws Exception {
        double[] p = new double[this.getNumberOfParameters()];
        int i = 0;
        int k = 0;
        while (i < this.pars.length) {
            int j = 0;
            while (j < this.pars[i].length) {
                p[k] = this.pars[i][j];
                ++j;
                ++k;
            }
            ++i;
        }
        return p;
    }

    @Override
    public void setParameters(double[] params, int start) {
        this.globalNorm = 0.0;
        int i = 0;
        while (i < this.pars.length) {
            int j = 0;
            while (j < this.pars[i].length) {
                this.pars[i][j] = params[start];
                ++j;
                ++start;
            }
            this.norm[i] = Normalisation.getLogSum(this.pars[i]);
            this.globalNorm += this.norm[i];
            ++i;
        }
    }

    @Override
    public String getInstanceName() {
        return "PWM";
    }

    @Override
    public double getLogScoreFor(Sequence seq, int start) {
        return this.getLogProbFor(seq, start, 0, 0);
    }

    @Override
    public boolean isInitialized() {
        return this.isInitialized;
    }

    @Override
    public StringBuffer toXML() {
        return null;
    }

    @Override
    public double getLogProbFor(Sequence seq, int start, int ignoreLeft, int ignoreRight) {
        double score = 0.0;
        int i = ignoreLeft;
        while (i < this.pars.length - ignoreRight) {
            score += this.pars[i][seq.discreteVal(start + i)];
            ++i;
        }
        return score;
    }

    @Override
    public double getLogProbAndPartialDerivation(Sequence seq, int start, int ignoreLeft, int ignoreRight, IntList indices, DoubleList partialDer) {
        double score = 0.0;
        int k = ignoreLeft * this.pars[0].length;
        int i = ignoreLeft;
        while (i < this.pars.length - ignoreRight) {
            int val = seq.discreteVal(start + i);
            score += this.pars[i][val];
            indices.add(k + val);
            partialDer.add(1.0);
            k += this.pars[i].length;
            ++i;
        }
        return score;
    }

    @Override
    protected void fromXML(StringBuffer xml) throws NonParsableException {
    }

    @Override
    public boolean isNormalized() {
        return false;
    }

    @Override
    public String toString() {
        throw new Error("Unresolved compilation problem: \n\tCannot override the final method from AbstractDifferentiableStatisticalModel\n");
    }

    @Override
    public /* synthetic */ String toString(NumberFormat numberFormat) {
        throw new Error("Unresolved compilation problem: \n\tThe type DeconvolvablePWMDiffSM must implement the inherited abstract method SequenceScore.toString(NumberFormat)\n");
    }
}

