/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.motif;

import de.jstacs.data.DataSet;
import de.jstacs.data.WrongAlphabetException;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.motif.DurationDiffSM;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.random.DirichletMRG;
import de.jstacs.utils.random.DirichletMRGParams;
import java.text.NumberFormat;
import java.util.HashSet;

public class MixtureDurationDiffSM
extends DurationDiffSM {
    private DurationDiffSM[] function;
    private double[] hiddenParams;
    private double[] scores;
    private double logNorm;
    private int[] paramRef;
    private int[] partDerOffset;
    private int starts;
    private IntList help;
    private static String XML_TAG = "MixtureDurationDiffSM";

    private static double getESS(DurationDiffSM ... function) {
        double ess = function[0].getESS();
        boolean noESS = ess == 0.0;
        for (int i = 1; i < function.length; ++i) {
            double e = function[i].getESS();
            if (noESS) {
                if (!(e > 0.0)) continue;
                throw new IllegalArgumentException("The ESS of duration " + i + " has to be zero.");
            }
            ess += e;
        }
        return ess;
    }

    public MixtureDurationDiffSM(int starts, DurationDiffSM ... function) throws WrongAlphabetException, CloneNotSupportedException, IllegalArgumentException {
        super(function[0].getMin(), function[0].getMax(), MixtureDurationDiffSM.getESS(function));
        if (starts <= 0) {
            throw new IllegalArgumentException("The number of recommended starts should be positive.");
        }
        this.starts = starts;
        this.function = new DurationDiffSM[function.length];
        this.scores = new double[function.length];
        this.hiddenParams = new double[function.length];
        this.paramRef = null;
        this.partDerOffset = new int[function.length];
        for (int i = 0; i < function.length; ++i) {
            if (!this.alphabets.checkConsistency(function[i].getAlphabetContainer())) {
                throw new WrongAlphabetException("All durations have to have the same alphabet: Violated at position " + i);
            }
            this.function[i] = (DurationDiffSM)function[i].clone();
        }
        this.logNorm = Normalisation.getLogSum(this.hiddenParams);
        this.setParamRef(false);
        this.help = new IntList();
    }

    private void setParamRef(boolean freeParams) {
        int i;
        if (this.paramRef == null || this.paramRef.length != this.function.length + 2) {
            this.paramRef = new int[this.function.length + 2];
        }
        boolean unknown = false;
        for (i = 0; i < this.function.length; ++i) {
            int n = this.function[i].getNumberOfParameters();
            unknown |= n < 0;
            this.paramRef[i + 1] = this.paramRef[i] + n;
        }
        this.paramRef[i + 1] = unknown ? -1 : this.paramRef[i] + this.scores.length - (freeParams ? 1 : 0);
    }

    public MixtureDurationDiffSM(StringBuffer source) throws NonParsableException {
        super(source);
    }

    @Override
    public MixtureDurationDiffSM clone() throws CloneNotSupportedException {
        MixtureDurationDiffSM clone = (MixtureDurationDiffSM)super.clone();
        clone.function = (DurationDiffSM[])ArrayHandler.clone((Cloneable[])this.function);
        clone.scores = (double[])this.scores.clone();
        clone.paramRef = (int[])this.paramRef.clone();
        clone.partDerOffset = (int[])this.partDerOffset.clone();
        clone.hiddenParams = (double[])this.hiddenParams.clone();
        clone.help = new IntList();
        return clone;
    }

    @Override
    public void adjust(int[] length, double[] weight) {
        Class<?> c;
        int i;
        double[][] assignedWeights = new double[this.function.length][];
        double[] stat = new double[this.hiddenParams.length];
        double all = 0.0;
        HashSet names = new HashSet();
        for (i = 0; i < this.function.length && !names.contains(c = this.function[i].getClass()); ++i) {
            names.add(c);
        }
        boolean init = i < this.function.length;
        for (i = 0; i < this.function.length; ++i) {
            if (init) {
                try {
                    this.function[i].initializeFunctionRandomly(false);
                }
                catch (Exception e) {
                    throw new RuntimeException();
                }
            } else {
                this.function[i].adjust(length, weight);
            }
            this.hiddenParams[i] = 0.0;
            assignedWeights[i] = new double[weight.length];
            stat[i] = this.function[i].getESS();
            all += stat[i];
        }
        this.logNorm = Normalisation.getLogSum(this.hiddenParams);
        int[] values = new int[1];
        for (int l = 0; l < length.length; ++l) {
            values[0] = length[l];
            for (i = 0; i < this.function.length; ++i) {
                this.scores[i] = this.hiddenParams[i] + this.function[i].getLogScore(values);
            }
            Normalisation.logSumNormalisation(this.scores);
            for (i = 0; i < this.function.length; ++i) {
                assignedWeights[i][l] = weight[l] * this.scores[i];
                int n = i;
                stat[n] = stat[n] + weight[l] * this.scores[i];
            }
            all += weight[l];
        }
        for (i = 0; i < this.function.length; ++i) {
            this.function[i].adjust(length, assignedWeights[i]);
            this.hiddenParams[i] = Math.log(stat[i] / all);
        }
        this.logNorm = 0.0;
    }

    @Override
    public double getLogScore(int ... values) {
        for (int i = 0; i < this.function.length; ++i) {
            this.scores[i] = this.hiddenParams[i] + this.function[i].getLogScore(values);
        }
        return Normalisation.getLogSum(this.scores) - this.logNorm;
    }

    @Override
    public double getLogScoreAndPartialDerivation(IntList indices, DoubleList partialDer, int ... values) {
        int j;
        int i;
        int o = partialDer.length();
        for (i = 0; i < this.function.length; ++i) {
            this.help.clear();
            this.scores[i] = this.hiddenParams[i] + this.function[i].getLogScoreAndPartialDerivation(this.help, partialDer, values);
            this.partDerOffset[i] = partialDer.length();
            for (j = 0; j < this.help.length(); ++j) {
                indices.add(this.paramRef[i] + this.help.get(j));
            }
        }
        double logScore = Normalisation.logSumNormalisation(this.scores);
        for (i = 0; i < this.function.length; ++i) {
            partialDer.multiply(o, this.partDerOffset[i], this.scores[i]);
            o = this.partDerOffset[i];
        }
        i = 0;
        j = this.paramRef[this.function.length];
        while (j < this.paramRef[this.function.length + 1]) {
            indices.add(j);
            partialDer.add(this.scores[i] - Math.exp(this.hiddenParams[i] - this.logNorm));
            ++j;
            ++i;
        }
        return logScore - this.logNorm;
    }

    @Override
    public void addGradientOfLogPriorTerm(double[] grad, int start) throws Exception {
        if (this.ess > 0.0) {
            for (int i = 0; i < this.function.length; ++i) {
                this.function[i].addGradientOfLogPriorTerm(grad, this.paramRef[i] + start);
            }
            int j = 0;
            int i = this.paramRef[this.function.length];
            while (i < this.paramRef[this.function.length + 1]) {
                grad[i + start] = this.function[j].getESS() - this.ess * Math.exp(this.hiddenParams[j] - this.logNorm);
                ++i;
                ++j;
            }
        }
    }

    @Override
    public double getLogPriorTerm() {
        double lp = 0.0;
        if (this.ess > 0.0) {
            for (int i = 0; i < this.function.length; ++i) {
                lp += this.function[i].getLogPriorTerm() + this.function[i].getESS() * this.hiddenParams[i];
            }
            lp -= this.ess * this.logNorm;
        }
        return lp;
    }

    @Override
    public double[] getCurrentParameterValues() throws Exception {
        int n = this.getNumberOfParameters();
        if (n > 0) {
            double[] params = new double[n];
            for (int i = 0; i < this.function.length; ++i) {
                double[] current = this.function[i].getCurrentParameterValues();
                System.arraycopy(current, 0, params, this.paramRef[i], current.length);
            }
            int j = 0;
            int i = this.paramRef[this.function.length];
            while (i < this.paramRef[this.function.length + 1]) {
                params[i] = this.hiddenParams[j];
                ++i;
                ++j;
            }
            return params;
        }
        throw new RuntimeException();
    }

    @Override
    public String getInstanceName() {
        String name = "mixture duration(" + this.function[0].getInstanceName();
        for (int i = 1; i < this.function.length; ++i) {
            name = name + ", " + this.function[i].getInstanceName();
        }
        return name + ")";
    }

    @Override
    public int getNumberOfParameters() {
        return this.paramRef[this.function.length + 1];
    }

    @Override
    public void initializeFunction(int index, boolean freeParams, DataSet[] data, double[][] weights) throws Exception {
        int i;
        double w = 1.0;
        double all = 0.0;
        for (i = 0; i < this.function.length; ++i) {
            this.hiddenParams[i] = this.function[i].getESS();
            all += this.hiddenParams[i];
        }
        DirichletMRGParams params = new DirichletMRGParams(this.hiddenParams);
        double[] current = new double[this.function.length];
        if (weights == null) {
            weights = new double[data.length][];
        }
        double[] help = weights[index];
        double[][] componentWeights = new double[this.function.length][data[index].getNumberOfElements()];
        for (int j = 0; j < componentWeights[0].length; ++j) {
            DirichletMRG.DEFAULT_INSTANCE.generate(current, 0, this.function.length, params);
            if (help != null) {
                w = help[j];
            }
            all += w;
            for (i = 0; i < this.function.length; ++i) {
                componentWeights[i][j] = w * current[i];
                int n = i;
                this.hiddenParams[n] = this.hiddenParams[n] + componentWeights[i][j];
            }
        }
        for (i = 0; i < this.function.length; ++i) {
            weights[index] = componentWeights[i];
            this.hiddenParams[i] = Math.log(this.hiddenParams[i] / all);
        }
        this.logNorm = 0.0;
        weights[index] = help;
        this.setParamRef(freeParams);
    }

    @Override
    public void initializeUniformly() {
        for (int i = 0; i < this.function.length; ++i) {
            this.function[i].initializeUniformly();
            this.hiddenParams[i] = 0.0;
        }
        this.logNorm = Normalisation.getLogSum(this.hiddenParams);
        this.setParamRef(false);
    }

    @Override
    public void initializeFunctionRandomly(boolean freeParams) throws Exception {
        int i;
        boolean noPrior = this.ess == 0.0;
        double[] hyper = (double[])this.scores.clone();
        for (i = 0; i < this.function.length; ++i) {
            this.function[i].initializeFunctionRandomly(freeParams);
            hyper[i] = noPrior ? 1.0 : this.function[i].getESS();
        }
        DirichletMRG.DEFAULT_INSTANCE.generate(this.hiddenParams, 0, this.function.length, new DirichletMRGParams(hyper));
        for (i = 0; i < this.function.length; ++i) {
            this.hiddenParams[i] = Math.log(this.hiddenParams[i]);
        }
        this.logNorm = 0.0;
        this.setParamRef(freeParams);
    }

    @Override
    public boolean isInitialized() {
        int i;
        for (i = 0; i < this.function.length && this.function[i].isInitialized(); ++i) {
        }
        return i == this.function.length;
    }

    @Override
    public boolean isNormalized() {
        return true;
    }

    @Override
    public void setParameters(double[] params, int start) {
        for (int i = 0; i < this.function.length; ++i) {
            this.function[i].setParameters(params, start + this.paramRef[i]);
        }
        int j = 0;
        int i = this.paramRef[this.function.length];
        while (i < this.paramRef[this.function.length + 1]) {
            this.hiddenParams[j] = params[start + i];
            ++i;
            ++j;
        }
        this.logNorm = Normalisation.getLogSum(this.hiddenParams);
    }

    @Override
    protected String getRNotation(String distributionName, NumberFormat nf) {
        String r = "";
        String sum = null;
        for (int i = 0; i < this.function.length; ++i) {
            r = r + this.function[i].getRNotation(distributionName + i, nf) + "\n";
            sum = sum == null ? distributionName + " = " : sum + " + ";
            sum = sum + nf.format(Math.exp(this.hiddenParams[i] - this.logNorm)) + " * " + distributionName + i;
        }
        return r + sum + ";";
    }

    @Override
    public void modify(int delta) {
        if (delta != 0) {
            super.modify(delta);
            for (int i = 0; i < this.function.length; ++i) {
                this.function[i].modify(delta);
            }
        }
    }

    @Override
    public int getNumberOfRecommendedStarts() {
        return this.starts;
    }

    @Override
    protected void fromXML(StringBuffer rep) throws NonParsableException {
        StringBuffer xml = XMLParser.extractForTag(rep, XML_TAG);
        super.fromXML(xml);
        this.function = XMLParser.extractObjectForTags(xml, "components", DurationDiffSM[].class);
        this.hiddenParams = XMLParser.extractObjectForTags(xml, "hiddenParams", double[].class);
        this.starts = XMLParser.extractObjectForTags(xml, "starts", Integer.TYPE);
        this.scores = new double[this.function.length];
        this.paramRef = null;
        this.partDerOffset = new int[this.function.length];
        this.logNorm = Normalisation.getLogSum(this.hiddenParams);
        this.setParamRef(false);
        this.help = new IntList();
    }

    @Override
    public StringBuffer toXML() {
        StringBuffer xml = super.toXML();
        XMLParser.appendObjectWithTags(xml, this.function, "components");
        XMLParser.appendObjectWithTags(xml, this.hiddenParams, "hiddenParams");
        XMLParser.appendObjectWithTags(xml, this.starts, "starts");
        XMLParser.addTags(xml, XML_TAG);
        return xml;
    }
}

