/*
 * Decompiled with CFR 0.152.
 */
package supplementary.cookbook.recipes;

import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.random.DirichletMRG;
import de.jstacs.utils.random.DirichletMRGParams;
import java.util.Arrays;

public class PositionWeightMatrixDiffSM
extends AbstractDifferentiableStatisticalModel {
    private double[][] parameters;
    private double ess;
    private boolean isInitialized;
    private Double norm;

    public PositionWeightMatrixDiffSM(AlphabetContainer alphabets, int length, double ess) throws IllegalArgumentException {
        super(alphabets, length);
        if (!alphabets.isSimple() || !alphabets.isDiscrete()) {
            throw new IllegalArgumentException("This PWM can handle only discrete alphabets with the same alphabet at each position.");
        }
        this.parameters = new double[length][(int)alphabets.getAlphabetLengthAt(0)];
        this.ess = ess;
        this.isInitialized = false;
        this.norm = null;
    }

    public PositionWeightMatrixDiffSM(StringBuffer xml) throws NonParsableException {
        super(xml);
    }

    @Override
    public int getSizeOfEventSpaceForRandomVariablesOfParameter(int index) {
        return this.parameters[0].length;
    }

    @Override
    public double getLogNormalizationConstant() {
        if (this.norm == null) {
            this.norm = 0.0;
            int i = 0;
            while (i < this.parameters.length) {
                this.norm = this.norm + Normalisation.getLogSum(this.parameters[i]);
                ++i;
            }
        }
        return this.norm;
    }

    @Override
    public double getLogPartialNormalizationConstant(int parameterIndex) throws Exception {
        if (this.norm == null) {
            this.getLogNormalizationConstant();
        }
        int symbol = parameterIndex % (int)this.alphabets.getAlphabetLengthAt(0);
        int position = parameterIndex / (int)this.alphabets.getAlphabetLengthAt(0);
        return this.norm - Normalisation.getLogSum(this.parameters[position]) + this.parameters[position][symbol];
    }

    @Override
    public double getLogPriorTerm() {
        double logPrior = 0.0;
        int i = 0;
        while (i < this.parameters.length) {
            int j = 0;
            while (j < this.parameters[i].length) {
                logPrior += this.ess / this.alphabets.getAlphabetLengthAt(0) * this.parameters[i][j];
                ++j;
            }
            ++i;
        }
        return logPrior;
    }

    @Override
    public void addGradientOfLogPriorTerm(double[] grad, int start) throws Exception {
        int i = 0;
        while (i < this.parameters.length) {
            int j = 0;
            while (j < this.parameters[i].length) {
                grad[start] = this.ess / this.alphabets.getAlphabetLengthAt(0);
                ++j;
                ++start;
            }
            ++i;
        }
    }

    @Override
    public double getESS() {
        return this.ess;
    }

    @Override
    public void initializeFunction(int index, boolean freeParams, DataSet[] data, double[][] weights) throws Exception {
        if (!data[index].getAlphabetContainer().checkConsistency(this.alphabets) || data[index].getElementLength() != this.length) {
            throw new IllegalArgumentException("Alphabet or length to not match.");
        }
        int i = 0;
        while (i < this.parameters.length) {
            Arrays.fill(this.parameters[i], this.ess / this.alphabets.getAlphabetLengthAt(0));
            ++i;
        }
        i = 0;
        while (i < data[index].getNumberOfElements()) {
            Sequence seq = data[index].getElementAt(i);
            int j = 0;
            while (j < seq.getLength()) {
                double[] dArray = this.parameters[j];
                int n = seq.discreteVal(j);
                dArray[n] = dArray[n] + weights[index][i];
                ++j;
            }
            ++i;
        }
        i = 0;
        while (i < this.parameters.length) {
            Normalisation.sumNormalisation(this.parameters[i]);
            int j = 0;
            while (j < this.parameters[i].length) {
                this.parameters[i][j] = Math.log(this.parameters[i][j]);
                ++j;
            }
            ++i;
        }
        this.norm = null;
        this.isInitialized = true;
    }

    @Override
    public void initializeFunctionRandomly(boolean freeParams) throws Exception {
        int al = (int)this.alphabets.getAlphabetLengthAt(0);
        DirichletMRGParams pars = new DirichletMRGParams(this.ess / (double)al, al);
        int i = 0;
        while (i < this.parameters.length) {
            this.parameters[i] = DirichletMRG.DEFAULT_INSTANCE.generate(al, pars);
            int j = 0;
            while (j < this.parameters[i].length) {
                this.parameters[i][j] = Math.log(this.parameters[i][j]);
                ++j;
            }
            ++i;
        }
        this.norm = null;
        this.isInitialized = true;
    }

    @Override
    public double getLogScoreFor(Sequence seq, int start) {
        double score = 0.0;
        int i = 0;
        while (i < this.parameters.length) {
            score += this.parameters[i][seq.discreteVal(i + start)];
            ++i;
        }
        return score;
    }

    @Override
    public double getLogScoreAndPartialDerivation(Sequence seq, int start, IntList indices, DoubleList partialDer) {
        double score = 0.0;
        int off = 0;
        int i = 0;
        while (i < this.parameters.length) {
            int v = seq.discreteVal(i + start);
            score += this.parameters[i][v];
            indices.add(off + v);
            partialDer.add(1.0);
            off += this.parameters[i].length;
            ++i;
        }
        return score;
    }

    @Override
    public int getNumberOfParameters() {
        int num = 0;
        int i = 0;
        while (i < this.parameters.length) {
            num += this.parameters[i].length;
            ++i;
        }
        return num;
    }

    @Override
    public double[] getCurrentParameterValues() throws Exception {
        double[] pars = new double[this.getNumberOfParameters()];
        int i = 0;
        int k = 0;
        while (i < this.parameters.length) {
            int j = 0;
            while (j < this.parameters[i].length) {
                pars[k] = this.parameters[i][j];
                ++j;
                ++k;
            }
            ++i;
        }
        return pars;
    }

    @Override
    public void setParameters(double[] params, int start) {
        int i = 0;
        while (i < this.parameters.length) {
            int j = 0;
            while (j < this.parameters[i].length) {
                this.parameters[i][j] = params[start];
                ++j;
                ++start;
            }
            ++i;
        }
        this.norm = null;
    }

    @Override
    public String getInstanceName() {
        return "Position weight matrix";
    }

    @Override
    public boolean isInitialized() {
        return this.isInitialized;
    }

    @Override
    public StringBuffer toXML() {
        StringBuffer xml = new StringBuffer();
        XMLParser.appendObjectWithTags(xml, this.alphabets, "alphabets");
        XMLParser.appendObjectWithTags(xml, this.length, "length");
        XMLParser.appendObjectWithTags(xml, this.parameters, "parameters");
        XMLParser.appendObjectWithTags(xml, this.isInitialized, "isInitialized");
        XMLParser.appendObjectWithTags(xml, this.ess, "ess");
        XMLParser.addTags(xml, "PWM");
        return xml;
    }

    @Override
    protected void fromXML(StringBuffer xml) throws NonParsableException {
        xml = XMLParser.extractForTag(xml, "PWM");
        this.alphabets = (AlphabetContainer)XMLParser.extractObjectForTags(xml, "alphabets");
        this.length = XMLParser.extractObjectForTags(xml, "length", Integer.TYPE);
        this.parameters = (double[][])XMLParser.extractObjectForTags(xml, "parameters");
        this.isInitialized = XMLParser.extractObjectForTags(xml, "isInitialized", Boolean.TYPE);
        this.ess = XMLParser.extractObjectForTags(xml, "ess", Double.TYPE);
    }
}

