00001
00018
00019
00020
00021
00022
00023
00024
00025 #include <sent/stddefs.h>
00026 #include <sent/htk_hmm.h>
00027 #include <sent/htk_param.h>
00028 #include <sent/hmm.h>
00029 #include <sent/hmm_calc.h>
00030
00031
00032 #define GS_MAX_PROB
00033 #define LAST_BEST
00034
00035
00042 void
00043 gms_gprune_init(HMMWork *wrk)
00044 {
00045 wrk->gms_last_max_id = (int *)mymalloc(sizeof(int) * wrk->gsset_num);
00046 }
00047
00054 void
00055 gms_gprune_prepare(HMMWork *wrk)
00056 {
00057 int i;
00058 for(i=0;i<wrk->gsset_num;i++) {
00059 wrk->gms_last_max_id[i] = -1;
00060 }
00061 }
00062
00069 void
00070 gms_gprune_free(HMMWork *wrk)
00071 {
00072 free(wrk->gms_last_max_id);
00073 }
00074
00075
00076
00086 static LOGPROB
00087 calc_contprob_with_safe_pruning(HMMWork *wrk, HTK_HMM_Dens *binfo, LOGPROB thres)
00088 {
00089 LOGPROB tmp, x;
00090 VECT *mean;
00091 VECT *var;
00092 LOGPROB fthres = thres * (-2.0);
00093 VECT *vec = wrk->OP_vec;
00094 short veclen = wrk->OP_veclen;
00095
00096 if (binfo == NULL) return(LOG_ZERO);
00097 mean = binfo->mean;
00098 var = binfo->var->vec;
00099
00100 tmp = binfo->gconst;
00101 for (; veclen > 0; veclen--) {
00102 x = *(vec++) - *(mean++);
00103 tmp += x * x * *(var++);
00104 if ( tmp > fthres) {
00105 return LOG_ZERO;
00106 }
00107 }
00108 return(tmp * -0.5);
00109 }
00110
00111 #ifdef LAST_BEST
00112
00124 static LOGPROB
00125 compute_g_max(HMMWork *wrk, HTK_HMM_State *stateinfo, int last_maxi, int *maxi_ret)
00126 {
00127 int i, maxi;
00128 LOGPROB prob;
00129 LOGPROB maxprob = LOG_ZERO;
00130
00131 if (last_maxi != -1) {
00132 maxi = last_maxi;
00133 maxprob = calc_contprob_with_safe_pruning(wrk, stateinfo->b[maxi], LOG_ZERO);
00134 for (i = stateinfo->mix_num - 1; i >= 0; i--) {
00135 if (i == last_maxi) continue;
00136 prob = calc_contprob_with_safe_pruning(wrk, stateinfo->b[i], maxprob);
00137 if (prob > maxprob) {
00138 maxprob = prob;
00139 maxi = i;
00140 }
00141 }
00142 *maxi_ret = maxi;
00143 } else {
00144 maxi = stateinfo->mix_num - 1;
00145 maxprob = calc_contprob_with_safe_pruning(wrk, stateinfo->b[maxi], LOG_ZERO);
00146 i = maxi - 1;
00147 for (; i >= 0; i--) {
00148 prob = calc_contprob_with_safe_pruning(wrk, stateinfo->b[i], maxprob);
00149 if (prob > maxprob) {
00150 maxprob = prob;
00151 maxi = i;
00152 }
00153 }
00154 *maxi_ret = maxi;
00155 }
00156
00157 return((maxprob + stateinfo->bweight[maxi]) * INV_LOG_TEN);
00158 }
00159
00160 #else
00161
00171 static LOGPROB
00172 compute_g_max(HMMWork *wrk, HTK_HMM_State *stateinfo)
00173 {
00174 int i, maxi;
00175 LOGPROB prob;
00176 LOGPROB maxprob = LOG_ZERO;
00177
00178 i = maxi = stateinfo->mix_num - 1;
00179 for (; i >= 0; i--) {
00180 prob = calc_contprob_with_safe_pruning(wrk, stateinfo->b[i], maxprob);
00181 if (prob > maxprob) {
00182 maxprob = prob;
00183 maxi = i;
00184 }
00185 }
00186 return((maxprob + stateinfo->bweight[maxi]) * INV_LOG_TEN);
00187 }
00188 #endif
00189
00190
00191
00192
00193
00203 void
00204 compute_gs_scores(HMMWork *wrk)
00205 {
00206 int i;
00207 #ifdef LAST_BEST
00208 int max_id;
00209 #endif
00210
00211 for (i=0;i<wrk->gsset_num;i++) {
00212 #ifdef GS_MAX_PROB
00213 #ifdef LAST_BEST
00214
00215 wrk->t_fs[i] = compute_g_max(wrk, wrk->gsset[i].state, wrk->gms_last_max_id[i], &max_id);
00216 wrk->gms_last_max_id[i] = max_id;
00217 #else
00218 wrk->t_fs[i] = compute_g_max(wrk, wrk->gsset[i].state);
00219 #endif
00220 #else
00221
00222 wrk->t_fs[i] = compute_g_base(wrk, wrk->gsset[i].state);
00223 #endif
00224
00225 }
00226
00227 }