MiMB Example: Difference between revisions
From Jstacs
Jump to navigationJump to search
No edit summary |
mNo edit summary |
||
Line 54: | Line 54: | ||
//read foreground and background data set form FastA-files | //read foreground and background data set form FastA-files | ||
Sample fgData = new DNASample(home+" | Sample fgData = new DNASample(home+"foreground.fa"); | ||
Sample bgData = new DNASample(home+" | Sample bgData = new DNASample(home+"background.fa"); | ||
/* generative part */ | /* generative part */ |
Revision as of 21:25, 2 December 2009
This class contains all code snippets used in the chapter Probabilistic Approaches to Transcription Factor Binding Site Prediction of the book Computational Biology of Transcription Factors of the series Methods in Molecular Biology:
package supplementary.codeExamples;
import java.io.PrintWriter;
import de.jstacs.algorithms.optimization.Optimizer;
import de.jstacs.classifier.AbstractScoreBasedClassifier;
import de.jstacs.classifier.ConfusionMatrix;
import de.jstacs.classifier.MeasureParameters;
import de.jstacs.classifier.AbstractScoreBasedClassifier.DoubleTableResult;
import de.jstacs.classifier.MeasureParameters.Measure;
import de.jstacs.classifier.assessment.RepeatedHoldOutAssessParameterSet;
import de.jstacs.classifier.assessment.RepeatedHoldOutExperiment;
import de.jstacs.classifier.modelBased.ModelBasedClassifier;
import de.jstacs.classifier.scoringFunctionBased.OptimizableFunction.KindOfParameter;
import de.jstacs.classifier.scoringFunctionBased.gendismix.GenDisMixClassifierParameterSet;
import de.jstacs.classifier.scoringFunctionBased.logPrior.CompositeLogPrior;
import de.jstacs.classifier.scoringFunctionBased.msp.MSPClassifier;
import de.jstacs.data.DNASample;
import de.jstacs.data.Sample;
import de.jstacs.data.Sequence;
import de.jstacs.models.Model;
import de.jstacs.models.VariableLengthWrapperModel;
import de.jstacs.models.discrete.inhomogeneous.BayesianNetworkModel;
import de.jstacs.models.discrete.inhomogeneous.StructureLearner.LearningType;
import de.jstacs.models.discrete.inhomogeneous.StructureLearner.ModelType;
import de.jstacs.models.discrete.inhomogeneous.parameters.BayesianNetworkModelParameterSet;
import de.jstacs.results.ListResult;
import de.jstacs.results.ResultSet;
import de.jstacs.scoringFunctions.directedGraphicalModels.BayesianNetworkScoringFunction;
import de.jstacs.scoringFunctions.directedGraphicalModels.BayesianNetworkScoringFunctionParameterSet;
import de.jstacs.scoringFunctions.directedGraphicalModels.structureLearning.measures.InhomogeneousMarkov;
import de.jstacs.utils.REnvironment;
/**
* This class implements a main that shows some features of Jstacs including models for generative and discriminative learning,
* creation of classifiers, evaluation of classifiers, hold-out sampling, and binding site prediction.
*
* @author Jan Grau, Jens Keilwagen
*/
public class MiMBExample {
/**
* @param args only the first parameter will be used; it determines the home directory
*/
public static void main( String[] args ) throws Exception {
String home = args[0];
/* read data */
//read foreground and background data set form FastA-files
Sample fgData = new DNASample(home+"foreground.fa");
Sample bgData = new DNASample(home+"background.fa");
/* generative part */
//create set of parameters for foreground model
BayesianNetworkModelParameterSet pars = new BayesianNetworkModelParameterSet(
fgData.getAlphabetContainer(),//used alphabets
fgData.getElementLength(),//element length == sequence length of each sequence in the sample
4,//ESS == equivalent sample size (has to be non-negative)
"fg model",//user description of the model
ModelType.IMM,//type of statistical model, here an inhomogeneous Markov model (IMM)
(byte)0,// model order, here 0, so we get an IMM(0) == PWM = position weight matrix
LearningType.ML_OR_MAP//how to learn the parameters, depends on ESS; for ESS=0 it is ML otherwise MAP
);
//create foreground model from these parameters
Model fgModel = new BayesianNetworkModel(pars);
//analogously, create the background model
BayesianNetworkModelParameterSet pars2 = new BayesianNetworkModelParameterSet(
fgData.getAlphabetContainer(),
fgData.getElementLength(),
1024,
"bg model",
ModelType.IMM,
(byte)0,
LearningType.ML_OR_MAP
);
Model bgModel = new BayesianNetworkModel(pars2);
bgModel = new VariableLengthWrapperModel(bgModel);
//create generative classifier from the models defined before
ModelBasedClassifier cl = new ModelBasedClassifier(fgModel, bgModel);
cl.train( fgData, bgData );
/* discriminative part */
//create set of parameters for foreground scoring function
BayesianNetworkScoringFunctionParameterSet parsD = new BayesianNetworkScoringFunctionParameterSet(
fgData.getAlphabetContainer(),//used alphabets
fgData.getElementLength(),//element length == sequence length of each sequence in the sample
4,//ESS == equivalent sample size (has to be non-negative)
true,//use MAP-parameters as start values of the optimization
new InhomogeneousMarkov(0) //the statistical model, here an IMM(0) == PWM
);
//create foreground scoring function from these parameters
BayesianNetworkScoringFunction fgFun = new BayesianNetworkScoringFunction(parsD);
//analogously, create the background scoring function
BayesianNetworkScoringFunctionParameterSet parsDbg = new BayesianNetworkScoringFunctionParameterSet(
fgData.getAlphabetContainer(),
fgData.getElementLength(),
1024,
true,
new InhomogeneousMarkov(0));
BayesianNetworkScoringFunction bgFun = new BayesianNetworkScoringFunction(parsDbg);
//create set of parameter for the discriminative classifier
GenDisMixClassifierParameterSet clPars = new GenDisMixClassifierParameterSet(
fgData.getAlphabetContainer(),//used alphabets
fgData.getElementLength(),//element length == sequence length of each sequence in the samples
Optimizer.QUASI_NEWTON_BFGS,//determines the algorithm for numerical optimization
1E-6,//epsilon to stop the numerical optimization
1E-6,//epsilon to stop the line search within the numerical optimization
1,//start step width in the numerical optimization
false,//a switch that decides whether to use only the free parameters or all parameters
KindOfParameter.PLUGIN,//a switch to decide which start parameters to choose
true,//a switch that states the objective function will be normalized
1//number of threads used during optimization
);
//create discriminative classifier from the parameters and the scoring function defined before
MSPClassifier cll = new MSPClassifier(
clPars,//the parameters of the classifier
new CompositeLogPrior(),//the used prior, to obtain MCL use null
fgFun,//the scoring function for the foreground class
bgFun//the scoring function for the background class
);
//train the discriminative classifier
cll.train( fgData, bgData );
/* performance measures */
//partition data
Sample[] fgSplit = bisect( fgData, fgData.getElementLength() );
Sample[] bgSplit = bisect( bgData, fgData.getElementLength() );
Sample fgTest = fgSplit[1];
Sample bgTest = bgSplit[1];
//train the generative classifier
cl.train( fgSplit[0], bgSplit[0] );
//fill a confusion matrix
ConfusionMatrix confMatrix = cl.test( fgTest, bgTest );
//read the entries of the table
double tp = confMatrix.getCountsFor(0, 0);
double fn = confMatrix.getCountsFor(1, 0);
double tn = confMatrix.getCountsFor(1, 1);
double fp = confMatrix.getCountsFor(0, 1);
double p = tp+fn;
double barp = tp+fp;
double n = tn+fp;
double barn = tn+fn;
System.out.println("TP = "+tp+"\t\tFP = "+fp+"\t\tbarp = "+barp+"\n" +
"FN = "+fn+"\t\tTN = "+tn+"\t\tbarn = "+barn+"\n" +
"p = "+p+"\t\tn = "+n+"\t\tN' = "+(n+p));
//compute the measures
double sn = tp/p;
double ppv = tp/barp;
double fpr = fp/n;
double sp = tn/n;
double cr = (tp+tn)/(n+p);
System.out.println("cr = "+cr+"\nSn = "+sn+"\nppv = "+ppv+"\nSp = "+sp+"\nfpr = "+fpr+"\n");
//define the measures that shall be evaluated
MeasureParameters mp = new MeasureParameters(
true,//evaluate all performance measures
0.999,//use specificity of 0.999 to measure the sensitivity
0.95,//use sensitivity of 0.95 to measure the false positive rate
0.95//use sensitivity of 0.95 to measure the positive predictive value
);
//evaluates the classifier
ResultSet rs = cl.evaluateAll(
mp,//defines the measures that will be evaluated
true,//allows to throw an exception if a measure can not be computed
fgTest,//the test data for the foreground class
bgTest//the test data for the background class
);
System.out.println(rs);
//plot ROC and PR curve
DoubleTableResult roc = (DoubleTableResult)rs.getResultAt( rs.findColumn( Measure.ReceiverOperatingCharacteristicCurve.getNameString() ) );
DoubleTableResult pr = (DoubleTableResult)rs.getResultAt( rs.findColumn( Measure.PrecisionRecallCurve.getNameString() ) );
REnvironment re = null;
//you need to have a Rserve running
try {
re = new REnvironment(
"localhost",//server name
"",//user name
""//password
);
String snfpr = "points( " + fpr + ", " + sn + ", col=" + 1 + ", pch=" +4 + ", cex=2, lwd=3 );\n";
String ppvsn = "points( " + sn + ", " + ppv + ", col=" + 1 + ", pch=" +4 + ", cex=2, lwd=3 );\n";
re.voidEval( "p<-palette();p[8]<-\"gray66\";palette(p);" );
re.plotToPDF( DoubleTableResult.getPlotCommands( re, null, new int[]{8}, roc ).toString()+"\n"+snfpr,4,4.5, home+"roc.pdf",true);
re.plotToPDF( DoubleTableResult.getPlotCommands( re, null, new int[]{8}, pr ).toString()+"\n"+ppvsn, 4,4.5, home+"pr.pdf",true);
} catch( Exception e ) {
System.out.println( "could not plot the curves" );
} finally {
if( re != null ) {
re.close();
}
}
separator();
/* hold-out sampling */
//define the measures that shall be evaluated
mp = new MeasureParameters(
false,//only evaluate numerical performance measures
0.999,//use specificity of 0.999 to measure the sensitivity
0.95,//use sensitivity of 0.95 to measure the false positive rate
0.95//use sensitivity of 0.95 to measure the positive predictive value
);
//create the parameters for the hold-out sampling
RepeatedHoldOutAssessParameterSet parsA = new RepeatedHoldOutAssessParameterSet(
Sample.PartitionMethod.PARTITION_BY_NUMBER_OF_SYMBOLS,//defines the way of splitting the data
fgData.getElementLength(), //defines the length of the elements in the test data set
true,//a switch that decides whether to throw an exception if a performance measure can not be evaluated
1000,//the number of samplings
new double[]{0.1,0.1}//the partition of the data of each class the will be used for testing
);
//creates an hold-out experiment for two classifiers (the generative and the discriminative classifier)
RepeatedHoldOutExperiment exp = new RepeatedHoldOutExperiment(cl,cll);
//does the experiment a stores the results in a ListResult
ListResult lr = exp.assess(
mp,//the measures that will be computed
parsA,//the parameters for the experiment
fgData,//the foreground data
bgData//the background data
);
System.out.println(lr);
separator();
/* prediction */
//re-train discriminative classifier
cll.train( fgData, bgData );
//load data for prediction
Sample promoters = new DNASample(home+"human_promoters.fa");
//find best possible binding site
int si=0,id=0;
double llr, max=Double.NEGATIVE_INFINITY;
PrintWriter out = new PrintWriter( home+"/allscores.txt" );
int i = 0;
//check all sequence
for( Sequence seq : promoters ){
//check each possible start position
for(int l=0;l<seq.getLength()-cll.getLength()+1;l++){
Sequence sub = seq.getSubSequence( l, cll.getLength() );
//compute likelihood ratio
llr = cll.getScore( sub, 0 ) - cll.getScore( sub, 1 );
out.print( llr + "\t" );
if(llr > max){
//set new best likelihood ratio, sequence, and site
max = llr;
si = i;
id = l;
}
}
out.println();
i++;
}
out.close();
//write a file for the best prediction
Sequence bestSequence = promoters.getElementAt( si );
out = new PrintWriter(home+"/scores.txt");
//write the sequence
out.println(bestSequence.toString("\t", id-30,id+30));
//write the log likelihood ratio that can be used to plot a profile
for(int l=id-30;l<id+30;l++){
Sequence site = bestSequence.getSubSequence( l, cll.getLength() );
out.print((cll.getScore( site, 0 ) - cll.getScore( site, 1 ))+"\t");
}
out.println();
out.close();
}
//method for obtaining reproduceably the same split of some given data
private static Sample[] bisect(Sample data, int l) throws Exception {
int mid = data.getNumberOfElements()/2;
return new Sample[]{
getSubSample(data, 0, mid, "train",l),
getSubSample(data, mid, data.getNumberOfElements(), "test",l)
};
}
//creates a sample from a specific part of the data
private static Sample getSubSample( Sample data, int start, int end, String annotation, int l ) throws Exception {
//copy the sequences into an array
Sequence[] seqs = new Sequence[end-start];
for(int i=0;i<seqs.length;i++){
seqs[i] = data.getElementAt( i+start );
}
return new Sample( new Sample( annotation, seqs ), l );
}
//prints a separator
private static void separator() {
for( int i = 0; i < 50; i++) {
System.out.print("=");
}
System.out.println();
}
}