00001
00052
00053
00054
00055
00056
00057
00058
00059 #include <sent/stddefs.h>
00060 #include <sent/speech.h>
00061 #include <sent/htk_hmm.h>
00062 #include <sent/htk_param.h>
00063 #include <sent/hmm.h>
00064 #include <sent/hmm_calc.h>
00065
00066
00067
00068 #define LOG_UNDEF (LOG_ZERO - 1)
00069
00070
00077 boolean
00078 outprob_cache_init(HMMWork *wrk)
00079 {
00080 wrk->statenum = wrk->OP_hmminfo->totalstatenum;
00081 wrk->outprob_cache = NULL;
00082 wrk->outprob_allocframenum = 0;
00083 wrk->OP_time = -1;
00084 wrk->croot = NULL;
00085 return TRUE;
00086 }
00087
00095 boolean
00096 outprob_cache_prepare(HMMWork *wrk)
00097 {
00098 int s,t;
00099
00100
00101 for (t = 0; t < wrk->outprob_allocframenum; t++) {
00102 for (s = 0; s < wrk->statenum; s++) {
00103 wrk->outprob_cache[t][s] = LOG_UNDEF;
00104 }
00105 }
00106
00107 return TRUE;
00108 }
00109
00116 static void
00117 outprob_cache_extend(HMMWork *wrk, int reqframe)
00118 {
00119 int newnum;
00120 int size;
00121 int t, s;
00122 LOGPROB *tmpp;
00123
00124
00125 if (reqframe < wrk->outprob_allocframenum) return;
00126
00127
00128 newnum = reqframe + 1;
00129 if (newnum < wrk->outprob_allocframenum + OUTPROB_CACHE_PERIOD) newnum = wrk->outprob_allocframenum + OUTPROB_CACHE_PERIOD;
00130 size = (newnum - wrk->outprob_allocframenum) * wrk->statenum;
00131
00132
00133 if (wrk->outprob_cache == NULL) {
00134 wrk->outprob_cache = (LOGPROB **)mymalloc(sizeof(LOGPROB *) * newnum);
00135 } else {
00136 wrk->outprob_cache = (LOGPROB **)myrealloc(wrk->outprob_cache, sizeof(LOGPROB *) * newnum);
00137 }
00138 tmpp = (LOGPROB *)mybmalloc2(sizeof(LOGPROB) * size, &(wrk->croot));
00139
00140 for(t = wrk->outprob_allocframenum; t < newnum; t++) {
00141 wrk->outprob_cache[t] = &(tmpp[(t - wrk->outprob_allocframenum) * wrk->statenum]);
00142 for (s = 0; s < wrk->statenum; s++) {
00143 wrk->outprob_cache[t][s] = LOG_UNDEF;
00144 }
00145 }
00146
00147
00148 wrk->outprob_allocframenum = newnum;
00149 }
00150
00157 void
00158 outprob_cache_free(HMMWork *wrk)
00159 {
00160 if (wrk->croot != NULL) mybfree2(&(wrk->croot));
00161 if (wrk->outprob_cache != NULL) free(wrk->outprob_cache);
00162 }
00163
00164
00183 LOGPROB
00184 outprob_state(HMMWork *wrk, int t, HTK_HMM_State *stateinfo, HTK_Param *param)
00185 {
00186 LOGPROB outp;
00187 int sid;
00188 int i, d;
00189
00190 sid = stateinfo->id;
00191
00192
00193 wrk->OP_state = stateinfo;
00194 wrk->OP_state_id = sid;
00195 wrk->OP_param = param;
00196 if (wrk->OP_time != t) {
00197 wrk->OP_last_time = wrk->OP_time;
00198 wrk->OP_time = t;
00199 for(d=0,i=0;i<wrk->OP_nstream;i++) {
00200 wrk->OP_vec_stream[i] = &(param->parvec[t][d]);
00201 d += wrk->OP_veclen_stream[i];
00202 }
00203
00204 outprob_cache_extend(wrk, t);
00205 wrk->last_cache = wrk->outprob_cache[t];
00206 }
00207
00208
00209 if ((outp = wrk->last_cache[sid]) == LOG_UNDEF) {
00210 outp = wrk->last_cache[sid] = (*(wrk->calc_outprob_state))(wrk);
00211 }
00212 return(outp);
00213 }
00214
00221 void
00222 outprob_cd_nbest_init(HMMWork *wrk, int num)
00223 {
00224 wrk->cd_nbest_maxprobs = (LOGPROB *)mymalloc(sizeof(LOGPROB) * num);
00225 wrk->cd_nbest_maxn = num;
00226 }
00227
00234 void
00235 outprob_cd_nbest_free(HMMWork *wrk)
00236 {
00237 free(wrk->cd_nbest_maxprobs);
00238 }
00239
00250 static LOGPROB
00251 outprob_cd_nbest(HMMWork *wrk, int t, CD_State_Set *lset, HTK_Param *param)
00252 {
00253 LOGPROB prob;
00254 int i, k, n;
00255
00256 n = 0;
00257 for(i=0;i<lset->num;i++) {
00258 prob = outprob_state(wrk, t, lset->s[i], param);
00259
00260 if (prob <= LOG_ZERO) continue;
00261 if (n == 0 || prob <= wrk->cd_nbest_maxprobs[n-1]) {
00262 if (n == wrk->cd_nbest_maxn) continue;
00263 wrk->cd_nbest_maxprobs[n] = prob;
00264 n++;
00265 } else {
00266 for(k=0; k<n; k++) {
00267 if (prob > wrk->cd_nbest_maxprobs[k]) {
00268 memmove(&(wrk->cd_nbest_maxprobs[k+1]), &(wrk->cd_nbest_maxprobs[k]),
00269 sizeof(LOGPROB) * (n - k - ( (n == wrk->cd_nbest_maxn) ? 1 : 0)));
00270 wrk->cd_nbest_maxprobs[k] = prob;
00271 break;
00272 }
00273 }
00274 if (n < wrk->cd_nbest_maxn) n++;
00275 }
00276 }
00277 prob = 0.0;
00278 for(i=0;i<n;i++) {
00279
00280 prob += wrk->cd_nbest_maxprobs[i];
00281 }
00282 return(prob/(float)n);
00283 }
00284
00295 static LOGPROB
00296 outprob_cd_max(HMMWork *wrk, int t, CD_State_Set *lset, HTK_Param *param)
00297 {
00298 LOGPROB maxprob, prob;
00299 int i;
00300
00301 maxprob = LOG_ZERO;
00302 for(i=0;i<lset->num;i++) {
00303 prob = outprob_state(wrk, t, lset->s[i], param);
00304 if (maxprob < prob) maxprob = prob;
00305 }
00306 return(maxprob);
00307 }
00308
00319 static LOGPROB
00320 outprob_cd_avg(HMMWork *wrk, int t, CD_State_Set *lset, HTK_Param *param)
00321 {
00322 LOGPROB sum, p;
00323 int i,j;
00324 sum = 0.0;
00325 j = 0;
00326 for(i=0;i<lset->num;i++) {
00327 p = outprob_state(wrk, t, lset->s[i], param);
00328 if (p > LOG_ZERO) {
00329 sum += p;
00330 j++;
00331 }
00332 }
00333 return(sum/(float)j);
00334 }
00335
00346 LOGPROB
00347 outprob_cd(HMMWork *wrk, int t, CD_State_Set *lset, HTK_Param *param)
00348 {
00349 LOGPROB ret;
00350
00351
00352 switch(wrk->OP_hmminfo->cdset_method) {
00353 case IWCD_AVG:
00354 ret = outprob_cd_avg(wrk, t, lset, param);
00355 break;
00356 case IWCD_MAX:
00357 ret = outprob_cd_max(wrk, t, lset, param);
00358 break;
00359 case IWCD_NBEST:
00360 ret = outprob_cd_nbest(wrk, t, lset, param);
00361 break;
00362 }
00363 return(ret);
00364 }
00365
00366
00377 LOGPROB
00378 outprob(HMMWork *wrk, int t, HMM_STATE *hmmstate, HTK_Param *param)
00379 {
00380 if (hmmstate->is_pseudo_state) {
00381 return(outprob_cd(wrk, t, hmmstate->out.cdset, param));
00382 } else {
00383 return(outprob_state(wrk, t, hmmstate->out.state, param));
00384 }
00385 }