From 89a7f9f3e4e3a1a5e73920da3b5c474c01100636 Mon Sep 17 00:00:00 2001
From: Alexis Nasr <alexis.nasr@lif.univ-mrs.fr>
Date: Thu, 13 Oct 2016 22:19:50 -0400
Subject: [PATCH] arc eager almost ready got better results than arc standard

---
 maca_common/src/word_buffer.c                 |  5 --
 .../src/maca_trans_parser_mcf2cff.c           | 90 ++++++-------------
 .../src/movement_parser_arc_eager.c           | 42 +++++++--
 .../src/movement_parser_arc_eager.h           | 17 ++--
 .../src/oracle_parser_arc_eager.c             | 49 ++++++----
 .../src/oracle_parser_arc_eager.h             |  2 +-
 maca_trans_parser/src/simple_decoder_parser.c | 16 ++--
 7 files changed, 117 insertions(+), 104 deletions(-)

diff --git a/maca_common/src/word_buffer.c b/maca_common/src/word_buffer.c
index a22ae9e..61695c8 100644
--- a/maca_common/src/word_buffer.c
+++ b/maca_common/src/word_buffer.c
@@ -13,15 +13,10 @@ word_buffer *word_buffer_new(FILE *input_file, mcd *mcd_struct, int lookahead)
   wb->array = (word **)memalloc(wb->size * sizeof(word *));
   wb->current_index = 0;
 
-
-  word_buffer_add(wb, word_new(NULL)); /* add dummy token */
-
   /* load lookahead next words */
   wb->lookahead = lookahead;
   for(i=0; i <= lookahead; i++)
     word_buffer_read_next_word(wb);
-
-  word_buffer_move_right(wb); /* pass dummy token */
   
   return wb;
 }
diff --git a/maca_trans_parser/src/maca_trans_parser_mcf2cff.c b/maca_trans_parser/src/maca_trans_parser_mcf2cff.c
index 6705393..cb3d6b7 100644
--- a/maca_trans_parser/src/maca_trans_parser_mcf2cff.c
+++ b/maca_trans_parser/src/maca_trans_parser_mcf2cff.c
@@ -54,9 +54,8 @@ void generate_training_file_stream(FILE *output_file, context *ctx)
   int eos_label = dico_string2int(ctx->dico_labels, "eos");
   word_buffer *ref = word_buffer_load_mcf(ctx->input_filename, ctx->mcd_struct);
   FILE *mcf_file = myfopen(ctx->input_filename, "r");
-  int start_sentence_index = 1;
+  int start_sentence_index = 0;
   
-
   /* create an mcd that corresponds to ctx->mcd_struct, but without gov and label */
   /* the idea is to ignore syntax in the mcf file that will be read */
   /* it is ugly !!! */
@@ -65,28 +64,44 @@ void generate_training_file_stream(FILE *output_file, context *ctx)
   mcd_remove_wf_column(mcd_struct_hyp, MCD_WF_GOV);
   mcd_remove_wf_column(mcd_struct_hyp, MCD_WF_LABEL);
   
-  c = config_initial(mcf_file, mcd_struct_hyp, 5);
+  c = config_initial_no_dummy_word(mcf_file, mcd_struct_hyp, 5);
   
   while(!word_buffer_end(ref)){
     /*printf("************ REF ************\n");
     word_buffer_print(stdout, ref);
-    	  printf("*****************************\n");*/
+    printf("*****************************\n");*/
     
-    /* printf("*****************************\n");  */
     config2feat_vec_cff(ctx->features_model, c, ctx->d_perceptron_features, fv, ctx->mode);
     
     /* feat_vec_print(stdout, fv);    */
     
-    mvt_code = oracle_parser_arc_eager(c, ref, start_sentence_index);
+    mvt_code = oracle_parser_arc_eager(c, ref, start_sentence_index, root_label);
       
     mvt_type = movement_type(mvt_code);
     mvt_label = movement_label(mvt_code);
      
-    /* config_print(stdout,c);          */
-     /* movement_print(stdout, mvt_code, ctx->dico_labels);       */
+    /* config_print(stdout,c);           */
+    /* movement_print(stdout, mvt_code, ctx->dico_labels);        */
     
     fprintf(output_file, "%d", mvt_code);
     feat_vec_print(output_file, fv);
+
+    if(mvt_type == MVT_EOS){
+      /* printf("************BEFORE *****************\n");    */
+      /* config_print(stdout,c);             */
+
+      movement_eos(c, 0);
+
+      /* printf("************AFTER*****************\n");    */
+      /* config_print(stdout,c);             */
+      start_sentence_index = word_get_index(word_buffer_b0(config_get_buffer(c))) - 1;
+      /* 	printf("%d\n", start_sentence_index); */
+      
+      if(word_buffer_is_last(ref)){
+	/* printf("it is the end\n"); */
+	break;
+      }
+    }
     
     if(mvt_type == MVT_LEFT){
       movement_left_arc(c, mvt_label, 0);
@@ -96,59 +111,6 @@ void generate_training_file_stream(FILE *output_file, context *ctx)
     if(mvt_type == MVT_RIGHT){
       movement_right_arc(c, mvt_label, 0);
       word_buffer_move_right(ref);
-      if((mvt_label == eos_label)){       /* sentence is complete */
-#if 0
-	=======
-  while((ref = sentence_read(conll_file_ref , ctx->mcd_struct)) && (sentence_nb < ctx->sent_nb)){ 
-     /* sentence_print(stdout, ref, ctx->dico_labels);   */
-    while(1){
-  /* config_print(stdout,c);         */
-      config2feat_vec_cff(ctx->features_model, c, ctx->d_perceptron_features, fv, ctx->mode);
-      
-      /* feat_vec_print(stdout, fv);    */
-      
-      mvt_code = oracle_parser(c, ref);
-      
-      mvt_type = movement_type(mvt_code);
-      mvt_label = movement_label(mvt_code);
-
-      /* printf("mvt code = %d\n", mvt_code); */
-       /* movement_print(stdout, mvt_code, ctx->dico_labels);   */
-      
-      fprintf(output_file, "%d", mvt_code);
-      feat_vec_print(output_file, fv);
-      
-      if(queue_is_empty(c->bf)) break;
-      
-      if((mvt_type == MVT_RIGHT) && (mvt_label == root_label)){       /* sentence is complete */
-	
-	/* create the root arc */
-	movement_right_arc(c, mvt_label, 0);
-	
-	/* shift dummy word in stack */
-	movement_shift(c, 1, 0);
-
-	/*		printf("sentence complete config : ");
-			config_print(stdout,c);   */
-	
-	/* empty depset */
-	depset_free(c->ds);
-	c->ds = depset_new();
-	sentence_free(ref);
->>>>>>> master
-#endif
-
-	  sentence_nb++;
-	start_sentence_index = word_get_index(word_buffer_b0(config_get_buffer(c))) - 1;
-	/* 	printf("%d\n", start_sentence_index); */
-	
-	/* printf("*****************************\n");  */
-	 /* config_print(stdout,c);           */
-	if(word_buffer_is_last(ref)){
-	  /* printf("it is the end\n"); */
-	  break;
-	}
-      }      
       continue;
     }
     
@@ -156,6 +118,12 @@ void generate_training_file_stream(FILE *output_file, context *ctx)
       movement_reduce(c, 0);
       continue;
     }
+   
+    if(mvt_type == MVT_ROOT){
+      movement_root(c, 0, root_label);
+      continue;
+    }
+
     if(mvt_type == MVT_SHIFT){
       movement_shift(c, 1, 0);
       word_buffer_move_right(ref);
diff --git a/maca_trans_parser/src/movement_parser_arc_eager.c b/maca_trans_parser/src/movement_parser_arc_eager.c
index 0481cb0..df00df7 100644
--- a/maca_trans_parser/src/movement_parser_arc_eager.c
+++ b/maca_trans_parser/src/movement_parser_arc_eager.c
@@ -10,6 +10,7 @@ void movement_print(FILE *f, int mvt_code, dico *dico_labels){
   char *label;
   if(mvt_type == MVT_SHIFT) {fprintf(f, "SHIFT\n"); return;}
   if(mvt_type == MVT_REDUCE) {fprintf(f, "REDUCE\n"); return;}
+  if(mvt_type == MVT_ROOT) {fprintf(f, "ROOT\n"); return;}
   if(mvt_type == MVT_RIGHT) fprintf(f, "RIGHT");
   else fprintf(f, "LEFT");
   label = dico_int2string(dico_labels, mvt_label);
@@ -20,6 +21,8 @@ int movement_type(int mvt)
 {
   if(mvt == 0) return MVT_SHIFT;      /* 0  is the code of shift */
   if(mvt == 1) return MVT_REDUCE;     /* 1  is the code of reduce */
+  if(mvt == 2) return MVT_ROOT;       /* 2  is the code of root */
+  if(mvt == 3) return MVT_EOS;       /* 2  is the code of root */
   if(mvt % 2 == 0) return MVT_LEFT;   /* even movements are left movements */
   return MVT_RIGHT;                    /* odd movements are right movements */
 }
@@ -28,15 +31,28 @@ int movement_label(int mvt)
 {
   if(mvt == 0) return -1;     /* 0  is the code of shift */ 
   if(mvt == 1) return -1;     /* 1  is the code of reduce */ 
+  if(mvt == 2) return -1;     /* 2  is the code of root */ 
+  if(mvt == 3) return -1;     /* 3  is the code of eos */ 
   if(mvt % 2 == 0)            /* even codes correspond to left movements */
-    return mvt / 2 - 1;
-  return (mvt - 1) / 2 - 1;   /* odd codes correspond to right movements */
+    return mvt / 2 - 2;
+  return (mvt - 1) / 2 - 2;   /* odd codes correspond to right movements */
+}
+
+int movement_eos(config *c, float score)
+{
+  /* perform all pending reduce */
+  while(movement_reduce(c,0));
+  
+  /* remove root from stack */
+  stack_pop(config_get_stack(c));
+  return 1;
 }
 
 int movement_left_arc(config *c, int label, float score)
 {
-  if(stack_height(config_get_stack(c)) < 2) return 0; /* the dummy word cannot be a dependent */
-  if(word_buffer_is_empty(config_get_buffer(c))) return 0;
+   /* the dummy word cannot be a dependent */
+  /* if(stack_height(config_get_stack(c)) < 2) return 0; */
+  /* if(word_buffer_is_empty(config_get_buffer(c))) return 0; */
   /* word on top of the stack should not have a governor */
   if(word_get_gov(stack_top(config_get_stack(c))) != 0) return 0; 
 
@@ -57,14 +73,13 @@ int movement_left_arc(config *c, int label, float score)
 
 int movement_right_arc(config *c, int label, float score)
 {
- /* printf("RA "); */
   if(stack_is_empty(config_get_stack(c))) return 0;
-  if(word_buffer_is_empty(config_get_buffer(c))) return 0;
-
+  /* if(word_buffer_is_empty(config_get_buffer(c))) return 0; */
+  
   word *gov = stack_top(config_get_stack(c));
   word *dep = word_buffer_b0(config_get_buffer(c));
   int dist = (word_get_index(gov)) - (word_get_index(dep));
-
+  
   /* printf("create right arc %d -> %d dist = %d\n", word_get_index(gov), word_get_index(dep), dist);  */
 
 
@@ -98,3 +113,14 @@ int movement_reduce(config *c, float score)
   return 1;
 }
 
+int movement_root(config *c, float score, int root_code)
+{
+  word *b0 = word_buffer_b0(config_get_buffer(c));
+  word_set_gov(b0, 0);
+  word_set_label(b0, root_code);
+  /* stack_push(config_get_stack(c), b0); */
+  /* word_buffer_move_right(config_get_buffer(c)); */
+  config_add_mvt(c, MVT_ROOT);
+  return 1;
+}
+
diff --git a/maca_trans_parser/src/movement_parser_arc_eager.h b/maca_trans_parser/src/movement_parser_arc_eager.h
index d117c5c..63f2f8a 100644
--- a/maca_trans_parser/src/movement_parser_arc_eager.h
+++ b/maca_trans_parser/src/movement_parser_arc_eager.h
@@ -7,15 +7,17 @@
 
 #define MVT_SHIFT 0
 #define MVT_REDUCE 1
-#define MVT_LEFT 2
-#define MVT_RIGHT 3
+#define MVT_ROOT 2
+#define MVT_EOS 3
+#define MVT_LEFT 4
+#define MVT_RIGHT 5
 
 
-/* even movements are left movements (except 0, which is shift) */
-#define movement_left_code(label) (2 * (label) + 2)
+/* even movements are left movements (except 0, which is shift and 2 which is root) */
+#define movement_left_code(label) (2 * (label) + 4)
 
-/* odd movements are right movements  (except 1, which is reduce) */
-#define movement_right_code(label) (2 * (label) + 3)
+/* odd movements are right movements  (except 1, which is reduce and 3 which is end_of_sentence) */
+#define movement_right_code(label) (2 * (label) + 5)
 
 int movement_type(int mvt);
 int movement_label(int mvt);
@@ -24,7 +26,8 @@ int movement_left_arc(config *c, int label, float score);
 int movement_right_arc(config *c, int label, float score);
 int movement_shift(config *c, int stream, float score);
 int movement_reduce(config *c, float score);
-
+int movement_root(config *c, float score, int root_code);
+int movement_eos(config *c, float score);
 void movement_print(FILE *f, int mvt_code, dico *dico_labels);
 
 #endif
diff --git a/maca_trans_parser/src/oracle_parser_arc_eager.c b/maca_trans_parser/src/oracle_parser_arc_eager.c
index 082130a..2e41c60 100644
--- a/maca_trans_parser/src/oracle_parser_arc_eager.c
+++ b/maca_trans_parser/src/oracle_parser_arc_eager.c
@@ -9,7 +9,8 @@ int check_all_dependents_of_word_in_ref_are_in_hyp(config *c, word_buffer *ref,
   int dep;
   int gov_ref;
   int gov_hyp;
-  int max = ((start_sentence_index + 500) > ref->nbelem)? ref->nbelem : (start_sentence_index + 500);
+  int max_sent_length = 300;
+  int max = ((start_sentence_index + max_sent_length) > ref->nbelem)? ref->nbelem : (start_sentence_index + max_sent_length);
   for(dep=start_sentence_index; dep < max; dep++){
     gov_ref = word_get_gov_index(word_buffer_get_word_n(ref, dep)); 
     if(gov_ref  == word_index){ /* found a dependent of word in ref */
@@ -30,26 +31,41 @@ int check_all_dependents_of_word_in_ref_are_in_hyp(config *c, word_buffer *ref,
   return 1;
 }
 
-int oracle_parser_arc_eager(config *c, word_buffer *ref, int start_sentence_index)
+int oracle_parser_arc_eager(config *c, word_buffer *ref, int start_sentence_index, int root_label)
 {
   word *s0; /* word on top of stack */
   word *b0; /* next word in the bufer */
   int s0_index, b0_index;
   int s0_gov_index, b0_gov_index;
+  int b0_label;
+  int b0_label_in_hyp;
 
-  if(!stack_is_empty(config_get_stack(c)) && !word_buffer_is_empty(config_get_buffer(c))){
+  b0 = word_buffer_b0(config_get_buffer(c));
+  b0_index = word_get_index(b0);
+  b0_gov_index = word_get_gov_index(word_buffer_get_word_n(ref, b0_index));
+  b0_label = word_get_label(word_buffer_get_word_n(ref, b0_index));
+  b0_label_in_hyp = word_get_label(word_buffer_get_word_n(config_get_buffer(c), b0_index));
+  
+  /* b0 is the root of the sentence */
+  if((b0_label == root_label) && (b0_label_in_hyp != root_label)){
+    return MVT_ROOT;
+  }
+  
+  /* if(!stack_is_empty(config_get_stack(c)) && !word_buffer_is_empty(config_get_buffer(c))){ */
+  if(!stack_is_empty(config_get_stack(c))){
     s0 = stack_top(config_get_stack(c));
     s0_index = word_get_index(s0);
     s0_gov_index = word_get_gov_index(word_buffer_get_word_n(ref, s0_index));
     
-    b0 = word_buffer_b0(config_get_buffer(c));
-    b0_index = word_get_index(b0);
-    b0_gov_index = word_get_gov_index(word_buffer_get_word_n(ref, b0_index));
-    
     /*    printf("s0_index = %d b0_index = %d\n", s0_index, b0_index);  
-    printf("dans ref gov de s0 (%d) = %d\n", s0_index, s0_gov_index);
-    printf("dans ref gov de b0 (%d) = %d\n", b0_index, b0_gov_index);*/
-    
+	  printf("dans ref gov de s0 (%d) = %d\n", s0_index, s0_gov_index);
+	  printf("dans ref gov de b0 (%d) = %d\n", b0_index, b0_gov_index);*/
+
+    /* word on the top of the stack is an end of sentence marker */
+    if(word_get_sent_seg(s0) == 1){
+      return MVT_EOS;
+    }
+
     /* LEFT ARC  b0 is the governor and s0 the dependent */
     if(s0_gov_index == b0_index){
       return movement_left_code(word_get_label(word_buffer_get_word_n(ref, s0_index)));
@@ -60,15 +76,16 @@ int oracle_parser_arc_eager(config *c, word_buffer *ref, int start_sentence_inde
       return movement_right_code(word_get_label(word_buffer_get_word_n(ref, b0_index))); 
     }
     /* REDUCE */
-    if((stack_height(config_get_stack(c)) > 2)
-       && check_all_dependents_of_word_in_ref_are_in_hyp(c, ref, s0_index, start_sentence_index)
+    if(
+       /* (stack_height(config_get_stack(c)) > 2) */
+       check_all_dependents_of_word_in_ref_are_in_hyp(c, ref, s0_index, start_sentence_index)
        && (word_get_gov(stack_top(config_get_stack(c))) != 0)) /* word on top of the stack has a goveror */
       {
 	return MVT_REDUCE;
       }
-    
-    /* SHIFT */
-    return MVT_SHIFT;
   }
-  return -1;
+
+  /* SHIFT */
+  return MVT_SHIFT;
+
 }
diff --git a/maca_trans_parser/src/oracle_parser_arc_eager.h b/maca_trans_parser/src/oracle_parser_arc_eager.h
index 4f75682..9b92d5d 100644
--- a/maca_trans_parser/src/oracle_parser_arc_eager.h
+++ b/maca_trans_parser/src/oracle_parser_arc_eager.h
@@ -6,6 +6,6 @@
 #include"word_buffer.h"
 
 
-int oracle_parser_arc_eager(config *c, word_buffer *ref, int start_sentence_index);
+int oracle_parser_arc_eager(config *c, word_buffer *ref, int start_sentence_index, int root_label);
 
 #endif
diff --git a/maca_trans_parser/src/simple_decoder_parser.c b/maca_trans_parser/src/simple_decoder_parser.c
index 81aea38..b6f8888 100644
--- a/maca_trans_parser/src/simple_decoder_parser.c
+++ b/maca_trans_parser/src/simple_decoder_parser.c
@@ -20,16 +20,16 @@ void simple_decoder_stream(context *ctx, FILE *f, feature_table *ft, int root_la
   config *c = NULL;
   word *dep;
   
-  c = config_initial(f, ctx->mcd_struct, 5);
+  c = config_initial_no_dummy_word(f, ctx->mcd_struct, 5);
   while(1){
     config2feat_vec_cff(ctx->features_model, c, ctx->d_perceptron_features, fv, LOOKUP_MODE);
     mvt_code = feature_table_argmax(fv, ft, &max);
     mvt_type = movement_type(mvt_code);
     mvt_label = movement_label(mvt_code);
-    
-    /* config_print(stdout, c);     */
-    /* movement_print(stdout, mvt_code, ctx->dico_labels);     */
-    
+
+    /* config_print(stdout, c);    */
+    /* movement_print(stdout, mvt_code, ctx->dico_labels);      */
+
     if(mvt_type == MVT_LEFT){
       dep = stack_s0(config_get_stack(c));
       if(movement_left_arc(c, mvt_label, max)){
@@ -55,13 +55,17 @@ void simple_decoder_stream(context *ctx, FILE *f, feature_table *ft, int root_la
     if(mvt_type == MVT_REDUCE)
       if(movement_reduce(c, max))
 	continue;
+   
+    if(mvt_type == MVT_ROOT)
+      if(movement_root(c, max, root_label))
+      continue;
     
     movement_shift(c, 1, max);
     
     if(word_buffer_is_last(config_get_buffer(c))) break;
   }
   
-  for(int i=1; i < config_get_buffer(c)->nbelem; i++){
+  for(int i=0; i < config_get_buffer(c)->nbelem; i++){
     dep = word_buffer_get_word_n(config_get_buffer(c), i);
     printf("%s\t", word_get_input(dep));
     printf("%d\t", word_get_gov(dep));
-- 
GitLab