libsent/src/phmm/gms.c

Go to the documentation of this file.
00001 
00021 /*
00022  * Copyright (c) 1991-2006 Kawahara Lab., Kyoto University
00023  * Copyright (c) 2000-2005 Shikano Lab., Nara Institute of Science and Technology
00024  * Copyright (c) 2005-2006 Julius project team, Nagoya Institute of Technology
00025  * All rights reserved
00026  */
00027 
00028 /*
00029   Implementation of Gaussian Mixture Selection (old doc...)
00030   
00031   It is called from gs_calc_selected_mixture_and_cache_{safe,heu,beam} in
00032   the first pass for each frame.  It calculates all GS HMM outprob for
00033   given input frame and get the N-best GS HMM states. Then,
00034        for the selected (N-best) states:
00035            calculate the corresponding codebook,
00036            and set fallback_score[t][book] to LOG_ZERO.
00037        else:
00038            set fallback_score[t][book] to the GS HMM outprob.
00039   Later, when calculating state outprobs, the fallback_score[t][book]
00040   is consulted and,
00041        if fallback_score[t][book] == LOG_ZERO:
00042            it means it has been selected, so calculate the outprob with
00043            the corresponding codebook and its weights.
00044        else:
00045            it means it was pruned, so use the fallback_score[t][book]
00046            as its outprob.
00047 
00048            
00049   For triphone, GS HMMs should be assigned to each state.
00050   So the fallback_score[][] is kept according to the GS state ID,
00051   and corresponding GS HMM state id for each triphone state id should be
00052   kept beforehand.
00053   GS HMM Calculation:
00054        for the selected (N-best) GS HMM states:
00055            set fallback_score[t][gs_stateid] to LOG_ZERO.
00056        else:
00057            set fallback_score[t][gs_stateid] to the GS HMM outprob.
00058   triphone HMM probabilities are assigned as:
00059        if fallback_score[t][state2gs[tri_stateid]] == LOG_ZERO:
00060            it has been selected, so calculate the original outprob.
00061        else:
00062            as it was pruned, re-use the fallback_score[t][stateid]
00063            as its outprob.
00064 */
00065 
00066 
00067 #include <sent/stddefs.h>
00068 #include <sent/htk_hmm.h>
00069 #include <sent/htk_param.h>
00070 #include <sent/hmm.h>
00071 #include <sent/gprune.h>
00072 #include "globalvars.h"
00073 
00074 #undef NORMALIZE_GS_SCORE       /* normalize score (ad-hoc) */
00075 
00076   /* GS HMMs must be defined at STATE level using "~s NAME" macro,
00077      where NAMES are like "i:4m", "s2m", etc. */
00078 
00079 /* variables for GMS */
00080 static int my_nbest;            
00081 static int allocframenum;       
00082 
00083 /* GMS info */
00084 static GS_SET *gsset;           
00085 static int gsset_num;           
00086 static int *state2gs; 
00087 
00088 /* results */
00089 static boolean *is_selected;    
00090 static LOGPROB **fallback_score = NULL; 
00091 
00092 /* for calculation */
00093 static int *gsindex;            
00094 static LOGPROB *t_fs;           
00095 
00096 
00101 static void
00102 build_gsset()
00103 {
00104   HTK_HMM_State *st;
00105 
00106   /* allocate */
00107   gsset = (GS_SET *)mymalloc(sizeof(GS_SET) * OP_gshmm->totalstatenum);
00108   gsset_num = OP_gshmm->totalstatenum;
00109   /* make ID */
00110   for(st = OP_gshmm->ststart; st; st=st->next) {
00111     gsset[st->id].state = st;
00112   }
00113 }
00114 
00119 static void
00120 free_gsset()
00121 {
00122   free(gsset);
00123 }
00124 
00125 #define MAXHMMNAMELEN 40        
00126 
00127 
00132 static boolean
00133 build_state2gs()
00134 {
00135   HTK_HMM_Data *dt;
00136   HTK_HMM_State *st, *cr;
00137   int i;
00138   char gstr[MAXHMMNAMELEN], cbuf[MAXHMMNAMELEN];
00139   boolean ok_p = TRUE;
00140 
00141   /* initialize */
00142   state2gs = (int *)mymalloc(sizeof(int) * OP_hmminfo->totalstatenum);
00143   for(i=0;i<OP_hmminfo->totalstatenum;i++) state2gs[i] = -1;
00144 
00145   /* parse through all HMM macro to register their state */
00146   for(dt = OP_hmminfo->start; dt; dt=dt->next) {
00147     if (strlen(dt->name) >= MAXHMMNAMELEN - 2) {
00148       j_printerr("Error: too long hmm name (>%d): \"%s\"\n",
00149                  MAXHMMNAMELEN-3, dt->name);
00150       ok_p = FALSE;
00151       continue;
00152     }
00153     for(i=1;i<dt->state_num-1;i++) { /* for all state */
00154       st = dt->s[i];
00155       /* skip if already assigned */
00156       if (state2gs[st->id] != -1) continue;
00157       /* set corresponding gshmm name */
00158       sprintf(gstr, "%s%dm", center_name(dt->name, cbuf), i + 1);
00159       /* look up the state in OP_gshmm */
00160       if ((cr = state_lookup(OP_gshmm, gstr)) == NULL) {
00161         j_printerr("Error: GS HMM \"%s\" not defined\n", gstr);
00162         ok_p = FALSE;
00163         continue;
00164       }
00165       /* store its ID */
00166       state2gs[st->id] = cr->id;
00167     }
00168   }
00169 #ifdef PARANOIA
00170   {
00171     HTK_HMM_State *st;
00172     for(st=OP_hmminfo->ststart; st; st=st->next) {
00173       printf("%s -> %s\n", (st->name == NULL) ? "(NULL)" : st->name,
00174              (gsset[state2gs[st->id]].state)->name);
00175     }
00176   }
00177 #endif
00178   return ok_p;
00179 }
00180 
00185 static void
00186 free_state2gs()
00187 {
00188   free(state2gs);
00189 }
00190 
00191 
00192 /* sort to find N-best states */
00193 #define SD(A) gsindex[A-1]      
00194 #define SCOPY(D,S) D = S        
00195 #define SVAL(A) (t_fs[gsindex[A-1]]) 
00196 #define STVAL (t_fs[s]) 
00197 
00198 
00204 static void
00205 sort_gsindex_upward(int neednum, int totalnum)
00206 {
00207   int n,root,child,parent;
00208   int s;
00209   for (root = totalnum/2; root >= 1; root--) {
00210     SCOPY(s, SD(root));
00211     parent = root;
00212     while ((child = parent * 2) <= totalnum) {
00213       if (child < totalnum && SVAL(child) < SVAL(child+1)) {
00214         child++;
00215       }
00216       if (STVAL >= SVAL(child)) {
00217         break;
00218       }
00219       SCOPY(SD(parent), SD(child));
00220       parent = child;
00221     }
00222     SCOPY(SD(parent), s);
00223   }
00224   n = totalnum;
00225   while ( n > totalnum - neednum) {
00226     SCOPY(s, SD(n));
00227     SCOPY(SD(n), SD(1));
00228     n--;
00229     parent = 1;
00230     while ((child = parent * 2) <= n) {
00231       if (child < n && SVAL(child) < SVAL(child+1)) {
00232         child++;
00233       }
00234       if (STVAL >= SVAL(child)) {
00235         break;
00236       }
00237       SCOPY(SD(parent), SD(child));
00238       parent = child;
00239     }
00240     SCOPY(SD(parent), s);
00241   }
00242 }
00243 
00248 static void
00249 do_gms()
00250 {
00251   int i;
00252   
00253   /* compute all gshmm scores (in gs_score.c) */
00254   compute_gs_scores(gsset, gsset_num, t_fs);
00255   /* sort and select */
00256   sort_gsindex_upward(my_nbest, gsset_num);
00257   for(i=gsset_num - my_nbest;i<gsset_num;i++) {
00258     /* set scores of selected states to LOG_ZERO */
00259     t_fs[gsindex[i]] = LOG_ZERO;
00260   }
00261 
00262   /* power e -> 10 */
00263 #ifdef NORMALIZE_GS_SCORE
00264   /* normalize other fallback scores (rate of max) */
00265   for(i=0;i<gsset_num;i++) {
00266     if (t_fs[i] != LOG_ZERO) {
00267       t_fs[i] = t_fs[i] * 0.975;
00268     }
00269   }
00270 #endif
00271 }  
00272 
00273 
00281 boolean
00282 gms_init(int nbest)
00283 {
00284   int i;
00285   
00286   /* Check gshmm type */
00287   if (OP_gshmm->is_triphone) {
00288     j_printerr("Error: GS HMM should be a monophone model\n");
00289     return FALSE;
00290   }
00291   if (OP_gshmm->is_tied_mixture) {
00292     j_printerr("Error: GS HMM should not be a tied mixture model\n");
00293     return FALSE;
00294   }
00295 
00296   /* store as local info */
00297   my_nbest = nbest;
00298 
00299   /* Register all GS HMM states in GS_SET */
00300   build_gsset();
00301   /* Make correspondence of all triphone states to GS HMM states */
00302   j_printerr("Mapping HMM states to GS HMM...");
00303   if (build_state2gs() == FALSE) {
00304     j_printerr("Error: failed in assigning GS HMM state for each state\n");
00305     return FALSE;
00306   }
00307   j_printerr("done\n");
00308 
00309   /* prepare index buffer for heap sort */
00310   gsindex = (int *)mymalloc(sizeof(int) * gsset_num);
00311   for(i=0;i<gsset_num;i++) gsindex[i] = i;
00312 
00313   /* init cache status */
00314   fallback_score = NULL;
00315   is_selected = NULL;
00316   allocframenum = -1;
00317 
00318   /* initialize gms_gprune functions */
00319   gms_gprune_init(OP_hmminfo, gsset_num);
00320   
00321   return TRUE;
00322 }
00323 
00331 boolean
00332 gms_prepare(int framenum)
00333 {
00334   LOGPROB *tmp;
00335   int t;
00336 
00337   /* allocate cache */
00338   if (allocframenum < framenum) {
00339     if (fallback_score != NULL) {
00340       free(fallback_score[0]);
00341       free(fallback_score);
00342       free(is_selected);
00343     }
00344     fallback_score = (LOGPROB **)mymalloc(sizeof(LOGPROB *) * framenum);
00345     tmp = (LOGPROB *)mymalloc(sizeof(LOGPROB) * gsset_num * framenum);
00346     for(t=0;t<framenum;t++) {
00347       fallback_score[t] = &(tmp[gsset_num * t]);
00348     }
00349     is_selected = (boolean *)mymalloc(sizeof(boolean) * framenum);
00350     allocframenum = framenum;
00351   }
00352   /* clear */
00353   for(t=0;t<framenum;t++) is_selected[t] = FALSE;
00354 
00355   /* prepare gms_gprune functions */
00356   gms_gprune_prepare();
00357   
00358   return TRUE;
00359 }
00360 
00365 void
00366 gms_free()
00367 {
00368   free_gsset();
00369   free_state2gs();
00370   free(gsindex);
00371   if (fallback_score != NULL) {
00372     free(fallback_score[0]);
00373     free(fallback_score);
00374     free(is_selected);
00375   }
00376   gms_gprune_free();
00377 }
00378 
00379 
00380 
00391 LOGPROB
00392 gms_state()
00393 {
00394   LOGPROB gsprob;
00395   if (OP_last_time != OP_time) { /* different frame */
00396     /* set current buffer */
00397     t_fs = fallback_score[OP_time];
00398     /* select state if not yet */
00399     if (!is_selected[OP_time]) {
00400       do_gms();
00401       is_selected[OP_time] = TRUE;
00402     }
00403   }
00404   if ((gsprob = t_fs[state2gs[OP_state_id]]) != LOG_ZERO) {
00405     /* un-selected: return the fallback value */
00406     return(gsprob);
00407   }
00408   /* selected: calculate the real outprob of the state */
00409   return(calc_outprob());
00410 }

Generated on Tue Dec 26 12:53:22 2006 for Julian by  doxygen 1.5.0