Select Git revision
dnn_decoder.c
-
Alexis Nasr authored
one module available : maca_trans_parser a transition based parser
Alexis Nasr authoredone module available : maca_trans_parser a transition based parser
dnn_decoder.c 4.91 KiB
#include<stdio.h>
#include<stdlib.h>
#include<string.h>
#include<unistd.h>
#include<getopt.h>
#include"context.h"
#include"movement.h"
#include"oracle.h"
#include"feat_fct.h"
#include"feature_table.h"
#include"dico_vec.h"
#include"word_emb.h"
#include"fann.h"
#include"config2feat_vec.h"
void dnn_decoder_buffer(FILE *f, mcd *mcd_struct, struct fann *ann, feat_model *fm, int verbose, int root_label);
void dnn_decoder_stream(FILE *f, mcd *mcd_struct, struct fann *ann, feat_model *fm, int verbose, int root_label);
void dnn_decoder(FILE *f, mcd *mcd_struct, struct fann *ann, feat_model *fm, int verbose, int root_label, int stream_mode)
{
if(stream_mode)
dnn_decoder_stream(f, mcd_struct, ann, fm, verbose, root_label);
else
dnn_decoder_buffer(f, mcd_struct, ann, fm, verbose, root_label);
}
void dnn_decoder_stream(FILE *f, mcd *mcd_struct, struct fann *ann, feat_model *fm, int verbose, int root_label)
{
config *c = NULL;
int mvt_type;
int mvt_label;
feat_vec *fv = feat_vec_new(feature_types_nb);
float max;
fann_type *output_array = NULL;
int argmax, i;
sentence *s = NULL;
int input_nb = fann_get_num_input(ann);
int output_nb = fann_get_num_output(ann);
fann_type *input_array = (fann_type *)memalloc(input_nb * sizeof(fann_type));
c = config_initial(f, mcd_struct, 10, 5);
while(!config_is_terminal(c)){
config2feat_vec_fann(fm, c, fv, LOOKUP_MODE);
feat_vec_fill_input_array_dnn(input_array, fv, fm, mcd_struct);
output_array = fann_run(ann, input_array);
argmax = 0;
max = output_array[argmax];
for(i=0; i < output_nb; i++){
if(output_array[i] > max){
max = output_array[i];
argmax = i;
}
}
mvt_type = movement_type(argmax);
mvt_label = movement_label(argmax);
/* config_print(stdout, c); */
if((stack_height(c->st)==1) && (mvt_type == MVT_RIGHT) && (mvt_label == root_label)){ /* sentence is complete */
/* create the root arc */
movement_right_arc(c, mvt_label, 0);
/* shift dummy word in stack */
movement_shift(c, 1, 0);
/* pop it */
stack_pop(c->st);
/* remplace it with a fresh one */
stack_push(c->st, word_create_dummy(mcd_struct));
/* config_print(stdout, c); */
#if 0
/* create a new empty sentence */
/* fill it, print it and destory it */
s = sentence_new(mcd_struct, f);
sentence_depset_update(s, c->ds);
sentence_print_conll07(stdout, s);
/* sentence_free(s); */
#endif
/* create a new empty sentence */
/* fill it, print it and destory it */
s = sentence_new(mcd_struct, f);
config_connect_subtrees(c, root_label);
sentence_depset_update(s, c->ds);
sentence_print_conll07(stdout, s);
sentence_free(s);
/* empty depset */
depset_free(c->ds);
c->ds = depset_new();
/* sentence_nb++; */
continue;
}
if(mvt_type == MVT_LEFT)
if(movement_left_arc(c, mvt_label, max))
continue;
if(mvt_type == MVT_RIGHT)
if(movement_right_arc(c, mvt_label, max))
continue;
movement_shift(c, 1, max);
}
}
void dnn_decoder_buffer(FILE *f, mcd *mcd_struct, struct fann *ann, feat_model *fm, int verbose, int root_label)
{
config *c = NULL;
int mvt_type;
int mvt_label;
feat_vec *fv = feat_vec_new(feature_types_nb);
float max;
int input_nb = fann_get_num_input(ann);
int output_nb = fann_get_num_output(ann);
fann_type *output_array = NULL;
int argmax, i;
sentence *s = NULL;
fann_type *input_array = (fann_type *)memalloc(input_nb * sizeof(fann_type));
c = config_initial(f, mcd_struct, 1000, 0);
while(queue_read_sentence(c->bf, f, mcd_struct)){
while(!config_is_terminal(c)){
config2feat_vec_fann(fm, c, fv, LOOKUP_MODE);
feat_vec_fill_input_array_dnn(input_array, fv, fm, mcd_struct);
output_array = fann_run(ann, input_array);
argmax = 0;
max = output_array[argmax];
for(i=0; i < output_nb; i++){
if(output_array[i] > max){
max = output_array[i];
argmax = i;
}
}
mvt_type = movement_type(argmax);
mvt_label = movement_label(argmax);
if(mvt_type == MVT_LEFT)
if(movement_left_arc(c, mvt_label, max))
continue;
if(mvt_type == MVT_RIGHT)
if(movement_right_arc(c, mvt_label, max))
continue;
movement_shift(c, 0, max);
}
/* config_print(stderr, c); */
#if 0
/* create a new empty sentence */
/* fill it, print it and destory it */
s = sentence_new(mcd_struct, f);
sentence_depset_update(s, c->ds);
sentence_print_conll07(stdout, s);
#endif
/* create a new empty sentence */
/* fill it, print it and destory it */
s = sentence_new(mcd_struct, f);
config_connect_subtrees(c, root_label);
sentence_depset_update(s, c->ds);
sentence_print_conll07(stdout, s);
sentence_free(s);
config_free(c);
c = config_initial(f, mcd_struct, 1000, 0);
}
}