Skip to content
Snippets Groups Projects
Select Git revision
  • f7f725c806a5f00f4e4b67de6122ca50cfe23cf9
  • master default
  • object
  • develop protected
  • private_algos
  • cuisine
  • SMOTE
  • revert-76c4cca5
  • archive protected
  • no_graphviz
  • 0.0.1
11 results

analyzeResults.py

Blame
  • maca_trans_parser_nn.cc 8.77 KiB
    #include<stdio.h>
    #include<stdlib.h>
    #include<string.h>
    #include<unistd.h>
    #include<getopt.h>
    #include"context.h"
    #include"movement_parser.h"
    #include"oracle_parser_arc_eager.h"
    #include"feat_fct.h"
    #include"feature_table.h"
    #include"dico.h"
    #include "keras.h"
    #include"movement_parser_arc_eager.h"
    #include"feat_fct.h"
    #include"feature_table.h"
    
    
    
    
    void maca_trans_parser_nn_help_message(context *ctx)
    {
      context_general_help_message(ctx);
      context_debug_help_message(ctx);
      fprintf(stderr, "INPUT\n");
      context_input_help_message(ctx);
      context_mcd_help_message(ctx);
      context_vocabs_help_message(ctx);
      context_features_model_help_message(ctx);
      context_root_label_help_message(ctx);
      context_json_help_message(ctx);
      context_dnn_model_help_message(ctx);
    }
    
    void maca_trans_parser_nn_check_options(context *ctx){
      if(ctx->help
         /*!ctx->conll_filename*/
         /*     || !ctx->perc_model_filename
         || !ctx->mcd_filename
         || !ctx->vocabs_filename
         || !ctx->features_model_filename*/
         ){
        maca_trans_parser_nn_help_message(ctx);
        exit(1);
      }
    }
    
    void set_linguistic_resources_filenames_parser(context *ctx)
    {
      char absolute_filename[500];
    
      if(!ctx->dnn_model_filename){
        strcpy(absolute_filename, ctx->maca_data_path);
        strcat(absolute_filename, DEFAULT_MODEL_PARSER_NN_FILENAME);
        ctx->dnn_model_filename = strdup(absolute_filename);
      }
    
      if(!ctx->json_filename){
        strcpy(absolute_filename, ctx->maca_data_path);
        strcat(absolute_filename, DEFAULT_JSON_PARSER_NN_FILENAME);
        ctx->json_filename = strdup(absolute_filename);
      }
    
      if(!ctx->vocabs_filename){
        strcpy(absolute_filename, ctx->maca_data_path);
        strcat(absolute_filename, DEFAULT_VOCABS_PARSER_NN_FILENAME);
        ctx->vocabs_filename = strdup(absolute_filename);
      }
    
      if(!ctx->features_model_filename){
        strcpy(absolute_filename, ctx->maca_data_path);
        strcat(absolute_filename, DEFAULT_FEATURES_MODEL_PARSER_NN_FILENAME);
        ctx->features_model_filename = strdup(absolute_filename);
      }
    
      if(ctx->verbose){
        fprintf(stderr, "dnn_model = %s\n",      ctx->dnn_model_filename);
        fprintf(stderr, "json = %s\n",           ctx->json_filename);
        fprintf(stderr, "vocabs %s\n",           ctx->vocabs_filename);
        fprintf(stderr, "mcd = %s\n",            ctx->mcd_filename);
        fprintf(stderr, "features_model = %s\n", ctx->features_model_filename);
      }
    }
    
    
    
    void print_word_buffer(config *c, dico *dico_labels, mcd *mcd_struct)
    {
      int i;
      word *w;
      char *label;
      char *buffer = NULL;
      char *token = NULL;
      int col_nb = 0;
    
      
      for(i=0; i < config_get_buffer(c)->nbelem; i++){
        w = word_buffer_get_word_n(config_get_buffer(c), i);
    
        if((mcd_get_gov_col(mcd_struct) == -1)
           && (mcd_get_label_col(mcd_struct) == -1)
           && (mcd_get_sent_seg_col(mcd_struct) == -1)){
          printf("%s\t", word_get_input(w));
          printf("%d\t", word_get_gov(w));
          label = (word_get_label(w) == -1)? NULL : dico_int2string(dico_labels, word_get_label(w));
          if(label != NULL)
    	printf("%s\t", label) ;
          else
    	printf("_\t");
          if(word_get_sent_seg(w) == 1)
    	printf("1\n") ;
          else
    	printf("0\n");
        }
        else{
          buffer = strdup(w->input);
          token = strtok(buffer, "\t");
          col_nb = 0;
          while(token){
    	if(col_nb != 0) printf("\t");
    	if(col_nb == mcd_get_gov_col(mcd_struct)){
    	  printf("%d", word_get_gov(w));
    	}
    	else
    	  if(col_nb == mcd_get_label_col(mcd_struct)){
    	    label = (word_get_label(w) == -1)? NULL : dico_int2string(dico_labels, word_get_label(w));
    	    if(label != NULL)
    	      printf("%s", label) ;
    	    else
    	      printf("_");
    	  }
    	  else
    	    if(col_nb == mcd_get_sent_seg_col(mcd_struct)){
    	      if(word_get_sent_seg(w) == 1)
    		printf("1") ;
    	      else
    		printf("0");
    	    }
    	    else{
    	      word_print_col_n(stdout, w, col_nb);
    	    }
    	col_nb++;
    	token = strtok(NULL, "\t");
          }
          if((col_nb <= mcd_get_gov_col(mcd_struct)) || (mcd_get_gov_col(mcd_struct) == -1)){
    	printf("\t%d", word_get_gov(w));
          }
          if((col_nb <= mcd_get_label_col(mcd_struct)) || (mcd_get_label_col(mcd_struct) == -1)){
    	label = (word_get_label(w) == -1)? NULL : dico_int2string(dico_labels, word_get_label(w));
    	if(label != NULL)
    	  printf("\t%s", label) ;
    	else
    	  printf("\t_");
          }
          if((col_nb <= mcd_get_sent_seg_col(mcd_struct)) || (mcd_get_sent_seg_col(mcd_struct) == -1)){
    	if(word_get_sent_seg(w) == 1)
    	  printf("\t1") ;
    	else
    	  printf("\t0");
          }
          printf("\n");
          free(buffer);
        }
      }
    }
    
    std::vector<Matrix<float> > config2keras_vec(feat_model *fm, config *c)
    {
      int i;
      feat_desc *fd;
      int feat_value;
      std::vector<Matrix<float> > keras_vec(fm->nbelem, Matrix<float>(1, 1));
      for(i=0; i < fm->nbelem; i++){
        /* fm must be exclusively  composed of simple features */
        /* if this is not the case, the first feature of a complex feature is take into account */
        fd = fm->array[i];
        feat_value = fd->array[0]->fct(c);
        keras_vec[i][0][0] = feat_value + 1;
        //    printf("feature %d = %d\n", i, (int)keras_vec[i][0][0]);
      }
      return keras_vec;
    }
    
    void simple_decoder_parser_arc_eager_nn(context *ctx, Model &model)
    {
      FILE *f = (ctx->input_filename)? myfopen(ctx->input_filename, "r") : stdin;
      //  feature_table *ft = feature_table_load(ctx->perc_model_filename, ctx->verbose);
      int root_label;
      int mvt_code;
      int mvt_type;
      int mvt_label;
      config *c = NULL;
      int result;
      std::vector<Matrix<float> > keras_vec;
      int sentence_nb = 0;  
      root_label = dico_string2int(ctx->dico_labels, ctx->root_label);
      if(root_label == -1) root_label = 0;
      
      c = config_new(f, ctx->mcd_struct, 5);
      while(!config_is_terminal(c)){
    
        if(ctx->debug_mode){
          fprintf(stdout, "***********************************\n");
          config_print(stdout, c);      
        }	
        /* forced EOS (the element on the top of the stack is eos, but the preceding movement is not MVT_PARSER_EOS */
        /* which means that the top of the stack got its eos status from input */
        /* force the parser to finish parsing the sentence (perform all pending reduce actions) and determine root of the sentence */ 
    
        if((word_get_sent_seg(stack_top(config_get_stack(c))) == 1) && (mvt_get_type(mvt_stack_top(config_get_history(c))) != MVT_PARSER_EOS)){
          word_set_sent_seg(stack_top(config_get_stack(c)), -1);
          movement_parser_eos(c);
          while(movement_parser_reduce(c));
          while(movement_parser_root(c, root_label));
          if(ctx->debug_mode) printf("force EOS\n");
        }
    
        /* normal behaviour, ask classifier what is the next movement to do and do it */
        else{
          keras_vec = config2keras_vec(ctx->features_model, c);
          std::vector<Matrix<float> > y = model.forward(keras_vec);
          Matrix<float> argmax = y[0].argmax();
          mvt_code = argmax.at(0, 0);
    
          //      fprintf(stderr,"mvt code = %d\n", mvt_code);
          mvt_type = movement_parser_type(mvt_code);
          mvt_label = movement_parser_label(mvt_code);
                
          result = 0;
          switch(mvt_type){
          case MVT_PARSER_LEFT :
    	result = movement_parser_left_arc(c, mvt_label);
    	break;
          case MVT_PARSER_RIGHT:
    	result = movement_parser_right_arc(c, mvt_label);
    	break;
          case MVT_PARSER_REDUCE:
    	result = movement_parser_reduce(c);
    	break;
          case MVT_PARSER_ROOT:
    	result = movement_parser_root(c, root_label);
    	break;
          case MVT_PARSER_EOS:
    	result = movement_parser_eos(c);
    	sentence_nb++;
    	if((sentence_nb % 2) == 0)
    	  fprintf(stderr, "\rsentence %d", sentence_nb);
    	break;
          case MVT_PARSER_SHIFT:
    	result = movement_parser_shift(c);
          }
          
          if(result == 0){
    	if(ctx->debug_mode) fprintf(stdout, "WARNING : movement cannot be executed doing a SHIFT instead !\n");
    	result = movement_parser_shift(c);
    	if(result == 0){ /* SHIFT failed no more words to read, let's get out of here ! */
    	  if(ctx->debug_mode) fprintf(stdout, "WARNING : cannot exectue a SHIFT emptying stack !\n");
    	  while(!stack_is_empty(config_get_stack(c)))
    	    movement_parser_root(c, root_label);
    	}
          }
        }
        
      }
      fprintf(stderr, "\n");  
      print_word_buffer(c, ctx->dico_labels, ctx->mcd_struct);
      
      config_free(c); 
      if(ctx->input_filename)
        fclose(f);
    }
    
    
    
    int main(int argc, char *argv[])
    {
      context *ctx;
    
      ctx = context_read_options(argc, argv);
      maca_trans_parser_nn_check_options(ctx);
    
      set_linguistic_resources_filenames_parser(ctx);
      Model model = Model::load(ctx->json_filename, ctx->dnn_model_filename);
    
      ctx->features_model = feat_model_read(ctx->features_model_filename, feat_lib_build(), ctx->verbose);
      ctx->vocabs = dico_vec_read(ctx->vocabs_filename, ctx->hash_ratio);
    
      mcd_link_to_dico(ctx->mcd_struct, ctx->vocabs, ctx->verbose);
    
      ctx->dico_labels = dico_vec_get_dico(ctx->vocabs, (char *)"LABEL");
    
      if(ctx->dico_labels == NULL){
        fprintf(stderr, "cannot find label names\n");
        return 1;
      }
      simple_decoder_parser_arc_eager_nn(ctx, model);
      
      //  context_free(ctx);
      return 0;
    }