Train classifiers using GenDisMix (a hybrid learning principle)

From Jstacs
Jump to navigationJump to search
//read FastA-files
Sample[] data = {
         new DNASample( args[0] ),
         new DNASample( args[1] )
};
AlphabetContainer container = data[0].getAlphabetContainer();
int length = data[0].getElementLength();

//equivalent sample size =^= ESS
double essFg = 4, essBg = 4;
//create ScoringFunction, here PWM
NormalizableScoringFunction pwmFg = new BayesianNetworkScoringFunction( container, length, essFg, true, new InhomogeneousMarkov(0) );
NormalizableScoringFunction pwmBg = new BayesianNetworkScoringFunction( container, length, essBg, true, new InhomogeneousMarkov(0) );

//create parameters of the classifier
GenDisMixClassifierParameterSet cps = new GenDisMixClassifierParameterSet(
		container,//the used alphabets
		length,//sequence length that can be modeled/classified
		Optimizer.QUASI_NEWTON_BFGS, 1E-9, 1E-11, 1,//optimization parameter
		false,//use free parameters or all
		KindOfParameter.PLUGIN,//how to start the numerical optimization
		true,//use a normalized objective function
		AbstractMultiThreadedOptimizableFunction.getNumberOfAvailableProcessors()//number of compute threads		
);

//create classifiers
LearningPrinciple[] lp = LearningPrinciple.values();
GenDisMixClassifier[] cl = new GenDisMixClassifier[lp.length+1];
//elementary learning principles
int i = 0;
for( ; i < cl.length-1; i++ ){
	System.out.println( "classifier " + i + " uses " + lp[i] );
	cl[i] = new GenDisMixClassifier( cps, new CompositeLogPrior(), lp[i], pwmFg, pwmBg );
}

//use some weighted version of log conditional likelihood, log likelihood, and log prior
double[] beta = {0.3,0.3,0.4};
System.out.println( "classifier " + i + " uses the weights " + Arrays.toString( beta ) );
cl[i] = new GenDisMixClassifier( cps, new CompositeLogPrior(), beta, pwmFg, pwmBg );

//do what ever you like

//e.g., train
for( i = 0; i < cl.length; i++ ){
	cl[i].train( data );
}

//e.g., evaluate (normally done on a test data set)
MeasureParameters mp = new MeasureParameters( false, 0.95, 0.999, 0.999 );
for( i = 0; i < cl.length; i++ ){
	System.out.println( cl[i].evaluate( mp, true, data ) );
}