Select Git revision
-
Benoit Favre authoredBenoit Favre authored
simple_decoder_parser_arc_eager.c 10.93 KiB
#include<stdio.h>
#include<stdlib.h>
#include<string.h>
#include<unistd.h>
#include<getopt.h>
#include"context.h"
#include<math.h>
#include"movement_parser_arc_eager.h"
#include"feat_fct.h"
#include"config2feat_vec.h"
#include"feature_table.h"
#include"dico.h"
#include"partial_parser_conditional.h"
#include"confidence_score.h"
void print_word_buffer_old(config *c, dico *dico_labels, mcd *mcd_struct)
{
int i;
word *dep;
char *label;
for(i=0; i < config_get_buffer(c)->nbelem; i++){
dep = word_buffer_get_word_n(config_get_buffer(c), i);
printf("%s\t", word_get_input(dep));
printf("%d\t", word_get_gov(dep));
label = (word_get_label(dep) == -1)? NULL : dico_int2string(dico_labels, word_get_label(dep));
if(label != NULL)
printf("%s\t", label) ;
else
printf("_\t");
if(mcd_get_sent_seg_col(mcd_struct) == -1){
if(word_get_sent_seg(dep) == 1)
printf("1") ;
else
printf("0");
}
printf("\n");
}
}
void print_word_buffer(config *c, dico *dico_labels, mcd *mcd_struct)
{
int i;
word *w;
char *label;
char *buffer = NULL;
char *token = NULL;
int col_nb = 0;
for(i=0; i < config_get_buffer(c)->nbelem; i++){
w = word_buffer_get_word_n(config_get_buffer(c), i);
if((mcd_get_gov_col(mcd_struct) == -1)
&& (mcd_get_label_col(mcd_struct) == -1)
&& (mcd_get_sent_seg_col(mcd_struct) == -1)){
printf("%s\t", word_get_input(w));
printf("%d\t", word_get_gov(w));
label = (word_get_label(w) == -1)? NULL : dico_int2string(dico_labels, word_get_label(w));
if(label != NULL)
printf("%s\t", label) ;
else
printf("_\t");
if(word_get_sent_seg(w) == 1)
printf("1\n") ;
else
printf("0\n");
}
else{
buffer = strdup(w->input);
token = strtok(buffer, "\t");
col_nb = 0;
while(token){
if(col_nb != 0) printf("\t");
if(col_nb == mcd_get_gov_col(mcd_struct)){
printf("%d", word_get_gov(w));
}
else
if(col_nb == mcd_get_label_col(mcd_struct)){
label = (word_get_label(w) == -1)? NULL : dico_int2string(dico_labels, word_get_label(w));
if(label != NULL)
printf("%s", label) ;
else
printf("_");
}
else
if(col_nb == mcd_get_sent_seg_col(mcd_struct)){
if(word_get_sent_seg(w) == 1)
printf("1") ;
else
printf("0");
}
else{
word_print_col_n(stdout, w, col_nb);
}
col_nb++;
token = strtok(NULL, "\t");
}
if((col_nb <= mcd_get_gov_col(mcd_struct)) || (mcd_get_gov_col(mcd_struct) == -1)){
printf("\t%d", word_get_gov(w));
}
if((col_nb <= mcd_get_label_col(mcd_struct)) || (mcd_get_label_col(mcd_struct) == -1)){
label = (word_get_label(w) == -1)? NULL : dico_int2string(dico_labels, word_get_label(w));
if(label != NULL)
printf("\t%s", label) ;
else
printf("\t_");
}
if((col_nb <= mcd_get_sent_seg_col(mcd_struct)) || (mcd_get_sent_seg_col(mcd_struct) == -1)){
if(word_get_sent_seg(w) == 1)
printf("\t1") ;
else
printf("\t0");
}
if(col_nb <= mcd_get_s_col(mcd_struct)){
if(word_get_S(w) > 0)
printf("\t%d",word_get_S(w));
else
printf("\t-1");
}
if(col_nb <= mcd_get_t_col(mcd_struct)){
if(word_get_T(w) > 0)
printf("\t%d",word_get_T(w));
else
printf("\t-1");
}
printf("\n");
free(buffer);
}
}
}
void simple_decoder_parser_arc_eager(context *ctx)
{
FILE *f = (ctx->input_filename)? myfopen(ctx->input_filename, "r") : stdin;
feature_table *ft = feature_table_load(ctx->perc_model_filename, ctx->verbose);
int root_label;
int mvt_code;
int mvt_type;
int mvt_label;
float max;
feat_vec *fv = feat_vec_new(feature_types_nb);
config *c = NULL;
int result;
/* float entropy; */
/* float delta; */
int argmax1, argmax2;
float max1, max2;
int index;
float score;
double sumExp = -1.;
double currentSumExp = 0.;
double ScoreTranslation = -150.;
word* word_scored;
root_label = dico_string2int(ctx->dico_labels, ctx->root_label);
if(root_label == -1) root_label = 0;
c = config_new(f, ctx->mcd_struct, 5);
while(!config_is_terminal(c)){
if(ctx->debug_mode){
fprintf(stdout, "***********************************\n");
config_print(stdout, c);
}
/* forced EOS (the element on the top of the stack is eos, but the preceding movement is not MVT_PARSER_EOS */
/* which means that the top of the stack got its eos status from input */
/* force the parser to finish parsing the sentence (perform all pending reduce actions) and determine root of the sentence */
if((word_get_sent_seg(stack_top(config_get_stack(c))) == 1) && (mvt_get_type(mvt_stack_top(config_get_history(c))) != MVT_PARSER_EOS)){
word_set_sent_seg(stack_top(config_get_stack(c)), -1);
movement_parser_eos(c);
while(movement_parser_reduce(c));
while(movement_parser_root(c, root_label));
if(ctx->debug_mode) printf("force EOS\n");
}
/* normal behaviour, ask classifier what is the next movement to do and do it */
else{
config2feat_vec_cff(ctx->features_model, c, ctx->d_perceptron_features, fv, LOOKUP_MODE);
mvt_code = feature_table_argmax(fv, ft, &max);
if(ctx->debug_mode){
vcode *vcode_array = feature_table_get_vcode_array(fv, ft);
/* Get the probabilistic parameters */
for(int i=0; i < ft->classes_nb; i++){
int b1 = respect_standard_constraint(c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
int b2 = respect_stack_constraint(ctx->partial_mode, c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
int b3 = respect_buffer_constraint(ctx->partial_mode, c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
if(b1 && b2 && b3){
if(sumExp < 0.){
ScoreTranslation += vcode_array[i].score;
sumExp = 0.;
}
if(vcode_array[i].score - ScoreTranslation >= 0){
sumExp += exp(vcode_array[i].score - ScoreTranslation);
}
}
}
currentSumExp = 0.;
for(int i=0; i < ft->classes_nb && i < 1000; i++){
printf("%d\t", i);
movement_parser_print(stdout, vcode_array[i].class_code, ctx->dico_labels);
printf("\t%.4f", vcode_array[i].score);
fflush(stdout);
int b1 = respect_standard_constraint(c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
int b2 = respect_stack_constraint(ctx->partial_mode, c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
int b3 = respect_buffer_constraint(ctx->partial_mode, c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
if(b1 && b2 && b3){
if(vcode_array[i].score - ScoreTranslation >= 0){
printf("%f %f %f",sumExp, currentSumExp,ScoreTranslation);
printf(" [%f-",currentSumExp/sumExp);
currentSumExp += exp(vcode_array[i].score - ScoreTranslation);
printf("%f[", currentSumExp/sumExp);
}
printf("\t<----");
}else
printf("\t<%d,%d,%d>",b1,b2,b3);
// printf("\t%d", respect_standard_constraint(c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code)));
//printf("AAAAAAA\n");
printf("\n");
}
free(vcode_array);
}
if(ctx->trace_mode){
index = word_get_index(word_buffer_b0(config_get_buffer(c)));
fprintf(stdout, "%d\t", index);
stack_print(stdout, c->st);
fprintf(stdout, "\t");
movement_parser_print(stdout, mvt_code, ctx->dico_labels);
fprintf(stdout, "\t");
feature_table_argmax_1_2(fv, ft, &argmax1, &max1, &argmax2, &max2);
printf("%f\n", max1 - max2);
}
if(ctx->partial_mode){
vcode *vcode_array = feature_table_get_vcode_array(fv, ft);
mvt_code = 0;
for(int i=0; i < ft->classes_nb; i++){
int b1 = respect_standard_constraint(c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
int b2 = respect_stack_constraint(ctx->partial_mode, c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
int b3 = respect_buffer_constraint(ctx->partial_mode, c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
if(b1 && b2 && b3){
mvt_code = vcode_array[i].class_code;
break;
}
}
free(vcode_array);
}
mvt_type = movement_parser_type(mvt_code);
mvt_label = movement_parser_label(mvt_code);
if((mvt_type == MVT_PARSER_EOS) && (word_get_sent_seg(stack_top(config_get_stack(c))) == 0)){
if(ctx->verbose)
fprintf(stderr, "the classifier did predict EOS but this is not the case\n");
feature_table_argmax_1_2(fv, ft, &argmax1, &max1, &argmax2, &max2);
mvt_code = argmax2;
mvt_type = movement_parser_type(mvt_code);
mvt_label = movement_parser_label(mvt_code);
}
result = 0;
switch(mvt_type){
case MVT_PARSER_LEFT :
word_scored = stack_top(config_get_stack(c));
result = movement_parser_left_arc(c, mvt_label);
break;
case MVT_PARSER_RIGHT:
word_scored = word_buffer_b0(config_get_buffer(c));
result = movement_parser_right_arc(c, mvt_label);
break;
case MVT_PARSER_REDUCE:
word_scored = stack_top(config_get_stack(c));
result = movement_parser_reduce(c);
break;
case MVT_PARSER_ROOT:
word_scored = stack_top(config_get_stack(c));
result = movement_parser_root(c, root_label);
break;
case MVT_PARSER_EOS:
result = movement_parser_eos(c);
break;
case MVT_PARSER_SHIFT:
word_scored = word_buffer_b0(config_get_buffer(c));
result = movement_parser_shift(c);
}
if(result == 0){
if(ctx->debug_mode) fprintf(stdout, "WARNING : movement cannot be executed doing a SHIFT instead !\n");
result = movement_parser_shift(c);
if(result == 0){ /* SHIFT failed no more words to read, let's get out of here ! */
if(ctx->debug_mode) fprintf(stdout, "WARNING : cannot exectue a SHIFT emptying stack !\n");
while(!stack_is_empty(config_get_stack(c)))
movement_parser_root(c, root_label);
}
}else{
if(ctx->score_method > 0){
score = confidence_score(mvt_code,feature_table_get_vcode_array(fv,ft),ft->classes_nb,ctx,c);
switch(mvt_type){
case MVT_PARSER_LEFT :
case MVT_PARSER_RIGHT :
case MVT_PARSER_ROOT :
// printf("dep score: %d %d!!\n", word_get_form(word_scored), (int)(score*1000));
word_set_S(word_scored,(int)(score*1000));
break;
case MVT_PARSER_REDUCE:
case MVT_PARSER_SHIFT:
// printf("pop/shift score: %d %d!!\n", word_get_form(word_scored), (int)(score*1000));
word_set_T(word_scored,(int)(score*1000));
break;
default:
break;
}
}
}
}
}
if(!ctx->trace_mode)
print_word_buffer(c, ctx->dico_labels, ctx->mcd_struct);
config_free(c);
feat_vec_free(fv);
feature_table_free(ft);
if(ctx->input_filename)
fclose(f);
}