/*
 * This file is part of Jstacs.
 * 
 * Jstacs is free software: you can redistribute it and/or modify it under the
 * terms of the GNU General Public License as published by the Free Software
 * Foundation, either version 3 of the License, or (at your option) any later
 * version.
 * 
 * Jstacs is distributed in the hope that it will be useful, but WITHOUT ANY
 * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
 * A PARTICULAR PURPOSE. See the GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License along with
 * Jstacs. If not, see <http://www.gnu.org/licenses/>.
 * 
 * For more information on Jstacs, visit http://www.jstacs.de
 */

package de.jstacs.models;

import java.io.OutputStream;

import de.jstacs.NonParsableException;
import de.jstacs.NotTrainedException;
import de.jstacs.WrongAlphabetException;
import de.jstacs.algorithms.optimization.LimitedMedianStartDistance;
import de.jstacs.algorithms.optimization.NegativeDifferentiableFunction;
import de.jstacs.algorithms.optimization.Optimizer;
import de.jstacs.algorithms.optimization.StartDistanceForecaster;
import de.jstacs.algorithms.optimization.termination.AbstractTerminationCondition;
import de.jstacs.algorithms.optimization.termination.SmallDifferenceOfFunctionEvaluationsCondition;
import de.jstacs.classifier.scoringFunctionBased.OptimizableFunction.KindOfParameter;
import de.jstacs.classifier.scoringFunctionBased.gendismix.LearningPrinciple;
import de.jstacs.classifier.scoringFunctionBased.gendismix.LogGenDisMixFunction;
import de.jstacs.classifier.scoringFunctionBased.logPrior.CompositeLogPrior;
import de.jstacs.classifier.scoringFunctionBased.logPrior.LogPrior;
import de.jstacs.data.Sample;
import de.jstacs.data.Sequence;
import de.jstacs.data.WrongLengthException;
import de.jstacs.data.Sample.WeightedSampleFactory;
import de.jstacs.data.Sample.WeightedSampleFactory.SortOperation;
import de.jstacs.io.XMLParser;
import de.jstacs.results.NumericalResultSet;
import de.jstacs.scoringFunctions.IndependentProductScoringFunction;
import de.jstacs.scoringFunctions.NormalizableScoringFunction;
import de.jstacs.scoringFunctions.UniformScoringFunction;
import de.jstacs.scoringFunctions.homogeneous.UniformHomogeneousScoringFunction;
import de.jstacs.utils.SafeOutputStream;

/**
 * This model can be used to use a NormalizableScoringFunction as model.
 * It enables the user to train the NormalizableScoringFunction in a generative way.
 * 
 * @author Jens Keilwagen
 * 
 * @see NormalizableScoringFunction
 * @see LogGenDisMixFunction
 */
public class NormalizableScoringFunctionModel extends AbstractModel
{
	private SafeOutputStream out;
	
	/**
	 * The internally used {@link NormalizableScoringFunction}.
	 */
	protected NormalizableScoringFunction nsf;
	private double logNorm, lineps, startD;
	private AbstractTerminationCondition tc;
	private byte algo;
	private int threads;

	/**
	 * The main constructor that creates an instance with the user given parameters.
	 * 
	 * @param nsf the {@link NormalizableScoringFunction} that should be used
	 * @param threads the number of threads that should be used for optimization
	 * @param algo the algorithm that should be used for the optimization
	 * @param tc the {@link AbstractTerminationCondition} for stopping the optimization
	 * @param lineps the line epsilon for stopping the line search in the optimization
	 * @param startD the start distance that should be used initially
	 * 
	 * @throws CloneNotSupportedException if <code>nsf</code> can not be cloned
	 */
	public NormalizableScoringFunctionModel( NormalizableScoringFunction nsf, int threads, byte algo, AbstractTerminationCondition tc, double lineps, double startD ) throws CloneNotSupportedException
	{
		super( nsf.getAlphabetContainer(), nsf.getLength() );
		if( threads < 1 )
		{
			throw new IllegalArgumentException( "The number of threads has to be positive." );
		}
		this.threads = threads;
		this.tc = tc.clone();
		if( lineps < 0 )
		{
			throw new IllegalArgumentException( "The value of lineps has to be non-negative." );
		}
		this.lineps = lineps;
		if( startD <= 0 )
		{
			throw new IllegalArgumentException( "The value of startD has to be positive." );
		}
		this.startD = startD;
		this.algo = algo;
		this.nsf = (NormalizableScoringFunction) nsf.clone();
		if( isTrained() )
		{
			logNorm = nsf.getLogNormalizationConstant();
		}
		else
		{
			logNorm = Double.NEGATIVE_INFINITY;
		}
		setOutputStream( SafeOutputStream.DEFAULT_STREAM );
	}

	/**
	 * The standard constructor for the interface {@link de.jstacs.Storable}.
	 * Creates a new {@link NormalizableScoringFunctionModel} out of a {@link StringBuffer}.
	 * 
	 * @param stringBuff
	 *            the {@link StringBuffer} to be parsed
	 * 
	 * @throws NonParsableException
	 *             is thrown if the {@link StringBuffer} could not be parsed
	 */
	public NormalizableScoringFunctionModel( StringBuffer stringBuff ) throws NonParsableException
	{
		super( stringBuff );
	}

	public NormalizableScoringFunctionModel clone() throws CloneNotSupportedException
	{
		NormalizableScoringFunctionModel clone = (NormalizableScoringFunctionModel) super.clone();
		clone.nsf = (NormalizableScoringFunction) nsf.clone();
		clone.tc = tc.clone();
		clone.setOutputStream( out.doesNothing() ? null : SafeOutputStream.DEFAULT_STREAM );
		return clone;
	}
	
	public void train( Sample data, double[] weights ) throws Exception
	{
		if( !data.getAlphabetContainer().checkConsistency( alphabets ) )
		{
			throw new WrongAlphabetException( "The AlphabetConatainer of the sample and the model do not match." );
		}
		if( length != 0 && length != data.getElementLength() )
		{
			throw new WrongLengthException( "The length of the elements of the sample is not suitable for the model." );
		}
		
		if( nsf instanceof IndependentProductScoringFunction ) {
			IndependentProductScoringFunction ipsf = (IndependentProductScoringFunction) nsf;
			NormalizableScoringFunction[] nsfs = ipsf.getFunctions();
			Sample[] part = new Sample[1], packedData = { data };
			double[][] partWeights, packedWeights = { weights };
			for( int a, i = 0; i < nsfs.length; i++ ) {
				a = ipsf.extractSequenceParts( i, packedData, part );
				partWeights = ipsf.extractWeights( a, packedWeights );
				nsfs[i] = train( part[0], partWeights[0], nsfs[i] );
			}
			nsf = new IndependentProductScoringFunction( ipsf.getEss(), true, nsfs, ipsf.getIndices(), ipsf.getPartialLengths(), ipsf.getReverseSwitches() );
		} else {
			nsf = train( data, weights, nsf );
		}
	}
	
	private NormalizableScoringFunction train( Sample data, double[] weights, NormalizableScoringFunction nsf ) throws Exception {
		if( !(nsf instanceof UniformScoringFunction || nsf instanceof UniformHomogeneousScoringFunction ) ) {
			WeightedSampleFactory wsf = new WeightedSampleFactory( SortOperation.NO_SORT, data, weights );
			Sample small = wsf.getSample();
			double[] smallWeights = wsf.getWeights();  
			
			double[] params;
			NormalizableScoringFunction best = null;
			double current, max = Double.NEGATIVE_INFINITY, fac = data.getNumberOfElements(), ess = nsf.getEss();
			fac = fac / (ess+ fac) * (ess == 0 ? 1d : 2d);
			
			NormalizableScoringFunction[] score = { (NormalizableScoringFunction) nsf.clone() };
			LogPrior prior = new CompositeLogPrior();
			double[] beta = LearningPrinciple.getBeta( ess == 0 ? LearningPrinciple.ML : LearningPrinciple.MAP );
			LogGenDisMixFunction f = new LogGenDisMixFunction( threads, score, new Sample[]{small}, new double[][]{smallWeights}, prior, beta, true, false );
			NegativeDifferentiableFunction minusF = new NegativeDifferentiableFunction( f );
			StartDistanceForecaster sd =
				//new ConstantStartDistance( startD*fac );
				new LimitedMedianStartDistance( 5, startD*fac );
			for( int i = 0; i < nsf.getNumberOfRecommendedStarts(); i++ )
			{
				out.writeln( "start: " + i );
				//TODO freeParams???
				score[0].initializeFunction( 0, false, new Sample[]{small}, new double[][]{smallWeights} );
				f.reset( score );
				params = f.getParameters( KindOfParameter.PLUGIN );
				sd.reset();
				Optimizer.optimize( algo, minusF, params, tc, lineps*fac, sd, out );
				current = f.evaluateFunction( params );
				if( current > max )
				{
					best = score[0];
					max = current;
				}
				score[0] = (NormalizableScoringFunction) nsf.clone();
			}
			out.writeln( "best: " + max );
			nsf = best;
			logNorm = nsf.getLogNormalizationConstant();
			f.stopThreads();
			System.gc();
		}
		return nsf;
	}

	public double getProbFor( Sequence sequence, int startpos, int endpos ) throws NotTrainedException, Exception
	{
		return Math.exp( getLogProbFor( sequence, startpos, endpos ) );
	}

	public double getLogProbFor( Sequence sequence, int startpos, int endpos ) throws NotTrainedException, Exception
	{
		if( !isTrained() )
		{
			throw new NotTrainedException();
		}
		if( !sequence.getAlphabetContainer().checkConsistency(alphabets) )
		{
			throw new WrongAlphabetException( "The AlphabetContainer of the sequence and the model do not match." );
		}
		if( startpos < 0 )
		{
			throw new IllegalArgumentException( "Check start position." );
		}
		if( endpos+1 < startpos || endpos >= sequence.getLength() )
		{
			throw new IllegalArgumentException( "Check end position." );
		}
		if( length != 0 && length != endpos-startpos+1 )
		{
			throw new WrongLengthException( "Check length of the sequence." );
		}
		return nsf.getLogScore( sequence, startpos ) - logNorm;
	}
	
	public double getLogPriorTerm() throws Exception
	{
		return nsf.getLogPriorTerm() - nsf.getEss()*logNorm;
	}

	public String getInstanceName()
	{
		return "model using " + nsf.getInstanceName();
	}

	public boolean isTrained()
	{
		return nsf.isInitialized();
	}

	public NumericalResultSet getNumericalCharacteristics() throws Exception
	{
		return null;
	}

	public String toString()
	{
		return nsf.toString();
	}

	private static final String XML_TAG = "NormalizableScoringFunctionModel";
	
	protected void fromXML( StringBuffer xml ) throws NonParsableException
	{
		StringBuffer rep = XMLParser.extractForTag( xml, XML_TAG );
		nsf = XMLParser.extractObjectForTags( rep, "NormalizableScoringFunction", NormalizableScoringFunction.class );// TODO XMLP14CONV This and (possibly) the following lines have been converted automatically
		threads = XMLParser.extractObjectForTags( rep, "threads", int.class );
		algo = XMLParser.extractObjectForTags( rep, "algorithm", byte.class );
		if( XMLParser.hasTag( rep, "terminationCondition", null, null ) ) {
			tc = (AbstractTerminationCondition) XMLParser.extractObjectForTags( rep, "terminationCondition" );
		} else {
			try {
				tc = new SmallDifferenceOfFunctionEvaluationsCondition( XMLParser.extractObjectForTags( rep, "eps", double.class ) );
			} catch (Exception e) {
				NonParsableException n = new NonParsableException( e.getMessage() );
				throw n;
			}
		}
		lineps = XMLParser.extractObjectForTags( rep, "lineps", double.class );
		startD = XMLParser.extractObjectForTags( rep, "startDistance", double.class );
		if( isTrained() )
		{
			logNorm = nsf.getLogNormalizationConstant();
		}
		else
		{
			logNorm = Double.NEGATIVE_INFINITY;
		}
		alphabets = nsf.getAlphabetContainer();
		length = nsf.getLength();
		setOutputStream( SafeOutputStream.DEFAULT_STREAM );
	}
	
	public StringBuffer toXML()
	{
		StringBuffer xml = new StringBuffer( 100000 );
		XMLParser.appendObjectWithTags( xml, nsf, "NormalizableScoringFunction" );
		XMLParser.appendObjectWithTags( xml, threads, "threads" );
		XMLParser.appendObjectWithTags( xml, algo, "algorithm" );
		XMLParser.appendObjectWithTags( xml, tc, "tc" );
		XMLParser.appendObjectWithTags( xml, lineps, "lineps" );
		XMLParser.appendObjectWithTags( xml, startD, "startDistance" );
		XMLParser.addTags( xml, XML_TAG );
		return xml;
	}
	
	/**
	 * Sets the OutputStream that is used e.g. for writing information while training. It is possible to set
	 * <code>o=null</code>, than nothing will be written.
	 * 
	 * @param o
	 *            the OutputStream
	 */
	public final void setOutputStream( OutputStream o )
	{
		out = new SafeOutputStream( o );
	}
	
	/**
	 * Returns a copy of the internally used {@link NormalizableScoringFunction}.
	 * @return a copy of the internally used {@link NormalizableScoringFunction}
	 * @throws CloneNotSupportedException if the internal instance could not be cloned
	 */
	public NormalizableScoringFunction getFunction() throws CloneNotSupportedException
	{
		return (NormalizableScoringFunction) nsf.clone();
	}
}
