/************************************************************************
 **                                                                    **
 **      em.c:  built-in EM routine of PRISM system                    **
 **                                                                    **
 ** Copyright (C) 1998                                                 **
 **   Taisuke Sato, Yoshitaka Kameya, Yasushi Hagiwara, Nobuhisa Ueda, **
 **     Dept. of Computer Science, Tokyo Institute of Technology.      **
 **                                                                    **
 ************************************************************************/

#include "prism.h"

/* Global Variables: */
/* for EM learning */
int idNum;                   /* msw(i, d, x) : 0 <= i <= idNUM    */
int idNum1;                  /* always equals idNum+1             */
int valueNum;                /* msw(i, d, x) : 0 <= x <= valueNUM */
int valueNum1;               /* always equals valueNum+1          */
int valueNum2;               /* always equals valueNum+2          */
int ivNum11;                 /* always equals (idNum+1)*(valueNum+1) */
int ivNum12;                 /* always equals (idNum+1)*(valueNum+2) */
int goals;                   /* the number of different goals     */
struct goal *G;              
struct switchInfo *switch_info;
/* [NOTE] Prolog part of PRISM must refer to theta (not theta_tmp). */
double *theta;               /* master parameters */
double *theta_tmp;           /* temporal parameters */
/* Pf: added by kame on Nov/20/1997 */
double *Pf;
double boundary = 0.000001;
double loglike, loglike_bak;

/* Global variables used in trie.c
extern struct glink *Root;
extern struct glink *ERoot;
extern struct swlink *Crr;
*/

/* [NOTE]
 *  To allocate the memory  consistently,  the function call order  of
 *  set_values(), set_inInfo(), set_fixed(), set_goals(), set_Gstatus(),
 *  set_table() and set_data() must satisfy following conditions.
 *  
 *  - set_values < set_idInfo
 *  - set_values < set_fixed
 *  - set_goals  < set_Gstatus
 *  - set_goals  < set_data
 *  - set_Gstatus < set_table
 * 
 *  where the relation "f < g" means that function f() must be called
 *  before g() is called. 
 */

void set_epsilon(double new_boundary){
  boundary = new_boundary;
}

void show_epsilon(void){
  printf("Epsilon is set to %e.\n",boundary);
}

double get_epsilon(void){
  return boundary;
}

/* [NOTE]
 *   id = idNum (= max ID of occurring switches)
 *   value = valueNum (= max of max value of occurring switches)
 */
int set_values( int id, int value)
{   
  /* printf("set_values -- received id:%d value:%d\n",id,value); */

  idNum=id;  idNum1=idNum+1;
  valueNum=value; valueNum1=valueNum+1; valueNum2=valueNum+2;
  ivNum11=idNum1*valueNum1;
  ivNum12=idNum1*valueNum2;

  if ((switch_info = (struct switchInfo *)malloc(sizeof(struct switchInfo)
						 *idNum1)) == NULL)
    return(0);
  if((theta = (double *)malloc(idNum1*valueNum1*sizeof(double))) == NULL)
    return(0);
  if((theta_tmp = (double *)malloc(idNum1*valueNum1*sizeof(double))) 
     == NULL)
    return(0);
  return(1);
}

/* [NOTE] g = the number of different goals */
int set_goals(int g)
{
  goals = g;
  if ((G = (struct goal *)malloc(sizeof(struct goal)*g)) == NULL)
	return(0);
  return set_Gstatus();
  /* return ((G = (struct goal *)malloc(sizeof(struct goal)*g))==NULL)? 0:1; */
}

/* set_idInfo(id,maxval,tnum) -- setting for each of occurring switches
 * [NOTE]
 *   id = (ID of the occurring switch)
 *   maxval = (max value of the switch "id"),
 *     i.e., for msw(id, d, v), 0<=v<=maxval.
 *   tnum = (the number of different T_id (2nd argument))
 */
int set_idInfo(int id, int maxval, int tnum) 
{                                    
  /* Since variable "id" recieved from Prolog is not always the
   * ordered intergers 0,...,idNum, we use static variable "num".
   */ 
  static int num=0;

  /* printf("set_idInfo -- received id:%d maxval:%d with idNum:%d\n",
		 id,maxval,idNum); */
  if (id < 0) {
	num=0; return(1);      /* Initialize (maxval is ignored.) */
  }
  if(num > idNum1) 
    return(0);
  switch_info[num].idInProlog = id;
  switch_info[num].maxValue = maxval;
  switch_info[num].Tnum = tnum;
  switch_info[num].fixed = NOTFIXED;
  num++;
  return(1);
}
  
/* set flag on for fixed switches */
void set_fixed(int idInProlog, int val, double Prob)  
{
  /* printf("set_fixed -- id:%d val:%d Prob:%d\n",id,val,Prob); */
  int i;
  for(i=0;i<=idNum;i++)
	if(switch_info[i].idInProlog == idInProlog){
	  switch_info[i].fixed = FIXED;
	  Theta(i, val, theta) = Prob;
	}
}
/* set_data(t, num) -- setting for teacher data
 * [NOTE]
 *   t = (ID of the goal), num = (the number of observed G[t])
 */ 
void set_data( int t, int num)
{
  /* printf("set_data -- t:%d num:%d\n",t,num); */
  G[t].observedNum = num;        
  /* printf("G[%d].observedNum=%d\n",t,num); */
}

/*
 * set_table(goal,val) -- fill the each item of EM table
 * [NOTE]
 *   goal = (ID of the goal), val = (item for the goal)
int set_table( int goal, int val)
{
  static int index = 0, pre_goal = 0;
  int v;
                                         
  if(goal == -1){
	index += val+1;	return(1);           Skip (val+1) columns
  }
  else if(goal == -2){
	index = 0; pre_goal = 0; return(1);  Initialize (val is ignored.)
  }
  else if(goal < 0) return(0);

  if(goal != pre_goal) 
	if( index != idNum1*valueNum2*G[pre_goal].Snum){
	  return(0);
	} else index = 0;
  if(index > idNum1*valueNum2*G[goal].Snum)
    return(0);
  pre_goal = goal;
  *(G[goal].table+index) = val;
  index ++;
  return(1);
}
*/

/* created by kame on Nov/20/1997 */
int prepare_Pf(void)
{
  int t,total_s=0;

  for(t=0; t<goals; t++)
        total_s += G[t].Snum;

  if((Pf = (double *)malloc(total_s*sizeof(double)))==NULL) return(0);
  else return(1);
}

void initTheta(void)
{
  int i, v;
  long total;

  for( i=0; i<=idNum; i++)
    if( switch_info[i].fixed == NOTFIXED){
      total=0;
	/* bug fixed by kame on Jan/7/1998 */
      for( v=0; v<=switch_info[i].maxValue; v++){
		Theta(i,v,theta) = random_int(800000)+100000;
		total += Theta(i,v,theta);
      }
	  /* warning message added by kame on Dec/12/1997 */
	  if (total <= 0.0)
		printf("{PRISM INTERNAL WARNING: initTheta() -- denominator `total' is 0.0}\n");
	/* bug fixed by kame on Jan/7/1998 */
      for( v=0; v<=switch_info[i].maxValue; v++)
		Theta(i,v,theta)=
		  Theta(i,v,theta)/((total<=0.0)? 0.000000001: total);
	      /* denominator modified by kame on Dec/12/1997 */
    }
}

/* calc_Pdb() -- calcluration of Pdb
 *
 * modified by kame on Nov/20/1997 for global variable Pf
 * (to prevent the redundancy in caluculating Pf).
 *
 * modified by kame on Dec/12/1997 for warning message on Pdb's value.
 */
void calc_Pdb(void){

  double Pdb,tmpPf;
  int t,s,i,v,tmps,ss=0;

  /* calculates Pdb(G[t] = 1) for each t */
  for(t=0; t<goals; t++){
	Pdb = 0.0;
	for(s=0; s<G[t].Snum; s++){
	  tmpPf = 1.0;
	  for( i=0; i<=idNum; i++){
		for( v=0; v<=switch_info[i].maxValue; v++){
		  /*
		   * [NOTE] we redefine pow() such that pow(x,0) = 1.0 for any x.
		   */
		  tmps = Seq1(t,s,i,v);
		  tmpPf *= (tmps==0)? 1.0: pow(Theta(i,v,theta),tmps);
		}
		*(Pf+ss) = tmpPf;
	  }
	  ss++;
	  Pdb += tmpPf;
	}
	if (Pdb <= 0.0)
	  printf("{PRISM INTERNAL WARNING: calc_Pdb() -- `Pdb' is 0.0.\n");
	G[t].Pdb = Pdb;
	loglike +=
	  G[t].observedNum*log(((Pdb <= 0.0)? 0.000000001: Pdb));
  }
}

/* ON() -- calcluration of ON
 *
 * modified by kame on Nov/20/1997 for global variable Pf
 * (to prevent the redundancy in caluculating Pf).
 * modified by kame on Dec/12/1997 for warning message wrt Pdb's value.
 */
double ON(int id, int val)
{
  int t,s,i,v,tmps,ss=0;
  double on_t,on_s,Pdb,tmp_on;

  on_t = 0.0;
  for( t=0; t<goals; t++){
        Pdb = 0.0;
    on_s = 0.0;
    for( s=0; s<G[t].Snum; s++){
	  tmp_on = Pf[ss]*(Seq1(t,s,id,val)+(Gamma(t,s,id))*Theta(id,val,theta));
	  ss++;
	  on_s += tmp_on;
	}
	if (G[t].Pdb <= 0.0)
	  printf("{PRISM INTERNAL WARNING: ON() -- denominator `G[t].Pdb' is 0.0.}\n");
    on_t += (on_s/((G[t].Pdb<=0.0)? 0.000000001: G[t].Pdb))*G[t].observedNum;
  }
  return(on_t);
}

void copyTheta(double *theta_src, double *theta_dest)
{
  int i, v;
  
  for(i=0; i<=idNum; i++)
    if( switch_info[i].fixed == NOTFIXED)
      for(v=0; v<=switch_info[i].maxValue; v++)
	Theta(i,v,theta_dest) = Theta(i,v,theta_src);
}

/* update_theta(): updates each theta _once_.
 * 
 *   modified by kame on Nov/20/1997 
 *   modified by kame on Dec/12/1997 for warning message wrt Onsum.
 */
void update_theta(void)
{
  double ONsum, *on_table;
  int i,v,t,s;

  loglike_bak = loglike;
  loglike = 0.0;

  calc_Pdb();

  for( i=0; i <= idNum; i++)
	if (switch_info[i].fixed == NOTFIXED) {/* do nothing for fixed switch.*/
	  ONsum = 0;
	  on_table = (double *)malloc(valueNum1*sizeof(double));
	  for( v=0; v <= switch_info[i].maxValue; v++){
		*(on_table+v) = ON(i,v);
		ONsum += *(on_table+v);
		/* printf("ON(%d,%d)=%.6f ",i,v,ON(i,v)); */
	  }
	  if (ONsum <= 0.0)
		printf("{PRISM INTERNAL WARNING: update_theta() -- denominator `ONsum' is 0.0.\n");
	  /* printf("\n"); */
	  for( v=0; v <= switch_info[i].maxValue; v++)
		Theta(i,v,theta_tmp) =
		  *(on_table+v)/((ONsum<=0.0)? 0.000000001: ONsum);
	    /* denominator modified by kame on Dec/12/1997 */
	  free(on_table);
	  on_table = (double *)NULL;
	}
  copyTheta(theta_tmp, theta);
}

int c_EM_loop(void)
{
  int iteration=0;
  
  loglike=0.0; loglike_bak=0.0;
  /* printf("__loglike=%.6f loglike_bak=%.6f\n",loglike,loglike_bak); */

  update_theta();
  display(iteration++);

  while(1){
	update_theta();
	display(iteration++);
	if (loglike - loglike_bak < boundary) break;
  }
  printf("\n");
  return iteration-1;
}

void display(int iteration){
  
  if (iteration % 50 == 0) {
	printf("%d",iteration);
	fflush(stdout);
  }
  else if (iteration % 5 == 0){
	putchar('.');
	fflush(stdout);
  }
  
}

void initVars(void) {

  G=(struct goal *)NULL;
  switch_info=(struct switchInfo *)NULL;
  theta=(double *)NULL;
  theta_tmp=(double *)NULL;
  set_idInfo(-2,(int)NULL,(int)NULL);
  /* set_table(-2,(int)NULL); */

}

void freeMemory(void)
{
  int g,flag=0;

  if (G != (struct goal *)NULL) {
	for(g=0; g<goals; g++){
	  free(G[g].table);
	  G[g].table=(int *)NULL;
	}
	free(G);
	G=(struct goal *)NULL;
	flag=1;
  }
  if (switch_info != (struct switchInfo *)NULL){
	free(switch_info);
	switch_info=(struct switchInfo *)NULL;
	flag=1;
  }
  if (theta != (double *)NULL){
	free(theta);
	theta=(double *)NULL;
	flag=1;
  }
  if (theta_tmp != (double *)NULL){
	free(theta_tmp);
	theta_tmp=(double *)NULL;
	flag=1;
  }
  set_idInfo(-2,(int)NULL,(int)NULL);
                        /* Initialize -- reset of static variable */
  /* set_table(-2,(int)NULL);   Initialize -- reset of static variable */
  
  /* added by kame on Nov/20/1997 */
  if (Pf != (double *)NULL){
        free(Pf);
        Pf=(double *)NULL;
        flag=1;
  }

  if (flag) printf("{Previous explanation table cleaned up.}\n");
}

/* Returns Theta(i,v,theta) */
double get_theta(int idInProlog, int val){
  int i;
  for(i=0;i<=idNum;i++)
	if(switch_info[i].idInProlog == idInProlog)
	  return Theta(i,val,theta);
  return -1.0;
}

double get_loglike(void){
  return loglike;
}

/*
 *
 * PRINTER FUNCTIONS:
 *
 */

void show_sizes(void){

  printf("Table size:\n    idNum=%d valueNum=%d goals=%d\n",
		 idNum,valueNum,goals);

}

void show_sw(void){
  
  int i,v;

  printf("Switches recognized by C routine:\n");
  for(i=0; i<=idNum; i++){
	printf("    Switch %d (Code %d, Size %d, Tnum %d):",
		   i,switch_info[i].idInProlog,
		   (switch_info[i].maxValue)+1,switch_info[i].Tnum);
	for(v=0; v<=switch_info[i].maxValue; v++)
	  printf(" %d (%.6f)",v,Theta(i,v,theta));
	printf("\n");
  }
}

void show_goals(void)
{
  int t;

  printf("Goals recognized in C routine:\n");

  for(t=0;t<goals;t++)
	printf("    Goal %d: num_of_expls %d, weight %d\n",
		   t,G[t].Snum,G[t].observedNum);
}

void show_raw_table(void)
{
  int t,i;

  for(t=0;t<goals;t++){
	printf("G[%d]: ",t);
	for(i=0;i<idNum1*valueNum2*G[t].Snum;i++)
	  printf("%d ",*(G[t].table+i));
	printf("\n");
  }
}

void show_table(void)
{
  int t,s,i,v;

  printf("\n|Goal|Expl|\n+----+----+\n");
  for(t=0;t<goals;t++){
	for(s=0;s<G[t].Snum;s++){
	  printf("|%4d|%4d||",t,s);
	  for(i=0;i<=idNum;i++){
		for(v=0;v<=switch_info[i].maxValue;v++){
		  printf("%1d",Seq1(t,s,i,v));
		  if(v != switch_info[i].maxValue) printf(" ");
		}
		if(i==idNum) printf("||"); else printf("|");
	  }
	  for(i=0;i<=idNum;i++){
		printf("%1d",Gamma(t,s,i));
		if(i != idNum) printf(" ");
	  }
	  printf("|\n");
	}
  }
}
