#include <stdlib.h>
#include <math.h>
#include <stdio.h>

#include "hmm-gene.h"
#include "ga.h"
#include "hmm.h"


int baum_welch(int num_seq, SEQ *seq_str, int seq_length, 
	       int num_state, int *tuple, int **connect, HMM_CND hmm_cnd, 
	       double ***a_ij, double **pi_i, TPL_TBL **b_ij, 
	       double **score)
{
    int i, j, k, l;
    int itr, t, id;
    int idum;
    double freq;
    double pav = L_NEGATIVE;
    double cav;
    double ddum1, ddum2;
    double **dum_a;
    double *dum_pi;
    double **dum_scale;
    double *dum_likelihd;
    TPL_TBL *dum_b;
    int *idum_ptr;
    char *sw;

    HMM_LEARN_PARAM *param_str;

    int alloc_parameter2(int, int, int, HMM_LEARN_PARAM **, double ***,
			 double **, TPL_TBL **, int *, double ***, double **);
    int free_parameter2(int, int, int, HMM_LEARN_PARAM *, double **,
			double *, TPL_TBL *, int *, double **, double *);
    int forward(int, SEQ *, int, int, int *, int **, HMM_CND, 
		double ***, double **, TPL_TBL **, HMM_LEARN_PARAM **,
		double ***, double **);
    int backward(int, SEQ *, int, int, int *, int **, HMM_CND, 
		 double ***, double **, TPL_TBL **, HMM_LEARN_PARAM **,
		 double **);
    int print_vector(double *, int);
    int print_matrix(double **, int, int);
    int vector_times(double *, double, int, double*);
    int matrix_times(double **, double, int, int, double **);
    int frequency(char *, int, int, TPL_TBL **, int, double *);
    int is_zero(double);
    int *ivector(long, long);
    unsigned char *cvector(long, long);
    void free_ivector(int *, long, long);
    void free_cvector(unsigned char *, long, long);
    int int2n(int, int, int, int *);


    /*  */
#ifdef DEBUG
    fprintf(stderr, "--- allocating memory of parameters...\n");
#endif
    if(alloc_parameter2(num_seq, num_state, seq_length, 
			&param_str, &dum_a, &dum_pi, &dum_b, tuple,
			&dum_scale, &dum_likelihd) != 0){
	fprintf(stderr, "error occures in alloc_parameter2()...\n");
	free_parameter2(num_seq, num_state, seq_length, param_str,
			dum_a, dum_pi, dum_b, tuple, dum_scale, 
			dum_likelihd);
	return(-1);
    }


    /* b_ijѤΥƥݥѿ */
    for(i = 1; i <= num_state; i++){
	idum = (int) pow((double)NUCL, (double)(tuple[i]));

	/* Ĺlength[i]tuple */
	idum_ptr = ivector(0, (long)tuple[i]-1);
	for(j = 1; j <= idum; j++){

	    /* (j-1)4ʿѴ */
	    if(int2n(j-1, tuple[i], NUCL, idum_ptr) != 0){
		fprintf(stderr, "error occures in int2n()...\n");
		free_ivector(idum_ptr, 0, (long)(tuple[i]-1));
		free_parameter2(num_seq, num_state, seq_length, param_str,
				dum_a, dum_pi, dum_b, tuple, dum_scale, 
				dum_likelihd);
		return(-1);
	    }
	    for(k = 0; k < tuple[i]; k++){
		if(idum_ptr[k] == 0)
		    ((dum_b + i)->tuple)[j][k+1] = 'A';
		else if(idum_ptr[k] == 1)
		    ((dum_b + i)->tuple)[j][k+1] = 'T';
		else if(idum_ptr[k] == 2)
		    ((dum_b + i)->tuple)[j][k+1] = 'C';
		else if(idum_ptr[k] == 3)
		    ((dum_b + i)->tuple)[j][k+1] = 'G';
		else{
		    free_ivector(idum_ptr, 0, (long)(tuple[i]-1));
		    free_parameter2(num_seq, num_state, seq_length, param_str,
				    dum_a, dum_pi, dum_b, tuple, dum_scale, 
				    dum_likelihd);
		    return(-1);
		}
	    }

	}
	free_ivector(idum_ptr, 0, (long)(tuple[i]-1));

    }


    /* Baum-Welch Algorithmˤѥ᡼ꡥ */
    sw = "off";
    for(itr = 1; itr <= hmm_cnd.max_itr; itr++){
#ifdef DEBUG
	fprintf(stderr, "number of iteration = %d...\n", itr);
#endif
	fprintf(stderr, "%d ", itr);
	fflush(stderr);

	/* Forward algorithmˤ_n(i,t)򻻽Ф롥 */
#ifdef DEBUG
	fprintf(stderr, "--- calculating alpha_n(i,t) using forward algorithm...\n");
#endif
	if(forward(num_seq, seq_str, seq_length, num_state, tuple, 
		   connect, hmm_cnd, a_ij, pi_i, b_ij, &param_str,
		   &dum_scale, &dum_likelihd) != 0){
	    fprintf(stderr, "error occures in forward()...\n");
	    free_parameter2(num_seq, num_state, seq_length, param_str,
			    dum_a, dum_pi, dum_b, tuple, dum_scale, 
			    dum_likelihd);
	    return(-1);
	}
	for(id = 1; id <= num_seq; id++){
	    if(is_zero(((param_str + id)->alpha_ij)[num_state][seq_length])
	       == TRUE){
		sw = "on";
		break;
	    }
	}


	/* Backward algorithmˤ_n(i,t)򻻽Ф롥 */
#ifdef DEBUG
	fprintf(stderr, "--- calculating beta_n(i,t) using backward algorithm...\n");
#endif
	if(backward(num_seq, seq_str, seq_length, num_state, tuple, 
		    connect, hmm_cnd, a_ij, pi_i, b_ij, &param_str,
		    dum_scale) != 0){
	    fprintf(stderr, "error occures in backward()...\n");
	    free_parameter2(num_seq, num_state, seq_length, param_str,
			    dum_a, dum_pi, dum_b, tuple, dum_scale,
			    dum_likelihd);
	    return(-1);
	}


	/* ܳΨa_ijꤹ롥 */
#ifdef DEBUG
	fprintf(stderr, "--- estimating a_ij...\n");
#endif
	for(i = 1; i <= num_state; i++){
	    for(j = 1; j <= num_state; j++){
		ddum1 = 0.0;
		for(t = 1; t <= (seq_length-1); t++){
		    for(id = 1; id <= num_seq; id++){
			frequency((char *)((seq_str + id)->a_seq), t+1, 
				  tuple[j], b_ij, j, &freq);
			ddum1 += ((param_str + id)->alpha_ij)[i][t] * 
			    (*a_ij)[i][j] * freq * 
				((param_str + id)->beta_ij)[j][t+1];
		    }
		}
		dum_a[i][j] = ddum1;
	    }

	    ddum2 = 0.0;
	    for(j = 1; j <= num_state; j++){
		ddum2 += dum_a[i][j];
	    }
	    if(is_zero(ddum2) == FALSE){
		for(j = 1; j <= num_state; j++){
		    dum_a[i][j] = dum_a[i][j] / ddum2;
		}
	    }
	    else{
		/*
		for(j = 1; j <= num_state; j++){
		    dum_a[i][j] = L_POSITIVE;
		}
		sw = "on";
		*/
		for(j = 1; j <= num_state; j++){
		    dum_a[i][j] = 0.0;
		}
	    }
	}
#ifdef DEBUG
	print_matrix(dum_a, num_state, num_state);
#endif


	/* _n(i,j,t)򻻽С */
#ifdef DEBUG
	fprintf(stderr, "--- calculating gamma_ijt...\n");
#endif
	for(id = 1; id <= num_seq; id++){
	    for(t = 1; t <= (seq_length-1); t++){
		for(j = 1; j <= num_state; j++){
		    ddum1 = 0.0;
		    for(k = 1; k <= num_state; k++){
			frequency((char *)((seq_str + id)->a_seq), t+1, 
				  tuple[k], b_ij, k, &freq);
			ddum1 += ((param_str + id)->alpha_ij)[j][t] * 
			    (*a_ij)[j][k] * freq * 
				((param_str + id)->beta_ij)[k][t+1];
		    }
		    ((param_str + id)->gamma_ijt)[j][t] = ddum1;
		}
		for(j = 1; j <= num_state; j++){
		    ((param_str + id)->gamma_ijt)[j][seq_length] = 
			((param_str + id)->alpha_ij)[j][seq_length];
		}
	    }
	}


	/* Ψʬb_ijꤹ롥 */
#ifdef DEBUG
	fprintf(stderr, "--- estimating b_ij...\n");
#endif
	for(j = 1; j <= num_state; j++){
	    ddum1 = pow((double)NUCL, (double)tuple[j]);
	    for(k = 1; k <= (int)ddum1; k++){
		unsigned char *key1;
		unsigned char *key2;

		key1 = cvector((long)1, (long)tuple[j]);
		key2 = cvector((long)1, (long)tuple[j]);

		ddum2 = 0.0;
		idum_ptr = ivector(0, (long)tuple[j]-1);

		/* k4ʿѴ */
		if(int2n(k-1, tuple[j], NUCL, idum_ptr) != 0){
		    fprintf(stderr, "error occures in int2n()...\n");
		    free_ivector(idum_ptr, 0, (long)(tuple[j]-1));
		    free_parameter2(num_seq, num_state, seq_length, param_str,
				    dum_a, dum_pi, dum_b, tuple, dum_scale, 
				    dum_likelihd);
		    free_cvector(key1, (long)1, (long)tuple[j]);
		    free_cvector(key2, (long)1, (long)tuple[j]);
		    return(-1);
		}

		/* kɽtuplekey1 */
		for(l = 0; l < tuple[j]; l++){
		    if(idum_ptr[l] == 0)
			key1[l+1] = 'A';
		    else if(idum_ptr[l] == 1)
			key1[l+1] = 'T';
		    else if(idum_ptr[l] == 2)
			key1[l+1] = 'C';
		    else if(idum_ptr[l] == 3)
			key1[l+1] = 'G';
		    else{
			free_ivector(idum_ptr, 0, (long)(tuple[j]-1));
			free_parameter2(num_seq, num_state, seq_length, 
					param_str, dum_a, dum_pi, dum_b, 
					tuple, dum_scale, dum_likelihd);
			free_cvector(key1, (long)1, (long)tuple[j]);
			free_cvector(key2, (long)1, (long)tuple[j]);
			return(-1);
		    }
		}

		for(t = 1; t <= seq_length; t++){
		    for(id = 1; id <= num_seq; id++){
			int sensor = 0;

			/* idΰtˤĹtuple[j]ʸ
			   key2 */
			for(l = 1; l <= tuple[j]; l++){
			    key2[l] = ((seq_str + id)->a_seq)[(t-1)+(l-1)];
			}

			/* key1key2Ʊ? */
			for(l = 1; l <= tuple[j]; l++){
			    if(key1[l] != key2[l]){
				sensor++;
				break;
			    }
			}
			if(sensor == 0){
			    ddum2 += ((param_str + id)->gamma_ijt)[j][t];
			}
		    }
		}
		((dum_b + j)->frq)[k] = ddum2;
		free_ivector(idum_ptr, 0, (long)(tuple[j]-1));

		free_cvector(key1, (long)1, (long)tuple[j]);
		free_cvector(key2, (long)1, (long)tuple[j]);
	    }
	}
	for(j = 1; j <= num_state; j++){
	    ddum1 = pow((double)NUCL, (double)tuple[j]);
	    ddum2 = 0.0;
	    for(i = 1; i <= (int)ddum1; i++){
		ddum2 += ((dum_b + j)->frq)[i];
	    }
	    if(is_zero(ddum2) == FALSE){
		for(k = 1; k <= (int)ddum1; k++){
		    ((dum_b + j)->frq)[k] = ((dum_b + j)->frq)[k] / ddum2;
		}
	    }
	    else{
		/*
		for(k = 1; k <= (int)ddum1; k++){
		    ((dum_b + j)->frq)[k] = L_POSITIVE;
		}
		sw = "on";
		*/
		for(k = 1; k <= (int)ddum1; k++){
		    ((dum_b + j)->frq)[k] = 0.0;
		}
	    }
	}
#ifdef DEBUG
	for(i = 1; i <= num_state; i++){
	    idum = (int) pow((double)NUCL, (double)(tuple[i]));
	    fprintf(stderr, "\tstate=%d  freq.=...\n", i);
	    for(j = 1; j <= idum; j++){
		fprintf(stderr, "\t");
		for(k = 1; k <= tuple[i]; k++){
		    fprintf(stderr, "%c", ((dum_b + i)->tuple)[j][k]);
		}
		fprintf(stderr, "\t%e\n", ((dum_b + i)->frq)[j]);
	    }
	}
#endif


	/* ʬ۾pi_iꤹ롥 */
#ifdef DEBUG
	fprintf(stderr, "--- estimating pi_i...\n");
#endif
	for(i = 1; i <= num_state; i++){
	    ddum1 = 0.0;
	    for(j = 1; j <= num_state; j++){
		for(id = 1; id <= num_seq; id++){
		    ddum1 += ((param_str + id)->gamma_ijt)[j][1];
		}
	    }
	    ddum2 = 0.0;
	    for(id = 1; id <= num_seq; id++){
		ddum2 += ((param_str + id)->gamma_ijt)[i][1];
	    }
	    if(is_zero(ddum1) == FALSE){
		dum_pi[i] = ddum2 / ddum1;
	    }
	    else{
		/*
		dum_pi[i] = L_POSITIVE;
		sw = "on";
		*/
		dum_pi[i] = 0.0;
	    }
	}
#ifdef DEBUG
	print_vector(dum_pi, num_state);
#endif


	/* ѥѥ᡼饳ԡ */
	vector_times(dum_pi, 1.0, num_state, *pi_i);
	matrix_times(dum_a, 1.0, num_state, num_state, *a_ij);
	for(j = 1; j <= num_state; j++){
	    ddum1 = pow((double)NUCL, (double)tuple[j]);
	    for(k = 1; k <= (int)ddum1; k++){
		((*b_ij + j)->frq)[k] = ((dum_b + j)->frq)[k];
	    }
	}


	/* 򥳥ԡ */
	vector_times(dum_likelihd, 1.0, num_seq, *score);


	/* Сե եƤ
	   ؽߤ */
	if(strcmp(sw, "on") == 0){
#ifdef DEBUG
	    fprintf(stderr, "!!! give up to learn !!!\n");
#endif
	    break;
	}


	/* «Ƚ */
	ddum1 = 0.0;
	for(j = 1; j <= num_seq; j++){
	    ddum1 += dum_likelihd[j];
	}
	cav = ddum1/(double)num_seq;
#ifdef DEBUG
	fprintf(stderr, "av. of likelihd = %e\n", cav);
	fprintf(stderr, "normarized diff. of likelihd = %e\n", 
		fabs((cav-pav)/cav));
#endif
	if(fabs((cav-pav)/cav) < hmm_cnd.eps){
	    break;
	}
	pav = cav;

    }
    fprintf(stderr, "\n");
    fflush(stderr);


    /* 곫 */
#ifdef DEBUG
    fprintf(stderr, "--- releaseing memory of parameters...\n");
#endif
    if(free_parameter2(num_seq, num_state, seq_length, param_str,
		       dum_a, dum_pi, dum_b, tuple, dum_scale, 
		       dum_likelihd) != 0){
	fprintf(stderr, "error occures in free_parameter2()...\n");
	return(-1);
    }


    return(0);
}


