/*
 *
 * This variant splits the sequence into words of size b, computes matrices 
 * for all possible words and rund viterbi's algorithm on the sequence of n/b
 * words using the matrices.
 *
 * gets 3 arguments: number of states in model, word size b and filename 
 * containing DNA sequence.
 *
 * NOTE: the Trie class is not actually used by this variant.
 *
 * Shay Mozes, 2007
 *
 */

#include <stdio.h>
#include <string>
#include <iostream>
#include <fstream>
#include <math.h>
#include <stdlib.h>
#include <time.h>
#include <assert.h>

using namespace std;

#define ALPH_SIZE 4
//#define INFTY 9e99999


typedef struct _node{
  struct _node** edges;
  double** mat;
  struct _node* parent;
  char c;
} node;


typedef struct _hmm{
  int k;
  double** e;  // k by Alph_Size emission probs
  double** t;  // k by k transition probs
} hmm;

//assume seq is a sequence of chars in range [0...ALPH_SIZE-1], ending with ALPH_SIZE
class Trie{

public:
  node root;
  hmm model;
  double *v,*tmpv;


  Trie(int k , double** e, double** t){
    model.k = k;
    model.e = e;
    model.t = t;

  };

  double realViterbi(char* seq) {
    int k = model.k;

    v = new double[model.k];
    tmpv = new double[model.k];
    double* tmp;

    int i;
    for (i=0; i<k;i++) v[i] = 0;

    i=0;
    while (seq[i]!=ALPH_SIZE){

      for (int j=0;j<k;j++){
	tmpv[j] = v[0] + model.e[0][seq[i]] + model.t[0][j];
	for (int l=1;l<k;l++){
	  if (tmpv[j] < v[l] + model.e[l][seq[i]] + model.t[l][j]){
	    tmpv[j] = v[l] + model.e[l][seq[i]] + model.t[l][j];
	  }
	}   
      } 

      tmp = v;
      v = tmpv;
      tmpv = tmp;
      ///      for (int j=0;j<k;j++) printf("%5.3g ",v[j]);
      ///cout << endl;
      i++;
    }

    double res = v[0];
    for(int j=1;j<k;j++){
      if (res < v[j]) res = v[j];
    }

    return(res);
    
  }


  //build LZ from seq
  node** LZ(char* seq , const int len, int& num){

    int i=0;
    root.edges = new node*[ALPH_SIZE];
    for (;i<ALPH_SIZE;i++) root.edges[i]=NULL;
    root.parent = NULL;
    root.c = ALPH_SIZE;

    i=0;
    num=0;
    node **res = new node*[len];
    while (seq[i]<ALPH_SIZE) {
      res[num] = addword(seq, i);
      num ++;
      ///      for (int j=0;j<model.k;j++) printf("%5.3g ",v[j]);
      ///cout << endl;
    }

    cout << "number of LZ words is: " << num << endl;

    node**tmp = new node*[num];
    for (i=0;i<num;i++) tmp[i] = res[i];

    delete[] res;

    return tmp;
  }

  void buildMats(){

    for (int c=0;c<ALPH_SIZE;c++){

      if (NULL == root.edges[c]) continue;
      node* n = root.edges[c];
      
      //build Matrix for child
      n->mat = new double*[model.k];
      for (int i=0;i<model.k;i++){
	n->mat[i] = new double[model.k];
	n->mat[i] = new double[model.k];
	for (int j=0;j<model.k;j++){
	  n->mat[i][j] = model.e[i][n->c] + model.t[i][j];
	}
      }

      ///      cout << "in buildMats, char is: " << (char) (n->c+'0') << endl;
      
      //recurse on child's childrens.
      for (int cc=0;cc<ALPH_SIZE;cc++){
	if (NULL != n->edges[cc]) buildMat(n->edges[cc]);
      }
    }
    
  }

  double Viterbi(node** seq, const int len){
    int p,j,l,i;
    int k = model.k;
    v = new double[k];
    tmpv = new double[k];

    //assume uniform initial state dist.
    for (i=0;i<model.k;i++){
      v[i] = 0;
    }

    double**mat;
    double *tmp;

    for (p=0; p<len; p++){
      //do one large viterbi step
      mat = seq[p]->mat;
      for (j=0;j<k;j++){
	tmpv[j] = v[0] + mat[0][j];
	for (l=1;l<k;l++) {
	  if (tmpv[j] < v[l] + mat[l][j]) {
	    tmpv[j] = v[l] + mat[l][j];
	  }
	}
      }
      
      tmp = v;
      v = tmpv;
      tmpv = tmp;
    }

    double res = v[0];
    for (j=1;j<k;j++) {
      if (res<v[j]) res = v[j];
    }
    return res;
  };


  node* addword(char* seq, int& i){

    node* cur = &root;
    
    //find longest word already in trie
    while( seq[i]<ALPH_SIZE && cur->edges != NULL && cur->edges[seq[i]] != NULL) {
      cur = cur->edges[seq[i]];
      ///cout << (int) seq[i] << flush;
      i++;
    }

    if (seq[i] == ALPH_SIZE) {
      return cur;
    }

    //add new char
    if (NULL == cur->edges){
      cur->edges = new node*[ALPH_SIZE];
      for (int j=0;j<ALPH_SIZE;j++) cur->edges[j]=NULL;
    }  
    cur->edges[seq[i]] = new node;

    node* nd = cur->edges[seq[i]];
    nd->parent = cur;
    nd->c = seq[i];

    nd->edges = NULL;
    nd->mat = NULL;

    //    buildMat(*nd);
    i++;
    return nd;


    ///cout << (int) seq[i] << endl;
      ///cout << "added char " << (char) ('0'+nd->c) << endl;
    ///for (int j=0;j<k;j++) printf("%5.3g ",v[j]);
    ///cout << endl;

  };



   //build MPP matrices for a node in the trie
  //mat[i][j] = prob to start in state i, emit the seq associated
  // with the node and end in the state j.
  void buildMat(node* n){

    ///    cout << "in buildmat node char is: " << (char) ('0'+n->c) << endl;

    int i;

    //build matrix for current node
    int k= model.k;
    n->mat = new double*[k];
    for (i=0;i<k;i++){
      n->mat[i] = new double[k];
      for (int j=0;j<k;j++){
	n->mat[i][j] = n->parent->mat[i][0] + model.e[0][n->c] + model.t[0][j];
	for (int l=1;l<k;l++){
	  if (n->mat[i][j] < n->parent->mat[i][l] + model.e[l][n->c] + model.t[l][j])
	    n->mat[i][j] = n->parent->mat[i][l] + model.e[l][n->c] + model.t[l][j];
	}
      }
    }

    //recurse on children
    if (NULL != n->edges){
      for (int c=0; c<ALPH_SIZE; c++){
	if (NULL != n->edges[c]) buildMat(n->edges[c]);
      }
    }
  }

  
};


char* randSeq(int len){
  
  char* seq = new char[len+1];
  double x;

  for (int i=0;i<len;i++){
    x = drand48();
    seq[i] = (char) (x*x*x*x*ALPH_SIZE);
  }
  seq[len] = ALPH_SIZE;

  return(seq);
}



char* get_seq(const char* filename, int *len){

  ifstream file(filename);

  char stam[1000];

  //skip first line
  //  file.getline((char*) &stam,1000);

  //find size of file
  int begin = (int) file.tellg();
  file.seekg (0, ios::end);
  int end = (int) file.tellg();
  
  //allocate and read seq
  char* seq = new (nothrow) char[end-begin+1000];
  assert(seq!=NULL);
  cout << "length is: " << end-begin+1000 << endl;

  int cnt=0;
  *len = end-begin+999;
  file.seekg(begin);
  while (! file.eof()) {    
    file.getline(seq + cnt, *len-cnt);
    ///    cout << seq+cnt << endl;
    cnt += strlen(seq+cnt);
    //    cout << cnt << endl;
  }

  *len = (unsigned int) cnt;

  char* iseq = new (nothrow) char[cnt+1];
  assert(iseq!=NULL);

  int l=0;
  for (unsigned int i=0;i<cnt;i++) {
    if (seq[i]=='A' || seq[i]=='a') iseq[l++]=0;
    else if (seq[i]=='C' || seq[i]=='c') iseq[l++]=1;
    else if (seq[i]=='G' || seq[i]=='g') iseq[l++]=2;
    else if (seq[i]=='T' || seq[i]=='t') iseq[l++]=3;
    else cout << "garble at position " << i << endl;
  }
  delete[] seq;

  *len = (int) l;
  iseq[l] = ALPH_SIZE;
  

  return iseq;
}


void tmpViterbi(char* seq, double* v, double* tmpv, int k, double** e, double** t) {

  double* tmp;

  int i=0;
  while (seq[i]!=ALPH_SIZE){

    for (int j=0;j<k;j++){
      tmpv[j] = v[0] + e[0][seq[i]] + t[0][j];
      for (int l=1;l<k;l++){
	if (tmpv[j] < v[l] + e[l][seq[i]] + t[l][j]){
	  tmpv[j] = v[l] + e[l][seq[i]] + t[l][j];
	}
      }   
    } 


    tmp = v;
    v = tmpv;
    tmpv = tmp;
    i++;
    
    ///      for (int j=0;j<k;j++) printf("%5.3g ",v[j]);
    ///cout << endl;
      
  }

}

void compress(char* seq, int len,  int b, int* res){

  int tmp;
  int* help = new int[b];
  help[0] = 1;
  for (int i=1;i<b;i++) {
    help[i] = help[i-1]*ALPH_SIZE;
  }

  for (int ind=0; ind<len; ind+=b){
    tmp = 0;
    for (int i=0; i<b; i++) {
      tmp+=help[i]*seq[ind+i];
    }
    res[ind/b] = tmp;
  }
}
    
    

double modViterbi(int* seq, const int len, double *** mats, int k){
  int p,j,l,i;
  double* v = new double[k];
  double* tmpv = new double[k];

  //assume uniform initial state dist.
  for (i=0;i<k;i++){
    v[i] = 0;
  }

  double**mat;
  double *tmp;

  for (p=0; p<len; p++){

    //do one large viterbi step
    mat = mats[seq[p]];
    for (j=0;j<k;j++){
      tmpv[j] = v[0] + mat[0][j];
      for (l=1;l<k;l++) {
	if (tmpv[j] < v[l] + mat[l][j]) {
	  tmpv[j] = v[l] + mat[l][j];
	}
      }
    }
      
    tmp = v;
    v = tmpv;
    tmpv = tmp;
  }

  double res = v[0];
  for (j=1;j<k;j++) {
    if (res<v[j]) res = v[j];
  }
  return res;
};






int main(int argc, char* argv[] ){
  cout <<"split into bmers" << endl;
  
  int len;
  char* seq =  get_seq(argv[3],&len);

  seq[len] = ALPH_SIZE;


  cout << "read " << len << " bases." << endl;


  time_t tmp;
  time(&tmp); 
  srand48(tmp); 
  

  int k=atoi(argv[1]);
  cout << "k = " << k << endl;
  double **ee = new double*[k];
  double **ttt = new double*[k];

  for (int i=0; i<k; i++) {
    ee[i] = new double[ALPH_SIZE];
    ttt[i] = new double[k];

    for (int j=0;j<4;j++){
      //      ee[i][j] = (double) ((int)(100*drand48()))/128.0;
      ee[i][j] = -1.0*drand48();
    }
    for (int j=0;j<k;j++){
      //      ttt[i][j] = (double) ((int)(100*drand48()))/128.0;
      ttt[i][j] = -1.0*drand48();
    }
  }

  //  cout << "in here!" << endl;

  clock_t cl1 = clock();

  int b = atoi(argv[2]);
  cout << "b is : " << b << endl;
  len = b*(len/b);
  cout << "len is now: " << len << endl;
  seq[len] = ALPH_SIZE;

 
  //construct precomputed matrices

  int nmats = (int) pow(ALPH_SIZE,b);
  double*** mats = new double**[nmats];
  double *v = new double[k];
  double *tmpv = new double[k];
  for (int i=0; i<k; i++) v[i] = -99999;

  char*  tmpvec = new char[b+1];
  for (int i=0; i<b; i++) tmpvec[i] = 0;
  tmpvec[b] = ALPH_SIZE;

  int cur=0;
  int ind=0;
  while (tmpvec[b]==ALPH_SIZE) {


    mats[cur] = new double*[k];
    for (int i=0;i<k;i++) {
      v[i] = 0;
      mats[cur][i] = new double[k];

      tmpViterbi(tmpvec, v, tmpv, k, ee, ttt);
      for (int j=0; j<k;j++) mats[cur][i][j] = v[j];
      v[i] = -9999;
    }

    tmpvec[0]++;
    ind = 0;
    while (ind<b && tmpvec[ind]==ALPH_SIZE){
      tmpvec[ind] = 0;
      tmpvec[ind+1]++;
      ind++;
    }
    cur++;
  }

  clock_t cl2 = clock();

  //compress seq;
  int* cseq = new int[len/b];
  compress(seq,len,b, cseq);
  cout << "compressed length: " << len/b << endl;

  
  clock_t cl3 = clock();

  //run modified viterbi
  double ourvit, realvit;
  ourvit = modViterbi(cseq, len/b, mats, k);

  clock_t cl4 = clock();

  Trie* t = new Trie(k,ee,ttt);
  realvit = t->realViterbi(seq);
  clock_t cl5 = clock();

  
  cout << "buildmat time: " << cl2 - cl1 << endl;
  cout << "compress time: " << cl3 - cl2 << endl;
  cout << "our viterbi returned: " << ourvit << endl;
  cout << "our time: " << cl4 - cl3 << endl;
  
  cout << "real viterbi returned: " << realvit << endl;
  cout << "real viterbi time: " << cl5-cl4 << endl;
  cout << "ratio is: " << (cl4-cl3+0.0)/(cl5-cl4) << endl;


}


