/*
 * This file is part of Jstacs.
 *
 * Jstacs is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 * 
 * Jstacs is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Jstacs.  If not, see <http://www.gnu.org/licenses/>.
 * 
 * For more information on Jstacs, visit http://www.jstacs.de
 */

package de.jstacs.scoringFunctions.homogeneous;

import java.util.Arrays;

import de.jstacs.NonParsableException;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.Sample;
import de.jstacs.data.Sequence;
import de.jstacs.io.XMLParser;
import de.jstacs.models.discrete.inhomogeneous.MEMConstraint;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.random.DirichletMRG;
import de.jstacs.utils.random.FastDirichletMRGParams;
import de.jtem.numericalMethods.calculus.specialFunctions.Gamma;

/**
 * This scoring function implements a homogeneous Markov model of order zero (hMM(0)) for a fixed sequence length.
 * 
 * @author Jens Keilwagen
 */
public class HMM0ScoringFunction extends HomogeneousScoringFunction
{
	private double ess, norm, sumOfHyperParams, logGammaSum;
	private int[] counter;

	private boolean freeParams, plugIn, optimize;

	private MEMConstraint params;
	
	private int anz;

	/**
	 * The main constructor that creates an instance of a homogeneous Markov model of order 0.
	 *  
	 * @param alphabets the AlphabetContainer of the model
	 * @param length the length of sequences respectively the model
	 * @param ess the equivalent sample size (ess)
	 * @param plugIn whether to use a plug-in strategy to initialize the parameters
	 * @param optimize whether to optimize the parameters or not after they have been initialized
	 */
	public HMM0ScoringFunction( AlphabetContainer alphabets, int length, double ess, boolean plugIn, boolean optimize )
	{
		super( alphabets, length );
		if( ess < 0 )
		{
			throw new IllegalArgumentException( "The ess has to be non-negative." );
		}
		this.ess = ess;
		sumOfHyperParams = ess*length;
		params = new MEMConstraint( new int[]{ 0 }, new int[]{ (int) alphabets.getAlphabetLengthAt( 0 ) } );
		this.plugIn = plugIn;
		this.optimize = optimize;
		setFreeParams( false );
		norm = 1;
		double d = -Math.log( alphabets.getAlphabetLengthAt(0) );
		for( int i = 0; i < counter.length; i++ )
		{
			params.setLambda( i, d );
		}
		computeConstantsOfLogPrior();
	}

	/**
	 * This is the constructor for {@link de.jstacs.Storable}.
	 * 
	 * @param xml the xml representation
	 * 
	 * @throws NonParsableException if the representation could not be parsed.
	 */
	public HMM0ScoringFunction( StringBuffer xml ) throws NonParsableException
	{
		super( xml );
	}

	public HMM0ScoringFunction clone() throws CloneNotSupportedException
	{
		HMM0ScoringFunction clone = (HMM0ScoringFunction) super.clone();
		clone.params = params.clone();
		clone.counter = counter.clone();
		return clone;
	}

	public String getInstanceName()
	{
		return "hMM(0)";
	}

	public double getLogScore( Sequence seq, int start, int length )
	{
		double erg = 0;
		for( int l = 0; l < length; l++ )
		{
			erg += params.getLambda( params.satisfiesSpecificConstraint( seq, start+l ) );
		}
		return erg;
	}

	public double getLogScoreAndPartialDerivation( Sequence seq, int start, int length, IntList indices, DoubleList dList )
	{
		Arrays.fill( counter, 0 );
		int l = 0;
		for( ; l < length; l++ )
		{
			counter[params.satisfiesSpecificConstraint( seq, start+l )]++;
		}
		double erg = 0;
		for( l = 0; l < counter.length; l++ )
		{
			if( counter[l] > 0 )
			{
				erg += counter[l] * params.getLambda(l);
				if( l < anz )
				{	
					indices.add( l );
					dList.add( counter[l] );
				}
			}
		}
		return erg;
	}

	public int getNumberOfParameters()
	{
		return anz;
	}

	public void setParameters( double[] params, int start )
	{
		if( optimize )
		{
			norm = 0;
			for( int i = 0; i < anz; i++ )
			{
				this.params.setLambda( i, params[start+i] );
				norm += this.params.getExpLambda(i);
			}
			if( anz < counter.length )
			{
				norm += this.params.getExpLambda(anz);
			}
		}
	}

	public StringBuffer toXML()
	{
		StringBuffer b = new StringBuffer( 1000 );
		XMLParser.appendIntWithTags( b, length, "length" );
		XMLParser.appendStorableWithTags( b, alphabets, "alphabets" );
		XMLParser.appendDoubleWithTags( b, ess, "ess" );
		XMLParser.appendDoubleWithTags( b, sumOfHyperParams, "sumOfHyperParams" );
		XMLParser.appendStorableWithTags( b, params, "params" );
		XMLParser.appendBooleanWithTags( b, freeParams, "freeParams" );
		XMLParser.appendBooleanWithTags( b, plugIn, "plugIn" );
		XMLParser.appendBooleanWithTags( b, optimize, "optimize" );
		XMLParser.addTags( b, getClass().getSimpleName() );
		return b;
	}

	public double[] getCurrentParameterValues()
	{
		double[] erg = new double[anz];
		for( int i = 0; i < anz; i++ )
		{
			erg[i] = params.getLambda( i );
		}
		return erg;
	}
	
	public double[] getStationarySymbolDistribution()
	{
		double[] erg = new double[params.getNumberOfSpecificConstraints()];
		double norm = getNormalizationConstant(1);
		for(int i=0;i<erg.length;i++){
			erg[i] = params.getExpLambda(i)/norm;
		}
		return erg;
	}
	
	public void initializeFunction( int index, boolean freeParams, Sample[] data, double[][] weights )
	{
		params.reset();
		if( plugIn )
		{
			if( data != null && data[index] != null )
			{
				Sequence seq;
				for( int k, l, i = 0; i < data[index].getNumberOfElements(); i++ )
				{
					seq = data[index].getElementAt(i);
					l = seq.getLength();
					for( k = 0; k < l; k++ )
					{
						params.add( seq.discreteVal(k), weights[index][i] );
					}
				}
			}
			params.estimate( sumOfHyperParams );
			for( int i = 0; i < counter.length; i++ )
			{
				params.setExpLambda( i, params.getFreq(i) );
			}
		}
		else
		{
			double d = -Math.log( alphabets.getAlphabetLengthAt(0) );
			for( int i = 0; i < counter.length; i++ )
			{
				params.setLambda( i, d );
			}
		}
		norm = 1;
		setFreeParams( freeParams );		
	}
	
	public void initializeFunctionRandomly( boolean freeParams )
	{
		int n = counter.length;
		double[] p = DirichletMRG.DEFAULT_INSTANCE.generate( n, new FastDirichletMRGParams( sumOfHyperParams==0?1:(sumOfHyperParams/(double)n) ) );
		for( int i = 0; i < n; i++ )
		{
			params.setExpLambda( i, p[i] );
		}
		norm = 1;
		setFreeParams( freeParams );
	}

	protected void fromXML( StringBuffer xml ) throws NonParsableException
	{
		StringBuffer b = XMLParser.extractForTag( xml, getClass().getSimpleName() );
		length = XMLParser.extractIntForTag( b, "length" );
		alphabets = (AlphabetContainer) XMLParser.extractStorableForTag( b, "alphabets" );
		ess = XMLParser.extractDoubleForTag( b, "ess" );
		sumOfHyperParams = XMLParser.extractDoubleForTag( b, "sumOfHyperParams" );
		params = (MEMConstraint) XMLParser.extractStorableForTag( b, "params" );
		plugIn = XMLParser.extractBooleanForTag( b, "plugIn" );
		optimize = XMLParser.extractBooleanForTag( b, "optimize" );
		setFreeParams( XMLParser.extractBooleanForTag( b, "freeParams" ) );
		for( int i = 0; i < params.getNumberOfSpecificConstraints(); norm += params.getExpLambda(i++) );
		computeConstantsOfLogPrior();
	}
	
	private void setFreeParams( boolean freeParams )
	{
		this.freeParams = freeParams;
		counter = new int[params.getNumberOfSpecificConstraints()];
		if( optimize )
		{
			anz = counter.length - (freeParams?1:0);
		}
		else
		{
			anz = 0;
		}
		//TODO OK?
		if(freeParams){
			for(int i=0;i<params.getNumberOfSpecificConstraints();i++){
				params.setLambda(i, params.getLambda(i) - params.getLambda(params.getNumberOfSpecificConstraints() - 1));
			}
		}
	}

	public int getSizeOfEventSpaceForRandomVariablesOfParameter( int index )
	{
		if( index < anz )
		{
			return params.getNumberOfSpecificConstraints();
		}
		else
		{
			throw new IndexOutOfBoundsException();
		}
	}

	public double getNormalizationConstant( int length )
	{
		if( length == 0 )
		{
			throw new RuntimeException( "The normalization constant can not be computed for length 0." );
		}
		else
		{
			//System.out.println( norm + " ^ " + length + "\t" + Math.pow( norm, length ) );
			return Math.pow( norm, length );
		}
	}
	
	public double getPartialNormalizationConstant( int parameterIndex, int length ) throws Exception
	{
		if( parameterIndex < anz )
		{
			double erg = length * Math.pow(norm, length-1) * params.getExpLambda(parameterIndex);
			return erg; 
		}
		else
		{
			throw new IndexOutOfBoundsException();
		}
	}

	public double getEss()
	{
		return ess;
	}
	
	public String toString()
	{
		StringBuffer info = new StringBuffer(100);
		info.append( alphabets.getSymbol( 0, 0 ) + ": " + (params.getExpLambda( 0 )/norm) );
		for( int i = 1; i < params.getNumberOfSpecificConstraints(); i++ )
		{
			info.append( "\t" + alphabets.getSymbol( 0, i ) + ": " + (params.getExpLambda( i )/norm) );
		}
		return info.toString();
	}
	
	public double getLogPriorTerm()
	{
		if( optimize )
		{
			double val = 0;
			int n = params.getNumberOfSpecificConstraints(), i = 0;
			while( i < n )
			{
				val += params.getLambda(i++); 
			}
			return (val * sumOfHyperParams / (double) n) + logGammaSum;
		}
		return 0;
	}
	
	private void computeConstantsOfLogPrior()
	{
		int anz = params.getNumberOfSpecificConstraints();
		logGammaSum = Gamma.logOfGamma( sumOfHyperParams )
					- anz * Gamma.logOfGamma( sumOfHyperParams / (double) anz );
	}

	public void addGradientOfLogPriorTerm( double[] grad, int start )
	{
		double d = sumOfHyperParams / (double) params.getNumberOfSpecificConstraints();
		for( int i = 0; i < anz; i++ )
		{
			grad[start+i] += d;
		}
	}

	public boolean isInitialized()
	{
		return true;
	}

	public int getMaximalMarkovOrder()
	{
		return 0;
	}

	public void setStatisticForHyperparameters( int[] length, double[] weight ) throws Exception
	{
		if( weight.length != length.length )
		{
			throw new IllegalArgumentException( "The length of both arrays (length, weight) have to be identical." );
		}
		sumOfHyperParams = 0;
		for( int i = 0; i < length.length; i++ )
		{
			if( weight[i] < 0 || length[i] < 0 )
			{
				throw new IllegalArgumentException( "check length and weight for entry " + i );
			}
			else
			{
				sumOfHyperParams += length[i]*weight[i];  
			}
		}
		computeConstantsOfLogPrior();
	}
}
