/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.classifiers.assessment;

import de.jstacs.classifiers.AbstractClassifier;
import de.jstacs.classifiers.ClassDimensionException;
import de.jstacs.classifiers.assessment.ClassifierAssessment;
import de.jstacs.classifiers.assessment.ClassifierAssessmentAssessParameterSet;
import de.jstacs.classifiers.assessment.KFoldCrossValidationAssessParameterSet;
import de.jstacs.classifiers.performanceMeasures.NumericalPerformanceMeasureParameterSet;
import de.jstacs.data.DataSet;
import de.jstacs.data.EmptyDataSetException;
import de.jstacs.data.WrongAlphabetException;
import de.jstacs.results.CategoricalResult;
import de.jstacs.results.ListResult;
import de.jstacs.results.MeanResultSet;
import de.jstacs.results.Result;
import de.jstacs.results.ResultSet;
import de.jstacs.sequenceScores.statisticalModels.trainable.TrainableStatisticalModel;
import de.jstacs.utils.Pair;
import de.jstacs.utils.ProgressUpdater;
import java.util.Arrays;
import java.util.LinkedList;

public class KFoldCrossValidation
extends ClassifierAssessment<KFoldCrossValidationAssessParameterSet> {
    protected KFoldCrossValidation(AbstractClassifier[] aCs, TrainableStatisticalModel[][] aMs, boolean buildClassifiersByCrossProduct, boolean checkAlphabetConsistencyAndLength) throws IllegalArgumentException, WrongAlphabetException, CloneNotSupportedException, ClassDimensionException {
        super(aCs, aMs, buildClassifiersByCrossProduct, checkAlphabetConsistencyAndLength);
    }

    public KFoldCrossValidation(AbstractClassifier ... aCs) throws IllegalArgumentException, WrongAlphabetException, CloneNotSupportedException, ClassDimensionException {
        super(aCs);
    }

    public KFoldCrossValidation(boolean buildClassifiersByCrossProduct, TrainableStatisticalModel[] ... aMs) throws IllegalArgumentException, WrongAlphabetException, CloneNotSupportedException, ClassDimensionException {
        super(buildClassifiersByCrossProduct, aMs);
    }

    public KFoldCrossValidation(AbstractClassifier[] aCs, boolean buildClassifiersByCrossProduct, TrainableStatisticalModel[] ... aMs) throws IllegalArgumentException, WrongAlphabetException, CloneNotSupportedException, ClassDimensionException {
        super(aCs, buildClassifiersByCrossProduct, aMs);
    }

    @Override
    protected void evaluateClassifier(NumericalPerformanceMeasureParameterSet mp, KFoldCrossValidationAssessParameterSet assessPS, DataSet[] s, double[][] weights, ProgressUpdater pU) throws IllegalArgumentException, Exception {
        DataSet.PartitionMethod splitMethod = assessPS.getDataSplitMethod();
        int k = assessPS.getK();
        DataSet[][] sInParts = new DataSet[s.length][];
        double[][][] weightsInParts = new double[s.length][][];
        try {
            int i = 0;
            while (i < sInParts.length) {
                Pair<DataSet[], double[][]> p = s[i].partition(weights[i], splitMethod, k);
                sInParts[i] = p.getFirstElement();
                weightsInParts[i] = p.getSecondElement();
                ++i;
            }
        }
        catch (EmptyDataSetException e) {
            throw new IllegalArgumentException("Given DataSet s seems to contain to few elements for a " + k + "-fold crossvalidation since at least one empty subset occured " + "during splitting given data into " + k + " non-overlapping parts.");
        }
        this.evaluate(mp, assessPS, pU, sInParts, weightsInParts);
    }

    private void evaluate(NumericalPerformanceMeasureParameterSet mp, ClassifierAssessmentAssessParameterSet caaps, ProgressUpdater pU, DataSet[][] splitData, double[][][] splitWeights) throws Exception {
        int subSeqL = caaps.getElementLength();
        boolean exceptionIfMPNotComputable = caaps.getExceptionIfMPNotComputable();
        int clazz = splitData.length;
        int k = splitData[0].length;
        int j = 1;
        DataSet[][] sTrainTestClassWise = new DataSet[2][clazz];
        double[][][] weightsTrainTestClassWise = new double[2][clazz][];
        boolean[] tempBool = new boolean[k];
        Arrays.fill(tempBool, true);
        while (j < splitData.length && splitData[j].length == k) {
            ++j;
        }
        if (j != splitData.length) {
            throw new IllegalArgumentException("Please check the number of predefined splits per class. Compare class 0 with class " + j);
        }
        pU.setMax(k);
        int i = 0;
        while (i < k) {
            tempBool[i] = false;
            j = 0;
            while (j < clazz) {
                Pair<DataSet, double[]> p = DataSet.union(splitData[j], splitWeights[j], tempBool);
                sTrainTestClassWise[0][j] = p.getFirstElement();
                weightsTrainTestClassWise[0][j] = p.getSecondElement();
                p = splitData[j][i].resize(splitWeights[j][i], subSeqL);
                sTrainTestClassWise[1][j] = p.getFirstElement();
                weightsTrainTestClassWise[1][j] = p.getSecondElement();
                ++j;
            }
            tempBool[i] = true;
            this.train(sTrainTestClassWise[0], weightsTrainTestClassWise[0]);
            this.test(mp, exceptionIfMPNotComputable, sTrainTestClassWise[1], weightsTrainTestClassWise[1]);
            pU.setValue(i + 1);
            ++i;
        }
    }

    public ListResult assessWithPredefinedSplits(NumericalPerformanceMeasureParameterSet mp, ClassifierAssessmentAssessParameterSet caaps, ProgressUpdater pU, DataSet[][] splitData, double[][][] splitWeights) throws Exception {
        int clazz = this.myAbstractClassifier[0].getNumberOfClasses();
        if (splitData.length != clazz) {
            throw new IllegalArgumentException("The number of classes in the data array and the classifier differs.");
        }
        if (splitWeights == null) {
            splitWeights = new double[clazz][][];
        }
        this.myTempMeanResultSets = new MeanResultSet[this.myAbstractClassifier.length];
        int i = 0;
        while (i < this.myAbstractClassifier.length) {
            this.myTempMeanResultSets[i] = new MeanResultSet(this.myAbstractClassifier[i].getClassifierAnnotation());
            ++i;
        }
        this.evaluate(mp, caaps, pU, splitData, (double[][][])splitWeights);
        LinkedList<Result> annotation = new LinkedList<Result>();
        annotation.add(new CategoricalResult("kind of assessment", "a description or name of the assessment", this.getNameOfAssessment()));
        annotation.addAll(caaps.getAnnotation());
        StringBuffer sb = new StringBuffer(1000);
        sb.append("[" + DataSet.getAnnotation(splitData[0]));
        int i2 = 1;
        while (i2 < splitData.length) {
            sb.append(", " + DataSet.getAnnotation(splitData[i2]));
            ++i2;
        }
        sb.append("]");
        annotation.add(new CategoricalResult("samples", "annotation of used samples", "predefined splits: " + sb));
        return new ListResult(this.getNameOfAssessment(), "the results of a " + this.getNameOfAssessment() + " of predefined splits", new ResultSet(annotation), this.myTempMeanResultSets);
    }

    @Override
    public KFoldCrossValidationAssessParameterSet getAssessParameterSet() throws Exception {
        return new KFoldCrossValidationAssessParameterSet();
    }
}

