/*
 * This file is part of Jstacs.
 *
 * Jstacs is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 * 
 * Jstacs is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Jstacs.  If not, see <http://www.gnu.org/licenses/>.
 * 
 * For more information on Jstacs, visit http://www.jstacs.de
 */

package de.jstacs.scoringFunctions.mix;

import java.util.Arrays;

import de.jstacs.NonParsableException;
import de.jstacs.data.Sample;
import de.jstacs.data.Sequence;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.XMLParser;
import de.jstacs.scoringFunctions.AbstractNormalizableScoringFunction;
import de.jstacs.scoringFunctions.NormalizableScoringFunction;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.random.DirichletMRG;
import de.jstacs.utils.random.DirichletMRGParams;
import de.jtem.numericalMethods.calculus.specialFunctions.Gamma;

/**
 * This main abstract class for any mixture (e.g. &quot;real&quot; mixture, strand mixture, hidden motif, ...).
 * The potential for the hidden variables is parameterized depending on the parameterization of the given NormalizableScoringFunctions.
 * If these are already normalized (see {@link de.jstacs.scoringFunctions.NormalizableScoringFunction#isNormalized()}) the
 * potential is parameterized using the Meila-parameterization, otherwise it is parameterized using the unnormalized MRF-parameterization.
 * 
 * @author Jens Keilwagen
 */
public abstract class AbstractMixtureScoringFunction extends AbstractNormalizableScoringFunction
{
	/**
	 * Returns the index with maximal value in the array.
	 * 
	 * @param w
	 *            the array
	 * 
	 * @return the index
	 */
	protected static final int getMaxIndex( double[] w )
	{
		int max = 0, i = 1;
		for( ; i < w.length; i++ )
		{
			if( w[i] > w[max] )
			{
				max = i;
			}
		}
		return max;
	}
	
	private int starts;

	/**
	 * This array contains the references/indices for the parameters. Only the start index for each new function is stored.
	 */
	protected int[] paramRef;

	/**
	 * This boolean indicates whether to optimize the hidden variables of this instance. (It is not used recursive.)
	 */
	protected boolean optimizeHidden;
	
	/**
	 * This boolean indicates whether free parameterization or all parameters are used. 
	 */
	private boolean freeParams;
	
	/**
	 * This boolean indicates whether to use a plug-in strategy to initialize the instance.
	 */
	protected boolean plugIn;

	/**
	 * This array contains the internal functions that are used to determine the score.
	 */
	protected NormalizableScoringFunction[] function;

	/**
	 * This array contains the hidden parameters of the instance. 
	 */
	protected double[] hiddenParameter;
	
	/**
	 * This array contains the logarithm of the hidden potentials of the instance 
	 */
	protected double[] logHiddenPotential;
	
	/**
	 * This array contains the hidden potentials of the instance. 
	 */
	protected double[] hiddenPotential;
	
	/**
	 * This array is used while computing the score. It stores the scores of the components and is used to avoid creating a new array every time.  
	 */
	protected double[] componentScore;
	
	/**
	 * This array contains the partial normalization constants, i.e. the normalization constant for each component. 
	 */
	protected double[] partNorm;

	/**
	 * This <code>double</code> contains normalization constant of the instance.
	 */
	protected double norm;
	
	/**
	 * This <code>double</code> contains logarithm of the normalization constant of hidden parameters of the instance.
	 */
	protected double logHiddenNorm;
	
	/**
	 * This <code>double</code> contains sum of the logarithm of the gamma functions used in the prior.
	 * 
	 * @see AbstractMixtureScoringFunction#computeLogGammaSum()
	 */
	protected double logGammaSum;

	/**
	 * This array contains some {@link DoubleList}s that are used while computing the partial derivation
	 */
	protected DoubleList[] dList;

	/**
	 * This array contains some {@link IntList}s that are used while computing the partial derivation
	 */
	protected IntList[] iList;
	
	/**
	 * This boolean indicates whether this instance is a normalized one or not.
	 */
	protected boolean isNormalized;

	/**
	 * This constructor creates an AbstractMixtureScoringFunction.
	 * 
	 * @param length
	 *            the sequence length that should be modeled
	 * @param starts
	 *            the number of starts the should be done in an optimization
	 * @param dimension
	 *            the number of different mixture components
	 * @param optimizeHidden
	 *            whether the parameters for the hidden variables should be optimized
	 * @param plugIn
	 *            whether the initial parameters for an optimization should be related to the data or randomly drawn
	 * @param function
	 *            the ScoringFunctions
	 * 
	 * @throws CloneNotSupportedException
	 */
	public AbstractMixtureScoringFunction( int length, int starts, int dimension, boolean optimizeHidden,
			boolean plugIn, NormalizableScoringFunction... function ) throws CloneNotSupportedException
	{
		super( function[0].getAlphabetContainer(), length );
		this.function = ArrayHandler.clone( function );
		if( starts < 1 )
		{
			throw new IllegalArgumentException( "The number of recommended starts has to be positive." );
		}
		else
		{
			this.starts = starts;
		}
		if( dimension == 0 )
		{
			throw new IllegalArgumentException( "The number of components has to be positive." );
		}
		
		isNormalized = isNormalized( function );
		hiddenParameter = new double[dimension];
		logHiddenPotential = new double[dimension];
		hiddenPotential = new double[dimension];
		partNorm = new double[dimension];
		setHiddenParameters( hiddenParameter, 0 );
		
		componentScore = new double[dimension];
		this.optimizeHidden = optimizeHidden && dimension > 1;
		this.plugIn = plugIn;
		paramRef = null;
		init( freeParams );
	}

	/**
	 * This method is used to precompute the sum of the logarithm of the gamma functions that is used in the prior. 
	 */
	protected void computeLogGammaSum()
	{
		logGammaSum = 0;
		int i = 0, n = getNumberOfComponents();
		if( n > 1 && getEss() > 0 )
		{
			double sum = 0, h;
			for( ; i < n; i++ )
			{
				h = getHyperparameterForHiddenParameter( i );
				sum += h;
				logGammaSum -= Gamma.logOfGamma( h );
			}
			logGammaSum += Gamma.logOfGamma( sum );
		}
	}
	
	/**
	 * This is the constructor for {@link de.jstacs.Storable}.
	 * 
	 * @param xml the xml representation
	 * 
	 * @throws NonParsableException if the representation could not be parsed.
	 */
	public AbstractMixtureScoringFunction( StringBuffer xml ) throws NonParsableException
	{
		super( xml );
	}

	public AbstractMixtureScoringFunction clone() throws CloneNotSupportedException
	{
		AbstractMixtureScoringFunction clone = (AbstractMixtureScoringFunction) super.clone();
		clone.cloneFunctions( function );
		clone.hiddenParameter = hiddenParameter.clone();
		clone.logHiddenPotential = logHiddenPotential.clone();
		clone.hiddenPotential = hiddenPotential.clone();
		clone.componentScore = componentScore.clone();
		clone.partNorm = partNorm.clone();
		clone.iList = null;
		clone.paramRef = null;
		clone.init( freeParams );
		return clone;
	}
	
	/**
	 * This method clones the given array of function and enables the user to do some postprocessing. This method is
	 * only used in {@link AbstractMixtureScoringFunction#clone()}.
	 * 
	 * @param originalFunctions the array of functions to be cloned
	 * 
	 * @throws CloneNotSupportedException
	 */
	protected void cloneFunctions( NormalizableScoringFunction[] originalFunctions ) throws CloneNotSupportedException
	{
		function = ArrayHandler.clone( originalFunctions );
	}

	/**
	 * This method returns the hyperparameter for the hidden parameter with index <code>index</code>.
	 * 
	 * @param index the index
	 * 
	 * @return the hyperparameter
	 */
	public abstract double getHyperparameterForHiddenParameter( int index );
	
	public double getLogPriorTerm()
	{
		double val = 0, h, sum = 0;
		for( int i = 0; i < hiddenParameter.length; i++ )
		{
			h = getHyperparameterForHiddenParameter(i);
			sum += h;
			val += hiddenParameter[i] * h;
		}
		if( isNormalized )
		{
			val -= sum * logHiddenNorm;
		}
		for( int i = 0; i < function.length; i++ )
		{
			val += function[i].getLogPriorTerm();
		}
		return val + logGammaSum;
	}

	public void addGradientOfLogPriorTerm( double[] grad, int start ) throws Exception
	{
		int i = 0, j = function.length + 1;
		for( ; i < function.length; i++ )
		{
			function[i].addGradientOfLogPriorTerm( grad, start + paramRef[i] );
		}
		j = start + paramRef[j];
		start += paramRef[function.length];
		double e = getEss();
		for( i = 0; start < j; i++, start++ )
		{
			grad[start] += getHyperparameterForHiddenParameter(i) - (isNormalized?e*hiddenPotential[i]:0);
		}
	}

	/**
	 * Returns the index of the component that has the greatest impact on the complete score
	 * 
	 * @param seq
	 *            the sequence
	 * @param start
	 *            the start position
	 * 
	 * @return the index of the component
	 */
	public int getIndexOfMaximalComponentFor( Sequence seq, int start )
	{
		fillComponentScores( seq, start );
		return getMaxIndex( componentScore );
	}

	public double[] getCurrentParameterValues() throws Exception
	{
		int numPars = this.getNumberOfParameters();
		if( numPars == UNKNOWN )
		{
			throw new Exception( "No parameters exists, yet." );
		}
		else
		{
			double[] part, current = new double[numPars];
			int i = 0, j = function.length;
			while( i < j )
			{
				part = function[i].getCurrentParameterValues();
				System.arraycopy( part, 0, current, paramRef[i++], part.length );
			}
			System.arraycopy( hiddenParameter, 0, current, paramRef[j], paramRef[j+1] - paramRef[j] );
			return current;
		}
	}

	public double getLogScore( Sequence seq, int start )
	{
		fillComponentScores( seq, start );
		return Normalisation.getLogSum( componentScore );
	}

	public final double getNormalizationConstant()
	{
		if( norm < 0 )
		{
			precomputeNorm();
		}
		return norm;
	}

	/**
	 * Returns the number of different components.
	 * 
	 * @return the number of different components.
	 */
	public final int getNumberOfComponents()
	{
		return componentScore.length;
	}

	public final int getNumberOfParameters()
	{
		if( paramRef == null )
		{
			return UNKNOWN;
		}
		else
		{
			return paramRef[paramRef.length - 1];
		}
	}

	public final int getNumberOfRecommendedStarts()
	{
		return starts;
	}

	/**
	 * Returns the probabilities for each component
	 * 
	 * @param seq
	 *            the sequence
	 * 
	 * @return an array containing the probability of component i (=p(i|class,seq)) in entry i
	 */
	public double[] getProbsForComponent( Sequence seq )
	{
		fillComponentScores( seq, 0 );
		double[] p = new double[componentScore.length];
		Normalisation.logSumNormalisation( componentScore, 0, p.length, p, 0 );
		return p;
	}

	/**
	 * Returns a deep copy of all internal used ScoringFunctions
	 * 
	 * @return a deep copy of all internal used ScoringFunctions
	 * 
	 * @throws CloneNotSupportedException
	 */
	public NormalizableScoringFunction[] getScoringFunctions() throws CloneNotSupportedException
	{
		return ArrayHandler.clone( function );
	}

	public int getSizeOfEventSpaceForRandomVariablesOfParameter( int index )
	{
		int[] ind = getIndices( index );
		if( ind[0] == function.length )
		{
			return hiddenParameter.length;
		}
		else
		{
			return function[ind[0]].getSizeOfEventSpaceForRandomVariablesOfParameter( ind[1] );
		}
	}

	public void initializeFunction( int index, boolean freeParams, Sample[] data, double[][] weights ) throws Exception
	{
		if( plugIn )
		{
			initializeUsingPlugIn( index, freeParams, data, weights );
			init( freeParams );
		}
		else
		{
			initializeFunctionRandomly( freeParams );
		}
	}
	
	/**
	 * This method initializes the function using the data in some way.
	 * 
	 * @param index the class index 
	 * @param freeParams if <code>true<code>, the (reduced) parameterization is used
	 * @param data the data
	 * @param weights the weights
	 * 
	 * @throws Exception if the initialization could not be done
	 * 
	 * @see de.jstacs.scoringFunctions.ScoringFunction#initializeFunction(int, boolean, Sample[], double[][])
	 */
	protected abstract void initializeUsingPlugIn( int index, boolean freeParams, Sample[] data, double[][] weights ) throws Exception;
	
	public void initializeFunctionRandomly( boolean freeParams ) throws Exception
	{
		for( int i = 0; i < function.length; i++ )
		{
			function[i].initializeFunctionRandomly( freeParams );
		}
		if( optimizeHidden )
		{
			initializeHiddenPotentialRandomly();
		}
		init( freeParams );
	}
	
	/**
	 * This method initializes the hidden potential (and the corresponding parameters) randomly. 
	 */
	protected void initializeHiddenPotentialRandomly()
	{
		double[] h = new double[this.getNumberOfComponents()];
		if( getEss() == 0 )
		{
			Arrays.fill( h, 1 );
		}
		else
		{
			for( int j = 0; j < h.length; j++ )
			{
				h[j] = getHyperparameterForHiddenParameter( j );
			}
		}			
		DirichletMRGParams param = new DirichletMRGParams( h );
		DirichletMRG.DEFAULT_INSTANCE.generate( hiddenPotential, 0, hiddenPotential.length, param );
		computeHiddenParameter( hiddenPotential );
	}
	
	public boolean isInitialized()
	{
		int i = 0;
		while( i < function.length && function[i].isInitialized() )
		{
			i++;
		}
		return paramRef != null && i == function.length;
	}
	
	public void setParameters( double[] params, int start )
	{
		int i = 0;
		for( ; i < function.length; i++ )
		{
			setParametersForFunction( i, params, start + paramRef[i] );
		}
		if( optimizeHidden )
		{
			setHiddenParameters( params, start + paramRef[i] );
		}
		else
		{
			if( isNormalized )
			{
				norm = 1;
			}
			else
			{
				norm = -1;
			}
		}
	}
	
	/**
	 * This method initializes the hidden parameters of the instance uniformly.
	 */
	public void initializeHiddenUniformly()
	{
		System.out.println(  "hier" );
		int i = 0, c = getNumberOfComponents();
		for( ; i < function.length; i++ )
		{
			if( function[i] instanceof AbstractMixtureScoringFunction )
			{
				((AbstractMixtureScoringFunction) function[i]).initializeHiddenUniformly();
			}
		}
		if( optimizeHidden )
		{
			double[] pars = new double[c];
			double d;
			if( freeParams )
			{
				d = getNormalizationConstantForComponent(c);
			}
			else
			{
				d = 1d;
			}
			for( i = 0; i < c; i++ )
			{
				pars[i] = Math.log( d / getNormalizationConstantForComponent(i) );
			}
			setHiddenParameters( pars, 0 );
		}
		init( freeParams );
	}
	
	/**
	 * This method set the hidden parameters of the model
	 * 
	 * @param params the parameter vector
	 * @param start the start index
	 */
	protected void setHiddenParameters( double[] params, int start )
	{
		int i, len = hiddenParameter.length - (freeParams?1:0);
		double z = 0;
		for( i = 0; i < len; i++, start++ )
		{
			hiddenParameter[i] = logHiddenPotential[i] = params[start];
			hiddenPotential[i] = Math.exp( logHiddenPotential[i] );
			z += hiddenPotential[i];
		}
		if( freeParams )
		{
			hiddenParameter[i] = logHiddenPotential[i] = 0;
			hiddenPotential[i] = 1;
			z++;
		}
		if( isNormalized )
		{
			logHiddenNorm = Math.log( z );
			for( i = 0; i < len; i++ )
			{
				logHiddenPotential[i] -= logHiddenNorm;
				hiddenPotential[i] /= z;
				partNorm[i] = hiddenPotential[i];
			}
			if( freeParams )
			{
				logHiddenPotential[i] -= logHiddenNorm;
				hiddenPotential[i] /= z;
				partNorm[i] = hiddenPotential[i];
			}
			norm = 1;
		}
		else
		{
			norm = -1;
		}
	}
	
	/**
	 * This method allows to set the parameters for specific functions.
	 * 
	 * @param index the function index
	 * @param params the parameter vector
	 * @param start the start index
	 */
	protected void setParametersForFunction( int index, double[] params, int start )
	{
		function[index].setParameters( params, start );
	}

	public final StringBuffer toXML()
	{
		StringBuffer b = new StringBuffer( 10000 );
		XMLParser.appendIntWithTags( b, length, "length" );
		XMLParser.appendIntWithTags( b, starts, "starts" );
		XMLParser.appendBooleanWithTags( b, freeParams, "freeParams" );
		XMLParser.appendStorableArrayWithTags( b, function, "function" );
		XMLParser.appendBooleanWithTags( b, optimizeHidden, "optimizeHidden" );
		XMLParser.appendBooleanWithTags( b, plugIn, "plugIn" );
		XMLParser.appendDoubleArrayWithTags( b, hiddenParameter, "hiddenParameter" );
		b.append( getFurtherInformation() );
		XMLParser.addTags( b, getXMLTag() );
		return b;
	}

	protected final void fromXML( StringBuffer b ) throws NonParsableException
	{
		StringBuffer xml = XMLParser.extractForTag( b, getXMLTag() );
		length = XMLParser.extractIntForTag( xml, "length" );
		starts = XMLParser.extractIntForTag( xml, "starts" );
		freeParams = XMLParser.extractBooleanForTag( xml, "freeParams" );
		function = (NormalizableScoringFunction[]) ArrayHandler.cast( XMLParser.extractStorableArrayForTag( xml, "function" ) );
		alphabets = function[0].getAlphabetContainer();
		isNormalized = isNormalized( function );
		optimizeHidden = XMLParser.extractBooleanForTag( xml, "optimizeHidden" );
		plugIn = XMLParser.extractBooleanForTag( xml, "plugIn" );
		hiddenParameter = XMLParser.extractDoubleArrayForTag( xml, "hiddenParameter" );
		hiddenPotential = new double[hiddenParameter.length];
		logHiddenPotential = new double[hiddenParameter.length];
		partNorm = new double[logHiddenPotential.length];
		setHiddenParameters( hiddenParameter, 0 );
		componentScore = new double[logHiddenPotential.length];
		
		extractFurtherInformation( xml );
		init( freeParams );
		computeLogGammaSum();
	}

	/**
	 * This method is used to append further information of the instance to the xml representation.
	 * This method is designed to allow subclass to add information. 
	 * 
	 * @return the further information as XML in a {@link StringBuffer}
	 */
	protected StringBuffer getFurtherInformation()
	{
		return new StringBuffer( 1 );
	}

	/**
	 * This method is the opposite of {@link AbstractMixtureScoringFunction#getFurtherInformation()}.
	 * 
	 * @param xml the StringBuffer containing the information
	 * 
	 * @throws NonParsableException if the StringBuffer could not be parsed
	 */
	protected void extractFurtherInformation( StringBuffer xml ) throws NonParsableException
	{
	}

	/**
	 * This array is used to compute the relative indices of a parameter index.
	 * 
	 * @param index the parameter index
	 * 
	 * @return the indices
	 * 
	 * @see AbstractMixtureScoringFunction#paramRef
	 */
	protected int[] getIndices( int index )
	{
		int[] erg = { 0, -1 };
		while( index >= paramRef[erg[0]] )
		{
			erg[0]++;
		}
		erg[0]--;
		erg[1] = index - paramRef[erg[0]];
		return erg;
	}

	/**
	 * This method returns the XML tag of the instance that is used to build and XML representation 
	 * 
	 * @return the XML tag of the instance
	 */
	protected String getXMLTag()
	{
		return getClass().getSimpleName();
	}

	/**
	 * This method creates the underlying structure for the parameters.
	 * 
	 * @param freeParams whether to use free parameters or all
	 */
	protected void init( boolean freeParams )
	{
		initWithLength( freeParams, function.length + 2 );
	}
	
	/**
	 * This method is used to create the underlying structure, e.g.
	 * {@link AbstractMixtureScoringFunction#paramRef}
	 * 
	 * @param freeParams whether to use free parameters or all
	 * @param len the length of the paramRef array
	 */
	protected final void initWithLength( boolean freeParams, int len )
	{
		if( paramRef == null || paramRef.length != len )
		{
			paramRef = new int[len];
		}
		int h, i = 0;
		if( iList == null )
		{
			iList = new IntList[Math.max( function.length, hiddenParameter.length )];
			dList = new DoubleList[iList.length];
			for( ; i < iList.length; i++ )
			{
				iList[i] = new IntList();
				dList[i] = new DoubleList();
			}
		}
		for( i = 0; i < function.length; i++ )
		{
			h = function[i].getNumberOfParameters();
			if( h != UNKNOWN )
			{
				paramRef[i + 1] = paramRef[i] + function[i].getNumberOfParameters();
			}
			else
			{
				paramRef = null;
				return;
			}
		}
		if( optimizeHidden )
		{
			paramRef[i + 1] = paramRef[i] + hiddenParameter.length - (freeParams ? 1 : 0);
		}
		else
		{
			paramRef[i + 1] = paramRef[i];
		}
		this.freeParams = freeParams;
	}

	/**
	 * This method has to be invoked during an initialization.
	 * 
	 * @param statistic a statistic for the initialization of the hidden parameters
	 * 
	 * @see de.jstacs.scoringFunctions.ScoringFunction#initializeFunction(int, boolean, Sample[], double[][])
	 */
	protected void computeHiddenParameter( double[] statistic )
	{
		int i, j;
		for( i = 0; i < hiddenParameter.length; i++ )
		{
			statistic[i] += getHyperparameterForHiddenParameter( i );
		}
		if( freeParams )
		{
			j = hiddenParameter.length - 1;
			hiddenParameter[j] = Math.log( statistic[j] );
			for( i = 0; i < j; i++ )
			{
				hiddenParameter[i] = Math.log( statistic[i] ) - hiddenParameter[j];
			}
			hiddenParameter[j] = 0;
		}
		else
		{
			double sum = 0;
			for( i = 0; i < hiddenParameter.length; i++ )
			{
				sum += statistic[i];
			}
			sum = Math.log( sum );
			for( i = 0; i < hiddenParameter.length; i++ )
			{
				hiddenParameter[i] = Math.log( statistic[i] ) - sum;
			}
		}
		setHiddenParameters( hiddenParameter, 0 );
	}

	/**
	 * Precomutes the normalisation constant.
	 */
	protected void precomputeNorm()
	{
		norm = 0;
		for( int i = 0; i < logHiddenPotential.length; i++ )
		{
			partNorm[i] = hiddenPotential[i] * getNormalizationConstantForComponent( i );
			norm += partNorm[i];
		}
	}

	/**
	 * Computes the normalization constant for the component <code>i</code>
	 * 
	 * @param i
	 *            the index of the component
	 * 
	 * @return the normalization constant
	 */
	protected abstract double getNormalizationConstantForComponent( int i );

	/**
	 * Fills the internal array <code>componentScore</code> with the log scores of the components.
	 * 
	 * @param seq
	 *            the sequence
	 * @param start
	 *            the start position
	 */
	protected abstract void fillComponentScores( Sequence seq, int start );
	
	public boolean isNormalized()
	{
		return isNormalized;
	}
	
	/**
	 * This method returns a specific internal function
	 * 
	 * @param index the index of the function
	 * 
	 * @return a clone of the function 
	 * 
	 * @throws CloneNotSupportedException if the function could not be cloned
	 */
	public NormalizableScoringFunction getFunction( int index ) throws CloneNotSupportedException
	{
		return (NormalizableScoringFunction) function[index].clone();
	}
	
	/**
	 * This method returns an array of clones of the internal used functions.
	 * 
	 * @return an array of clones of the internal used functions
	 * 
	 * @throws CloneNotSupportedException if at least one function could not be cloned
	 */
	public NormalizableScoringFunction[] getFunctions() throws CloneNotSupportedException
	{
		return ArrayHandler.clone( function );
	}
}
