00001 
00044 
00045 
00046 
00047 
00048 
00049 
00050 
00051 #include <sent/stddefs.h>
00052 #include <sent/htk_hmm.h>
00053 #include <sent/htk_param.h>
00054 #include <sent/hmm.h>
00055 #include <sent/hmm_calc.h>
00056 
00057 
00058 
00059 
00060 
00061 
00062 
00063 
00064 
00065 
00066 
00067 
00068 
00069 
00070 
00071 
00072 
00073 
00074 
00075 
00076 
00077 
00078 
00079 
00080 
00081 
00082 
00083 
00084 
00085 
00086 
00087 
00088 
00089 
00090 
00091 
00092 
00093 
00094 
00095 
00096 
00097 
00098 
00099 
00100 
00101 
00102 
00103 
00104 
00111 static void
00112 init_backmax(HMMWork *wrk)
00113 {
00114   int i;
00115   for(i=0;i<wrk->backmax_num;i++) wrk->backmax[i] = 0;
00116 }
00117 
00124 
00125 
00126 
00127 
00128 
00129 
00130 
00131 
00132 static void
00133 make_backmax(HMMWork *wrk)
00134 {
00135   int i;
00136 
00137   wrk->backmax[wrk->backmax_num-1] = 0.0;
00138   
00139   for(i=wrk->backmax_num-2;i>=0;i--) {
00140     wrk->backmax[i] += wrk->backmax[i+1];
00141   }
00142   
00143 
00144 
00145 }
00146 
00160 static LOGPROB
00161 compute_g_heu_updating(HMMWork *wrk, HTK_HMM_Dens *binfo)
00162 {
00163   VECT tmp, x, sum = 0.0;
00164   VECT *mean;
00165   VECT *var;
00166   VECT *bm = wrk->backmax;
00167   VECT *vec = wrk->OP_vec;
00168   short veclen = wrk->OP_veclen;
00169 
00170   if (binfo == NULL) return(LOG_ZERO);
00171   mean = binfo->mean;
00172   var = binfo->var->vec;
00173 
00174   tmp = 0.0;
00175   for (; veclen > 0; veclen--) {
00176     x = *(vec++) - *(mean++);
00177     tmp = x * x * *(var++);
00178     sum += tmp;
00179     if ( *bm < tmp) *bm = tmp;
00180     bm++;
00181   }
00182   return((sum + binfo->gconst) * -0.5);
00183 }
00184 
00199 static LOGPROB
00200 compute_g_heu_pruning(HMMWork *wrk, HTK_HMM_Dens *binfo, LOGPROB thres)
00201 {
00202   VECT tmp, x;
00203   VECT *mean;
00204   VECT *var;
00205   VECT *bm = wrk->backmax;
00206   VECT *vec = wrk->OP_vec;
00207   short veclen = wrk->OP_veclen;
00208   LOGPROB fthres;
00209 
00210   if (binfo == NULL) return(LOG_ZERO);
00211   mean = binfo->mean;
00212   var = binfo->var->vec;
00213   fthres = thres * (-2.0);
00214 
00215   tmp = 0.0;
00216   bm++;
00217   for (; veclen > 0; veclen--) {
00218     x = *(vec++) - *(mean++);
00219     tmp += x * x * *(var++);
00220     if ( tmp + *bm > fthres) {
00221       return LOG_ZERO;
00222     }
00223     bm++;
00224   }
00225   return((tmp + binfo->gconst) * -0.5);
00226 }
00227 
00228 
00236 boolean
00237 gprune_heu_init(HMMWork *wrk)
00238 {
00239   int i;
00240 
00241   
00242   wrk->OP_calced_maxnum = wrk->OP_hmminfo->maxmixturenum * wrk->OP_nstream;
00243   wrk->OP_calced_score = (LOGPROB *)mymalloc(sizeof(LOGPROB) * wrk->OP_calced_maxnum);
00244   wrk->OP_calced_id = (int *)mymalloc(sizeof(int) * wrk->OP_calced_maxnum);
00245   wrk->mixcalced = (boolean *)mymalloc(sizeof(int) * wrk->OP_calced_maxnum);
00246   for(i=0;i<wrk->OP_calced_maxnum;i++) wrk->mixcalced[i] = FALSE;
00247   wrk->backmax_num = wrk->OP_hmminfo->opt.vec_size + 1;
00248   wrk->backmax = (LOGPROB *)mymalloc(sizeof(LOGPROB) * wrk->backmax_num);
00249 
00250   return TRUE;
00251 }
00252 
00259 void
00260 gprune_heu_free(HMMWork *wrk)
00261 {
00262   free(wrk->OP_calced_score);
00263   free(wrk->OP_calced_id);
00264   free(wrk->mixcalced);
00265   free(wrk->backmax);
00266 }
00267 
00294 void
00295 gprune_heu(HMMWork *wrk, HTK_HMM_Dens **g, int gnum, int *last_id, int lnum)
00296 {
00297   int i, j, num = 0;
00298   LOGPROB score, thres;
00299 
00300   if (last_id != NULL) {        
00301     
00302     init_backmax(wrk);
00303     
00304     for (j=0; j<lnum; j++) {
00305       i = last_id[j];
00306       score = compute_g_heu_updating(wrk, g[i]);
00307       num = cache_push(wrk, i, score, num);
00308       wrk->mixcalced[i] = TRUE;      
00309     }
00310     
00311     make_backmax(wrk);
00312     
00313     thres = wrk->OP_calced_score[num-1];
00314     for (i = 0; i < gnum; i++) {
00315       
00316       if (wrk->mixcalced[i]) {
00317         wrk->mixcalced[i] = FALSE;
00318         continue;
00319       }
00320       
00321       score = compute_g_heu_pruning(wrk, g[i], thres);
00322       if (score > LOG_ZERO) {
00323         num = cache_push(wrk, i, score, num);
00324         thres = wrk->OP_calced_score[num-1];
00325       }
00326     }
00327   } else {                      
00328     
00329     
00330     thres = LOG_ZERO;
00331     for (i = 0; i < gnum; i++) {
00332       if (num < wrk->OP_gprune_num) {
00333         score = compute_g_base(wrk, g[i]);
00334       } else {
00335         score = compute_g_safe(wrk, g[i], thres);
00336         if (score <= thres) continue;
00337       }
00338       num = cache_push(wrk, i, score, num);
00339       thres = wrk->OP_calced_score[num-1];
00340     }
00341   }
00342   wrk->OP_calced_num = num;
00343 }