/*
 * This file is part of Jstacs.
 *
 * Jstacs is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 * 
 * Jstacs is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Jstacs.  If not, see <http://www.gnu.org/licenses/>.
 * 
 * For more information on Jstacs, visit http://www.jstacs.de
 */

package de.jstacs.classifier.scoringFunctionBased.cll;

import java.util.Arrays;

import de.jstacs.WrongAlphabetException;
import de.jstacs.algorithms.optimization.DimensionException;
import de.jstacs.algorithms.optimization.EvaluationException;
import de.jstacs.classifier.scoringFunctionBased.OptimizableFunction;
import de.jstacs.classifier.scoringFunctionBased.logPrior.DoesNothingLogPrior;
import de.jstacs.classifier.scoringFunctionBased.logPrior.LogPrior;
import de.jstacs.data.Sample;
import de.jstacs.data.Sequence;
import de.jstacs.scoringFunctions.ScoringFunction;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;

/**
 * This class implements the normalized log conditional likelihood. It can be used to maximize the parameters.
 * 
 * @author Jens Keilwagen
 */
public class NormConditionalLogLikelihood extends OptimizableFunction
{
	private ScoringFunction[] score;

	private int[] shortcut;

	private Sample[] data;

	private double[][] weights;
	
	private double[] helpArray, clazz, logClazz;

	private double[] sum;

	private DoubleList[] dList;

	private IntList[] iList;

	private int cl;

	private LogPrior prior;

	private boolean norm, freeParams;

	/**
	 * The constructor creates an instance of the log conditional likelihood.
	 * 
	 * @param score the ScoringFunctions
	 * @param data the data
	 * @param weights the weights
	 * @param norm the switch for using the normalization (division by the number of sequences)
	 * @param freeParams the switch for using only the free parameters
	 * 
	 * @throws IllegalArgumentException
	 * @throws WrongAlphabetException
	 */
	public NormConditionalLogLikelihood( ScoringFunction[] score, Sample[] data, double[][] weights, boolean norm,
			boolean freeParams ) throws IllegalArgumentException, WrongAlphabetException
	{
		this( score, data, weights, null, norm, freeParams );
	}

	/**
	 * The constructor creates an instance using the given prior.
	 * 
	 * @param score the ScoringFunctions
	 * @param data the data
	 * @param weights the weights
	 * @param prior the prior
	 * @param norm the switch for using the normalization (division by the number of sequences)
	 * @param freeParams the switch for using only the free parameters
	 * 
	 * @throws IllegalArgumentException
	 * @throws WrongAlphabetException
	 */
	public NormConditionalLogLikelihood( ScoringFunction[] score, Sample[] data, double[][] weights, LogPrior prior,
			boolean norm, boolean freeParams ) throws IllegalArgumentException, WrongAlphabetException
	{
		this.prior = (prior == null) ? DoesNothingLogPrior.defaultInstance : prior;
		this.norm = norm;
		this.freeParams = freeParams;
		shortcut = new int[score.length + 1];
		cl = score.length;
		if( cl < 2 || cl != data.length )
		{
			throw new IllegalArgumentException(
					"The number of classes is not correct. Check the the length of the constraint array as well as the length of the array f." );
		}
		if( freeParams )
		{
			shortcut[0] = cl - 1;
		}
		else
		{
			shortcut[0] = cl;
		}
		this.data = data;
		this.weights = weights;
		helpArray = new double[cl];
		logClazz = new double[cl];
		clazz = new double[cl];
		dList = new DoubleList[cl];
		iList = new IntList[cl];
		this.score = score;
		sum = new double[cl+1];
		sum[cl] = 0;
		int i = 0, j;
		for( ; i < cl; i++ )
		{
			dList[i] = new DoubleList();
			iList[i] = new IntList();
			sum[i] = 0;
			for( j = 0; j < weights[i].length; j++ )
			{
				sum[i] += weights[i][j];
			}
			sum[cl] += sum[i];
		}
	}

	public double[] evaluateGradientOfFunction( double[] x ) throws DimensionException, EvaluationException
	{
		setParams( x );
		double[] grad = new double[shortcut[cl]];
		double weight;
		int counter1, counter2, counter3 = 0, counter4 = 0;
		Sequence s;
		//comments are old version
		for( ; counter3 < cl; counter3++ )
		{
			for( counter2 = 0; counter2 < data[counter3].getNumberOfElements(); counter2++ )
			{
				s = data[counter3].getElementAt( counter2 );
				weight = weights[counter3][counter2];
				//l = 0;
				for( counter1 = 0; counter1 < cl; counter1++ )
				{
					iList[counter1].clear();
					dList[counter1].clear();
					
					//helpArray[counter1] = Math.exp( score[counter1].getLogScoreAndPartialDerivation( s, iList[counter1], dList[counter1] ) );
					//l += clazz[counter1] * helpArray[counter1];
					
					helpArray[counter1] = logClazz[counter1] + score[counter1].getLogScoreAndPartialDerivation( s, 0, iList[counter1], dList[counter1] );
				}

				Normalisation.logSumNormalisation( helpArray,0,helpArray.length, helpArray, 0 ); 
				
				for( counter1 = 0; counter1 < shortcut[0]; counter1++ )
				{
					if( counter1 != counter3 )
					{
						//grad[counter1] -= weight * clazz[counter1] * helpArray[counter1] / l;
						grad[counter1] -= weight * helpArray[counter1];
					}
					else
					{
						//grad[counter1] += weight * (1 - clazz[counter1] * helpArray[counter1] / l);
						grad[counter1] += weight * (1 - helpArray[counter1] );
					}
				}
				for( counter1 = 0; counter1 < cl; counter1++ )
				{
					if( counter1 != counter3 )
					{
						for( counter4 = 0; counter4 < iList[counter1].length(); counter4++ )
						{
							//grad[shortcut[counter1] + iList[counter1].get( counter4 )] -= weight * dList[counter1].get( counter4 ) * clazz[counter1] / l;
							grad[shortcut[counter1] + iList[counter1].get( counter4 )] -= weight * dList[counter1].get( counter4 ) * helpArray[counter1];
						}
					}
					else
					{
						for( counter4 = 0; counter4 < iList[counter1].length(); counter4++ )
						{
							//grad[shortcut[counter1] + iList[counter1].get( counter4 )] += weight * dList[counter1].get( counter4 ) * (1d / helpArray[counter1] - clazz[counter1] / l);
							grad[shortcut[counter1] + iList[counter1].get( counter4 )] += weight * dList[counter1].get( counter4 ) * (1d  - helpArray[counter1]);
						}
					}
				}
			}
		}

		// prior
		prior.addGradientFor( x, grad );

		// normalization
		if( norm )
		{
			for( counter1 = 0; counter1 < grad.length; counter1++ )
			{
				grad[counter1] /= sum[cl];
			}
		}	
		
		return grad;
	}

	public double evaluateFunction( double[] x ) throws DimensionException, EvaluationException
	{
		setParams( x );

		double cll = 0, pr;
		int counter1, counter2, counter3 = 0;

		Sequence s;
		for( ; counter3 < cl; counter3++ )
		{
			for( counter2 = 0; counter2 < data[counter3].getNumberOfElements(); counter2++ )
			{
				s = data[counter3].getElementAt( counter2 );
				for( counter1 = 0; counter1 < cl; counter1++ )
				{
					// class weight + class score
					helpArray[counter1] = logClazz[counter1] + score[counter1].getLogScore( s, 0 );
				}
				cll += weights[counter3][counter2]
						* (helpArray[counter3] - Normalisation.getLogSum( helpArray ));
			}
		}

		pr = prior.evaluateFunction( x );
		//System.out.println( (cll/sum[cl]) + " + " + (pr/sum[cl]) );

		if( Double.isNaN( cll+pr ) )
		{
			System.out.println( "params " + Arrays.toString( x ) );System.out.flush();
			throw new EvaluationException( "Evaluating the function gives: " + cll + " + " + pr );
		}
		else if( norm )
		{
			// normalization
			return (cll+pr) / sum[cl];
		}
		else
		{
			return cll+pr;
		}
	}

	public int getDimensionOfScope()
	{
		return shortcut[cl];
	}

	/**
	 * This method enables the user to get the start parameters without creating a new array.
	 * 
	 * @param plugIn a switch to decide whether to used plug-in parameters or not
	 * @param erg the array for the start parameters
	 * 
	 * @throws Exception if the array is null or does not have the correct length
	 *  
	 * @see NormConditionalLogLikelihood#getStartParams(boolean)
	 */
	public void getStartParams( boolean plugIn, double[] erg) throws Exception{
		if(erg == null || erg.length != getDimensionOfScope()){
			throw new Exception("Null argument or length do not match.");
		}
		if( plugIn )
		{
			double l = freeParams?Math.log(sum[cl-1]):Math.log(sum[cl]);
			for( int i = 0; i < cl; i++ )
			{
				if( i < shortcut[0] )
				{
					erg[i] = Math.log(sum[i]) - l;
				}
				System.arraycopy( score[i].getCurrentParameterValues(), 0, erg, shortcut[i], score[i]
				                                                                                   .getNumberOfParameters() );
			}
		}
	}
	
	public double[] getStartParams( boolean plugIn ) throws Exception
	{
		double[] temp = new double[getDimensionOfScope()];
		getStartParams( plugIn, temp );
		return temp;
	}

	public void setParams( double[] params ) throws DimensionException
	{
		if( params == null || params.length != shortcut[cl] )
		{
			if( params != null )
			{
				throw new DimensionException( params.length, shortcut[cl] );
			}
			else
			{
				throw new DimensionException( 0, shortcut[cl] );
			}
		}
		for( int counter1 = 0; counter1 < cl; counter1++ )
		{
			if( counter1 < shortcut[0] || !freeParams )
			{
				logClazz[counter1] = params[counter1];
				clazz[counter1] = Math.exp( params[counter1] );
			}
			else
			{
				logClazz[counter1] = 0;
				clazz[counter1] = 1;
			}
			score[counter1].setParameters( params, shortcut[counter1] );
		}
	}

	public double[] getClassParams( double[] params )
	{
		double[] res = new double[cl];
		System.arraycopy( params, 0, res, 0, shortcut[0] );
		if( freeParams )
		{
			res[shortcut[0]] = 0;
		}
		return res;
	}

	@Override
	public int getNumberOfStarts()
	{
		int starts = score[0].getNumberOfRecommendedStarts();
		for( int i = 1; i < score.length; i++ )
		{
			starts = Math.max( starts, score[i].getNumberOfRecommendedStarts() );
		}
		return starts;
	}

	public void reset( ScoringFunction[] funs ) throws Exception
	{
		if( funs.length != cl )
		{
			throw new IllegalArgumentException( "Could not reset." );
		}
		for( int i = 0; i < cl; i++ )
		{
			score[i] = funs[i];
			shortcut[i + 1] = shortcut[i] + score[i].getNumberOfParameters();
		}
		if( prior != null )
		{
			prior.set( freeParams, score );
		}
	}
}
