/************************************************************************
 **                                                                    **
 **      em.c:  built-in EM routine of PRISM system                    **
 **                                                                    **
 ** Copyright (C) 1998                                                 **
 **   Taisuke Sato, Yoshitaka Kameya, Yasushi Hagiwara, Nobuhisa Ueda, **
 **     Dept. of Computer Science, Tokyo Institute of Technology.      **
 **                                                                    **
 ************************************************************************/

#include "prism.h"

/* Global Variables: */
/* for EM learning */
int idNum;                   /* msw(i, d, x) : 0 <= i <= idNUM    */
int idNum1;                  /* always equals idNum+1             */
int valueNum;                /* msw(i, d, x) : 0 <= x <= valueNUM */
int valueNum1;               /* always equals valueNum+1          */
int valueNum2;               /* always equals valueNum+2          */
int ivNum11;                 /* always equals (idNum+1)*(valueNum+1) */
int ivNum12;                 /* always equals (idNum+1)*(valueNum+2) */
int goals;                   /* the number of different goals     */
struct switchInfo *switch_info;
int *rev_switch_info;
/* [NOTE] Prolog part of PRISM must refer to theta (not theta_tmp). */
double *theta;               /* master parameters */
double *theta_tmp;           /* temporal parameters */
double epsilon = EPSILON;
double beta;

double loglike, loglike_bak;
double *Pstack;              /* Stacks for calculation of Pf and ON */
struct swlink **Astack;

extern struct glink *Root;


/* [NOTE]
 *  To allocate the memory  consistently,  the function call order  of
 *  set_values(), set_inInfo(), set_fixed(), prepre_rev_info(),
 *  set_goals() and set_data() must satisfy following conditions.
 *  
 *  - set_values < set_idInfo
 *  - set_values < set_fixed
 *  - set_values < prepare_rev_idInfo
 *  - set_idInfo < prepare_rev_idInfo
 *  - set_goals  < set_data
 * 
 *  where the relation "f < g" means that function f() must be called
 *  before g() is called. 
 */

void set_epsilon(double new_epsilon){
  epsilon = new_epsilon;
}

void show_epsilon(void){
  printf("Epsilon is set to %e.\n",epsilon);
}

double get_epsilon(void){
  return epsilon;
}

/* [NOTE]
 *   id = idNum (= max ID of occurring switches)
 *   value = valueNum (= max of max value of occurring switches)
 */
int set_values( int id, int value)
{   
  /* printf("set_values -- received id:%d value:%d\n",id,value); */

  idNum=id;  idNum1=idNum+1;
  valueNum=value; valueNum1=valueNum+1; valueNum2=valueNum+2;
  ivNum11=idNum1*valueNum1;
  ivNum12=idNum1*valueNum2;

  if ((switch_info = (struct switchInfo *)malloc(sizeof(struct switchInfo)
						 *idNum1)) == NULL)
    return(0);
  if((theta = (double *)malloc(idNum1*valueNum1*sizeof(double))) == NULL)
    return(0);
  if((theta_tmp = (double *)malloc(idNum1*valueNum1*sizeof(double))) 
     == NULL)
    return(0);
  return(1);
}

/* rewrited by kame on Jul/23/1998 */
int set_goals(int g)
{
  /* `goals' is global variable
      indicating no. of diffent goal patterns */
  goals = g;

  return(1);
}

/* set_idInfo(id,maxval,tnum) -- setting for each of occurring switches
 * [NOTE]
 *   id = (ID of the occurring switch)
 *   maxval = (max value of the switch "id"),
 *     i.e., for msw(id, d, v), 0<=v<=maxval.
 *   tnum = (the number of different T_id (2nd argument))
 */
int set_idInfo(int id, int maxval, int tnum) 
{                                    
  /* Since variable "id" recieved from Prolog is not always the
   * ordered intergers 0,...,idNum, we use static variable "num".
   */ 
  static int num=0;

  /* printf("set_idInfo -- received id:%d maxval:%d with idNum:%d\n",
		 id,maxval,idNum); */
  if (id < 0) {
	num=0; return(1);      /* Initialize (maxval is ignored.) */
  }
  if(num > idNum1) 
    return(0);
  switch_info[num].idInProlog = id;
  switch_info[num].maxValue = maxval;
  switch_info[num].Tnum = tnum;
  switch_info[num].fixed = NOTFIXED;
  num++;
  return(1);
}
  
/* set flag on for fixed switches */
void set_fixed(int idInProlog, int val, double Prob)  
{
  /* printf("set_fixed -- id:%d val:%d Prob:%d\n",id,val,Prob); */
  int i;
  for(i=0;i<=idNum;i++)
	if(switch_info[i].idInProlog == idInProlog){
	  switch_info[i].fixed = FIXED;
	  Theta(i, val, theta) = Prob;
	}
}

/* set_data(t, num) -- setting for teacher data
 * [NOTE]
 *   t = (ID of the goal), num = (the weight of the goal)
 */
void set_data(int t, int num)
{
  struct glink *glp;

  for(glp=Root; glp != GNULL; glp=glp->NextG){
	if (t==glp->Goal) {
	  glp->ObsNum = num;
	  break;
	}
  }
}

/*
 *  prepare_rev_idInfo() -- build a lookup table of switch IDs. 
 *
 *  Switch ID in Prolog is not always in [0,idNum], thus sometimes it is
 *  not the same as that in C. So we need a lookup table of both switch
 *  IDs in Prolog and C.
 *
 *  created by kame on Jul/11/1997
 */
int prepare_rev_idInfo(void)
{
  int i,max=0;

  /* Get maximum idInProlog */
  for(i=0;i<=idNum;i++)
	max = (max < switch_info[i].idInProlog)? switch_info[i].idInProlog: max;

  if ((rev_switch_info = (int *)malloc(sizeof(int)*(max+1))) == (int *)NULL)
	return(0);

  for(i=0;i<=max;i++) rev_switch_info[i] = -1;
  
  for(i=0;i<=idNum;i++)
	rev_switch_info[switch_info[i].idInProlog]=i;
  
  return(1);
}

void show_rev_idInfo(void)
{
  int i;
  for(i=0;i<=idNum;i++)
	printf("  rev_switch_info[%d]=%d\n",switch_info[i].idInProlog,i);
}

/* count_expls(): count switches and make entries of the EM table,
 *                and build Pstack ans Astack (stacks for the calculation
 *                of Pdb).
 *
 *    re-created by kame on Jul/23/1998
 */
int count_expls(void){

  struct glink *glp;
  int *ebuf;
  int i,mdepth,td; /* mdepth is maximum depth of the TRIE */

  if (prepare_rev_idInfo()>0) {
	mdepth=0;
	for(glp=Root; glp != GNULL; glp=glp->NextG){
	  /* [NOTE] ebuf will be linked to leaf node, so do not free(ebuf)! */
	  if ((ebuf=(int *)malloc(sizeof(int)*ivNum12))==(int *)NULL)
		return(0);
	  for(i=0;i<ivNum12;i++) ebuf[i]=0;
	  if (glp->Expl != LNULL)
		td = count_an_expl(1,glp->Expl,ebuf);
	  glp->MaxDepth = td;
	  mdepth = (td > mdepth)? td: mdepth;
	}
	/* Build Pstack and Astack (the size is mdepth) */
	if (mdepth>0) {
	  if (((Pstack=(double *)malloc(sizeof(double)*(mdepth+1)))
		   ==(double *)NULL)
	    || 
		((Astack=(struct swlink **)malloc(sizeof(struct sw *)*(mdepth+1)))
			==(struct swlink **)NULL))
		  return(0);
	  return(1);
	  /* printf("Stacks are built in %8x(P), %8x(A)\n",Pstack,Astack); */
	} else return(0);
  }
  else return(0);
}

/* count_an_expl(): count switches and make an entry of the EM table,
 *                  and return the maximum depth of expls.
 *
 *    re-created by kame on Jul/23/1998
 */
int count_an_expl(int depth, struct swlink *swlp, int *ebuf){

  struct sw *swp;
  int *ebufcopy;
  int i,j,r1,r2,mdepth1,mdepth2;

  swp=swlp->Sw;

  /* [NOTE] If there is a branch, buffers must be copied
   *        before updating them and conquering child nodes.
   */
  if (swlp->NextSw != LNULL){
	/* [NOTE] ebufcopy will be linked to leaf node,
     *        so do not free(ebufcopy)!
     */
	if ((ebufcopy=(int *)malloc(sizeof(int)*ivNum12))==(int *)NULL) return(0);
	for(j=0;j<ivNum12;j++)
	  ebufcopy[j]=ebuf[j];
  }

  /* Updating Buffers */
  i = rev_switch_info[swp->G_id];
  ebuf[i*valueNum1+(swp->Val)]++;        /* Count up the switch */
  ebuf[ivNum11+i]++;                     /* used for Gamma  */

  /* Conquer child node before branching (depth-first) */
  if (swp->Child == LNULL) {
	/* If the switch is leaf node,
         associate it with an entry of EM table  */
	for(i=0;i<=idNum;i++)
	  ebuf[ivNum11+i] = switch_info[i].Tnum - ebuf[ivNum11+i]; /* Gamma */
    swp->Table = ebuf;
	mdepth1=depth;
  }
  else mdepth1=count_an_expl(depth+1,swp->Child,ebuf);  /* Recursion */

  /* Branching */
  if (swlp->NextSw != LNULL) {
	/* Recursion */
	mdepth2 = count_an_expl(depth,swlp->NextSw,ebufcopy);
  }
  else mdepth2=depth;
  
  return (mdepth1>mdepth2)? mdepth1: mdepth2;
}

void initTheta(void)
{
  int i, v;
  long total;

  for( i=0; i<=idNum; i++)
    if( switch_info[i].fixed == NOTFIXED){
      total=0;
	/* bug fixed by kame on Jan/7/1998 */
      for( v=0; v<=switch_info[i].maxValue; v++){
		Theta(i,v,theta) = random_int(800000)+100000;
		total += Theta(i,v,theta);
      }
	  /* warning message added by kame on Dec/12/1997 */
	  if (total <= 0.0)
		printf("{PRISM INTERNAL WARNING: initTheta() -- denominator `total' is 0.0}\n");
	/* bug fixed by kame on Jan/7/1998 */
      for( v=0; v<=switch_info[i].maxValue; v++)
		Theta(i,v,theta)=
		  Theta(i,v,theta)/((total<=0.0)? 0.000000001: total);
	      /* denominator modified by kame on Dec/12/1997 */
    }
}

/* ON_loop():  using no recursive calls
 *
 *   created by kame on Jul/24/1998.
 *
 *  [NOTE]
 *    - ONtab is a table in each of whose cell contains ON_i(v).
 *    - ON^t_i(v) is caluculated depth-first along paths in TRIE
 *      structured explanations.
 *    - For efficiency, `for' loops are used instead of recursive calls.
 *    - If there is a branch, stack up the probability and the return
 *      address on Pstack and Astack respectively.
 *    - Pstack and Astack have already been built in count_expls().
 *
 */
int ON_loop(double *ONtab)
{
  int i,j,v,gamma;
  int *tab;
  double Pdb,Pf;
  double *onbuf;
  struct glink *glp;
  struct swlink *swlp, *swlp_next;
  struct sw *swp;
  int index;  /* index for Pstack, Astack */

  /* Initialize ONtab */
  for(i=0;i<ivNum11;i++) ONtab[i]=0.0;

  /* [NOTE] loglike is already initialized to 0.0
   *        in update_theta().
   */
  
  if ((onbuf=(double *)malloc(sizeof(double)*ivNum11))==(double *)NULL)
	return(0);

  for(glp = Root; glp != GNULL; glp=(glp->NextG)){

	/* Initialize buffers and indexes */
	for(i=0;i<ivNum11;i++) onbuf[i]=0.0;
	Pdb=0.0; Pf=1.0;
	index=0;  /* Pstack[0], Astack[0] is not used */

	for(swlp=glp->Expl; swlp != LNULL; swlp=swlp_next){
	  
	  /* printf("swlp = %8x Pf = %.12f\n",swlp,Pf); */

	  if (swlp->NextSw != LNULL){
		Pstack[++index] = Pf;
		Astack[index] = swlp->NextSw;
	  }

	  /*
	  printf("Pstack:");
	  for(i=1;i<=index;i++)	printf(" %.12f",Pstack[i]);
	  printf("\n");
	  printf("Astack:");
	  for(i=1;i<=index;i++)	printf(" %8x",Astack[i]);
	  printf("\n");
	  */

	  swp = swlp->Sw;
	  Pf *= Theta((rev_switch_info[swp->G_id]),(swp->Val),theta);

	  if (swp->Child == LNULL){
		Pdb += Pf;
		for(i=0; i<=idNum; i++){
		  j = i*valueNum1;
		  tab = swp->Table;
		  gamma = *(tab+ivNum11+i);
		  for(v=0; v<=switch_info[i].maxValue; v++)
			onbuf[j+v] += Pf*(*(tab+j+v) + gamma*Theta(i,v,theta));
		}
		if (index > 0) {
		  Pf = Pstack[index];
		  swlp_next = Astack[index--];
		}
		else swlp_next = LNULL;  /* No branch exists */
	  }
	  else swlp_next = swp->Child;
	}

	if (Pdb <= 0) {
	  printf("{PRISM INTERNAL WARNING: ON_loop() -- Pdb is 0.0.}\n");
	  Pdb = 0.000000000000000001;
	}
	/* printf("Pdb = %.12f\n",Pdb); */
	glp->Pdb = Pdb;
	loglike += glp->ObsNum*log(Pdb);

	for(i=0; i<=idNum; i++){
	  j = i*valueNum1;
	  for(v=0; v<=switch_info[i].maxValue; v++)
		ONtab[j+v] += glp->ObsNum*(onbuf[j+v]/Pdb);
	}
  }
  free(onbuf);
  return(1);
}

void copyTheta(double *theta_src, double *theta_dest)
{
  int i, v;
  
  for(i=0; i<=idNum; i++)
    if( switch_info[i].fixed == NOTFIXED)
      for(v=0; v<=switch_info[i].maxValue; v++)
	Theta(i,v,theta_dest) = Theta(i,v,theta_src);
}

/* update_theta() -- update theta *once*.
 * 
 *   re-created by kame on Jul/23/1998.
 */
int  update_theta(void)
{
  double *ONtab;
  double ONsum;
  int i,j,v;

  loglike_bak = loglike;
  loglike = 0.0;

  if ((ONtab = (double *)malloc(sizeof(double)*ivNum11))==(double *)NULL)
	printf("{PRISM INTERNAL ERROR: update_theta() -- Memory allocation failed.}\n");

  /* Get ON table */
  if (ON_loop(ONtab)==0) return(0);
  /* show_ONtab(ONtab); */

  for(i=0; i<=idNum; i++){
	if (switch_info[i].fixed == NOTFIXED) {/* do nothing for fixed switch */
	  ONsum=0.0;
	  j = i*valueNum1;
	  for(v=0; v<=switch_info[i].maxValue; v++)
		ONsum += ONtab[j+v];
	  if (ONsum <= 0.0)
		printf("{PRISM INTERNAL WARNING: update_theta() -- ONsum is 0.0.}\n");
	  for(v=0; v<=switch_info[i].maxValue; v++)
		Theta(i,v,theta_tmp) =
		  ONtab[j+v]/((ONsum <= 0.0)? 0.0000000000001: ONsum);
	}
  }
  free(ONtab);
  copyTheta(theta_tmp, theta);

  return(1);
}

/*
 *  c_EM_loop(): main routine of the EM algorithm.
 *
 */
int c_EM_loop(void)
{
  int iteration=0;
  
  loglike=0.0; loglike_bak=0.0;
  /* printf("loglike=%.12f loglike_bak=%.12f\n",loglike,loglike_bak); */

  update_theta();
  display(iteration++);
  /* printf("[%d] loglike=%.12f loglike_bak=%.12f\n",iteration++,loglike,loglike_bak); */

  while(1){
	update_theta();
	display(iteration++);
        /* printf("[%d] loglike=%.12f loglike_bak=%.12f\n",iteration++,loglike,loglike_bak); */
	if (loglike - loglike_bak < epsilon) break;
  }
  printf("\nloglike = %.12f\n",loglike);
  return iteration-1;
}

void display(int iteration){
  
  if (iteration % DISPLAYNUM == 0) {
	printf("%d",iteration);
	fflush(stdout);
  }
  else if (iteration % DISPLAYDOT == 0){
	putchar('.');
	fflush(stdout);
  }
  
}

void initVars(void) {

  switch_info=(struct switchInfo *)NULL;
  theta=(double *)NULL;
  theta_tmp=(double *)NULL;
  set_idInfo(-2,(int)NULL,(int)NULL);

}

void freeMemory(void)
{
  int g,flag=0;

  if (switch_info != (struct switchInfo *)NULL){
	free(switch_info);
	switch_info=(struct switchInfo *)NULL;
	flag=1;
  }
  if (rev_switch_info != (int *)NULL){
	free(rev_switch_info);
	rev_switch_info=(int *)NULL;
	flag=1;
  }
  if (theta != (double *)NULL){
	free(theta);
	theta=(double *)NULL;
	flag=1;
  }
  if (theta_tmp != (double *)NULL){
	free(theta_tmp);
	theta_tmp=(double *)NULL;
	flag=1;
  }
  set_idInfo(-2,(int)NULL,(int)NULL);
                        /* Initialize -- reset of static variable */
  /* set_table(-2,(int)NULL);   Initialize -- reset of static variable */
  
  /* added by kame on Nov/20/1997 */
  /*
  if (Pf != (double *)NULL){
	free(Pf);
	Pf=(double *)NULL;
	flag=1;
  }
  */
  if (Pstack != (double *)NULL){
	free(Pstack);
	Pstack=(double *)NULL;
	flag=1;
  }
  if (Astack != (struct swlink **)NULL){
	free(Astack);
	Astack=(struct swlink **)NULL;
	flag=1;
  }

  if (flag) printf("{Previous explanation table cleaned up.}\n");
}

/* Returns Theta(i,v,theta) */
double get_theta(int idInProlog, int val){
  int i;
  for(i=0;i<=idNum;i++)
	if(switch_info[i].idInProlog == idInProlog)
	  return Theta(i,val,theta);
  return -1.0;
}

double get_loglike(void){
  return loglike;
}

/*
 *
 * PRINTER FUNCTIONS:
 *
 */

void show_sizes(void){

  printf("Table size:\n    idNum=%d valueNum=%d goals=%d\n",
		 idNum,valueNum,goals);

}

void show_sw(void){
  
  int i,v;

  printf("Switches recognized by C routine:\n");
  for(i=0; i<=idNum; i++){
	printf("    Switch %d (Code %d, Size %d, Tnum %d):",
		   i,switch_info[i].idInProlog,
		   (switch_info[i].maxValue)+1,switch_info[i].Tnum);
	for(v=0; v<=switch_info[i].maxValue; v++)
	  printf(" %d (%.6f)",v,Theta(i,v,theta));
	printf("\n");
  }
}

void show_raw_sw(void)
{
  int i,v;
  
  for(i=0; i<=idNum; i++){
	for(v=0; v<=switch_info[i].maxValue; v++)
	  printf("Th(%d,%d)=%.6f  ",i,v,Theta(i,v,theta));
	printf("\n");
  }
}

void show_goals(void)
{
  /* int t; */
  struct glink *glp;

  printf("Goals recognized in C routine:\n");

  for(glp=Root; glp != GNULL; glp=glp->NextG) {
	if (glp->Pdb < 0)
	  printf("    Goal %d: num_of_expls %d, max_depth %d, weight %d, prob (unknown)\n",
			 glp->Goal,glp->Count,glp->MaxDepth,glp->ObsNum);
    else
	  printf("    Goal %d: num_of_expls %d, max_depth %d, weight %d, prob %.6f\n",
			 glp->Goal,glp->Count,glp->MaxDepth,glp->ObsNum,glp->Pdb);
  }
}

void show_ONtab(double *ONtab)
{
  int i,v,j;

  for(i=0; i<=idNum; i++){
	for(v=0; v<=switch_info[i].maxValue; v++)
	  printf("ON(%d,%d)=%.6f  ",i,v,ONtab[i*valueNum1+v]);
	printf("\n");
  }
}
