From 89a7f9f3e4e3a1a5e73920da3b5c474c01100636 Mon Sep 17 00:00:00 2001 From: Alexis Nasr <alexis.nasr@lif.univ-mrs.fr> Date: Thu, 13 Oct 2016 22:19:50 -0400 Subject: [PATCH] arc eager almost ready got better results than arc standard --- maca_common/src/word_buffer.c | 5 -- .../src/maca_trans_parser_mcf2cff.c | 90 ++++++------------- .../src/movement_parser_arc_eager.c | 42 +++++++-- .../src/movement_parser_arc_eager.h | 17 ++-- .../src/oracle_parser_arc_eager.c | 49 ++++++---- .../src/oracle_parser_arc_eager.h | 2 +- maca_trans_parser/src/simple_decoder_parser.c | 16 ++-- 7 files changed, 117 insertions(+), 104 deletions(-) diff --git a/maca_common/src/word_buffer.c b/maca_common/src/word_buffer.c index a22ae9e..61695c8 100644 --- a/maca_common/src/word_buffer.c +++ b/maca_common/src/word_buffer.c @@ -13,15 +13,10 @@ word_buffer *word_buffer_new(FILE *input_file, mcd *mcd_struct, int lookahead) wb->array = (word **)memalloc(wb->size * sizeof(word *)); wb->current_index = 0; - - word_buffer_add(wb, word_new(NULL)); /* add dummy token */ - /* load lookahead next words */ wb->lookahead = lookahead; for(i=0; i <= lookahead; i++) word_buffer_read_next_word(wb); - - word_buffer_move_right(wb); /* pass dummy token */ return wb; } diff --git a/maca_trans_parser/src/maca_trans_parser_mcf2cff.c b/maca_trans_parser/src/maca_trans_parser_mcf2cff.c index 6705393..cb3d6b7 100644 --- a/maca_trans_parser/src/maca_trans_parser_mcf2cff.c +++ b/maca_trans_parser/src/maca_trans_parser_mcf2cff.c @@ -54,9 +54,8 @@ void generate_training_file_stream(FILE *output_file, context *ctx) int eos_label = dico_string2int(ctx->dico_labels, "eos"); word_buffer *ref = word_buffer_load_mcf(ctx->input_filename, ctx->mcd_struct); FILE *mcf_file = myfopen(ctx->input_filename, "r"); - int start_sentence_index = 1; + int start_sentence_index = 0; - /* create an mcd that corresponds to ctx->mcd_struct, but without gov and label */ /* the idea is to ignore syntax in the mcf file that will be read */ /* it is ugly !!! */ @@ -65,28 +64,44 @@ void generate_training_file_stream(FILE *output_file, context *ctx) mcd_remove_wf_column(mcd_struct_hyp, MCD_WF_GOV); mcd_remove_wf_column(mcd_struct_hyp, MCD_WF_LABEL); - c = config_initial(mcf_file, mcd_struct_hyp, 5); + c = config_initial_no_dummy_word(mcf_file, mcd_struct_hyp, 5); while(!word_buffer_end(ref)){ /*printf("************ REF ************\n"); word_buffer_print(stdout, ref); - printf("*****************************\n");*/ + printf("*****************************\n");*/ - /* printf("*****************************\n"); */ config2feat_vec_cff(ctx->features_model, c, ctx->d_perceptron_features, fv, ctx->mode); /* feat_vec_print(stdout, fv); */ - mvt_code = oracle_parser_arc_eager(c, ref, start_sentence_index); + mvt_code = oracle_parser_arc_eager(c, ref, start_sentence_index, root_label); mvt_type = movement_type(mvt_code); mvt_label = movement_label(mvt_code); - /* config_print(stdout,c); */ - /* movement_print(stdout, mvt_code, ctx->dico_labels); */ + /* config_print(stdout,c); */ + /* movement_print(stdout, mvt_code, ctx->dico_labels); */ fprintf(output_file, "%d", mvt_code); feat_vec_print(output_file, fv); + + if(mvt_type == MVT_EOS){ + /* printf("************BEFORE *****************\n"); */ + /* config_print(stdout,c); */ + + movement_eos(c, 0); + + /* printf("************AFTER*****************\n"); */ + /* config_print(stdout,c); */ + start_sentence_index = word_get_index(word_buffer_b0(config_get_buffer(c))) - 1; + /* printf("%d\n", start_sentence_index); */ + + if(word_buffer_is_last(ref)){ + /* printf("it is the end\n"); */ + break; + } + } if(mvt_type == MVT_LEFT){ movement_left_arc(c, mvt_label, 0); @@ -96,59 +111,6 @@ void generate_training_file_stream(FILE *output_file, context *ctx) if(mvt_type == MVT_RIGHT){ movement_right_arc(c, mvt_label, 0); word_buffer_move_right(ref); - if((mvt_label == eos_label)){ /* sentence is complete */ -#if 0 - ======= - while((ref = sentence_read(conll_file_ref , ctx->mcd_struct)) && (sentence_nb < ctx->sent_nb)){ - /* sentence_print(stdout, ref, ctx->dico_labels); */ - while(1){ - /* config_print(stdout,c); */ - config2feat_vec_cff(ctx->features_model, c, ctx->d_perceptron_features, fv, ctx->mode); - - /* feat_vec_print(stdout, fv); */ - - mvt_code = oracle_parser(c, ref); - - mvt_type = movement_type(mvt_code); - mvt_label = movement_label(mvt_code); - - /* printf("mvt code = %d\n", mvt_code); */ - /* movement_print(stdout, mvt_code, ctx->dico_labels); */ - - fprintf(output_file, "%d", mvt_code); - feat_vec_print(output_file, fv); - - if(queue_is_empty(c->bf)) break; - - if((mvt_type == MVT_RIGHT) && (mvt_label == root_label)){ /* sentence is complete */ - - /* create the root arc */ - movement_right_arc(c, mvt_label, 0); - - /* shift dummy word in stack */ - movement_shift(c, 1, 0); - - /* printf("sentence complete config : "); - config_print(stdout,c); */ - - /* empty depset */ - depset_free(c->ds); - c->ds = depset_new(); - sentence_free(ref); ->>>>>>> master -#endif - - sentence_nb++; - start_sentence_index = word_get_index(word_buffer_b0(config_get_buffer(c))) - 1; - /* printf("%d\n", start_sentence_index); */ - - /* printf("*****************************\n"); */ - /* config_print(stdout,c); */ - if(word_buffer_is_last(ref)){ - /* printf("it is the end\n"); */ - break; - } - } continue; } @@ -156,6 +118,12 @@ void generate_training_file_stream(FILE *output_file, context *ctx) movement_reduce(c, 0); continue; } + + if(mvt_type == MVT_ROOT){ + movement_root(c, 0, root_label); + continue; + } + if(mvt_type == MVT_SHIFT){ movement_shift(c, 1, 0); word_buffer_move_right(ref); diff --git a/maca_trans_parser/src/movement_parser_arc_eager.c b/maca_trans_parser/src/movement_parser_arc_eager.c index 0481cb0..df00df7 100644 --- a/maca_trans_parser/src/movement_parser_arc_eager.c +++ b/maca_trans_parser/src/movement_parser_arc_eager.c @@ -10,6 +10,7 @@ void movement_print(FILE *f, int mvt_code, dico *dico_labels){ char *label; if(mvt_type == MVT_SHIFT) {fprintf(f, "SHIFT\n"); return;} if(mvt_type == MVT_REDUCE) {fprintf(f, "REDUCE\n"); return;} + if(mvt_type == MVT_ROOT) {fprintf(f, "ROOT\n"); return;} if(mvt_type == MVT_RIGHT) fprintf(f, "RIGHT"); else fprintf(f, "LEFT"); label = dico_int2string(dico_labels, mvt_label); @@ -20,6 +21,8 @@ int movement_type(int mvt) { if(mvt == 0) return MVT_SHIFT; /* 0 is the code of shift */ if(mvt == 1) return MVT_REDUCE; /* 1 is the code of reduce */ + if(mvt == 2) return MVT_ROOT; /* 2 is the code of root */ + if(mvt == 3) return MVT_EOS; /* 2 is the code of root */ if(mvt % 2 == 0) return MVT_LEFT; /* even movements are left movements */ return MVT_RIGHT; /* odd movements are right movements */ } @@ -28,15 +31,28 @@ int movement_label(int mvt) { if(mvt == 0) return -1; /* 0 is the code of shift */ if(mvt == 1) return -1; /* 1 is the code of reduce */ + if(mvt == 2) return -1; /* 2 is the code of root */ + if(mvt == 3) return -1; /* 3 is the code of eos */ if(mvt % 2 == 0) /* even codes correspond to left movements */ - return mvt / 2 - 1; - return (mvt - 1) / 2 - 1; /* odd codes correspond to right movements */ + return mvt / 2 - 2; + return (mvt - 1) / 2 - 2; /* odd codes correspond to right movements */ +} + +int movement_eos(config *c, float score) +{ + /* perform all pending reduce */ + while(movement_reduce(c,0)); + + /* remove root from stack */ + stack_pop(config_get_stack(c)); + return 1; } int movement_left_arc(config *c, int label, float score) { - if(stack_height(config_get_stack(c)) < 2) return 0; /* the dummy word cannot be a dependent */ - if(word_buffer_is_empty(config_get_buffer(c))) return 0; + /* the dummy word cannot be a dependent */ + /* if(stack_height(config_get_stack(c)) < 2) return 0; */ + /* if(word_buffer_is_empty(config_get_buffer(c))) return 0; */ /* word on top of the stack should not have a governor */ if(word_get_gov(stack_top(config_get_stack(c))) != 0) return 0; @@ -57,14 +73,13 @@ int movement_left_arc(config *c, int label, float score) int movement_right_arc(config *c, int label, float score) { - /* printf("RA "); */ if(stack_is_empty(config_get_stack(c))) return 0; - if(word_buffer_is_empty(config_get_buffer(c))) return 0; - + /* if(word_buffer_is_empty(config_get_buffer(c))) return 0; */ + word *gov = stack_top(config_get_stack(c)); word *dep = word_buffer_b0(config_get_buffer(c)); int dist = (word_get_index(gov)) - (word_get_index(dep)); - + /* printf("create right arc %d -> %d dist = %d\n", word_get_index(gov), word_get_index(dep), dist); */ @@ -98,3 +113,14 @@ int movement_reduce(config *c, float score) return 1; } +int movement_root(config *c, float score, int root_code) +{ + word *b0 = word_buffer_b0(config_get_buffer(c)); + word_set_gov(b0, 0); + word_set_label(b0, root_code); + /* stack_push(config_get_stack(c), b0); */ + /* word_buffer_move_right(config_get_buffer(c)); */ + config_add_mvt(c, MVT_ROOT); + return 1; +} + diff --git a/maca_trans_parser/src/movement_parser_arc_eager.h b/maca_trans_parser/src/movement_parser_arc_eager.h index d117c5c..63f2f8a 100644 --- a/maca_trans_parser/src/movement_parser_arc_eager.h +++ b/maca_trans_parser/src/movement_parser_arc_eager.h @@ -7,15 +7,17 @@ #define MVT_SHIFT 0 #define MVT_REDUCE 1 -#define MVT_LEFT 2 -#define MVT_RIGHT 3 +#define MVT_ROOT 2 +#define MVT_EOS 3 +#define MVT_LEFT 4 +#define MVT_RIGHT 5 -/* even movements are left movements (except 0, which is shift) */ -#define movement_left_code(label) (2 * (label) + 2) +/* even movements are left movements (except 0, which is shift and 2 which is root) */ +#define movement_left_code(label) (2 * (label) + 4) -/* odd movements are right movements (except 1, which is reduce) */ -#define movement_right_code(label) (2 * (label) + 3) +/* odd movements are right movements (except 1, which is reduce and 3 which is end_of_sentence) */ +#define movement_right_code(label) (2 * (label) + 5) int movement_type(int mvt); int movement_label(int mvt); @@ -24,7 +26,8 @@ int movement_left_arc(config *c, int label, float score); int movement_right_arc(config *c, int label, float score); int movement_shift(config *c, int stream, float score); int movement_reduce(config *c, float score); - +int movement_root(config *c, float score, int root_code); +int movement_eos(config *c, float score); void movement_print(FILE *f, int mvt_code, dico *dico_labels); #endif diff --git a/maca_trans_parser/src/oracle_parser_arc_eager.c b/maca_trans_parser/src/oracle_parser_arc_eager.c index 082130a..2e41c60 100644 --- a/maca_trans_parser/src/oracle_parser_arc_eager.c +++ b/maca_trans_parser/src/oracle_parser_arc_eager.c @@ -9,7 +9,8 @@ int check_all_dependents_of_word_in_ref_are_in_hyp(config *c, word_buffer *ref, int dep; int gov_ref; int gov_hyp; - int max = ((start_sentence_index + 500) > ref->nbelem)? ref->nbelem : (start_sentence_index + 500); + int max_sent_length = 300; + int max = ((start_sentence_index + max_sent_length) > ref->nbelem)? ref->nbelem : (start_sentence_index + max_sent_length); for(dep=start_sentence_index; dep < max; dep++){ gov_ref = word_get_gov_index(word_buffer_get_word_n(ref, dep)); if(gov_ref == word_index){ /* found a dependent of word in ref */ @@ -30,26 +31,41 @@ int check_all_dependents_of_word_in_ref_are_in_hyp(config *c, word_buffer *ref, return 1; } -int oracle_parser_arc_eager(config *c, word_buffer *ref, int start_sentence_index) +int oracle_parser_arc_eager(config *c, word_buffer *ref, int start_sentence_index, int root_label) { word *s0; /* word on top of stack */ word *b0; /* next word in the bufer */ int s0_index, b0_index; int s0_gov_index, b0_gov_index; + int b0_label; + int b0_label_in_hyp; - if(!stack_is_empty(config_get_stack(c)) && !word_buffer_is_empty(config_get_buffer(c))){ + b0 = word_buffer_b0(config_get_buffer(c)); + b0_index = word_get_index(b0); + b0_gov_index = word_get_gov_index(word_buffer_get_word_n(ref, b0_index)); + b0_label = word_get_label(word_buffer_get_word_n(ref, b0_index)); + b0_label_in_hyp = word_get_label(word_buffer_get_word_n(config_get_buffer(c), b0_index)); + + /* b0 is the root of the sentence */ + if((b0_label == root_label) && (b0_label_in_hyp != root_label)){ + return MVT_ROOT; + } + + /* if(!stack_is_empty(config_get_stack(c)) && !word_buffer_is_empty(config_get_buffer(c))){ */ + if(!stack_is_empty(config_get_stack(c))){ s0 = stack_top(config_get_stack(c)); s0_index = word_get_index(s0); s0_gov_index = word_get_gov_index(word_buffer_get_word_n(ref, s0_index)); - b0 = word_buffer_b0(config_get_buffer(c)); - b0_index = word_get_index(b0); - b0_gov_index = word_get_gov_index(word_buffer_get_word_n(ref, b0_index)); - /* printf("s0_index = %d b0_index = %d\n", s0_index, b0_index); - printf("dans ref gov de s0 (%d) = %d\n", s0_index, s0_gov_index); - printf("dans ref gov de b0 (%d) = %d\n", b0_index, b0_gov_index);*/ - + printf("dans ref gov de s0 (%d) = %d\n", s0_index, s0_gov_index); + printf("dans ref gov de b0 (%d) = %d\n", b0_index, b0_gov_index);*/ + + /* word on the top of the stack is an end of sentence marker */ + if(word_get_sent_seg(s0) == 1){ + return MVT_EOS; + } + /* LEFT ARC b0 is the governor and s0 the dependent */ if(s0_gov_index == b0_index){ return movement_left_code(word_get_label(word_buffer_get_word_n(ref, s0_index))); @@ -60,15 +76,16 @@ int oracle_parser_arc_eager(config *c, word_buffer *ref, int start_sentence_inde return movement_right_code(word_get_label(word_buffer_get_word_n(ref, b0_index))); } /* REDUCE */ - if((stack_height(config_get_stack(c)) > 2) - && check_all_dependents_of_word_in_ref_are_in_hyp(c, ref, s0_index, start_sentence_index) + if( + /* (stack_height(config_get_stack(c)) > 2) */ + check_all_dependents_of_word_in_ref_are_in_hyp(c, ref, s0_index, start_sentence_index) && (word_get_gov(stack_top(config_get_stack(c))) != 0)) /* word on top of the stack has a goveror */ { return MVT_REDUCE; } - - /* SHIFT */ - return MVT_SHIFT; } - return -1; + + /* SHIFT */ + return MVT_SHIFT; + } diff --git a/maca_trans_parser/src/oracle_parser_arc_eager.h b/maca_trans_parser/src/oracle_parser_arc_eager.h index 4f75682..9b92d5d 100644 --- a/maca_trans_parser/src/oracle_parser_arc_eager.h +++ b/maca_trans_parser/src/oracle_parser_arc_eager.h @@ -6,6 +6,6 @@ #include"word_buffer.h" -int oracle_parser_arc_eager(config *c, word_buffer *ref, int start_sentence_index); +int oracle_parser_arc_eager(config *c, word_buffer *ref, int start_sentence_index, int root_label); #endif diff --git a/maca_trans_parser/src/simple_decoder_parser.c b/maca_trans_parser/src/simple_decoder_parser.c index 81aea38..b6f8888 100644 --- a/maca_trans_parser/src/simple_decoder_parser.c +++ b/maca_trans_parser/src/simple_decoder_parser.c @@ -20,16 +20,16 @@ void simple_decoder_stream(context *ctx, FILE *f, feature_table *ft, int root_la config *c = NULL; word *dep; - c = config_initial(f, ctx->mcd_struct, 5); + c = config_initial_no_dummy_word(f, ctx->mcd_struct, 5); while(1){ config2feat_vec_cff(ctx->features_model, c, ctx->d_perceptron_features, fv, LOOKUP_MODE); mvt_code = feature_table_argmax(fv, ft, &max); mvt_type = movement_type(mvt_code); mvt_label = movement_label(mvt_code); - - /* config_print(stdout, c); */ - /* movement_print(stdout, mvt_code, ctx->dico_labels); */ - + + /* config_print(stdout, c); */ + /* movement_print(stdout, mvt_code, ctx->dico_labels); */ + if(mvt_type == MVT_LEFT){ dep = stack_s0(config_get_stack(c)); if(movement_left_arc(c, mvt_label, max)){ @@ -55,13 +55,17 @@ void simple_decoder_stream(context *ctx, FILE *f, feature_table *ft, int root_la if(mvt_type == MVT_REDUCE) if(movement_reduce(c, max)) continue; + + if(mvt_type == MVT_ROOT) + if(movement_root(c, max, root_label)) + continue; movement_shift(c, 1, max); if(word_buffer_is_last(config_get_buffer(c))) break; } - for(int i=1; i < config_get_buffer(c)->nbelem; i++){ + for(int i=0; i < config_get_buffer(c)->nbelem; i++){ dep = word_buffer_get_word_n(config_get_buffer(c), i); printf("%s\t", word_get_input(dep)); printf("%d\t", word_get_gov(dep)); -- GitLab