00001
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067 #include <sent/stddefs.h>
00068 #include <sent/htk_hmm.h>
00069 #include <sent/htk_param.h>
00070 #include <sent/hmm.h>
00071 #include <sent/gprune.h>
00072 #include "globalvars.h"
00073
00074 #undef NORMALIZE_GS_SCORE
00075
00076
00077
00078
00079
00080 static int my_nbest;
00081 static int allocframenum;
00082
00083
00084 static GS_SET *gsset;
00085 static int gsset_num;
00086 static int *state2gs;
00087
00088
00089 static boolean *is_selected;
00090 static LOGPROB **fallback_score = NULL;
00091
00092
00093 static int *gsindex;
00094 static LOGPROB *t_fs;
00095
00096
00101 static void
00102 build_gsset()
00103 {
00104 HTK_HMM_State *st;
00105
00106
00107 gsset = (GS_SET *)mymalloc(sizeof(GS_SET) * OP_gshmm->totalstatenum);
00108 gsset_num = OP_gshmm->totalstatenum;
00109
00110 for(st = OP_gshmm->ststart; st; st=st->next) {
00111 gsset[st->id].state = st;
00112 }
00113 }
00114
00115 #define MAXHMMNAMELEN 40
00116
00117
00122 static boolean
00123 build_state2gs()
00124 {
00125 HTK_HMM_Data *dt;
00126 HTK_HMM_State *st, *cr;
00127 int i;
00128 char gstr[MAXHMMNAMELEN], cbuf[MAXHMMNAMELEN];
00129 boolean ok_p = TRUE;
00130
00131
00132 state2gs = (int *)mymalloc(sizeof(int) * OP_hmminfo->totalstatenum);
00133 for(i=0;i<OP_hmminfo->totalstatenum;i++) state2gs[i] = -1;
00134
00135
00136 for(dt = OP_hmminfo->start; dt; dt=dt->next) {
00137 if (strlen(dt->name) >= MAXHMMNAMELEN - 2) {
00138 j_printerr("Error: too long hmm name (>%d): \"%s\"\n",
00139 MAXHMMNAMELEN-3, dt->name);
00140 ok_p = FALSE;
00141 continue;
00142 }
00143 for(i=1;i<dt->state_num-1;i++) {
00144 st = dt->s[i];
00145
00146 if (state2gs[st->id] != -1) continue;
00147
00148 sprintf(gstr, "%s%dm", center_name(dt->name, cbuf), i + 1);
00149
00150 if ((cr = state_lookup(OP_gshmm, gstr)) == NULL) {
00151 j_printerr("Error: GS HMM \"%s\" not defined\n", gstr);
00152 ok_p = FALSE;
00153 continue;
00154 }
00155
00156 state2gs[st->id] = cr->id;
00157 }
00158 }
00159 #ifdef PARANOIA
00160 {
00161 HTK_HMM_State *st;
00162 for(st=OP_hmminfo->ststart; st; st=st->next) {
00163 printf("%s -> %s\n", (st->name == NULL) ? "(NULL)" : st->name,
00164 (gsset[state2gs[st->id]].state)->name);
00165 }
00166 }
00167 #endif
00168 return ok_p;
00169 }
00170
00171
00172
00173 #define SD(A) gsindex[A-1]
00174 #define SCOPY(D,S) D = S
00175 #define SVAL(A) (t_fs[gsindex[A-1]])
00176 #define STVAL (t_fs[s])
00177
00178
00184 static void
00185 sort_gsindex_upward(int neednum, int totalnum)
00186 {
00187 int n,root,child,parent;
00188 int s;
00189 for (root = totalnum/2; root >= 1; root--) {
00190 SCOPY(s, SD(root));
00191 parent = root;
00192 while ((child = parent * 2) <= totalnum) {
00193 if (child < totalnum && SVAL(child) < SVAL(child+1)) {
00194 child++;
00195 }
00196 if (STVAL >= SVAL(child)) {
00197 break;
00198 }
00199 SCOPY(SD(parent), SD(child));
00200 parent = child;
00201 }
00202 SCOPY(SD(parent), s);
00203 }
00204 n = totalnum;
00205 while ( n > totalnum - neednum) {
00206 SCOPY(s, SD(n));
00207 SCOPY(SD(n), SD(1));
00208 n--;
00209 parent = 1;
00210 while ((child = parent * 2) <= n) {
00211 if (child < n && SVAL(child) < SVAL(child+1)) {
00212 child++;
00213 }
00214 if (STVAL >= SVAL(child)) {
00215 break;
00216 }
00217 SCOPY(SD(parent), SD(child));
00218 parent = child;
00219 }
00220 SCOPY(SD(parent), s);
00221 }
00222 }
00223
00228 static void
00229 do_gms()
00230 {
00231 int i;
00232
00233
00234 compute_gs_scores(gsset, gsset_num, t_fs);
00235
00236 sort_gsindex_upward(my_nbest, gsset_num);
00237 for(i=gsset_num - my_nbest;i<gsset_num;i++) {
00238
00239 t_fs[gsindex[i]] = LOG_ZERO;
00240 }
00241
00242
00243 #ifdef NORMALIZE_GS_SCORE
00244
00245 for(i=0;i<gsset_num;i++) {
00246 if (t_fs[i] != LOG_ZERO) {
00247 t_fs[i] = t_fs[i] * 0.975;
00248 }
00249 }
00250 #endif
00251 }
00252
00253
00261 boolean
00262 gms_init(int nbest)
00263 {
00264 int i;
00265
00266
00267 if (OP_gshmm->is_triphone) {
00268 j_printerr("Error: GS HMM should be a monophone model\n");
00269 return FALSE;
00270 }
00271 if (OP_gshmm->is_tied_mixture) {
00272 j_printerr("Error: GS HMM should not be a tied mixture model\n");
00273 return FALSE;
00274 }
00275
00276
00277 my_nbest = nbest;
00278
00279
00280 build_gsset();
00281
00282 j_printerr("Mapping HMM states to GS HMM...");
00283 if (build_state2gs() == FALSE) {
00284 j_printerr("Error: failed in assigning GS HMM state for each state\n");
00285 return FALSE;
00286 }
00287 j_printerr("done\n");
00288
00289
00290 gsindex = (int *)mymalloc(sizeof(int) * gsset_num);
00291 for(i=0;i<gsset_num;i++) gsindex[i] = i;
00292
00293
00294 fallback_score = NULL;
00295 is_selected = NULL;
00296 allocframenum = -1;
00297
00298
00299 gms_gprune_init(OP_hmminfo, gsset_num);
00300
00301 return TRUE;
00302 }
00303
00311 boolean
00312 gms_prepare(int framenum)
00313 {
00314 LOGPROB *tmp;
00315 int t;
00316
00317
00318 if (allocframenum < framenum) {
00319 if (fallback_score != NULL) {
00320 free(fallback_score[0]);
00321 free(fallback_score);
00322 free(is_selected);
00323 }
00324 fallback_score = (LOGPROB **)mymalloc(sizeof(LOGPROB *) * framenum);
00325 tmp = (LOGPROB *)mymalloc(sizeof(LOGPROB) * gsset_num * framenum);
00326 for(t=0;t<framenum;t++) {
00327 fallback_score[t] = &(tmp[gsset_num * t]);
00328 }
00329 is_selected = (boolean *)mymalloc(sizeof(boolean) * framenum);
00330 allocframenum = framenum;
00331 }
00332
00333 for(t=0;t<framenum;t++) is_selected[t] = FALSE;
00334
00335
00336 gms_gprune_prepare();
00337
00338 return TRUE;
00339 }
00340
00351 LOGPROB
00352 gms_state()
00353 {
00354 LOGPROB gsprob;
00355 if (OP_last_time != OP_time) {
00356
00357 t_fs = fallback_score[OP_time];
00358
00359 if (!is_selected[OP_time]) {
00360 do_gms();
00361 is_selected[OP_time] = TRUE;
00362 }
00363 }
00364 if ((gsprob = t_fs[state2gs[OP_state_id]]) != LOG_ZERO) {
00365
00366 return(gsprob);
00367 }
00368
00369 return(calc_outprob());
00370 }