/*
 * This file is part of Jstacs.
 *
 * Jstacs is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 * 
 * Jstacs is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Jstacs.  If not, see <http://www.gnu.org/licenses/>.
 * 
 * For more information on Jstacs, visit http://www.jstacs.de
 */

package de.jstacs.scoringFunctions.directedGraphicalModels.structureLearning.measures.btMeasures;

import de.jstacs.NonParsableException;
import de.jstacs.algorithms.graphs.MST;
import de.jstacs.data.Sample;
import de.jstacs.io.XMLParser;
import de.jstacs.scoringFunctions.directedGraphicalModels.BayesianNetworkScoringFunction;
import de.jstacs.scoringFunctions.directedGraphicalModels.structureLearning.measures.Measure;

/**
 * Structure learning {@link Measure} that computes a maximum spanning tree based on mutual information and uses the resulting
 * tree structure as structure of a Bayesian tree (special case of a Bayesian network) in a {@link BayesianNetworkScoringFunction}.
 * 
 * @author Jan Grau
 *
 */
public class BTMutualInformation extends Measure {

	private int clazz;
	private double[] ess;
	
	/**
	 * Compute mutual information only from foreground data
	 */
	public static final int FG=0;
	/**
	 * Compute mutual information only from background data
	 */
	public static final int BG=1;
	/**
	 * Use both data sets to compute the mutual information
	 */
	public static final int BOTH=2;
	
	/**
	 * Re-creates a {@link BTMutualInformation} from is XML-representation as returned by {@link BTMutualInformation#toXML()}.
	 * @param buf the XML-representation
	 * @throws NonParsableException is thrown if the XML-code could not be parsed
	 */
	public BTMutualInformation(StringBuffer buf) throws NonParsableException{
		buf = XMLParser.extractForTag(buf, "btMutualInformation");
		clazz = XMLParser.extractIntForTag(buf, "clazz");
		ess = XMLParser.extractDoubleArrayForTag(buf, "ess");
	}
	
	/**
	 * Creates a new mutual information Bayesian tree {@link Measure}.
	 * @param clazz the classes used for computation of mutual information, one of {@link BTMutualInformation#FG}, {@link BTMutualInformation#BG}, {@link BTMutualInformation#BOTH}
	 * @param ess the equivalent sample sizes for both classes
	 * @throws Exception thrown if <code>clazz</code> is not one of the allowed values
	 */
	public BTMutualInformation(int clazz, double[] ess) throws Exception {
		if(clazz == FG || clazz == BG || clazz == BOTH){
			this.clazz = clazz;
		}else{
			throw new Exception("Value of clazz not allowed.");
		}
		this.ess = ess.clone();
	}
	
	/* (non-Javadoc)
	 * @see de.jstacs.scoringFunctions.directedGraphicalModels.structureLearning.measures.Measure#clone()
	 */
	public BTMutualInformation clone() throws CloneNotSupportedException{
		BTMutualInformation clone = (BTMutualInformation) super.clone();
		clone.ess = ess.clone();
		return clone;
	}

	/* (non-Javadoc)
	 * @see de.jstacs.scoringFunctions.directedGraphicalModels.structureLearning.measures.Measure#getInstanceName()
	 */
	@Override
	public String getInstanceName() {
		String str = "Bayesian tree with mutual information of";
		if(clazz == FG){
			return str+" foreground";
		}else if(clazz == BG){
			return str+" background";
		}else{
			return str+" foreground and background";
		}
	}

	/* (non-Javadoc)
	 * @see de.jstacs.scoringFunctions.directedGraphicalModels.structureLearning.measures.Measure#getParents(de.jstacs.data.Sample, de.jstacs.data.Sample, double[], double[], int)
	 */
	@Override
	public int[][] getParents(Sample fg, Sample bg, double[] weightsFg, double[] weightsBg, int length) throws Exception {
		Sample data = null;
		double[] weights = null;
		double ess2 = 0;
		if(clazz == FG){
			data = fg;
			weights = weightsFg;
			ess2 = ess[0];
		}else if(clazz == BG){
			data = bg;
			weights = weightsBg;
			ess2 = ess[1];
		}else{
			data = Sample.union(fg,bg);
			weights = union(new double[][]{weightsFg,weightsBg});
			ess2 = ess[0] + ess[1];
		}
		double[][][][] stat = getStatistics(data,weights, length, ess2);
		double[][] mi = getMI(stat, sum(weights) + ess2);
		
		int[][] p = null;

		p = MST.kruskal( mi );
		
		int[][] parents = new int[length][1];
		for(int i=0;i<parents.length;i++){
			parents[i][0] = i;
		}
		for(int i=0;i<p.length;i++){
			int idx = p[i][1];
			parents[idx] = new int[2];
			parents[idx][0] = p[i][0];
			parents[idx][1] = idx;
			
		}
		return parents;
	}

	/* (non-Javadoc)
	 * @see de.jstacs.Storable#toXML()
	 */
	public StringBuffer toXML() {
		StringBuffer buf = new StringBuffer();
		XMLParser.appendIntWithTags(buf, clazz, "clazz");
		XMLParser.appendDoubleArrayWithTags(buf, ess, "ess");
		XMLParser.addTags(buf, "btMutualInformation");
		return buf;
	}

}
