/*
 * Decompiled with CFR 0.152.
 */
package projects.kmermotifs;

import de.jstacs.algorithms.optimization.ConstantStartDistance;
import de.jstacs.algorithms.optimization.NegativeDifferentiableFunction;
import de.jstacs.algorithms.optimization.Optimizer;
import de.jstacs.algorithms.optimization.termination.SmallDifferenceOfFunctionEvaluationsCondition;
import de.jstacs.classifiers.differentiableSequenceScoreBased.OptimizableFunction;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.LearningPrinciple;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.LogGenDisMixFunction;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.CompositeLogPrior;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.DoesNothingLogPrior;
import de.jstacs.data.DNADataSet;
import de.jstacs.data.DataSet;
import de.jstacs.data.WrongAlphabetException;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.ArrayHandler;
import de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore;
import de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModelFactory;
import de.jstacs.sequenceScores.statisticalModels.differentiable.homogeneous.HomogeneousDiffSM;
import de.jstacs.sequenceScores.statisticalModels.differentiable.homogeneous.HomogeneousMMDiffSM;
import de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.MixtureDiffSM;
import de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.StrandDiffSM;
import de.jstacs.utils.Pair;
import de.jstacs.utils.SafeOutputStream;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.Random;
import javax.naming.OperationNotSupportedException;
import projects.kmermotifs.DeconvolvablePWMDiffSM;
import projects.kmermotifs.DeconvolvableStrandDiffSM;
import projects.kmermotifs.DeconvolvedDiffSM;
import projects.kmermotifs.KMerStatistic;
import projects.kmermotifs.PositionWrapperDiffSM;
import projects.kmermotifs.UniformPositionStatisticsFunction;

public class KMerMain {
    public static void main(String[] args) throws Exception {
        DNADataSet data = new DNADataSet("/Users/dev/Desktop/ChIP-seq/positives_min7.fa");
        DNADataSet bgData = new DNADataSet("/Users/dev/Desktop/ChIP-seq/negatives.fa");
        int numMotifs = 10;
        int length = 6;
        int shift = 3;
        double motifESS = 40.0;
        double baseESS = motifESS * (data.getAverageElementLength() - (double)length + 1.0);
        System.out.println(String.valueOf(baseESS) + " " + motifESS);
        boolean initRandomly = true;
        double[][] weights = new double[data.getNumberOfElements()][1];
        int i = 0;
        while (i < weights.length) {
            weights[i][0] = 1.0;
            ++i;
        }
        UniformPositionStatisticsFunction posStat = new UniformPositionStatisticsFunction();
        int[] offsets = new int[data.getNumberOfElements()];
        KMerStatistic statData = new KMerStatistic(data, weights, offsets, length, 2, posStat);
        double[][] weightsBg = new double[bgData.getNumberOfElements()][1];
        int i2 = 0;
        while (i2 < weightsBg.length) {
            weightsBg[i2][0] = 1.0;
            ++i2;
        }
        int[] offsetsBg = new int[bgData.getNumberOfElements()];
        KMerStatistic statBg = new KMerStatistic(bgData, weightsBg, offsetsBg, length, 2, posStat);
        Pair<DataSet, double[]> pair = statData.getWeightedDataSet();
        Pair<DataSet, double[]> pair2 = statBg.getWeightedDataSet();
        double[] esss = new double[length];
        Arrays.fill(esss, motifESS / (double)numMotifs);
        double rem = 0.0;
        int i3 = 0;
        while (i3 < shift) {
            int j = 0;
            while (j <= i3) {
                int n = j;
                esss[n] = esss[n] - motifESS / (((double)shift * 2.0 + 1.0) * (double)numMotifs);
                int n2 = esss.length - j - 1;
                esss[n2] = esss[n2] - motifESS / (((double)shift * 2.0 + 1.0) * (double)numMotifs);
                rem += 2.0 * motifESS / ((double)shift * 2.0 + 1.0);
                ++j;
            }
            ++i3;
        }
        System.out.println(Arrays.toString(esss));
        HomogeneousMMDiffSM flanking = new HomogeneousMMDiffSM(data.getAlphabetContainer(), 1, baseESS - motifESS + rem / (double)length, length);
        DeconvolvablePWMDiffSM pwm = new DeconvolvablePWMDiffSM(data.getAlphabetContainer(), length, esss, motifESS / (double)numMotifs);
        DeconvolvableStrandDiffSM diffSM = new DeconvolvableStrandDiffSM(pwm, 0.5, 1, true, StrandDiffSM.InitMethod.INIT_FORWARD_STRAND);
        DifferentiableStatisticalModel[] models = ArrayHandler.cast(DifferentiableStatisticalModel.class, ArrayHandler.createArrayOf((Cloneable)DifferentiableStatisticalModelFactory.createStrandModel(DifferentiableStatisticalModelFactory.createPWM(data.getAlphabetContainer(), length, motifESS / (double)numMotifs)), (int)(numMotifs + 1)));
        models[models.length - 1] = flanking;
        AbstractDifferentiableStatisticalModel fg = new MixtureDiffSM(1, true, models);
        fg = new PositionWrapperDiffSM(fg, posStat);
        HomogeneousMMDiffSM bg = new HomogeneousMMDiffSM(data.getAlphabetContainer(), 1, baseESS, length);
        DataSet[] allData = new DataSet[]{pair.getFirstElement(), pair2.getFirstElement()};
        double[][] allWeights = new double[][]{pair.getSecondElement(), pair2.getSecondElement()};
        double[] initPars = null;
        if (initRandomly) {
            initPars = KMerMain.initRandomly(fg, (HomogeneousDiffSM)bg, allData, (double[][])allWeights, 100);
        }
        LogGenDisMixFunction fun = new LogGenDisMixFunction(2, new DifferentiableSequenceScore[]{fg, bg}, allData, allWeights, new CompositeLogPrior(), LearningPrinciple.getBeta(LearningPrinciple.MSP), true, false);
        fun.reset();
        NegativeDifferentiableFunction neg = new NegativeDifferentiableFunction(fun);
        Optimizer.optimize((byte)20, neg, initPars, new SmallDifferenceOfFunctionEvaluationsCondition(1.0E-10), 1.0E-10, new ConstantStartDistance(1.0E-4), SafeOutputStream.getSafeOutputStream(SafeOutputStream.DEFAULT_STREAM));
        fg.setParameters(initPars, 2);
        bg.setParameters(initPars, 2 + fg.getNumberOfParameters());
        System.out.println(fg);
        System.out.println(bg);
    }

    private static boolean containsShiftedVariant(LinkedList<Sequence> seqs, Sequence seq2) throws WrongAlphabetException, OperationNotSupportedException {
        System.out.println("testing: " + seq2);
        Sequence rc = seq2.reverseComplement();
        int i = 0;
        while (i < seqs.size()) {
            Sequence seq1 = seqs.get(i);
            if (KMerMain.isShiftedVariant(seq1, seq2) || KMerMain.isShiftedVariant(seq2, seq1) || KMerMain.isShiftedVariant(seq1, rc) || KMerMain.isShiftedVariant(rc, seq1)) {
                return true;
            }
            ++i;
        }
        return false;
    }

    private static boolean isShiftedVariant(Sequence seq1, Sequence seq2) throws WrongAlphabetException {
        int i = 0;
        while (i < seq1.getLength() / 2) {
            Sequence temp2;
            Sequence temp1 = seq1.getSubSequence(i);
            double normedDist = (double)temp1.getHammingDistance(temp2 = seq2.getSubSequence(0, seq2.getLength() - i)) / (double)temp1.getLength();
            if (normedDist < 0.2) {
                return true;
            }
            ++i;
        }
        return false;
    }

    private static double[][] initFlanking(HomogeneousDiffSM flanking, HomogeneousDiffSM background, DataSet[] data, double[][] weights) throws Exception {
        flanking.initializeUniformly(false);
        background.initializeUniformly(false);
        double[] pars = new double[2 + flanking.getNumberOfParameters() + background.getNumberOfParameters()];
        LogGenDisMixFunction fun = new LogGenDisMixFunction(2, new DifferentiableSequenceScore[]{flanking, background}, data, weights, new CompositeLogPrior(), LearningPrinciple.getBeta(LearningPrinciple.MSP), true, false);
        fun.reset();
        NegativeDifferentiableFunction neg = new NegativeDifferentiableFunction(fun);
        Optimizer.optimize((byte)20, neg, pars, new SmallDifferenceOfFunctionEvaluationsCondition(1.0E-10), 1.0E-10, new ConstantStartDistance(1.0E-4), SafeOutputStream.getSafeOutputStream(SafeOutputStream.DEFAULT_STREAM));
        double[][] res = new double[3][];
        res[0] = new double[flanking.getNumberOfParameters()];
        System.arraycopy(pars, 2, res[0], 0, res[0].length);
        res[1] = new double[background.getNumberOfParameters()];
        System.arraycopy(pars, 2 + res[0].length, res[1], 0, res[1].length);
        res[2] = new double[]{pars[0], pars[1]};
        return res;
    }

    /*
     * Unable to fully structure code
     */
    private static double[] initFromSequences(DeconvolvedDiffSM fg, HomogeneousDiffSM bg, KMerStatistic.KMerEntry[] diffs, int num, DataSet[] data, double[][] weights) throws Exception {
        fg.initializeFunctionRandomly(false);
        bg.initializeFunctionRandomly(false);
        pars = new double[2 + fg.getNumberOfParameters() + bg.getNumberOfParameters()];
        temp = KMerMain.initFlanking(fg.getFlankingModel(), bg, data, weights);
        fg.setFlankingParameters(temp[0], 0);
        bg.setParameters(temp[1], 0);
        seqs = new LinkedList<Sequence>();
        seqs.add(diffs[0].getKMer());
        r = new Random();
        i = 1;
        j = 1;
        ** GOTO lbl24
        {
            System.out.println(diffs[j].getKMer() + " " + diffs[j].getN());
            ++j;
            do {
                if (j < diffs.length && KMerMain.containsShiftedVariant(seqs, diffs[j].getKMer())) continue block0;
                if (j == diffs.length) {
                    j = r.nextInt(diffs.length);
                }
                System.out.println("found: " + diffs[j].getKMer());
                seqs.add(diffs[j].getKMer());
                ++i;
lbl24:
                // 2 sources

            } while (i < num);
        }
        System.out.println(seqs);
        ar = seqs.toArray(new Sequence[0]);
        i = 0;
        while (i < fg.getNumberOfMotifs()) {
            System.out.println(String.valueOf(i) + " " + ar[i]);
            fg.initializeMotif(new DataSet("", new Sequence[]{ar[i]}), new double[]{fg.getMotifESS(i) * 3.0}, i);
            System.out.println(fg);
            ++i;
        }
        pars = new double[2 + fg.getNumberOfParameters() + bg.getNumberOfParameters()];
        System.arraycopy(fg.getCurrentParameterValues(), 0, pars, 2, fg.getNumberOfParameters());
        System.arraycopy(bg.getCurrentParameterValues(), 0, pars, 2 + fg.getNumberOfParameters(), bg.getNumberOfParameters());
        return pars;
    }

    private static double[] initRandomly(DifferentiableStatisticalModel fg, HomogeneousDiffSM bg, DataSet[] data, double[][] weights, int num) throws Exception {
        fg.initializeFunctionRandomly(false);
        bg.initializeFunctionRandomly(false);
        LogGenDisMixFunction fun = new LogGenDisMixFunction(2, new DifferentiableSequenceScore[]{fg, bg}, data, weights, DoesNothingLogPrior.defaultInstance, LearningPrinciple.getBeta(LearningPrinciple.MCL), true, false);
        fun.reset();
        double[] pars = fun.getParameters(OptimizableFunction.KindOfParameter.PLUGIN);
        double[] bestPars = null;
        double bestScore = Double.NEGATIVE_INFINITY;
        int i = 0;
        while (i < num) {
            System.arraycopy(fg.getCurrentParameterValues(), 0, pars, 2, fg.getNumberOfParameters());
            System.arraycopy(bg.getCurrentParameterValues(), 0, pars, 2 + fg.getNumberOfParameters(), bg.getNumberOfParameters());
            fun.reset();
            double val = fun.evaluateFunction(pars);
            System.out.println(String.valueOf(i) + " " + val);
            if (val > bestScore) {
                bestScore = val;
                bestPars = (double[])pars.clone();
            }
            fg.initializeFunctionRandomly(false);
            bg.initializeFunctionRandomly(false);
            ++i;
        }
        return bestPars;
    }

    private static double[] initRandomly(DeconvolvedDiffSM fg, HomogeneousDiffSM bg, DataSet[] data, double[][] weights, int num) throws Exception {
        fg.initializeFunctionRandomly(false);
        bg.initializeFunctionRandomly(false);
        double[] pars = new double[2 + fg.getNumberOfParameters() + bg.getNumberOfParameters()];
        double[][] temp = KMerMain.initFlanking(fg.getFlankingModel(), bg, data, weights);
        pars[0] = temp[2][0];
        pars[1] = temp[2][1];
        double[] bestPars = null;
        double bestScore = Double.NEGATIVE_INFINITY;
        LogGenDisMixFunction fun = new LogGenDisMixFunction(2, new DifferentiableSequenceScore[]{fg, bg}, data, weights, DoesNothingLogPrior.defaultInstance, LearningPrinciple.getBeta(LearningPrinciple.MCL), true, false);
        int i = 0;
        while (i < num) {
            fg.setFlankingParameters(temp[0], 0);
            bg.setParameters(temp[1], 0);
            System.arraycopy(fg.getCurrentParameterValues(), 0, pars, 2, fg.getNumberOfParameters());
            System.arraycopy(bg.getCurrentParameterValues(), 0, pars, 2 + fg.getNumberOfParameters(), bg.getNumberOfParameters());
            fun.reset();
            double val = fun.evaluateFunction(pars);
            System.out.println(String.valueOf(i) + " " + val);
            if (val > bestScore) {
                bestScore = val;
                bestPars = (double[])pars.clone();
            }
            fg.initializeFunctionRandomly(false);
            ++i;
        }
        return bestPars;
    }
}

