package de.jstacs.classifier.scoringFunctionBased.logPrior;

import de.jstacs.algorithms.optimization.DimensionException;
import de.jstacs.algorithms.optimization.EvaluationException;
import de.jstacs.scoringFunctions.NormalizableScoringFunction;
import de.jstacs.scoringFunctions.ScoringFunction;
import de.jtem.numericalMethods.calculus.specialFunctions.Gamma;

/**
 * This class implements a composite prior that can be used for NormalizableScoringFunction. The prior for each
 * NormalizableScoringFunction should be the (transformed) prior of the corresponding generative model. So e.g. for a
 * PWM one should use an product of Dirichlets. The prior is more or less implemented in the
 * NormalizableScoringFunction. For the class variables the prior uses a (transformed) Dirichlet with hyperparameters
 * equal to the ESS of the classes.
 * 
 * <br>
 * <br>
 * 
 * If this class uses only the free parameters the class implements a real prior the is normalized to 1. If it used all
 * parameters the function does not have (and is in general) not normalized to 1. Fortunately this is no problem, since
 * it can be shown the it makes no difference in the optimization.
 * 
 * @author Jan Grau, Jens Keilwagen
 * 
 * @see NormalizableScoringFunction
 * @see NormalizableScoringFunction#addGradientOfLogPriorTerm(double[], int)
 * @see NormalizableScoringFunction#getLogPriorTerm()
 */
public class CompositeLogPrior extends LogPrior
{

	private NormalizableScoringFunction[] function;

	private double fullEss, logGammaSum;

	private double[] ess, expClass;

	private boolean freeParameters;

	/**
	 * The main constructor.
	 */
	public CompositeLogPrior(){}

	/**
	 * The constructor for the {@link de.jstacs.Storable} interface.
	 * 
	 * @param xml the StringBuffer
	 */
	public CompositeLogPrior( StringBuffer xml )
	{
		this();
	}

	public void set( boolean freeParameters, ScoringFunction... funs ) throws Exception
	{
		function = new NormalizableScoringFunction[funs.length];
		ess = new double[funs.length];
		expClass = new double[funs.length];
		fullEss = 0;
		logGammaSum = 0;
		for( int i = 0; i < funs.length; i++ )
		{
			if( !(funs[i] instanceof NormalizableScoringFunction) )
			{
				throw new Exception( "Only NormalizableScoringFunction allowed." );
			}
			else
			{
				function[i] = (NormalizableScoringFunction) funs[i];
			}
			ess[i] = function[i].getEss();
			if( ess[i] == 0 ) {
				throw new IllegalArgumentException( "The ess of the function " + i + " is zero, but should be positive." );
			}
			fullEss += ess[i];
			logGammaSum -= Gamma.logOfGamma(ess[i]);
		}
		logGammaSum += Gamma.logOfGamma(fullEss);
		this.freeParameters = freeParameters;
	}

	public void addGradientFor( double[] params, double[] grad ) throws EvaluationException
	{
		try
		{
						
			double fullNorm = 0;
			double[] norms = new double[function.length];
			int start = 0, j = function.length - (freeParameters ? 1 : 0), k = 0, l;
			for( ; k < j; k++ )
			{
				expClass[k] = Math.exp( params[k] );
				norms[k] = expClass[k] * function[k].getNormalizationConstant();
				fullNorm += norms[k];
			}
			if( freeParameters )
			{
				expClass[j] = 1d;
				norms[j] = function[j].getNormalizationConstant();
				fullNorm += norms[j];
			}

			for( start = 0; start < j; start++ )
			{
				grad[start] += ess[start] - fullEss * norms[start] / fullNorm;
			}

			for( j = 0 ; j < function.length; j++ )
			{
				function[j].addGradientOfLogPriorTerm( grad, start );
				k = function[j].getNumberOfParameters();
				for( l = 0; l < k; l++, start++ )
				{
					grad[start] -= (fullEss * expClass[j] * function[j].getPartialNormalizationConstant( l ) / fullNorm );
				}
			}
		}
		catch( Exception e )
		{
			e.printStackTrace();
			throw new EvaluationException( e.getMessage() );
		}
	}

	public double evaluateFunction( double[] x ) throws DimensionException, EvaluationException
	{
		try
		{
			double norm = 0, logProductPart = 0;
			int i = 0, j =  function.length - (freeParameters ? 1 : 0);
			for( ; i < j; i++ )
			{
				norm += Math.exp( x[i] ) * function[i].getNormalizationConstant();
				logProductPart += x[i] * ess[i] + function[i].getLogPriorTerm();
			}
			if( freeParameters )
			{
				norm += 1d * function[j].getNormalizationConstant();
				logProductPart += function[j].getLogPriorTerm();
			}
			return logGammaSum - fullEss * Math.log( norm ) + logProductPart;
		}
		catch( Exception e )
		{
			e.printStackTrace();
			EvaluationException eva = new EvaluationException( e.getCause().getMessage() );
			eva.setStackTrace( e.getStackTrace() );
			throw eva;
		}
	}

	public int getDimensionOfScope()
	{
		int current, all = function.length - (freeParameters?1:0);
		for( int i = 0; i < function.length; i++ )
		{
			current = function[i].getNumberOfParameters();
			if( current == UNKNOWN )
			{
				return UNKNOWN;
			}
			else
			{
				all += current;
			}
		}
		return all;
	}

	public CompositeLogPrior getNewInstance() throws CloneNotSupportedException
	{
		return new CompositeLogPrior();
	}

	public StringBuffer toXML()
	{
		return new StringBuffer( 1 );
	}

	public String getInstanceName()
	{
		return "Composite log prior";
	}
}
