00001
00051
00052
00053
00054
00055
00056
00057
00058 #include <sent/stddefs.h>
00059 #include <sent/speech.h>
00060 #include <sent/htk_hmm.h>
00061 #include <sent/htk_param.h>
00062 #include <sent/hmm.h>
00063 #include <sent/gprune.h>
00064 #include "globalvars.h"
00065
00066
00067
00068 static int statenum;
00069 static LOGPROB **outprob_cache = NULL;
00070 static int allocframenum;
00071 static int allocblock;
00072 static LOGPROB *last_cache;
00073 #define LOG_UNDEF (LOG_ZERO - 1)
00074
00075
00080 boolean
00081 outprob_cache_init()
00082 {
00083 statenum = OP_hmminfo->totalstatenum;
00084 outprob_cache = NULL;
00085 allocframenum = 0;
00086 allocblock = OUTPROB_CACHE_PERIOD;
00087 OP_time = -1;
00088 return TRUE;
00089 }
00090
00097 boolean
00098 outprob_cache_prepare()
00099 {
00100 int s,t;
00101
00102
00103 for (t = 0; t < allocframenum; t++) {
00104 for (s = 0; s < statenum; s++) {
00105 outprob_cache[t][s] = LOG_UNDEF;
00106 }
00107 }
00108
00109 return TRUE;
00110 }
00111
00117 static void
00118 outprob_cache_extend(int reqframe)
00119 {
00120 int newnum;
00121 int size;
00122 int t, s;
00123 LOGPROB *tmpp;
00124
00125
00126 if (reqframe < allocframenum) return;
00127
00128
00129 newnum = reqframe + 1;
00130 if (newnum < allocframenum + allocblock) newnum = allocframenum + allocblock;
00131 size = (newnum - allocframenum) * statenum;
00132
00133
00134 if (outprob_cache == NULL) {
00135 outprob_cache = (LOGPROB **)mymalloc(sizeof(LOGPROB *) * newnum);
00136 } else {
00137 outprob_cache = (LOGPROB **)myrealloc(outprob_cache, sizeof(LOGPROB *) * newnum);
00138 }
00139 tmpp = (LOGPROB *)mymalloc(sizeof(LOGPROB) * size);
00140
00141 for(t = allocframenum; t < newnum; t++) {
00142 outprob_cache[t] = &(tmpp[(t - allocframenum) * statenum]);
00143 for (s = 0; s < statenum; s++) {
00144 outprob_cache[t][s] = LOG_UNDEF;
00145 }
00146 }
00147
00148
00149 allocframenum = newnum;
00150 }
00151
00152
00166 LOGPROB
00167 outprob_state(
00168 int t,
00169 HTK_HMM_State *stateinfo,
00170 HTK_Param *param)
00171 {
00172 LOGPROB outp;
00173
00174
00175 OP_state = stateinfo;
00176 OP_state_id = stateinfo->id;
00177 OP_param = param;
00178 if (OP_time != t) {
00179 OP_last_time = OP_time;
00180 OP_time = t;
00181 OP_vec = param->parvec[t];
00182 OP_veclen = param->veclen;
00183
00184 outprob_cache_extend(t);
00185 last_cache = outprob_cache[t];
00186 }
00187
00188
00189 if ((outp = last_cache[OP_state_id]) == LOG_UNDEF) {
00190 outp = last_cache[OP_state_id] = calc_outprob_state();
00191 }
00192 return(outp);
00193 }
00194
00195 static LOGPROB *maxprobs;
00196 static int maxn;
00197
00203 void
00204 outprob_cd_nbest_init(int num)
00205 {
00206 maxprobs = (LOGPROB *)mymalloc(sizeof(LOGPROB) * num);
00207 maxn = num;
00208 }
00209
00219 static LOGPROB
00220 outprob_cd_nbest(int t, CD_State_Set *lset, HTK_Param *param)
00221 {
00222 LOGPROB prob;
00223 int i, k, n;
00224
00225 n = 0;
00226 for(i=0;i<lset->num;i++) {
00227 prob = outprob_state(t, lset->s[i], param);
00228
00229 if (prob <= LOG_ZERO) continue;
00230 if (n == 0 || prob <= maxprobs[n-1]) {
00231 if (n == maxn) continue;
00232 maxprobs[n] = prob;
00233 n++;
00234 } else {
00235 for(k=0; k<n; k++) {
00236 if (prob > maxprobs[k]) {
00237 memmove(&(maxprobs[k+1]), &(maxprobs[k]),
00238 sizeof(LOGPROB) * (n - k - ( (n == maxn) ? 1 : 0)));
00239 maxprobs[k] = prob;
00240 break;
00241 }
00242 }
00243 if (n < maxn) n++;
00244 }
00245 }
00246 prob = 0.0;
00247 for(i=0;i<n;i++) {
00248
00249 prob += maxprobs[i];
00250 }
00251 return(prob/(float)n);
00252 }
00253
00263 static LOGPROB
00264 outprob_cd_max(int t, CD_State_Set *lset, HTK_Param *param)
00265 {
00266 LOGPROB maxprob, prob;
00267 int i;
00268 maxprob = LOG_ZERO;
00269 for(i=0;i<lset->num;i++) {
00270 prob = outprob_state(t, lset->s[i], param);
00271 if (maxprob < prob) maxprob = prob;
00272 }
00273 return(maxprob);
00274 }
00275
00285 static LOGPROB
00286 outprob_cd_avg(int t, CD_State_Set *lset, HTK_Param *param)
00287 {
00288 LOGPROB sum, p;
00289 int i,j;
00290 sum = 0.0;
00291 j = 0;
00292 for(i=0;i<lset->num;i++) {
00293 p = outprob_state(t, lset->s[i], param);
00294 if (p > LOG_ZERO) {
00295 sum += p;
00296 j++;
00297 }
00298 }
00299 return(sum/(float)j);
00300 }
00301
00311 LOGPROB
00312 outprob_cd(int t, CD_State_Set *lset, HTK_Param *param)
00313 {
00314 LOGPROB ret;
00315
00316
00317 switch(OP_hmminfo->cdset_method) {
00318 case IWCD_AVG:
00319 ret = outprob_cd_avg(t, lset, param);
00320 break;
00321 case IWCD_MAX:
00322 ret = outprob_cd_max(t, lset, param);
00323 break;
00324 case IWCD_NBEST:
00325 ret = outprob_cd_nbest(t, lset, param);
00326 break;
00327 default:
00328 j_error("unknown cdhmm method!\n");
00329 ret = 0;
00330 break;
00331 }
00332 return(ret);
00333 }
00334
00335
00345 LOGPROB
00346 outprob(int t, HMM_STATE *hmmstate, HTK_Param *param)
00347 {
00348 if (hmmstate->is_pseudo_state) {
00349 return(outprob_cd(t, hmmstate->out.cdset, param));
00350 } else {
00351 return(outprob_state(t, hmmstate->out.state, param));
00352 }
00353 }