00001 
00043 
00044 
00045 
00046 
00047 
00048 
00049 
00050 #include <sent/stddefs.h>
00051 #include <sent/htk_hmm.h>
00052 #include <sent/htk_param.h>
00053 #include <sent/hmm.h>
00054 #include <sent/gprune.h>
00055 #include "globalvars.h"
00056 
00057 
00058 
00059 
00060 
00061 
00062 
00063 
00064 
00065 
00066 
00067 
00068 
00069 
00070 
00071 
00072 
00073 
00074 
00075 
00076 
00077 
00078 
00079 
00080 
00081 
00082 
00083 
00084 
00085 
00086 
00087 
00088 
00089 
00090 
00091 
00092 
00093 
00094 
00095 
00096 
00097 
00098 
00099 
00100 
00101 
00102 
00103 
00104 
00105 
00106 static LOGPROB *backmax;        
00107 static int backmax_num;         
00108 
00109 static boolean *mixcalced;      
00110 
00115 static void
00116 init_backmax()
00117 {
00118   int i;
00119   for(i=0;i<backmax_num;i++) backmax[i] = 0;
00120 }
00121 
00127 
00128 
00129 
00130 
00131 
00132 
00133 
00134 
00135 static void
00136 make_backmax()
00137 {
00138   int i;
00139   backmax[backmax_num-1] = 0.0;
00140   
00141   for(i=backmax_num-2;i>=0;i--) {
00142     backmax[i] += backmax[i+1];
00143   }
00144   
00145 
00146 
00147 }
00148 
00161 static LOGPROB
00162 compute_g_heu_updating(HTK_HMM_Dens *binfo)
00163 {
00164   VECT tmp, x, sum = 0.0;
00165   VECT *mean;
00166   VECT *var;
00167   VECT *bm = backmax;
00168   VECT *vec = OP_vec;
00169   short veclen = OP_veclen;
00170 
00171   if (binfo == NULL) return(LOG_ZERO);
00172   mean = binfo->mean;
00173   var = binfo->var->vec;
00174 
00175   tmp = 0.0;
00176   for (; veclen > 0; veclen--) {
00177     x = *(vec++) - *(mean++);
00178     tmp = x * x * *(var++);
00179     sum += tmp;
00180     if ( *bm < tmp) *bm = tmp;
00181     bm++;
00182   }
00183   return((sum + binfo->gconst) * -0.5);
00184 }
00185 
00199 static LOGPROB
00200 compute_g_heu_pruning(HTK_HMM_Dens *binfo, LOGPROB thres)
00201 {
00202   VECT tmp, x;
00203   VECT *mean;
00204   VECT *var;
00205   VECT *bm = backmax;
00206   VECT *vec = OP_vec;
00207   short veclen = OP_veclen;
00208   LOGPROB fthres;
00209 
00210   if (binfo == NULL) return(LOG_ZERO);
00211   mean = binfo->mean;
00212   var = binfo->var->vec;
00213   fthres = thres * (-2.0);
00214 
00215   tmp = 0.0;
00216   bm++;
00217   for (; veclen > 0; veclen--) {
00218     x = *(vec++) - *(mean++);
00219     tmp += x * x * *(var++);
00220     if ( tmp + *bm > fthres) {
00221       return LOG_ZERO;
00222     }
00223     bm++;
00224   }
00225   return((tmp + binfo->gconst) * -0.5);
00226 }
00227 
00228 
00234 boolean
00235 gprune_heu_init()
00236 {
00237   int i;
00238   
00239   OP_calced_maxnum = OP_hmminfo->maxmixturenum;
00240   OP_calced_score = (LOGPROB *)mymalloc(sizeof(LOGPROB) * OP_gprune_num);
00241   OP_calced_id = (int *)mymalloc(sizeof(int) * OP_gprune_num);
00242   mixcalced = (boolean *)mymalloc(sizeof(int) * OP_calced_maxnum);
00243   for(i=0;i<OP_calced_maxnum;i++) mixcalced[i] = FALSE;
00244   backmax_num = OP_hmminfo->opt.vec_size + 1;
00245   backmax = (LOGPROB *)mymalloc(sizeof(LOGPROB) * backmax_num);
00246 
00247   return TRUE;
00248 }
00249 
00254 void
00255 gprune_heu_free()
00256 {
00257   free(OP_calced_score);
00258   free(OP_calced_id);
00259   free(mixcalced);
00260   free(backmax);
00261 }
00262 
00287 void
00288 gprune_heu(HTK_HMM_Dens **g, int gnum, int *last_id)
00289 {
00290   int i, j, num = 0;
00291   LOGPROB score, thres;
00292 
00293   if (last_id != NULL) {        
00294     
00295     init_backmax();
00296     
00297     for (j=0; j<OP_gprune_num; j++) {
00298       i = last_id[j];
00299       score = compute_g_heu_updating(g[i]);
00300       num = cache_push(i, score, num);
00301       mixcalced[i] = TRUE;      
00302     }
00303     
00304     make_backmax();
00305     
00306     thres = OP_calced_score[num-1];
00307     for (i = 0; i < gnum; i++) {
00308       
00309       if (mixcalced[i]) {
00310         mixcalced[i] = FALSE;
00311         continue;
00312       }
00313       
00314       score = compute_g_heu_pruning(g[i], thres);
00315       if (score > LOG_ZERO) {
00316         num = cache_push(i, score, num);
00317         thres = OP_calced_score[num-1];
00318       }
00319     }
00320   } else {                      
00321     
00322     
00323     thres = LOG_ZERO;
00324     for (i = 0; i < gnum; i++) {
00325       if (num < OP_gprune_num) {
00326         score = compute_g_base(g[i]);
00327       } else {
00328         score = compute_g_safe(g[i], thres);
00329         if (score <= thres) continue;
00330       }
00331       num = cache_push(i, score, num);
00332       thres = OP_calced_score[num-1];
00333     }
00334   }
00335   OP_calced_num = num;
00336 }