Main Page | Modules | Data Structures | Directories | File List | Data Fields | Globals | Related Pages

gms.c

Go to the documentation of this file.
00001 
00021 /*
00022  * Copyright (c) 1991-2006 Kawahara Lab., Kyoto University
00023  * Copyright (c) 2000-2005 Shikano Lab., Nara Institute of Science and Technology
00024  * Copyright (c) 2005-2006 Julius project team, Nagoya Institute of Technology, Nagoya Institute of Technology
00025  * All rights reserved
00026  */
00027 
00028 /*
00029   Implementation of Gaussian Mixture Selection (old doc...)
00030   
00031   It is called from gs_calc_selected_mixture_and_cache_{safe,heu,beam} in
00032   the first pass for each frame.  It calculates all GS HMM outprob for
00033   given input frame and get the N-best GS HMM states. Then,
00034        for the selected (N-best) states:
00035            calculate the corresponding codebook,
00036            and set fallback_score[t][book] to LOG_ZERO.
00037        else:
00038            set fallback_score[t][book] to the GS HMM outprob.
00039   Later, when calculating state outprobs, the fallback_score[t][book]
00040   is consulted and,
00041        if fallback_score[t][book] == LOG_ZERO:
00042            it means it has been selected, so calculate the outprob with
00043            the corresponding codebook and its weights.
00044        else:
00045            it means it was pruned, so use the fallback_score[t][book]
00046            as its outprob.
00047 
00048            
00049   For triphone, GS HMMs should be assigned to each state.
00050   So the fallback_score[][] is kept according to the GS state ID,
00051   and corresponding GS HMM state id for each triphone state id should be
00052   kept beforehand.
00053   GS HMM Calculation:
00054        for the selected (N-best) GS HMM states:
00055            set fallback_score[t][gs_stateid] to LOG_ZERO.
00056        else:
00057            set fallback_score[t][gs_stateid] to the GS HMM outprob.
00058   triphone HMM probabilities are assigned as:
00059        if fallback_score[t][state2gs[tri_stateid]] == LOG_ZERO:
00060            it has been selected, so calculate the original outprob.
00061        else:
00062            as it was pruned, re-use the fallback_score[t][stateid]
00063            as its outprob.
00064 */
00065 
00066 
00067 #include <sent/stddefs.h>
00068 #include <sent/htk_hmm.h>
00069 #include <sent/htk_param.h>
00070 #include <sent/hmm.h>
00071 #include <sent/gprune.h>
00072 #include "globalvars.h"
00073 
00074 #undef NORMALIZE_GS_SCORE       /* normalize score (ad-hoc) */
00075 
00076   /* GS HMMs must be defined at STATE level using "~s NAME" macro,
00077      where NAMES are like "i:4m", "s2m", etc. */
00078 
00079 /* variables for GMS */
00080 static int my_nbest;            
00081 static int allocframenum;       
00082 
00083 /* GMS info */
00084 static GS_SET *gsset;           
00085 static int gsset_num;           
00086 static int *state2gs; 
00087 
00088 /* results */
00089 static boolean *is_selected;    
00090 static LOGPROB **fallback_score = NULL; 
00091 
00092 /* for calculation */
00093 static int *gsindex;            
00094 static LOGPROB *t_fs;           
00095 
00096 
00101 static void
00102 build_gsset()
00103 {
00104   HTK_HMM_State *st;
00105 
00106   /* allocate */
00107   gsset = (GS_SET *)mymalloc(sizeof(GS_SET) * OP_gshmm->totalstatenum);
00108   gsset_num = OP_gshmm->totalstatenum;
00109   /* make ID */
00110   for(st = OP_gshmm->ststart; st; st=st->next) {
00111     gsset[st->id].state = st;
00112   }
00113 }
00114 
00115 #define MAXHMMNAMELEN 40        
00116 
00117 
00122 static boolean
00123 build_state2gs()
00124 {
00125   HTK_HMM_Data *dt;
00126   HTK_HMM_State *st, *cr;
00127   int i;
00128   char gstr[MAXHMMNAMELEN], cbuf[MAXHMMNAMELEN];
00129   boolean ok_p = TRUE;
00130 
00131   /* initialize */
00132   state2gs = (int *)mymalloc(sizeof(int) * OP_hmminfo->totalstatenum);
00133   for(i=0;i<OP_hmminfo->totalstatenum;i++) state2gs[i] = -1;
00134 
00135   /* parse through all HMM macro to register their state */
00136   for(dt = OP_hmminfo->start; dt; dt=dt->next) {
00137     if (strlen(dt->name) >= MAXHMMNAMELEN - 2) {
00138       j_printerr("Error: too long hmm name (>%d): \"%s\"\n",
00139                  MAXHMMNAMELEN-3, dt->name);
00140       ok_p = FALSE;
00141       continue;
00142     }
00143     for(i=1;i<dt->state_num-1;i++) { /* for all state */
00144       st = dt->s[i];
00145       /* skip if already assigned */
00146       if (state2gs[st->id] != -1) continue;
00147       /* set corresponding gshmm name */
00148       sprintf(gstr, "%s%dm", center_name(dt->name, cbuf), i + 1);
00149       /* look up the state in OP_gshmm */
00150       if ((cr = state_lookup(OP_gshmm, gstr)) == NULL) {
00151         j_printerr("Error: GS HMM \"%s\" not defined\n", gstr);
00152         ok_p = FALSE;
00153         continue;
00154       }
00155       /* store its ID */
00156       state2gs[st->id] = cr->id;
00157     }
00158   }
00159 #ifdef PARANOIA
00160   {
00161     HTK_HMM_State *st;
00162     for(st=OP_hmminfo->ststart; st; st=st->next) {
00163       printf("%s -> %s\n", (st->name == NULL) ? "(NULL)" : st->name,
00164              (gsset[state2gs[st->id]].state)->name);
00165     }
00166   }
00167 #endif
00168   return ok_p;
00169 }
00170 
00171 
00172 /* sort to find N-best states */
00173 #define SD(A) gsindex[A-1]      
00174 #define SCOPY(D,S) D = S        
00175 #define SVAL(A) (t_fs[gsindex[A-1]]) 
00176 #define STVAL (t_fs[s]) 
00177 
00178 
00184 static void
00185 sort_gsindex_upward(int neednum, int totalnum)
00186 {
00187   int n,root,child,parent;
00188   int s;
00189   for (root = totalnum/2; root >= 1; root--) {
00190     SCOPY(s, SD(root));
00191     parent = root;
00192     while ((child = parent * 2) <= totalnum) {
00193       if (child < totalnum && SVAL(child) < SVAL(child+1)) {
00194         child++;
00195       }
00196       if (STVAL >= SVAL(child)) {
00197         break;
00198       }
00199       SCOPY(SD(parent), SD(child));
00200       parent = child;
00201     }
00202     SCOPY(SD(parent), s);
00203   }
00204   n = totalnum;
00205   while ( n > totalnum - neednum) {
00206     SCOPY(s, SD(n));
00207     SCOPY(SD(n), SD(1));
00208     n--;
00209     parent = 1;
00210     while ((child = parent * 2) <= n) {
00211       if (child < n && SVAL(child) < SVAL(child+1)) {
00212         child++;
00213       }
00214       if (STVAL >= SVAL(child)) {
00215         break;
00216       }
00217       SCOPY(SD(parent), SD(child));
00218       parent = child;
00219     }
00220     SCOPY(SD(parent), s);
00221   }
00222 }
00223 
00228 static void
00229 do_gms()
00230 {
00231   int i;
00232   
00233   /* compute all gshmm scores (in gs_score.c) */
00234   compute_gs_scores(gsset, gsset_num, t_fs);
00235   /* sort and select */
00236   sort_gsindex_upward(my_nbest, gsset_num);
00237   for(i=gsset_num - my_nbest;i<gsset_num;i++) {
00238     /* set scores of selected states to LOG_ZERO */
00239     t_fs[gsindex[i]] = LOG_ZERO;
00240   }
00241 
00242   /* power e -> 10 */
00243 #ifdef NORMALIZE_GS_SCORE
00244   /* normalize other fallback scores (rate of max) */
00245   for(i=0;i<gsset_num;i++) {
00246     if (t_fs[i] != LOG_ZERO) {
00247       t_fs[i] = t_fs[i] * 0.975;
00248     }
00249   }
00250 #endif
00251 }  
00252 
00253 
00261 boolean
00262 gms_init(int nbest)
00263 {
00264   int i;
00265   
00266   /* Check gshmm type */
00267   if (OP_gshmm->is_triphone) {
00268     j_printerr("Error: GS HMM should be a monophone model\n");
00269     return FALSE;
00270   }
00271   if (OP_gshmm->is_tied_mixture) {
00272     j_printerr("Error: GS HMM should not be a tied mixture model\n");
00273     return FALSE;
00274   }
00275 
00276   /* store as local info */
00277   my_nbest = nbest;
00278 
00279   /* Register all GS HMM states in GS_SET */
00280   build_gsset();
00281   /* Make correspondence of all triphone states to GS HMM states */
00282   j_printerr("Mapping HMM states to GS HMM...");
00283   if (build_state2gs() == FALSE) {
00284     j_printerr("Error: failed in assigning GS HMM state for each state\n");
00285     return FALSE;
00286   }
00287   j_printerr("done\n");
00288 
00289   /* prepare index buffer for heap sort */
00290   gsindex = (int *)mymalloc(sizeof(int) * gsset_num);
00291   for(i=0;i<gsset_num;i++) gsindex[i] = i;
00292 
00293   /* init cache status */
00294   fallback_score = NULL;
00295   is_selected = NULL;
00296   allocframenum = -1;
00297 
00298   /* initialize gms_gprune functions */
00299   gms_gprune_init(OP_hmminfo, gsset_num);
00300   
00301   return TRUE;
00302 }
00303 
00311 boolean
00312 gms_prepare(int framenum)
00313 {
00314   LOGPROB *tmp;
00315   int t;
00316 
00317   /* allocate cache */
00318   if (allocframenum < framenum) {
00319     if (fallback_score != NULL) {
00320       free(fallback_score[0]);
00321       free(fallback_score);
00322       free(is_selected);
00323     }
00324     fallback_score = (LOGPROB **)mymalloc(sizeof(LOGPROB *) * framenum);
00325     tmp = (LOGPROB *)mymalloc(sizeof(LOGPROB) * gsset_num * framenum);
00326     for(t=0;t<framenum;t++) {
00327       fallback_score[t] = &(tmp[gsset_num * t]);
00328     }
00329     is_selected = (boolean *)mymalloc(sizeof(boolean) * framenum);
00330     allocframenum = framenum;
00331   }
00332   /* clear */
00333   for(t=0;t<framenum;t++) is_selected[t] = FALSE;
00334 
00335   /* prepare gms_gprune functions */
00336   gms_gprune_prepare();
00337   
00338   return TRUE;
00339 }
00340 
00351 LOGPROB
00352 gms_state()
00353 {
00354   LOGPROB gsprob;
00355   if (OP_last_time != OP_time) { /* different frame */
00356     /* set current buffer */
00357     t_fs = fallback_score[OP_time];
00358     /* select state if not yet */
00359     if (!is_selected[OP_time]) {
00360       do_gms();
00361       is_selected[OP_time] = TRUE;
00362     }
00363   }
00364   if ((gsprob = t_fs[state2gs[OP_state_id]]) != LOG_ZERO) {
00365     /* un-selected: return the fallback value */
00366     return(gsprob);
00367   }
00368   /* selected: calculate the real outprob of the state */
00369   return(calc_outprob());
00370 }

Generated on Tue Mar 28 16:01:39 2006 for Julius by  doxygen 1.4.2