Train classifiers using GenDisMix (a hybrid learning principle)
From Jstacs
Jump to navigationJump to search
//read FastA-files
Sample[] data = {
new DNASample( args[0] ),
new DNASample( args[1] )
};
AlphabetContainer container = data[0].getAlphabetContainer();
int length = data[0].getElementLength();
//equivalent sample size =^= ESS
double essFg = 4, essBg = 4;
//create ScoringFunction, here PWM
NormalizableScoringFunction pwmFg = new BayesianNetworkScoringFunction( container, length, essFg, true, new InhomogeneousMarkov(0) );
NormalizableScoringFunction pwmBg = new BayesianNetworkScoringFunction( container, length, essBg, true, new InhomogeneousMarkov(0) );
//create parameters of the classifier
GenDisMixClassifierParameterSet cps = new GenDisMixClassifierParameterSet(
container,//the used alphabets
length,//sequence length that can be modeled/classified
Optimizer.QUASI_NEWTON_BFGS, 1E-9, 1E-11, 1,//optimization parameter
false,//use free parameters or all
KindOfParameter.PLUGIN,//how to start the numerical optimization
true,//use a normalized objective function
AbstractMultiThreadedOptimizableFunction.getNumberOfAvailableProcessors()//number of compute threads
);
//create classifiers
LearningPrinciple[] lp = LearningPrinciple.values();
GenDisMixClassifier[] cl = new GenDisMixClassifier[lp.length+1];
//elementary learning principles
int i = 0;
for( ; i < cl.length-1; i++ ){
System.out.println( "classifier " + i + " uses " + lp[i] );
cl[i] = new GenDisMixClassifier( cps, new CompositeLogPrior(), lp[i], pwmFg, pwmBg );
}
//use some weighted version of log conditional likelihood, log likelihood, and log prior
double[] beta = {0.3,0.3,0.4};
System.out.println( "classifier " + i + " uses the weights " + Arrays.toString( beta ) );
cl[i] = new GenDisMixClassifier( cps, new CompositeLogPrior(), beta, pwmFg, pwmBg );
//do what ever you like
//e.g., train
for( i = 0; i < cl.length; i++ ){
cl[i].train( data );
}
//e.g., evaluate (normally done on a test data set)
MeasureParameters mp = new MeasureParameters( false, 0.95, 0.999, 0.999 );
for( i = 0; i < cl.length; i++ ){
System.out.println( cl[i].evaluate( mp, true, data ) );
}