/*
 *
 * This variant computes the LZ78 trie of a sequence, and computes the 
 * matrices corresponding to LZ-words.
 * It then runs viterbi's algorithm on the compressed sequence
 * using the matrices.
 *
 * change the parameters manually in the main function.
 *
 * Shay Mozes, 2007
 *
 */
//#include "trie.h"

#include <stdio.h>
#include <string>
#include <iostream>
#include <fstream>
#include <math.h>
#include <stdlib.h>
#include <time.h>

using namespace std;

#define ALPH_SIZE 4
#define INFTY 9e99


typedef struct _node{
  struct _node** edges;
  double** mat;
  struct _node* parent;
  char c;
} node;


typedef struct _hmm{
  int k;
  double** e;  // k by Alph_Size emission probs
  double** t;  // k by k transition probs
} hmm;

//assume seq is a sequence of chars in range [0...ALPH_SIZE-1], ending with ALPH_SIZE
class Trie{

public:
  node root;
  hmm model;
  double *v,*tmpv;


  Trie(int k , double** e, double** t){
    model.k = k;
    model.e = e;
    model.t = t;

  };

  double realViterbi(char* seq) {
    int k = model.k;

    v = new double[model.k];
    tmpv = new double[model.k];
    double* tmp;

    int i;
    for (i=0; i<k;i++) v[i] = 0;

    i=0;
    while (seq[i]!=ALPH_SIZE){

      for (int j=0;j<k;j++){
	tmpv[j] = v[0] + model.e[0][seq[i]] + model.t[0][j];
	for (int l=1;l<k;l++){
	  if (tmpv[j] < v[l] + model.e[l][seq[i]] + model.t[l][j]){
	    tmpv[j] = v[l] + model.e[l][seq[i]] + model.t[l][j];
	  }
	}   
      } 

      tmp = v;
      v = tmpv;
      tmpv = tmp;
      ///      for (int j=0;j<k;j++) printf("%5.3g ",v[j]);
      ///cout << endl;
      i++;
    }

    double res = v[0];
    for(int j=1;j<k;j++){
      if (res < v[j]) res = v[j];
    }

    return(res);
    
  }



  //build LZ from seq
  double Viterbi(char* seq){
    int i=0;
    root.edges = new node*[ALPH_SIZE];
    for (;i<ALPH_SIZE;i++) root.edges[i]=NULL;
    root.parent = NULL;
    root.c = ALPH_SIZE;
    v = new double[model.k];
    tmpv = new double[model.k];


    //assume uniform initial state dist.
    for (i=0;i<model.k;i++){
      v[i] = 0;
    }

    i=0;
    int num=0;
    double res;
    while (seq[i]<ALPH_SIZE) {
      res = addword(seq, i);
      num++;
      ///      for (int j=0;j<model.k;j++) printf("%5.3g ",v[j]);
      ///cout << endl;
    }

    cout << "number of LZ words is: " << num << endl;

    return res;
  };


  double addword(char* seq, int& i){
    node* cur = &root;
    
    //find longest word already in trie
    while( seq[i]<ALPH_SIZE && cur->edges != NULL && cur->edges[seq[i]] != NULL) {
      cur = cur->edges[seq[i]];
      ///cout << (int) seq[i] << flush;
      i++;
    }

    int k = model.k;

    if (seq[i] == ALPH_SIZE) {

      double res =  - INFTY;
      for (int j=0;j<k;j++){
	tmpv[j] = v[0] + cur->mat[0][j];
	for (int l=1;l<k;l++) {
	  if (tmpv[j] < v[l] + cur->mat[l][j]) {
	    tmpv[j] = v[l] + cur->mat[l][j];
	  }
	}
	if (res < tmpv[j]) res = tmpv[j];
      }

      return res;
    }


    //add new char
    if (NULL == cur->edges){
      cur->edges = new node*[ALPH_SIZE];
      for (int j=0;j<ALPH_SIZE;j++) cur->edges[j]=NULL;
    }  
    cur->edges[seq[i]] = new node;

    node* nd = cur->edges[seq[i]];
    nd->parent = cur;
    nd->c = seq[i];
    buildMat(*nd);
    

    //do one large viterbi step
    for (int j=0;j<k;j++){
      tmpv[j] = v[0] + nd->mat[0][j];
      for (int l=1;l<k;l++) {
	if (tmpv[j] < v[l] + nd->mat[l][j]) {
	  tmpv[j] = v[l] + nd->mat[l][j];
	}
      }
    }
    double *tmp;
    tmp = v;
    v = tmpv;
    tmpv = tmp;


    ///cout << (int) seq[i] << endl;
      ///cout << "added char " << (char) ('0'+nd->c) << endl;
    ///for (int j=0;j<k;j++) printf("%5.3g ",v[j]);
    ///cout << endl;

    i++;
    if (seq[i]==ALPH_SIZE) {
      double res = v[0];
      for (int j=1;j<k;j++) {
	if (res<v[j]) res = v[j];
      }
      return res;
    }

    return 0;

  };



   //build MPP matrices for a node in the trie
  //mat[i][j] = prob to start in state i, emit the seq associated
  // with the node and end in the state j.
  void buildMat(node& n){

    int i;

    //build matrix for current node
    int k= model.k;
    if (n.parent != NULL){
      n.mat  = new double*[k];

      //need different cases:
      if (n.parent->parent == NULL){
	for (i=0;i<k;i++){
	  n.mat[i] = new double[k];
	  for (int j=0;j<k;j++){
	    n.mat[i][j] = model.e[i][n.c] + model.t[i][j];
	  }
	}
      }
      else {
	///	cout << "my parent char is: " << (char) ('0'+n.parent->c) << endl;
	for (i=0;i<k;i++){
	  n.mat[i] = new double[k];
	  for (int j=0;j<k;j++){
	    n.mat[i][j] = n.parent->mat[i][0] + model.e[0][n.c] + model.t[0][j];
	    for (int l=1;l<k;l++){
	      if (n.mat[i][j] < n.parent->mat[i][l] + model.e[l][n.c] + model.t[l][j])
		n.mat[i][j] = n.parent->mat[i][l] + model.e[l][n.c] + model.t[l][j];
	    }
	  }
	}

      }

      /*
      cout << endl;
      for(i=0;i<k;i++){
	for (int j=0;j<k;j++) printf("%5.3g ", n.mat[i][j]);
	cout << endl;
      }
      cout << endl;
      */
    }

  }

    
    
};

char* randSeq(int len){
  
  char* seq = new char[len+1];
  double x;

  for (int i=0;i<len;i++){
    x = drand48();
    seq[i] = (char) (x*x*x*x*ALPH_SIZE);
  }
  seq[len] = ALPH_SIZE;

  return(seq);
}


char* get_seq(const char* filename, int *len){

  ifstream file(filename);

  char stam[1000];

  //skip first line
  file.getline((char*) &stam,1000);

  //find size of file
  int begin = (int) file.tellg();
  file.seekg (0, ios::end);
  int end = (int) file.tellg();
  
  //allocate and read seq
  char* seq = new char[end-begin+1];

  int cnt=0;
  *len = end-begin;
  file.seekg(begin);
  while (! file.eof()) {    
    file.getline(seq + cnt, *len-cnt);
    ///    cout << seq+cnt << endl;
    cnt += strlen(seq+cnt);
  }

  *len = cnt;

  for (int i=0;i<cnt;i++) {
    if (seq[i]=='A') seq[i]=0;
    else if (seq[i]=='C') seq[i]=1;
    else if (seq[i]=='G') seq[i]=2;
    else if (seq[i]=='T') seq[i]=3;
    else cout << "garble at position " << i << endl;
  }

  seq[cnt] = 4;

  return seq;
}


int main(){

  //  char seq[] = {0,1,0,0,0,0,0,1,0,1,0,1,4};

  int len;
  char* seq =  get_seq("chromFa/chr4.fa",&len);

  for (int i=0;i<100;i++)
    cout << (int) seq[i] ;
  cout << endl;


  cout << "read " << len << " bases." << endl;

  time_t tmp;
  time(&tmp); 
  srand48(tmp); 
  /*
  char* seq = randSeq(10000000);

  for (int i=0;i<15;i++) cout << (char)('0'+seq[i]);
  cout << endl;
  */


  /*
  double e[3][4];
  double tt[3][3];

  e[0][0] = e[0][1] = log2(0.1);
  e[0][2] = e[0][3] = log2(0.4);

  e[1][0] = e[1][1] = log2(0.4);
  e[1][2] = e[1][3] = log2(0.1);

  e[2][0] = e[2][3] = log2(0.4);
  e[2][2] = e[2][1] = log2(0.1);

  e[3][0] = e[3][2] = log2(0.4);
  e[3][3] = e[3][1] = log2(0.1);

  tt[0][0] = tt[0][3] = log2(0.49); tt[0][1] = tt[0][2] = log2(0.01);
  tt[1][0] = tt[1][2] = log2(0.01); tt[1][1] = tt[1][3] = log2(0.49);
  tt[2][2] = tt[2][1] = log2(0.49); tt[2][0] = tt[2][3] = log2(0.01);
  tt[3][3] = tt[3][2] = log2(0.49); tt[3][1] = tt[3][0] = log2(0.01);
  */
 
  int k=5;
  double **ee = new double*[k];
  double **ttt = new double*[k];

  for (int i=0; i<k; i++) {
    ee[i] = new double[ALPH_SIZE];
    ttt[i] = new double[k];

    for (int j=0;j<4;j++){
      ee[i][j] = (double) ((int)(100*drand48()))/128.0;
    }
    for (int j=0;j<k;j++){
      ttt[i][j] = (double) ((int)(100*drand48()))/128.0;
    }
  }


  Trie* t = new Trie(k,ee,ttt);

  //  seq[100000] = ALPH_SIZE;

  clock_t cl2 = clock();
  cout << "viterbi returned: " << t->Viterbi(seq) << endl;
  clock_t cl3 = clock();
  cout << "real viterbi returned: " << t->realViterbi(seq) << endl;
  clock_t cl1 = clock();

  cout << "realViterbi time: " << cl1-cl3 << endl;
  cout << "our time: " << cl3 - cl2 << endl;
  


}


