00001
00046
00047
00048
00049
00050
00051
00052
00053 #include <sent/stddefs.h>
00054 #include <sent/htk_hmm.h>
00055 #include <sent/htk_param.h>
00056 #include <sent/hmm.h>
00057 #include <sent/hmm_calc.h>
00058
00059 #define TEST2
00060
00061
00062
00063
00064
00065
00066
00067
00068
00069
00070
00071
00072
00073
00074
00075
00076
00077
00078
00079
00080
00081
00082
00083
00084
00085
00086
00087
00088
00089
00090
00091
00092
00093
00094
00095
00096
00097
00098
00099
00100
00101
00102
00103
00104
00105
00106
00107
00108
00109
00110
00111
00112
00113
00114
00115
00116
00117
00124 static void
00125 clear_dimthres(HMMWork *wrk)
00126 {
00127 int i;
00128 for(i=0;i<wrk->dimthres_num;i++) wrk->dimthres[i] = 0.0;
00129 }
00130
00138 static void
00139 set_dimthres(HMMWork *wrk)
00140 {
00141 int i;
00142 for(i=0;i<wrk->dimthres_num;i++) wrk->dimthres[i] += TMBEAMWIDTH;
00143 }
00144
00158 static LOGPROB
00159 compute_g_beam_updating(HMMWork *wrk, HTK_HMM_Dens *binfo)
00160 {
00161 VECT tmp, x;
00162 VECT *mean;
00163 VECT *var;
00164 VECT *th = wrk->dimthres;
00165 VECT *vec = wrk->OP_vec;
00166 short veclen = wrk->OP_veclen;
00167
00168 #ifndef TEST2
00169 if (binfo == NULL) return(LOG_ZERO);
00170 #endif
00171
00172 mean = binfo->mean;
00173 var = binfo->var->vec;
00174
00175 tmp = 0.0;
00176 for (; veclen > 0; veclen--) {
00177 x = *(vec++) - *(mean++);
00178 tmp += x * x * *(var++);
00179 if ( *th < tmp) *th = tmp;
00180 th++;
00181 }
00182 return((tmp + binfo->gconst) * -0.5);
00183 }
00184
00198 static LOGPROB
00199 compute_g_beam_pruning(HMMWork *wrk, HTK_HMM_Dens *binfo)
00200 {
00201 VECT tmp, x;
00202 VECT *mean;
00203 VECT *var;
00204 VECT *th = wrk->dimthres;
00205 VECT *vec = wrk->OP_vec;
00206 short veclen = wrk->OP_veclen;
00207
00208 #ifndef TEST2
00209 if (binfo == NULL) return(LOG_ZERO);
00210 #endif
00211 mean = binfo->mean;
00212 var = binfo->var->vec;
00213
00214 tmp = 0.0;
00215 for (; veclen > 0; veclen--) {
00216 x = *(vec++) - *(mean++);
00217 tmp += x * x * *(var++);
00218 if ( tmp > *(th++)) {
00219 return LOG_ZERO;
00220 }
00221 }
00222 return((tmp + binfo->gconst) * -0.5);
00223 }
00224
00225
00233 boolean
00234 gprune_beam_init(HMMWork *wrk)
00235 {
00236 int i;
00237
00238
00239 wrk->OP_calced_maxnum = wrk->OP_hmminfo->maxmixturenum * wrk->OP_nstream;
00240 wrk->OP_calced_score = (LOGPROB *)mymalloc(sizeof(LOGPROB) * wrk->OP_calced_maxnum);
00241 wrk->OP_calced_id = (int *)mymalloc(sizeof(int) * wrk->OP_calced_maxnum);
00242 wrk->mixcalced = (boolean *)mymalloc(sizeof(int) * wrk->OP_calced_maxnum);
00243 for(i=0;i<wrk->OP_calced_maxnum;i++) wrk->mixcalced[i] = FALSE;
00244 wrk->dimthres_num = wrk->OP_hmminfo->opt.vec_size;
00245 wrk->dimthres = (LOGPROB *)mymalloc(sizeof(LOGPROB) * wrk->dimthres_num);
00246
00247 return TRUE;
00248 }
00249
00256 void
00257 gprune_beam_free(HMMWork *wrk)
00258 {
00259 free(wrk->OP_calced_score);
00260 free(wrk->OP_calced_id);
00261 free(wrk->mixcalced);
00262 free(wrk->dimthres);
00263 }
00264
00290 void
00291 gprune_beam(HMMWork *wrk, HTK_HMM_Dens **g, int gnum, int *last_id, int lnum)
00292 {
00293 int i, j, num = 0;
00294 LOGPROB score, thres;
00295
00296 if (last_id != NULL) {
00297
00298 clear_dimthres(wrk);
00299
00300 for (j=0; j<lnum; j++) {
00301 i = last_id[j];
00302 #ifdef TEST2
00303 if (!g[i]) {
00304 score = LOG_ZERO;
00305 } else {
00306 score = compute_g_beam_updating(wrk, g[i]);
00307 }
00308 num = cache_push(wrk, i, score, num);
00309 #else
00310 score = compute_g_beam_updating(wrk, g[i]);
00311 num = cache_push(wrk, i, score, num);
00312 #endif
00313 wrk->mixcalced[i] = TRUE;
00314 }
00315
00316 set_dimthres(wrk);
00317
00318
00319 for (i = 0; i < gnum; i++) {
00320
00321 if (wrk->mixcalced[i]) {
00322 wrk->mixcalced[i] = FALSE;
00323 continue;
00324 }
00325 #ifdef TEST2
00326
00327 if (!g[i]) continue;
00328 score = compute_g_beam_pruning(wrk, g[i]);
00329 if (score > LOG_ZERO) {
00330 num = cache_push(wrk, i, score, num);
00331 }
00332 #else
00333
00334 score = compute_g_beam_pruning(wrk, g[i]);
00335 if (score > LOG_ZERO) {
00336 num = cache_push(wrk, i, score, num);
00337 }
00338 #endif
00339 }
00340 } else {
00341
00342
00343 thres = LOG_ZERO;
00344 for (i = 0; i < gnum; i++) {
00345 if (num < wrk->OP_gprune_num) {
00346 score = compute_g_base(wrk, g[i]);
00347 } else {
00348 score = compute_g_safe(wrk, g[i], thres);
00349 if (score <= thres) continue;
00350 }
00351 num = cache_push(wrk, i, score, num);
00352 thres = wrk->OP_calced_score[num-1];
00353 }
00354 }
00355 wrk->OP_calced_num = num;
00356 }