/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.sequenceScores.statisticalModels.trainable.hmm;

import de.jstacs.Storable;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.WrongAlphabetException;
import de.jstacs.data.WrongLengthException;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.sequenceScores.statisticalModels.trainable.AbstractTrainableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.State;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.Emission;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.training.HMMTrainingParameterSet;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.training.MultiThreadedTrainingParameterSet;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.BasicHigherOrderTransition;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.HigherOrderTransition;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.Transition;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.elements.TransitionElement;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.Pair;
import de.jstacs.utils.SafeOutputStream;
import de.jstacs.utils.ToolBox;
import java.io.OutputStream;
import java.text.NumberFormat;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;

public abstract class AbstractHMM
extends AbstractTrainableStatisticalModel
implements Cloneable,
Storable {
    protected State[] states;
    protected String[] name;
    protected int[] emissionIdx;
    protected boolean[] forward;
    protected Emission[] emission;
    protected Transition transition;
    protected double[][] fwdMatrix;
    protected double[][] bwdMatrix;
    protected HMMTrainingParameterSet trainingParameter;
    protected SafeOutputStream sostream;
    protected boolean[] finalState;
    protected int threads;
    public static final String START_NODE = "START";

    protected AbstractHMM(HMMTrainingParameterSet trainingParameterSet, String[] name, int[] emissionIdx, boolean[] forward, Emission[] emission) throws CloneNotSupportedException, WrongAlphabetException {
        super(AbstractHMM.getAlphabetContainer(emission), 0);
        if (!trainingParameterSet.hasDefaultOrIsSet()) {
            throw new IllegalArgumentException("Please check the training parameters.");
        }
        this.trainingParameter = (HMMTrainingParameterSet)trainingParameterSet.clone();
        this.setThreads();
        this.setOutputStream(SafeOutputStream.DEFAULT_STREAM);
        int n = name.length;
        this.name = new String[n];
        HashSet<String> hash = new HashSet<String>();
        for (int i = 0; i < n; ++i) {
            if (hash.contains(name[i])) {
                throw new IllegalArgumentException("The state names should be unique. Please check: " + name[i]);
            }
            this.name[i] = name[i];
            hash.add(name[i]);
        }
        hash.clear();
        hash = null;
        if (emissionIdx == null) {
            this.emissionIdx = new int[n];
            for (int e = 0; e < n; ++e) {
                this.emissionIdx[e] = e;
            }
        } else {
            if (n != emissionIdx.length) {
                throw new IllegalArgumentException();
            }
            this.emissionIdx = (int[])emissionIdx.clone();
        }
        if (forward == null) {
            this.forward = new boolean[n];
            Arrays.fill(this.forward, true);
        } else {
            if (n != forward.length) {
                throw new IllegalArgumentException();
            }
            this.forward = (boolean[])forward.clone();
        }
        if (emission.length > n) {
            throw new IllegalArgumentException();
        }
        this.emission = (Emission[])ArrayHandler.clone((Cloneable[])emission);
    }

    private void setThreads() {
        this.threads = this.trainingParameter instanceof MultiThreadedTrainingParameterSet ? ((MultiThreadedTrainingParameterSet)this.trainingParameter).getNumberOfThreads() : 1;
    }

    private static AlphabetContainer getAlphabetContainer(Emission ... e) throws WrongAlphabetException {
        AlphabetContainer con = null;
        int i = 0;
        while (con == null) {
            con = e[i++].getAlphabetContainer();
        }
        while (i < e.length) {
            AlphabetContainer current;
            if ((current = e[i++].getAlphabetContainer()) == null || current.checkConsistency(con)) continue;
            throw new WrongAlphabetException("All emission should use the same AlphabetContainer.");
        }
        if (!con.isSimple()) {
            throw new IllegalArgumentException("The AlphabetContainer has to be simple.");
        }
        return con;
    }

    protected AbstractHMM(StringBuffer xml) throws NonParsableException {
        super(xml);
        this.setOutputStream(SafeOutputStream.DEFAULT_STREAM);
    }

    protected void initTransition(BasicHigherOrderTransition.AbstractTransitionElement ... te) throws Exception {
        boolean[] isSilent = new boolean[this.states.length];
        for (int i = 0; i < this.states.length; ++i) {
            isSilent[i] = this.states[i].isSilent();
        }
        if (te instanceof TransitionElement[]) {
            this.transition = new HigherOrderTransition(isSilent, (TransitionElement[])te);
        } else {
            int t = 0;
            TransitionElement[] help = new TransitionElement[te.length];
            for (int i = 0; i < help.length && te[t] instanceof TransitionElement; ++i) {
                help[t] = (TransitionElement)te[t];
            }
            this.transition = t == te.length ? new HigherOrderTransition(isSilent, help) : new BasicHigherOrderTransition(isSilent, te);
        }
    }

    protected abstract String getXMLTag();

    @Override
    public StringBuffer toXML() {
        StringBuffer xml = new StringBuffer();
        XMLParser.appendObjectWithTags(xml, this.trainingParameter, "trainingParameter");
        XMLParser.appendObjectWithTags(xml, this.transition, "transition");
        XMLParser.appendObjectWithTags(xml, this.name, "name");
        XMLParser.appendObjectWithTags(xml, this.emissionIdx, "emissionIdx");
        XMLParser.appendObjectWithTags(xml, this.forward, "strand");
        XMLParser.appendObjectWithTags(xml, this.emission, "emission");
        this.appendFurtherInformation(xml);
        XMLParser.addTags(xml, this.getXMLTag());
        return xml;
    }

    @Override
    protected void fromXML(StringBuffer xml) throws NonParsableException {
        this.length = 0;
        xml = XMLParser.extractForTag(xml, this.getXMLTag());
        this.trainingParameter = (HMMTrainingParameterSet)XMLParser.extractObjectForTags(xml, "trainingParameter");
        this.setThreads();
        this.transition = (Transition)XMLParser.extractObjectForTags(xml, "transition");
        this.name = XMLParser.extractObjectForTags(xml, "name", String[].class);
        this.emissionIdx = XMLParser.extractObjectForTags(xml, "emissionIdx", int[].class);
        this.forward = XMLParser.extractObjectForTags(xml, "strand", boolean[].class);
        this.emission = XMLParser.extractObjectForTags(xml, "emission", Emission[].class);
        this.extractFurtherInformation(xml);
        try {
            this.alphabets = AbstractHMM.getAlphabetContainer(this.emission);
        }
        catch (WrongAlphabetException e) {
            NonParsableException npe = new NonParsableException(e.getMessage());
            throw npe;
        }
        this.createStates();
        this.determineFinalStates();
    }

    protected abstract void appendFurtherInformation(StringBuffer var1);

    protected abstract void extractFurtherInformation(StringBuffer var1) throws NonParsableException;

    @Override
    public AbstractHMM clone() throws CloneNotSupportedException {
        AbstractHMM clone = (AbstractHMM)super.clone();
        clone.name = (String[])this.name.clone();
        clone.emissionIdx = (int[])this.emissionIdx.clone();
        clone.forward = (boolean[])this.forward.clone();
        clone.emission = (Emission[])ArrayHandler.clone((Cloneable[])this.emission);
        clone.transition = this.transition.clone();
        clone.fwdMatrix = (double[][])ArrayHandler.clone((Cloneable[])this.fwdMatrix);
        clone.bwdMatrix = (double[][])ArrayHandler.clone((Cloneable[])this.bwdMatrix);
        clone.trainingParameter = (HMMTrainingParameterSet)this.trainingParameter.clone();
        clone.finalState = (boolean[])this.finalState.clone();
        clone.createStates();
        clone.setOutputStream(this.sostream.doesNothing() ? null : SafeOutputStream.DEFAULT_STREAM);
        return clone;
    }

    protected abstract void createStates();

    protected abstract void fillFwdMatrix(int var1, int var2, Sequence var3) throws Exception;

    protected abstract void fillBwdMatrix(int var1, int var2, Sequence var3) throws Exception;

    public int getNumberOfThreads() {
        return this.threads;
    }

    public String getGraphvizRepresentation(NumberFormat nf) {
        return this.getGraphvizRepresentation(nf, null, null, false);
    }

    public String getGraphvizRepresentation(NumberFormat nf, boolean sameTypeSameRank) {
        return this.getGraphvizRepresentation(nf, null, null, sameTypeSameRank);
    }

    public String getGraphvizRepresentation(NumberFormat nf, DataSet data, double[] weight, boolean sameTypeSameRank) {
        HashMap<String, String> map = new HashMap<String, String>();
        String regex = ".*";
        for (int i = 0; i < this.name.length; ++i) {
            map.put(this.name[i].charAt(0) + regex, "same");
        }
        return this.getGraphvizRepresentation(nf, data, weight, map);
    }

    public String getGraphvizRepresentation(NumberFormat nf, DataSet data, double[] weight, HashMap<String, String> rankPatterns) {
        double[] freq = null;
        double maxFreq = 0.0;
        if (data != null) {
            try {
                freq = this.getStateFreq(data, weight);
                maxFreq = ToolBox.max(freq);
            }
            catch (Exception e) {
                e.printStackTrace();
                freq = new double[this.states.length];
                maxFreq = 0.0;
            }
        } else {
            freq = new double[this.states.length];
            Arrays.fill(freq, -1.0);
            maxFreq = -1.0;
        }
        StringBuffer sb = new StringBuffer();
        sb.append("digraph G {\n\trankdir=" + (rankPatterns != null ? "TB" : "LR") + "\n\n");
        sb.append("\tSTART[shape=point];\n\n");
        for (int s = 0; s < this.states.length; ++s) {
            sb.append("\t" + s + "[" + this.states[s].getGraphvizNodeOptions(freq[s], maxFreq, nf) + ",color=" + (this.finalState[s] ? "red" : "black") + "];\n");
        }
        if (rankPatterns != null) {
            IntList list;
            StringBuffer ranks = new StringBuffer();
            HashMap<String, IntList> map = new HashMap<String, IntList>();
            block3: for (int s = 0; s < this.name.length; ++s) {
                for (String key : rankPatterns.keySet()) {
                    if (!this.name[s].matches(key)) continue;
                    list = (IntList)map.get(key);
                    if (list == null) {
                        map.put(key, new IntList());
                    }
                    ((IntList)map.get(key)).add(s);
                    continue block3;
                }
            }
            boolean startRanked = false;
            for (String key : map.keySet()) {
                ranks.append("{rank=" + rankPatterns.get(key) + "; ");
                if (!startRanked && START_NODE.matches(key)) {
                    ranks.append("START ");
                    startRanked = true;
                }
                list = (IntList)map.get(key);
                for (int i = 0; i < list.length(); ++i) {
                    ranks.append(list.get(i) + " ");
                }
                ranks.append(";}\n");
            }
            sb.append(ranks);
        }
        sb.append("\n");
        sb.append(this.transition.getGraphizNetworkRepresentation(nf, null, data != null));
        sb.append("}");
        return sb.toString();
    }

    private double[] getStateFreq(DataSet data, double[] weight) throws Exception {
        double[] res = new double[this.states.length];
        if (data != null) {
            double w = 1.0;
            double sum = 0.0;
            double[][] current = this.createMatrixForStatePosterior(0, data.getMaximalElementLength() - 1);
            for (int i = 0; i < data.getNumberOfElements(); ++i) {
                Sequence seq = data.getElementAt(i);
                this.fillLogStatePosteriorMatrix(current, 0, seq.getLength() - 1, data.getElementAt(i), false);
                if (weight != null) {
                    w = weight[i];
                }
                sum += w;
                for (int s = 0; s < this.states.length; ++s) {
                    int n = s;
                    res[n] = res[n] + w * Math.exp(Normalisation.getLogSum(current[s]));
                }
            }
            int s = 0;
            while (s < this.states.length) {
                int n = s++;
                res[n] = res[n] / sum;
            }
        }
        return res;
    }

    protected double[][] createMatrixForStatePosterior(int startPos, int endPos) {
        return new double[this.states.length][endPos - startPos + 1 + 1];
    }

    protected abstract void fillLogStatePosteriorMatrix(double[][] var1, int var2, int var3, Sequence var4, boolean var5) throws Exception;

    public double[][] getLogStatePosteriorMatrixFor(int startPos, int endPos, Sequence seq) throws Exception {
        double[][] m = this.createMatrixForStatePosterior(startPos, endPos);
        this.fillLogStatePosteriorMatrix(m, startPos, endPos, seq, true);
        return this.getFinalStatePosterioriMatrix(m);
    }

    protected double[][] getFinalStatePosterioriMatrix(double[][] intermediate) {
        double[][] res = new double[intermediate.length][];
        for (int i = 0; i < res.length; ++i) {
            res[i] = new double[intermediate[i].length - 1];
            System.arraycopy(intermediate[i], 1, res[i], 0, res[i].length);
        }
        return res;
    }

    public double[][] getStatePosteriorMatrixFor(Sequence seq) throws Exception {
        double[][] matrix = this.getLogStatePosteriorMatrixFor(0, seq.getLength() - 1, seq);
        for (int i = 0; i < matrix.length; ++i) {
            for (int j = 0; j < matrix[i].length; ++j) {
                matrix[i][j] = Math.exp(matrix[i][j]);
            }
        }
        return matrix;
    }

    public double[][][] getLogStatePosteriorMatrixFor(DataSet data) throws Exception {
        double[][][] matrix = new double[data.getNumberOfElements()][][];
        for (int i = 0; i < matrix.length; ++i) {
            Sequence s = data.getElementAt(i);
            matrix[i] = this.getLogStatePosteriorMatrixFor(0, s.getLength() - 1, s);
        }
        return matrix;
    }

    public double[][][] getStatePosteriorMatrixFor(DataSet data) throws Exception {
        double[][][] matrix = new double[data.getNumberOfElements()][][];
        for (int i = 0; i < matrix.length; ++i) {
            matrix[i] = this.getStatePosteriorMatrixFor(data.getElementAt(i));
        }
        return matrix;
    }

    public abstract Pair<IntList, Double> getViterbiPathFor(int var1, int var2, Sequence var3) throws Exception;

    public Pair<IntList, Double> getViterbiPathFor(Sequence seq) throws Exception {
        return this.getViterbiPathFor(0, seq.getLength() - 1, seq);
    }

    public Pair<IntList, Double>[] getViterbiPathsFor(DataSet data) throws Exception {
        Pair[] matrix = new Pair[data.getNumberOfElements()];
        for (int i = 0; i < matrix.length; ++i) {
            matrix[i] = this.getViterbiPathFor(data.getElementAt(i));
        }
        return matrix;
    }

    public final String[] decodePath(IntList path) {
        String[] decoded = new String[path.length()];
        for (int i = 0; i < decoded.length; ++i) {
            decoded[i] = this.name[path.get(i)];
        }
        return decoded;
    }

    public abstract double getLogProbForPath(IntList var1, int var2, Sequence var3) throws Exception;

    protected abstract void createHelperVariables();

    protected void provideMatrix(int type, int length) {
        Object matrix;
        this.createHelperVariables();
        ++length;
        switch (type) {
            case 0: {
                matrix = this.fwdMatrix;
                break;
            }
            case 1: {
                matrix = this.bwdMatrix;
                break;
            }
            default: {
                throw new IllegalArgumentException("unknown matrix type");
            }
        }
        if (matrix == null || ((double[][])matrix).length < length) {
            matrix = new double[length][];
            int maxOrder = this.transition.getMaximalMarkovOrder();
            int dim = -1;
            int l = 0;
            for (l = 0; l <= maxOrder && l < length; ++l) {
                dim = this.transition.getNumberOfIndexes(l);
                matrix[l] = new double[dim];
            }
            while (l < length) {
                matrix[l++] = new double[dim];
            }
        }
        for (int l = 0; l < length; ++l) {
            Arrays.fill(matrix[l], Double.NEGATIVE_INFINITY);
        }
        switch (type) {
            case 0: {
                this.fwdMatrix = matrix;
                break;
            }
            case 1: {
                this.bwdMatrix = matrix;
            }
        }
    }

    public int getNumberOfStates() {
        return this.states.length;
    }

    @Override
    public double getLogProbFor(Sequence sequence, int startpos, int endpos) throws Exception {
        int l = endpos - startpos + 1;
        int len = this.getLength();
        if (!sequence.getAlphabetContainer().checkConsistency(this.getAlphabetContainer())) {
            throw new WrongAlphabetException("The AlphabetContainer of the sequence and the model do not match.");
        }
        if (len != 0 && l != len) {
            throw new WrongLengthException("The given start position (" + startpos + ") and end position (" + endpos + ") yield an length of " + l + " which is not possible for the current model that models sequences of length " + len + ".");
        }
        return this.logProb(startpos, endpos, sequence);
    }

    protected static RuntimeException getRunTimeException(Exception e) {
        RuntimeException re;
        if (e instanceof RuntimeException) {
            re = (RuntimeException)e;
        } else {
            re = new RuntimeException(e.getMessage());
            re.setStackTrace(e.getStackTrace());
        }
        return re;
    }

    protected double logProb(int startpos, int endpos, Sequence sequence) throws Exception {
        try {
            this.fillBwdMatrix(startpos, endpos, sequence);
        }
        catch (Exception e) {
            throw AbstractHMM.getRunTimeException(e);
        }
        return this.bwdMatrix[0][0];
    }

    @Override
    public void train(DataSet data) throws Exception {
        this.train(data, null);
    }

    public final void setOutputStream(OutputStream o) {
        this.sostream = SafeOutputStream.getSafeOutputStream(o);
    }

    protected void finalize() throws Throwable {
        this.transition = null;
        this.states = null;
        this.trainingParameter = null;
        this.bwdMatrix = null;
        this.fwdMatrix = null;
        super.finalize();
    }

    protected void determineFinalStates() {
        int i;
        this.finalState = this.transition.isAbsoring();
        for (i = 0; i < this.finalState.length && !this.finalState[i]; ++i) {
        }
        if (i == this.finalState.length) {
            for (i = 0; i < this.finalState.length; ++i) {
                this.finalState[i] = !this.states[i].isSilent();
            }
        }
    }

    public static int[][] decodeStatePosterior(double[][] ... statePosterior) {
        int[][] res = new int[statePosterior.length][];
        for (int s = 0; s < res.length; ++s) {
            res[s] = new int[statePosterior[s][0].length];
            for (int l = 0; l < res[s].length; ++l) {
                res[s][l] = 0;
                for (int j = 1; j < statePosterior[s].length; ++j) {
                    if (!(statePosterior[s][res[s][l]][l] < statePosterior[s][j][l])) continue;
                    res[s][l] = j;
                }
            }
        }
        return res;
    }

    @Override
    public String toString(NumberFormat nf) {
        String res = "Transition:\n-----------\n" + this.transition.toString(this.name, nf);
        res = res + "\nStates:\n-------\n";
        for (int e = 0; e < this.states.length; ++e) {
            res = res + this.states[e].toString(nf) + "\n";
        }
        return res;
    }
}

