/*
 *
 * This variant computes the LZ78 trie of a sequence. It keeps only LZ words
 * that appear more than a specified number of times (good words)
 * and uses just those to compress the text. To do that, each "bad" LZ word is
 * decomposed into good words. 
 * It then computes the matrices in DFS order on the LZ trie.
 * Finally, it runs viterbi's algorithm on the compressed sequence,
 * using the matrices.
 *
 * takes 3 arguments: the number of states in the HMM, the minimal number of 
 * occurances to define a word as good, filename with DNA sequence
 *
 * Shay Mozes, 2007
 *
 */
//#include "trie.h"
#include <stdio.h>
#include <string>
#include <iostream>
#include <fstream>
#include <math.h>
#include <stdlib.h>
#include <time.h>
#include <assert.h>

using namespace std;

#define ALPH_SIZE 4
#define INFTY 9e99


typedef struct _node{
  struct _node** edges;
  double** mat;
  struct _node* parent;
  char c;
  int size;
} node;


typedef struct _hmm{
  int k;
  double** e;  // k by Alph_Size emission probs
  double** t;  // k by k transition probs
} hmm;

//assume seq is a sequence of chars in range [0...ALPH_SIZE-1], ending with ALPH_SIZE
class Trie{

public:
  node root;
  hmm model;
  double *v,*tmpv;
  int thresh;
  int ngood;


  Trie(int k , double** e, double** t, int tr){
    model.k = k;
    model.e = e;
    model.t = t;
    root.size = 0;
    thresh = tr;
    ngood = 0;

  };

  double realViterbi(char* seq) {
    int k = model.k;

    v = new double[model.k];
    tmpv = new double[model.k];
    double* tmp;

    int i;
    for (i=0; i<k;i++) v[i] = 0;

    i=0;
    while (seq[i]!=ALPH_SIZE){

      for (int j=0;j<k;j++){
	tmpv[j] = v[0] + model.e[0][seq[i]] + model.t[0][j];
	for (int l=1;l<k;l++){
	  if (tmpv[j] < v[l] + model.e[l][seq[i]] + model.t[l][j]){
	    tmpv[j] = v[l] + model.e[l][seq[i]] + model.t[l][j];
	  }
	}   
      } 

      tmp = v;
      v = tmpv;
      tmpv = tmp;
      ///      for (int j=0;j<k;j++) printf("%5.3g ",v[j]);
      ///cout << endl;
      i++;
    }

    double res = v[0];
    for(int j=1;j<k;j++){
      if (res < v[j]) res = v[j];
    }

    return(res);
    
  }


  //build LZ from seq
  node** LZ(char* seq , const int len, int& num){

    int i=0,l=0;
    root.edges = new node*[ALPH_SIZE];
    for (;i<ALPH_SIZE;i++) root.edges[i]=NULL;
    root.parent = NULL;
    root.c = ALPH_SIZE;

    i=0;
    num=0;
    node* nd=NULL;
    node **res = new node*[len];
    while (seq[i]<ALPH_SIZE) {
      res[num] = addword(seq, i);
      num ++;
      ///      for (int j=0;j<model.k;j++) printf("%5.3g ",v[j]);
      ///cout << endl;
    }
    cout << "number of LZ words is: " << num << endl;

    //go over res, convert each node that is not good into a bunch of shorter good nodes
    node** newres = new node*[len];
    int j=0;
    char* str = new char[2000];
    for (i=0;i<num;i++){

      if (res[i]->size > thresh) {
	//	cout << "." << flush;
	newres[j] = res[i];
	j++;
      }
      else{
	//	cout << "+" << flush;

	//traceback string
	l=0;
	nd = res[i];
	while (nd->size <= thresh){
	  //	  cout << "x" << flush;
	  str[l] = nd->c;
	  l++;
	  nd = nd->parent;
	}

	//	cout << "found prefix" << endl;

	//encode good prefix
	newres[j] = nd;
	j++;

	l--;
	//	cout << "l is: " << l << endl;
	nd = &root;
	//encode reversed suffix
	while (l>=0){
	  if ((NULL != nd->edges[str[l]]) && (nd->edges[str[l]]->size > thresh)){
	    nd = nd->edges[str[l]];
	    l--;
	  }
	  else {
	    newres[j] = nd;
	    j++;
	    nd = &root;
	  }
	}
	newres[j] = nd;
	j++;
      }
    }

    num = j;

    delete[] res;


    node**tmp = new node*[num];
    for (i=0;i<num;i++) tmp[i] = newres[i];

    delete[] newres;

    return tmp;
  }

  void buildMats(){

    for (int c=0;c<ALPH_SIZE;c++){

      if (NULL == root.edges[c]) continue;
      node* n = root.edges[c];
      
      //build Matrix for child
      n->mat = new double*[model.k];
      for (int i=0;i<model.k;i++){
	n->mat[i] = new double[model.k];
	for (int j=0;j<model.k;j++){
	  n->mat[i][j] = model.e[i][n->c] + model.t[i][j];
	}
      }

      ///      cout << "in buildMats, char is: " << (char) (n->c+'0') << endl;
      
      //recurse on child's childrens.
      for (int cc=0;cc<ALPH_SIZE;cc++){
	if (NULL != n->edges[cc]) buildMat(n->edges[cc]);
      }
    }
    
  }

  double Viterbi(node** seq, const int len){
    int p,j,l,i;
    int k = model.k;
    v = new double[k];
    tmpv = new double[k];

    //assume uniform initial state dist.
    for (i=0;i<model.k;i++){
      v[i] = 0;
    }

    double**mat;
    double *tmp;

    for (p=0; p<len; p++){
      //do one large viterbi step
      mat = seq[p]->mat;
      for (j=0;j<k;j++){
	tmpv[j] = v[0] + mat[0][j];
	for (l=1;l<k;l++) {
	  if (tmpv[j] < v[l] + mat[l][j]) {
	    tmpv[j] = v[l] + mat[l][j];
	  }
	}
      }
      
      tmp = v;
      v = tmpv;
      tmpv = tmp;
    }

    double res = v[0];
    for (j=1;j<k;j++) {
      if (res<v[j]) res = v[j];
    }
    return res;
  };


  node* addword(char* seq, int& i){

    node* cur = &root;
    
    //find longest word already in trie
    while( seq[i]<ALPH_SIZE && cur->edges != NULL && cur->edges[seq[i]] != NULL) {
      cur->size++;
      cur = cur->edges[seq[i]];
      ///cout << (int) seq[i] << flush;
      i++;
    }

    if (seq[i] == ALPH_SIZE) {
      return cur;
    }

    //add new char
    if (NULL == cur->edges){
      cur->edges = new node*[ALPH_SIZE];
      for (int j=0;j<ALPH_SIZE;j++) cur->edges[j]=NULL;
    }  
    cur->edges[seq[i]] = new node;

    node* nd = cur->edges[seq[i]];
    nd->parent = cur;
    nd->c = seq[i];
    nd->size = 1;

    nd->edges = NULL;
    nd->mat = NULL;

    //    buildMat(*nd);
    i++;
    return nd;


    ///cout << (int) seq[i] << endl;
      ///cout << "added char " << (char) ('0'+nd->c) << endl;
    ///for (int j=0;j<k;j++) printf("%5.3g ",v[j]);
    ///cout << endl;

  };



   //build MPP matrices for a node in the trie
  //mat[i][j] = prob to start in state i, emit the seq associated
  // with the node and end in the state j.
  void buildMat(node* n){

    ///    cout << "in buildmat node char is: " << (char) ('0'+n->c) << endl;

    int i;

    if (n->size<thresh) return;
    ngood++;

    //build matrix for current node
    int k= model.k;
    n->mat = new double*[k];
    for (i=0;i<k;i++){
      n->mat[i] = new double[k];
      for (int j=0;j<k;j++){
	n->mat[i][j] = n->parent->mat[i][0] + model.e[0][n->c] + model.t[0][j];
	for (int l=1;l<k;l++){
	  if (n->mat[i][j] < n->parent->mat[i][l] + model.e[l][n->c] + model.t[l][j])
	    n->mat[i][j] = n->parent->mat[i][l] + model.e[l][n->c] + model.t[l][j];
	}
      }
    }

    //recurse on children
    if (NULL != n->edges){
      for (int c=0; c<ALPH_SIZE; c++){
	if (NULL != n->edges[c]) buildMat(n->edges[c]);
      }
    }
  }

  
};


char* randSeq(int len){
  
  char* seq = new char[len+1];
  double x;

  for (int i=0;i<len;i++){
    x = drand48();
    seq[i] = (char) (x*x*x*x*ALPH_SIZE);
  }
  seq[len] = ALPH_SIZE;

  return(seq);
}

char* get_seq(const char* filename, int *len){

  ifstream file(filename);

  char stam[1000];

  //skip first line
  //  file.getline((char*) &stam,1000);

  //find size of file
  int begin = (int) file.tellg();
  file.seekg (0, ios::end);
  int end = (int) file.tellg();
  
  //allocate and read seq
  char* seq = new (nothrow) char[end-begin+1000];
  assert(seq!=NULL);
  cout << "length is: " << end-begin+1000 << endl;

  int cnt=0;
  *len = end-begin+999;
  file.seekg(begin);
  while (! file.eof()) {    
    file.getline(seq + cnt, *len-cnt);
    ///    cout << seq+cnt << endl;
    cnt += strlen(seq+cnt);
    //    cout << cnt << endl;
  }

  *len = (unsigned int) cnt;

  char* iseq = new (nothrow) char[cnt];
  assert(iseq!=NULL);

  int l=0;
  for (unsigned int i=0;i<cnt;i++) {
    if (seq[i]=='A' || seq[i]=='a') iseq[l++]=0;
    else if (seq[i]=='C' || seq[i]=='c') iseq[l++]=1;
    else if (seq[i]=='G' || seq[i]=='g') iseq[l++]=2;
    else if (seq[i]=='T' || seq[i]=='t') iseq[l++]=3;
    else cout << "garble at position " << i << endl;
  }
  delete[] seq;

  *len = (int) l;
  

  return iseq;
}

/*
char* get_seq(const char* filename, int *len){

  ifstream file(filename);

  char stam[1000];

  //skip first line
  file.getline((char*) &stam,1000);

  //find size of file
  int begin = (int) file.tellg();
  file.seekg (0, ios::end);
  int end = (int) file.tellg();
  
  //allocate and read seq
  char* seq = new char[end-begin+1];

  int cnt=0;
  *len = end-begin;
  file.seekg(begin);
  while (! file.eof()) {    
    file.getline(seq + cnt, *len-cnt);
    ///    cout << seq+cnt << endl;
    cnt += strlen(seq+cnt);
  }

  *len = cnt;

  for (int i=0;i<cnt;i++) {
    if (seq[i]=='A') seq[i]=0;
    else if (seq[i]=='C') seq[i]=1;
    else if (seq[i]=='G') seq[i]=2;
    else if (seq[i]=='T') seq[i]=3;
    else cout << "garble at position " << i << endl;
  }

  seq[cnt] = 4;

  return seq;
}
*/

int main(int argc, char* argv[] ){

  cout << "LZ separately, build mat by DFS, calculate viterbi as you along compressed seq." << endl;

  //  char seq[] = {0,1,0,0,0,0,0,1,0,1,0,1,4};

  
  int len;
  //  char* seq =  get_seq("yeast/chr4.fa",&len);
  char* seq =  get_seq(argv[3],&len);

  for (int i=0;i<100;i++)
    cout << (int) seq[i] ;
  cout << endl;

  //  len = 500000;
  seq[len] = ALPH_SIZE;


  cout << "read " << len << " bases." << endl;


  time_t tmp;
  time(&tmp); 
  srand48(tmp); 
  
  /*
  int len = 1000000;
  //char* seq = randSeq(100);
  char* seq = new char[len+1];
  for (int i=0;i<len;i++) seq[i] = 0;
  seq[len] = ALPH_SIZE;

  for (int i=0;i<15;i++) cout << (char)('0'+seq[i]);
  cout << endl;
  */


  /*
  double e[3][4];
  double tt[3][3];

  e[0][0] = e[0][1] = log2(0.1);
  e[0][2] = e[0][3] = log2(0.4);

  e[1][0] = e[1][1] = log2(0.4);
  e[1][2] = e[1][3] = log2(0.1);

  e[2][0] = e[2][3] = log2(0.4);
  e[2][2] = e[2][1] = log2(0.1);

  e[3][0] = e[3][2] = log2(0.4);
  e[3][3] = e[3][1] = log2(0.1);

  tt[0][0] = tt[0][3] = log2(0.49); tt[0][1] = tt[0][2] = log2(0.01);
  tt[1][0] = tt[1][2] = log2(0.01); tt[1][1] = tt[1][3] = log2(0.49);
  tt[2][2] = tt[2][1] = log2(0.49); tt[2][0] = tt[2][3] = log2(0.01);
  tt[3][3] = tt[3][2] = log2(0.49); tt[3][1] = tt[3][0] = log2(0.01);
  */
 
 
  int k=atoi(argv[1]);
  cout << "k = " << k << endl;
  double **ee = new double*[k];
  double **ttt = new double*[k];

  for (int i=0; i<k; i++) {
    ee[i] = new double[ALPH_SIZE];
    ttt[i] = new double[k];

    for (int j=0;j<4;j++){
      //      ee[i][j] = (double) ((int)(100*drand48()))/128.0;
      ee[i][j] = -1.0*drand48();
    }
    for (int j=0;j<k;j++){
      //      ttt[i][j] = (double) ((int)(100*drand48()))/128.0;
      ttt[i][j] = -1.0*drand48();
    }
  }


  Trie* t = new Trie(k,ee,ttt, atoi(argv[2]));

  node** lzseq;
  int lzlen;

  double ourvit, realvit;
  
  clock_t cl1 = clock();
  lzseq = t->LZ(seq, len,lzlen);
  clock_t cl2 = clock();
  cout << "after LZ, len is: " << lzlen << endl;
  t->buildMats();
  cout << "number of good mats is: " << t->ngood << endl;
  clock_t cl3 = clock();
  ourvit = t->Viterbi(lzseq,lzlen);
  
  clock_t cl4 = clock();
  realvit = t->realViterbi(seq);
  clock_t cl5 = clock();

  
  cout << "LZ time: " << cl2 - cl1 << endl;
  cout << "builmat time: " << cl3 - cl2 << endl;
  cout << "our viterbi returned: " << ourvit << endl;
  cout << "our time: " << cl4 - cl3 << endl;
  
  cout << "real viterbi returned: " << realvit << endl;
  cout << "real viterbi time: " << cl5-cl4 << endl;
  
  double ratio = (lzlen+0.0)/len;
  cout << "compression ratio is: " << ratio << endl;
  cout << "theoretical ratio of ours to real (k*ratio): " << k*ratio << endl;
  cout << "actual times ratio is: " << (cl4-cl2+0.0) / (cl5-cl4) << endl;
  
}


