50#include "EST_SCFG_Chart.h"
51#include "EST_simplestats.h"
53#include "EST_TVector.h"
55static const EST_bracketed_string def_val_s;
56static EST_bracketed_string error_return_s;
61#if defined(INSTANTIATE_TEMPLATES)
62#include "../base_class/EST_TVector.cc"
67void set_corpus(EST_Bcorpus &b, LISP examples)
72 b.
resize(siod_llength(examples));
74 for (i=0,e=examples; e != NIL; e=cdr(e),i++)
75 b.a_no_check(i).set_bracketed_string(car(e));
78void EST_bracketed_string::init()
87EST_bracketed_string::EST_bracketed_string()
92EST_bracketed_string::EST_bracketed_string(LISP
string)
96 set_bracketed_string(
string);
99EST_bracketed_string::~EST_bracketed_string()
105 for (i=0; i < p_length; i++)
106 delete [] valid_spans[i];
107 delete [] valid_spans;
110void EST_bracketed_string::set_bracketed_string(LISP
string)
116 p_length = find_num_nodes(
string);
117 symbols =
new LISP[p_length];
119 set_leaf_indices(
string,0,symbols);
124 valid_spans =
new int*[length()];
125 for (i=0; i < length(); i++)
127 valid_spans[i] =
new int[length()+1];
128 for (j=i+1; j <= length(); j++)
129 valid_spans[i][j] = 0;
138int EST_bracketed_string::find_num_nodes(LISP
string)
143 else if (CONSP(
string))
144 return find_num_nodes(car(
string))+
145 find_num_nodes(cdr(
string));
150int EST_bracketed_string::set_leaf_indices(LISP
string,
int i,LISP *syms)
154 else if (!CONSP(car(
string)))
157 return set_leaf_indices(cdr(
string),i+1,syms);
161 return set_leaf_indices(cdr(
string),
162 set_leaf_indices(car(
string),i,syms),
167void EST_bracketed_string::find_valid(
int s,LISP t)
const
174 for (c=s,l=t; l != NIL; l=cdr(l))
176 c += num_leafs(car(l));
177 valid_spans[s][c] = 1;
179 find_valid(s,car(t));
180 find_valid(s+num_leafs(car(t)),cdr(t));
184int EST_bracketed_string::num_leafs(LISP t)
const
191 return num_leafs(car(t)) + num_leafs(cdr(t));
194EST_SCFG_traintest::EST_SCFG_traintest(
void) :
EST_SCFG()
202EST_SCFG_traintest::~EST_SCFG_traintest(
void)
207void EST_SCFG_traintest::load_corpus(
const EST_String &filename)
209 set_corpus(corpus,vload(filename,1));
213double EST_SCFG_traintest::f_I_cal(
int c,
int p,
int i,
int k)
220 res = prob_U(p,terminal(corpus.a_no_check(c).symbol_at(i)));
228 else if (corpus.a_no_check(c).valid(i,k) == TRUE)
234 for (q = 0; q < num_nonterminals(); q++)
235 for (r = 0; r < num_nonterminals(); r++)
237 double pBpqr = prob_B(p,q,r);
239 for (j=i+1; j < k; j++)
241 double in = f_I(c,q,i,j);
243 s += pBpqr * in * f_I(c,r,j,k);
251 inside[p][i][k] = res;
259double EST_SCFG_traintest::f_O_cal(
int c,
int p,
int i,
int k)
264 if ((i == 0) && (k == corpus.a_no_check(c).length()))
266 if (p == distinguished_symbol())
271 else if (corpus.a_no_check(c).valid(i,k) == TRUE)
279 for (q = 0; q < num_nonterminals(); q++)
280 for (r = 0; r < num_nonterminals(); r++)
282 pBqrp = prob_B(q,r,p);
288 double out = f_O(c,q,j,k);
290 s2 += out * f_I(c,r,j,i);
294 pBqpr = prob_B(q,p,r);
297 for (j=k+1;j <= corpus.a_no_check(c).length(); j++)
299 double out = f_O(c,q,i,j);
301 s3 += out * f_I(c,r,k,j);
312 outside[p][i][k] = res;
317void EST_SCFG_traintest::reestimate_rule_prob_B(
int c,
int ri,
int p,
int q,
int r)
323 double pBpqr = prob_B(p,q,r);
327 for (i=0; i <= corpus.a_no_check(c).length()-2; i++)
328 for (j=i+1; j <= corpus.a_no_check(c).length()-1; j++)
330 double d1 = f_I(c,q,i,j);
331 if (d1 == 0)
continue;
332 for (k=j+1; k <= corpus.a_no_check(c).length(); k++)
334 double d2 = f_I(c,r,j,k);
335 if (d2 == 0)
continue;
336 double d3 = f_O(c,p,i,k);
337 if (d3 == 0)
continue;
357void EST_SCFG_traintest::reestimate_rule_prob_U(
int c,
int ri,
int p,
int m)
369 for (i=1; i < corpus.a_no_check(c).length(); i++)
370 if (m == terminal(corpus.a_no_check(c).symbol_at(i-1)))
371 n2 += prob_U(p,m) * f_O(c,p,i-1,i);
377 d[ri] += f_P(c,p) / fP;
381double EST_SCFG_traintest::f_P(
int c)
383 return f_I(c,distinguished_symbol(),0,corpus.a_no_check(c).length());
386double EST_SCFG_traintest::f_P(
int c,
int p)
391 for (i=0; i < corpus.a_no_check(c).length(); i++)
392 for (j=i+1; j <= corpus.a_no_check(c).length(); j++)
394 double d1 = f_O(c,p,i,j);
395 if (d1 == 0)
continue;
396 db += f_I(c,p,i,j)*d1;
402void EST_SCFG_traintest::reestimate_grammar_probs(
int passes,
417 n.resize(rules.length());
418 d.resize(rules.length());
420 for (pass = startpass; pass < passes; pass++)
427 set_rule_prob_cache();
429 for (mC=0.0,lPc=0.0,c=0; c < corpus.length(); c++)
432 if ((spread > 0) && (((c+(pass*spread))%100) >= spread))
434 printf(
" %d",c); fflush(stdout);
435 if (corpus.a_no_check(c).length() == 0)
continue;
436 init_io_cache(c,num_nonterminals());
437 for (ri=0,r=rules.head(); r != 0; r=r->next(),ri++)
439 if (rules(r).type() == est_scfg_binary_rule)
440 reestimate_rule_prob_B(c,ri,
442 rules(r).daughter1(),
443 rules(r).daughter2());
445 reestimate_rule_prob_U(c,
448 rules(r).daughter1());
450 lPc += safe_log(f_P(c));
451 mC += corpus.a_no_check(c).length();
456 for (se=0.0,ri=0,r=rules.head(); r != 0; r=r->next(),ri++)
458 double n_prob = n[ri]/d[ri];
461 se += (n_prob-rules(r).prob())*(n_prob-rules(r).prob());
462 rules(r).set_prob(n_prob);
464 printf(
"pass %d cross entropy %g RMSE %f %f %d\n",
465 pass,-(lPc/mC),sqrt(se/rules.length()),
468 if (checkpoint != -1)
470 if ((pass % checkpoint) == checkpoint-1)
473 sprintf(cp,
".%03d",pass);
482void EST_SCFG_traintest::train_inout(
int passes,
490 reestimate_grammar_probs(passes, startpass, checkpoint,
494void EST_SCFG_traintest::init_io_cache(
int c,
int nt)
498 int mc = corpus.a_no_check(c).length()+1;
500 inside =
new double**[nt];
501 outside =
new double**[nt];
502 for (i=0; i < nt; i++)
504 inside[i] =
new double*[mc];
505 outside[i] =
new double*[mc];
506 for (j=0; j < mc; j++)
508 inside[i][j] =
new double[mc];
509 outside[i][j] =
new double[mc];
510 for (k=0; k < mc; k++)
512 inside[i][j][k] = -1;
513 outside[i][j][k] = -1;
519void EST_SCFG_traintest::clear_io_cache(
int c)
521 int mc = corpus.a_no_check(c).length()+1;
527 for (i=0; i < num_nonterminals(); i++)
529 for (j=0; j < mc; j++)
531 delete [] inside[i][j];
532 delete [] outside[i][j];
535 delete [] outside[i];
545double EST_SCFG_traintest::cross_entropy()
550 for (c=0; c < corpus.length(); c++)
553 mC += corpus.a_no_check(c).length();
559void EST_SCFG_traintest::test_corpus()
568 n.resize(rules.length());
569 d.resize(rules.length());
570 for (i=0; i < rules.length(); i++)
573 for (mC=0.0,lPc=0.0,c=0; c < corpus.length(); c++)
575 if (corpus.length() > 50)
580 init_io_cache(c,num_nonterminals());
587 mC += corpus.a_no_check(c).length();
591 if (corpus.length() > 50)
594 cout <<
"cross entropy " << -(lPc/mC) <<
" (" << failed <<
" failed out of " <<
595 corpus.length() <<
" sentences )" << endl;
void resize(int n, int set=1)
static const T * def_val
default value, used for filling matrix after resizing