00001
00017
00018
00019
00020
00021
00022
00023
00024 #include <sent/stddefs.h>
00025 #include <sent/htk_hmm.h>
00026 #include <sent/htk_param.h>
00027 #include <sent/hmm.h>
00028 #include <sent/gprune.h>
00029 #include "globalvars.h"
00030
00031
00032 #define GS_MAX_PROB
00033 #define LAST_BEST
00034 #undef BEAM
00035 #define BEAM_OFFSET 10.0
00036
00037 #ifdef BEAM
00038 #define LAST_BEST
00039 #endif
00040 #ifdef LAST_BEST
00041 #define GS_MAX_PROB
00042 #endif
00043
00044
00045 static int my_gsset_num;
00046 static int *last_max_id;
00047 #ifdef BEAM
00048 static VECT *dimthres;
00049 static int dimthres_num;
00050 #endif
00051
00052
00059 void
00060 gms_gprune_init(HTK_HMM_INFO *hmminfo, int gsset_num)
00061 {
00062 my_gsset_num = gsset_num;
00063 last_max_id = (int *)mybmalloc(sizeof(int) * gsset_num);
00064 #ifdef BEAM
00065 dimthres_num = hmminfo->opt.vec_size;
00066 dimthres = (LOGPROB *)mybmalloc(sizeof(LOGPROB) * dimthres_num);
00067 #endif
00068 }
00069
00074 void
00075 gms_gprune_prepare()
00076 {
00077 int i;
00078 for(i=0;i<my_gsset_num;i++) {
00079 last_max_id[i] = -1;
00080 }
00081 }
00082
00083
00084
00085
00094 static LOGPROB
00095 calc_contprob_with_safe_pruning(HTK_HMM_Dens *binfo, LOGPROB thres)
00096 {
00097 LOGPROB tmp, x;
00098 VECT *mean;
00099 VECT *var;
00100 LOGPROB fthres = thres * (-2.0);
00101 VECT *vec = OP_vec;
00102 short veclen = OP_veclen;
00103
00104 if (binfo == NULL) return(LOG_ZERO);
00105 mean = binfo->mean;
00106 var = binfo->var->vec;
00107
00108 tmp = binfo->gconst;
00109 for (; veclen > 0; veclen--) {
00110 x = *(vec++) - *(mean++);
00111 tmp += x * x / *(var++);
00112 if ( tmp > fthres) {
00113 return LOG_ZERO;
00114 }
00115 }
00116 return(tmp / -2.0);
00117 }
00118
00119 #ifdef BEAM
00120
00129 static LOGPROB
00130 calc_contprob_with_beam_pruning_pre(HTK_HMM_Dens *binfo)
00131 {
00132 LOGPROB tmp, x;
00133 VECT *mean;
00134 VECT *var;
00135 VECT *th = dimthres;
00136 VECT *vec = OP_vec;
00137 short veclen = OP_veclen;
00138
00139 if (binfo == NULL) return(LOG_ZERO);
00140 mean = binfo->mean;
00141 var = binfo->var->vec;
00142
00143 tmp = 0.0;
00144 for (; veclen > 0; veclen--) {
00145 x = *(vec++) - *(mean++);
00146 tmp += x * x / *(var++);
00147 if ( *th < tmp) *th = tmp;
00148 th++;
00149 }
00150 return((tmp + binfo->gconst) / -2.0);
00151 }
00152
00161 static LOGPROB
00162 calc_contprob_with_beam_pruning_post(HTK_HMM_Dens *binfo)
00163 {
00164 LOGPROB tmp, x;
00165 LOGPROB *mean;
00166 LOGPROB *var;
00167 LOGPROB *th = dimthres;
00168 VECT *vec = OP_vec;
00169 short veclen = OP_veclen;
00170
00171 if (binfo == NULL) return(LOG_ZERO);
00172 mean = binfo->mean;
00173 var = binfo->var->vec;
00174
00175 tmp = 0.0;
00176 for (; veclen > 0; veclen--) {
00177 x = *(vec++) - *(mean++);
00178 tmp += x * x / *(var++);
00179 if ( tmp > *(th++)) {
00180 return LOG_ZERO;
00181 }
00182 }
00183 return((tmp + binfo->gconst) / -2.0);
00184 }
00185
00186 #endif
00187
00188 #ifdef LAST_BEST
00189
00200 static LOGPROB
00201 compute_g_max(HTK_HMM_State *stateinfo, int last_maxi, int *maxi_ret)
00202 {
00203 int i, maxi;
00204 LOGPROB prob;
00205 LOGPROB maxprob = LOG_ZERO;
00206
00207 if (last_maxi != -1) {
00208 maxi = last_maxi;
00209 #ifdef BEAM
00210
00211 for(i=0;i<dimthres_num;i++) dimthres[i] = 0.0;
00212
00213 maxprob = calc_contprob_with_beam_pruning_pre(stateinfo->b[maxi]);
00214
00215 for(i=0;i<dimthres_num;i++) dimthres[i] += BEAM_OFFSET;
00216 #else
00217 maxprob = calc_contprob_with_safe_pruning(stateinfo->b[maxi], LOG_ZERO);
00218 #endif
00219 for (i = stateinfo->mix_num - 1; i >= 0; i--) {
00220 if (i == last_maxi) continue;
00221 #ifdef BEAM
00222 prob = calc_contprob_with_beam_pruning_post(stateinfo->b[i]);
00223 #else
00224 prob = calc_contprob_with_safe_pruning(stateinfo->b[i], maxprob);
00225 #endif
00226 if (prob > maxprob) {
00227 maxprob = prob;
00228 maxi = i;
00229 }
00230 }
00231 *maxi_ret = maxi;
00232 } else {
00233 maxi = stateinfo->mix_num - 1;
00234 maxprob = calc_contprob_with_safe_pruning(stateinfo->b[maxi], LOG_ZERO);
00235 i = maxi - 1;
00236 for (; i >= 0; i--) {
00237 prob = calc_contprob_with_safe_pruning(stateinfo->b[i], maxprob);
00238 if (prob > maxprob) {
00239 maxprob = prob;
00240 maxi = i;
00241 }
00242 }
00243 *maxi_ret = maxi;
00244 }
00245
00246 return((maxprob + stateinfo->bweight[maxi]) / LOG_TEN);
00247 }
00248
00249 #else
00250
00259 static LOGPROB
00260 compute_g_max(HTK_HMM_State *stateinfo)
00261 {
00262 int i, maxi;
00263 LOGPROB prob;
00264 LOGPROB maxprob = LOG_ZERO;
00265
00266 i = maxi = stateinfo->mix_num - 1;
00267 for (; i >= 0; i--) {
00268 prob = calc_contprob_with_safe_pruning(stateinfo->b[i], maxprob);
00269 if (prob > maxprob) {
00270 maxprob = prob;
00271 maxi = i;
00272 }
00273 }
00274 return((maxprob + stateinfo->bweight[maxi]) / LOG_TEN);
00275 }
00276 #endif
00277
00278
00279
00280
00281
00292 void
00293 compute_gs_scores(GS_SET *gsset, int gsset_num, LOGPROB *scores_ret)
00294 {
00295 int i;
00296 #ifdef LAST_BEST
00297 int max_id;
00298 #endif
00299
00300 for (i=0;i<gsset_num;i++) {
00301 #ifdef GS_MAX_PROB
00302 #ifdef LAST_BEST
00303
00304 scores_ret[i] = compute_g_max(gsset[i].state, last_max_id[i], &max_id);
00305 last_max_id[i] = max_id;
00306 #else
00307 scores_ret[i] = compute_g_max(gsset[i].state);
00308 #endif
00309 #else
00310
00311 scores_ret[i] = compute_g_base(gsset[i].state);
00312 #endif
00313
00314 }
00315
00316 }