%%%
%%% Baysian network learning
%%%
%%% Copyright (C) 1997
%%%   SATO Taisuke and KAMEYA Yoshitaka,
%%%     Dept. of Computer Science, Tokyo Institute of Technology

% learning process will start by
% | ?- mylearn(100).
% or
% | ?- learn.
% (if teacher data file "world.dat" is already built.)

% Control declarations:
target(world,2).
data('world.dat').

% Only Smoke and Report are observable.
world(Sm,Re) :- world(_,_,_,Sm,_,Re).

% Joint distribution:
world(Ta,Fi,Al,Sm,Le,Re) :-
    fire(Fi),tampering(Ta),c_smoke(Sm,Fi),
    c_alarm(Al,Fi,Ta),c_leaving(Le,Al),c_report(Re,Le).

% Marginal probabilities:
smoke(Sm) :- fire(Fi),c_smoke(Sm,Fi).
alarm(Al) :- fire(Fi),tampering(Ta),c_alarm(Al,Fi,Ta).
leaving(Le) :- alarm(Al),c_leaving(Le,Al).
report(Re)  :- leaving(Le),c_report(Re,Le).

% Conditional probabilities:
tampering(yes):- bsw(ta,none,1).
tampering(no) :- bsw(ta,none,0).
fire(yes):- bsw(fi,none,1).
fire(no) :- bsw(fi,none,0).
c_smoke(yes,Fi):- bsw(sm(Fi),none,1).
c_smoke(no,Fi) :- bsw(sm(Fi),none,0).
c_alarm(yes,Fi,Ta):- bsw(al(Fi,Ta),none,1).
c_alarm(no,Fi,Ta) :- bsw(al(Fi,Ta),none,0).
c_leaving(yes,Al):- bsw(le(Al),none,1).
c_leaving(no,Al) :- bsw(le(Al),none,0).
c_report(yes,Le):- bsw(re(Le),none,1).
c_report(no,Le) :- bsw(re(Le),none,0).

% Utility program:

% print the distribution of the world
print_world :- 
    prob(world(yes,yes),YY),prob(world(yes,no),YN),
    prob(world(no,yes),NY), prob(world(no,no),NN),
    format("Pb(world(yes,yes))= ~6f",[YY]),nl,
    format("Pb(world(yes,no)) = ~6f",[YN]),nl,
    format("Pb(world(no,yes)) = ~6f",[NY]),nl,
    format("Pb(world(no,no))  = ~6f",[NN]),nl.

% Generate teacher data and write them to "world.dat"
% before call learn/0. 

mylearn(N) :- write_world(N,'world.dat'),learn.

write_world(N,File) :-
	gen_world(N,Gs),
	tell(File),write_world2(Gs),told.

write_world2([world(Sm,Re)|Gs]) :-
	write(world(Sm,Re)),write('.'),nl,
	write_world2(Gs).
write_world2([]).

gen_world(N,[world(Sm,Re)|Gs]) :-
	N > 0,
	gen_world(_,_,_,Sm,_,Re),
	N1 is N-1,
	gen_world(N1,Gs).
gen_world(0,[]).

gen_world(Ta,Fi,Al,Sm,Le,Re) :-
	gen_fire(Fi),gen_tampering(Ta),gen_c_smoke(Sm,Fi),
	gen_c_alarm(Al,Fi,Ta),gen_c_leaving(Le,Al),gen_c_report(Re,Le).

gen_fire(Fi) :- dice_multi([yes,no],[0.1,0.9],Fi).
gen_tampering(Ta) :- dice_multi([yes,no],[0.15,0.85],Fi).
gen_c_smoke(Sm,yes):- dice_multi([yes,no],[0.95,0.05],Sm).
gen_c_smoke(Sm,no) :- dice_multi([yes,no],[0.05,0.95],Sm).
gen_c_alarm(Al,yes,yes):- dice_multi([yes,no],[0.50,0.50],Al).
gen_c_alarm(Al,yes,no) :- dice_multi([yes,no],[0.90,0.10],Al).
gen_c_alarm(Al,no,yes) :- dice_multi([yes,no],[0.85,0.15],Al).
gen_c_alarm(Al,no,no)  :- dice_multi([yes,no],[0.05,0.95],Al).
gen_c_leaving(Le,yes) :- dice_multi([yes,no],[0.88,0.12],Le).
gen_c_leaving(Le,no) :- dice_multi([yes,no],[0.01,0.99],Le).
gen_c_report(Re,yes) :- dice_multi([yes,no],[0.75,0.25],Re).
gen_c_report(Re,no) :- dice_multi([yes,no],[0.10,0.90],Re).
