Select Git revision
beam.c 5.88 KiB
em#include<stdio.h>
#include<stdlib.h>
#include<string.h>
#include"beam.h"
#include"feat_fct.h"
#include"movement_parser.h"
#include"config2feat_vec.h"
config *beam_argmax(beam *b)
{
int i;
float max = b->t[0]->score;
config *argmax = b->t[0];
for(i=1; i < b->nbelem; i++){
if(b->t[i]->score > max){
max = b->t[i]->score;
argmax = b->t[i];
}
}
return argmax;
}
void beam_print(FILE *f, beam *b)
{
int i;
for(i=0; i < b->nbelem; i++){
fprintf(f, "%d %.4f\t", i, b->t[i]->score);
config_print(f, b->t[i]);
}
}
beam *beam_new(int size)
{
beam *b = memalloc(sizeof(beam));
b->size = size;
b->nbelem = 0;
b->t = memalloc(size * sizeof(config *));
return b;
}
void beam_free(beam *b)
{
beam_empty(b);
free(b->t);
free(b);
}
int beam_add(beam *b, config * c)
{
if(b->nbelem == b->size - 1){
b->size = 2 * (b->size + 1);
b->t = realloc(b->t, b->size * sizeof(config *));
}
b->t[b->nbelem] = c;
b->nbelem++;
return b->nbelem;
}
void beam_empty(beam *b)
{
int i;
for(i=0; i < b->nbelem; i++)
config_free(b->t[i]);
b->nbelem = 0;
}
int compare_triplets_cms(const void *t1, const void *t2)
{
float ret = ((triplet_cms *)t2)->score - ((triplet_cms *)t1)->score;
if(ret > 0) return 1;
if(ret < 0) return -1;
return 0;
}
int beam_fill_k_best_array(beam *current_beam, triplet_cms *k_best_array, feature_table *ft, feat_model *fm, dico *dico_features, int beam_width, int mode)
{
int i,j;
int classes_nb = ft->classes_nb;
int k = 0;
vcode *vcode_array = NULL;
feat_vec *fv = feat_vec_new(feature_types_nb);
for(i=0; i < current_beam->nbelem; i++){
config2feat_vec_cff(fm, current_beam->t[i], dico_features, fv, mode);
vcode_array = feature_table_get_vcode_array(fv, ft);
/* for(j=0; (j < classes_nb) && ( j < beam_width); j++){ */
for(j=0; j < classes_nb; j++){
k_best_array[k].config = i;
k_best_array[k].movement = vcode_array[j].class_code;
k_best_array[k].score = vcode_array[j].score + current_beam->t[i]->score;
/* fprintf(stdout, "score = %f\n", k_best_array[k].score); */
k++;
}
free(vcode_array);
}
feat_vec_free(fv);
qsort(k_best_array, k, sizeof(triplet_cms), compare_triplets_cms);
return k;
}
void fill_next_beam(beam *current_beam, beam *next_beam, beam *final_beam, feature_table *ft, feat_model *fm, dico *dico_features, int beam_width, int mode)
{
config *c;
triplet_cms *k_best_array = memalloc(beam_width * ft->classes_nb * sizeof(triplet_cms));
int i,k;
float configuration_score;
config *next_config;
int mvt_code;
int mvt_type;
int mvt_label;
feat_vec *fv = feat_vec_new(feature_types_nb);
k = beam_fill_k_best_array(current_beam, k_best_array, ft, fm, dico_features, beam_width, mode);
/* fprintf(stdout, "number of elements in k best array = %d\n", k); */
for(i=0; (i < k) && (next_beam->nbelem < beam_width); i++){
/* fprintf(stdout, "i = %d config = %d score = %f\n", i, k_best_array[i].config,k_best_array[i].score); */
configuration_score = k_best_array[i].score;
c = current_beam->t[k_best_array[i].config];
config2feat_vec_cff(fm, c, dico_features, fv, mode);
mvt_code = k_best_array[i].movement;
mvt_type = movement_type(mvt_code);
mvt_label = movement_label(mvt_code);
/* fprintf(stdout, "mvt code = %d type = %d\n", mvt_code, mvt_type); */
if(mvt_type == MVT_LEFT){
/* fprintf(stdout, "LEFT\n"); */
next_config = movement_left_arc_dup(c, mvt_label, configuration_score, fv);
if(next_config != NULL){
if(config_is_terminal(next_config))
beam_add(final_beam, next_config);
else
beam_add(next_beam, next_config);
continue;
}
}
if(mvt_type == MVT_RIGHT){
/* fprintf(stdout, "RIGHT\n"); */
next_config = movement_right_arc_dup(c, mvt_label, configuration_score, fv);
if(next_config != NULL){
if(config_is_terminal(next_config))
beam_add(final_beam, next_config);
else
beam_add(next_beam, next_config);
continue;
}
}
if(mvt_type == MVT_SHIFT){
/* fprintf(stdout, "SHIFT\n"); */
next_config = movement_shift_dup(c, 0, configuration_score, fv);
if(next_config != NULL){
if(config_is_terminal(next_config))
beam_add(final_beam, next_config);
else
beam_add(next_beam, next_config);
continue;
}
}
}
feat_vec_free(fv);
free(k_best_array);
}
config *beam_decoder_sentence(config *initial_config, dico *dico_features, feature_table *ft, int beam_width, int mvt_nb, feat_model *fm)
{
beam *current_beam= beam_new(beam_width);
beam *next_beam= beam_new(beam_width);
beam *final_beam= beam_new(beam_width);
beam *tmp_beam= NULL;
config *argmax;
int step = 0;
beam_add(current_beam, initial_config);
while(current_beam->nbelem > 0){
/* fprintf(stderr, "step = %d elements in beam = %d\n", step, current_beam->nbelem); */
fill_next_beam(current_beam, next_beam, final_beam, ft, fm, dico_features, beam_width, LOOKUP_MODE);
beam_empty(current_beam);
tmp_beam = current_beam;
current_beam = next_beam;
next_beam = tmp_beam;
step++;
}
argmax = config_copy(beam_argmax(final_beam));
beam_free(current_beam);
beam_free(next_beam);
beam_free(final_beam);
return argmax;
}
void beam_decoder(FILE *f, mcd *mcd_struct, dico *dico_features, dico *dico_labels, feature_table *ft, feat_model *fm, int verbose, int root_label, int beam_width, int mvt_nb)
{
config *c;
config *c_final;
c = config_initial(f, mcd_struct, 0);
while(queue_read_sentence(c->bf, f, mcd_struct)){
c_final = beam_decoder_sentence(c, dico_features, ft, beam_width, mvt_nb, fm);
config_connect_subtrees(c_final, root_label);
/* depset_print(stdout, c_final->ds); */
depset_print2(stdout, c_final->ds, dico_labels);
config_free(c_final);
c = config_initial(f, mcd_struct, 0);
}
}