package projects;

import java.io.File;
import java.util.Arrays;

import de.jstacs.DataType;
import de.jstacs.algorithms.optimization.ConstantStartDistance;
import de.jstacs.classifier.scoringFunctionBased.OptimizableFunction.KindOfParameter;
import de.jstacs.classifier.scoringFunctionBased.cll.CLLClassifier;
import de.jstacs.classifier.scoringFunctionBased.cll.CLLClassifierParameterSet;
import de.jstacs.classifier.scoringFunctionBased.cll.NormConditionalLogLikelihood;
import de.jstacs.classifier.scoringFunctionBased.logPrior.CompositeLogPrior;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.Sample;
import de.jstacs.data.Sequence;
import de.jstacs.data.alphabets.DNAAlphabet;
import de.jstacs.data.sequences.annotation.MotifAnnotation;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.FileManager;
import de.jstacs.io.SparseStringExtractor;
import de.jstacs.io.XMLParser;
import de.jstacs.motifDiscovery.MotifDiscoverer;
import de.jstacs.motifDiscovery.MutableMotifDiscovererToolbox;
import de.jstacs.motifDiscovery.SignificantMotifOccurrencesFinder;
import de.jstacs.motifDiscovery.MutableMotifDiscovererToolbox.InitMethodForScoringFunction;
import de.jstacs.motifDiscovery.SignificantMotifOccurrencesFinder.RandomSeqType;
import de.jstacs.motifDiscovery.history.RestrictedRepeatHistory;
import de.jstacs.parameters.Parameter;
import de.jstacs.parameters.ParameterSet;
import de.jstacs.parameters.SimpleParameter;
import de.jstacs.parameters.validation.NumberValidator;
import de.jstacs.scoringFunctions.AbstractNormalizableScoringFunction;
import de.jstacs.scoringFunctions.NormalizableScoringFunction;
import de.jstacs.scoringFunctions.directedGraphicalModels.MutableMarkovModelScoringFunction;
import de.jstacs.scoringFunctions.directedGraphicalModels.structureLearning.measures.InhomogeneousMarkov;
import de.jstacs.scoringFunctions.homogeneous.HMMScoringFunction;
import de.jstacs.scoringFunctions.mix.StrandScoringFunction;
import de.jstacs.scoringFunctions.mix.StrandScoringFunction.InitMethod;
import de.jstacs.scoringFunctions.mix.motifSearch.DurationScoringFunction;
import de.jstacs.scoringFunctions.mix.motifSearch.HiddenMotifsMixture;
import de.jstacs.scoringFunctions.mix.motifSearch.SkewNormalLikeScoringFunction;
import de.jstacs.utils.ComparableElement;
import de.jstacs.utils.SafeOutputStream;

/**
 * Discriminative de-novo motif discovery for single hidden motifs.
 * 
 * @author Jens Keilwagen
 */
public class DiPoMM {
	
	private static void setHyper( HMMScoringFunction func, double essMotif, double essNonMotif, int length, int motifL ) throws Exception
	{
		int[] len = new int[length+1];
		double[] weight = new double[len.length];
		
		int i = 0, m;
		for( ; i < len.length; i++ )
		{
			len[i] = i;
		}
		weight[length] = essNonMotif;
		
		m = (length-motifL + 1);
		
		double part = essMotif / m;
		for( i = 0; i < m; i++ )
		{
			weight[i] += 2 * part;
		}
		
		func.setStatisticForHyperparameters( len, weight );
	}
	
	/**
	 * This is the main of DiPoMM that starts the program. 
	 * 
	 * @param args the arguments for DiPoMM. Each argument has the form <code>name=value</code>.
	 * 
	 * @throws Exception if something went wrong.
	 */
	public static void main( String[] args ) throws Exception {
		
		DiPoMMParameterSet params = new DiPoMMParameterSet( args );
		System.out.println( "parameters:" );
		System.out.println( params );
		System.out.println("_________________________________");
		if( !params.hasDefaultOrIsSet() ) {
			System.out.println( "Some of the required parameters are not specified." );
			System.exit( 1 );
		}
		

		AlphabetContainer con = new AlphabetContainer( new DNAAlphabet() );


		//load data
		String home = params.getPathHomeDirectory();
		char ignore = params.getIgnoreChar();
		Sample[] data = {
			new Sample( con, new SparseStringExtractor( new File( home + File.separatorChar + params.getFgFileName() ), ignore ) ),
			new Sample( con, new SparseStringExtractor( new File( home + File.separatorChar + params.getBgFileName() ), ignore ) ),
		};
		double[][] weights = new double[2][];
		int sl = data[0].getElementLength();
		for( int i = 0; i < data.length; i++ )
		{
			if( data[i].getElementLength() != sl ) {
				throw new IllegalArgumentException( "The sequences in the foreground and in the background file have to have the same fixed length." );
			}
			weights[i] = new double[data[i].getNumberOfElements()];
			Arrays.fill( weights[i], 1 );
			System.out.println( i + "\t# = " + data[i].getNumberOfElements() + "\tlength = " + data[i].getElementLength() + "\t" + data[i].getAnnotation() );
		}
			
	
		// create functions
		int starts = 1, flOrder = params.getFlankingOrder(), fgOrder = params.getMotifOrder(), motifL = params.getMotifLength();
		double essMotif = 4, essNonMotif = 1;
		boolean free = false;
				
		HMMScoringFunction flanking = new HMMScoringFunction( con, flOrder, essNonMotif, new double[flOrder+1], true, true, starts );
		HMMScoringFunction bg = new HMMScoringFunction( con, flOrder, essMotif + essNonMotif, new double[flOrder+1], true, true, starts );
		setHyper( flanking, essMotif, essNonMotif, sl, motifL );
		setHyper( bg, 0, essMotif+essNonMotif, sl, motifL );
		
		SkewNormalLikeScoringFunction motifPenalty = new SkewNormalLikeScoringFunction(1,50,true, 10d,1,true,1,25,true,0,2,10);;
		motifPenalty.setParameters( new double[]{-1.5, -2.5, 4.2}, 0 );
		System.out.println( motifPenalty );

		NormalizableScoringFunction motif = new MutableMarkovModelScoringFunction( con, motifL, essMotif, true, new InhomogeneousMarkov(fgOrder), motifPenalty );
		if( params.useBothStrands() ) {
			motif = new StrandScoringFunction( motif, 0.5, starts, true, InitMethod.INIT_FORWARD_STRAND );
		}
		motif.initializeFunctionRandomly( free );
		
		double seqs = 1, sd = 150;
		DurationScoringFunction pos =
			//new UniformDurationScoringFunction(0,sl-motifL);
			new SkewNormalLikeScoringFunction(0,sl-motifL, true,(sl-motifL)/2d,sl/2d, true,seqs/2d,seqs/2d*(sd*sd), true,0,0.5, starts);
		
		HiddenMotifsMixture fg = new HiddenMotifsMixture( HiddenMotifsMixture.CONTAINS_SOMETIMES_A_MOTIF, sl, starts, true,	flanking, motif, pos, false );
		
		NormalizableScoringFunction[] score = { fg, bg };

		// optimize
		double eps = 1E-7, lineps = 1E-10, startD = 1;
		NormConditionalLogLikelihood func = new NormConditionalLogLikelihood( score, data, weights, new CompositeLogPrior(), true, false );
		SafeOutputStream stream = new SafeOutputStream( null );
		
		String initMethod = params.getInitializationMethod();
		String v = initMethod.substring( initMethod.indexOf( '=' ) + 1 );
		System.out.println( initMethod );
		NormConditionalLogLikelihood mcl = new NormConditionalLogLikelihood( score, data, weights, true, false );
		if( initMethod.startsWith( "best-random=" ) ) {
			ComparableElement<double[], Double>[] pars = MutableMotifDiscovererToolbox.getSortedInitialParameters( data, score, new InitMethodForScoringFunction[]{InitMethodForScoringFunction.PLUG_IN, InitMethodForScoringFunction.NOTHING}, mcl, Integer.parseInt( v ), stream );
			func.reset( score );
			func.setParams( pars[pars.length-1].getElement() );
		} else {			
			boolean sp = initMethod.startsWith( "specific=" );
			Sample s = null;
			int len;
			double[] w = null;
			if( sp ) {
				File f = new File( v );
				if( f.exists() ) {
					s = new Sample( con, new SparseStringExtractor( f, ignore ) );
				} else {
					s = new Sample( "one sequence", Sequence.create( con, v.trim() ) );
					w = new double[]{ 6 };
				}
				len = s.getElementLength();
			} else if( initMethod.startsWith( "enum=" ) ) {
				len = Integer.parseInt( initMethod.substring( 5 ) );
			} else {
				throw new IllegalArgumentException( "Unknown initialization method choice." );
			}
			if( len != motifL ) {
				fg.modifyMotif( 0, 0, len - motifL );
			}
			
			if( !sp ) {
				w = new double[]{ 6 };
				s = new Sample( "best sequence", MutableMotifDiscovererToolbox.enumerate( data, score, 0, 0, w[0], mcl, System.out ) );
			}
			fg.initializeMotif( 0, s, w );

			if( len != motifL ) {
				len = -len + motifL;
				fg.modifyMotif( 0, -len/2, len -len/2  );
			}
		}
		System.out.println( fg );
		System.out.println("_________________________________");
		
		stream = new SafeOutputStream( System.out );
		boolean adjust = params.adjust();
		double[][] res = MutableMotifDiscovererToolbox.optimize( score, func, (byte) 10, eps, lineps, new ConstantStartDistance(startD),
				stream, false, new RestrictedRepeatHistory(true,adjust,adjust,false,1), KindOfParameter.PLUGIN );

		// show
		System.out.println("_________________________________");
		System.out.println( score[0] );
		System.out.println( "result: " + res[0][0] );
		
		// save classifier
		CLLClassifierParameterSet cps = new CLLClassifierParameterSet( con, sl, (byte) 10, eps, lineps, startD, free, KindOfParameter.PLUGIN, true );
		CLLClassifier cl = new CLLClassifier( cps, new CompositeLogPrior(), (AbstractNormalizableScoringFunction[]) ArrayHandler.cast( score ) );
		cl.setClassWeights( false, res[1] );
		StringBuffer sb = new StringBuffer( 100000 );
		XMLParser.appendStorableWithTags( sb, cl, "classifier" );
		FileManager.writeFile( new File( params.getXMLFileName() ), sb );
		
		
		//predict
		if( params.hasPValue() ) {
			System.out.println("_________________________________");
			SignificantMotifOccurrencesFinder smof = new SignificantMotifOccurrencesFinder( (MotifDiscoverer) score[0], RandomSeqType.PERMUTED, 1000, params.getPValue() );
			Sequence seq;
			MotifAnnotation[] ma;
			System.out.println( "prediction" );
			System.out.println();
			System.out.println( "sequence\tposition\tstrand\tbinding site\tp-value" );
			System.out.println( "------------------------------------------------------------------------");
			for( int j, i = 0; i < data[0].getNumberOfElements(); i++ ) {
				seq = data[0].getElementAt( i );
				ma = smof.findSignificantMotifOccurrences( 0, seq, 0 );
				if( !(ma == null || ma.length == 0) ) {
					for( j = 0; j < ma.length; j++ ) {
						System.out.println( i + "\t" + ma[j].getPosition() + "\t" + ma[j].getStrandedness() + "\t" + seq.getSubSequence( ma[j].getPosition(), ma[j].getLength() ) + "\t" + ma[j].getAnnotations()[1].getResult() );
					}
				}
			}
		}
	}
}

/**
 * This class is a container for all parameters of DiPoMM. It also parses the parameter from Strings.
 *  
 * @author Jens Keilwagen
 */
class DiPoMMParameterSet extends ParameterSet {

	private static final String[] PREFIX = { "home", "ignore", "fg", "bg", "length", "flankOrder", "motifOrder", "bothStrands", "init", "xml", "adjust", "p-val" }; 
	
	public DiPoMMParameterSet( String... args ) throws Exception {
		loadParameters();
		boolean[] set = new boolean[parameters.size()];
		Arrays.fill( set, false );
		for( int idx, i = 0; i < args.length; i++ ) {
			idx = setParameter( args[i] );
			if( set[idx] ) {
				Parameter p = parameters.get( idx );
				throw new IllegalArgumentException( "Confusion: The parameter (" + p.getName() + " (" + p.getComment() + ") has been specified at least twice." );
			} else {
				set[idx] = true;
			}
		}
	}
	
	public int setParameter( String taggedValue ) throws IllegalArgumentException {
		int idx = 0;
		while( idx < PREFIX.length && !taggedValue.startsWith( PREFIX[idx] + "=" ) ) {
			idx++;
		}
		if( idx >= parameters.size() ) {
			throw new IllegalArgumentException( "Could not set specified value (" + taggedValue + ") since the tag is unknown.");
		} else {
			try {
				parameters.get( idx ).setValue( taggedValue.substring( PREFIX[idx].length() + 1 ) );
			}catch( Exception e ) {
				throw new IllegalArgumentException( "Could not set \"" + taggedValue + "\" for parameter\n" + parameters.get( idx ) + e.getClass().getName() + ": " + e.getMessage() );
			}
		}
		return idx;
	}

	/**
	 * 0 home
	 * 1 ignore
	 * 2 fg file
	 * 3 bg file
	 * 4 initial length
	 * 5 flanking order
	 * 6 foreground order
	 * 7 both strands
	 * 8 initialization (best-random=,specific=,...)
	 * 9 xml file name
	 * 10 adjust motif length
	 * 11 p-val
	 */
	protected void loadParameters() throws Exception {
		parameters = new ParameterList( 12 );
		parameters.add( new SimpleParameter( DataType.STRING, "home directory", "the path to the data directory", true, "./" ) );
		parameters.add( new SimpleParameter( DataType.CHAR, "the ignore char for the data files", "the char that is used to mask comment lines in data files, e.g., '>' in a FASTA-file", true, '>' ) );
		parameters.add( new SimpleParameter( DataType.STRING, "foreground file", "the file name of the foreground data file (the file containing sequences which are expected to contain binding sites of a common motif)", true ) );
		parameters.add( new SimpleParameter( DataType.STRING, "background file", "the file name of the background data file", true ) );
		parameters.add( new SimpleParameter( DataType.INT, "initial motif length", "the motif length that is used at the beginning", true, new NumberValidator<Integer>(1,50), 15 ) );
		parameters.add( new SimpleParameter( DataType.INT, "Markov order for flanking models", "The Markov order of the model for the flanking sequence and the background sequence", true, new NumberValidator<Integer>(0,5), 0 ) );
		parameters.add( new SimpleParameter( DataType.INT, "Markov order for motif model", "The Markov order of the motif model", true, new NumberValidator<Integer>(0,3), 0 ) );
		parameters.add( new SimpleParameter( DataType.BOOLEAN, "both strands", "a switch whether to use both strands or not", true, true ) );
		parameters.add( new SimpleParameter( DataType.STRING, "initialization method", "the method that is used for initialization, one of 'best-random=<number>', 'enum=<length>', and 'specific=<sequence or file of sequence>'", true ) ); //TODO
		parameters.add( new SimpleParameter( DataType.STRING, "classifier xml-file", "the file name of the xml file containing the classifier", true, "./classifier.xml" ) );
		parameters.add( new SimpleParameter( DataType.BOOLEAN, "adjust motif length", "a switch whether to adjust the motif length, i.e., either to shrink or expand", true, true ) );
		parameters.add( new SimpleParameter( DataType.DOUBLE, "p-value", "a p-value for predicting binding sites", false, new NumberValidator<Double>(0d,1d) ) );
	}
	
	public String getPathHomeDirectory() {
		return (String) parameters.get( 0 ).getValue();
	}
	
	public char getIgnoreChar() {
		return (Character) parameters.get( 1 ).getValue();
	}
	
	public String getFgFileName() {
		return (String) parameters.get( 2 ).getValue();
	}
	
	public String getBgFileName() {
		return (String) parameters.get( 3 ).getValue();
	}
	
	public int getMotifLength() {
		return (Integer) parameters.get( 4 ).getValue();
	}
	
	public int getFlankingOrder() {
		return (Integer) parameters.get( 5 ).getValue();
	}
	
	public int getMotifOrder() {
		return (Integer) parameters.get( 6 ).getValue();
	}
	
	public boolean useBothStrands() {
		return (Boolean) parameters.get( 7 ).getValue();
	}
	
	public String getInitializationMethod() {
		return (String) parameters.get( 8 ).getValue();
	}
	
	public String getXMLFileName() {
		return (String) parameters.get( 9 ).getValue();
	}

	public boolean adjust() {
		return (Boolean) parameters.get( 10 ).getValue();
	}
	
	public boolean hasPValue() {
		return parameters.get( 11 ).getValue() != null;
	}
	
	public Double getPValue() {
		return (Double) parameters.get( 11 ).getValue();
	}
	
	public String toString() {
		int i = 0, p = parameters.size();
		StringBuffer sb = new StringBuffer( 75 * p );
		while( i < p ) {
			sb.append( PREFIX[i] + "\t... " + parameters.get( i++ ).toString() + "\n" );
		}
		return sb.substring( 0, sb.length()-1 );
	}
}