00001
00044
00045
00046
00047
00048
00049
00050
00051 #include <sent/stddefs.h>
00052 #include <sent/htk_hmm.h>
00053 #include <sent/htk_param.h>
00054 #include <sent/hmm.h>
00055 #include <sent/hmm_calc.h>
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067
00068
00069
00070
00071
00072
00073
00074
00075
00076
00077
00078
00079
00080
00081
00082
00083
00084
00085
00086
00087
00088
00089
00090
00091
00092
00093
00094
00095
00096
00097
00098
00099
00100
00101
00102
00103
00104
00111 static void
00112 init_backmax(HMMWork *wrk)
00113 {
00114 int i;
00115 for(i=0;i<wrk->backmax_num;i++) wrk->backmax[i] = 0;
00116 }
00117
00124
00125
00126
00127
00128
00129
00130
00131
00132 static void
00133 make_backmax(HMMWork *wrk)
00134 {
00135 int i;
00136
00137 wrk->backmax[wrk->backmax_num-1] = 0.0;
00138
00139 for(i=wrk->backmax_num-2;i>=0;i--) {
00140 wrk->backmax[i] += wrk->backmax[i+1];
00141 }
00142
00143
00144
00145 }
00146
00160 static LOGPROB
00161 compute_g_heu_updating(HMMWork *wrk, HTK_HMM_Dens *binfo)
00162 {
00163 VECT tmp, x, sum = 0.0;
00164 VECT *mean;
00165 VECT *var;
00166 VECT *bm = wrk->backmax;
00167 VECT *vec = wrk->OP_vec;
00168 short veclen = wrk->OP_veclen;
00169
00170 if (binfo == NULL) return(LOG_ZERO);
00171 mean = binfo->mean;
00172 var = binfo->var->vec;
00173
00174 tmp = 0.0;
00175 for (; veclen > 0; veclen--) {
00176 x = *(vec++) - *(mean++);
00177 tmp = x * x * *(var++);
00178 sum += tmp;
00179 if ( *bm < tmp) *bm = tmp;
00180 bm++;
00181 }
00182 return((sum + binfo->gconst) * -0.5);
00183 }
00184
00199 static LOGPROB
00200 compute_g_heu_pruning(HMMWork *wrk, HTK_HMM_Dens *binfo, LOGPROB thres)
00201 {
00202 VECT tmp, x;
00203 VECT *mean;
00204 VECT *var;
00205 VECT *bm = wrk->backmax;
00206 VECT *vec = wrk->OP_vec;
00207 short veclen = wrk->OP_veclen;
00208 LOGPROB fthres;
00209
00210 if (binfo == NULL) return(LOG_ZERO);
00211 mean = binfo->mean;
00212 var = binfo->var->vec;
00213 fthres = thres * (-2.0);
00214
00215 tmp = 0.0;
00216 bm++;
00217 for (; veclen > 0; veclen--) {
00218 x = *(vec++) - *(mean++);
00219 tmp += x * x * *(var++);
00220 if ( tmp + *bm > fthres) {
00221 return LOG_ZERO;
00222 }
00223 bm++;
00224 }
00225 return((tmp + binfo->gconst) * -0.5);
00226 }
00227
00228
00236 boolean
00237 gprune_heu_init(HMMWork *wrk)
00238 {
00239 int i;
00240
00241 wrk->OP_calced_maxnum = wrk->OP_hmminfo->maxmixturenum;
00242 wrk->OP_calced_score = (LOGPROB *)mymalloc(sizeof(LOGPROB) * wrk->OP_gprune_num);
00243 wrk->OP_calced_id = (int *)mymalloc(sizeof(int) * wrk->OP_gprune_num);
00244 wrk->mixcalced = (boolean *)mymalloc(sizeof(int) * wrk->OP_calced_maxnum);
00245 for(i=0;i<wrk->OP_calced_maxnum;i++) wrk->mixcalced[i] = FALSE;
00246 wrk->backmax_num = wrk->OP_hmminfo->opt.vec_size + 1;
00247 wrk->backmax = (LOGPROB *)mymalloc(sizeof(LOGPROB) * wrk->backmax_num);
00248
00249 return TRUE;
00250 }
00251
00258 void
00259 gprune_heu_free(HMMWork *wrk)
00260 {
00261 free(wrk->OP_calced_score);
00262 free(wrk->OP_calced_id);
00263 free(wrk->mixcalced);
00264 free(wrk->backmax);
00265 }
00266
00292 void
00293 gprune_heu(HMMWork *wrk, HTK_HMM_Dens **g, int gnum, int *last_id)
00294 {
00295 int i, j, num = 0;
00296 LOGPROB score, thres;
00297
00298 if (last_id != NULL) {
00299
00300 init_backmax(wrk);
00301
00302 for (j=0; j<wrk->OP_gprune_num; j++) {
00303 i = last_id[j];
00304 score = compute_g_heu_updating(wrk, g[i]);
00305 num = cache_push(wrk, i, score, num);
00306 wrk->mixcalced[i] = TRUE;
00307 }
00308
00309 make_backmax(wrk);
00310
00311 thres = wrk->OP_calced_score[num-1];
00312 for (i = 0; i < gnum; i++) {
00313
00314 if (wrk->mixcalced[i]) {
00315 wrk->mixcalced[i] = FALSE;
00316 continue;
00317 }
00318
00319 score = compute_g_heu_pruning(wrk, g[i], thres);
00320 if (score > LOG_ZERO) {
00321 num = cache_push(wrk, i, score, num);
00322 thres = wrk->OP_calced_score[num-1];
00323 }
00324 }
00325 } else {
00326
00327
00328 thres = LOG_ZERO;
00329 for (i = 0; i < gnum; i++) {
00330 if (num < wrk->OP_gprune_num) {
00331 score = compute_g_base(wrk, g[i]);
00332 } else {
00333 score = compute_g_safe(wrk, g[i], thres);
00334 if (score <= thres) continue;
00335 }
00336 num = cache_push(wrk, i, score, num);
00337 thres = wrk->OP_calced_score[num-1];
00338 }
00339 }
00340 wrk->OP_calced_num = num;
00341 }