00001
00018
00019
00020
00021
00022
00023
00024
00025 #include <sent/stddefs.h>
00026 #include <sent/htk_hmm.h>
00027 #include <sent/htk_param.h>
00028 #include <sent/hmm.h>
00029 #include <sent/hmm_calc.h>
00030
00031
00032 #define GS_MAX_PROB
00033 #define LAST_BEST
00034
00035
00042 void
00043 gms_gprune_init(HMMWork *wrk)
00044 {
00045 int i;
00046 wrk->gms_last_max_id_list = (int **)mymalloc(sizeof(int *) * wrk->gsset_num);
00047 for(i=0;i<wrk->gsset_num;i++) {
00048 wrk->gms_last_max_id_list[i] = (int *)mymalloc(sizeof(int) * wrk->OP_nstream);
00049 }
00050 }
00051
00058 void
00059 gms_gprune_prepare(HMMWork *wrk)
00060 {
00061 int i, j;
00062 for(i=0;i<wrk->gsset_num;i++) {
00063 for(j=0;j<wrk->OP_nstream;j++) {
00064 wrk->gms_last_max_id_list[i][j] = -1;
00065 }
00066 }
00067 }
00068
00075 void
00076 gms_gprune_free(HMMWork *wrk)
00077 {
00078 int i;
00079 for(i=0;i<wrk->gsset_num;i++) free(wrk->gms_last_max_id_list[i]);
00080 free(wrk->gms_last_max_id_list);
00081 }
00082
00083
00084
00094 static LOGPROB
00095 calc_contprob_with_safe_pruning(HMMWork *wrk, HTK_HMM_Dens *binfo, LOGPROB thres)
00096 {
00097 LOGPROB tmp, x;
00098 VECT *mean;
00099 VECT *var;
00100 LOGPROB fthres = thres * (-2.0);
00101 VECT *vec = wrk->OP_vec;
00102 short veclen = wrk->OP_veclen;
00103
00104 if (binfo == NULL) return(LOG_ZERO);
00105 mean = binfo->mean;
00106 var = binfo->var->vec;
00107
00108 tmp = binfo->gconst;
00109 for (; veclen > 0; veclen--) {
00110 x = *(vec++) - *(mean++);
00111 tmp += x * x * *(var++);
00112 if ( tmp > fthres) {
00113 return LOG_ZERO;
00114 }
00115 }
00116 return(tmp * -0.5);
00117 }
00118
00119 #ifdef LAST_BEST
00120
00132 static LOGPROB
00133 compute_g_max(HMMWork *wrk, HTK_HMM_State *stateinfo, int *last_maxi)
00134 {
00135 int i, maxi;
00136 LOGPROB prob;
00137 LOGPROB maxprob = LOG_ZERO;
00138 int s;
00139 PROB stream_weight;
00140 LOGPROB logprobsum;
00141
00142 logprobsum = 0.0;
00143 for(s=0;s<wrk->OP_nstream;s++) {
00144
00145 if (stateinfo->w) stream_weight = stateinfo->w->weight[s];
00146 else stream_weight = 1.0;
00147
00148 wrk->OP_vec = wrk->OP_vec_stream[s];
00149 wrk->OP_veclen = wrk->OP_veclen_stream[s];
00150
00151 if (last_maxi[s] != -1) {
00152 maxi = last_maxi[s];
00153 maxprob = calc_contprob_with_safe_pruning(wrk, stateinfo->pdf[s]->b[maxi], LOG_ZERO);
00154 for (i = stateinfo->pdf[s]->mix_num - 1; i >= 0; i--) {
00155 if (i == last_maxi[s]) continue;
00156 prob = calc_contprob_with_safe_pruning(wrk, stateinfo->pdf[s]->b[i], maxprob);
00157 if (prob > maxprob) {
00158 maxprob = prob;
00159 maxi = i;
00160 }
00161 }
00162 last_maxi[s] = maxi;
00163 } else {
00164 maxi = stateinfo->pdf[s]->mix_num - 1;
00165 maxprob = calc_contprob_with_safe_pruning(wrk, stateinfo->pdf[s]->b[maxi], LOG_ZERO);
00166 for (i = maxi - 1; i >= 0; i--) {
00167 prob = calc_contprob_with_safe_pruning(wrk, stateinfo->pdf[s]->b[i], maxprob);
00168 if (prob > maxprob) {
00169 maxprob = prob;
00170 maxi = i;
00171 }
00172 }
00173 last_maxi[s] = maxi;
00174 }
00175 logprobsum += (maxprob + stateinfo->pdf[s]->bweight[maxi]) * stream_weight;
00176 }
00177 return (logprobsum * INV_LOG_TEN);
00178 }
00179
00180 #else
00181
00191 static LOGPROB
00192 compute_g_max(HMMWork *wrk, HTK_HMM_State *stateinfo)
00193 {
00194 int i, maxi;
00195 LOGPROB prob;
00196 LOGPROB maxprob = LOG_ZERO;
00197 int s;
00198 PROB stream_weight;
00199 LOGPROB logprob, logprobsum;
00200
00201 logprobsum = 0.0;
00202 for(s=0;s<wrk->OP_nstream;s++) {
00203
00204 if (stateinfo->w) stream_weight = stateinfo->w->weight[s];
00205 else stream_weight = 1.0;
00206
00207 wrk->OP_vec = wrk->OP_vec_stream[s];
00208 wrk->OP_veclen = wrk->OP_veclen_stream[s];
00209
00210 i = maxi = stateinfo->pdf[s]->mix_num - 1;
00211 for (; i >= 0; i--) {
00212 prob = calc_contprob_with_safe_pruning(wrk, stateinfo->pdf[s]->b[i], maxprob);
00213 if (prob > maxprob) {
00214 maxprob = prob;
00215 maxi = i;
00216 }
00217 }
00218 logprobsum += (maxprob + stateinfo->pdf[s]->bweight[maxi]) * stream_weight;
00219 }
00220 return (logprobsum * INV_LOG_TEN);
00221 }
00222 #endif
00223
00224
00225
00226
00227
00237 void
00238 compute_gs_scores(HMMWork *wrk)
00239 {
00240 int i;
00241
00242 for (i=0;i<wrk->gsset_num;i++) {
00243 #ifdef GS_MAX_PROB
00244 #ifdef LAST_BEST
00245
00246 wrk->t_fs[i] = compute_g_max(wrk, wrk->gsset[i].state, wrk->gms_last_max_id_list[i]);
00247 #else
00248 wrk->t_fs[i] = compute_g_max(wrk, wrk->gsset[i].state);
00249 #endif
00250 #else
00251
00252 wrk->t_fs[i] = compute_g_base(wrk, wrk->gsset[i].state);
00253 #endif
00254
00255 }
00256
00257 }