/*
 * Decompiled with CFR 0.152.
 */
import de.jstacs.data.DNADataSet;
import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.data.sequences.annotation.SplitSequenceAnnotationParser;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.REnvironment;
import de.jstacs.utils.ToolBox;
import java.util.Arrays;
import java.util.LinkedList;

public class PlotLogoAndMI2 {
    public static void main(String[] args) throws Exception {
        double[] wt = new double[]{0.5, 0.6, 0.7, 0.8, 0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99};
        double[] pt = new double[]{0.001};
        REnvironment re = new REnvironment();
        re.voidEval("library(seqLogo);library(gplots);");
        re.createVector("wt", wt);
        DoubleList means = new DoubleList();
        int s = 0;
        while (s < pt.length) {
            int t = 0;
            while (t < wt.length) {
                DataSet ds = new DNADataSet(args[0], '>', new SplitSequenceAnnotationParser(":", ";"));
                LinkedList<Sequence> seqs = new LinkedList<Sequence>();
                DoubleList ws = new DoubleList();
                int i = 0;
                while (i < ds.getNumberOfElements()) {
                    double w = Double.parseDouble(ds.getElementAt(i).getSequenceAnnotationByType("weight", 0).getIdentifier());
                    double p = Double.parseDouble(ds.getElementAt(i).getSequenceAnnotationByType("pval", 0).getIdentifier());
                    if (w >= wt[t] && p <= pt[s]) {
                        seqs.add(ds.getElementAt(i));
                        ws.add(w);
                    }
                    ++i;
                }
                if (seqs.size() == 0) {
                    means.add(-1.0);
                } else {
                    ds = new DataSet("", seqs.toArray(new Sequence[0]));
                    double[][] mis = PlotLogoAndMI2.computeMIs(ds, ws.toArray());
                    double[] nn = new double[mis.length - 1];
                    int i2 = 0;
                    while (i2 < mis.length - 1) {
                        nn[i2] = mis[i2][i2 + 1];
                        ++i2;
                    }
                    double mean = ToolBox.sum(nn) / (double)nn.length;
                    means.add(mean);
                }
                ++t;
            }
            ++s;
        }
        re.createVector("mea", means.toArray());
        re.voidEval("mea[mea<0]<-NA;");
        re.plotToPDF("plot(wt,mea,t=\"l\")", 8.0, 5.0, String.valueOf(args[0]) + "_meanmi.pdf", true);
    }

    private static double[][] computeMIs(DataSet ds, double[] w) {
        double[][] mis = new double[ds.getElementLength()][ds.getElementLength()];
        int i = 0;
        while (i < mis.length) {
            int j = 0;
            while (j < i) {
                mis[i][j] = PlotLogoAndMI2.computeMI(ds, w, i, j);
                mis[j][i] = mis[i][j];
                ++j;
            }
            ++i;
        }
        return mis;
    }

    private static double[][][] getConditionalPWMs(DataSet ds, double[] w) {
        double[][][] pwms = new double[4][ds.getElementLength()][4];
        int i = 0;
        while (i < ds.getNumberOfElements()) {
            Sequence seq = ds.getElementAt(i);
            int j = 0;
            while (j < seq.getLength() - 1) {
                double[] dArray = pwms[seq.discreteVal(j)][j + 1];
                int n = seq.discreteVal(j + 1);
                dArray[n] = dArray[n] + w[i];
                ++j;
            }
            ++i;
        }
        i = 0;
        while (i < pwms.length) {
            int j = 0;
            while (j < pwms[i].length) {
                Normalisation.sumNormalisation(pwms[i][j]);
                if (Double.isNaN(pwms[i][j][0])) {
                    Arrays.fill(pwms[i][j], 0.25);
                }
                ++j;
            }
            ++i;
        }
        return pwms;
    }

    private static double computeMI(DataSet ds, double[] w, int p1, int p2) {
        double[][] count = new double[4][4];
        int i = 0;
        while (i < ds.getNumberOfElements()) {
            Sequence seq = ds.getElementAt(i);
            double[] dArray = count[seq.discreteVal(p1)];
            int n = seq.discreteVal(p2);
            dArray[n] = dArray[n] + w[i];
            ++i;
        }
        double[] count1 = new double[4];
        double[] count2 = new double[4];
        int i2 = 0;
        while (i2 < count.length) {
            int j = 0;
            while (j < count[i2].length) {
                int n = i2;
                count1[n] = count1[n] + count[i2][j];
                int n2 = j;
                count2[n2] = count2[n2] + count[i2][j];
                ++j;
            }
            ++i2;
        }
        double sum = ToolBox.sum(count1);
        Normalisation.sumNormalisation(count1);
        Normalisation.sumNormalisation(count2);
        double mi = 0.0;
        int i3 = 0;
        while (i3 < count.length) {
            int j = 0;
            while (j < count[i3].length) {
                double[] dArray = count[i3];
                int n = j;
                dArray[n] = dArray[n] / sum;
                if (count[i3][j] > 0.0) {
                    mi += count[i3][j] * Math.log(count[i3][j] / (count1[i3] * count2[j]));
                }
                ++j;
            }
            ++i3;
        }
        return mi;
    }
}

