00001
00041
00042
00043
00044
00045
00046
00047 #include <julius.h>
00048
00049 #undef MES
00050
00051 static LOGPROB *gmm_score;
00052 static int framecount;
00053
00054
00055 static LOGPROB *OP_calced_score;
00056 static int *OP_calced_id;
00057 static int OP_calced_num;
00058 static int OP_calced_maxnum;
00059 static int OP_gprune_num;
00060 static VECT *OP_vec;
00061 static short OP_veclen;
00062
00082 static int
00083 gmm_find_insert_point(LOGPROB score, int len)
00084 {
00085
00086 int left = 0;
00087 int right = len - 1;
00088 int mid;
00089
00090 while (left < right) {
00091 mid = (left + right) / 2;
00092 if (OP_calced_score[mid] > score) {
00093 left = mid + 1;
00094 } else {
00095 right = mid;
00096 }
00097 }
00098 return(left);
00099 }
00100
00121 static int
00122 gmm_cache_push(int id, LOGPROB score, int len)
00123 {
00124 int insertp;
00125
00126 if (len == 0) {
00127 OP_calced_score[0] = score;
00128 OP_calced_id[0] = id;
00129 return(1);
00130 }
00131 if (OP_calced_score[len-1] >= score) {
00132 if (len < OP_gprune_num) {
00133 OP_calced_score[len] = score;
00134 OP_calced_id[len] = id;
00135 len++;
00136 }
00137 return len;
00138 }
00139 if (OP_calced_score[0] < score) {
00140 insertp = 0;
00141 } else {
00142 insertp = gmm_find_insert_point(score, len);
00143 }
00144 if (len < OP_gprune_num) {
00145 memmove(&(OP_calced_score[insertp+1]), &(OP_calced_score[insertp]), sizeof(LOGPROB)*(len - insertp));
00146 memmove(&(OP_calced_id[insertp+1]), &(OP_calced_id[insertp]), sizeof(int)*(len - insertp));
00147 } else if (insertp < len - 1) {
00148 memmove(&(OP_calced_score[insertp+1]), &(OP_calced_score[insertp]), sizeof(LOGPROB)*(len - insertp - 1));
00149 memmove(&(OP_calced_id[insertp+1]), &(OP_calced_id[insertp]), sizeof(int)*(len - insertp - 1));
00150 }
00151 OP_calced_score[insertp] = score;
00152 OP_calced_id[insertp] = id;
00153 if (len < OP_gprune_num) len++;
00154 return(len);
00155 }
00156
00175 static LOGPROB
00176 gmm_compute_g_base(HTK_HMM_Dens *binfo)
00177 {
00178 VECT tmp, x;
00179 VECT *mean;
00180 VECT *var;
00181 VECT *vec = OP_vec;
00182 short veclen = OP_veclen;
00183
00184 if (binfo == NULL) return(LOG_ZERO);
00185 mean = binfo->mean;
00186 var = binfo->var->vec;
00187 tmp = 0.0;
00188 for (; veclen > 0; veclen--) {
00189 x = *(vec++) - *(mean++);
00190 tmp += x * x / *(var++);
00191 }
00192 return((tmp + binfo->gconst) / -2.0);
00193 }
00194
00215 static LOGPROB
00216 gmm_compute_g_safe(HTK_HMM_Dens *binfo, LOGPROB thres)
00217 {
00218 VECT tmp, x;
00219 VECT *mean;
00220 VECT *var;
00221 VECT *vec = OP_vec;
00222 short veclen = OP_veclen;
00223 VECT fthres = thres * (-2.0);
00224
00225 if (binfo == NULL) return(LOG_ZERO);
00226 mean = binfo->mean;
00227 var = binfo->var->vec;
00228 tmp = binfo->gconst;
00229 for (; veclen > 0; veclen--) {
00230 x = *(vec++) - *(mean++);
00231 tmp += x * x / *(var++);
00232 if (tmp > fthres) return LOG_ZERO;
00233 }
00234 return(tmp / -2.0);
00235 }
00236
00251 static void
00252 gmm_gprune_safe_init(HTK_HMM_INFO *hmminfo, int prune_num)
00253 {
00254
00255 OP_gprune_num = prune_num;
00256
00257 OP_calced_maxnum = hmminfo->maxmixturenum;
00258
00259 OP_calced_score = (LOGPROB *)mymalloc(sizeof(LOGPROB) * OP_gprune_num);
00260 OP_calced_id = (int *)mymalloc(sizeof(int) * OP_gprune_num);
00261 }
00262
00288 static void
00289 gmm_gprune_safe(HTK_HMM_Dens **g, int gnum)
00290 {
00291 int i, num = 0;
00292 LOGPROB score, thres;
00293
00294 thres = LOG_ZERO;
00295 for (i = 0; i < gnum; i++) {
00296 if (num < OP_gprune_num) {
00297 score = gmm_compute_g_base(g[i]);
00298 } else {
00299 score = gmm_compute_g_safe(g[i], thres);
00300 if (score <= thres) continue;
00301 }
00302 num = gmm_cache_push(i, score, num);
00303 thres = OP_calced_score[num-1];
00304 }
00305 OP_calced_num = num;
00306 }
00307
00308
00325 static LOGPROB
00326 gmm_calc_mix(HTK_HMM_State *s)
00327 {
00328 int i;
00329 LOGPROB logprob = LOG_ZERO;
00330
00331
00332 gmm_gprune_safe(s->b, s->mix_num);
00333
00334
00335
00336
00337
00338 for(i=0;i<OP_calced_num;i++) {
00339 OP_calced_score[i] += s->bweight[OP_calced_id[i]];
00340 }
00341 logprob = addlog_array(OP_calced_score, OP_calced_num);
00342 if (logprob <= LOG_ZERO) return LOG_ZERO;
00343 return (logprob / LOG_TEN);
00344 }
00345
00367 static LOGPROB
00368 outprob_state_nocache(int t, HTK_HMM_State *stateinfo, HTK_Param *param)
00369 {
00370
00371 OP_vec = param->parvec[t];
00372 OP_veclen = param->veclen;
00373 return(gmm_calc_mix(stateinfo));
00374 }
00375
00376
00377
00378
00379
00396 void
00397 gmm_init(HTK_HMM_INFO *gmm, int gmm_prune_num)
00398 {
00399 HTK_HMM_Data *d;
00400
00401
00402
00403 if (gmm->is_tied_mixture) {
00404 j_exit("Error: mixture-tying GMM is not supported yet.\n");
00405 }
00406
00407 for(d=gmm->start;d;d=d->next) {
00408 if (d->state_num > 3) {
00409 j_exit("Error: GMM has more than 1 output state! [%s]\n", d->name);
00410 }
00411 }
00412
00413
00414
00415
00416 gmm_score = (LOGPROB *)mymalloc(sizeof(LOGPROB) * gmm->totalhmmnum);
00417
00418
00419 gmm_gprune_safe_init(gmm, gmm_prune_num);
00420 }
00421
00435 void
00436 gmm_prepare(HTK_HMM_INFO *gmm)
00437 {
00438 HTK_HMM_Data *d;
00439 int i;
00440
00441
00442 i = 0;
00443 for(d=gmm->start;d;d=d->next) {
00444 gmm_score[i] = 0.0;
00445 i++;
00446 }
00447 framecount = 0;
00448 }
00449
00468 void
00469 gmm_proceed(HTK_HMM_INFO *gmm, HTK_Param *param, int t)
00470 {
00471 HTK_HMM_Data *d;
00472 int i;
00473
00474 framecount++;
00475 i = 0;
00476 for(d=gmm->start;d;d=d->next) {
00477 gmm_score[i] += outprob_state_nocache(t, d->s[1], param);
00478 #ifdef MES
00479 printf("[%d:total=%f avg=%f]\n", i, gmm_score[i], gmm_score[i] / (float)framecount);
00480 #endif
00481 i++;
00482 }
00483 }
00484
00485 static HTK_HMM_Data *max_d;
00486 #ifdef CONFIDENCE_MEASURE
00487 static LOGPROB gmm_max_cm;
00488 #endif
00489 static HTK_HMM_INFO *gmm_local;
00490
00511 void
00512 gmm_end(HTK_HMM_INFO *gmm)
00513 {
00514 HTK_HMM_Data *d;
00515 LOGPROB maxprob, sum;
00516 int i;
00517
00518
00519 i = 0;
00520 maxprob = LOG_ZERO;
00521 for(d=gmm->start;d;d=d->next) {
00522 if (maxprob < gmm_score[i]) {
00523 max_d = d;
00524 maxprob = gmm_score[i];
00525 }
00526 i++;
00527 }
00528 #ifdef CONFIDENCE_MEASURE
00529
00530 sum = 0.0;
00531 i = 0;
00532 for(d=gmm->start;d;d=d->next) {
00533 sum += pow(10, cm_alpha * (gmm_score[i] - maxprob));
00534 i++;
00535 }
00536 gmm_max_cm = 1.0 / sum;
00537 #endif
00538
00539
00540 gmm_local = gmm;
00541 result_gmm();
00542 }
00543
00561 boolean
00562 gmm_valid_input()
00563 {
00564 if (max_d == NULL) return FALSE;
00565 if (strstr(gmm_reject_cmn_string, max_d->name)) {
00566 return FALSE;
00567 }
00568 return TRUE;
00569 }
00570
00571
00572
00573
00583 void
00584 ttyout_gmm(){
00585 HTK_HMM_Data *d;
00586 int i;
00587
00588 if (debug2_flag) {
00589 j_printf("--- GMM result begin ---\n");
00590 i = 0;
00591 for(d=gmm_local->start;d;d=d->next) {
00592 j_printf(" [%8s: total=%f avg=%f]\n", d->name, gmm_score[i], gmm_score[i] / (float)framecount);
00593 i++;
00594 }
00595 j_printf(" max = \"%s\"", max_d->name);
00596 #ifdef CONFIDENCE_MEASURE
00597 j_printf(" (CM: %f)", gmm_max_cm);
00598 #endif
00599 j_printf("\n");
00600 j_printf("--- GMM result end ---\n");
00601 } else if (verbose_flag) {
00602 j_printf("GMM: max = \"%s\"", max_d->name);
00603 #ifdef CONFIDENCE_MEASURE
00604 j_printf(" (CM: %f)", gmm_max_cm);
00605 #endif
00606 j_printf("\n");
00607 } else {
00608 j_printf("[GMM: %s]\n", max_d->name);
00609 }
00610 }
00611
00621 void
00622 msock_gmm()
00623 {
00624 module_send(module_sd, "<GMM RESULT=\"%s\"", max_d->name);
00625 #ifdef CONFIDENCE_MEASURE
00626 module_send(module_sd, " CMSCORE=\"%f\"", gmm_max_cm);
00627 #endif
00628 module_send(module_sd, "/>\n.\n");
00629 }