/*
 * This file is part of Jstacs.
 *
 * Jstacs is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 * 
 * Jstacs is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Jstacs.  If not, see <http://www.gnu.org/licenses/>.
 * 
 * For more information on Jstacs, visit http://www.jstacs.de
 */

package de.jstacs.scoringFunctions;

import java.util.AbstractList;
import java.util.Arrays;

import de.jstacs.NonParsableException;
import de.jstacs.Storable;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.Sample;
import de.jstacs.data.Sequence;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.XMLParser;
import de.jstacs.models.discrete.ConstraintManager;
import de.jstacs.models.discrete.inhomogeneous.MEMConstraint;
import de.jstacs.models.discrete.inhomogeneous.SequenceIterator;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;

/**
 * This class implements the scoring function for any MRF.
 * 
 * @author Jens Keilwagen
 */
public final class MRFScoringFunction extends AbstractNormalizableScoringFunction
{
	private MEMConstraint[] constr;

	private String name;

	private boolean freeParams;

	private int[] offset, help;
	
	private double ess, norm;
	private double[][] partNorm;
	private SequenceIterator seqIt;

	/**
	 * This constructor creates an instance with ess 0.
	 * 
	 * @param alphabets the AlphabetContainer
	 * @param length the length of the sequences respectively the model
	 * @param constr the constraints that are used for the model, see {@link de.jstacs.models.discrete.ConstraintManager#extract(int, String)}
	 * 
	 * @see MRFScoringFunction#MRFScoringFunction(AlphabetContainer, int, double, String)
	 */
	public MRFScoringFunction( AlphabetContainer alphabets, int length, String constr )
	{
		this( alphabets, length, 0, constr );
	}
	
	/**
	 * This is the main constructor.
	 * 
	 * @param alphabets the AlphabetContainer
	 * @param length the length of the sequences respectively the model
	 * @param ess the equivalent sample size (ess)
	 * @param constr the constraints that are used for the model, see {@link de.jstacs.models.discrete.ConstraintManager#extract(int, String)}
	 */
	public MRFScoringFunction( AlphabetContainer alphabets, int length, double ess, String constr )
	{
		super( alphabets, length );
		if( !alphabets.isDiscrete() )
		{
			throw new IllegalArgumentException( "The AlphabetContainer has to be discrete." );
		}
		if( ess < 0 )
		{
			throw new IllegalArgumentException( "The ess has to be non-negative." );
		}
		this.ess = ess;
		int[] aLength = new int[length];
		for( int i = 0; i < length; aLength[i] = (int) alphabets.getAlphabetLengthAt( i++ ) );
		AbstractList<int[]> list = ConstraintManager.extract( length, constr );
		ConstraintManager.reduce( list );
		this.constr = ConstraintManager.createConstraints( list, aLength );
		this.name = constr;
		freeParams = false;
		getNumberOfParameters();
		init(-1);
	}

	/**
	 * This is the constructor for {@link Storable}.
	 * 
	 * @param source the xml representation
	 * 
	 * @throws NonParsableException if the representation could not be parsed.
	 */
	public MRFScoringFunction( StringBuffer source ) throws NonParsableException
	{
		super( source );
	}
	
	private void init( double n )
	{
		norm = n;
		if( partNorm == null )
		{
			partNorm = new double[constr.length][];
			for( int i = 0; i < partNorm.length; i++ )
			{
				partNorm[i] = new double[constr[i].getNumberOfSpecificConstraints()];
			}
			help = new int[2];
			int[] aLength = new int[length];
			for( int i = 0; i < length; aLength[i] = (int) alphabets.getAlphabetLengthAt( i++ ) );
			seqIt = new SequenceIterator( length );
			seqIt.setBounds( aLength );
		}
		else
		{
			for( int i = 0; i < partNorm.length; i++ )
			{
				Arrays.fill( partNorm[i], 0 );
			}
		}
	}

	protected void fromXML( StringBuffer representation ) throws NonParsableException
	{
		StringBuffer xml = XMLParser.extractForTag( representation, XML_TAG );
		length = XMLParser.extractIntForTag( xml, "length" );
		alphabets = (AlphabetContainer) XMLParser.extractStorableForTag( xml, "alphabets" );
		ess = XMLParser.extractDoubleForTag( xml, "ess" );
		name = XMLParser.extractStringForTag( xml, "name" );
		constr = (MEMConstraint[]) ArrayHandler.cast( XMLParser.extractStorableArrayForTag( xml, "constr" ) );
		freeParams = XMLParser.extractBooleanForTag( xml, "freeParams" );
		getNumberOfParameters();
		init(-1);
	}

	public MRFScoringFunction clone() throws CloneNotSupportedException
	{
		MRFScoringFunction clone = (MRFScoringFunction) super.clone();
		clone.constr = ArrayHandler.clone( constr );
		clone.init(-1);
		if( norm > 0 )
		{
			clone.norm = norm;
			for( int i = 0; i < partNorm.length; i++ )
			{
				System.arraycopy( partNorm[i], 0, clone.partNorm[i], 0, partNorm[i].length );
			}
		}
		clone.help = help.clone();
		clone.getNumberOfParameters();
		return clone;
	}

	public double getLogScore( Sequence seq, int start )
	{
		double erg = 0;
		for( int i = 0; i < constr.length; i++ )
		{
			erg += constr[i].getLambda( constr[i].satisfiesSpecificConstraint( seq, start ) );
		}
		return erg;
	}
	
	public double getLogScoreAndPartialDerivation( Sequence seq, int start, IntList indices, DoubleList partialDer )
	{
		double erg = 0;
		int i = 0, j, z;
		for( ; i < constr.length; i++ )
		{
			j = constr[i].satisfiesSpecificConstraint( seq, start );
			if( (z = offset[i] + j) < offset[i + 1] )
			{
				indices.add( z );
				partialDer.add( 1 );
			}
			erg += constr[i].getLambda( j );
		}
		return erg;
	}

	public int getNumberOfParameters()
	{
		if( offset == null )
		{
			int i = 0, anz = 0;
			offset = new int[constr.length + 1];
			while( i < constr.length )
			{
				anz += constr[i++].getNumberOfSpecificConstraints();
				if( freeParams )
				{
					anz -= 1;
				}
				offset[i] = anz;
			}
		}

		return offset[constr.length];
	}

	public String getInstanceName()
	{
		return name;
	}

	public void setParameters( double[] params, int start )
	{
		norm = -1;
		int i = 0, j, s = 0;
		for( ; i < constr.length; i++ )
		{
			constr[i].setLambda( constr[i].getNumberOfSpecificConstraints() - 1, 0 );
			for( j = 0; s < offset[i + 1]; s++, j++ )
			{
				constr[i].setLambda( j, params[start + s] );
			}
		}
	}
	
	/*
	public void showParameters()
	{
		int i = 0, j;
		for( ; i < constr.length; i++ )
		{
			for( j = 0; j < constr[i].getNumberOfSpecificConstraints(); j++ )
			{
				System.out.print( constr[i].getLambda( j ) + "\t" ); 
			}
			System.out.println();
		}
	}*/

	private static final String XML_TAG = "MRFScoringFunction";

	public StringBuffer toXML()
	{
		StringBuffer b = new StringBuffer( 10000 );
		XMLParser.appendIntWithTags( b, length, "length" );
		XMLParser.appendStorableWithTags( b, alphabets, "alphabets" );
		XMLParser.appendDoubleWithTags( b, ess, "ess" );
		XMLParser.appendStringWithTags( b, name, "name" );
		XMLParser.appendStorableArrayWithTags( b, constr, "constr" );
		XMLParser.appendBooleanWithTags( b, freeParams, "freeParams" );
		XMLParser.addTags( b, XML_TAG );
		return b;
	}

	public void initializeFunction( int index, boolean freeParams, Sample[] data, double[][] weights ) throws Exception
	{
		if( this.freeParams != freeParams )
		{
			offset = null;
			this.freeParams = freeParams;
			getNumberOfParameters();
		}
		double d = 0;
		for( int i = 0; i < length; i++ )
		{
			d -=Math.log( alphabets.getAlphabetLengthAt(i) );
		}
		d /= constr.length;
		for( int k, j, i = 0; i < constr.length; i++ )
		{
			k = constr[i].getNumberOfSpecificConstraints();
			for( j = 0; j < k; j++ )
			{
				constr[i].setLambda( j, d );
			}
		}
		/*
		MEMConstraint[] c2 = MEMConstraint.clone( constr );
		for( int i = 0; i < c2.length; i++ )
		{
			c2[i].reset();
		}
		double sum1=0, sum2 = 0;
		for( int i = 0; i < data.length; i++ )
		{
			if( i == index )
			{
				sum1 = ConstraintManager.countInhomogeneous( alphabets, length, data[index], weights[index], true, constr );
			}
			else
			{
				sum2 += ConstraintManager.countInhomogeneous( alphabets, length, data[i], weights[i], false, c2 );
			}
		}
		
		double d1, d2, ess1=1, ess2=1, pc1, pc2, s, a, b, cl1 = sum1/(sum1+sum2), cl2 = sum2/(sum1+sum2);
		for( int j, i = 0; i < constr.length; i++ )
		{
			System.out.println( constr[i] );
			pc1 = ess1/ constr[i].getNumberOfSpecificConstraints();
			pc2 = ess2/ c2[i].getNumberOfSpecificConstraints();
			for( j = 0; j < constr[i].getNumberOfSpecificConstraints(); j++ )
			{
				d1 = constr[i].getCount(j);
				d2 = c2[i].getCount(j);
				a = d1 * Math.log( cl1 ) +  d2*Math.log(cl2);
				b = Gamma.logOfGamma(pc1+pc2) - Gamma.logOfGamma(pc1) - Gamma.logOfGamma(pc2)
					+ Gamma.logOfGamma(d1+pc1) + Gamma.logOfGamma( d2 + pc2 ) - Gamma.logOfGamma(d1+d2+pc1+pc2); 
				System.out.println( d1 + "\t" + d2 + "\t" + b + "\t" + a + "\t\t" + (b-a) );
			}
			System.out.println();
		}
		*/
		
		// does nothing since we have fixed structure
	}
	
	public void initializeFunctionRandomly( boolean freeParams ) throws Exception
	{	
		if( this.freeParams != freeParams )
		{
			offset = null;
			this.freeParams = freeParams;
			getNumberOfParameters();
		}
		for( int k, j, i = 0; i < constr.length; i++ )
		{
			k = constr[i].getNumberOfSpecificConstraints();
			for( j = 0; j < k; j++ )
			{
				constr[i].setLambda( j, r.nextDouble() );
			}
		}
	}

	public double getNormalizationConstant()
	{
		if( norm < 0 )
		{
			precompute();
		}
		return norm;
	}

	public double getPartialNormalizationConstant( int parameterIndex ) throws Exception
	{
		if( norm < 0 )
		{
			precompute();
		}
		computeIndices( parameterIndex );
		return partNorm[help[0]][help[1]];
	}
	
	private void precompute()
	{
		init(0);
		//TODO current implementation is only the naive approach, so try to make this faster?!?
		seqIt.reset();
		int i = 0;
		int[] fulfilled = new int[constr.length];
		seqIt.reset();
		double s;
		do
		{
			s = getScore( fulfilled, seqIt );
			for( i = 0; i < constr.length; i++ )
			{
				partNorm[i][fulfilled[i]] += s;
			}
			norm += s;
		}while( seqIt.next() );
	}
	
	private double getScore( int[] fulfilled, SequenceIterator sequence )
	{
		double s = 1;
		for( int counter = 0; counter < constr.length; counter++ )
		{
			fulfilled[counter] = constr[counter].satisfiesSpecificConstraint( sequence );
			s *= constr[counter].getExpLambda( fulfilled[counter] );
		}
		return s;
	}

	public double getEss()
	{
		return ess;
	}

	public int getSizeOfEventSpaceForRandomVariablesOfParameter( int index )
	{
		computeIndices( index );
		return constr[help[0]].getNumberOfSpecificConstraints();
	}
	
	private void computeIndices( int index )
	{
		help[0] = 0;
		while( index >= offset[help[0]] )
		{
			help[0]++;
		}
		help[0]--;
		help[1] = index - offset[help[0]];
	}
	
	public double getLogPriorTerm()
	{
		double logPriorTerm = 0, d;
		int i = 0, j, s = 0;
		for( ; i < constr.length; i++ )
		{
			d = ess / constr[i].getNumberOfSpecificConstraints();
			for( j = 0; s < offset[i + 1]; s++, j++ )
			{
				logPriorTerm += constr[i].getLambda( j ) * d;
			}
		}
		return logPriorTerm;
	}

	public void addGradientOfLogPriorTerm( double[] grad, int start )
	{
		double d;
		int i = 0, j, s = 0;
		for( ; i < constr.length; i++ )
		{
			d = ess / constr[i].getNumberOfSpecificConstraints();
			for( j = 0; s < offset[i + 1]; s++, j++ )
			{
				grad[start+s] += d;
			}
		}		
	}
	
	public double[] getCurrentParameterValues() throws Exception
	{
		double[] start =  new double[offset[constr.length]];
		for( int k, j, i = 0, n = 0; i < constr.length; i++ )
		{
			k = constr[i].getNumberOfSpecificConstraints();
			for( j = 0; j < k; j++, n++ )
			{
				start[n] = constr[i].getLambda( j );
			}
		}
		return start;		
	}

	public boolean isInitialized()
	{
		return true;
	}	
}
