00001 
00022 
00023 
00024 
00025 
00026 
00027 
00028 
00029 
00030 
00031 
00032 
00033 
00034 
00035 
00036 
00037 
00038 
00039 
00040 
00041 
00042 
00043 
00044 
00045 
00046 
00047 
00048 
00049 
00050 
00051 
00052 
00053 
00054 
00055 
00056 
00057 
00058 
00059 
00060 
00061 
00062 
00063 
00064 
00065 
00066 
00067 
00068 #include <sent/stddefs.h>
00069 #include <sent/htk_hmm.h>
00070 #include <sent/htk_param.h>
00071 #include <sent/hmm.h>
00072 #include <sent/hmm_calc.h>
00073 
00074 #undef NORMALIZE_GS_SCORE       
00075 
00076   
00077 
00078 
00079 
00086 static void
00087 build_gsset(HMMWork *wrk)
00088 {
00089   HTK_HMM_State *st;
00090 
00091   
00092   wrk->gsset = (GS_SET *)mymalloc(sizeof(GS_SET) * wrk->OP_gshmm->totalstatenum);
00093   wrk->gsset_num = wrk->OP_gshmm->totalstatenum;
00094   
00095   for(st = wrk->OP_gshmm->ststart; st; st=st->next) {
00096     wrk->gsset[st->id].state = st;
00097   }
00098 }
00099 
00106 static void
00107 free_gsset(HMMWork *wrk)
00108 {
00109   free(wrk->gsset);
00110 }
00111 
00119 static boolean
00120 build_state2gs(HMMWork *wrk)
00121 {
00122   HTK_HMM_Data *dt;
00123   HTK_HMM_State *st, *cr;
00124   int i;
00125   char gstr[MAX_HMMNAME_LEN], cbuf[MAX_HMMNAME_LEN];
00126   boolean ok_p = TRUE;
00127 
00128   
00129   wrk->state2gs = (int *)mymalloc(sizeof(int) * wrk->OP_hmminfo->totalstatenum);
00130   for(i=0;i<wrk->OP_hmminfo->totalstatenum;i++) wrk->state2gs[i] = -1;
00131 
00132   
00133   for(dt = wrk->OP_hmminfo->start; dt; dt=dt->next) {
00134     if (strlen(dt->name) >= MAX_HMMNAME_LEN - 2) {
00135       jlog("Error: gms: too long hmm name (>%d): \"%s\"\n",
00136            MAX_HMMNAME_LEN-3, dt->name);
00137       jlog("Error: gms: change value of MAX_HMMNAME_LEN\n");
00138       ok_p = FALSE;
00139       continue;
00140     }
00141     for(i=1;i<dt->state_num-1;i++) { 
00142       st = dt->s[i];
00143       
00144       if (wrk->state2gs[st->id] != -1) continue;
00145       
00146       sprintf(gstr, "%s%dm", center_name(dt->name, cbuf), i + 1);
00147       
00148       if ((cr = state_lookup(wrk->OP_gshmm, gstr)) == NULL) {
00149         jlog("Error: gms: GS HMM \"%s\" not defined\n", gstr);
00150         ok_p = FALSE;
00151         continue;
00152       }
00153       
00154       wrk->state2gs[st->id] = cr->id;
00155     }
00156   }
00157 #ifdef PARANOIA
00158   {
00159     HTK_HMM_State *st;
00160     for(st=wrk->OP_hmminfo->ststart; st; st=st->next) {
00161       printf("%s -> %s\n", (st->name == NULL) ? "(NULL)" : st->name,
00162              (wrk->gsset[wrk->state2gs[st->id]].state)->name);
00163     }
00164   }
00165 #endif
00166   return ok_p;
00167 }
00168 
00175 static void
00176 free_state2gs(HMMWork *wrk)
00177 {
00178   free(wrk->state2gs);
00179 }
00180 
00181 
00182 
00183 #define SD(A) idx[A-1]  
00184 #define SCOPY(D,S) D = S        
00185 #define SVAL(A) (fs[idx[A-1]]) 
00186 #define STVAL (fs[s]) 
00187 
00188 
00194 static void
00195 sort_gsindex_upward(HMMWork *wrk)
00196 {
00197   int n,root,child,parent;
00198   int s;
00199   int *idx;
00200   LOGPROB *fs;
00201   int neednum, totalnum;
00202 
00203   idx = wrk->gsindex;
00204   fs = wrk->t_fs;
00205   neednum = wrk->my_nbest;
00206   totalnum = wrk->gsset_num;
00207 
00208   for (root = totalnum/2; root >= 1; root--) {
00209     SCOPY(s, SD(root));
00210     parent = root;
00211     while ((child = parent * 2) <= totalnum) {
00212       if (child < totalnum && SVAL(child) < SVAL(child+1)) {
00213         child++;
00214       }
00215       if (STVAL >= SVAL(child)) {
00216         break;
00217       }
00218       SCOPY(SD(parent), SD(child));
00219       parent = child;
00220     }
00221     SCOPY(SD(parent), s);
00222   }
00223   n = totalnum;
00224   while ( n > totalnum - neednum) {
00225     SCOPY(s, SD(n));
00226     SCOPY(SD(n), SD(1));
00227     n--;
00228     parent = 1;
00229     while ((child = parent * 2) <= n) {
00230       if (child < n && SVAL(child) < SVAL(child+1)) {
00231         child++;
00232       }
00233       if (STVAL >= SVAL(child)) {
00234         break;
00235       }
00236       SCOPY(SD(parent), SD(child));
00237       parent = child;
00238     }
00239     SCOPY(SD(parent), s);
00240   }
00241 }
00242 
00249 static void
00250 do_gms(HMMWork *wrk)
00251 {
00252   int i;
00253   
00254   
00255   compute_gs_scores(wrk);
00256   
00257   sort_gsindex_upward(wrk);
00258   for(i=wrk->gsset_num - wrk->my_nbest;i<wrk->gsset_num;i++) {
00259     
00260     wrk->t_fs[wrk->gsindex[i]] = LOG_ZERO;
00261   }
00262 
00263   
00264 #ifdef NORMALIZE_GS_SCORE
00265   
00266   for(i=0;i<wrk->gsset_num;i++) {
00267     if (wrk->t_fs[i] != LOG_ZERO) {
00268       wrk->t_fs[i] *= 0.975;
00269     }
00270   }
00271 #endif
00272 }  
00273 
00274 
00282 boolean
00283 gms_init(HMMWork *wrk)
00284 {
00285   int i;
00286   
00287   
00288   if (wrk->OP_gshmm->is_triphone) {
00289     jlog("Error: gms: GS HMM should be a monophone model\n");
00290     return FALSE;
00291   }
00292   if (wrk->OP_gshmm->is_tied_mixture) {
00293     jlog("Error: gms: GS HMM should not be a tied mixture model\n");
00294     return FALSE;
00295   }
00296 
00297   
00298   build_gsset(wrk);
00299   
00300   if (build_state2gs(wrk) == FALSE) {
00301     jlog("Error: gms: failed in assigning GS HMM state for each state\n");
00302     return FALSE;
00303   }
00304   jlog("Stat: gms: GS HMMs are mapped to HMM states\n");
00305 
00306   
00307   wrk->gsindex = (int *)mymalloc(sizeof(int) * wrk->gsset_num);
00308   for(i=0;i<wrk->gsset_num;i++) wrk->gsindex[i] = i;
00309 
00310   
00311   wrk->fallback_score = NULL;
00312   wrk->gms_is_selected = NULL;
00313   wrk->gms_allocframenum = -1;
00314 
00315   
00316   gms_gprune_init(wrk);
00317   
00318   return TRUE;
00319 }
00320 
00329 boolean
00330 gms_prepare(HMMWork *wrk, int framenum)
00331 {
00332   LOGPROB *tmp;
00333   int t;
00334 
00335   
00336   if (wrk->gms_allocframenum < framenum) {
00337     if (wrk->fallback_score != NULL) {
00338       free(wrk->fallback_score[0]);
00339       free(wrk->fallback_score);
00340       free(wrk->gms_is_selected);
00341     }
00342     wrk->fallback_score = (LOGPROB **)mymalloc(sizeof(LOGPROB *) * framenum);
00343     tmp = (LOGPROB *)mymalloc(sizeof(LOGPROB) * wrk->gsset_num * framenum);
00344     for(t=0;t<framenum;t++) {
00345       wrk->fallback_score[t] = &(tmp[wrk->gsset_num * t]);
00346     }
00347     wrk->gms_is_selected = (boolean *)mymalloc(sizeof(boolean) * framenum);
00348     wrk->gms_allocframenum = framenum;
00349   }
00350   
00351   for(t=0;t<framenum;t++) wrk->gms_is_selected[t] = FALSE;
00352 
00353   
00354   gms_gprune_prepare(wrk);
00355   
00356   return TRUE;
00357 }
00358 
00365 void
00366 gms_free(HMMWork *wrk)
00367 {
00368   free_gsset(wrk);
00369   free_state2gs(wrk);
00370   free(wrk->gsindex);
00371   if (wrk->fallback_score != NULL) {
00372     free(wrk->fallback_score[0]);
00373     free(wrk->fallback_score);
00374     free(wrk->gms_is_selected);
00375   }
00376   gms_gprune_free(wrk);
00377 }
00378 
00379 
00380 
00393 LOGPROB
00394 gms_state(HMMWork *wrk)
00395 {
00396   LOGPROB gsprob;
00397   if (wrk->OP_last_time != wrk->OP_time) { 
00398     
00399     wrk->t_fs = wrk->fallback_score[wrk->OP_time];
00400     
00401     if (!wrk->gms_is_selected[wrk->OP_time]) {
00402       do_gms(wrk);
00403       wrk->gms_is_selected[wrk->OP_time] = TRUE;
00404     }
00405   }
00406   if ((gsprob = wrk->t_fs[wrk->state2gs[wrk->OP_state_id]]) != LOG_ZERO) {
00407     
00408     return(gsprob);
00409   }
00410   
00411   return((*(wrk->calc_outprob))(wrk));
00412 }