Skip to content
Snippets Groups Projects
Select Git revision
  • 7cc3e7f13cc7892286ba214f55a1c41a578df512
  • master default protected
  • erased
  • states
  • negatives
  • temp
  • negativeExamples
  • Rl
8 results

Util.py

Blame
  • simple_decoder_parser_arc_eager.c 10.93 KiB
    #include<stdio.h>
    #include<stdlib.h>
    #include<string.h>
    #include<unistd.h>
    #include<getopt.h>
    #include"context.h"
    #include<math.h>
    #include"movement_parser_arc_eager.h"
    #include"feat_fct.h"
    #include"config2feat_vec.h"
    #include"feature_table.h"
    #include"dico.h"
    #include"partial_parser_conditional.h"
    #include"confidence_score.h"
    
    void print_word_buffer_old(config *c, dico *dico_labels, mcd *mcd_struct)
    {
      int i;
      word *dep;
      char *label;
      
      for(i=0; i < config_get_buffer(c)->nbelem; i++){
        dep = word_buffer_get_word_n(config_get_buffer(c), i);
        printf("%s\t", word_get_input(dep));
        printf("%d\t", word_get_gov(dep));
        label = (word_get_label(dep) == -1)? NULL : dico_int2string(dico_labels, word_get_label(dep));
        if(label != NULL)
          printf("%s\t", label) ;
        else
          printf("_\t");
        if(mcd_get_sent_seg_col(mcd_struct) == -1){
          if(word_get_sent_seg(dep) == 1)
    	printf("1") ;
          else
    	printf("0");
        }
        printf("\n");
        
      }
    }
    
    void print_word_buffer(config *c, dico *dico_labels, mcd *mcd_struct)
    {
      int i;
      word *w;
      char *label;
      char *buffer = NULL;
      char *token = NULL;
      int col_nb = 0;
    
      
      for(i=0; i < config_get_buffer(c)->nbelem; i++){
        w = word_buffer_get_word_n(config_get_buffer(c), i);
    
        if((mcd_get_gov_col(mcd_struct) == -1)
           && (mcd_get_label_col(mcd_struct) == -1)
           && (mcd_get_sent_seg_col(mcd_struct) == -1)){
          printf("%s\t", word_get_input(w));
          printf("%d\t", word_get_gov(w));
          label = (word_get_label(w) == -1)? NULL : dico_int2string(dico_labels, word_get_label(w));
          if(label != NULL)
    	printf("%s\t", label) ;
          else
    	printf("_\t");
          if(word_get_sent_seg(w) == 1)
    	printf("1\n") ;
          else
    	printf("0\n");
        }
        else{
          buffer = strdup(w->input);
          token = strtok(buffer, "\t");
          col_nb = 0;
          while(token){
    	if(col_nb != 0) printf("\t");
    	if(col_nb == mcd_get_gov_col(mcd_struct)){
    	  printf("%d", word_get_gov(w));
    	}
    	else
    	  if(col_nb == mcd_get_label_col(mcd_struct)){
    	    label = (word_get_label(w) == -1)? NULL : dico_int2string(dico_labels, word_get_label(w));
    	    if(label != NULL)
    	      printf("%s", label) ;
    	    else
    	      printf("_");
    	  }
    	  else
    	    if(col_nb == mcd_get_sent_seg_col(mcd_struct)){
    	      if(word_get_sent_seg(w) == 1)
    		printf("1") ;
    	      else
    		printf("0");
    	    }
    	    else{
    	      word_print_col_n(stdout, w, col_nb);
    	    }
    	col_nb++;
    	token = strtok(NULL, "\t");
          }
          if((col_nb <= mcd_get_gov_col(mcd_struct)) || (mcd_get_gov_col(mcd_struct) == -1)){
    	printf("\t%d", word_get_gov(w));
          }
          if((col_nb <= mcd_get_label_col(mcd_struct)) || (mcd_get_label_col(mcd_struct) == -1)){
    	label = (word_get_label(w) == -1)? NULL : dico_int2string(dico_labels, word_get_label(w));
    	if(label != NULL)
    	  printf("\t%s", label) ;
    	else
    	  printf("\t_");
          }
          if((col_nb <= mcd_get_sent_seg_col(mcd_struct)) || (mcd_get_sent_seg_col(mcd_struct) == -1)){
    	if(word_get_sent_seg(w) == 1)
    	  printf("\t1") ;
    	else
    	  printf("\t0");
          }
          if(col_nb <= mcd_get_s_col(mcd_struct)){
    	if(word_get_S(w) > 0)
    	  printf("\t%d",word_get_S(w));
    	else
    	  printf("\t-1");
          }
          if(col_nb <= mcd_get_t_col(mcd_struct)){
    	if(word_get_T(w) > 0)
    	  printf("\t%d",word_get_T(w));
    	else
    	  printf("\t-1");
          }
          printf("\n");
          free(buffer);
        }
      }
    }
    
    
    void simple_decoder_parser_arc_eager(context *ctx)
    {
      FILE *f = (ctx->input_filename)? myfopen(ctx->input_filename, "r") : stdin;
      feature_table *ft = feature_table_load(ctx->perc_model_filename, ctx->verbose);
      int root_label;
      int mvt_code;
      int mvt_type;
      int mvt_label;
      float max;
      feat_vec *fv = feat_vec_new(feature_types_nb);
      config *c = NULL;
      int result;
      /* float entropy; */
      /* float delta; */
      int argmax1, argmax2;
      float max1, max2;
      int index;
      float score;
      
      double sumExp = -1.;
      double currentSumExp = 0.;
      double ScoreTranslation = -150.;
      
      word* word_scored;
      
      root_label = dico_string2int(ctx->dico_labels, ctx->root_label);
      if(root_label == -1) root_label = 0;
      
      c = config_new(f, ctx->mcd_struct, 5);
      while(!config_is_terminal(c)){
    
        if(ctx->debug_mode){
          fprintf(stdout, "***********************************\n");
          config_print(stdout, c);      
        }	
        /* forced EOS (the element on the top of the stack is eos, but the preceding movement is not MVT_PARSER_EOS */
        /* which means that the top of the stack got its eos status from input */
        /* force the parser to finish parsing the sentence (perform all pending reduce actions) and determine root of the sentence */ 
    
        if((word_get_sent_seg(stack_top(config_get_stack(c))) == 1) && (mvt_get_type(mvt_stack_top(config_get_history(c))) != MVT_PARSER_EOS)){
          word_set_sent_seg(stack_top(config_get_stack(c)), -1);
          movement_parser_eos(c);
          while(movement_parser_reduce(c));
          while(movement_parser_root(c, root_label));
          if(ctx->debug_mode) printf("force EOS\n");
        }
    
        /* normal behaviour, ask classifier what is the next movement to do and do it */
        else{
          config2feat_vec_cff(ctx->features_model, c, ctx->d_perceptron_features, fv, LOOKUP_MODE);
          mvt_code = feature_table_argmax(fv, ft, &max);
    
          if(ctx->debug_mode){
    	vcode *vcode_array = feature_table_get_vcode_array(fv, ft);
    	
    	/* Get the probabilistic parameters */
    	for(int i=0; i < ft->classes_nb; i++){
    	  int b1 = respect_standard_constraint(c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
    	  int b2 = respect_stack_constraint(ctx->partial_mode, c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
    	  int b3 = respect_buffer_constraint(ctx->partial_mode, c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
    
    	  if(b1 && b2 && b3){
          if(sumExp < 0.){
            ScoreTranslation += vcode_array[i].score;
            sumExp = 0.;
          }
          if(vcode_array[i].score - ScoreTranslation >= 0){
            sumExp += exp(vcode_array[i].score - ScoreTranslation);
          }	  
    	  }
    	}
      
      currentSumExp = 0.;
    	for(int i=0; i < ft->classes_nb && i < 1000; i++){
    	  printf("%d\t", i);
    	  movement_parser_print(stdout, vcode_array[i].class_code, ctx->dico_labels);
    	  printf("\t%.4f", vcode_array[i].score);
    	  fflush(stdout);
    	  int b1 = respect_standard_constraint(c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
    	  int b2 = respect_stack_constraint(ctx->partial_mode, c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
    	  int b3 = respect_buffer_constraint(ctx->partial_mode, c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
    	  if(b1 && b2 && b3){
          if(vcode_array[i].score - ScoreTranslation >= 0){
            printf("%f %f %f",sumExp, currentSumExp,ScoreTranslation);
            printf(" [%f-",currentSumExp/sumExp);
            currentSumExp += exp(vcode_array[i].score - ScoreTranslation);
            printf("%f[", currentSumExp/sumExp);
          }
          printf("\t<----");
    	  }else
    		printf("\t<%d,%d,%d>",b1,b2,b3);
    	  // printf("\t%d", respect_standard_constraint(c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code)));
    	  //printf("AAAAAAA\n");
    	  printf("\n");
    	}
    	free(vcode_array);
          }
          
          if(ctx->trace_mode){
    	index = word_get_index(word_buffer_b0(config_get_buffer(c)));
    	fprintf(stdout, "%d\t", index);
    	
    	stack_print(stdout, c->st);
    	fprintf(stdout, "\t");
    	
    	movement_parser_print(stdout, mvt_code, ctx->dico_labels);        
    	fprintf(stdout, "\t");
    	feature_table_argmax_1_2(fv, ft, &argmax1, &max1, &argmax2, &max2);
    	printf("%f\n", max1 - max2);
          }
    
          if(ctx->partial_mode){
    	vcode *vcode_array = feature_table_get_vcode_array(fv, ft);
            mvt_code = 0;
    	for(int i=0; i < ft->classes_nb; i++){
              int b1 = respect_standard_constraint(c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
              int b2 = respect_stack_constraint(ctx->partial_mode, c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
              int b3 = respect_buffer_constraint(ctx->partial_mode, c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
    
              if(b1 && b2 && b3){
                mvt_code = vcode_array[i].class_code;
                break;
              }
    	}
    	free(vcode_array);
          }
    
         
          mvt_type = movement_parser_type(mvt_code);
          mvt_label = movement_parser_label(mvt_code);
    
          if((mvt_type == MVT_PARSER_EOS) && (word_get_sent_seg(stack_top(config_get_stack(c))) == 0)){
    	if(ctx->verbose)
    	  fprintf(stderr, "the classifier did predict EOS but this is not the case\n");
    	feature_table_argmax_1_2(fv, ft, &argmax1, &max1, &argmax2, &max2);
    	mvt_code = argmax2;
    	mvt_type = movement_parser_type(mvt_code);
    	mvt_label = movement_parser_label(mvt_code);
    	
          }
          
          result = 0;
          switch(mvt_type){
          case MVT_PARSER_LEFT :
    	word_scored = stack_top(config_get_stack(c));
    	result = movement_parser_left_arc(c, mvt_label);
    	break;
          case MVT_PARSER_RIGHT:
    	word_scored = word_buffer_b0(config_get_buffer(c));
    	result = movement_parser_right_arc(c, mvt_label);
    	break;
          case MVT_PARSER_REDUCE:
    	word_scored = stack_top(config_get_stack(c));
    	result = movement_parser_reduce(c);
    	break;
          case MVT_PARSER_ROOT:
    	word_scored = stack_top(config_get_stack(c));
    	result = movement_parser_root(c, root_label);
    	break;
          case MVT_PARSER_EOS:
    	result = movement_parser_eos(c);
    	break;
          case MVT_PARSER_SHIFT:
    	word_scored = word_buffer_b0(config_get_buffer(c));
    	result = movement_parser_shift(c);
          }
          
          if(result == 0){
    	if(ctx->debug_mode) fprintf(stdout, "WARNING : movement cannot be executed doing a SHIFT instead !\n");
    	result = movement_parser_shift(c);
    	if(result == 0){ /* SHIFT failed no more words to read, let's get out of here ! */
    	  if(ctx->debug_mode) fprintf(stdout, "WARNING : cannot exectue a SHIFT emptying stack !\n");
    	  while(!stack_is_empty(config_get_stack(c)))
    	    movement_parser_root(c, root_label);
    	}
          }else{
    	 if(ctx->score_method > 0){
    	   score = confidence_score(mvt_code,feature_table_get_vcode_array(fv,ft),ft->classes_nb,ctx,c);
    	   switch(mvt_type){
    	   case MVT_PARSER_LEFT :
    	   case MVT_PARSER_RIGHT :
    	   case MVT_PARSER_ROOT :
    //	     printf("dep score: %d %d!!\n", word_get_form(word_scored), (int)(score*1000));
    	     word_set_S(word_scored,(int)(score*1000));
    	     break;
    	   case MVT_PARSER_REDUCE:
    	   case MVT_PARSER_SHIFT:
    //	     printf("pop/shift score: %d  %d!!\n", word_get_form(word_scored), (int)(score*1000));
    	     word_set_T(word_scored,(int)(score*1000));
    	     break;
    	   default:
    	     break;
    	   }
    	 }
          }
        }
      }
      
      if(!ctx->trace_mode)
        print_word_buffer(c, ctx->dico_labels, ctx->mcd_struct);
      
      config_free(c); 
      feat_vec_free(fv);
      feature_table_free(ft);
      if(ctx->input_filename)
        fclose(f);
    }