/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.classifiers.neuralNetworks.neurons;

import de.jstacs.classifiers.neuralNetworks.activationFunctions.ActivationFunction;
import de.jstacs.classifiers.neuralNetworks.neurons.Neuron;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import java.util.Arrays;
import java.util.Random;

public class InnerNeuron
extends Neuron {
    private static Random r = new Random();
    protected Neuron[] predecessors;
    protected double[] weights;
    protected int index;
    protected ActivationFunction activationFunction;
    protected Double h = null;
    protected Double output = null;
    protected Double error = null;
    protected double[] delta;
    protected Sequence last;

    public InnerNeuron(ActivationFunction activationFunction, int index, Neuron ... predecessors) {
        this.predecessors = (Neuron[])predecessors.clone();
        int i = 0;
        while (i < this.predecessors.length) {
            predecessors[i].addDescendant(this);
            ++i;
        }
        this.weights = new double[predecessors.length + 1];
        this.delta = new double[this.weights.length];
        this.activationFunction = activationFunction;
        this.index = index;
    }

    public final void setPredecessors(Neuron ... predecessors) {
        if (this.predecessors != null) {
            throw new IllegalArgumentException("Can set predecessors only initially");
        }
        this.predecessors = (Neuron[])predecessors.clone();
        int i = 0;
        while (i < this.predecessors.length) {
            predecessors[i].addDescendant(this);
            ++i;
        }
    }

    @Override
    public int getNumberOfWeights() {
        return this.weights.length;
    }

    @Override
    public double getOutput(Sequence input) {
        if (this.last != null && input != this.last) {
            this.reset();
        }
        if (this.output == null) {
            double output = 0.0;
            int i = 0;
            while (i < this.predecessors.length) {
                output += this.weights[i] * this.predecessors[i].getOutput(input);
                ++i;
            }
            this.h = output -= this.weights[this.weights.length - 1];
            this.output = this.activationFunction.getValue(output);
            this.last = input;
        }
        return this.output;
    }

    @Override
    public void reset() {
        this.output = null;
        this.error = null;
        this.h = null;
    }

    @Override
    public void initializeRandomly() {
        int i = 0;
        while (i < this.weights.length) {
            this.weights[i] = r.nextGaussian();
            ++i;
        }
    }

    @Override
    public double getError(Sequence input, double weight, double[] desiredOutputs) {
        if (this.last != input) {
            throw new RuntimeException();
        }
        if (this.error == null) {
            double err = 0.0;
            int i = 0;
            while (i < this.getNumberOfDescendants()) {
                err += this.getDescendant(i).getError(input, weight, desiredOutputs) * this.getDescendant(i).getWeightForPredecessor(this.index);
                ++i;
            }
            this.error = err * this.activationFunction.getDerivation(this.h);
            i = 0;
            while (i < this.predecessors.length) {
                int n = i;
                this.delta[n] = this.delta[n] + this.error * this.predecessors[i].getOutput(input);
                ++i;
            }
            int n = this.delta.length - 1;
            this.delta[n] = this.delta[n] + this.error * -1.0;
        }
        return this.error;
    }

    @Override
    public void adaptWeights(double eta) {
        int i = 0;
        while (i < this.weights.length) {
            int n = i;
            this.weights[n] = this.weights[n] + eta * this.delta[i];
            ++i;
        }
        Arrays.fill(this.delta, 0.0);
    }

    public double getWeightForPredecessor(int index) {
        return this.weights[index];
    }

    @Override
    protected void extractFurtherInformation(StringBuffer xml) throws NonParsableException {
        this.activationFunction = XMLParser.extractObjectForTags(xml, "activationFunction", ActivationFunction.class);
        this.delta = XMLParser.extractObjectForTags(xml, "delta", double[].class);
        this.index = XMLParser.extractObjectForTags(xml, "index", Integer.TYPE);
        this.weights = XMLParser.extractObjectForTags(xml, "weights", double[].class);
    }

    @Override
    protected StringBuffer getFurtherInformation() {
        StringBuffer xml = new StringBuffer();
        XMLParser.appendObjectWithTags(xml, this.activationFunction, "activationFunction");
        XMLParser.appendObjectWithTags(xml, this.delta, "delta");
        XMLParser.appendObjectWithTags(xml, this.index, "index");
        XMLParser.appendObjectWithTags(xml, this.weights, "weights");
        return xml;
    }

    public String toString() {
        return Arrays.toString(this.weights);
    }
}

