00001
00046
00047
00048
00049
00050
00051
00052
00053 #include <sent/stddefs.h>
00054 #include <sent/htk_hmm.h>
00055 #include <sent/htk_param.h>
00056 #include <sent/hmm.h>
00057 #include <sent/hmm_calc.h>
00058
00059 #define TEST2
00060
00061
00062
00063
00064
00065
00066
00067
00068
00069
00070
00071
00072
00073
00074
00075
00076
00077
00078
00079
00080
00081
00082
00083
00084
00085
00086
00087
00088
00089
00090
00091
00092
00093
00094
00095
00096
00097
00098
00099
00100
00101
00102
00103
00104
00105
00106
00107
00108
00109
00110
00111
00112
00113
00114
00115
00116
00117
00124 static void
00125 clear_dimthres(HMMWork *wrk)
00126 {
00127 int i;
00128 for(i=0;i<wrk->dimthres_num;i++) wrk->dimthres[i] = 0.0;
00129 }
00130
00138 static void
00139 set_dimthres(HMMWork *wrk)
00140 {
00141 int i;
00142 for(i=0;i<wrk->dimthres_num;i++) wrk->dimthres[i] += TMBEAMWIDTH;
00143 }
00144
00158 static LOGPROB
00159 compute_g_beam_updating(HMMWork *wrk, HTK_HMM_Dens *binfo)
00160 {
00161 VECT tmp, x;
00162 VECT *mean;
00163 VECT *var;
00164 VECT *th = wrk->dimthres;
00165 VECT *vec = wrk->OP_vec;
00166 short veclen = wrk->OP_veclen;
00167
00168 #ifndef TEST2
00169 if (binfo == NULL) return(LOG_ZERO);
00170 #endif
00171
00172 mean = binfo->mean;
00173 var = binfo->var->vec;
00174
00175 tmp = 0.0;
00176 for (; veclen > 0; veclen--) {
00177 x = *(vec++) - *(mean++);
00178 tmp += x * x * *(var++);
00179 if ( *th < tmp) *th = tmp;
00180 th++;
00181 }
00182 return((tmp + binfo->gconst) * -0.5);
00183 }
00184
00198 static LOGPROB
00199 compute_g_beam_pruning(HMMWork *wrk, HTK_HMM_Dens *binfo)
00200 {
00201 VECT tmp, x;
00202 VECT *mean;
00203 VECT *var;
00204 VECT *th = wrk->dimthres;
00205 VECT *vec = wrk->OP_vec;
00206 short veclen = wrk->OP_veclen;
00207
00208 #ifndef TEST2
00209 if (binfo == NULL) return(LOG_ZERO);
00210 #endif
00211 mean = binfo->mean;
00212 var = binfo->var->vec;
00213
00214 tmp = 0.0;
00215 for (; veclen > 0; veclen--) {
00216 x = *(vec++) - *(mean++);
00217 tmp += x * x * *(var++);
00218 if ( tmp > *(th++)) {
00219 return LOG_ZERO;
00220 }
00221 }
00222 return((tmp + binfo->gconst) * -0.5);
00223 }
00224
00225
00233 boolean
00234 gprune_beam_init(HMMWork *wrk)
00235 {
00236 int i;
00237
00238 wrk->OP_calced_maxnum = wrk->OP_hmminfo->maxmixturenum;
00239 wrk->OP_calced_score = (LOGPROB *)mymalloc(sizeof(LOGPROB) * wrk->OP_gprune_num);
00240 wrk->OP_calced_id = (int *)mymalloc(sizeof(int) * wrk->OP_gprune_num);
00241 wrk->mixcalced = (boolean *)mymalloc(sizeof(int) * wrk->OP_calced_maxnum);
00242 for(i=0;i<wrk->OP_calced_maxnum;i++) wrk->mixcalced[i] = FALSE;
00243 wrk->dimthres_num = wrk->OP_hmminfo->opt.vec_size;
00244 wrk->dimthres = (LOGPROB *)mymalloc(sizeof(LOGPROB) * wrk->dimthres_num);
00245
00246 return TRUE;
00247 }
00248
00255 void
00256 gprune_beam_free(HMMWork *wrk)
00257 {
00258 free(wrk->OP_calced_score);
00259 free(wrk->OP_calced_id);
00260 free(wrk->mixcalced);
00261 free(wrk->dimthres);
00262 }
00263
00288 void
00289 gprune_beam(HMMWork *wrk, HTK_HMM_Dens **g, int gnum, int *last_id)
00290 {
00291 int i, j, num = 0;
00292 LOGPROB score, thres;
00293
00294 if (last_id != NULL) {
00295
00296 clear_dimthres(wrk);
00297
00298 for (j=0; j<wrk->OP_gprune_num; j++) {
00299 i = last_id[j];
00300 #ifdef TEST2
00301 if (!g[i]) {
00302 score = LOG_ZERO;
00303 } else {
00304 score = compute_g_beam_updating(wrk, g[i]);
00305 }
00306 num = cache_push(wrk, i, score, num);
00307 #else
00308 score = compute_g_beam_updating(wrk, g[i]);
00309 num = cache_push(wrk, i, score, num);
00310 #endif
00311 wrk->mixcalced[i] = TRUE;
00312 }
00313
00314 set_dimthres(wrk);
00315
00316
00317 for (i = 0; i < gnum; i++) {
00318
00319 if (wrk->mixcalced[i]) {
00320 wrk->mixcalced[i] = FALSE;
00321 continue;
00322 }
00323 #ifdef TEST2
00324
00325 if (!g[i]) continue;
00326 score = compute_g_beam_pruning(wrk, g[i]);
00327 if (score > LOG_ZERO) {
00328 num = cache_push(wrk, i, score, num);
00329 }
00330 #else
00331
00332 score = compute_g_beam_pruning(wrk, g[i]);
00333 if (score > LOG_ZERO) {
00334 num = cache_push(wrk, i, score, num);
00335 }
00336 #endif
00337 }
00338 } else {
00339
00340
00341 thres = LOG_ZERO;
00342 for (i = 0; i < gnum; i++) {
00343 if (num < wrk->OP_gprune_num) {
00344 score = compute_g_base(wrk, g[i]);
00345 } else {
00346 score = compute_g_safe(wrk, g[i], thres);
00347 if (score <= thres) continue;
00348 }
00349 num = cache_push(wrk, i, score, num);
00350 thres = wrk->OP_calced_score[num-1];
00351 }
00352 }
00353 wrk->OP_calced_num = num;
00354 }