/************************************************************************
 **  prob.c: C part of probability calculation.
 **
 **  Copyright (C) 1998
 **    Taisuke Sato, Yoshitaka Kameya, Yasushi Hagiwara, Nobuhisa Ueda,
 **      Dept. of Computer Science, Tokyo Institute of Technology.
 ************************************************************************/
 
#include "prism.h"

/* Global variables */
struct swlink *PRoot;
struct swlink *PCrr;
int PCount;    /* The number of explanations for "prob" goal */
int PMaxDepth; /* Depth of the TRIE */
double *PPstack;
struct swlink **PAstack;

/* prob $B%3%^%s%IMQ$N(B global $BJQ?t(B:
 *  [NOTE] $B3FJQ?t$O(B idNum,idNum1,..., theta $B$K3F!9BP1~$9$k!%(B
 */
int pidNum,pidNum1,pvalueNum,pvalueNum1,pvalueNum2,pivNum11;  
struct switchInfo *Pswitch_info;
int *P_rev_switch_info;
double *ptheta;

/* [NOTE]
 * $B3NN($r7W;;$9$k$N$_$N$H$-$O(B Gamma $B$NItJ,$OI,MW$J$$0Y!$(BEM learning $BMQ(B
 * $B$N(B table $B$H$O0[$J$j!$(Bprob $B%3%^%s%IMQ$N(B explanation table $B$O(B Gamma $B$N(B
 * $BMs$r;}$?$J$$!%(B
 */

/*
 * For initialization, cancellation,..
 */

/* (Prolog name) really_init_prob */
void initProb(void)
{
  PRoot=LNULL;
  PCrr=LNULL;
  Pswitch_info=(struct switchInfo *)NULL;
  ptheta=(double *)NULL;
  set_PidInfo(-2,(int)NULL);
}

/* (Prolog name) init_prob_table */
void freePMemory(void)
{
  int flag=0;

  if (Pswitch_info != (struct switchInfo *)NULL){
	free(Pswitch_info);
	Pswitch_info=(struct switchInfo *)NULL;
	flag=1;
  }
  if (P_rev_switch_info != (int *)NULL){
	free(P_rev_switch_info);
	P_rev_switch_info=(int *)NULL;
	flag=1;
  }
  if (ptheta != (double *)NULL){
	free(ptheta);
	ptheta=(double *)NULL;
	flag=1;
  }
  set_PidInfo(-2,(int)NULL);
}

/* (Prolog name) prepare_prob_trie */
int prepare_prob_trie(void)
{
  /* [NOTE] PRoot $B$O(B struct swlink $B$X$N%]%$%s%?(B */
  get_Ptrie(-1,(int)NULL,(int)NULL,(int)NULL);
  /* get_Ptrie $B$N(B static $BJQ?t(B pignore $B$r(B 0 $B$K=i4|2=(B */
  if ((PRoot = lalloc())==LNULL) return(0);
  else {
	PRoot->Sw=SNULL;
	PRoot->NextSw=LNULL;
	PCrr=PRoot;
	return(1);
  }
}

/*
 *  prepare_P_rev_idInfo() -- build a lookup table of switch IDs. 
 *
 *  Switch ID in Prolog is not always in [0,idNum], thus sometimes it is
 *  not the same as that in C. So we need a lookup table of both switch
 *  IDs in Prolog and C.
 *
 *  created by kame on Aug/6/1998
 */
int prepare_P_rev_idInfo(void)
{
  int i,max=0;

  /* Get maximum idInProlog */
  for(i=0;i<=pidNum;i++)
	max = (max < Pswitch_info[i].idInProlog)? Pswitch_info[i].idInProlog: max;

  if ((P_rev_switch_info = (int *)malloc(sizeof(int)*(max+1))) == (int *)NULL)
	return(0);

  for(i=0;i<=max;i++) P_rev_switch_info[i] = -1;
  
  for(i=0;i<=pidNum;i++)
	P_rev_switch_info[Pswitch_info[i].idInProlog]=i;
  
  return(1);
}

void show_P_rev_idInfo(void)
{
  int i;
  for(i=0;i<=pidNum;i++)
	printf("  P_rev_switch_info[%d]=%d\n",Pswitch_info[i].idInProlog,i);
}

/* (Prolog name) cancel_prob
 * added by kame on Nov/20/1997.
 */
void cancelProb(void)
{
  free(PRoot);
  PRoot=LNULL;
  PCount=0;
  PCrr=LNULL;

  freePMemory();
}

void return_to_prob_root(void)
{
  PCrr=PRoot;
}

void add_prob_count(int n)
{
  PCount += n;
}

int get_Ptrie(int flag, int g_id, int t_id, int val)
{
  static int pignore=0;

  if (flag==-1) { pignore=0; return(1); }

  if (pignore)
	switch(flag){
	case 1: return(1);
	case 0: pignore=0; return_to_prob_root(); return(1);
	default: return(0);
	}

  if (PCrr->Sw==SNULL) {

	if((PCrr->Sw=salloc())==SNULL) return(0);

	PCrr->Sw->G_id = (short)g_id;
	PCrr->Sw->T_id = (short)t_id;
	PCrr->Sw->Val  = (short)val;
	PCrr->Sw->Table = (int *)NULL;

	switch(flag){
	case 1:
	  if ((PCrr->Sw->Child=lalloc())==LNULL) return(0);
	  PCrr->Sw->Child->Sw=SNULL;
	  PCrr->Sw->Child->NextSw=LNULL;
	  PCrr=PCrr->Sw->Child;
	  break;
	case 0:
	  PCrr->Sw->Child=LNULL;
	  add_prob_count(1);
	  return_to_prob_root();
	  break;
	default:
	  printf("{PRISM INTERNAL ERROR: get_Ptrie(%d,_,_,_) -- ",flag);
	  printf("%d must be 0 or 1.}",flag);
	  return(0);
	}
	return(1);
  }
  else {
	while(1){
	  if (compare_sw(PCrr->Sw,g_id,t_id,val)){ /* $B0JA0$N(B trie $B$H%N!<%I6&M-(B */

		if (PCrr->Sw->Child==LNULL)
		  switch(flag){
		  case 1:     /* $BF~NO$HESCf$^$GF1$8(B explanation $B$,0JA0$K$bB8:_$7$?(B */
			pignore=1;         /* flag $B$,(B 0 $B$N8F$S=P$7$,$"$k$^$G2?$b$7$J$$(B */
			break;
		  case 0:         /* $BF~NO$HA4$/F1$8(B explanation $B$,0JA0$K$bB8:_$7$?(B */
			return_to_prob_root();    /*      PRoot $B$KLa$k0J30$O2?$b$7$J$$(B */
			break;                    /* (modified by kame on Dec/13/1997) */
		  default:
			printf("{PRISM INTERNAL ERROR: get_Ptrie(%d,_,_,_) -- ",flag);
			printf("%d must be 0 or 1.}",flag);
			return(0);
		  }
		else {
		  switch(flag){
		  case 1:
			PCrr=PCrr->Sw->Child;                /* $B99$K0JA0$N(B trie $B$rC)$k(B */
			break;
		  case 0:
			/* $B?7$7$$(B explanation $B$NJ}$,0JA0$N(B explanation $B$r(B subsume $B$9$k(B
             * $B$N$G0JA0$N(B explanation $B$r:o=|!%(B
             * <ex.>
             *   $B4{$K(B B & C v B & D $B$,$"$C$?;~!$?7$?$K(B B $B$,F~NO$5$l$?$i(B
             *   B & C v B & D v B <=> B $B$J$N$G(B C, D $B$r:o=|$7!$(BB $B$@$1$K(B
             *   $B$7$J$1$l$P$J$i$J$$!%(B
             *
			 * explanation $B$N?t$O>C$7$?;R6!$N?t(B-1 $B$r0z$$$F$*$/!%(B
             * (add_prob_count()$B$r;HMQ(B)
             */
			add_prob_count(1-freeChild(PCrr->Sw->Child));
			PCrr->Sw->Child=LNULL;
			return_to_prob_root();
			break;
		  defalut:
			printf("{PRISM INTERNAL ERROR: get_Ptrie(%d,_,_,_) -- ",flag);
			printf("%d must be 0 or 1.}",flag);
			return(0);
		  }
		} /* else{.. */
		break;
	  } /* if (compare_sw(...)){ */
	  else if (PCrr->NextSw==LNULL) {

		if ((PCrr->NextSw = lalloc())==LNULL) return(0);
		if ((PCrr->NextSw->Sw = salloc())==SNULL) return(0);
		PCrr->NextSw->NextSw=LNULL;

		PCrr->NextSw->Sw->G_id = (short)g_id;
		PCrr->NextSw->Sw->T_id = (short)t_id;
		PCrr->NextSw->Sw->Val  = (short)val;
		PCrr->NextSw->Sw->Table = (int *)NULL;
		
		switch(flag){
		case 1:
		  if ((PCrr->NextSw->Sw->Child = lalloc())==LNULL) return(0);
		  PCrr=PCrr->NextSw->Sw->Child;
		  PCrr->Sw=SNULL;
		  PCrr->NextSw=LNULL;
		  break;
		case 0:
		  PCrr->NextSw->Sw->Child=LNULL;
		  add_prob_count(1);
		  return_to_prob_root();
		  break;
		default:
		  printf("{PRISM INTERNAL ERROR: get_Ptrie(%d,_,_,_) -- ",flag);
		  printf("%d must be 0 or 1.}",flag);
		  return(0);
		}
		break;
	  } /* else if (...){ */
	  else PCrr=PCrr->NextSw;
	} /* while(1){ */
	return(1);
  } /* else { */
}

void free_Ptrie(void)
{
  int de; /* deleted explanations */

  if (PRoot != LNULL) {
	de = freeChild(PRoot);
	/* printf("Prob_Goal(%d) deleted %d expls\n",PCount,de); */
	PRoot=LNULL;
  }
  PCount=0;
  PCrr=LNULL;
}

int count_expls_P(void)
{
  int *ebuf;
  int i,td;
  
  if (prepare_P_rev_idInfo()>0) {
	/* [NOTE] ebuf will be linked to leaf node, so do not free(ebuf)! */
	if ((ebuf=(int *)malloc(sizeof(int)*pivNum11))==(int *)NULL)
	  return(0);
	for(i=0;i<pivNum11;i++) ebuf[i]=0;
  
	if (PRoot != LNULL)
	  PMaxDepth = count_an_expl_P(1,PRoot,ebuf);
  
	if (PMaxDepth>0) {
	  if (((PPstack=(double *)malloc(sizeof(double)*(PMaxDepth+1)))
		   ==(double *)NULL)
		  ||
		  ((PAstack
			=(struct swlink **)malloc(sizeof(struct sw *)*(PMaxDepth+1)))
		   ==NULL))
		return(0);
	  return(1);
	} else return(0);
  }
  else return(0);
}

int count_an_expl_P(int depth, struct swlink *swlp, int *ebuf)
{
  struct sw *swp;
  int *ebufcopy;
  int i,g,flag,r1,r2,mdepth1,mdepth2;

  swp=swlp->Sw;

  /* [NOTE] If there is a branch, buffers must be copied
   *        before updating them and conquering child nodes.
   */
  if(swlp->NextSw != LNULL){
	/* [NOTE] ebufcopy will be linked to leaf node,
     *        so do not free(ebufcopy)!
     */
	if((ebufcopy=(int *)malloc(sizeof(int)*pivNum11))==(int *)NULL)
	  return(0);
	for(i=0;i<pivNum11;i++) ebufcopy[i]=ebuf[i];
  }

  /* Updating Buffers */
  ebuf[P_rev_switch_info[swp->G_id]*pvalueNum1+swp->Val]++;

  /* Conquer child node before branching (depth-first) */
  if (swp->Child == LNULL) {
	swp->Table = ebuf;
	mdepth1=depth;
  }
  else mdepth1=count_an_expl_P(depth+1,swp->Child,ebuf);

  /* Branching */
  if (swlp->NextSw != LNULL)
	/* Recursion */
	mdepth2 = count_an_expl_P(depth,swlp->NextSw,ebufcopy);
  else mdepth2=depth;

  return (mdepth1>mdepth2)? mdepth1: mdepth2;
}

int set_Pvalues(int id, int mmvalue)
{
  pidNum=id;  pidNum1=pidNum+1;
  pvalueNum=mmvalue;  pvalueNum1=mmvalue+1;  pvalueNum2=mmvalue+2;
  pivNum11=pidNum1*pvalueNum1;

  if ((Pswitch_info =
	   (struct switchInfo *)malloc(sizeof(struct switchInfo)*
								   pidNum1))==(struct switchInfo *)NULL)
	return(0);
  if ((ptheta = (double *)malloc(pivNum11*sizeof(double)))==NULL)
	return(0);

  return(1);
}

/* $B3F(B switch $B$K4X$9$k>pJs$r@_Dj(B */
int set_PidInfo(int id, int maxval)
{
  /* Prolog $BB&$GEO$5$l$k(B switch $B$N(B G_id $B$OO"B3$7$?(B 0,...,pidNum $B$G(B
   * $B$"$k$H$O8B$i$J$$$N$G(B static $BJQ?t(B num $B$G4IM}$9$k(B	 
   */
  static int num=0; 

  if(id < 0){
	num=0; return(1);                /* id < 0 $B$N$H$-$OFCJL$K=i4|2=(B*/
  }

  if (num > pidNum) {
	printf("{PRISM INTERNAL ERROR: set_PidInfo(%d,%d)",id,maxval);
	printf("-- num(%d) exceeds idNum(%d)\n",num,pidNum);
	return(0);
  }

  Pswitch_info[num].idInProlog = id;
  Pswitch_info[num].maxValue   = maxval;
  Pswitch_info[num].Tnum       = 0;  /* $B;HMQ$7$J$$$N$G(B 0 $B$H$$$&?t$K0UL#$J$7(B */
  Pswitch_info[num].fixed      = NOTFIXED;

  num++;
  return(1);
}

int set_Ptheta(int id, int v, double param)
{
  int i,g_id,flag;

  flag=0;
  for(i=0;i<=pidNum;i++)
	if(id==Pswitch_info[i].idInProlog){ g_id=i; flag=1; }

  if (flag==0) {
	printf("{PRISM INTERNAL ERROR: Unknown switch %d occurs in trie.}\n",id);
	return(0);
  }
  
  ptheta[g_id*pvalueNum1+v]=param;

  /*
  ptheta[P_rev_switch_info[id]*pvalueNum1+v]=param;
  */

  return(1);
}

double calc_Pdb(void)
{
  double Pf, Pdb;
  int index;
  struct swlink *swlp, *swlp_next;
  struct sw *swp;

  /* Initialize buffers and indexes */
  Pdb=0.0;  Pf=1.0;
  index=0;  /* PPstack[0], PAstack[0] is not used */

  for(swlp=PRoot; swlp != LNULL; swlp=swlp_next){

	if (swlp->NextSw != LNULL) {
	  PPstack[++index] = Pf;
	  PAstack[index] = swlp->NextSw;
	}

	swp = swlp->Sw;
	Pf *= PTheta((P_rev_switch_info[swp->G_id]),(swp->Val));

	if (swp->Child == NULL) {
	  Pdb += Pf;
	  if (index > 0) {
		Pf = PPstack[index];
		swlp_next = PAstack[index--];
	  }
	  else swlp_next = LNULL;   /* No branch exists */
	}
	else swlp_next = swp->Child;
  }

  return Pdb;
}

/*
 *   PRINTER FUNCTIONS:
 */
void show_Ptrie(void)
{
  if (PRoot != LNULL) show_expl_more(0,PRoot);
}
