/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.sequenceScores.statisticalModels.trainable.discrete.inhomogeneous.shared;

import de.jstacs.algorithms.graphs.tensor.SymmetricTensor;
import de.jstacs.classifiers.ClassDimensionException;
import de.jstacs.classifiers.trainSMBased.TrainSMBasedClassifier;
import de.jstacs.data.DataSet;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.results.CategoricalResult;
import de.jstacs.sequenceScores.statisticalModels.trainable.TrainableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.trainable.discrete.inhomogeneous.FSDAGTrainSM;
import de.jstacs.sequenceScores.statisticalModels.trainable.discrete.inhomogeneous.StructureLearner;
import de.jstacs.sequenceScores.statisticalModels.trainable.discrete.inhomogeneous.parameters.BayesianNetworkTrainSMParameterSet;

public class SharedStructureClassifier
extends TrainSMBasedClassifier {
    private StructureLearner.ModelType model;
    private byte order;
    private StructureLearner.LearningType method;
    private StructureLearner sl;

    public SharedStructureClassifier(int length, StructureLearner.ModelType model, byte order, StructureLearner.LearningType method, FSDAGTrainSM ... models) throws IllegalArgumentException, CloneNotSupportedException, ClassDimensionException {
        super(true, (TrainableStatisticalModel[])models);
        this.model = model;
        if (order < 0) {
            throw new IllegalArgumentException("The value of order has to be non-negative.");
        }
        this.order = order;
        this.method = method;
        this.sl = new StructureLearner(this.getAlphabetContainer(), length);
    }

    public SharedStructureClassifier(StringBuffer xml) throws NonParsableException {
        super(xml);
    }

    @Override
    public SharedStructureClassifier clone() throws CloneNotSupportedException {
        SharedStructureClassifier clone = (SharedStructureClassifier)super.clone();
        clone.sl = new StructureLearner(this.getAlphabetContainer(), this.getLength());
        return clone;
    }

    @Override
    public void train(DataSet[] data, double[][] weights) throws IllegalArgumentException, Exception {
        int dimension = this.models.length;
        SymmetricTensor[] parts = new SymmetricTensor[dimension];
        double[] w = new double[dimension];
        for (int i = 0; i < dimension; ++i) {
            this.sl.setESS(((FSDAGTrainSM)this.models[i]).getESS());
            parts[i] = this.sl.getTensor(data[i], weights[i], this.order, this.method);
            w[i] = 1.0;
        }
        FSDAGTrainSM.train(this.models, StructureLearner.getStructure(new SymmetricTensor(parts, w), this.model, this.order), weights, data);
    }

    @Override
    public String getInstanceName() {
        return "shared-structure classifier";
    }

    @Override
    protected void extractFurtherClassifierInfosFromXML(StringBuffer xml) throws NonParsableException {
        super.extractFurtherClassifierInfosFromXML(xml);
        this.model = XMLParser.extractObjectForTags(xml, "model", StructureLearner.ModelType.class);
        this.order = XMLParser.extractObjectForTags(xml, "order", Byte.TYPE);
        this.method = XMLParser.extractObjectForTags(xml, "method", StructureLearner.LearningType.class);
        this.sl = new StructureLearner(this.getAlphabetContainer(), this.getLength());
    }

    @Override
    protected StringBuffer getFurtherClassifierInfos() {
        StringBuffer xml = super.getFurtherClassifierInfos();
        XMLParser.appendObjectWithTags(xml, (Object)this.model, "model");
        XMLParser.appendObjectWithTags(xml, this.order, "order");
        XMLParser.appendObjectWithTags(xml, (Object)this.method, "method");
        return xml;
    }

    @Override
    public CategoricalResult[] getClassifierAnnotation() {
        CategoricalResult[] res = new CategoricalResult[this.models.length + 1];
        res[0] = new CategoricalResult("classifier", "a <b>short</b> description of the classifier", this.getInstanceName());
        int i = 0;
        while (i < this.models.length) {
            res[i + 1] = new CategoricalResult("class info " + i, "some information about the class", BayesianNetworkTrainSMParameterSet.getModelInstanceName(this.model, this.order, this.method, ((FSDAGTrainSM)this.models[i++]).getESS()));
        }
        return res;
    }
}

