/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.classifiers.differentiableSequenceScoreBased;

import de.jstacs.algorithms.optimization.DimensionException;
import de.jstacs.algorithms.optimization.EvaluationException;
import de.jstacs.classifiers.differentiableSequenceScoreBased.AbstractMultiThreadedOptimizableFunction;
import de.jstacs.classifiers.differentiableSequenceScoreBased.OptimizableFunction;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.EmptyDataSetException;
import de.jstacs.data.WrongAlphabetException;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.data.sequences.annotation.ReferenceSequenceAnnotation;
import de.jstacs.data.sequences.annotation.SequenceAnnotation;
import de.jstacs.results.Result;
import de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore;
import de.jstacs.utils.ComparableElement;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.ToolBox;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;

public class RankBasedOptimizableFunction3LW
extends AbstractMultiThreadedOptimizableFunction {
    private double c = 1.0;
    private double con = 1.0;
    private DifferentiableSequenceScore[] scores;
    private double[][][] sij;
    private double[][][][] wijk;
    private double[][] grads;
    private double[] vals;
    private IntList[] indices;
    private DoubleList[] partDers;
    private double eps = 1.0E-6;

    public RankBasedOptimizableFunction3LW(DifferentiableSequenceScore score, int threads, DataSet[] data, double[][] weights) throws IllegalArgumentException {
        super(threads, data, weights, false, false);
        this.scores = new DifferentiableSequenceScore[threads];
        this.scores[0] = score;
        this.precomputeIndexes();
    }

    @Override
    public int getDimensionOfScope() {
        return this.scores[0].getNumberOfParameters();
    }

    @Override
    protected void evaluateGradientOfFunction(int index, int startClass, int startSeq, int endClass, int endSeq) {
        if (startSeq != 0 || endSeq != this.data[endClass].getNumberOfElements()) {
            throw new RuntimeException();
        }
        try {
            this.evaluateFunction(index, startClass, startSeq, endClass, endSeq);
        }
        catch (EvaluationException e) {
            e.printStackTrace();
        }
        Arrays.fill(this.grads[index], 0.0);
        int cl = startClass;
        while (cl <= endClass) {
            int start = 0;
            int end = this.data[cl].getNumberOfElements();
            int i = start;
            while (i < end) {
                double sj;
                this.indices[index].clear();
                this.partDers[index].clear();
                double si = Math.exp(this.c * this.scores[index].getLogScoreAndPartialDerivation(this.data[cl].getElementAt(i), this.indices[index], this.partDers[index]));
                double ws = 0.0;
                int j = 0;
                while (j < i) {
                    sj = this.sij[index][cl - startClass][j];
                    ws += si / sj * this.wijk[index][cl - startClass][i][j];
                    ++j;
                }
                j = i + 1;
                while (j < this.data[cl].getNumberOfElements()) {
                    sj = this.sij[index][cl - startClass][j];
                    ws -= sj / si * this.wijk[index][cl - startClass][j][i];
                    ++j;
                }
                j = 0;
                while (j < this.indices[index].length()) {
                    double[] dArray = this.grads[index];
                    int n = this.indices[index].get(j);
                    dArray[n] = dArray[n] + ws * this.c * this.partDers[index].get(j);
                    ++j;
                }
                ++i;
            }
            ++cl;
        }
    }

    @Override
    protected double[] joinGradients() throws EvaluationException {
        int i = 1;
        while (i < this.grads.length) {
            int j = 0;
            while (j < this.grads[0].length) {
                double[] dArray = this.grads[0];
                int n = j;
                dArray[n] = dArray[n] + this.grads[i][j];
                ++j;
            }
            ++i;
        }
        i = 0;
        while (i < this.params.length) {
            double[] dArray = this.grads[0];
            int n = i;
            dArray[n] = dArray[n] + 2.0 * this.params[i] * this.con;
            ++i;
        }
        i = 0;
        while (i < this.params.length) {
            double[] dArray = this.grads[0];
            int n = i++;
            dArray[n] = dArray[n] / this.sum[this.cl];
        }
        return (double[])this.grads[0].clone();
    }

    @Override
    protected void evaluateFunction(int index, int startClass, int startSeq, int endClass, int endSeq) throws EvaluationException {
        if (startSeq != 0 || endSeq != this.data[endClass].getNumberOfElements()) {
            throw new RuntimeException();
        }
        double val = 0.0;
        int cl = startClass;
        while (cl <= endClass) {
            int start = 0;
            int end = this.data[cl].getNumberOfElements();
            int i = start;
            while (i < end) {
                this.sij[index][cl - startClass][i] = Math.exp(this.c * this.scores[index].getLogScoreFor(this.data[cl].getElementAt(i)));
                ++i;
            }
            i = 1;
            while (i < this.data[cl].getNumberOfElements()) {
                int j = 0;
                while (j < i) {
                    double si = this.sij[index][cl - startClass][i];
                    double sj = this.sij[index][cl - startClass][j];
                    double w = this.wijk[index][cl - startClass][i][j];
                    if (Double.isNaN(val += si / sj * w)) {
                        throw new EvaluationException(String.valueOf(si) + " " + sj + " " + w + " " + Arrays.toString(this.params));
                    }
                    ++j;
                }
                ++i;
            }
            ++cl;
        }
        this.vals[index] = val;
    }

    @Override
    protected double joinFunction() throws EvaluationException, DimensionException {
        double val = ToolBox.sum(this.vals);
        int i = 0;
        while (i < this.params.length) {
            val += this.params[i] * this.params[i] * this.con;
            ++i;
        }
        if (Double.isNaN(val /= this.sum[this.cl])) {
            throw new EvaluationException("NaN");
        }
        return val;
    }

    private double computeWeight(int cl, int i, int j, int numberOfElements, double maxWeight, double minWeight, double meanWeight, double sdWeight, DataSet ds) {
        double wlw2;
        double w1 = Double.parseDouble(ds.getElementAt(i).getSequenceAnnotationByType("localWeight2", 0).getIdentifier());
        double w2 = Double.parseDouble(ds.getElementAt(j).getSequenceAnnotationByType("localWeight2", 0).getIdentifier());
        double denom = 1.0;
        double wlw = Math.abs(w1 - w2);
        double res = wlw * (wlw2 = Math.max(w1, w2) - minWeight) * this.weights[cl][i] * this.weights[cl][j];
        if (Double.isNaN(res)) {
            System.out.println(String.valueOf(w1) + "\t" + w2 + "\t" + i + "\t" + j + "\t" + wlw + "\t" + wlw2 + "\t" + this.weights[cl][i] + "\t" + this.weights[cl][j] + "\t" + meanWeight + "\t" + sdWeight);
        }
        return res;
    }

    @Override
    protected void setThreadIndependentParameters() throws DimensionException {
    }

    @Override
    public void setDataAndWeights(DataSet[] data, double[][] weights) throws IllegalArgumentException {
        super.setDataAndWeights(data, weights);
        this.precomputeIndexes();
    }

    public void precomputeIndexes() {
        if (this.worker != null) {
            this.vals = new double[this.worker.length];
            this.sij = new double[this.worker.length][][];
            this.wijk = new double[this.worker.length][][][];
            this.indices = new IntList[this.worker.length];
            this.partDers = new DoubleList[this.worker.length];
            this.grads = new double[this.worker.length][this.getDimensionOfScope()];
            int i = 0;
            while (i < this.worker.length) {
                int[] temp = this.worker[i].getIndices();
                int startClass = temp[0];
                int startSeq = temp[1];
                int endClass = temp[2];
                int endSeq = temp[3];
                if (startSeq != 0 || endSeq != this.data[endClass].getNumberOfElements()) {
                    throw new RuntimeException();
                }
                this.sij[i] = new double[endClass - startClass + 1][];
                this.wijk[i] = new double[endClass - startClass + 1][][];
                int j = startClass;
                while (j <= endClass) {
                    this.sij[i][j - startClass] = new double[this.data[j].getNumberOfElements()];
                    this.wijk[i][j - startClass] = new double[this.data[j].getNumberOfElements()][];
                    double maxWeight = Double.NEGATIVE_INFINITY;
                    double minWeight = Double.POSITIVE_INFINITY;
                    int k = 0;
                    while (k < this.data[j].getNumberOfElements()) {
                        double mw = Double.parseDouble(this.data[j].getElementAt(k).getSequenceAnnotationByType("localWeight2", 0).getIdentifier());
                        if (mw > maxWeight) {
                            maxWeight = mw;
                        }
                        if (mw < minWeight) {
                            minWeight = mw;
                        }
                        ++k;
                    }
                    double meanWeight = 0.0;
                    double sdWeight = 0.0;
                    double n = this.data[j].getNumberOfElements();
                    int k2 = 0;
                    while (k2 < this.data[j].getNumberOfElements()) {
                        double mw = Double.parseDouble(this.data[j].getElementAt(k2).getSequenceAnnotationByType("localWeight2", 0).getIdentifier());
                        meanWeight += mw / n;
                        sdWeight += mw * mw / n;
                        ++k2;
                    }
                    sdWeight = Math.sqrt(sdWeight - meanWeight * meanWeight);
                    k2 = 0;
                    while (k2 < this.wijk[i][j - startClass].length) {
                        this.wijk[i][j - startClass][k2] = new double[k2];
                        int l = 0;
                        while (l < k2) {
                            this.wijk[i][j - startClass][k2][l] = this.computeWeight(j, k2, l, this.data[j].getNumberOfElements(), maxWeight, minWeight, meanWeight, sdWeight, this.data[j]);
                            ++l;
                        }
                        ++k2;
                    }
                    ++j;
                }
                this.indices[i] = new IntList();
                this.partDers[i] = new DoubleList();
                ++i;
            }
        }
    }

    @Override
    protected void setParams(int index) throws DimensionException {
        this.scores[index].setParameters(this.params, 0);
    }

    @Override
    public void getParameters(OptimizableFunction.KindOfParameter kind, double[] erg) throws Exception {
        double[] temp = this.scores[0].getCurrentParameterValues();
        System.arraycopy(temp, 0, erg, 0, temp.length);
    }

    @Override
    public void reset() throws Exception {
        int i = 1;
        while (i < this.scores.length) {
            this.scores[i] = this.scores[0].clone();
            ++i;
        }
    }

    /*
     * Enabled aggressive block sorting
     */
    @Override
    protected void prepareThreads() {
        double[] sizes = new double[this.data.length];
        int i = 0;
        while (i < this.data.length) {
            sizes[i] = this.data[i].getNumberOfElements();
            sizes[i] = sizes[i] * sizes[i] * Math.sqrt(this.data[i].getAverageElementLength());
            ++i;
        }
        double sum = ToolBox.sum(sizes);
        double part = sum / (double)this.worker.length;
        int startClass = 0;
        int i2 = 0;
        while (i2 < this.worker.length) {
            double curr = 0.0;
            int endClass = startClass;
            curr = sizes[endClass];
            while (endClass < sizes.length - 1 && curr + sizes[endClass + 1] <= part) {
                curr += sizes[endClass + 1];
                ++endClass;
            }
            sum -= curr;
            if (i2 == this.worker.length - 1) {
                endClass = this.data.length - 1;
            }
            if (this.worker[i2] != null) {
                if (!this.worker[i2].isWaiting()) {
                    this.stopThreads();
                    throw new RuntimeException();
                }
                this.worker[i2].setIndices(startClass, 0, endClass, this.data[endClass].getNumberOfElements());
            } else {
                this.worker[i2] = new AbstractMultiThreadedOptimizableFunction.Worker(i2, startClass, 0, endClass, this.data[endClass].getNumberOfElements());
                this.worker[i2].start();
            }
            startClass = endClass + 1;
            part = sum / (double)(this.worker.length - i2 - 1);
            ++i2;
        }
    }

    public static DataSet[] splitByTagAndSort(int numThreads, DataSet data, String splitTag, String sortTag, boolean filter) throws EmptyDataSetException, WrongAlphabetException {
        HashMap sets = new HashMap();
        int i = 0;
        while (i < data.getNumberOfElements()) {
            Sequence seq = data.getElementAt(i);
            String key = seq.getSequenceAnnotationByType(splitTag, 0).getIdentifier();
            if (!sets.containsKey(key)) {
                sets.put(key, new LinkedList());
            }
            ((LinkedList)sets.get(key)).add(seq);
            ++i;
        }
        DataSet[] ds = new DataSet[sets.keySet().size()];
        int i2 = 0;
        for (String key : sets.keySet()) {
            Sequence[] seqs = ((LinkedList)sets.get(key)).toArray(new Sequence[0]);
            Object[] ws = new ComparableElement[seqs.length];
            int j = 0;
            while (j < seqs.length) {
                double w = Double.parseDouble(seqs[j].getSequenceAnnotationByType(sortTag, 0).getIdentifier());
                ws[j] = new ComparableElement<Sequence, Double>(seqs[j], -w);
                ++j;
            }
            Arrays.sort(ws);
            j = 0;
            while (j < ws.length) {
                seqs[j] = (Sequence)((ComparableElement)ws[j]).getElement();
                seqs[j] = seqs[j].annotate(true, new SequenceAnnotation("intgroup", String.valueOf(i2), (Result[][])new Result[0][]));
                ++j;
            }
            if (filter) {
                ArrayList<Sequence> list = new ArrayList<Sequence>();
                int j2 = 0;
                while (j2 < seqs.length) {
                    SequenceAnnotation mask = seqs[j2].getSequenceAnnotationByType("mask", 0);
                    if (mask == null || mask.getIdentifier().indexOf("X") < 0) {
                        ReferenceSequenceAnnotation an = (ReferenceSequenceAnnotation)seqs[j2].getSequenceAnnotationByType("reference", 0);
                        Sequence ref = an.getReferenceSequence();
                        AlphabetContainer rvds = ref.getAlphabetContainer();
                        int nmm = 0;
                        int k = 0;
                        while (k < ref.getLength()) {
                            if ((double)ref.discreteVal(k) == rvds.getCode(k, "HD")) {
                                if (seqs[j2].discreteVal(k + 1) != 1) {
                                    ++nmm;
                                }
                            } else if ((double)ref.discreteVal(k) == rvds.getCode(k, "NI")) {
                                if (seqs[j2].discreteVal(k + 1) != 0) {
                                    ++nmm;
                                }
                            } else if ((double)ref.discreteVal(k) == rvds.getCode(k, "NG")) {
                                if (seqs[j2].discreteVal(k + 1) != 3) {
                                    ++nmm;
                                }
                            } else if ((double)ref.discreteVal(k) == rvds.getCode(k, "NN")) {
                                if (seqs[j2].discreteVal(k + 1) != 0 && seqs[j2].discreteVal(k + 1) != 2) {
                                    ++nmm;
                                }
                            } else {
                                nmm = 0;
                                break;
                            }
                            ++k;
                        }
                        if (nmm <= 3) {
                            list.add(seqs[j2]);
                        }
                    } else {
                        list.add(seqs[j2]);
                    }
                    ++j2;
                }
                seqs = list.toArray(new Sequence[0]);
            }
            ds[i2] = new DataSet("", seqs);
            ++i2;
        }
        if (numThreads > 1) {
            double[] sizes = new double[ds.length];
            i2 = 0;
            while (i2 < ds.length) {
                sizes[i2] = ds[i2].getNumberOfElements();
                sizes[i2] = sizes[i2] * sizes[i2] * Math.sqrt(ds[i2].getAverageElementLength());
                ++i2;
            }
            int[] order = ToolBox.order(sizes, true);
            IntList[] lists = new IntList[numThreads];
            int j = 0;
            while (j < lists.length) {
                lists[j] = new IntList();
                ++j;
            }
            double[] curr = new double[numThreads];
            int j3 = 0;
            while (j3 < order.length) {
                double size = sizes[order[j3]];
                int idx = ToolBox.getMinIndex(curr);
                lists[idx].add(order[j3]);
                int n = idx;
                curr[n] = curr[n] + size;
                ++j3;
            }
            DataSet[] ds2 = new DataSet[ds.length];
            int j4 = 0;
            int k = 0;
            while (j4 < lists.length) {
                int l = 0;
                while (l < lists[j4].length()) {
                    ds2[k] = ds[lists[j4].get(l)];
                    ++l;
                    ++k;
                }
                ++j4;
            }
            ds = ds2;
        }
        return ds;
    }
}

