Performing a 10-fold cross validation

From Jstacs
Jump to navigationJump to search
The printable version is no longer supported and may have rendering errors. Please update your browser bookmarks and please use the default browser print function instead.
//create a Sample for each class from the input data, using the DNA alphabet
Sample[] data = new Sample[2];
data[0] = new DNASample( args[0] );

//the length of our input sequences
int length = data[0].getElementLength();

data[1] = new Sample( new DNASample( args[1] ), length );
 
AlphabetContainer container = data[0].getAlphabetContainer();

//create a new PWM
BayesianNetworkModel pwm = new BayesianNetworkModel( new BayesianNetworkModelParameterSet(
		//the alphabet and the length of the model:
		container, length, 
		//the equivalent sample size to compute hyper-parameters
		4, 
		//some identifier for the model
		"my PWM", 
		//we want a PWM, which is an inhomogeneous Markov model (IMM) of order 0
		ModelType.IMM, (byte) 0, 
		//we want to estimate the MAP-parameters
		LearningType.ML_OR_MAP ) );
 
//create a new mixture model using 2 PWMs
MixtureModel mixPwms = new MixtureModel(
		//the length of the mixture model
		length, 
		//the two components, which are PWMs
		new Model[]{pwm,pwm},
		//the number of starts of the EM
		10,
		//the equivalent sample sizes
		new double[]{pwm.getESS(),pwm.getESS()},
		//the hyper-parameters to draw the initial sequence-specific component weights (hidden variables)
		1,
		//stopping criterion
		new SmallDifferenceOfFunctionEvaluationsCondition(1E-6),
		//parameterization of the model, LAMBDA complies with the
		//parameterization by log-probabilities
		Parameterization.LAMBDA);
 
//create a new inhomogeneous Markov model of order 3
BayesianNetworkModel mm = new BayesianNetworkModel( 
		new BayesianNetworkModelParameterSet( container, length, 256, "my iMM(3)", ModelType.IMM, (byte) 3, LearningType.ML_OR_MAP ) );
 
//create a new PWM scoring function
BayesianNetworkScoringFunction dPwm = new BayesianNetworkScoringFunction(
		//the alphabet and the length of the scoring function
		container, length, 
		//the equivalent sample size for the plug-in parameters
		4, 
		//we use plug-in parameters
		true, 
		//a PWM is an inhomogeneous Markov model of order 0
		new InhomogeneousMarkov(0));
 
//create a new mixture scoring function
MixtureScoringFunction dMixPwms = new MixtureScoringFunction(
		//the number of starts
		2,
		//we use plug-in parameters
		true,
		//the two components, which are PWMs
		dPwm,dPwm);
 
//create a new scoring function that is an inhomogeneous Markov model of order 3
BayesianNetworkScoringFunction dMm = new BayesianNetworkScoringFunction(container, length, 4, true, new InhomogeneousMarkov(3));
 
//create the classifiers
int threads = AbstractMultiThreadedOptimizableFunction.getNumberOfAvailableProcessors();
AbstractScoreBasedClassifier[] classifiers = new AbstractScoreBasedClassifier[]{
							   //model based with mixture model and Markov model
							   new ModelBasedClassifier( mixPwms, mm ),
							   //conditional likelihood based classifier
							   new MSPClassifier( new GenDisMixClassifierParameterSet(container, length, 
									   //method for optimizing the conditional likelihood and 
									   //other parameters of the numerical optimization
									   Optimizer.QUASI_NEWTON_BFGS, 1E-2, 1E-2, 1, true, KindOfParameter.PLUGIN, false, threads),
									   //mixture scoring function and Markov model scoring function
									   dMixPwms,dMm )
};
 
//create an new k-fold cross validation using above classifiers
KFoldCrossValidation cv = new KFoldCrossValidation( classifiers );
 
//we use a specificity of 0.999 to compute the sensitivity and a sensitivity of 0.95 to compute FPR and PPV
MeasureParameters mp = new MeasureParameters(false, 0.999, 0.95, 0.95);
//we do a 10-fold cross validation and partition the data by means of the number of symbols
KFoldCVAssessParameterSet cvpars = new KFoldCVAssessParameterSet(PartitionMethod.PARTITION_BY_NUMBER_OF_SYMBOLS, length, true, 10);
 
//compute the result of the cross validation and print them to System.out
System.out.println( cv.assess( mp, cvpars, data ) );