00001
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067
00068 #include <sent/stddefs.h>
00069 #include <sent/htk_hmm.h>
00070 #include <sent/htk_param.h>
00071 #include <sent/hmm.h>
00072 #include <sent/hmm_calc.h>
00073
00074 #undef NORMALIZE_GS_SCORE
00075
00076
00077
00078
00079
00086 static void
00087 build_gsset(HMMWork *wrk)
00088 {
00089 HTK_HMM_State *st;
00090
00091
00092 wrk->gsset = (GS_SET *)mymalloc(sizeof(GS_SET) * wrk->OP_gshmm->totalstatenum);
00093 wrk->gsset_num = wrk->OP_gshmm->totalstatenum;
00094
00095 for(st = wrk->OP_gshmm->ststart; st; st=st->next) {
00096 wrk->gsset[st->id].state = st;
00097 }
00098 }
00099
00106 static void
00107 free_gsset(HMMWork *wrk)
00108 {
00109 free(wrk->gsset);
00110 }
00111
00119 static boolean
00120 build_state2gs(HMMWork *wrk)
00121 {
00122 HTK_HMM_Data *dt;
00123 HTK_HMM_State *st, *cr;
00124 int i;
00125 char gstr[MAX_HMMNAME_LEN], cbuf[MAX_HMMNAME_LEN];
00126 boolean ok_p = TRUE;
00127
00128
00129 wrk->state2gs = (int *)mymalloc(sizeof(int) * wrk->OP_hmminfo->totalstatenum);
00130 for(i=0;i<wrk->OP_hmminfo->totalstatenum;i++) wrk->state2gs[i] = -1;
00131
00132
00133 for(dt = wrk->OP_hmminfo->start; dt; dt=dt->next) {
00134 if (strlen(dt->name) >= MAX_HMMNAME_LEN - 2) {
00135 jlog("Error: gms: too long hmm name (>%d): \"%s\"\n",
00136 MAX_HMMNAME_LEN-3, dt->name);
00137 jlog("Error: gms: change value of MAX_HMMNAME_LEN\n");
00138 ok_p = FALSE;
00139 continue;
00140 }
00141 for(i=1;i<dt->state_num-1;i++) {
00142 st = dt->s[i];
00143
00144 if (wrk->state2gs[st->id] != -1) continue;
00145
00146 sprintf(gstr, "%s%dm", center_name(dt->name, cbuf), i + 1);
00147
00148 if ((cr = state_lookup(wrk->OP_gshmm, gstr)) == NULL) {
00149 jlog("Error: gms: GS HMM \"%s\" not defined\n", gstr);
00150 ok_p = FALSE;
00151 continue;
00152 }
00153
00154 wrk->state2gs[st->id] = cr->id;
00155 }
00156 }
00157 #ifdef PARANOIA
00158 {
00159 HTK_HMM_State *st;
00160 for(st=wrk->OP_hmminfo->ststart; st; st=st->next) {
00161 printf("%s -> %s\n", (st->name == NULL) ? "(NULL)" : st->name,
00162 (wrk->gsset[wrk->state2gs[st->id]].state)->name);
00163 }
00164 }
00165 #endif
00166 return ok_p;
00167 }
00168
00175 static void
00176 free_state2gs(HMMWork *wrk)
00177 {
00178 free(wrk->state2gs);
00179 }
00180
00181
00182
00183 #define SD(A) idx[A-1]
00184 #define SCOPY(D,S) D = S
00185 #define SVAL(A) (fs[idx[A-1]])
00186 #define STVAL (fs[s])
00187
00188
00194 static void
00195 sort_gsindex_upward(HMMWork *wrk)
00196 {
00197 int n,root,child,parent;
00198 int s;
00199 int *idx;
00200 LOGPROB *fs;
00201 int neednum, totalnum;
00202
00203 idx = wrk->gsindex;
00204 fs = wrk->t_fs;
00205 neednum = wrk->my_nbest;
00206 totalnum = wrk->gsset_num;
00207
00208 for (root = totalnum/2; root >= 1; root--) {
00209 SCOPY(s, SD(root));
00210 parent = root;
00211 while ((child = parent * 2) <= totalnum) {
00212 if (child < totalnum && SVAL(child) < SVAL(child+1)) {
00213 child++;
00214 }
00215 if (STVAL >= SVAL(child)) {
00216 break;
00217 }
00218 SCOPY(SD(parent), SD(child));
00219 parent = child;
00220 }
00221 SCOPY(SD(parent), s);
00222 }
00223 n = totalnum;
00224 while ( n > totalnum - neednum) {
00225 SCOPY(s, SD(n));
00226 SCOPY(SD(n), SD(1));
00227 n--;
00228 parent = 1;
00229 while ((child = parent * 2) <= n) {
00230 if (child < n && SVAL(child) < SVAL(child+1)) {
00231 child++;
00232 }
00233 if (STVAL >= SVAL(child)) {
00234 break;
00235 }
00236 SCOPY(SD(parent), SD(child));
00237 parent = child;
00238 }
00239 SCOPY(SD(parent), s);
00240 }
00241 }
00242
00249 static void
00250 do_gms(HMMWork *wrk)
00251 {
00252 int i;
00253
00254
00255 compute_gs_scores(wrk);
00256
00257 sort_gsindex_upward(wrk);
00258 for(i=wrk->gsset_num - wrk->my_nbest;i<wrk->gsset_num;i++) {
00259
00260 wrk->t_fs[wrk->gsindex[i]] = LOG_ZERO;
00261 }
00262
00263
00264 #ifdef NORMALIZE_GS_SCORE
00265
00266 for(i=0;i<wrk->gsset_num;i++) {
00267 if (wrk->t_fs[i] != LOG_ZERO) {
00268 wrk->t_fs[i] *= 0.975;
00269 }
00270 }
00271 #endif
00272 }
00273
00274
00282 boolean
00283 gms_init(HMMWork *wrk)
00284 {
00285 int i;
00286
00287
00288 if (wrk->OP_gshmm->is_triphone) {
00289 jlog("Error: gms: GS HMM should be a monophone model\n");
00290 return FALSE;
00291 }
00292 if (wrk->OP_gshmm->is_tied_mixture) {
00293 jlog("Error: gms: GS HMM should not be a tied mixture model\n");
00294 return FALSE;
00295 }
00296
00297
00298 build_gsset(wrk);
00299
00300 if (build_state2gs(wrk) == FALSE) {
00301 jlog("Error: gms: failed in assigning GS HMM state for each state\n");
00302 return FALSE;
00303 }
00304 jlog("Stat: gms: GS HMMs are mapped to HMM states\n");
00305
00306
00307 wrk->gsindex = (int *)mymalloc(sizeof(int) * wrk->gsset_num);
00308 for(i=0;i<wrk->gsset_num;i++) wrk->gsindex[i] = i;
00309
00310
00311 wrk->fallback_score = NULL;
00312 wrk->gms_is_selected = NULL;
00313 wrk->gms_allocframenum = -1;
00314
00315
00316 gms_gprune_init(wrk);
00317
00318 return TRUE;
00319 }
00320
00329 boolean
00330 gms_prepare(HMMWork *wrk, int framenum)
00331 {
00332 LOGPROB *tmp;
00333 int t;
00334
00335
00336 if (wrk->gms_allocframenum < framenum) {
00337 if (wrk->fallback_score != NULL) {
00338 free(wrk->fallback_score[0]);
00339 free(wrk->fallback_score);
00340 free(wrk->gms_is_selected);
00341 }
00342 wrk->fallback_score = (LOGPROB **)mymalloc(sizeof(LOGPROB *) * framenum);
00343 tmp = (LOGPROB *)mymalloc(sizeof(LOGPROB) * wrk->gsset_num * framenum);
00344 for(t=0;t<framenum;t++) {
00345 wrk->fallback_score[t] = &(tmp[wrk->gsset_num * t]);
00346 }
00347 wrk->gms_is_selected = (boolean *)mymalloc(sizeof(boolean) * framenum);
00348 wrk->gms_allocframenum = framenum;
00349 }
00350
00351 for(t=0;t<framenum;t++) wrk->gms_is_selected[t] = FALSE;
00352
00353
00354 gms_gprune_prepare(wrk);
00355
00356 return TRUE;
00357 }
00358
00365 void
00366 gms_free(HMMWork *wrk)
00367 {
00368 free_gsset(wrk);
00369 free_state2gs(wrk);
00370 free(wrk->gsindex);
00371 if (wrk->fallback_score != NULL) {
00372 free(wrk->fallback_score[0]);
00373 free(wrk->fallback_score);
00374 free(wrk->gms_is_selected);
00375 }
00376 gms_gprune_free(wrk);
00377 }
00378
00379
00380
00393 LOGPROB
00394 gms_state(HMMWork *wrk)
00395 {
00396 LOGPROB gsprob;
00397 if (wrk->OP_last_time != wrk->OP_time) {
00398
00399 wrk->t_fs = wrk->fallback_score[wrk->OP_time];
00400
00401 if (!wrk->gms_is_selected[wrk->OP_time]) {
00402 do_gms(wrk);
00403 wrk->gms_is_selected[wrk->OP_time] = TRUE;
00404 }
00405 }
00406 if ((gsprob = wrk->t_fs[wrk->state2gs[wrk->OP_state_id]]) != LOG_ZERO) {
00407
00408 return(gsprob);
00409 }
00410
00411 return((*(wrk->calc_outprob))(wrk));
00412 }