Implementation of a homogeneous Markov model of order 0 based on AbstractModel

From Jstacs
Revision as of 13:37, 5 September 2008 by Grau (talk | contribs)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)
Jump to navigationJump to search
The printable version is no longer supported and may have rendering errors. Please update your browser bookmarks and please use the default browser print function instead.
import java.util.Arrays;

import de.jstacs.NonParsableException;
import de.jstacs.NotTrainedException;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.Sample;
import de.jstacs.data.Sequence;
import de.jstacs.io.XMLParser;
import de.jstacs.models.AbstractModel;
import de.jstacs.results.NumericalResult;
import de.jstacs.results.NumericalResultSet;



public class HomogeneousMarkovModel extends AbstractModel {

	private double[] logProbs;//array for the parameters, i.e. the probabilities for each symbol
	private boolean isTrained;//stores if the model has been trained

	public HomogeneousMarkovModel( AlphabetContainer alphabets ) throws Exception {
		super( alphabets, 0 ); //we have a homogeneous Model, hence the length is set to 0
		//a homogeneous Model can only handle simple alphabets
		if(! (alphabets.isSimple() && alphabets.isDiscrete()) ){
			throw new Exception("Only simple and discrete alphabets allowed");
		}
		//initialize parameter array
		this.logProbs = new double[(int) alphabets.getAlphabetLengthAt( 0 )];
		isTrained = false; //we have not trained the model, yet
	}

	public HomogeneousMarkovModel( StringBuffer stringBuff ) throws NonParsableException { 
            super( stringBuff ); 
        }

	protected void fromXML( StringBuffer xml ) throws NonParsableException {
		//extract our XML-code
		xml = XMLParser.extractForTag( xml, "homogeneousMarkovModel" );
		//extract all the variables using XMLParser
		alphabets = (AlphabetContainer) XMLParser.extractStorableForTag( xml, "alphabets" );
		length = XMLParser.extractIntForTag( xml, "length" );
		logProbs = XMLParser.extractDoubleArrayForTag( xml, "logProbs" );
		isTrained = XMLParser.extractBooleanForTag( xml, "isTrained" );
	}

	public StringBuffer toXML() {
		StringBuffer buf = new StringBuffer();
		//pack all the variables using XMLParser
		XMLParser.appendStorableWithTags( buf, alphabets, "alphabets" );
		XMLParser.appendIntWithTags( buf, length, "length" );
		XMLParser.appendDoubleArrayWithTags( buf, logProbs, "logProbs" );
		XMLParser.appendBooleanWithTags( buf, isTrained, "isTrained" );
		//add our own tag
		XMLParser.addTags( buf, "homogeneousMarkovModel" );
		return buf;
	}

	public String getInstanceName() { 
            return "Homogeneous Markov model of order 0"; 
        }

	public double getLogPriorTerm() throws Exception { 
            //we use ML-estimation, hence no prior term
            return 0; 
        } 

	public NumericalResultSet getNumericalCharacteristics() throws Exception {
		//we do not have much to tell here
		return new NumericalResultSet(new NumericalResult("Number of parameters","The number of parameters this model uses",logProbs.length));
	}

	public double getLogProbFor( Sequence sequence, int startpos, int endpos ) throws NotTrainedException, Exception {
		double seqLogProb = 0.0;
		//compute the log-probability of the sequence between startpos and endpos (inclusive)
		//as sum of the single symbol log-probabilities
		for(int i=startpos;i<=endpos;i++){
			//directly access the array by the numerical representation of the symbols
			seqLogProb += logProbs[sequence.discreteVal( i )];
		}
		return seqLogProb;
	}
	
	public double getProbFor( Sequence sequence, int startpos, int endpos ) throws NotTrainedException, Exception {
		return Math.exp( getLogProbFor(sequence, startpos, endpos) );
	}

	public boolean isTrained() { 
            return isTrained; 
        }

	public void train( Sample data, double[] weights ) throws Exception {
		//reset the parameter array
		Arrays.fill( logProbs, 0.0 );
		//default sequence weight
		double w = 1;
		//for each sequence in the data set
		for(int i=0;i<data.getNumberOfElements();i++){
			//retrieve sequence
			Sequence seq = data.getElementAt( i );
			//if we do have any weights, use them
			if(weights != null){
				w = weights[i];
			}
			//for each position in the sequence
			for(int j=0;j<seq.getLength();j++){
				//count symbols, weighted by weights
				logProbs[ seq.discreteVal( j ) ] += w;
			}
		}
		//compute normalization
		double norm = 0.0;
		for(int i=0;i<logProbs.length;i++){ norm += logProbs[i]; }
		//normalize probs to obtain proper probabilities
		for(int i=0;i<logProbs.length;i++){ logProbs[i] = Math.log( logProbs[i]/norm ); }
		//now the model is trained
		isTrained = true;
	}

}