/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.sequenceScores.statisticalModels.differentiable.directedGraphicalModels;

import de.jstacs.Storable;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.alphabets.DiscreteAlphabet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.sequenceScores.statisticalModels.differentiable.directedGraphicalModels.BNDiffSMParameter;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.random.DirichletMRG;
import de.jstacs.utils.random.DirichletMRGParams;
import de.jtem.numericalMethods.calculus.specialFunctions.Gamma;
import java.text.NumberFormat;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.Random;

public class BNDiffSMParameterTree
implements Cloneable,
Storable {
    private int pos;
    private int[] contextPoss;
    private TreeElement root;
    private AlphabetContainer alphabet;
    private int firstParent;
    private int[] firstChildren;
    private static Random r = new Random();

    public BNDiffSMParameterTree(int pos, int[] contextPoss, AlphabetContainer alphabet, int firstParent, int[] firstChildren) {
        this.pos = pos;
        this.contextPoss = contextPoss;
        this.alphabet = alphabet;
        this.firstParent = firstParent;
        this.firstChildren = firstChildren;
        this.root = new TreeElement(0, alphabet);
    }

    public BNDiffSMParameterTree(StringBuffer source) throws NonParsableException {
        source = XMLParser.extractForTag(source, "parameterTree");
        this.pos = XMLParser.extractObjectForTags(source, "pos", Integer.TYPE);
        this.contextPoss = XMLParser.extractObjectForTags(source, "contextPoss", int[].class);
        this.root = new TreeElement(XMLParser.extractForTag(source, "root"));
        this.alphabet = null;
        this.firstParent = XMLParser.extractObjectForTags(source, "firstParent", Integer.TYPE);
        this.firstChildren = XMLParser.extractObjectForTags(source, "firstChildren", int[].class);
    }

    void setAlphabet(AlphabetContainer alphabet) {
        this.alphabet = alphabet;
    }

    public BNDiffSMParameterTree clone() throws CloneNotSupportedException {
        BNDiffSMParameterTree clone = (BNDiffSMParameterTree)super.clone();
        clone.contextPoss = (int[])this.contextPoss.clone();
        clone.cloneRoot();
        clone.firstChildren = (int[])this.firstChildren.clone();
        return clone;
    }

    private void cloneRoot() throws CloneNotSupportedException {
        TreeElement temp = this.root;
        this.root = new TreeElement(this.root.contNum, this.alphabet);
        this.root.cloneRest(temp);
    }

    public String toString(NumberFormat nf) {
        StringBuffer all = new StringBuffer();
        all.append("Probabilities at position " + this.pos + ":\n");
        this.root.appendToBuffer(all, "", nf);
        return all.toString();
    }

    public void insertProbs(double[] probs) throws Exception {
        this.root.insertProbs(probs);
    }

    public LinkedList<BNDiffSMParameter> linearizeParameters() {
        return this.root.linearizeParameters(new LinkedList());
    }

    public boolean isLeaf() {
        return this.firstChildren.length == 0;
    }

    public int getNumberOfParents() {
        return this.contextPoss.length;
    }

    public void print() {
        System.out.println("tree " + this.pos + ": ");
        this.root.print();
    }

    public BNDiffSMParameter getParameterFor(Sequence seq, int start) {
        return this.root.getParameterFor(seq, start);
    }

    public void setParameterFor(int symbol, int[][] context, BNDiffSMParameter par) {
        this.root.setParameterFor(0, symbol, context, par);
    }

    public void invalidateNormalizers() {
        this.root.invalidateNormalizers();
    }

    public double forward(BNDiffSMParameterTree[] trees) throws RuntimeException {
        if (this.getNumberOfParents() > 0) {
            throw new RuntimeException("Forward can only be started at roots.");
        }
        return this.getLogZ(new int[0][2], trees);
    }

    private double getLogZ(int[][] context, BNDiffSMParameterTree[] trees) throws RuntimeException {
        return this.root.getLogZ(context, new int[this.getNumberOfParents() + 1][2], trees, 0);
    }

    private double getLogT(int[][] context, BNDiffSMParameterTree[] trees, int[][] order) throws RuntimeException {
        return this.root.getLogT(context, this.firstParent > -1 ? new int[trees[this.firstParent].contextPoss.length + 1][2] : new int[0][2], trees, order, 0);
    }

    public void backward(BNDiffSMParameterTree[] trees, int[][] order) throws RuntimeException {
        if (!this.isLeaf()) {
            throw new RuntimeException("Backward can only be started at leaves.");
        }
        this.root.startBackward(new int[this.getNumberOfParents() + 1][2], trees, order, 0);
    }

    public void addCount(Sequence seq, int start, double count) {
        this.getParameterFor(seq, start).addCount(count);
    }

    public void normalizePlugInParameters() {
        this.root.normalizePlugInParameters();
    }

    public void normalizeParameters() {
        this.root.normalizeParameters();
    }

    public void divideByUnfree() {
        this.root.divideByUnfree();
    }

    @Override
    public StringBuffer toXML() {
        StringBuffer source = new StringBuffer();
        XMLParser.appendObjectWithTags(source, this.pos, "pos");
        XMLParser.appendObjectWithTags(source, this.contextPoss, "contextPoss");
        XMLParser.appendObjectWithTags(source, this.root, "root");
        XMLParser.appendObjectWithTags(source, this.firstParent, "firstParent");
        XMLParser.appendObjectWithTags(source, this.firstChildren, "firstChildren");
        XMLParser.addTags(source, "parameterTree");
        return source;
    }

    public int getFirstParent() {
        return this.firstParent;
    }

    public void drawKLDivergences(double weight, double[] kls, int startIdx, int endIdx, double[][][] contrast, double samples) {
        this.root.drawKLDivergences(weight, kls, startIdx, endIdx, contrast, samples, 0, 0);
    }

    public double getKLDivergence(double[][][] ds) {
        return this.root.getWeightedKLDivergence(ds, 0, 0);
    }

    public double getKLDivergence(double[] weight, double[][][][] distribution) {
        return this.root.getWeightedKLDivergence(weight, distribution, 0, 0);
    }

    public void drawKLDivergences(double[] kls, double[] weights, double[][][][] contrast, double samples) {
        this.root.drawKLDivergences(kls, weights, contrast, samples, 0, 0);
    }

    public void fill(double[] weight, double[][][][] distribution) {
        this.root.setNewParameters(weight, distribution, 0, 0);
    }

    public void copy(BNDiffSMParameterTree parameterTree) {
        this.root.copy(parameterTree.root);
    }

    public void initializeRandomly(double ess) {
        this.root.initializeRandomly(ess);
    }

    public Double computeGammaNorm() {
        return this.root.computeGammaNorm();
    }

    public double getProbFor(Sequence sequence) {
        return this.root.getProbFor(sequence, 0);
    }

    int getNumberOfSamplingSteps() {
        return this.root.getNumberOfSamplingSteps();
    }

    int getNumberOfParameters() {
        return this.root.getNumberOfParameters();
    }

    public int[] getParameterIndexesForSamplingStep(int step, int offset) {
        return this.root.getParameterIndexesForSamplingStep(step, offset);
    }

    void emitSymbol(int[] content) {
        this.root.emitSymbol(content);
    }

    public byte getMaximalMarkovOrder() {
        return this.root.getMaximalMarkovOrder((byte)0);
    }

    public double getMaximumScore() {
        return this.root.getMaximumScore();
    }

    public String toHtml(NumberFormat nf) {
        StringBuffer all = new StringBuffer();
        all.append("<p><strong>Probabilities at position " + this.pos + ":<strong><br/>");
        all.append("<table border=\"1\"><tr>");
        if (this.getNumberOfParents() > 0) {
            all.append("<th>context</th>");
        }
        int i = 0;
        while ((double)i < this.alphabet.getAlphabetLengthAt(this.pos)) {
            all.append("<th>" + ((DiscreteAlphabet)this.alphabet.getAlphabetAt(this.pos)).getSymbolAt(i) + "</th>");
            ++i;
        }
        all.append("</tr>");
        this.root.appendHtmlToBuffer(all, "", nf);
        all.append("</table></p>");
        return all.toString();
    }

    public class TreeElement
    implements Storable,
    Cloneable {
        private int contextPos;
        private TreeElement[] children;
        private BNDiffSMParameter[] pars;
        private Double fullNormalizer;
        private Double[] symT;
        private int contNum;

        private TreeElement(int contNum, AlphabetContainer alphabet) {
            this.contNum = contNum;
            if (contNum < BNDiffSMParameterTree.this.contextPoss.length) {
                this.contextPos = BNDiffSMParameterTree.this.contextPoss[contNum];
                this.children = new TreeElement[(int)alphabet.getAlphabetLengthAt(this.contextPos)];
                int i = 0;
                while ((double)i < alphabet.getAlphabetLengthAt(this.contextPos)) {
                    this.children[i] = new TreeElement(contNum + 1, alphabet);
                    ++i;
                }
            } else {
                this.contextPos = -1;
                this.pars = new BNDiffSMParameter[(int)alphabet.getAlphabetLengthAt(BNDiffSMParameterTree.this.pos)];
                this.fullNormalizer = null;
                this.symT = new Double[this.pars.length];
            }
        }

        private void appendToBuffer(StringBuffer all, String after, NumberFormat nf) {
            if (this.children != null) {
                for (int i = 0; i < this.children.length; ++i) {
                    this.children[i].appendToBuffer(all, after + "X_" + this.contextPos + " = " + BNDiffSMParameterTree.this.alphabet.getSymbol(this.contextPos, i) + ", ", nf);
                }
            } else {
                double[] norms = new double[this.pars.length];
                for (int i = 0; i < this.pars.length; ++i) {
                    norms[i] = this.pars[i].getValue() + this.pars[i].getLogZ();
                }
                double logNorm = Normalisation.getLogSum(norms);
                for (int i = 0; i < this.pars.length; ++i) {
                    double tempTheta = Math.exp(this.pars[i].getValue() + this.pars[i].getLogZ() - logNorm);
                    all.append("P(X_" + BNDiffSMParameterTree.this.pos + " = " + BNDiffSMParameterTree.this.alphabet.getSymbol(BNDiffSMParameterTree.this.pos, i) + " | " + after + "c)=" + nf.format(tempTheta));
                    if (i >= this.pars.length - 1) continue;
                    all.append("\t");
                }
                all.append("\n");
            }
        }

        private void normalizeParameters() {
            block5: {
                int i;
                block4: {
                    if (this.children == null) break block4;
                    for (int i2 = 0; i2 < this.children.length; ++i2) {
                        this.children[i2].normalizeParameters();
                    }
                    break block5;
                }
                double[] norms = new double[this.pars.length];
                for (int i3 = 0; i3 < this.pars.length; ++i3) {
                    norms[i3] = this.pars[i3].getValue() + this.pars[i3].getLogZ();
                }
                double logNorm = Normalisation.getLogSum(norms);
                for (i = 0; i < this.pars.length; ++i) {
                    double tempTheta = this.pars[i].getValue() + this.pars[i].getLogZ() - logNorm;
                    this.pars[i].setValue(tempTheta);
                }
                if (this.pars[this.pars.length - 1].isFree()) break block5;
                for (i = 0; i < this.pars.length; ++i) {
                    this.pars[i].setValue(this.pars[i].getValue() - this.pars[this.pars.length - 1].getValue());
                }
            }
        }

        private void insertProbs(double[] probs) throws Exception {
            if (this.children != null) {
                for (int i = 0; i < this.children.length; ++i) {
                    this.children[i].insertProbs(probs);
                }
            } else {
                int i;
                double[] myProbs = new double[this.pars.length];
                for (i = 0; i < this.pars.length; ++i) {
                    myProbs[i] = this.pars[i].getValue() + this.pars[i].getLogZ();
                }
                Normalisation.logSumNormalisation(myProbs);
                for (i = 0; i < probs.length; ++i) {
                    int n = i;
                    probs[n] = probs[n] + this.getContextProbability() * myProbs[i];
                }
            }
        }

        private void startBackward(int[][] newContext, BNDiffSMParameterTree[] trees, int[][] order, int depth) throws RuntimeException {
            if (this.children != null) {
                newContext[depth][0] = this.contextPos;
                for (int i = 0; i < this.children.length; ++i) {
                    newContext[depth][1] = i;
                    this.children[i].startBackward(newContext, trees, order, depth + 1);
                }
            } else {
                newContext[depth][0] = BNDiffSMParameterTree.this.pos;
                for (int i = 0; i < this.pars.length; ++i) {
                    newContext[depth][1] = this.pars[i].symbol;
                    trees[BNDiffSMParameterTree.this.pos].getLogT(newContext, trees, order);
                }
            }
        }

        private double getLogT(int[][] context, int[][] newContext, BNDiffSMParameterTree[] trees, int[][] order, int depth) throws RuntimeException {
            if (this.children != null) {
                for (int i = 0; i < context.length; ++i) {
                    if (context[i][0] != this.contextPos) continue;
                    newContext[depth][0] = context[i][0];
                    newContext[depth][1] = context[i][1];
                    return this.children[context[i][1]].getLogT(context, newContext, trees, order, depth + 1);
                }
                throw new RuntimeException("Correct context not found for depth " + depth + " at position " + BNDiffSMParameterTree.this.pos + ".");
            }
            for (int i = 0; i < context.length; ++i) {
                if (context[i][0] != BNDiffSMParameterTree.this.pos) continue;
                for (int j = 0; j < this.pars.length; ++j) {
                    if (this.pars[j].symbol != context[i][1]) continue;
                    if (this.symT[j] != null) {
                        return this.symT[j];
                    }
                    int fp = BNDiffSMParameterTree.this.firstParent;
                    if (fp == -1) {
                        this.pars[j].setLogT(0.0);
                        return this.pars[j].getValue();
                    }
                    if (trees[fp].contextPoss.length < BNDiffSMParameterTree.this.contextPoss.length) {
                        double temp = trees[fp].getLogT(newContext, trees, order);
                        int[] fcoffp = trees[fp].firstChildren;
                        for (int k = 0; k < fcoffp.length; ++k) {
                            if (fcoffp[k] == BNDiffSMParameterTree.this.pos) continue;
                            temp += trees[fcoffp[k]].getLogZ(newContext, trees);
                        }
                        this.pars[j].setLogT(temp);
                        this.symT[j] = this.pars[j].getValue() + temp;
                        return this.symT[j];
                    }
                    int[] cp = trees[fp].contextPoss;
                    int lowestOrder = Integer.MAX_VALUE;
                    int lowestOrderIndex = -1;
                    for (int k = 0; k < cp.length; ++k) {
                        if (order[cp[k]][1] >= lowestOrder) continue;
                        lowestOrder = order[cp[k]][1];
                        lowestOrderIndex = cp[k];
                    }
                    newContext[depth][0] = lowestOrderIndex;
                    int al = (int)BNDiffSMParameterTree.this.alphabet.getAlphabetLengthAt(lowestOrderIndex);
                    double[] temp = new double[al];
                    for (int a = 0; a < al; a = (int)((byte)(a + 1))) {
                        newContext[depth][1] = a;
                        temp[a] = trees[fp].getLogT(newContext, trees, order);
                        int[] fcoffp = trees[fp].firstChildren;
                        for (int k = 0; k < fcoffp.length; ++k) {
                            if (fcoffp[k] == BNDiffSMParameterTree.this.pos) continue;
                            int n = a;
                            temp[n] = temp[n] + trees[fcoffp[k]].getLogZ(newContext, trees);
                        }
                    }
                    double temp2 = Normalisation.getLogSum(temp);
                    this.pars[j].setLogT(temp2);
                    this.symT[j] = this.pars[j].getValue() + temp2;
                    return this.symT[j];
                }
            }
            throw new RuntimeException("BNDiffSMParameter value not defined in context.");
        }

        private double getLogZ(int[][] context, int[][] newContext, BNDiffSMParameterTree[] trees, int depth) throws RuntimeException {
            if (this.children != null) {
                for (int i = 0; i < context.length; ++i) {
                    if (context[i][0] != this.contextPos) continue;
                    newContext[depth][0] = context[i][0];
                    newContext[depth][1] = context[i][1];
                    return this.children[context[i][1]].getLogZ(context, newContext, trees, depth + 1);
                }
                throw new RuntimeException("Correct context could not be found at position " + BNDiffSMParameterTree.this.pos + " and depth " + depth);
            }
            if (this.fullNormalizer != null) {
                return this.fullNormalizer;
            }
            double[] vals = new double[this.pars.length];
            for (int i = 0; i < this.pars.length; ++i) {
                int[] fc = BNDiffSMParameterTree.this.firstChildren;
                if (fc == null) {
                    throw new RuntimeException("First children of parameter " + this.pars[i].getIndex() + " not defined.");
                }
                newContext[depth][0] = this.pars[i].getPosition();
                newContext[depth][1] = this.pars[i].symbol;
                double temp = 0.0;
                for (int j = 0; j < fc.length; ++j) {
                    temp += trees[fc[j]].getLogZ(newContext, trees);
                }
                this.pars[i].setLogZ(temp);
                vals[i] = this.pars[i].getValue() + temp;
            }
            this.fullNormalizer = Normalisation.getLogSum(vals);
            return this.fullNormalizer;
        }

        private void invalidateNormalizers() {
            if (this.children != null) {
                for (int i = 0; i < this.children.length; ++i) {
                    this.children[i].invalidateNormalizers();
                }
            } else {
                for (int i = 0; i < this.pars.length; ++i) {
                    this.pars[i].invalidateNormalizers();
                    this.symT[i] = null;
                }
            }
            this.fullNormalizer = null;
        }

        private void cloneRest(TreeElement original) throws CloneNotSupportedException {
            int i;
            this.contextPos = original.contextPos;
            if (this.children != null) {
                this.children = new TreeElement[this.children.length];
                for (i = 0; i < this.children.length; ++i) {
                    this.children[i] = new TreeElement(original.children[i].contNum, BNDiffSMParameterTree.this.alphabet);
                    this.children[i].cloneRest(original.children[i]);
                }
            } else {
                this.children = null;
            }
            if (this.pars != null) {
                this.pars = new BNDiffSMParameter[this.pars.length];
                for (i = 0; i < this.pars.length; ++i) {
                    this.pars[i] = original.pars[i].clone();
                }
                this.fullNormalizer = null;
                this.symT = new Double[this.pars.length];
            }
        }

        private void divideByUnfree() {
            if (this.pars != null) {
                double div = this.pars[this.pars.length - 1].getValue();
                for (int i = 0; i < this.pars.length; ++i) {
                    if (!Double.isNaN(this.pars[i].getValue() - div) && !Double.isInfinite(this.pars[i].getValue() - div)) {
                        this.pars[i].setValue(this.pars[i].getValue() - div);
                        continue;
                    }
                    this.pars[i].setValue(0.0);
                }
            } else {
                for (int i = 0; i < this.children.length; ++i) {
                    this.children[i].divideByUnfree();
                }
            }
        }

        private LinkedList<BNDiffSMParameter> linearizeParameters(LinkedList<BNDiffSMParameter> list) {
            if (this.children != null) {
                for (int i = 0; i < this.children.length; ++i) {
                    this.children[i].linearizeParameters(list);
                }
            } else {
                for (int i = 0; i < this.pars.length; ++i) {
                    list.add(this.pars[i]);
                }
            }
            return list;
        }

        public TreeElement(StringBuffer representation) throws NonParsableException {
            representation = XMLParser.extractForTag(representation, "treeElement");
            this.contNum = XMLParser.extractObjectForTags(representation, "contNum", Integer.TYPE);
            this.contextPos = XMLParser.extractObjectForTags(representation, "contextPos", Integer.TYPE);
            this.children = XMLParser.extractObjectAndAttributesForTags(representation, "children", null, null, TreeElement[].class, BNDiffSMParameterTree.class, BNDiffSMParameterTree.this);
            this.pars = XMLParser.extractObjectForTags(representation, "pars", BNDiffSMParameter[].class);
            if (this.pars != null) {
                this.symT = new Double[this.pars.length];
                this.fullNormalizer = null;
            }
        }

        private void setParameterFor(int depth, int symbol, int[][] context, BNDiffSMParameter par) {
            if (this.children != null) {
                for (int i = 1; i < context[depth].length; ++i) {
                    this.children[context[depth][i]].setParameterFor(depth + 1, symbol, context, par);
                }
            } else {
                this.pars[symbol] = par;
            }
        }

        private void print() {
            System.out.println(this.contextPos);
            if (this.children != null) {
                for (int i = 0; i < this.children.length; ++i) {
                    System.out.println("child " + i + ":");
                    this.children[i].print();
                }
            } else {
                for (int i = 0; i < this.pars.length; ++i) {
                    this.pars[i].print();
                }
            }
        }

        private void normalizePlugInParameters() {
            if (this.children != null) {
                for (int i = 0; i < this.children.length; ++i) {
                    this.children[i].normalizePlugInParameters();
                }
            } else {
                int i;
                double sum = 0.0;
                for (i = 0; i < this.pars.length; ++i) {
                    sum += this.pars[i].getCounts();
                }
                if (sum > 0.0) {
                    for (i = 0; i < this.pars.length; ++i) {
                        this.pars[i].setValue(Math.log(this.pars[i].getCounts() / sum));
                    }
                } else {
                    for (i = 0; i < this.pars.length; ++i) {
                        this.pars[i].setValue(-Math.log(this.pars.length));
                    }
                }
            }
        }

        private BNDiffSMParameter getParameterFor(Sequence seq, int start) {
            if (this.children != null) {
                return this.children[seq.discreteVal(this.contextPos + start)].getParameterFor(seq, start);
            }
            return this.pars[seq.discreteVal(BNDiffSMParameterTree.this.pos + start)];
        }

        @Override
        public StringBuffer toXML() {
            StringBuffer source = new StringBuffer();
            XMLParser.appendObjectWithTags(source, this.contNum, "contNum");
            XMLParser.appendObjectWithTags(source, this.contextPos, "contextPos");
            XMLParser.appendObjectWithTags(source, this.children, "children");
            XMLParser.appendObjectWithTags(source, this.pars, "pars");
            XMLParser.addTags(source, "treeElement");
            return source;
        }

        private void drawKLDivergences(double weight, double[] kls, int startIdx, int endIdx, double[][][] ds, double samples, int context, int depth) {
            if (this.children != null && depth < ds.length - 1) {
                for (int i = 0; i < this.children.length; ++i) {
                    this.children[i].drawKLDivergences(weight, kls, startIdx, endIdx, ds, samples, context + i * ds[depth].length, depth + 1);
                }
            } else if (this.children != null) {
                for (int i = 0; i < this.children.length; ++i) {
                    this.children[i].drawKLDivergences(weight, kls, startIdx, endIdx, ds, samples, context, depth);
                }
            } else {
                double[] dist = ds[depth][context];
                double[] weightedEss = new double[this.pars.length];
                double contextProb = this.getContextProbability();
                for (int i = 0; i < weightedEss.length; ++i) {
                    weightedEss[i] = samples * contextProb * dist[i];
                }
                DirichletMRGParams p = new DirichletMRGParams(weightedEss);
                double[] temp = new double[weightedEss.length];
                for (int i = startIdx; i < endIdx; ++i) {
                    DirichletMRG.DEFAULT_INSTANCE.generate(temp, 0, temp.length, p);
                    for (int j = 0; j < temp.length; ++j) {
                        if (!(temp[j] > 0.0)) continue;
                        int n = i;
                        kls[n] = kls[n] + weight * contextProb * temp[j] * Math.log(temp[j] / dist[j]);
                    }
                }
            }
        }

        private void setNewParameters(double[] weight, double[][][][] distribution, int context, int depth) {
            int a = (int)BNDiffSMParameterTree.this.alphabet.getAlphabetLengthAt(this.contextPos);
            if (this.children != null) {
                a = (int)Math.pow(a, depth);
                for (int i = 0; i < this.children.length; ++i) {
                    this.children[i].setNewParameters(weight, distribution, context + i * a, depth + 1);
                }
            } else {
                this.fill(this.getMarginal(weight, distribution, context, depth));
            }
        }

        private double getWeightedKLDivergence(double[][][] ds, int context, int depth) {
            double kl = 0.0;
            if (this.children != null && depth < ds.length - 1) {
                for (int i = 0; i < this.children.length; ++i) {
                    kl += this.children[i].getWeightedKLDivergence(ds, context + i * ds[depth].length, depth + 1);
                }
            } else if (this.children != null) {
                for (int i = 0; i < this.children.length; ++i) {
                    kl += this.children[i].getWeightedKLDivergence(ds, context, depth);
                }
            } else {
                double val = 0.0;
                double[] norms = new double[this.pars.length];
                for (int i = 0; i < this.pars.length; ++i) {
                    norms[i] = this.pars[i].getValue() + this.pars[i].getLogZ();
                }
                double logNorm = Normalisation.getLogSum(norms);
                for (int i = 0; i < this.pars.length; ++i) {
                    double temp = Math.exp(this.pars[i].getValue() + this.pars[i].getLogZ() - logNorm);
                    if (!(temp > 0.0)) continue;
                    val += temp * Math.log(temp / ds[depth][context][i]);
                }
                double weight = this.getContextProbability();
                kl = val * weight;
            }
            return kl;
        }

        private double getWeightedKLDivergence(double[] weight, double[][][][] distribution, int context, int depth) {
            int i;
            double kl = 0.0;
            int a = (int)BNDiffSMParameterTree.this.alphabet.getAlphabetLengthAt(this.contextPos);
            if (this.children != null) {
                a = (int)Math.pow(a, depth);
                for (i = 0; i < this.children.length; ++i) {
                    kl += this.children[i].getWeightedKLDivergence(weight, distribution, context + i * a, depth + 1);
                }
            } else {
                double[] norms = new double[this.pars.length];
                while (i < this.pars.length) {
                    int n = i;
                    norms[n] = norms[n] + (this.pars[i].getValue() + this.pars[i].getLogZ());
                    ++i;
                }
                double logNorm = Normalisation.getLogSum(norms);
                double[] temp2 = this.getMarginal(weight, distribution, context, depth);
                for (i = 0; i < this.pars.length; ++i) {
                    double temp = Math.exp(this.pars[i].getValue() + this.pars[i].getLogZ() - logNorm);
                    if (!(temp > 0.0)) continue;
                    kl += temp * Math.log(temp / temp2[i]);
                }
                kl *= this.getContextProbability();
            }
            return kl;
        }

        private double[] getMarginal(double[] weight, double[][][][] distribution, int context, int depth) {
            int a = (int)BNDiffSMParameterTree.this.alphabet.getAlphabetLengthAt(this.contextPos);
            double[] marginal = new double[this.pars.length];
            for (int i = 0; i < weight.length; ++i) {
                int c;
                int d;
                if (depth < distribution[i].length) {
                    d = depth;
                    c = context;
                } else {
                    d = distribution[i].length - 1;
                    c = context % (int)Math.pow(a, d);
                }
                for (int j = 0; j < this.pars.length; ++j) {
                    int n = j;
                    marginal[n] = marginal[n] + weight[i] * distribution[i][d][c][j];
                }
            }
            return marginal;
        }

        private void drawKLDivergences(double[] kls, double[] weight, double[][][][] distribution, double samples, int context, int depth) {
            int i;
            int a = (int)BNDiffSMParameterTree.this.alphabet.getAlphabetLengthAt(this.contextPos);
            if (this.children != null) {
                a = (int)Math.pow(a, depth);
                for (i = 0; i < this.children.length; ++i) {
                    this.children[i].drawKLDivergences(kls, weight, distribution, samples, context + i * a, depth + 1);
                }
            } else {
                int j;
                double[] weightedEss = new double[this.pars.length];
                double[] marginal = new double[this.pars.length];
                double contextProb = this.getContextProbability();
                DirichletMRGParams[] p = new DirichletMRGParams[weight.length];
                for (i = 0; i < weight.length; ++i) {
                    int c;
                    int d;
                    if (depth < distribution[i].length) {
                        d = depth;
                        c = context;
                    } else {
                        d = distribution[i].length - 1;
                        c = context % (int)Math.pow(a, d);
                    }
                    for (j = 0; j < this.pars.length; ++j) {
                        int n = j;
                        marginal[n] = marginal[n] + weight[i] * distribution[i][d][c][j];
                        weightedEss[j] = samples * contextProb * weight[i] * distribution[i][d][c][j];
                    }
                    p[i] = new DirichletMRGParams(weightedEss);
                }
                double[] part = new double[this.pars.length];
                double[] marginalDrawn = new double[this.pars.length];
                for (i = 0; i < kls.length; ++i) {
                    Arrays.fill(marginalDrawn, 0.0);
                    for (j = 0; j < p.length; ++j) {
                        DirichletMRG.DEFAULT_INSTANCE.generate(part, 0, part.length, p[j]);
                        for (a = 0; a < this.pars.length; ++a) {
                            int n = a;
                            marginalDrawn[n] = marginalDrawn[n] + weight[j] * part[a];
                        }
                    }
                    for (j = 0; j < part.length; ++j) {
                        if (!(part[j] > 0.0)) continue;
                        int n = i;
                        kls[n] = kls[n] + contextProb * marginalDrawn[j] * Math.log(marginalDrawn[j] / marginal[j]);
                    }
                }
            }
        }

        private double getContextProbability() {
            if (this.children == null) {
                double[] norms = new double[this.pars.length];
                for (int i = 0; i < this.pars.length; ++i) {
                    norms[i] = this.pars[i].getValue() + this.pars[i].getLogZ();
                }
                double logNorm = Normalisation.getLogSum(norms);
                double[] vals = new double[this.pars.length];
                for (int i = 0; i < this.pars.length; ++i) {
                    vals[i] = this.pars[i].getValue() + this.pars[i].getLogZ() - logNorm + this.pars[i].getLogT();
                }
                return Math.exp(Normalisation.getLogSum(vals));
            }
            double val = 0.0;
            for (int i = 0; i < this.children.length; ++i) {
                val += this.children[i].getContextProbability();
            }
            return val;
        }

        private void findAndFill(double[][] fillEmptyWith, int contextLength) {
            this.fill(fillEmptyWith, 0, 1, contextLength);
        }

        private void fill(double[][] fillEmptyWith, int context, int power, int contextLength) {
            if (contextLength > 0) {
                --contextLength;
                for (int i = 0; i < this.children.length; ++i) {
                    this.children[i].fill(fillEmptyWith, context + i * power, power * fillEmptyWith[0].length, contextLength);
                }
            } else {
                this.fill(fillEmptyWith[context]);
            }
        }

        private void fill(double[] distr) {
            if (this.children != null) {
                for (int i = 0; i < this.children.length; ++i) {
                    this.children[i].fill(distr);
                }
            } else if (this.pars[this.pars.length - 1].isFree()) {
                if (distr.length != this.pars.length) {
                    throw new IndexOutOfBoundsException("Different number of values (" + distr.length + ") than free parameters (" + this.pars.length + ").");
                }
                for (int i = 0; i < this.pars.length; ++i) {
                    this.pars[i].setValue(Math.log(distr[i]));
                }
            } else {
                for (int i = 0; i < this.pars.length - 1; ++i) {
                    this.pars[i].setValue(Math.log(distr[i]) - Math.log(distr[this.pars.length - 1]));
                }
            }
        }

        private void copy(TreeElement node) {
            if (this.children != null) {
                if (node.children != null) {
                    if (this.children.length != node.children.length) {
                        throw new IndexOutOfBoundsException("Different number of children.");
                    }
                    for (int i = 0; i < this.children.length; ++i) {
                        this.children[i].copy(node.children[i]);
                    }
                } else {
                    for (int i = 0; i < this.children.length; ++i) {
                        this.children[i].copy(node);
                    }
                }
            } else if (node.pars != null) {
                if (this.pars.length != node.pars.length) {
                    throw new IndexOutOfBoundsException("Different number of parameters.");
                }
                for (int i = 0; i < this.pars.length; ++i) {
                    this.pars[i].setValue(node.pars[i].getValue());
                }
            } else {
                double[] vals = new double[this.pars.length];
                for (int i = 0; i < this.pars.length; ++i) {
                    double res;
                    vals[i] = res = node.getLogSum(i);
                    this.pars[i].setValue(res);
                }
                double norm = Normalisation.getLogSum(vals);
                for (int i = 0; i < this.pars.length; ++i) {
                    this.pars[i].setValue(this.pars[i].getValue() - norm);
                }
            }
        }

        private double getLogSum(int idx) {
            if (this.children != null) {
                double[] vals = new double[this.children.length];
                for (int i = 0; i < vals.length; ++i) {
                    vals[i] = this.children[i].getLogSum(idx);
                }
                double ret = Normalisation.getLogSum(vals);
                return ret;
            }
            return this.getLogSumForLeaf(idx);
        }

        private double getLogSumForLeaf(int idx) {
            return this.pars[idx].getValue() + this.pars[idx].getLogT();
        }

        private void initializeRandomly(double ess) {
            if (this.pars != null) {
                if (ess <= 0.0) {
                    ess = BNDiffSMParameterTree.this.alphabet.getAlphabetLengthAt(this.pars[0].getPosition());
                }
                double[] hyp = new double[this.pars.length];
                for (int i = 0; i < hyp.length; ++i) {
                    hyp[i] = ess / BNDiffSMParameterTree.this.alphabet.getAlphabetLengthAt(this.pars[i].getPosition());
                }
                double[] temp = DirichletMRG.DEFAULT_INSTANCE.generate(this.pars.length, new DirichletMRGParams(hyp));
                for (int i = 0; i < this.pars.length; ++i) {
                    this.pars[i].count = temp[i];
                }
                this.normalizePlugInParameters();
                if (!this.pars[this.pars.length - 1].isFree()) {
                    this.divideByUnfree();
                }
            } else {
                for (int i = 0; i < this.children.length; ++i) {
                    this.children[i].initializeRandomly(ess / BNDiffSMParameterTree.this.alphabet.getAlphabetLengthAt(this.contextPos));
                }
            }
        }

        private double computeGammaNorm() {
            if (this.children != null) {
                double val = 0.0;
                for (int i = 0; i < this.children.length; ++i) {
                    val += this.children[i].computeGammaNorm();
                }
                return val;
            }
            double val = 0.0;
            double hypSum = 0.0;
            for (int i = 0; i < this.pars.length; ++i) {
                double alpha = this.pars[i].getPseudoCount();
                hypSum += alpha;
                val -= Gamma.logOfGamma((double)alpha);
            }
            return val += Gamma.logOfGamma((double)hypSum);
        }

        private double getProbFor(Sequence sequence, int offset) {
            if (this.children != null) {
                if (offset < sequence.getLength() - 1) {
                    return this.children[sequence.discreteVal(offset)].getProbFor(sequence, offset + 1);
                }
                double val = 0.0;
                for (int i = 0; i < this.children.length; ++i) {
                    val += this.children[i].getProbFor(sequence, offset + 1);
                }
                return val;
            }
            return this.getContextProbability() * this.pars[sequence.discreteVal(sequence.getLength() - 1)].getExpValue();
        }

        private int getNumberOfParameters() {
            if (this.children != null) {
                int sum = 0;
                for (int i = 0; i < this.children.length; ++i) {
                    sum += this.children[i].getNumberOfParameters();
                }
                return sum;
            }
            int sum = 0;
            for (int i = 0; i < this.pars.length; ++i) {
                if (!this.pars[i].isFree()) continue;
                ++sum;
            }
            return sum;
        }

        private int getNumberOfSamplingSteps() {
            if (this.children != null) {
                int sum = 0;
                for (int i = 0; i < this.children.length; ++i) {
                    sum += this.children[i].getNumberOfSamplingSteps();
                }
                return sum;
            }
            return 1;
        }

        private int[] getParameterIndexesForSamplingStep(int step, int offset) {
            if (this.children != null) {
                for (int i = 0; i < this.children.length; ++i) {
                    int currSteps = this.children[i].getNumberOfSamplingSteps();
                    if (step < currSteps) {
                        return this.children[i].getParameterIndexesForSamplingStep(step, offset);
                    }
                    step -= currSteps;
                    offset += this.children[i].getNumberOfParameters();
                }
                return null;
            }
            int[] pars = new int[this.pars.length - 1];
            for (int i = 0; i < pars.length; ++i) {
                pars[i] = offset + i;
            }
            return pars;
        }

        private void emitSymbol(int[] content) {
            if (this.children != null) {
                this.children[content[this.contextPos]].emitSymbol(content);
            } else {
                double[] temp = new double[this.pars.length];
                for (int i = 0; i < temp.length; ++i) {
                    temp[i] = this.pars[i].getValue() + this.pars[i].getLogZ();
                }
                Normalisation.logSumNormalisation(temp);
                double v = r.nextDouble();
                for (int i = 0; i < temp.length; ++i) {
                    if (v - temp[i] <= 0.0) {
                        content[((BNDiffSMParameterTree)BNDiffSMParameterTree.this).pos] = i;
                        return;
                    }
                    v -= temp[i];
                }
            }
        }

        private byte getMaximalMarkovOrder(byte i) {
            if (this.children != null) {
                return this.children[0].getMaximalMarkovOrder((byte)(i + 1));
            }
            return i;
        }

        private double getMaximumScore() {
            if (this.children != null) {
                throw new RuntimeException("Not implemented");
            }
            double max = Double.NEGATIVE_INFINITY;
            for (int i = 0; i < this.pars.length; ++i) {
                double temp = this.pars[i].getValue();
                if (!(temp > max)) continue;
                max = temp;
            }
            return max;
        }

        private void appendHtmlToBuffer(StringBuffer all, String after, NumberFormat nf) {
            if (this.children != null) {
                for (int i = 0; i < this.children.length; ++i) {
                    this.children[i].appendHtmlToBuffer(all, after + (after.length() == 0 ? "" : ", ") + "X_" + this.contextPos + " = " + BNDiffSMParameterTree.this.alphabet.getSymbol(this.contextPos, i), nf);
                }
            } else {
                double[] norms = new double[this.pars.length];
                for (int i = 0; i < this.pars.length; ++i) {
                    norms[i] = this.pars[i].getValue() + this.pars[i].getLogZ();
                }
                double logNorm = Normalisation.getLogSum(norms);
                if (BNDiffSMParameterTree.this.getNumberOfParents() > 0) {
                    all.append("<tr><td>" + after + "</td>");
                } else {
                    all.append("<tr>");
                }
                for (int i = 0; i < this.pars.length; ++i) {
                    double tempTheta = Math.exp(this.pars[i].getValue() + this.pars[i].getLogZ() - logNorm);
                    all.append("<td>" + nf.format(tempTheta) + "</td>");
                }
                all.append("</tr>");
            }
        }
    }
}

