MiMB Example: Difference between revisions
From Jstacs
Jump to navigationJump to search
No edit summary |
No edit summary |
||
Line 6: | Line 6: | ||
import de.jstacs.algorithms.optimization.Optimizer; | import de.jstacs.algorithms.optimization.Optimizer; | ||
import de.jstacs.classifier.AbstractScoreBasedClassifier; | |||
import de.jstacs.classifier.ConfusionMatrix; | import de.jstacs.classifier.ConfusionMatrix; | ||
import de.jstacs.classifier.MeasureParameters; | import de.jstacs.classifier.MeasureParameters; | ||
Line 85: | Line 86: | ||
//create generative classifier from the models defined before | //create generative classifier from the models defined before | ||
ModelBasedClassifier cl = new ModelBasedClassifier(fgModel,bgModel); | ModelBasedClassifier cl = new ModelBasedClassifier(fgModel, bgModel); | ||
cl.train( fgData, bgData ); | cl.train( fgData, bgData ); | ||
Line 92: | Line 93: | ||
//create set of parameters for foreground scoring function | //create set of parameters for foreground scoring function | ||
BayesianNetworkScoringFunctionParameterSet | BayesianNetworkScoringFunctionParameterSet parsD = new BayesianNetworkScoringFunctionParameterSet( | ||
fgData.getAlphabetContainer(),//used alphabets | fgData.getAlphabetContainer(),//used alphabets | ||
fgData.getElementLength(),//element length == sequence length of each sequence in the sample | fgData.getElementLength(),//element length == sequence length of each sequence in the sample | ||
Line 100: | Line 101: | ||
); | ); | ||
//create foreground scoring function from these parameters | //create foreground scoring function from these parameters | ||
BayesianNetworkScoringFunction fgFun = new BayesianNetworkScoringFunction( | BayesianNetworkScoringFunction fgFun = new BayesianNetworkScoringFunction(parsD); | ||
//analogously, create the background scoring function | //analogously, create the background scoring function | ||
BayesianNetworkScoringFunctionParameterSet | BayesianNetworkScoringFunctionParameterSet parsDbg = new BayesianNetworkScoringFunctionParameterSet( | ||
fgData.getAlphabetContainer(), | fgData.getAlphabetContainer(), | ||
fgData.getElementLength(), | fgData.getElementLength(), | ||
Line 109: | Line 110: | ||
true, | true, | ||
new InhomogeneousMarkov(0)); | new InhomogeneousMarkov(0)); | ||
BayesianNetworkScoringFunction bgFun = new BayesianNetworkScoringFunction( | BayesianNetworkScoringFunction bgFun = new BayesianNetworkScoringFunction(parsDbg); | ||
//create set of parameter for the discriminative classifier | //create set of parameter for the discriminative classifier | ||
Line 132: | Line 133: | ||
); | ); | ||
//train the discriminative classifier | |||
//cll.train( fgData, bgData ); TODO | |||
/* performance measures */ | /* performance measures */ | ||
Line 228: | Line 231: | ||
//create the parameters for the hold-out sampling | //create the parameters for the hold-out sampling | ||
RepeatedHoldOutAssessParameterSet | RepeatedHoldOutAssessParameterSet parsA = new RepeatedHoldOutAssessParameterSet( | ||
Sample.PartitionMethod.PARTITION_BY_NUMBER_OF_SYMBOLS,//defines the way of splitting the data | Sample.PartitionMethod.PARTITION_BY_NUMBER_OF_SYMBOLS,//defines the way of splitting the data | ||
fgData.getElementLength(), //defines the length of the elements in the test data set | fgData.getElementLength(), //defines the length of the elements in the test data set | ||
Line 234: | Line 237: | ||
1000,//the number of samplings | 1000,//the number of samplings | ||
new double[]{0.1,0.1}//the partition of the data of each class the will be used for testing | new double[]{0.1,0.1}//the partition of the data of each class the will be used for testing | ||
); | ); | ||
//creates an hold-out experiment for two classifiers (the generative and the discriminative classifier) | //creates an hold-out experiment for two classifiers (the generative and the discriminative classifier) | ||
Line 241: | Line 244: | ||
ListResult lr = exp.assess( | ListResult lr = exp.assess( | ||
mp,//the measures that will be computed | mp,//the measures that will be computed | ||
parsA,//the parameters for the experiment | |||
fgData,//the foreground data | fgData,//the foreground data | ||
bgData//the background data | bgData//the background data | ||
Line 251: | Line 254: | ||
/* prediction */ | /* prediction */ | ||
//train discriminative classifier | //re-train discriminative classifier | ||
cll.train( fgData, bgData ); | cll.train( fgData, bgData ); | ||
//load data for prediction | //load data for prediction | ||
Sample | Sample promoters = new DNASample(home+"human_promoters.fa"); | ||
//find best possible binding site | //find best possible binding site | ||
Line 261: | Line 264: | ||
double llr, max=Double.NEGATIVE_INFINITY; | double llr, max=Double.NEGATIVE_INFINITY; | ||
PrintWriter out = new PrintWriter( home+"/allscores.txt" ); | |||
int i = 0; | |||
//check all sequence | //check all sequence | ||
for( | for( Sequence seq : promoters ){ | ||
//check each possible start position | //check each possible start position | ||
for(int l=0;l<seq.getLength()-cll.getLength()+1;l++){ | for(int l=0;l<seq.getLength()-cll.getLength()+1;l++){ | ||
Line 269: | Line 274: | ||
//compute likelihood ratio | //compute likelihood ratio | ||
llr = cll.getScore( sub, 0 ) - cll.getScore( sub, 1 ); | llr = cll.getScore( sub, 0 ) - cll.getScore( sub, 1 ); | ||
out.print( llr + "\t" ); | |||
if(llr > max){ | if(llr > max){ | ||
//set new best likelihood ratio, sequence, and site | //set new best likelihood ratio, sequence, and site | ||
Line 276: | Line 282: | ||
} | } | ||
} | } | ||
out.println(); | |||
i++; | |||
} | } | ||
out.close(); | |||
//write a file for the best prediction | //write a file for the best prediction | ||
Sequence bestSequence = | Sequence bestSequence = promoters.getElementAt( si ); | ||
out = new PrintWriter(home+"/scores.txt"); | |||
//write the sequence | //write the sequence | ||
out.println(bestSequence.toString("\t", id-30,id+30)); | out.println(bestSequence.toString("\t", id-30,id+30)); |
Revision as of 23:23, 18 November 2009
This class contains all code snippets used in the chapter Probabilistic Approaches to Transcription Factor Binding Site Prediction of the book Computational Biology of Transcription Factors of the series Methods in Molecular Biology:
package supplementary.codeExamples;
import java.io.PrintWriter;
import de.jstacs.algorithms.optimization.Optimizer;
import de.jstacs.classifier.AbstractScoreBasedClassifier;
import de.jstacs.classifier.ConfusionMatrix;
import de.jstacs.classifier.MeasureParameters;
import de.jstacs.classifier.AbstractScoreBasedClassifier.DoubleTableResult;
import de.jstacs.classifier.MeasureParameters.Measure;
import de.jstacs.classifier.assessment.RepeatedHoldOutAssessParameterSet;
import de.jstacs.classifier.assessment.RepeatedHoldOutExperiment;
import de.jstacs.classifier.modelBased.ModelBasedClassifier;
import de.jstacs.classifier.scoringFunctionBased.OptimizableFunction.KindOfParameter;
import de.jstacs.classifier.scoringFunctionBased.gendismix.GenDisMixClassifierParameterSet;
import de.jstacs.classifier.scoringFunctionBased.logPrior.CompositeLogPrior;
import de.jstacs.classifier.scoringFunctionBased.msp.MSPClassifier;
import de.jstacs.data.DNASample;
import de.jstacs.data.Sample;
import de.jstacs.data.Sequence;
import de.jstacs.models.Model;
import de.jstacs.models.VariableLengthWrapperModel;
import de.jstacs.models.discrete.inhomogeneous.BayesianNetworkModel;
import de.jstacs.models.discrete.inhomogeneous.StructureLearner.LearningType;
import de.jstacs.models.discrete.inhomogeneous.StructureLearner.ModelType;
import de.jstacs.models.discrete.inhomogeneous.parameters.BayesianNetworkModelParameterSet;
import de.jstacs.results.ListResult;
import de.jstacs.results.ResultSet;
import de.jstacs.scoringFunctions.directedGraphicalModels.BayesianNetworkScoringFunction;
import de.jstacs.scoringFunctions.directedGraphicalModels.BayesianNetworkScoringFunctionParameterSet;
import de.jstacs.scoringFunctions.directedGraphicalModels.structureLearning.measures.InhomogeneousMarkov;
import de.jstacs.utils.REnvironment;
/**
* This class implements a main that shows some features of Jstacs including models for generative and discriminative learning,
* creation of classifiers, evaluation of classifiers, hold-out sampling, and binding site prediction.
*
* @author Jan Grau, Jens Keilwagen
*/
public class MiMBExample {
/**
* @param args only the first parameter will be used; it determines the home directory
*/
public static void main( String[] args ) throws Exception {
String home = args[0];
/* read data */
//read foreground and background data set form FastA-files
Sample fgData = new DNASample(home+"argrpr.fa");
Sample bgData = new DNASample(home+"2ndexons.fa");
/* generative part */
//create set of parameters for foreground model
BayesianNetworkModelParameterSet pars = new BayesianNetworkModelParameterSet(
fgData.getAlphabetContainer(),//used alphabets
fgData.getElementLength(),//element length == sequence length of each sequence in the sample
4,//ESS == equivalent sample size (has to be non-negative)
"fg model",//user description of the model
ModelType.IMM,//type of statistical model, here an inhomogeneous Markov model (IMM)
(byte)0,// model order, here 0, so we get an IMM(0) == PWM = position weight matrix
LearningType.ML_OR_MAP//how to learn the parameters, depends on ESS; for ESS=0 it is ML otherwise MAP
);
//create foreground model from these parameters
Model fgModel = new BayesianNetworkModel(pars);
//analogously, create the background model
BayesianNetworkModelParameterSet pars2 = new BayesianNetworkModelParameterSet(
fgData.getAlphabetContainer(),
fgData.getElementLength(),
1024,
"bg model",
ModelType.IMM,
(byte)0,
LearningType.ML_OR_MAP
);
Model bgModel = new BayesianNetworkModel(pars2);
bgModel = new VariableLengthWrapperModel(bgModel);
//create generative classifier from the models defined before
ModelBasedClassifier cl = new ModelBasedClassifier(fgModel, bgModel);
cl.train( fgData, bgData );
/* discriminative part */
//create set of parameters for foreground scoring function
BayesianNetworkScoringFunctionParameterSet parsD = new BayesianNetworkScoringFunctionParameterSet(
fgData.getAlphabetContainer(),//used alphabets
fgData.getElementLength(),//element length == sequence length of each sequence in the sample
4,//ESS == equivalent sample size (has to be non-negative)
true,//use MAP-parameters as start values of the optimization
new InhomogeneousMarkov(0) //the statistical model, here an IMM(0) == PWM
);
//create foreground scoring function from these parameters
BayesianNetworkScoringFunction fgFun = new BayesianNetworkScoringFunction(parsD);
//analogously, create the background scoring function
BayesianNetworkScoringFunctionParameterSet parsDbg = new BayesianNetworkScoringFunctionParameterSet(
fgData.getAlphabetContainer(),
fgData.getElementLength(),
1024,
true,
new InhomogeneousMarkov(0));
BayesianNetworkScoringFunction bgFun = new BayesianNetworkScoringFunction(parsDbg);
//create set of parameter for the discriminative classifier
GenDisMixClassifierParameterSet clPars = new GenDisMixClassifierParameterSet(
fgData.getAlphabetContainer(),//used alphabets
fgData.getElementLength(),//element length == sequence length of each sequence in the samples
Optimizer.QUASI_NEWTON_BFGS,//determines the algorithm for numerical optimization
1E-6,//epsilon to stop the numerical optimization
1E-6,//epsilon to stop the line search within the numerical optimization
1,//start step width in the numerical optimization
false,//a switch that decides whether to use only the free parameters or all parameters
KindOfParameter.PLUGIN,//a switch to decide which start parameters to choose
true,//a switch that states the objective function will be normalized
1//number of threads used during optimization
);
//create discriminative classifier from the parameters and the scoring function defined before
MSPClassifier cll = new MSPClassifier(
clPars,//the parameters of the classifier
new CompositeLogPrior(),//the used prior, to obtain MCL use null
fgFun,//the scoring function for the foreground class
bgFun//the scoring function for the background class
);
//train the discriminative classifier
//cll.train( fgData, bgData ); TODO
/* performance measures */
//partition data
Sample[] fgSplit = bisect( fgData, fgData.getElementLength() );
Sample[] bgSplit = bisect( bgData, fgData.getElementLength() );
Sample fgTest = fgSplit[1];
Sample bgTest = bgSplit[1];
//train the generative classifier
cl.train( fgSplit[0], bgSplit[0] );
//fill a confusion matrix
ConfusionMatrix confMatrix = cl.test( fgTest, bgTest );
//read the entries of the table
double tp = confMatrix.getCountsFor(0, 0);
double fn = confMatrix.getCountsFor(1, 0);
double tn = confMatrix.getCountsFor(1, 1);
double fp = confMatrix.getCountsFor(0, 1);
double p = tp+fn;
double barp = tp+fp;
double n = tn+fp;
double barn = tn+fn;
System.out.println("TP = "+tp+"\t\tFP = "+fp+"\t\tbarp = "+barp+"\n" +
"FN = "+fn+"\t\tTN = "+tn+"\t\tbarn = "+barn+"\n" +
"p = "+p+"\t\tn = "+n+"\t\tN' = "+(n+p));
//compute the measures
double sn = tp/p;
double ppv = tp/barp;
double fpr = fp/n;
double sp = tn/n;
double cr = (tp+tn)/(n+p);
System.out.println("cr = "+cr+"\nSn = "+sn+"\nppv = "+ppv+"\nSp = "+sp+"\nfpr = "+fpr+"\n");
//define the measures that shall be evaluated
MeasureParameters mp = new MeasureParameters(
true,//evaluate all performance measures
0.999,//use specificity of 0.999 to measure the sensitivity
0.95,//use sensitivity of 0.95 to measure the false positive rate
0.95//use sensitivity of 0.95 to measure the positive predictive value
);
//evaluates the classifier
ResultSet rs = cl.evaluateAll(
mp,//defines the measures that will be evaluated
true,//allows to throw an exception if a measure can not be computed
fgTest,//the test data for the foreground class
bgTest//the test data for the background class
);
System.out.println(rs);
//plot ROC and PR curve
DoubleTableResult roc = (DoubleTableResult)rs.getResultAt( rs.findColumn( Measure.ReceiverOperatingCharacteristicCurve.getNameString() ) );
DoubleTableResult pr = (DoubleTableResult)rs.getResultAt( rs.findColumn( Measure.PrecisionRecallCurve.getNameString() ) );
REnvironment re = null;
//you need to have a Rserve running
try {
re = new REnvironment(
"localhost",//server name
"",//user name
""//password
);
String snfpr = "points( " + fpr + ", " + sn + ", col=" + 1 + ", pch=" +4 + ", cex=2, lwd=3 );\n";
String ppvsn = "points( " + sn + ", " + ppv + ", col=" + 1 + ", pch=" +4 + ", cex=2, lwd=3 );\n";
re.voidEval( "p<-palette();p[8]<-\"gray66\";palette(p);" );
re.plotToPDF( DoubleTableResult.getPlotCommands( re, null, new int[]{8}, roc ).toString()+"\n"+snfpr,4,4.5, home+"roc.pdf",true);
re.plotToPDF( DoubleTableResult.getPlotCommands( re, null, new int[]{8}, pr ).toString()+"\n"+ppvsn, 4,4.5, home+"pr.pdf",true);
} catch( Exception e ) {
System.out.println( "could not plot the curves" );
} finally {
if( re != null ) {
re.close();
}
}
separator();
/* hold-out sampling */
//define the measures that shall be evaluated
mp = new MeasureParameters(
false,//only evaluate numerical performance measures
0.999,//use specificity of 0.999 to measure the sensitivity
0.95,//use sensitivity of 0.95 to measure the false positive rate
0.95//use sensitivity of 0.95 to measure the positive predictive value
);
//create the parameters for the hold-out sampling
RepeatedHoldOutAssessParameterSet parsA = new RepeatedHoldOutAssessParameterSet(
Sample.PartitionMethod.PARTITION_BY_NUMBER_OF_SYMBOLS,//defines the way of splitting the data
fgData.getElementLength(), //defines the length of the elements in the test data set
true,//a switch that decides whether to throw an exception if a performance measure can not be evaluated
1000,//the number of samplings
new double[]{0.1,0.1}//the partition of the data of each class the will be used for testing
);
//creates an hold-out experiment for two classifiers (the generative and the discriminative classifier)
RepeatedHoldOutExperiment exp = new RepeatedHoldOutExperiment(cl,cll);
//does the experiment a stores the results in a ListResult
ListResult lr = exp.assess(
mp,//the measures that will be computed
parsA,//the parameters for the experiment
fgData,//the foreground data
bgData//the background data
);
System.out.println(lr);
separator();
/* prediction */
//re-train discriminative classifier
cll.train( fgData, bgData );
//load data for prediction
Sample promoters = new DNASample(home+"human_promoters.fa");
//find best possible binding site
int si=0,id=0;
double llr, max=Double.NEGATIVE_INFINITY;
PrintWriter out = new PrintWriter( home+"/allscores.txt" );
int i = 0;
//check all sequence
for( Sequence seq : promoters ){
//check each possible start position
for(int l=0;l<seq.getLength()-cll.getLength()+1;l++){
Sequence sub = seq.getSubSequence( l, cll.getLength() );
//compute likelihood ratio
llr = cll.getScore( sub, 0 ) - cll.getScore( sub, 1 );
out.print( llr + "\t" );
if(llr > max){
//set new best likelihood ratio, sequence, and site
max = llr;
si = i;
id = l;
}
}
out.println();
i++;
}
out.close();
//write a file for the best prediction
Sequence bestSequence = promoters.getElementAt( si );
out = new PrintWriter(home+"/scores.txt");
//write the sequence
out.println(bestSequence.toString("\t", id-30,id+30));
//write the log likelihood ratio that can be used to plot a profile
for(int l=id-30;l<id+30;l++){
Sequence site = bestSequence.getSubSequence( l, cll.getLength() );
out.print((cll.getScore( site, 0 ) - cll.getScore( site, 1 ))+"\t");
}
out.println();
out.close();
}
//method for obtaining reproduceably the same split of some given data
private static Sample[] bisect(Sample data, int l) throws Exception {
int mid = data.getNumberOfElements()/2;
return new Sample[]{
getSubSample(data, 0, mid, "train",l),
getSubSample(data, mid, data.getNumberOfElements(), "test",l)
};
}
//creates a sample from a specific part of the data
private static Sample getSubSample( Sample data, int start, int end, String annotation, int l ) throws Exception {
//copy the sequences into an array
Sequence[] seqs = new Sequence[end-start];
for(int i=0;i<seqs.length;i++){
seqs[i] = data.getElementAt( i+start );
}
return new Sample( new Sample( annotation, seqs ), l );
}
//prints a separator
private static void separator() {
for( int i = 0; i < 50; i++) {
System.out.print("=");
}
System.out.println();
}
}