00001 
00021 
00022 
00023 
00024 
00025 
00026 
00027 
00028 
00029 
00030 
00031 
00032 
00033 
00034 
00035 
00036 
00037 
00038 
00039 
00040 
00041 
00042 
00043 
00044 
00045 
00046 
00047 
00048 
00049 
00050 
00051 
00052 
00053 
00054 
00055 
00056 
00057 
00058 
00059 
00060 
00061 
00062 
00063 
00064 
00065 
00066 
00067 #include <sent/stddefs.h>
00068 #include <sent/htk_hmm.h>
00069 #include <sent/htk_param.h>
00070 #include <sent/hmm.h>
00071 #include <sent/gprune.h>
00072 #include "globalvars.h"
00073 
00074 #undef NORMALIZE_GS_SCORE       
00075 
00076   
00077 
00078 
00079 
00080 static int my_nbest;            
00081 static int allocframenum;       
00082 
00083 
00084 static GS_SET *gsset;           
00085 static int gsset_num;           
00086 static int *state2gs; 
00087 
00088 
00089 static boolean *is_selected;    
00090 static LOGPROB **fallback_score = NULL; 
00091 
00092 
00093 static int *gsindex;            
00094 static LOGPROB *t_fs;           
00095 
00096 
00101 static void
00102 build_gsset()
00103 {
00104   HTK_HMM_State *st;
00105 
00106   
00107   gsset = (GS_SET *)mymalloc(sizeof(GS_SET) * OP_gshmm->totalstatenum);
00108   gsset_num = OP_gshmm->totalstatenum;
00109   
00110   for(st = OP_gshmm->ststart; st; st=st->next) {
00111     gsset[st->id].state = st;
00112   }
00113 }
00114 
00119 static void
00120 free_gsset()
00121 {
00122   free(gsset);
00123 }
00124 
00125 #define MAXHMMNAMELEN 40        
00126 
00127 
00132 static boolean
00133 build_state2gs()
00134 {
00135   HTK_HMM_Data *dt;
00136   HTK_HMM_State *st, *cr;
00137   int i;
00138   char gstr[MAXHMMNAMELEN], cbuf[MAXHMMNAMELEN];
00139   boolean ok_p = TRUE;
00140 
00141   
00142   state2gs = (int *)mymalloc(sizeof(int) * OP_hmminfo->totalstatenum);
00143   for(i=0;i<OP_hmminfo->totalstatenum;i++) state2gs[i] = -1;
00144 
00145   
00146   for(dt = OP_hmminfo->start; dt; dt=dt->next) {
00147     if (strlen(dt->name) >= MAXHMMNAMELEN - 2) {
00148       j_printerr("Error: too long hmm name (>%d): \"%s\"\n",
00149                  MAXHMMNAMELEN-3, dt->name);
00150       ok_p = FALSE;
00151       continue;
00152     }
00153     for(i=1;i<dt->state_num-1;i++) { 
00154       st = dt->s[i];
00155       
00156       if (state2gs[st->id] != -1) continue;
00157       
00158       sprintf(gstr, "%s%dm", center_name(dt->name, cbuf), i + 1);
00159       
00160       if ((cr = state_lookup(OP_gshmm, gstr)) == NULL) {
00161         j_printerr("Error: GS HMM \"%s\" not defined\n", gstr);
00162         ok_p = FALSE;
00163         continue;
00164       }
00165       
00166       state2gs[st->id] = cr->id;
00167     }
00168   }
00169 #ifdef PARANOIA
00170   {
00171     HTK_HMM_State *st;
00172     for(st=OP_hmminfo->ststart; st; st=st->next) {
00173       printf("%s -> %s\n", (st->name == NULL) ? "(NULL)" : st->name,
00174              (gsset[state2gs[st->id]].state)->name);
00175     }
00176   }
00177 #endif
00178   return ok_p;
00179 }
00180 
00185 static void
00186 free_state2gs()
00187 {
00188   free(state2gs);
00189 }
00190 
00191 
00192 
00193 #define SD(A) gsindex[A-1]      
00194 #define SCOPY(D,S) D = S        
00195 #define SVAL(A) (t_fs[gsindex[A-1]]) 
00196 #define STVAL (t_fs[s]) 
00197 
00198 
00204 static void
00205 sort_gsindex_upward(int neednum, int totalnum)
00206 {
00207   int n,root,child,parent;
00208   int s;
00209   for (root = totalnum/2; root >= 1; root--) {
00210     SCOPY(s, SD(root));
00211     parent = root;
00212     while ((child = parent * 2) <= totalnum) {
00213       if (child < totalnum && SVAL(child) < SVAL(child+1)) {
00214         child++;
00215       }
00216       if (STVAL >= SVAL(child)) {
00217         break;
00218       }
00219       SCOPY(SD(parent), SD(child));
00220       parent = child;
00221     }
00222     SCOPY(SD(parent), s);
00223   }
00224   n = totalnum;
00225   while ( n > totalnum - neednum) {
00226     SCOPY(s, SD(n));
00227     SCOPY(SD(n), SD(1));
00228     n--;
00229     parent = 1;
00230     while ((child = parent * 2) <= n) {
00231       if (child < n && SVAL(child) < SVAL(child+1)) {
00232         child++;
00233       }
00234       if (STVAL >= SVAL(child)) {
00235         break;
00236       }
00237       SCOPY(SD(parent), SD(child));
00238       parent = child;
00239     }
00240     SCOPY(SD(parent), s);
00241   }
00242 }
00243 
00248 static void
00249 do_gms()
00250 {
00251   int i;
00252   
00253   
00254   compute_gs_scores(gsset, gsset_num, t_fs);
00255   
00256   sort_gsindex_upward(my_nbest, gsset_num);
00257   for(i=gsset_num - my_nbest;i<gsset_num;i++) {
00258     
00259     t_fs[gsindex[i]] = LOG_ZERO;
00260   }
00261 
00262   
00263 #ifdef NORMALIZE_GS_SCORE
00264   
00265   for(i=0;i<gsset_num;i++) {
00266     if (t_fs[i] != LOG_ZERO) {
00267       t_fs[i] = t_fs[i] * 0.975;
00268     }
00269   }
00270 #endif
00271 }  
00272 
00273 
00281 boolean
00282 gms_init(int nbest)
00283 {
00284   int i;
00285   
00286   
00287   if (OP_gshmm->is_triphone) {
00288     j_printerr("Error: GS HMM should be a monophone model\n");
00289     return FALSE;
00290   }
00291   if (OP_gshmm->is_tied_mixture) {
00292     j_printerr("Error: GS HMM should not be a tied mixture model\n");
00293     return FALSE;
00294   }
00295 
00296   
00297   my_nbest = nbest;
00298 
00299   
00300   build_gsset();
00301   
00302   j_printerr("Mapping HMM states to GS HMM...");
00303   if (build_state2gs() == FALSE) {
00304     j_printerr("Error: failed in assigning GS HMM state for each state\n");
00305     return FALSE;
00306   }
00307   j_printerr("done\n");
00308 
00309   
00310   gsindex = (int *)mymalloc(sizeof(int) * gsset_num);
00311   for(i=0;i<gsset_num;i++) gsindex[i] = i;
00312 
00313   
00314   fallback_score = NULL;
00315   is_selected = NULL;
00316   allocframenum = -1;
00317 
00318   
00319   gms_gprune_init(OP_hmminfo, gsset_num);
00320   
00321   return TRUE;
00322 }
00323 
00331 boolean
00332 gms_prepare(int framenum)
00333 {
00334   LOGPROB *tmp;
00335   int t;
00336 
00337   
00338   if (allocframenum < framenum) {
00339     if (fallback_score != NULL) {
00340       free(fallback_score[0]);
00341       free(fallback_score);
00342       free(is_selected);
00343     }
00344     fallback_score = (LOGPROB **)mymalloc(sizeof(LOGPROB *) * framenum);
00345     tmp = (LOGPROB *)mymalloc(sizeof(LOGPROB) * gsset_num * framenum);
00346     for(t=0;t<framenum;t++) {
00347       fallback_score[t] = &(tmp[gsset_num * t]);
00348     }
00349     is_selected = (boolean *)mymalloc(sizeof(boolean) * framenum);
00350     allocframenum = framenum;
00351   }
00352   
00353   for(t=0;t<framenum;t++) is_selected[t] = FALSE;
00354 
00355   
00356   gms_gprune_prepare();
00357   
00358   return TRUE;
00359 }
00360 
00365 void
00366 gms_free()
00367 {
00368   free_gsset();
00369   free_state2gs();
00370   free(gsindex);
00371   if (fallback_score != NULL) {
00372     free(fallback_score[0]);
00373     free(fallback_score);
00374     free(is_selected);
00375   }
00376   gms_gprune_free();
00377 }
00378 
00379 
00380 
00391 LOGPROB
00392 gms_state()
00393 {
00394   LOGPROB gsprob;
00395   if (OP_last_time != OP_time) { 
00396     
00397     t_fs = fallback_score[OP_time];
00398     
00399     if (!is_selected[OP_time]) {
00400       do_gms();
00401       is_selected[OP_time] = TRUE;
00402     }
00403   }
00404   if ((gsprob = t_fs[state2gs[OP_state_id]]) != LOG_ZERO) {
00405     
00406     return(gsprob);
00407   }
00408   
00409   return(calc_outprob());
00410 }