From c3f1f0a746454f2e9a173d6df99dbddef246246b Mon Sep 17 00:00:00 2001
From: RP <robin.perrotin@lis-lab.fr>
Date: Wed, 18 Jul 2018 02:18:22 +0200
Subject: [PATCH] Added --singleroot option to ensure a tree is produced

---
 maca_trans_parser/src/confidence_score.c      |  4 +--
 maca_trans_parser/src/context.c               | 18 ++++++++++---
 maca_trans_parser/src/context.h               |  3 +++
 .../src/partial_parser_conditional.c          | 24 ++++++++++++++---
 .../src/partial_parser_conditional.h          |  2 +-
 .../src/simple_decoder_parser_arc_eager.c     | 26 ++++++++++++-------
 6 files changed, 57 insertions(+), 20 deletions(-)

diff --git a/maca_trans_parser/src/confidence_score.c b/maca_trans_parser/src/confidence_score.c
index d575291..17520a9 100644
--- a/maca_trans_parser/src/confidence_score.c
+++ b/maca_trans_parser/src/confidence_score.c
@@ -15,7 +15,7 @@ float confidence_score(int mvt_code, vcode *vcode_array, int size, context *ctx,
   case 1: //methode 1: First - Second.
 
     while(firstindex == -1 && i < size){
-      int b1 = respect_standard_constraint(c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
+      int b1 = respect_standard_constraint(c, ctx,movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
       int b2 = respect_stack_constraint(ctx->partial_mode, c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
       int b3 = respect_buffer_constraint(ctx->partial_mode, c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
       if(b1 && b2 && b3)
@@ -23,7 +23,7 @@ float confidence_score(int mvt_code, vcode *vcode_array, int size, context *ctx,
       i += 1;
     }
     while(secondindex == -1 && i < size){
-      int b1 = respect_standard_constraint(c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
+      int b1 = respect_standard_constraint(c, ctx, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
       int b2 = respect_stack_constraint(ctx->partial_mode, c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
       int b3 = respect_buffer_constraint(ctx->partial_mode, c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
       if(b1 && b2 && b3)
diff --git a/maca_trans_parser/src/context.c b/maca_trans_parser/src/context.c
index bb08772..651de66 100644
--- a/maca_trans_parser/src/context.c
+++ b/maca_trans_parser/src/context.c
@@ -106,7 +106,9 @@ context *context_new(void)
   
   ctx->proba_mode = 0;
   ctx->proba_factor = 1;
-      
+
+  ctx->single_root_mode = 0;
+  
   return ctx;
 }
 
@@ -197,6 +199,10 @@ void context_proba_mode_help_message(context *ctx){
   fprintf(stderr, "\t-Z --probabilistic            : activate probabilistic mode with factor >0 (use 1 for default).\n");
 }
 
+void context_single_root_mode_help_message(context *ctx){
+  fprintf(stderr, "\t-X --singleroot               : activate single root mode: an tree will one root will be produced.");
+}
+
 context *context_read_options(int argc, char *argv[])
 {
   int c;
@@ -205,7 +211,7 @@ context *context_read_options(int argc, char *argv[])
 
   ctx->program_name = strdup(argv[0]);
 
-  static struct option long_options[29] =
+  static struct option long_options[30] =
     {
       {"help",                no_argument,       0, 'h'},
       {"verbose",             no_argument,       0, 'v'},
@@ -235,12 +241,13 @@ context *context_read_options(int argc, char *argv[])
       {"dnn_model",           required_argument, 0, 'N'},
       {"l_rules",             required_argument, 0, 'l'},
       {"probabilistic",       required_argument, 0, 'Z'},
-      {"fplm",                required_argument, 0, 'w'}
+      {"fplm",                required_argument, 0, 'w'},
+      {"singleroot",          no_argument,       0, 'X'}
     };
   optind = 0;
   opterr = 0;
   
-  while ((c = getopt_long (argc, argv, "hvdcTpm:i:n:x:u:r:M:b:f:s:C:F:V:L:D:R:P:J:N:w:l:S:T:Z:", long_options, &option_index)) != -1){ 
+  while ((c = getopt_long (argc, argv, "hvdcTpm:i:n:x:u:r:M:b:f:s:C:F:V:L:D:R:P:J:N:w:l:S:T:Z:X", long_options, &option_index)) != -1){ 
 
     switch (c)
       {
@@ -338,6 +345,9 @@ context *context_read_options(int argc, char *argv[])
       ctx->proba_mode = 1;
       ctx->proba_factor = atoi(optarg);
 	break;
+      case 'X':
+	ctx->single_root_mode = 1;
+	break;
       }
   }
 
diff --git a/maca_trans_parser/src/context.h b/maca_trans_parser/src/context.h
index c96a3f1..3d80cb6 100644
--- a/maca_trans_parser/src/context.h
+++ b/maca_trans_parser/src/context.h
@@ -112,6 +112,9 @@ typedef struct {
   int trace_mode;
   int partial_mode;
   int score_method;
+
+  int single_root_mode;
+  
   char *json_filename;
   char *dnn_model_filename;
   char *l_rules_filename;
diff --git a/maca_trans_parser/src/partial_parser_conditional.c b/maca_trans_parser/src/partial_parser_conditional.c
index 44507c0..622f5c2 100644
--- a/maca_trans_parser/src/partial_parser_conditional.c
+++ b/maca_trans_parser/src/partial_parser_conditional.c
@@ -1,7 +1,7 @@
 #include"context.h"
 #include"movement_parser_arc_eager.h"
 
-int respect_standard_constraint(config *c, int mvt_type, int mvt_label){
+int respect_standard_constraint(config *c, context* ctx, int mvt_type, int mvt_label){
   int gov;
 	switch(mvt_type){
       case MVT_PARSER_LEFT :
@@ -16,20 +16,22 @@ int respect_standard_constraint(config *c, int mvt_type, int mvt_label){
       case MVT_PARSER_RIGHT:
         if(stack_is_empty(config_get_stack(c))) return 0;
 	if(word_buffer_end(config_get_buffer(c))) return 0;
+	if((config_get_buffer(c)->current_index >= config_get_buffer(c)->nbelem - 1) && ctx->single_root_mode && !(stack_top(config_get_stack(c))->is_root || config_get_stack(c)->top == 1)) return 0;
         return 1;
       case MVT_PARSER_REDUCE:
         if(stack_is_empty(config_get_stack(c))) return 0;
         gov = word_get_gov(stack_top(config_get_stack(c)));
 	//	printf("\n%d %d???\n",gov,stack_top(config_get_stack(c))->is_root);
-        if(stack_top(config_get_stack(c))->is_root || gov != WORD_INVALID_GOV)
+        if((stack_top(config_get_stack(c))->is_root && !(ctx->single_root_mode && !(word_buffer_end(config_get_buffer(c))))) || gov != WORD_INVALID_GOV)
 	  //if(gov != WORD_INVALID_GOV)
           return 1;
         return 0;
       case MVT_PARSER_SHIFT:
-        if(word_buffer_end(config_get_buffer(c))) return 0;
+        if(config_get_buffer(c)->current_index >= config_get_buffer(c)->nbelem - 1) return 0;
 	return 1;
       case MVT_PARSER_ROOT:
-        if(stack_is_empty(config_get_stack(c))) return 0;
+        if(config_get_stack(c)->top != 1) return 0;
+	if(ctx->single_root_mode && config_get_buffer(c)->current_index < config_get_buffer(c)->nbelem) return 0;
 	gov = word_get_gov(stack_top(config_get_stack(c)));
         return (gov == WORD_INVALID_GOV);
       case MVT_PARSER_EOS:
@@ -92,6 +94,14 @@ int respect_stack_constraint(int mode_partial, config *c, int mvt_type, int mvt_
  //   printf("%d %d %d ",stack_id, gov_rel_id, buffer_id);
   if(gov_rel_id > 0){
     //top of stack needs to be governed by a left dependency.
+
+
+    //kind of a hack : 100000 _ means the word need to be governed by a left dependency.
+    if(gov_rel_id == 100000){
+      return (movement_type_safe_for_top_stack(mvt_type) || mvt_type == MVT_PARSER_LEFT);
+    }
+
+
     if(buffer_id - stack_id < gov_rel_id){
       //allow only if top of stack doesn't move and isn't set new dep (left move. implyed by not moving).
       if(!movement_type_safe_for_top_stack(mvt_type))
@@ -146,6 +156,12 @@ int respect_buffer_constraint(int mode_partial, config *c, int mvt_type, int mvt
   int buffer_id = word_get_index(w_buffer);
   if(gov_rel_id < 0){
     //top of buffer needs to be governed by a right dependency.
+
+    //kind of a hack : -100000 _ means the word need to be governed by a right dependency.
+    if(gov_rel_id == -100000){
+      return (movement_type_safe_for_top_buffer(mvt_type) || mvt_type == MVT_PARSER_RIGHT);
+    }
+    
     if(stack_id - buffer_id > gov_rel_id){
       //allow only if top of buffer doesn't move and isn't set new dep (right move. implyed by not moving OR ROOT move?).
       if(!movement_type_safe_for_top_buffer(mvt_type))
diff --git a/maca_trans_parser/src/partial_parser_conditional.h b/maca_trans_parser/src/partial_parser_conditional.h
index 55bd8a4..ed4fd05 100644
--- a/maca_trans_parser/src/partial_parser_conditional.h
+++ b/maca_trans_parser/src/partial_parser_conditional.h
@@ -2,7 +2,7 @@
 #define __PARTIAL_PARSER_CONDITIONAL__
 #include"context.h"
 
-int respect_standard_constraint(config *c,int mvt_type, int mv_label);
+int respect_standard_constraint(config *c,context* ctx, int mvt_type, int mv_label);
 int respect_stack_constraint(int mode_partial, config *c, int mvt_type, int mvt_label);
 int respect_buffer_constraint(int mode_partial, config *c, int mvt_type, int mvt_label);
 
diff --git a/maca_trans_parser/src/simple_decoder_parser_arc_eager.c b/maca_trans_parser/src/simple_decoder_parser_arc_eager.c
index 3fd5a3b..9e007b9 100644
--- a/maca_trans_parser/src/simple_decoder_parser_arc_eager.c
+++ b/maca_trans_parser/src/simple_decoder_parser_arc_eager.c
@@ -185,6 +185,8 @@ void simple_decoder_parser_arc_eager(context *ctx)
   if(root_label == -1) root_label = 0;
   
   c = config_new(f, ctx->mcd_struct, 5);
+  int noRootYet = 1;
+
   while(!config_is_terminal(c)){
     
     sumExp = 0;
@@ -220,11 +222,12 @@ void simple_decoder_parser_arc_eager(context *ctx)
       if(ctx->proba_mode){
         /* Get the probabilistic parameters */
         for(int i=0; i < ft->classes_nb; i++){
-          int b1 = respect_standard_constraint(c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
+          int b1 = respect_standard_constraint(c, ctx, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
           int b2 = respect_stack_constraint(ctx->partial_mode, c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
           int b3 = respect_buffer_constraint(ctx->partial_mode, c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
-
-          if(b1 && b2 && b3){
+	  int b4 = (noRootYet || MVT_PARSER_ROOT != movement_parser_type(vcode_array[i].class_code)) ;
+	  
+          if(b1 && b2 && b3 && b4){
             if(FlagNotInitExp){
               ScoreTranslation += vcode_array[i].score;
               FlagNotInitExp = 0;
@@ -244,10 +247,12 @@ void simple_decoder_parser_arc_eager(context *ctx)
           printf("\t%.4f", vcode_array[i].score);
           fflush(stdout);
         }
-        int b1 = respect_standard_constraint(c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
+        int b1 = respect_standard_constraint(c, ctx, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
         int b2 = respect_stack_constraint(ctx->partial_mode, c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
         int b3 = respect_buffer_constraint(ctx->partial_mode, c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
-        if(b1 && b2 && b3){
+	int b4 = (noRootYet || MVT_PARSER_ROOT != movement_parser_type(vcode_array[i].class_code)) ;
+	
+	if(b1 && b2 && b3 && b4){
           if(ctx->proba_mode && (vcode_array[i].score - ScoreTranslation)/ProbaDivider > 0){
             if(ctx->debug_mode){
          //     printf(" %f %f %f",sumExp, currentSumExp,ScoreTranslation);
@@ -290,15 +295,16 @@ void simple_decoder_parser_arc_eager(context *ctx)
       if(ctx->debug_mode)
         printf("< %f > is our random number. \n",randomFloat);
     }
-    if(ctx->partial_mode || ctx->proba_mode){
+    if(ctx->partial_mode || ctx->proba_mode || ctx->single_root_mode){
       vcode *vcode_array = feature_table_get_vcode_array(fv, ft);
       mvt_code = 0;
       for(int i=0; i < ft->classes_nb; i++){
-        int b1 = respect_standard_constraint(c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
+        int b1 = respect_standard_constraint(c, ctx, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
         int b2 = respect_stack_constraint(ctx->partial_mode, c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
         int b3 = respect_buffer_constraint(ctx->partial_mode, c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
-
-        if(b1 && b2 && b3){
+	int b4 = (noRootYet || MVT_PARSER_ROOT != movement_parser_type(vcode_array[i].class_code)) ;	
+	
+        if(b1 && b2 && b3 && b4){
           if(ctx->proba_mode){
             currentSumExp += exp((vcode_array[i].score - ScoreTranslation)/ProbaDivider);
             if(currentSumExp/sumExp >= randomFloat){ 
@@ -347,6 +353,8 @@ void simple_decoder_parser_arc_eager(context *ctx)
       case MVT_PARSER_ROOT:
 	word_scored = stack_top(config_get_stack(c));
 	result = movement_parser_root(c, root_label);
+	if(result && ctx->single_root_mode)
+	  noRootYet = 0;
 	/*	while(!stack_is_empty(config_get_stack(c)))
 		movement_parser_root(c, root_label);*/
 	break;
-- 
GitLab