00001 
00045 
00046 
00047 
00048 
00049 
00050 
00051 
00052 #include <sent/stddefs.h>
00053 #include <sent/htk_hmm.h>
00054 #include <sent/htk_param.h>
00055 #include <sent/hmm.h>
00056 #include <sent/gprune.h>
00057 #include "globalvars.h"
00058 
00059 
00060 
00061 
00062 
00063 
00064 
00065 
00066 
00067 
00068 
00069 
00070 
00071 
00072 
00073 
00074 
00075 
00076 
00077 
00078 
00079 
00080 
00081 
00082 
00083 
00084 
00085 
00086 
00087 
00088 
00089 
00090 
00091 
00092 
00093 
00094 
00095 
00096 
00097 
00098 
00099 
00100 
00101 
00102 
00103 
00104 
00105 
00106 
00107 
00108 
00109 
00110 
00111 
00112 
00113 
00114 
00115 
00116 
00117 static LOGPROB *dimthres;       
00118 static int dimthres_num;        
00119 
00120 static boolean *mixcalced;      
00121 
00126 static void
00127 clear_dimthres()
00128 {
00129   int i;
00130   for(i=0;i<dimthres_num;i++) dimthres[i] = 0.0;
00131 }
00132 
00138 static void
00139 set_dimthres()
00140 {
00141   int i;
00142   for(i=0;i<dimthres_num;i++) dimthres[i] += TMBEAMWIDTH;
00143 }
00144 
00157 static LOGPROB
00158 compute_g_beam_updating(HTK_HMM_Dens *binfo)
00159 {
00160   VECT tmp, x;
00161   VECT *mean;
00162   VECT *var;
00163   VECT *th = dimthres;
00164   VECT *vec = OP_vec;
00165   short veclen = OP_veclen;
00166 
00167   if (binfo == NULL) return(LOG_ZERO);
00168   mean = binfo->mean;
00169   var = binfo->var->vec;
00170 
00171   tmp = 0.0;
00172   for (; veclen > 0; veclen--) {
00173     x = *(vec++) - *(mean++);
00174     tmp += x * x * *(var++);
00175     if ( *th < tmp) *th = tmp;
00176     th++;
00177   }
00178   return((tmp + binfo->gconst) * -0.5);
00179 }
00180 
00193 static LOGPROB
00194 compute_g_beam_pruning(HTK_HMM_Dens *binfo)
00195 {
00196   VECT tmp, x;
00197   VECT *mean;
00198   VECT *var;
00199   VECT *th = dimthres;
00200   VECT *vec = OP_vec;
00201   short veclen = OP_veclen;
00202 
00203   if (binfo == NULL) return(LOG_ZERO);
00204   mean = binfo->mean;
00205   var = binfo->var->vec;
00206 
00207   tmp = 0.0;
00208   for (; veclen > 0; veclen--) {
00209     x = *(vec++) - *(mean++);
00210     tmp += x * x * *(var++);
00211     if ( tmp > *(th++)) {
00212       return LOG_ZERO;
00213     }
00214   }
00215   return((tmp + binfo->gconst) * -0.5);
00216 }
00217 
00218 
00224 boolean
00225 gprune_beam_init()
00226 {
00227   int i;
00228   
00229   OP_calced_maxnum = OP_hmminfo->maxmixturenum;
00230   OP_calced_score = (LOGPROB *)mymalloc(sizeof(LOGPROB) * OP_gprune_num);
00231   OP_calced_id = (int *)mymalloc(sizeof(int) * OP_gprune_num);
00232   mixcalced = (boolean *)mymalloc(sizeof(int) * OP_calced_maxnum);
00233   for(i=0;i<OP_calced_maxnum;i++) mixcalced[i] = FALSE;
00234   dimthres_num = OP_hmminfo->opt.vec_size;
00235   dimthres = (LOGPROB *)mymalloc(sizeof(LOGPROB) * dimthres_num);
00236 
00237   return TRUE;
00238 }
00239 
00244 void
00245 gprune_beam_free()
00246 {
00247   free(OP_calced_score);
00248   free(OP_calced_id);
00249   free(mixcalced);
00250   free(dimthres);
00251 }
00252 
00276 void
00277 gprune_beam(HTK_HMM_Dens **g, int gnum, int *last_id)
00278 {
00279   int i, j, num = 0;
00280   LOGPROB score, thres;
00281 
00282   if (last_id != NULL) {        
00283     
00284     clear_dimthres();
00285     
00286     for (j=0; j<OP_gprune_num; j++) {
00287       i = last_id[j];
00288       score = compute_g_beam_updating(g[i]);
00289       num = cache_push(i, score, num);
00290       mixcalced[i] = TRUE;      
00291     }
00292     
00293     set_dimthres();
00294 
00295     
00296     for (i = 0; i < gnum; i++) {
00297       
00298       if (mixcalced[i]) {
00299         mixcalced[i] = FALSE;
00300         continue;
00301       }
00302       
00303       score = compute_g_beam_pruning(g[i]);
00304       if (score > LOG_ZERO) {
00305         num = cache_push(i, score, num);
00306       }
00307     }
00308   } else {                      
00309     
00310     
00311     thres = LOG_ZERO;
00312     for (i = 0; i < gnum; i++) {
00313       if (num < OP_gprune_num) {
00314         score = compute_g_base(g[i]);
00315       } else {
00316         score = compute_g_safe(g[i], thres);
00317         if (score <= thres) continue;
00318       }
00319       num = cache_push(i, score, num);
00320       thres = OP_calced_score[num-1];
00321     }
00322   }
00323   OP_calced_num = num;
00324 }