Select Git revision
analyzeResults.py
-
Baptiste Bauvin authoredBaptiste Bauvin authored
maca_trans_parser_nn.cc 8.77 KiB
#include<stdio.h>
#include<stdlib.h>
#include<string.h>
#include<unistd.h>
#include<getopt.h>
#include"context.h"
#include"movement_parser.h"
#include"oracle_parser_arc_eager.h"
#include"feat_fct.h"
#include"feature_table.h"
#include"dico.h"
#include "keras.h"
#include"movement_parser_arc_eager.h"
#include"feat_fct.h"
#include"feature_table.h"
void maca_trans_parser_nn_help_message(context *ctx)
{
context_general_help_message(ctx);
context_debug_help_message(ctx);
fprintf(stderr, "INPUT\n");
context_input_help_message(ctx);
context_mcd_help_message(ctx);
context_vocabs_help_message(ctx);
context_features_model_help_message(ctx);
context_root_label_help_message(ctx);
context_json_help_message(ctx);
context_dnn_model_help_message(ctx);
}
void maca_trans_parser_nn_check_options(context *ctx){
if(ctx->help
/*!ctx->conll_filename*/
/* || !ctx->perc_model_filename
|| !ctx->mcd_filename
|| !ctx->vocabs_filename
|| !ctx->features_model_filename*/
){
maca_trans_parser_nn_help_message(ctx);
exit(1);
}
}
void set_linguistic_resources_filenames_parser(context *ctx)
{
char absolute_filename[500];
if(!ctx->dnn_model_filename){
strcpy(absolute_filename, ctx->maca_data_path);
strcat(absolute_filename, DEFAULT_MODEL_PARSER_NN_FILENAME);
ctx->dnn_model_filename = strdup(absolute_filename);
}
if(!ctx->json_filename){
strcpy(absolute_filename, ctx->maca_data_path);
strcat(absolute_filename, DEFAULT_JSON_PARSER_NN_FILENAME);
ctx->json_filename = strdup(absolute_filename);
}
if(!ctx->vocabs_filename){
strcpy(absolute_filename, ctx->maca_data_path);
strcat(absolute_filename, DEFAULT_VOCABS_PARSER_NN_FILENAME);
ctx->vocabs_filename = strdup(absolute_filename);
}
if(!ctx->features_model_filename){
strcpy(absolute_filename, ctx->maca_data_path);
strcat(absolute_filename, DEFAULT_FEATURES_MODEL_PARSER_NN_FILENAME);
ctx->features_model_filename = strdup(absolute_filename);
}
if(ctx->verbose){
fprintf(stderr, "dnn_model = %s\n", ctx->dnn_model_filename);
fprintf(stderr, "json = %s\n", ctx->json_filename);
fprintf(stderr, "vocabs %s\n", ctx->vocabs_filename);
fprintf(stderr, "mcd = %s\n", ctx->mcd_filename);
fprintf(stderr, "features_model = %s\n", ctx->features_model_filename);
}
}
void print_word_buffer(config *c, dico *dico_labels, mcd *mcd_struct)
{
int i;
word *w;
char *label;
char *buffer = NULL;
char *token = NULL;
int col_nb = 0;
for(i=0; i < config_get_buffer(c)->nbelem; i++){
w = word_buffer_get_word_n(config_get_buffer(c), i);
if((mcd_get_gov_col(mcd_struct) == -1)
&& (mcd_get_label_col(mcd_struct) == -1)
&& (mcd_get_sent_seg_col(mcd_struct) == -1)){
printf("%s\t", word_get_input(w));
printf("%d\t", word_get_gov(w));
label = (word_get_label(w) == -1)? NULL : dico_int2string(dico_labels, word_get_label(w));
if(label != NULL)
printf("%s\t", label) ;
else
printf("_\t");
if(word_get_sent_seg(w) == 1)
printf("1\n") ;
else
printf("0\n");
}
else{
buffer = strdup(w->input);
token = strtok(buffer, "\t");
col_nb = 0;
while(token){
if(col_nb != 0) printf("\t");
if(col_nb == mcd_get_gov_col(mcd_struct)){
printf("%d", word_get_gov(w));
}
else
if(col_nb == mcd_get_label_col(mcd_struct)){
label = (word_get_label(w) == -1)? NULL : dico_int2string(dico_labels, word_get_label(w));
if(label != NULL)
printf("%s", label) ;
else
printf("_");
}
else
if(col_nb == mcd_get_sent_seg_col(mcd_struct)){
if(word_get_sent_seg(w) == 1)
printf("1") ;
else
printf("0");
}
else{
word_print_col_n(stdout, w, col_nb);
}
col_nb++;
token = strtok(NULL, "\t");
}
if((col_nb <= mcd_get_gov_col(mcd_struct)) || (mcd_get_gov_col(mcd_struct) == -1)){
printf("\t%d", word_get_gov(w));
}
if((col_nb <= mcd_get_label_col(mcd_struct)) || (mcd_get_label_col(mcd_struct) == -1)){
label = (word_get_label(w) == -1)? NULL : dico_int2string(dico_labels, word_get_label(w));
if(label != NULL)
printf("\t%s", label) ;
else
printf("\t_");
}
if((col_nb <= mcd_get_sent_seg_col(mcd_struct)) || (mcd_get_sent_seg_col(mcd_struct) == -1)){
if(word_get_sent_seg(w) == 1)
printf("\t1") ;
else
printf("\t0");
}
printf("\n");
free(buffer);
}
}
}
std::vector<Matrix<float> > config2keras_vec(feat_model *fm, config *c)
{
int i;
feat_desc *fd;
int feat_value;
std::vector<Matrix<float> > keras_vec(fm->nbelem, Matrix<float>(1, 1));
for(i=0; i < fm->nbelem; i++){
/* fm must be exclusively composed of simple features */
/* if this is not the case, the first feature of a complex feature is take into account */
fd = fm->array[i];
feat_value = fd->array[0]->fct(c);
keras_vec[i][0][0] = feat_value + 1;
// printf("feature %d = %d\n", i, (int)keras_vec[i][0][0]);
}
return keras_vec;
}
void simple_decoder_parser_arc_eager_nn(context *ctx, Model &model)
{
FILE *f = (ctx->input_filename)? myfopen(ctx->input_filename, "r") : stdin;
// feature_table *ft = feature_table_load(ctx->perc_model_filename, ctx->verbose);
int root_label;
int mvt_code;
int mvt_type;
int mvt_label;
config *c = NULL;
int result;
std::vector<Matrix<float> > keras_vec;
int sentence_nb = 0;
root_label = dico_string2int(ctx->dico_labels, ctx->root_label);
if(root_label == -1) root_label = 0;
c = config_new(f, ctx->mcd_struct, 5);
while(!config_is_terminal(c)){
if(ctx->debug_mode){
fprintf(stdout, "***********************************\n");
config_print(stdout, c);
}
/* forced EOS (the element on the top of the stack is eos, but the preceding movement is not MVT_PARSER_EOS */
/* which means that the top of the stack got its eos status from input */
/* force the parser to finish parsing the sentence (perform all pending reduce actions) and determine root of the sentence */
if((word_get_sent_seg(stack_top(config_get_stack(c))) == 1) && (mvt_get_type(mvt_stack_top(config_get_history(c))) != MVT_PARSER_EOS)){
word_set_sent_seg(stack_top(config_get_stack(c)), -1);
movement_parser_eos(c);
while(movement_parser_reduce(c));
while(movement_parser_root(c, root_label));
if(ctx->debug_mode) printf("force EOS\n");
}
/* normal behaviour, ask classifier what is the next movement to do and do it */
else{
keras_vec = config2keras_vec(ctx->features_model, c);
std::vector<Matrix<float> > y = model.forward(keras_vec);
Matrix<float> argmax = y[0].argmax();
mvt_code = argmax.at(0, 0);
// fprintf(stderr,"mvt code = %d\n", mvt_code);
mvt_type = movement_parser_type(mvt_code);
mvt_label = movement_parser_label(mvt_code);
result = 0;
switch(mvt_type){
case MVT_PARSER_LEFT :
result = movement_parser_left_arc(c, mvt_label);
break;
case MVT_PARSER_RIGHT:
result = movement_parser_right_arc(c, mvt_label);
break;
case MVT_PARSER_REDUCE:
result = movement_parser_reduce(c);
break;
case MVT_PARSER_ROOT:
result = movement_parser_root(c, root_label);
break;
case MVT_PARSER_EOS:
result = movement_parser_eos(c);
sentence_nb++;
if((sentence_nb % 2) == 0)
fprintf(stderr, "\rsentence %d", sentence_nb);
break;
case MVT_PARSER_SHIFT:
result = movement_parser_shift(c);
}
if(result == 0){
if(ctx->debug_mode) fprintf(stdout, "WARNING : movement cannot be executed doing a SHIFT instead !\n");
result = movement_parser_shift(c);
if(result == 0){ /* SHIFT failed no more words to read, let's get out of here ! */
if(ctx->debug_mode) fprintf(stdout, "WARNING : cannot exectue a SHIFT emptying stack !\n");
while(!stack_is_empty(config_get_stack(c)))
movement_parser_root(c, root_label);
}
}
}
}
fprintf(stderr, "\n");
print_word_buffer(c, ctx->dico_labels, ctx->mcd_struct);
config_free(c);
if(ctx->input_filename)
fclose(f);
}
int main(int argc, char *argv[])
{
context *ctx;
ctx = context_read_options(argc, argv);
maca_trans_parser_nn_check_options(ctx);
set_linguistic_resources_filenames_parser(ctx);
Model model = Model::load(ctx->json_filename, ctx->dnn_model_filename);
ctx->features_model = feat_model_read(ctx->features_model_filename, feat_lib_build(), ctx->verbose);
ctx->vocabs = dico_vec_read(ctx->vocabs_filename, ctx->hash_ratio);
mcd_link_to_dico(ctx->mcd_struct, ctx->vocabs, ctx->verbose);
ctx->dico_labels = dico_vec_get_dico(ctx->vocabs, (char *)"LABEL");
if(ctx->dico_labels == NULL){
fprintf(stderr, "cannot find label names\n");
return 1;
}
simple_decoder_parser_arc_eager_nn(ctx, model);
// context_free(ctx);
return 0;
}