package de.jstacs.models.hmm.states.emissions.discrete;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.text.NumberFormat;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.Map;
import java.util.TreeMap;

import javax.naming.OperationNotSupportedException;

import de.jstacs.NonParsableException;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.Sequence;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.FileManager;
import de.jstacs.io.XMLParser;
import de.jstacs.models.hmm.states.emissions.DifferentiableEmission;
import de.jstacs.models.hmm.states.emissions.SamplingEmission;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.random.DiMRGParams;
import de.jstacs.utils.random.DirichletMRG;
import de.jstacs.utils.random.DirichletMRGParams;
import de.jstacs.utils.random.FastDirichletMRGParams;
import de.jtem.numericalMethods.calculus.specialFunctions.Gamma;


/**
 * The abstract super class of discrete emissions.
 * 
 * @author Jens Keilwagen, Michael Scharfe, Jan Grau
 */
public abstract class AbstractConditionalDiscreteEmission  implements SamplingEmission, DifferentiableEmission {

	private static double[] colors = new double[]{120.0/360.0,240.0/360.0,40.0/360.0,0,80.0/360.0,160.0/360.0,200.0/360.0,280.0/360.0,320.0/360.0};
	
	/**
	 * The files for saving the parameters during the sampling.
	 */
	protected File[] paramsFile;

	/**
	 * The counter for the sampling steps of each sampling.
	 */
	protected int[] counter;

	/**
	 * The index of the current sampling.
	 */
	protected int samplingIndex;

	/**
	 * The writer for the <code>paramsFile</code> in a sampling.
	 */
	protected BufferedWriter writer;

	/**
	 * The reader for the <code>paramsFile</code> after a sampling.
	 */
	protected BufferedReader reader;
	
	/**
	 * The offset of the parameter indexes
	 */
	protected int offset;
	
	/**
	 * The alphabet of the emissions
	 */
	protected AlphabetContainer con;
	
	/**
	 * The parameters of the emission
	 */
	protected double[][] params;
	
	/**
	 * The parameters transformed to probabilites
	 */
	protected double[][] probs;
	
	/**
	 * The hyper-parameters for the prior on the parameters
	 */
	protected double[][] hyperParams;
	
	/**
	 * The array for storing the statistics for 
	 * each parameter
	 */
	protected double[][] statistic;
	
	/**
	 * The array for storing the gradients for
	 * each parameter
	 */
	protected double[][] grad;
	
	/**
	 * The log-normalization constants for each condition
	 */
	protected double[] logNorm;
	
	/**
	 * The equivalent sample sizes for each condition
	 */
	protected double[] ess;

	/**
	 * The hyper-parameters for initializing the parameters
	 */
	private double[][] initHyperParams;
	
	private String shape;
	
	private boolean linear;

	/**
	 * Returns the hyper-parameters for all parameters and a given ess.
	 * The equivalent sample size is distributed evenly across all parameters
	 * @param ess the equivalent sample size
	 * @param numConditions the number of conditions
	 * @param numEmissions the number of emissions, assumed to be equal for all conditions
	 * @return hyper-parameters for all parameters
	 */
	protected static double[][] getHyperParams(double ess, int numConditions, int numEmissions){
		double[] ess2 = new double[numConditions];
		Arrays.fill( ess2, ess/(double)numConditions );
		return getHyperParams( ess2, numEmissions );
	}
	
	private static double[][] getHyperParams( double[] ess, int number ){
		double[][] res = new double[ess.length][number];
		for(int i=0;i<res.length;i++){
			Arrays.fill( res[i], ess[i]/(double) number );
		}
		return res;
	}
	
	/**
	 * This is a simple constructor for a {@link AbstractConditionalDiscreteEmission} based on the equivalent sample size.
	 * 
	 * @param con the {@link AlphabetContainer} of this emission
	 * @param numberOfConditions the number of conditions
	 * @param ess the equivalent sample size (ess) of this emission that is equally distributed over all parameters
	 * 
	 * @see #AbstractConditionalDiscreteEmission(AlphabetContainer, double[][])
	 */
	protected AbstractConditionalDiscreteEmission( AlphabetContainer con, int numberOfConditions, double ess ) {
		this( con, getHyperParams( ess, numberOfConditions, (int) con.getAlphabetLengthAt( 0 )));
	}

	
	/**
	 * This is a simple constructor for a {@link AbstractConditionalDiscreteEmission} defining the individual hyper parameters.
	 * 
	 * @param con the {@link AlphabetContainer} of this emission
	 * @param hyperParams the individual hyper parameters for each parameter
	 * 
	 * @see #AbstractConditionalDiscreteEmission(AlphabetContainer, double[][])
	 */
	protected AbstractConditionalDiscreteEmission( AlphabetContainer con, double[][] hyperParams ) {
		this(con,hyperParams,hyperParams);
	}
	
	/**
	 * This constructor creates a {@link AbstractConditionalDiscreteEmission} defining the individual hyper parameters for the
	 * prior used during training and initialization.
	 * 
	 * @param con the {@link AlphabetContainer} of this emission
	 * @param hyperParams the individual hyper parameters for each parameter (used during training)
	 * @param initHyperParams the individual hyper parameters for each parameter used in {@link #initializeFunctionRandomly()}
	 */
	protected AbstractConditionalDiscreteEmission( AlphabetContainer con, double[][] hyperParams, double[][] initHyperParams ) {
		this.con = con;
		ess = new double[hyperParams.length];
		this.hyperParams = new double[hyperParams.length][hyperParams[0].length];
		if(hyperParams == initHyperParams){
			this.initHyperParams = this.hyperParams;
		}else{
			this.initHyperParams = new double[initHyperParams.length][initHyperParams[0].length];
		}
		for( int i = 0; i < hyperParams.length; i++ ) {
			for( int j = 0; j < hyperParams[i].length; j++ ) {
				if( hyperParams[i][j] < 0 ) {
					throw new IllegalArgumentException( "Please check the hyper-parameter (" + i + ", " + j + ")." );
				}
				this.hyperParams[i][j] = hyperParams[i][j];
				if(this.hyperParams != this.initHyperParams){
					this.initHyperParams[i][j] = initHyperParams[i][j];
				}
				ess[i] += hyperParams[i][j];
			}
		}
		params = new double[hyperParams.length][hyperParams[0].length];
		probs = new double[hyperParams.length][hyperParams[0].length];
		statistic = new double[hyperParams.length][hyperParams[0].length];
		grad = new double[hyperParams.length][hyperParams[0].length];
		logNorm = new double[hyperParams.length];
		precompute();
	}
	
	/**
	 * Creates a {@link AbstractConditionalDiscreteEmission} from its XML representation.
	 * @param xml the XML representation.
	 * @throws NonParsableException if the XML representation could not be parsed
	 */
	protected AbstractConditionalDiscreteEmission( StringBuffer xml ) throws NonParsableException {
		fromXML( xml );
	}
	
	public AbstractConditionalDiscreteEmission clone() throws CloneNotSupportedException {
		AbstractConditionalDiscreteEmission clone = (AbstractConditionalDiscreteEmission) super.clone();
		clone.params = ArrayHandler.clone( params );
		clone.probs = ArrayHandler.clone( probs );
		clone.hyperParams = ArrayHandler.clone( hyperParams );
		clone.statistic = ArrayHandler.clone( statistic );
		if(grad != null){
			clone.grad = grad.clone();
		}
		if(logNorm != null){
			clone.logNorm = logNorm.clone();
		}
		if(ess != null){
			clone.ess = ess.clone();
		}
		if(counter != null){
			clone.counter = counter.clone();
		}
		return clone;
	}
	
	/**
	 * Sets the graphviz shape of the node that uses this emission to some non-standard value
	 * (standard is &quot;house&quot;).
	 * @param shape the shape of the node
	 */
	public void setShape(String shape){
		this.shape = shape;
	}
	
	public void addGradientOfLogPriorTerm(double[] gradient, int offset) {
		for( int i = 0; i < params.length; i++ ) {
			for( int j = 0; j < params[i].length; j++, offset++ ) {
				gradient[offset+this.offset] += hyperParams[i][j] - ess[i]*probs[i][j];
			}
		}
	}

	public double getLogPriorTerm() {
		double res = 0;
		for( int i = 0; i < params.length; i++ ) {
			if( ess[i] > 0 ) {
				res += -ess[i]*logNorm[i];
				for( int j = 0; j < params[i].length; j++ ) {
					res += hyperParams[i][j] * params[i][j];
				}
			}
		}
		return res;
	}

	public double getLogProbAndPartialDerivationFor( boolean forward, int startPos, int endPos,
			IntList indices, DoubleList partDer, Sequence seq) throws OperationNotSupportedException {
		int s, e;
		Sequence current;
		if( forward ) {
			current = seq;
			s = startPos;
			e = endPos;
		} else {
			current = seq.reverseComplement();
			int len = current.getLength();
			s = len - endPos -1;
			e = len - startPos -1;
		}
		int v = e-s+1;
		
		double res = 0;
		for(int i=0;i<grad.length;i++){
			Arrays.fill(grad[i],0);
		}
		while( s <= e ) {
			int condIdx = getConditionIndex( forward, s, seq );
			if(condIdx < 0){
				return Double.NEGATIVE_INFINITY;
			}
			v = current.discreteVal( s++ );
			res -= logNorm[condIdx];
			for(int i=0;i<grad[condIdx].length;i++) {
				grad[condIdx][i] -= probs[condIdx][i];
			}
			res += params[condIdx][v];
			grad[condIdx][v]++;
		}
		int myOff = 0;
		for( int i = 0; i< grad.length; i++ ) {
			for(int j=0;j<grad[i].length;j++, myOff++){
				indices.add( offset + myOff );
				partDer.add( grad[i][j] );
			}
		}
		return res;
	}

	public double getLogProbFor( boolean forward, int startPos, int endPos, Sequence seq) throws OperationNotSupportedException {
		int s, e;
		Sequence current;
		if( forward ) {
			current = seq;
			s = startPos;
			e = endPos;
		} else {
			current = seq.reverseComplement();
			int len = current.getLength();
			s = len - endPos -1;
			e = len - startPos -1;
		}
		
		double res = 0;
		while( s <= e ) {
			int condIdx = getConditionIndex( forward, s, seq );
			if(condIdx < 0){
				return Double.NEGATIVE_INFINITY;
			}
			res -= logNorm[condIdx];
			res += params[condIdx][current.discreteVal( s++ )];
		}
		return res;
	}

	public void initializeFunctionRandomly() {
		DiMRGParams p;
		for(int i=0;i<probs.length;i++){
				double ess = 0;
				for(int j=0;j<initHyperParams[i].length;j++){
					ess += initHyperParams[i][j];
				}
				if( ess == 0 ) {
					p = new FastDirichletMRGParams(1d);
				}else{
					p = new DirichletMRGParams( initHyperParams[i] );
				}	
				DirichletMRG.DEFAULT_INSTANCE.generate( probs[i], 0, probs[i].length, p );

		}
		
		Arrays.fill( logNorm, 0 );
		for( int i = 0; i < params.length; i++ ) {
			for(int j=0;j<params[i].length;j++) {
				params[i][j] = Math.log( probs[i][j] );
			}
		}
	}


	/**
	 * This method precomputes some normalization constant and probabilities.
	 * 
	 * @see #logNorm
	 * @see #probs
	 */
	protected void precompute() {
		Arrays.fill( logNorm, 0 );
		for(int i = 0 ; i < params.length; i++ ) {
			logNorm[i] = Normalisation.getLogSum( params[i] );
			for( int j = 0; j < params[i].length; j++ ) {
				probs[i][j] = Math.exp( params[i][j] - logNorm[i] );
			}
		}
	}

	private static final String XML_TAG = "ConditionalDiscreteEmission"; 
	
	public StringBuffer toXML() {
		StringBuffer xml = new StringBuffer();
		XMLParser.appendObjectWithTags( xml, params, "params" );
		XMLParser.appendObjectWithTags( xml, offset, "offset" );
		XMLParser.appendObjectWithTags( xml, con, "alphabetContainer" );
		XMLParser.appendObjectWithTags( xml, hyperParams, "hyperParams" );
		XMLParser.appendObjectWithTags( xml, initHyperParams, "initHyperParams" );
		XMLParser.appendObjectWithTags( xml, statistic, "statistic" );
		XMLParser.appendObjectWithTags( xml, ess, "ess" );
		XMLParser.appendObjectWithTags( xml, shape, "shape" );
		XMLParser.appendObjectWithTags( xml, linear, "linear" );
		
		if( writer != null ) {
			throw new RuntimeException( "could not parse SamplingHigherOrderTransition to XML while sampling" );
		}
		
		XMLParser.appendObjectWithTags( xml, paramsFile != null, "hasParameters" );
		if( paramsFile != null ) {
			String content;
			try {
				XMLParser.appendObjectWithTags( xml, counter, "counter" );
				
				for( int i = 0; i < paramsFile.length; i++ ) {
					if( paramsFile[i] != null ) {
						content = FileManager.readFile( paramsFile[i] ).toString();
					} else {
						content = "";
					}
					XMLParser.appendObjectWithTagsAndAttributes( xml, content, "fileContent", "pos=\"" + i + "\"" );
				}
			} catch ( IOException e ) {
				RuntimeException r = new RuntimeException( e.getMessage() );
				r.setStackTrace( e.getStackTrace() );
				throw r;
			}
		}
		
		appendFurtherInformation( xml );
		XMLParser.addTags( xml, XML_TAG );
		return xml;
	}
	
	/**
	 * This method appends further information to the XML representation. It allows subclasses to save further parameters that are not defined in the superclass.
	 * 
	 * @param xml the XML representation
	 */
	protected void appendFurtherInformation( StringBuffer xml ) {
	}

	/**
	 * This method is internally used by the constructor {@link #AbstractConditionalDiscreteEmission(StringBuffer)}.
	 * 
	 * @param xml the {@link StringBuffer} containing the xml representation of an instance
	 * 
	 * @throws NonParsableException if the {@link StringBuffer} is not parsable
	 * 
	 * @see #AbstractConditionalDiscreteEmission(StringBuffer)
	 */
	protected void fromXML( StringBuffer xml ) throws NonParsableException {
		xml = XMLParser.extractForTag( xml, XML_TAG );
		
		params = (double[][]) XMLParser.extractObjectForTags( xml, "params" );
		probs = new double[params.length][params[0].length];
		grad = new double[params.length][params[0].length];
		logNorm = new double[params.length];
		precompute();
		
		offset = (Integer) XMLParser.extractObjectForTags( xml, "offset" );
		con = (AlphabetContainer) XMLParser.extractObjectForTags( xml, "alphabetContainer" );
		hyperParams = (double[][]) XMLParser.extractObjectForTags( xml, "hyperParams" );
		try{
			initHyperParams = (double[][]) XMLParser.extractObjectForTags( xml, "initHyperParams" );
		}catch(NonParsableException e){
			try{
				initHyperParams = ArrayHandler.clone( hyperParams );
			}catch(CloneNotSupportedException ex){}
		}
		statistic = (double[][]) XMLParser.extractObjectForTags( xml, "statistic" );
		ess = (double[]) XMLParser.extractObjectForTags( xml, "ess" );
		try{
			shape = XMLParser.extractObjectForTags( xml, "shape", String.class );
			linear = XMLParser.extractObjectForTags( xml, "linear",boolean.class );
		}catch(NonParsableException e){
			shape = null;
			linear = false;
		}
		
		if( XMLParser.extractObjectForTags( xml, "hasParameters", boolean.class ) ) {
			counter = XMLParser.extractObjectForTags( xml, "counter", int[].class );
			paramsFile = new File[counter.length];
			try {
				String content;
				Map<String,String> filter = new TreeMap<String, String>();
				for( int i = 0; i < paramsFile.length; i++ ) {
					filter.clear();
					filter.put( "pos", ""+i );
					content = XMLParser.extractObjectAndAttributesForTags( xml, "fileContent", null, filter, String.class );
					if( !content.equalsIgnoreCase( "" ) ) {
						paramsFile[i] = File.createTempFile( "samplingDEmission-", ".dat", null );
						FileManager.writeFile( paramsFile[i], new StringBuffer( content ) );
					}
				}
			} catch ( IOException e ) {
				NonParsableException n = new NonParsableException( e.getMessage() );
				n.setStackTrace( e.getStackTrace() );
				throw n;
			}
		} else {
			counter = null;
			paramsFile = null;
		}
		writer = null;
		reader = null;
		
		extractFurtherInformation( xml );
	}
	
	/**
	 * This method extracts further information from the XML representation. It allows subclasses to cast further parameters that are not defined in the superclass.
	 * 
	 * @param xml the XML representation
	 *  
	 * @throws NonParsableException if the information could not be reconstructed out of the {@link StringBuffer} <code>xml</code>
	 */
	protected void extractFurtherInformation( StringBuffer xml ) throws NonParsableException {
	}

	public void addToStatistic( boolean forward, int startPos, int endPos, double weight, Sequence seq ) throws OperationNotSupportedException {
		int s, e;
		Sequence current;
		
		if( forward ) {
			current = seq;
			s = startPos;
			e = endPos;
		} else {
			current = seq.reverseComplement();
			int len = current.getLength();
			s = len - endPos -1;
			e = len - startPos -1;
		}
		
		while( s <= e ) {
			int condIdx = getConditionIndex( forward, s, seq);
			statistic[condIdx][current.discreteVal( s++ )] += weight;
		}
	}

	/**
	 * This method returns an index encoding the condition.
	 * 
	 * @param forward a switch to decide whether to use the forward or the reverse complementary strand (e.g. for DNA sequences)
	 * @param seqPos the position in the sequence <code>seq</code>
	 * @param seq the sequence
	 * 
	 * @return the index encoding the condition
	 */
	protected abstract int getConditionIndex( boolean forward, int seqPos, Sequence seq );

	public void estimateFromStatistic() {
		for(int j=0;j<statistic.length;j++){
			double sum = 0;
			for( int i = 0; i < statistic[j].length; i++ ) {
				sum += statistic[j][i];
			}
			if( sum == 0 ) {
				Arrays.fill( statistic[j], 1 );
				sum = statistic[j].length;
			}
			for( int i = 0; i < statistic[j].length; i++ ) {
				probs[j][i] = statistic[j][i] / sum;
				params[j][i] = Math.log( probs[j][i] );
			}
		}
		Arrays.fill( logNorm, 0 );
	}

	public void resetStatistic() {
		for(int i=0;i<hyperParams.length;i++){
			System.arraycopy( hyperParams[i], 0, statistic[i], 0, hyperParams[i].length );
		}
	}
	
	
	
	public abstract String toString();
	
	public void setParameter( double[] params, int offset ) {
		for( int i = 0; i < this.params.length; i++ ) {
			for( int j = 0; j < this.params[i].length; j++, offset++) {
				this.params[i][j] = params[this.offset+offset];
			}
		}
		precompute();
	}
	
	public AlphabetContainer getAlphabetContainer() {
		return con;
	}
	
	public void fillCurrentParameter( double[] params ) {
		int myOffset = offset;
		for( int i = 0; i < this.params.length; i++ ) {
			for(int j = 0; j < this.params[i].length; j++, myOffset++ ) {
				params[myOffset] = this.params[i][j];
			}
		}
	}
	
	public int setParameterOffset( int offset ) {
		this.offset = offset;
		for(int i=0;i<params.length;i++){
			offset += params[i].length;
		}
		return offset;
	}
	
    public void drawParametersFromStatistic()  throws Exception {
    	for(int j=0;j<probs.length;j++){
    		DirichletMRG.DEFAULT_INSTANCE.generate( probs[j], 0, statistic[j].length, new DirichletMRGParams( statistic[j] ) );
    		for( int i = 0; i < statistic[j].length; i++ ) {
    			params[j][i] = Math.log(probs[j][i]);
    		}
    		logNorm[j] = 0;
    	}
	}

    public double getLogGammaScoreFromStatistic() {

    	double[][] hyper = getHyperParams( ess, (int) con.getAlphabetLengthAt( 0 ) );
        double res = Double.NEGATIVE_INFINITY;
        for(int j=0;j<ess.length;j++){
        	
        	double sum = 0;

        	for (double i : hyper[j]) sum += i;

        	res = Gamma.logOfGamma(sum);
        	for(int i = 0; i < hyper.length; i++)
        		res += Gamma.logOfGamma(statistic[j][i]) - Gamma.logOfGamma(hyper[j][i]);

        	sum = 0;
        	for (double i : statistic[j]) sum += i;

        	res -= Gamma.logOfGamma(sum);
        }
        return res;
    }
	
	
	public void acceptParameters() throws IOException {
		writer.write( "" + ( counter[samplingIndex]++ ) );
		for( int i = 0; i < params.length; i++ ) {
			for(int j=0;j<params[i].length;j++){
				writer.write( "\t" + params[i][j] );
			}
		}
		writer.write( "\n" );
		writer.flush();
	}
	
	public double getLogPosteriorFromStatistic() {
		double logPost =0;
		 for( int i = 0; i < params.length; i++ ) {
			 for(int j=0;j<params[i].length;j++){
				 logPost += statistic[i][j] * (params[i][j]-logNorm[i]);
			 }
         }
		return logPost;
	}

	public void extendSampling(int start, boolean append) throws IOException {
		
		if( paramsFile[start] == null ) {
			paramsFile[start] = File.createTempFile( "samplingDEmission-", ".dat", null );
		} else {
			if( append ) {
				parseParameterSet( start, counter[start] - 1 );
				reader.close();
				reader = null;
			} else {
				counter[start] = 0;
			}
		}
		writer = new BufferedWriter( new FileWriter( paramsFile[start], append ) );
		samplingIndex = start;
	}

	public void initForSampling( int starts ) throws IOException {

		for(int i = 0; i < hyperParams.length; i++) {
			for(int j=0;j<hyperParams[i].length;j++){
				if(!Double.isNaN(hyperParams[i][j]) && hyperParams[i][j] <= 0){
					throw new IllegalArgumentException( "All (not NAN) hyper-parameters must have a value > 0. Please check the hyper-parameter " + i + "." );
				}
			}
		}
		
		
		if( paramsFile != null && paramsFile.length == starts ) {
			FileOutputStream o;
			for( int i = 0; i < starts; i++ ) {
				if( paramsFile[i] != null ) {
					o = new FileOutputStream( paramsFile[i] );
					o.close();
				}
				counter[i] = 0;
			}
		} else {
			deleteParameterFiles();
			paramsFile = new File[starts];
			counter = new int[starts];
		}

	}
	
	private void deleteParameterFiles() {
		if( paramsFile != null ) {
			for( int i = 0; i < paramsFile.length; i++ ) {
				if( paramsFile[i] != null ) {
					paramsFile[i].delete();
				}
			}
		}
	}

	public boolean isInSamplingMode() {
		return writer != null;
	}

	public boolean parseNextParameterSet() {
		if( writer != null ) {
			return false;
		}
		String str = null;
		try {
			str = reader.readLine();
		} catch ( IOException e ) {} finally {
			if( str == null ) {
				return false;
			}
		}

		parse( str );
		return true;
	}

	public boolean parseParameterSet(int start, int n) throws IOException {
		String str;
		if( reader != null ) {
			reader.close();
		}
		reader = new BufferedReader( new FileReader( paramsFile[start] ) );
		while( ( str = ( reader.readLine() ) ) != null ) {
			if( Integer.parseInt( str.substring( 0, str.indexOf( "\t" ) ) ) == n ) {
				parse( str );
				return true;
			}
		}
		return false;
	}
	
	private void parse( String str ) {
	
		String[] strArray = str.split( "\t" );
		int offset = 1;
		
		for( int i=0; i < params.length; i++ ) {
			for(int j=0;j<params[i].length;j++){
				params[i][j] = Double.parseDouble(strArray[offset++]);
				probs[i][j] = Math.exp(params[i][j]);
			}
			logNorm[i] = 0;
		}
		
	}

	public void samplingStopped() throws IOException {
		
		if( writer != null ) {
			writer.close();
			writer = null;
		}
	}
	
	protected void finalize() throws Throwable {
		if( writer != null ) {
			writer.close();
		}
		if( reader != null ) {
			reader.close();
		}
		deleteParameterFiles();
		super.finalize();
	}
	
	@Override
	public String getNodeShape(boolean forward) {
		String res;
		if( shape == null ) {
			res = "";
			if( getAlphabetContainer().isReverseComplementable() ) {
				res += "\"house\", orientation=";
				if( forward ) {
					res += "-";
				}
				res += "90";
			} else {
				res += "\"box\"";
			}
		} else {
			res = "\""+shape+"\"";
		}
		return res;
	}

	@Override
	public String getNodeLabel( double weight, String name, NumberFormat nf ) {
		if(weight < 0){
			return "\""+name+"\"";
		}else{
			StringBuffer buf = new StringBuffer();
			if(nf == null){
				String namelabel = name;
				if(weight < 0.5){
					namelabel = "<font color=\"white\">"+namelabel+"</font>";
				}
				buf.append( "<<table border=\"0\" cellspacing=\"0\"><tr><td colspan=\""+probs[0].length+"\">"+namelabel+"</td></tr>" );
				for(int i=0;i<probs.length;i++){
					buf.append( "<tr>" );
					double[] trans =  transformProbs( probs[i] );
					double en = (getInformationContent( probs[i] )+2.0)/3.0;
					for(int j=0;j<probs[i].length;j++){
						buf.append( "<td border=\"1\" width=\"25\" height=\"25\"" );
						buf.append( " bgcolor=\""+colors[j]+" "+trans[j]+" "+en+"\"" );
						buf.append( "></td>" );
					}
					buf.append( "</tr>" );
				}
				buf.append( "</table>>" );
			}else{
				buf.append( "<<table border=\"0\" cellspacing=\"0\"><tr><td colspan=\""+probs[0].length+"\">"+name+"</td></tr>" );
				for(int i=0;i<probs.length;i++){
					buf.append( "<tr>" );
					for(int j=0;j<probs[i].length;j++){
						buf.append( "<td border=\"1\">" );
						buf.append( nf.format( probs[i][j]) );
						buf.append( "</td>" );
					}
					buf.append( "</tr>" );
				}
				buf.append( "</table>>" );
			}
			return buf.toString();
		}
		
	}

	
	/**
	 * If set to true, the probabilities are mapped to colors by directly, otherwise
	 * a logistic mapping is used to emphasize deviations from the uniform distribution.
	 * @param linear map probabilities linear
	 */
	public void setLinear( boolean linear ) {
		this.linear = linear;
	}
	
	private double getInformationContent(double[] probs){
		double max = Math.log( probs.length );
		double en = 0.0;
		for(int i=0;i<probs.length;i++){
			if(probs[i]>0){
				en -= probs[i] * Math.log( probs[i] );
			}
		}
		return (max-en)/max;
	}
	
	private double[] transformProbs(double[] probs){
		if(linear){
			return probs.clone();
		}else{
			double[] trans = new double[probs.length];

			double a=15,b=4;

			for(int i=0;i<probs.length;i++){
				trans[i] = 1.0/(1.0+Math.exp( -a*probs[i] + b ));
			}
			return trans;
		}
	}

	@Override
	public void fillSamplingGroups( int parameterOffset, LinkedList<int[]> list ) {
		int off = 0;
		for(int i=0;i<params.length;i++){
			int[] idxs = new int[params[i].length];
			for(int j=0;j<idxs.length;j++){
				idxs[j] = j + off + offset + parameterOffset;
			}
			list.add( idxs );
			off += idxs.length;
		}
	}

	@Override
	public int getNumberOfParameters() {
		return params.length*params[0].length;
	}

	@Override
	public int getSizeOfEventSpace() {
		return params.length*params[0].length;
	}
	
	
	
}
