Skip to content
Snippets Groups Projects
Commit 89a7f9f3 authored by Alexis Nasr's avatar Alexis Nasr
Browse files

arc eager almost ready got better results than arc standard

parent c57a2280
No related branches found
No related tags found
No related merge requests found
...@@ -13,16 +13,11 @@ word_buffer *word_buffer_new(FILE *input_file, mcd *mcd_struct, int lookahead) ...@@ -13,16 +13,11 @@ word_buffer *word_buffer_new(FILE *input_file, mcd *mcd_struct, int lookahead)
wb->array = (word **)memalloc(wb->size * sizeof(word *)); wb->array = (word **)memalloc(wb->size * sizeof(word *));
wb->current_index = 0; wb->current_index = 0;
word_buffer_add(wb, word_new(NULL)); /* add dummy token */
/* load lookahead next words */ /* load lookahead next words */
wb->lookahead = lookahead; wb->lookahead = lookahead;
for(i=0; i <= lookahead; i++) for(i=0; i <= lookahead; i++)
word_buffer_read_next_word(wb); word_buffer_read_next_word(wb);
word_buffer_move_right(wb); /* pass dummy token */
return wb; return wb;
} }
......
...@@ -54,8 +54,7 @@ void generate_training_file_stream(FILE *output_file, context *ctx) ...@@ -54,8 +54,7 @@ void generate_training_file_stream(FILE *output_file, context *ctx)
int eos_label = dico_string2int(ctx->dico_labels, "eos"); int eos_label = dico_string2int(ctx->dico_labels, "eos");
word_buffer *ref = word_buffer_load_mcf(ctx->input_filename, ctx->mcd_struct); word_buffer *ref = word_buffer_load_mcf(ctx->input_filename, ctx->mcd_struct);
FILE *mcf_file = myfopen(ctx->input_filename, "r"); FILE *mcf_file = myfopen(ctx->input_filename, "r");
int start_sentence_index = 1; int start_sentence_index = 0;
/* create an mcd that corresponds to ctx->mcd_struct, but without gov and label */ /* create an mcd that corresponds to ctx->mcd_struct, but without gov and label */
/* the idea is to ignore syntax in the mcf file that will be read */ /* the idea is to ignore syntax in the mcf file that will be read */
...@@ -65,19 +64,18 @@ void generate_training_file_stream(FILE *output_file, context *ctx) ...@@ -65,19 +64,18 @@ void generate_training_file_stream(FILE *output_file, context *ctx)
mcd_remove_wf_column(mcd_struct_hyp, MCD_WF_GOV); mcd_remove_wf_column(mcd_struct_hyp, MCD_WF_GOV);
mcd_remove_wf_column(mcd_struct_hyp, MCD_WF_LABEL); mcd_remove_wf_column(mcd_struct_hyp, MCD_WF_LABEL);
c = config_initial(mcf_file, mcd_struct_hyp, 5); c = config_initial_no_dummy_word(mcf_file, mcd_struct_hyp, 5);
while(!word_buffer_end(ref)){ while(!word_buffer_end(ref)){
/*printf("************ REF ************\n"); /*printf("************ REF ************\n");
word_buffer_print(stdout, ref); word_buffer_print(stdout, ref);
printf("*****************************\n");*/ printf("*****************************\n");*/
/* printf("*****************************\n"); */
config2feat_vec_cff(ctx->features_model, c, ctx->d_perceptron_features, fv, ctx->mode); config2feat_vec_cff(ctx->features_model, c, ctx->d_perceptron_features, fv, ctx->mode);
/* feat_vec_print(stdout, fv); */ /* feat_vec_print(stdout, fv); */
mvt_code = oracle_parser_arc_eager(c, ref, start_sentence_index); mvt_code = oracle_parser_arc_eager(c, ref, start_sentence_index, root_label);
mvt_type = movement_type(mvt_code); mvt_type = movement_type(mvt_code);
mvt_label = movement_label(mvt_code); mvt_label = movement_label(mvt_code);
...@@ -88,67 +86,31 @@ void generate_training_file_stream(FILE *output_file, context *ctx) ...@@ -88,67 +86,31 @@ void generate_training_file_stream(FILE *output_file, context *ctx)
fprintf(output_file, "%d", mvt_code); fprintf(output_file, "%d", mvt_code);
feat_vec_print(output_file, fv); feat_vec_print(output_file, fv);
if(mvt_type == MVT_LEFT){ if(mvt_type == MVT_EOS){
movement_left_arc(c, mvt_label, 0); /* printf("************BEFORE *****************\n"); */
continue;
}
if(mvt_type == MVT_RIGHT){
movement_right_arc(c, mvt_label, 0);
word_buffer_move_right(ref);
if((mvt_label == eos_label)){ /* sentence is complete */
#if 0
=======
while((ref = sentence_read(conll_file_ref , ctx->mcd_struct)) && (sentence_nb < ctx->sent_nb)){
/* sentence_print(stdout, ref, ctx->dico_labels); */
while(1){
/* config_print(stdout,c); */ /* config_print(stdout,c); */
config2feat_vec_cff(ctx->features_model, c, ctx->d_perceptron_features, fv, ctx->mode);
/* feat_vec_print(stdout, fv); */
mvt_code = oracle_parser(c, ref);
mvt_type = movement_type(mvt_code);
mvt_label = movement_label(mvt_code);
/* printf("mvt code = %d\n", mvt_code); */
/* movement_print(stdout, mvt_code, ctx->dico_labels); */
fprintf(output_file, "%d", mvt_code);
feat_vec_print(output_file, fv);
if(queue_is_empty(c->bf)) break;
if((mvt_type == MVT_RIGHT) && (mvt_label == root_label)){ /* sentence is complete */ movement_eos(c, 0);
/* create the root arc */ /* printf("************AFTER*****************\n"); */
movement_right_arc(c, mvt_label, 0); /* config_print(stdout,c); */
/* shift dummy word in stack */
movement_shift(c, 1, 0);
/* printf("sentence complete config : ");
config_print(stdout,c); */
/* empty depset */
depset_free(c->ds);
c->ds = depset_new();
sentence_free(ref);
>>>>>>> master
#endif
sentence_nb++;
start_sentence_index = word_get_index(word_buffer_b0(config_get_buffer(c))) - 1; start_sentence_index = word_get_index(word_buffer_b0(config_get_buffer(c))) - 1;
/* printf("%d\n", start_sentence_index); */ /* printf("%d\n", start_sentence_index); */
/* printf("*****************************\n"); */
/* config_print(stdout,c); */
if(word_buffer_is_last(ref)){ if(word_buffer_is_last(ref)){
/* printf("it is the end\n"); */ /* printf("it is the end\n"); */
break; break;
} }
} }
if(mvt_type == MVT_LEFT){
movement_left_arc(c, mvt_label, 0);
continue;
}
if(mvt_type == MVT_RIGHT){
movement_right_arc(c, mvt_label, 0);
word_buffer_move_right(ref);
continue; continue;
} }
...@@ -156,6 +118,12 @@ void generate_training_file_stream(FILE *output_file, context *ctx) ...@@ -156,6 +118,12 @@ void generate_training_file_stream(FILE *output_file, context *ctx)
movement_reduce(c, 0); movement_reduce(c, 0);
continue; continue;
} }
if(mvt_type == MVT_ROOT){
movement_root(c, 0, root_label);
continue;
}
if(mvt_type == MVT_SHIFT){ if(mvt_type == MVT_SHIFT){
movement_shift(c, 1, 0); movement_shift(c, 1, 0);
word_buffer_move_right(ref); word_buffer_move_right(ref);
......
...@@ -10,6 +10,7 @@ void movement_print(FILE *f, int mvt_code, dico *dico_labels){ ...@@ -10,6 +10,7 @@ void movement_print(FILE *f, int mvt_code, dico *dico_labels){
char *label; char *label;
if(mvt_type == MVT_SHIFT) {fprintf(f, "SHIFT\n"); return;} if(mvt_type == MVT_SHIFT) {fprintf(f, "SHIFT\n"); return;}
if(mvt_type == MVT_REDUCE) {fprintf(f, "REDUCE\n"); return;} if(mvt_type == MVT_REDUCE) {fprintf(f, "REDUCE\n"); return;}
if(mvt_type == MVT_ROOT) {fprintf(f, "ROOT\n"); return;}
if(mvt_type == MVT_RIGHT) fprintf(f, "RIGHT"); if(mvt_type == MVT_RIGHT) fprintf(f, "RIGHT");
else fprintf(f, "LEFT"); else fprintf(f, "LEFT");
label = dico_int2string(dico_labels, mvt_label); label = dico_int2string(dico_labels, mvt_label);
...@@ -20,6 +21,8 @@ int movement_type(int mvt) ...@@ -20,6 +21,8 @@ int movement_type(int mvt)
{ {
if(mvt == 0) return MVT_SHIFT; /* 0 is the code of shift */ if(mvt == 0) return MVT_SHIFT; /* 0 is the code of shift */
if(mvt == 1) return MVT_REDUCE; /* 1 is the code of reduce */ if(mvt == 1) return MVT_REDUCE; /* 1 is the code of reduce */
if(mvt == 2) return MVT_ROOT; /* 2 is the code of root */
if(mvt == 3) return MVT_EOS; /* 2 is the code of root */
if(mvt % 2 == 0) return MVT_LEFT; /* even movements are left movements */ if(mvt % 2 == 0) return MVT_LEFT; /* even movements are left movements */
return MVT_RIGHT; /* odd movements are right movements */ return MVT_RIGHT; /* odd movements are right movements */
} }
...@@ -28,15 +31,28 @@ int movement_label(int mvt) ...@@ -28,15 +31,28 @@ int movement_label(int mvt)
{ {
if(mvt == 0) return -1; /* 0 is the code of shift */ if(mvt == 0) return -1; /* 0 is the code of shift */
if(mvt == 1) return -1; /* 1 is the code of reduce */ if(mvt == 1) return -1; /* 1 is the code of reduce */
if(mvt == 2) return -1; /* 2 is the code of root */
if(mvt == 3) return -1; /* 3 is the code of eos */
if(mvt % 2 == 0) /* even codes correspond to left movements */ if(mvt % 2 == 0) /* even codes correspond to left movements */
return mvt / 2 - 1; return mvt / 2 - 2;
return (mvt - 1) / 2 - 1; /* odd codes correspond to right movements */ return (mvt - 1) / 2 - 2; /* odd codes correspond to right movements */
}
int movement_eos(config *c, float score)
{
/* perform all pending reduce */
while(movement_reduce(c,0));
/* remove root from stack */
stack_pop(config_get_stack(c));
return 1;
} }
int movement_left_arc(config *c, int label, float score) int movement_left_arc(config *c, int label, float score)
{ {
if(stack_height(config_get_stack(c)) < 2) return 0; /* the dummy word cannot be a dependent */ /* the dummy word cannot be a dependent */
if(word_buffer_is_empty(config_get_buffer(c))) return 0; /* if(stack_height(config_get_stack(c)) < 2) return 0; */
/* if(word_buffer_is_empty(config_get_buffer(c))) return 0; */
/* word on top of the stack should not have a governor */ /* word on top of the stack should not have a governor */
if(word_get_gov(stack_top(config_get_stack(c))) != 0) return 0; if(word_get_gov(stack_top(config_get_stack(c))) != 0) return 0;
...@@ -57,9 +73,8 @@ int movement_left_arc(config *c, int label, float score) ...@@ -57,9 +73,8 @@ int movement_left_arc(config *c, int label, float score)
int movement_right_arc(config *c, int label, float score) int movement_right_arc(config *c, int label, float score)
{ {
/* printf("RA "); */
if(stack_is_empty(config_get_stack(c))) return 0; if(stack_is_empty(config_get_stack(c))) return 0;
if(word_buffer_is_empty(config_get_buffer(c))) return 0; /* if(word_buffer_is_empty(config_get_buffer(c))) return 0; */
word *gov = stack_top(config_get_stack(c)); word *gov = stack_top(config_get_stack(c));
word *dep = word_buffer_b0(config_get_buffer(c)); word *dep = word_buffer_b0(config_get_buffer(c));
...@@ -98,3 +113,14 @@ int movement_reduce(config *c, float score) ...@@ -98,3 +113,14 @@ int movement_reduce(config *c, float score)
return 1; return 1;
} }
int movement_root(config *c, float score, int root_code)
{
word *b0 = word_buffer_b0(config_get_buffer(c));
word_set_gov(b0, 0);
word_set_label(b0, root_code);
/* stack_push(config_get_stack(c), b0); */
/* word_buffer_move_right(config_get_buffer(c)); */
config_add_mvt(c, MVT_ROOT);
return 1;
}
...@@ -7,15 +7,17 @@ ...@@ -7,15 +7,17 @@
#define MVT_SHIFT 0 #define MVT_SHIFT 0
#define MVT_REDUCE 1 #define MVT_REDUCE 1
#define MVT_LEFT 2 #define MVT_ROOT 2
#define MVT_RIGHT 3 #define MVT_EOS 3
#define MVT_LEFT 4
#define MVT_RIGHT 5
/* even movements are left movements (except 0, which is shift) */ /* even movements are left movements (except 0, which is shift and 2 which is root) */
#define movement_left_code(label) (2 * (label) + 2) #define movement_left_code(label) (2 * (label) + 4)
/* odd movements are right movements (except 1, which is reduce) */ /* odd movements are right movements (except 1, which is reduce and 3 which is end_of_sentence) */
#define movement_right_code(label) (2 * (label) + 3) #define movement_right_code(label) (2 * (label) + 5)
int movement_type(int mvt); int movement_type(int mvt);
int movement_label(int mvt); int movement_label(int mvt);
...@@ -24,7 +26,8 @@ int movement_left_arc(config *c, int label, float score); ...@@ -24,7 +26,8 @@ int movement_left_arc(config *c, int label, float score);
int movement_right_arc(config *c, int label, float score); int movement_right_arc(config *c, int label, float score);
int movement_shift(config *c, int stream, float score); int movement_shift(config *c, int stream, float score);
int movement_reduce(config *c, float score); int movement_reduce(config *c, float score);
int movement_root(config *c, float score, int root_code);
int movement_eos(config *c, float score);
void movement_print(FILE *f, int mvt_code, dico *dico_labels); void movement_print(FILE *f, int mvt_code, dico *dico_labels);
#endif #endif
...@@ -9,7 +9,8 @@ int check_all_dependents_of_word_in_ref_are_in_hyp(config *c, word_buffer *ref, ...@@ -9,7 +9,8 @@ int check_all_dependents_of_word_in_ref_are_in_hyp(config *c, word_buffer *ref,
int dep; int dep;
int gov_ref; int gov_ref;
int gov_hyp; int gov_hyp;
int max = ((start_sentence_index + 500) > ref->nbelem)? ref->nbelem : (start_sentence_index + 500); int max_sent_length = 300;
int max = ((start_sentence_index + max_sent_length) > ref->nbelem)? ref->nbelem : (start_sentence_index + max_sent_length);
for(dep=start_sentence_index; dep < max; dep++){ for(dep=start_sentence_index; dep < max; dep++){
gov_ref = word_get_gov_index(word_buffer_get_word_n(ref, dep)); gov_ref = word_get_gov_index(word_buffer_get_word_n(ref, dep));
if(gov_ref == word_index){ /* found a dependent of word in ref */ if(gov_ref == word_index){ /* found a dependent of word in ref */
...@@ -30,26 +31,41 @@ int check_all_dependents_of_word_in_ref_are_in_hyp(config *c, word_buffer *ref, ...@@ -30,26 +31,41 @@ int check_all_dependents_of_word_in_ref_are_in_hyp(config *c, word_buffer *ref,
return 1; return 1;
} }
int oracle_parser_arc_eager(config *c, word_buffer *ref, int start_sentence_index) int oracle_parser_arc_eager(config *c, word_buffer *ref, int start_sentence_index, int root_label)
{ {
word *s0; /* word on top of stack */ word *s0; /* word on top of stack */
word *b0; /* next word in the bufer */ word *b0; /* next word in the bufer */
int s0_index, b0_index; int s0_index, b0_index;
int s0_gov_index, b0_gov_index; int s0_gov_index, b0_gov_index;
int b0_label;
if(!stack_is_empty(config_get_stack(c)) && !word_buffer_is_empty(config_get_buffer(c))){ int b0_label_in_hyp;
s0 = stack_top(config_get_stack(c));
s0_index = word_get_index(s0);
s0_gov_index = word_get_gov_index(word_buffer_get_word_n(ref, s0_index));
b0 = word_buffer_b0(config_get_buffer(c)); b0 = word_buffer_b0(config_get_buffer(c));
b0_index = word_get_index(b0); b0_index = word_get_index(b0);
b0_gov_index = word_get_gov_index(word_buffer_get_word_n(ref, b0_index)); b0_gov_index = word_get_gov_index(word_buffer_get_word_n(ref, b0_index));
b0_label = word_get_label(word_buffer_get_word_n(ref, b0_index));
b0_label_in_hyp = word_get_label(word_buffer_get_word_n(config_get_buffer(c), b0_index));
/* b0 is the root of the sentence */
if((b0_label == root_label) && (b0_label_in_hyp != root_label)){
return MVT_ROOT;
}
/* if(!stack_is_empty(config_get_stack(c)) && !word_buffer_is_empty(config_get_buffer(c))){ */
if(!stack_is_empty(config_get_stack(c))){
s0 = stack_top(config_get_stack(c));
s0_index = word_get_index(s0);
s0_gov_index = word_get_gov_index(word_buffer_get_word_n(ref, s0_index));
/* printf("s0_index = %d b0_index = %d\n", s0_index, b0_index); /* printf("s0_index = %d b0_index = %d\n", s0_index, b0_index);
printf("dans ref gov de s0 (%d) = %d\n", s0_index, s0_gov_index); printf("dans ref gov de s0 (%d) = %d\n", s0_index, s0_gov_index);
printf("dans ref gov de b0 (%d) = %d\n", b0_index, b0_gov_index);*/ printf("dans ref gov de b0 (%d) = %d\n", b0_index, b0_gov_index);*/
/* word on the top of the stack is an end of sentence marker */
if(word_get_sent_seg(s0) == 1){
return MVT_EOS;
}
/* LEFT ARC b0 is the governor and s0 the dependent */ /* LEFT ARC b0 is the governor and s0 the dependent */
if(s0_gov_index == b0_index){ if(s0_gov_index == b0_index){
return movement_left_code(word_get_label(word_buffer_get_word_n(ref, s0_index))); return movement_left_code(word_get_label(word_buffer_get_word_n(ref, s0_index)));
...@@ -60,15 +76,16 @@ int oracle_parser_arc_eager(config *c, word_buffer *ref, int start_sentence_inde ...@@ -60,15 +76,16 @@ int oracle_parser_arc_eager(config *c, word_buffer *ref, int start_sentence_inde
return movement_right_code(word_get_label(word_buffer_get_word_n(ref, b0_index))); return movement_right_code(word_get_label(word_buffer_get_word_n(ref, b0_index)));
} }
/* REDUCE */ /* REDUCE */
if((stack_height(config_get_stack(c)) > 2) if(
&& check_all_dependents_of_word_in_ref_are_in_hyp(c, ref, s0_index, start_sentence_index) /* (stack_height(config_get_stack(c)) > 2) */
check_all_dependents_of_word_in_ref_are_in_hyp(c, ref, s0_index, start_sentence_index)
&& (word_get_gov(stack_top(config_get_stack(c))) != 0)) /* word on top of the stack has a goveror */ && (word_get_gov(stack_top(config_get_stack(c))) != 0)) /* word on top of the stack has a goveror */
{ {
return MVT_REDUCE; return MVT_REDUCE;
} }
}
/* SHIFT */ /* SHIFT */
return MVT_SHIFT; return MVT_SHIFT;
}
return -1;
} }
...@@ -6,6 +6,6 @@ ...@@ -6,6 +6,6 @@
#include"word_buffer.h" #include"word_buffer.h"
int oracle_parser_arc_eager(config *c, word_buffer *ref, int start_sentence_index); int oracle_parser_arc_eager(config *c, word_buffer *ref, int start_sentence_index, int root_label);
#endif #endif
...@@ -20,7 +20,7 @@ void simple_decoder_stream(context *ctx, FILE *f, feature_table *ft, int root_la ...@@ -20,7 +20,7 @@ void simple_decoder_stream(context *ctx, FILE *f, feature_table *ft, int root_la
config *c = NULL; config *c = NULL;
word *dep; word *dep;
c = config_initial(f, ctx->mcd_struct, 5); c = config_initial_no_dummy_word(f, ctx->mcd_struct, 5);
while(1){ while(1){
config2feat_vec_cff(ctx->features_model, c, ctx->d_perceptron_features, fv, LOOKUP_MODE); config2feat_vec_cff(ctx->features_model, c, ctx->d_perceptron_features, fv, LOOKUP_MODE);
mvt_code = feature_table_argmax(fv, ft, &max); mvt_code = feature_table_argmax(fv, ft, &max);
...@@ -56,12 +56,16 @@ void simple_decoder_stream(context *ctx, FILE *f, feature_table *ft, int root_la ...@@ -56,12 +56,16 @@ void simple_decoder_stream(context *ctx, FILE *f, feature_table *ft, int root_la
if(movement_reduce(c, max)) if(movement_reduce(c, max))
continue; continue;
if(mvt_type == MVT_ROOT)
if(movement_root(c, max, root_label))
continue;
movement_shift(c, 1, max); movement_shift(c, 1, max);
if(word_buffer_is_last(config_get_buffer(c))) break; if(word_buffer_is_last(config_get_buffer(c))) break;
} }
for(int i=1; i < config_get_buffer(c)->nbelem; i++){ for(int i=0; i < config_get_buffer(c)->nbelem; i++){
dep = word_buffer_get_word_n(config_get_buffer(c), i); dep = word_buffer_get_word_n(config_get_buffer(c), i);
printf("%s\t", word_get_input(dep)); printf("%s\t", word_get_input(dep));
printf("%d\t", word_get_gov(dep)); printf("%d\t", word_get_gov(dep));
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment