/*
 * Decompiled with CFR 0.152.
 */
package challenges.dream6;

import challenges.dream6.SingleGaussianDiffSM;
import de.jstacs.DataType;
import de.jstacs.classifiers.differentiableSequenceScoreBased.AbstractMultiThreadedOptimizableFunction;
import de.jstacs.classifiers.differentiableSequenceScoreBased.OptimizableFunction;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.GenDisMixClassifier;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.GenDisMixClassifierParameterSet;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.DoesNothingLogPrior;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.LogPrior;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.SimpleGaussianSumLogPrior;
import de.jstacs.classifiers.differentiableSequenceScoreBased.msp.MSPClassifier;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.WrongAlphabetException;
import de.jstacs.data.alphabets.Alphabet;
import de.jstacs.data.alphabets.ContinuousAlphabet;
import de.jstacs.data.sequences.ArbitrarySequence;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.data.sequences.WrongSequenceTypeException;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.FileManager;
import de.jstacs.io.SparseStringExtractor;
import de.jstacs.io.XMLParser;
import de.jstacs.parameters.Parameter;
import de.jstacs.parameters.ParameterSet;
import de.jstacs.parameters.ParameterSetTagger;
import de.jstacs.parameters.SimpleParameter;
import de.jstacs.parameters.validation.NumberValidator;
import de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore;
import de.jstacs.sequenceScores.differentiable.UniformDiffSS;
import de.jstacs.sequenceScores.differentiable.logistic.LogisticConstraint;
import de.jstacs.sequenceScores.differentiable.logistic.LogisticDiffSS;
import de.jstacs.sequenceScores.differentiable.logistic.ProductConstraint;
import de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.IndependentProductDiffSM;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.SafeOutputStream;
import de.jstacs.utils.ToolBox;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.AbstractCollection;
import java.util.AbstractList;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.Random;

public class Dream6C4_final {
    private static final Random r = new Random();

    public static void main(String[] args) throws Exception {
        GenDisMixClassifier[] cl;
        ParameterSetTagger params = new ParameterSetTagger(Dream6C4Parameters.PREFIX, new Dream6C4Parameters());
        params.fillParameters("=", args);
        System.out.println("parameters:");
        System.out.println(params);
        System.out.println("_________________________________");
        if (!params.hasDefaultOrIsSet()) {
            System.out.println("Some of the required parameters are not specified.");
            System.exit(1);
        }
        String home = params.getValueFromTag("home", String.class);
        int threads = params.getValueFromTag("threads", Integer.class);
        double eps = params.getValueFromTag("epsilon", Double.class);
        SafeOutputStream sostream = SafeOutputStream.getSafeOutputStream(System.out);
        GenDisMixClassifier clOutCl = null;
        HashSet<Integer> aml = new HashSet<Integer>();
        int[][] idx = Dream6C4_final.read(String.valueOf(home) + File.separator + params.getValueFromTag("desc", String.class), aml, params.getValueFromTag("key", String.class));
        if (!params.getValueFromTag("load", Boolean.class).booleanValue()) {
            cl = Dream6C4_final.build(0.5 * eps, threads);
            Dream6C4_final.train(cl, idx, home, sostream);
            DataSet[] clOut = Dream6C4_final.getTrainingPred(cl, idx, home, sostream);
            int l = clOut[0].getElementLength();
            GenDisMixClassifierParameterSet ps = new GenDisMixClassifierParameterSet(cl[0].getAlphabetContainer(), l, 10, eps, 0.1 * eps, 1.0, false, OptimizableFunction.KindOfParameter.PLUGIN, true, 1);
            clOutCl = new MSPClassifier(ps, (LogPrior)new SimpleGaussianSumLogPrior(1.0), Double.NaN, Dream6C4_final.getLogistic(cl[0].getAlphabetContainer(), l, false), new UniformDiffSS(cl[0].getAlphabetContainer(), l));
            clOutCl.setOutputStream(sostream);
            clOutCl.train(clOut);
            StringBuffer sb = new StringBuffer();
            XMLParser.appendObjectWithTags(sb, cl, "classifier");
            XMLParser.appendObjectWithTags(sb, clOutCl, "combined_classifier");
            FileManager.writeFile(new File(String.valueOf(home) + File.separator + params.getValueFromTag("classifier", String.class)), (CharSequence)sb);
        } else {
            StringBuffer sb = FileManager.readFile(new File(String.valueOf(home) + File.separator + params.getValueFromTag("classifier", String.class)));
            cl = (GenDisMixClassifier[])XMLParser.extractObjectForTags(sb, "classifier");
            clOutCl = (GenDisMixClassifier)XMLParser.extractObjectForTags(sb, "combined_classifier");
        }
        sostream.writeln("patient\thealthy\ttest");
        Dream6C4_final.evaluate(1, 179, 1, home, aml, cl, clOutCl, sostream);
        sostream.writeln("");
        sostream.writeln("patient\thealthy\ttest");
        Dream6C4_final.evaluate(180, 359, 0, home, aml, cl, clOutCl, sostream);
    }

    private static void evaluate(int start, int end, int train, String home, HashSet<Integer> aml, GenDisMixClassifier[] cl, GenDisMixClassifier clOutCl, SafeOutputStream sostream) throws Exception {
        double[] score = new double[2];
        double[] vals = new double[cl.length];
        String a = "";
        int patient = start;
        while (patient <= end) {
            double combined;
            if (train == 1) {
                a = "" + (aml.contains(patient) ? 0 : 1);
            }
            sostream.write(String.valueOf(patient) + "\t" + a + "\t" + train);
            Sequence seq = Dream6C4_final.getClOut(cl, home, patient);
            int g = 0;
            while (g < cl.length) {
                double v = seq.continuousVal(g);
                v = 1.0 - 1.0 / (1.0 + Math.exp(v));
                sostream.write("\t" + v);
                ++g;
            }
            if (clOutCl != null) {
                score[0] = clOutCl.getScore(seq, 0);
                score[1] = clOutCl.getScore(seq, 1);
                Normalisation.logSumNormalisation(score);
                combined = score[0];
            } else {
                int c = 0;
                while (c < cl.length) {
                    vals[c] = 1.0 - 1.0 / (1.0 + Math.exp(seq.continuousVal(c)));
                    ++c;
                }
                combined = ToolBox.sum(0, cl.length - 1, vals) / (double)(cl.length - 1);
            }
            sostream.writeln("\t" + combined);
            ++patient;
        }
    }

    private static GenDisMixClassifier[] build(double eps, int threads) throws Exception {
        AlphabetContainer con = new AlphabetContainer((Alphabet)new ContinuousAlphabet());
        double ess = 1.0;
        DifferentiableSequenceScore fg = Dream6C4_final.getSF(con, ess, 7);
        DifferentiableSequenceScore bg = Dream6C4_final.getSF(con, 8.0 * ess, 7);
        GenDisMixClassifierParameterSet ps = new GenDisMixClassifierParameterSet(con, 7, 10, eps, eps * 0.1, 1.0, false, OptimizableFunction.KindOfParameter.PLUGIN, true, threads);
        MSPClassifier cl = new MSPClassifier(ps, (LogPrior)DoesNothingLogPrior.defaultInstance, Double.NaN, fg, bg);
        return (GenDisMixClassifier[])ArrayHandler.createArrayOf((Cloneable)cl, (int)8);
    }

    private static DifferentiableSequenceScore getSF(AlphabetContainer con, double ess, int length) throws Exception {
        double e = ess;
        int[] partL = new int[length];
        Arrays.fill(partL, 1);
        boolean alwaysRandomly = true;
        Object[] nsfs = new DifferentiableStatisticalModel[length];
        SingleGaussianDiffSM first = new SingleGaussianDiffSM(con, e, 6.0, 0.25, 1.0, alwaysRandomly);
        SingleGaussianDiffSM rest = new SingleGaussianDiffSM(con, e, -0.5, 4.0, 1.0, alwaysRandomly);
        int i = 0;
        while (i < length) {
            if (i == 0) {
                Arrays.fill(nsfs, 0, 1, first);
            } else {
                Arrays.fill(nsfs, i, i + 1, rest);
            }
            ++i;
        }
        return new IndependentProductDiffSM(ess, true, (DifferentiableStatisticalModel[])nsfs, partL);
    }

    private static int[][] read(String fName, HashSet<Integer> aml, String unhealthy) throws Exception {
        SparseStringExtractor sse = new SparseStringExtractor(fName, '#');
        sse.nextElement();
        IntList pos = new IntList();
        IntList neg = new IntList();
        while (sse.hasMoreElements()) {
            String line = sse.nextElement();
            int idx = Integer.parseInt(line.substring(0, line.indexOf(44)));
            if (line.endsWith(unhealthy)) {
                pos.add(idx);
                aml.add(idx);
                continue;
            }
            neg.add(idx);
        }
        return new int[][]{pos.toArray(), neg.toArray()};
    }

    private static int[][][] split(int[][] idx, double p, HashSet<Integer> training) throws FileNotFoundException, IOException {
        int[][][] res = new int[2][idx.length][];
        training.clear();
        int i = 0;
        while (i < idx.length) {
            int l = (int)Math.ceil(p * (double)idx[i].length);
            res[0][i] = new int[l];
            res[1][i] = new int[idx[i].length - l];
            int[] help = new int[idx[i].length];
            int j = 0;
            while (j < help.length) {
                help[j] = j;
                ++j;
            }
            j = 0;
            while (j < res[0][i].length) {
                int current = r.nextInt(help.length - j);
                res[0][i][j] = idx[i][help[current]];
                training.add(res[0][i][j]);
                help[current] = help[help.length - j - 1];
                ++j;
            }
            j = 0;
            while (j < res[1][i].length) {
                res[1][i][j] = idx[i][help[j]];
                ++j;
            }
            System.out.println(String.valueOf(i) + "\t" + Arrays.toString(res[0][i]));
            ++i;
        }
        return res;
    }

    private static void train(GenDisMixClassifier[] cl, int[][] split, String home, SafeOutputStream sostream) throws Exception {
        AlphabetContainer con = cl[0].getAlphabetContainer();
        LinkedList<Sequence> seqs = new LinkedList<Sequence>();
        DataSet[] data = new DataSet[2];
        int c = 0;
        while (c < cl.length) {
            sostream.writeln(String.valueOf(c) + " ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~");
            int i = 0;
            while (i < data.length) {
                ((AbstractList)seqs).clear();
                int j = 0;
                while (j < split[i].length) {
                    String f = "" + (8 * (split[i][j] - 1) + 1 + c);
                    while (f.length() < 4) {
                        f = "0" + f;
                    }
                    SparseStringExtractor sse = new SparseStringExtractor(String.valueOf(home) + "/csv/" + f + ".CSV", '#');
                    sse.nextElement();
                    while (sse.hasMoreElements()) {
                        ((AbstractList)seqs).add(Dream6C4_final.create(con, sse.nextElement(), ","));
                    }
                    ++j;
                }
                data[i] = new DataSet("class " + i, ((AbstractCollection)seqs).toArray(new Sequence[0]));
                sostream.writeln(String.valueOf(i) + "\t#(patients) = " + split[i].length + "\t#(seqs) = " + data[i].getNumberOfElements());
                ++i;
            }
            cl[c].setOutputStream(sostream);
            cl[c].train(data);
            ++c;
        }
    }

    private static DataSet[] getTrainingPred(GenDisMixClassifier[] cl, int[][] split, String home, SafeOutputStream sostream) throws Exception {
        DataSet[] clOut = new DataSet[2];
        int i = 0;
        while (i < clOut.length) {
            Sequence[] seqs = new Sequence[split[i].length];
            int j = 0;
            while (j < split[i].length) {
                seqs[j] = Dream6C4_final.getClOut(cl, home, split[i][j]);
                ++j;
            }
            clOut[i] = new DataSet("", seqs);
            sostream.writeln(String.valueOf(i) + "\t# = " + clOut[i].getNumberOfElements());
            ++i;
        }
        return clOut;
    }

    private static Sequence getClOut(GenDisMixClassifier[] cl, String home, int patient) throws Exception {
        DoubleList list = new DoubleList();
        int i = 0;
        while (i < cl.length) {
            Dream6C4_final.getScore(cl[i], i, home, patient, list);
            ++i;
        }
        double[] vals = list.toArray();
        i = 0;
        while (i < vals.length) {
            vals[i] = Math.log(vals[i]) - Math.log(1.0 - vals[i]);
            ++i;
        }
        return new ArbitrarySequence(cl[0].getAlphabetContainer(), vals);
    }

    private static void getScore(GenDisMixClassifier classifier, int subsample, String home, int patient, DoubleList listCl) throws Exception {
        if (classifier == null) {
            return;
        }
        double[] score = new double[2];
        String f = "" + (8 * (patient - 1) + 1 + subsample);
        while (f.length() < 4) {
            f = "0" + f;
        }
        SparseStringExtractor sse = new SparseStringExtractor(String.valueOf(home) + "/csv/" + f + ".CSV", '#');
        sse.nextElement();
        LinkedList list = new LinkedList();
        double[] grid = new double[]{0.5};
        int[] counts = new int[grid.length];
        Arrays.fill(counts, 1);
        int anz = counts.length;
        while (sse.hasMoreElements()) {
            Sequence seq = Dream6C4_final.create(classifier.getAlphabetContainer(), sse.nextElement(), ",");
            score[0] = classifier.getScore(seq, 0);
            score[1] = classifier.getScore(seq, 1);
            Normalisation.logSumNormalisation(score);
            int j = 0;
            while (j < grid.length) {
                if (score[0] > grid[j]) {
                    int n = j;
                    counts[n] = counts[n] + 1;
                }
                ++j;
            }
            ++anz;
        }
        int j = 0;
        while (j < grid.length) {
            listCl.add((double)counts[j] / (double)anz);
            ++j;
        }
    }

    private static Sequence create(AlphabetContainer con, String line, String delim) throws WrongAlphabetException, WrongSequenceTypeException {
        String[] split = line.split(delim);
        double[] val = new double[split.length];
        int i = 0;
        while (i < val.length) {
            val[i] = Double.parseDouble(split[i]);
            val[i] = val[i] <= 0.0 ? -1.7976931348623157E308 : Math.log(val[i]);
            ++i;
        }
        return new ArbitrarySequence(con, val);
    }

    private static DifferentiableSequenceScore getLogistic(AlphabetContainer con, int length, boolean prod) throws CloneNotSupportedException {
        ArrayList<ProductConstraint> constraints = new ArrayList<ProductConstraint>();
        int i = 0;
        while (i < length) {
            constraints.add(new ProductConstraint(i));
            if (prod) {
                int j = i + 1;
                while (j < length) {
                    constraints.add(new ProductConstraint(i, j++));
                }
            }
            ++i;
        }
        return new LogisticDiffSS(con, length, constraints.toArray(new LogisticConstraint[0]));
    }

    private static class Dream6C4Parameters
    extends ParameterSet {
        public static final String HOME = "home";
        public static final String DESC = "desc";
        public static final String THREADS = "threads";
        public static final String KEY = "key";
        public static final String CLS = "classifier";
        public static final String LOAD = "load";
        public static final String EPS = "epsilon";
        private static final String[] PREFIX = new String[]{"home", "desc", "key", "classifier", "load", "threads", "epsilon"};

        public Dream6C4Parameters() throws Exception {
            this.parameters.add(new Parameter[]{new SimpleParameter(DataType.STRING, "home directory", "the path to the data directory", true, "./")});
            this.parameters.add(new Parameter[]{new SimpleParameter(DataType.STRING, "description file", "file containing the description of patients for the training data", true, "DREAM6_AML_TrainingSet.csv")});
            this.parameters.add(new Parameter[]{new SimpleParameter(DataType.STRING, "unhealthy key", "key in the description file that indicates unhealthy patient, all other keys treated as healthy", true, "AML")});
            this.parameters.add(new Parameter[]{new SimpleParameter(DataType.STRING, CLS, "path to the classifier; if load=true, this classifier is loaded from the file, otherwise it is stored to that file", true, "final-classifiers.xml")});
            this.parameters.add(new Parameter[]{new SimpleParameter(DataType.BOOLEAN, LOAD, "load the classifier (instead of storing it)", true, false)});
            this.parameters.add(new Parameter[]{new SimpleParameter(DataType.INT, "compute threads", "the number of threads that are use to evaluate the objective function and its gradient", true, new NumberValidator<Integer>(1, 128), AbstractMultiThreadedOptimizableFunction.getNumberOfAvailableProcessors())});
            this.parameters.add(new Parameter[]{new SimpleParameter(DataType.DOUBLE, EPS, "the difference between two function evaluations to stop the training of the classifier", true, new NumberValidator<Double>(0.0, 1.0), 1.0E-6)});
        }
    }
}

