/************************************************************************
 **  trie.c: build TRIE structure of the explanations.
 **
 **  Copyright (C) 1998
 **    Taisuke Sato, Yoshitaka Kameya, Yasushi Hagiwara, Nobuhisa Ueda,
 **      Dept. of Computer Science, Tokyo Institute of Technology.
 ************************************************************************/

#include "prism.h"

/* Global variables */
struct glink *Root;
struct glink *ERoot;
struct swlink *Crr;

/* Global variables used in em.c */
extern int idNum,idNum1,valueNum,valueNum1,valueNum2,ivNum12,ivNum11;
extern int goals;
extern struct switchInfo *switch_info;

/*
 *  trie.c: PRISM keeps explanations in the form of trie structure.
 *
 *  <ex.>
 *  hmm([a,b]) is explained by
 *
 *    bsw(init,nil,0) & bsw(obs(s0),1,0) & bsw(trans(s0),0) & ...
 *  v bsw(init,nil,0) & bsw(obs(s0),1,0) & bsw(trans(s0),1) & ...
 *  v ...
 *  v bsw(init,nil,1) & bsw(obs(s1),1,1) & bsw(trans(s1),1) & ...
 *
 *  An explanation is a conjunction of switches, e.g, bsw(init,nil,0) &
 *  bsw(obs(s0),1,0) & bsw(trans(s0),0) &...
 *  PRISM keeps all explanations in the following form, which is called
 *  trie structure.
 *
 *  bsw(init,nil,0) -+-> bsw(obs(s0),1,0) -+-> bsw(trans(s0),0) -->..
 *                   |                     |
 *                   |                     +-> bsw(trans(s0),1) -->..
 *                   |
 *                   +-> bsw(obs(s0),1,1) -+-> ...
 */  

void initRoot(void) {
  Root=GNULL; ERoot=GNULL; Crr=LNULL;
}

/* Functions for preparation of trie structure */
int prepare_trie(void) {

  get_trie(-1,(int)NULL,(int)NULL,(int)NULL);
  /* get_trie $B$N(B static $BJQ?t(B ignore $B$r(B 0 $B$K=i4|2=(B */
  if ((Root = galloc())==GNULL) return(0);
  else {
	Root->Expl=LNULL;
	Root->NextG=GNULL;
	ERoot=Root;
	return(1);
  }
}

void cancel_trie(void) {

  free(Root);

}

int prepare_expl(int goal) {

  if (ERoot->Expl==LNULL) {

	ERoot->Goal=goal;
	ERoot->Count=0;
	ERoot->ObsNum=0;  /* added by kame on Jul/11/1998 */
	ERoot->Pdb=-1.0;  /* added by kame on Jul/11/1998 */
	ERoot->NextG=GNULL;

	if ((ERoot->Expl=lalloc())==LNULL) return(0);
	(ERoot->Expl)->Sw=SNULL;
	(ERoot->Expl)->NextSw=LNULL;
	Crr=ERoot->Expl;
	return(1);
  }
  else if (ERoot->NextG==GNULL) {

	if ((ERoot->NextG = galloc())==GNULL) return(0);

	ERoot=ERoot->NextG;

	ERoot->Goal=goal;
	ERoot->Count=0;
	ERoot->ObsNum=0;  /* added by kame on Jull/11/1998 */
	ERoot->Pdb=-1.0;  /* added by kame on Jull/11/1998 */
	ERoot->NextG=GNULL;

	if ((ERoot->Expl=lalloc())==LNULL) return(0);
	(ERoot->Expl)->Sw=SNULL;
	(ERoot->Expl)->NextSw=LNULL;
	Crr=ERoot->Expl;
	return(1);
  }
  else {
	ERoot=ERoot->NextG;
	return prepare_expl(goal);
  }
}

/* 'Return to root' functions */
void return_to_root(void) {
  ERoot=Root;
}

void return_to_expl_root(void) {
  Crr=ERoot->Expl;
}

/* Increment or decrement the count of explanations */
void add_expl_count(int n) {
  ERoot->Count += n;
}

/* Memory allocation functions */
struct sw *salloc(void) {
  return (struct sw *)malloc(sizeof(struct sw));
}

struct swlink *lalloc(void) {
  return (struct swlink *)malloc(sizeof(struct swlink));
}

struct glink *galloc(void) {
  return (struct glink *)malloc(sizeof(struct glink));
}

/*
 * PRISM $B$O$"$k%4!<%k(B(goal)$B$N@bL@<0(B(switch $B$N(B $BA*8@I8=`7A(B)$B$+$i(B switch $B$N(B
 * $B%H%i%$(B(trie)$B9=B$$r9=C[$9$k!%$?$@$7@bL@<0$NA*8@;R(B (switch $B$N(B $BO"8@(B) $B$O(B
 * switch $B$NBh(B1$B0z?t$HBh(B2$B0z?t$NAH$r%-!<$K$7$F(B  sort  $B$7$F$"$j!$(B $B=EJ#$9$k(B
 * switch $B$,=P8=$7$J$$$b$N$H$9$k!%(B
 *
 * get_trie() $B$O8D!9$N(B switch $B$r<u$1<h$j=g$K%H%i%$9=B$$r9=C[$9$k!%Nc$($P(B
 * sw(1,1,0) & sw(1,2,0) & sw(2,2,0) v sw(1,1,0) & sw(2,1,0) & sw(3,2,0)
 * $B$,M?$($i$l$?@bL@<0$G$"$C$?$H$-!$(B
 *
 *   get_trie(1,1,1,0);
 *   get_trie(1,1,2,1);
 *   get_trie(0,2,2,0);
 * 
 *   get_trie(1,1,1,0);
 *   get_trie(1,1,2,1);
 *   get_trie(0,3,2,0);
 *
 * $B$H8F$S=P$;$P$h$$!%(Bget_trie() $B$NBh(B1$B0z?t$r(B 0 $B$H$9$k$3$H$G0l$D$N(B switch
 * $B$NO"8@$,0l$D=*$o$C$?$3$H$rCN$i$;$k!%(B
 */
int get_trie(int flag, int g_id, int t_id, int val) {

  static int ignore=0;
  /*
   * $B@bL@<0$,(B A&B v A&B&C&D v B&C $B$N$H$-!$(BA&B&C&D $B$r:o=|$9$kI,MW$,$"$k!%(B
   * $B$b$C$H6qBNE*$K$$$&$H!$(Bget_trie() $B$,(B 2$BHVL\$N(B disjunct $BCf$N(B B $B$r<u$1(B
   * $B<h$C$?;~E@$G(B C $B0J2<$rL5;k$9$l$P$h$$(B(A $B$H(B B $B$O(B 1 $BHVL\$N(B disjunct $B$K(B
   * $B%^!<%8$5$l$F$$$k(B)$B!%$H$3$m$,(B Prolog $BB&$G$O2?$b9M$($:99$K(B C $B0J2<$rAw(B
   * $B$j$D$1$F$/$k$N$G(B static $BJQ?t(B ignore $B$rMQ0U$7!$(BB $B$r<u$1<h$C$?;~E@$G(B
   * ignore $B$r(B 1 $B$K%;%C%H$7!$(Bignore $B$,(B 1 $B$G$"$k8B$j2?$b9T$J$o$:$KL5;k$9(B
   * $B$k$h$&$K$9$k!%$?$@$7F~NO(B flag $B$,(B 0 $B$G$"$C$?>l9g!$$=$3$G0l$D$N(B
   * disjunct $B$,=*$o$C$?$3$H$K$J$k$N$G(B ignore $B$r(B 0 $B$K%j%;%C%H$7!$?7$?$J(B
   * disjunct B&C $B$rBT$D!%(B
   */
  struct sw *swp;

  /* printf("get_trie(%d,%d,%d,%d)\n",flag,g_id,t_id,val); */

  if (flag==-1) { /* -1 $B$N(B flag $B$r<u$1<h$C$?$i(B ignore $B$r(B 0 $B$K%;%C%H(B */
	ignore=0;
	return(1);
  }

  if (ignore)
	switch(flag){
	case 1:
	  return(1);           /* $B2?$b$7$J$$(B */
	case 0:
	  /* disjunct $B$,=*$o$C$?$N$G(B ignore $B$r%j%;%C%H$7!$(Broot $B$KLa$k(B */
	  ignore=0; return_to_expl_root(); return(1);
	default:
	  return(0);
	}

  if (Crr->Sw==SNULL) {
	/* printf("Crr->Sw==SNULL\n"); */

	if ((Crr->Sw = salloc())==SNULL) return(0);

	(Crr->Sw)->G_id = (short)g_id;
	(Crr->Sw)->T_id = (short)t_id;
	(Crr->Sw)->Val  = (short)val;
	(Crr->Sw)->Table = (int *)NULL;

	switch(flag){
	case 1:
	  if (((Crr->Sw)->Child = lalloc())==LNULL) return(0);
	  ((Crr->Sw)->Child)->Sw=SNULL;
	  ((Crr->Sw)->Child)->NextSw=LNULL;
	  Crr=(Crr->Sw)->Child;
	  break;
	case 0:
	  (Crr->Sw)->Child = LNULL;
	  add_expl_count(1);
	  return_to_expl_root();
	  break;
	default:
	  printf("{PRISM INTERNAL ERROR: get_trie(%d,_,_,_) --");
	  printf("%d must be 0 or 1}",flag);
	  return(0);
	}
	return(1);
  }
  else {
	
	while(1){
	  if (compare_sw(Crr->Sw,g_id,t_id,val)) { /* $B0JA0$N$b$N$H%N!<%I$r6&M-(B */
		if ((Crr->Sw)->Child==LNULL)
		  switch(flag){
		  case 1:     /* $BF~NO$HESCf$^$GF1$8(B explanation $B$,0JA0$K$bB8:_$7$?(B */
			ignore=1;          /* flag $B$,(B 0 $B$N8F$S=P$7$,$"$k$^$G2?$b$7$J$$(B */
			break;
		  case 0:         /* $BF~NO$HA4$/F1$8(B explanation $B$,0JA0$K$bB8:_$7$?(B */
			return_to_expl_root();       /* add_expl_count(1) $B$O9T$J$o$J$$(B */
			break;
		  default:
			printf("{PRISM INTERNAL ERROR: get_Ptrie(%d,_,_,_) -- ");
			printf("%d must be 0 or 1.}",flag);
			return(0);
		  }
		else {
		  switch(flag){
		  case 1:
			Crr=(Crr->Sw)->Child;
			break;
		  case 0:
			/* $B>C$7$?;R6!$N?t(B-1 $B$r0z$/(B */
			add_expl_count(1-freeChild((Crr->Sw)->Child));
			(Crr->Sw)->Child=LNULL;
			return_to_expl_root();
			break;
		  default:
			printf("{PRISM INTERNAL ERROR: get_trie(%d,_,_,_) --");
			printf("%d must be 0 or 1}",flag);
			return(0);
		  }
		}
		break;
	  }
	  else if (Crr->NextSw==LNULL) {

		if ((Crr->NextSw = lalloc())==LNULL) return(0);
		if (((Crr->NextSw)->Sw = salloc())==SNULL) return(0);
		(Crr->NextSw)->NextSw=LNULL;

		((Crr->NextSw)->Sw)->G_id = (short)g_id;
		((Crr->NextSw)->Sw)->T_id = (short)t_id;
		((Crr->NextSw)->Sw)->Val  = (short)val;
		((Crr->NextSw)->Sw)->Table  = (int *)NULL;

		switch(flag){
		case 1:
		  if ((((Crr->NextSw)->Sw)->Child = lalloc())==LNULL) return(0);
		  (((Crr->NextSw)->Sw)->Child)->Sw=SNULL;
		  (((Crr->NextSw)->Sw)->Child)->NextSw=LNULL;
		  Crr=((Crr->NextSw)->Sw)->Child;
		  break;
		case 0:
		  ((Crr->NextSw)->Sw)->Child = LNULL;
		  add_expl_count(1);
		  return_to_expl_root();
		  break;
		default:
		  printf("{PRISM INTERNAL ERROR: get_trie(%d,_,_,_) --");
		  printf("%d must be 0 or 1}",flag);
		  return(0);
		}
		break;
	  }
	  else {
		Crr=Crr->NextSw;
	  }
	}
	return(1);
  }
}

/* $B:G=*E*$K>C$7$?MU$N?t(B(explanation$B$N?t(B)$B$rJV$9(B
 * Modified by kame on Nov/26/1997 
 */
int freeChild(struct swlink *childp){

  struct swlink *tmp_ptr;
  struct swlink *tmp_ptr2;
  struct sw *swp;
  int deleted_child=0;

  for(tmp_ptr=childp; tmp_ptr != LNULL; tmp_ptr = tmp_ptr2){

	swp = tmp_ptr->Sw;

	if (swp->Child != LNULL) {
	  deleted_child += freeChild(swp->Child);
	  swp->Child=LNULL;
	}
	else {
	  if (swp->Table != (int *)NULL) {
		free(swp->Table);
		swp->Table=(int *)NULL;
	  }
	  deleted_child += 1;
	}

	free(tmp_ptr->Sw);
	tmp_ptr->Sw=SNULL;

	tmp_ptr2 = tmp_ptr->NextSw;
	free(tmp_ptr);
  }
  return deleted_child;
}

/* swp $B$N;X$9(B switch $B$N0z?t$H(B g_id, t_id, val $B$,0lCW$9$k$+(B(1 or 0)? */
int compare_sw(struct sw *swp, int g_id, int t_id, int val) {

  return (swp->G_id==(short)g_id
          && swp->T_id==(short)t_id
          && swp->Val==(short)val);

}

void freeTrie(void) {

  struct glink *glp;
  struct glink *glq;
  int deleted_expls=0,de;
  
  if (Root != GNULL) {
	for(glp=Root; glp != GNULL; glp=glq) {
	  deleted_expls += (de = freeChild(glp->Expl));
	  /* printf("Goal[%d](%d expls) deleted %d expls\n",
		 glp->Goal,glp->Count,de); */
	  glp->Expl=LNULL;
	  glq=glp->NextG;
	  free(glp);
	}
	Root=GNULL;
	printf("{Previous explanations cleaned up.}\n");
  }

  ERoot=GNULL;
  Crr=LNULL;

}


/* show_trie(): print TRIE structured explanations.
   modified by kame on Jul/11/1998
*/
void show_trie(void){

  struct glink *glp;

  for(glp=Root; glp != GNULL; glp = glp->NextG) {
	if (glp->Pdb < 0) {
	  printf("[Goal %d] -- %d expls, max_depth %d, weight %d (Pdb = unknown)\n",
             glp->Goal,glp->Count,glp->MaxDepth,glp->ObsNum);
	} else {
	  printf("[Goal %d] -- %d expls, max_depth %d, weight %d (Pdb = %.6f)\n",
             glp->Goal,glp->Count,glp->MaxDepth,glp->ObsNum,glp->Pdb);
  	}
	if (glp->Expl != LNULL) show_expl(0,glp->Expl);
  }
}

/* show_trie_more(): print TRIE structured explanations.
   created by kame on Jul/22/1998
*/
void show_trie_more(void){

  struct glink *glp;

  for(glp=Root; glp != GNULL; glp = glp->NextG) {
	if (glp->Pdb < 0) {
	  printf("[Goal %d] -- %d expls, %d times observed (Pdb = unknown)\n",
             glp->Goal,glp->Count,glp->ObsNum);
	} else {
	  printf("[Goal %d] -- %d expls, %d times observed (Pdb = %.6f)\n",
             glp->Goal,glp->Count,glp->ObsNum,glp->Pdb);
  	}
	if (glp->Expl != LNULL) show_expl_more(0,glp->Expl);
  }
}

/* show_trie(): print TRIE structured explanations.
   created by kame on Jul/23/1998
*/
void show_trie_less(void){

  struct glink *glp;

  for(glp=Root; glp != GNULL; glp = glp->NextG) {
	if (glp->Pdb < 0) {
	  printf("[Goal %d] -- %d expls, %d times observed (Pdb = unknown)\n",
             glp->Goal,glp->Count,glp->ObsNum);
	} else {
	  printf("[Goal %d] -- %d expls, %d times observed (Pdb = %.6f)\n",
             glp->Goal,glp->Count,glp->ObsNum,glp->Pdb);
  	}
	if (glp->Expl != LNULL) show_expl_less(0,glp->Expl);
  }
}

/* show_expl_more(): show explanations recursively 
 *                   (with memory address)
 */
void show_expl_more(int indent, struct swlink *linkp){

  int i;
  if(linkp->Sw != SNULL){

	printf("%8x->%8x[%2d|%2d|%1d]",
		   linkp,linkp->Sw,
		   (linkp->Sw)->G_id,(linkp->Sw)->T_id,(linkp->Sw)->Val);

	if ((linkp->Sw)->Child != LNULL){
	  printf("->");
	  show_expl_more(indent+1,(linkp->Sw)->Child);
	}
	else printf("$\n");

	/*
	if (linkp->NextSw != LNULL){
	  for(i=0;i<indent;i++) printf("                             ");
	  show_expl_more(indent,linkp->NextSw);
	} 
	*/
	if (linkp->NextSw != LNULL){
	  for(i=0;i<indent;i++) printf("                             ");
	  show_expl_more(indent,linkp->NextSw);
	} else {
	  for(i=0;i<indent;i++) printf("                             ");
	  printf("~~~~~~~~             \n");
	}
  }
}

/* show_expl(): show explanations recursively */
void show_expl(int indent, struct swlink *linkp){

  int i;
  if(linkp->Sw != SNULL){
	printf("[%2d|%2d|%1d]",
		   (linkp->Sw)->G_id,(linkp->Sw)->T_id,(linkp->Sw)->Val);
	if ((linkp->Sw)->Child != LNULL)
	  show_expl(indent+1,(linkp->Sw)->Child);
	else printf("\n");
	if (linkp->NextSw != LNULL){
	  for(i=0;i<indent;i++) printf("         ");
	  show_expl(indent,linkp->NextSw);
	}
  }
}

/* show_expl_less(): show explanations recursively */
void show_expl_less(int indent, struct swlink *linkp){

  int i;
  if(linkp->Sw != SNULL){
	printf("[]",
		   (linkp->Sw)->G_id,(linkp->Sw)->T_id,(linkp->Sw)->Val);
	if ((linkp->Sw)->Child != LNULL)
	  show_expl_less(indent+1,(linkp->Sw)->Child);
	else printf("\n");
	if (linkp->NextSw != LNULL){
	  for(i=0;i<indent;i++) printf("  ");
	  show_expl_less(indent,linkp->NextSw);
	}
  }
}
