#include<stdio.h>
#include<stdlib.h>
#include<string.h>
#include<unistd.h>
#include<getopt.h>
#include<math.h>
#include<time.h>
#include"context.h"
#include"movement_parser_arc_eager.h"
#include"feat_fct.h"
#include"config2feat_vec.h"
#include"feature_table.h"
#include"dico.h"
#include"partial_parser_conditional.h"
#include"confidence_score.h"

unsigned long mix(unsigned long a, unsigned long b, unsigned long c)
{
    a=a-b;  a=a-c;  a=a^(c >> 13);
    b=b-c;  b=b-a;  b=b^(a << 8);
    c=c-a;  c=c-b;  c=c^(b >> 13);
    a=a-b;  a=a-c;  a=a^(c >> 12);
    b=b-c;  b=b-a;  b=b^(a << 16);
    c=c-a;  c=c-b;  c=c^(b >> 5);
    a=a-b;  a=a-c;  a=a^(c >> 3);
    b=b-c;  b=b-a;  b=b^(a << 10);
    c=c-a;  c=c-b;  c=c^(b >> 15);
    return c;
}

void print_word_buffer_old(config *c, dico *dico_labels, mcd *mcd_struct)
{
  int i;
  word *dep;
  char *label;
  
  for(i=0; i < config_get_buffer(c)->nbelem; i++){
    dep = word_buffer_get_word_n(config_get_buffer(c), i);
    printf("%s\t", word_get_input(dep));
    printf("%d\t", word_get_gov(dep));
    label = (word_get_label(dep) == -1)? NULL : dico_int2string(dico_labels, word_get_label(dep));
    if(label != NULL)
      printf("%s\t", label) ;
    else
      printf("_\t");
    if(mcd_get_sent_seg_col(mcd_struct) == -1){
      if(word_get_sent_seg(dep) == 1)
	printf("1") ;
      else
	printf("0");
    }
    printf("\n");
    
  }
}

void print_word_buffer(config *c, dico *dico_labels, mcd *mcd_struct)
{
  int i;
  word *w;
  char *label;
  char *buffer = NULL;
  char *token = NULL;
  int col_nb = 0;

  
  for(i=0; i < config_get_buffer(c)->nbelem; i++){
    w = word_buffer_get_word_n(config_get_buffer(c), i);

    if((mcd_get_gov_col(mcd_struct) == -1)
       && (mcd_get_label_col(mcd_struct) == -1)
       && (mcd_get_sent_seg_col(mcd_struct) == -1)){
      printf("%s\t", word_get_input(w));
      printf("%d\t", word_get_gov(w));
      label = (word_get_label(w) == -1)? NULL : dico_int2string(dico_labels, word_get_label(w));
      if(label != NULL)
	printf("%s\t", label) ;
      else
	printf("_\t");
      if(word_get_sent_seg(w) == 1)
	printf("1\n") ;
      else
	printf("0\n");
    }
    else{
      buffer = strdup(w->input);
      token = strtok(buffer, "\t");
      col_nb = 0;
      while(token){
	if(col_nb != 0) printf("\t");
	if(col_nb == mcd_get_gov_col(mcd_struct)){
	  printf("%d", word_get_gov(w));
	}
	else
	  if(col_nb == mcd_get_label_col(mcd_struct)){
	    label = (word_get_label(w) == -1)? NULL : dico_int2string(dico_labels, word_get_label(w));
	    if(label != NULL)
	      printf("%s", label) ;
	    else
	      printf("_");
	  }
	  else
	    if(col_nb == mcd_get_sent_seg_col(mcd_struct)){
	      if(word_get_sent_seg(w) == 1)
		printf("1") ;
	      else
		printf("0");
	    }
	    else{
	      word_print_col_n(stdout, w, col_nb);
	    }
	col_nb++;
	token = strtok(NULL, "\t");
      }
      if((col_nb <= mcd_get_gov_col(mcd_struct)) || (mcd_get_gov_col(mcd_struct) == -1)){
	printf("\t%d", word_get_gov(w));
      }
      if((col_nb <= mcd_get_label_col(mcd_struct)) || (mcd_get_label_col(mcd_struct) == -1)){
	label = (word_get_label(w) == -1)? NULL : dico_int2string(dico_labels, word_get_label(w));
	if(label != NULL)
	  printf("\t%s", label) ;
	else
	  printf("\t_");
      }
      if((col_nb <= mcd_get_sent_seg_col(mcd_struct)) || (mcd_get_sent_seg_col(mcd_struct) == -1)){
	if(word_get_sent_seg(w) == 1)
	  printf("\t1") ;
	else
	  printf("\t0");
      }
      if(col_nb <= mcd_get_s_col(mcd_struct)){
	if(word_get_S(w) > 0)
	  printf("\t%d",word_get_S(w));
	else
	  printf("\t-1");
      }
      if(col_nb <= mcd_get_t_col(mcd_struct)){
	if(word_get_T(w) > 0)
	  printf("\t%d",word_get_T(w));
	else
	  printf("\t-1");
      }
      printf("\n");
      free(buffer);
    }
  }
}


void simple_decoder_parser_arc_eager(context *ctx)
{
  FILE *f = (ctx->input_filename)? myfopen(ctx->input_filename, "r") : stdin;
  feature_table *ft = feature_table_load(ctx->perc_model_filename, ctx->verbose);
  int root_label;
  int mvt_code;
  int mvt_type;
  int mvt_label;
  float max;
  feat_vec *fv = feat_vec_new(feature_types_nb);
  config *c = NULL;
  int result;
  /* float entropy; */
  /* float delta; */
  int argmax1, argmax2;
  float max1, max2;
  int index;
  float score;
  
  double sumExp;
  double currentSumExp;
  double ScoreTranslation;
  int FlagNotInitExp;
  double ProbaDivider = ctx->proba_factor;
  double randomFloat;
  
  if(ProbaDivider <= 0)
    ProbaDivider = 1;
  
  if(ctx->proba_mode)
    srand(mix(clock(), time(NULL), getpid()));
    
  word* word_scored;
  
  root_label = dico_string2int(ctx->dico_labels, ctx->root_label);
  if(root_label == -1) root_label = 0;
  
  c = config_new(f, ctx->mcd_struct, 5);
  int noRootYet = 1;

  while(!config_is_terminal(c)){
    
    sumExp = 0;
    currentSumExp = 0;
    ScoreTranslation = -100; // TO SETUP?
    FlagNotInitExp = 1;

    if(ctx->debug_mode){
      fprintf(stdout, "***********************************\n");
      config_print(stdout, c);      
    }	
    /* forced EOS (the element on the top of the stack is eos, but the preceding movement is not MVT_PARSER_EOS */
    /* which means that the top of the stack got its eos status from input */
    /* force the parser to finish parsing the sentence (perform all pending reduce actions) and determine root of the sentence */ 

    if((word_get_sent_seg(stack_top(config_get_stack(c))) == 1) && (mvt_get_type(mvt_stack_top(config_get_history(c))) != MVT_PARSER_EOS)){
      word_set_sent_seg(stack_top(config_get_stack(c)), -1);
      movement_parser_eos(c);
      while(movement_parser_reduce(c));
      while(movement_parser_root(c, root_label));
      if(ctx->debug_mode) printf("force EOS\n");
    }

    /* normal behaviour, ask classifier what is the next movement to do and do it */
    else{
      config2feat_vec_cff(ctx->features_model, c, ctx->d_perceptron_features, fv, LOOKUP_MODE);
      mvt_code = feature_table_argmax(fv, ft, &max);

    if(ctx->proba_mode || ctx->debug_mode){
      vcode *vcode_array = feature_table_get_vcode_array(fv, ft);
      
      
      if(ctx->proba_mode){
        /* Get the probabilistic parameters */
        for(int i=0; i < ft->classes_nb; i++){
          int b1 = respect_standard_constraint(c, ctx, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
          int b2 = respect_stack_constraint(ctx->partial_mode, c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
          int b3 = respect_buffer_constraint(ctx->partial_mode, c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
	  int b4 = (noRootYet || MVT_PARSER_ROOT != movement_parser_type(vcode_array[i].class_code)) ;
	  
          if(b1 && b2 && b3 && b4){
            if(FlagNotInitExp){
              ScoreTranslation += vcode_array[i].score;
              FlagNotInitExp = 0;
            }
            if((vcode_array[i].score - ScoreTranslation)/ProbaDivider > 0){
              sumExp += exp((vcode_array[i].score - ScoreTranslation)/ProbaDivider);
            }	  
          }
        }
      }
      
      currentSumExp = 0.;
      for(int i=0; i < ft->classes_nb && i < 10; i++){
        if(ctx->debug_mode){
          printf("%d\t", i);
          movement_parser_print(stdout, vcode_array[i].class_code, ctx->dico_labels);
          printf("\t%.4f", vcode_array[i].score);
          fflush(stdout);
        }
        int b1 = respect_standard_constraint(c, ctx, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
        int b2 = respect_stack_constraint(ctx->partial_mode, c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
        int b3 = respect_buffer_constraint(ctx->partial_mode, c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
	int b4 = (noRootYet || MVT_PARSER_ROOT != movement_parser_type(vcode_array[i].class_code)) ;
	
	if(b1 && b2 && b3 && b4){
          if(ctx->proba_mode && (vcode_array[i].score - ScoreTranslation)/ProbaDivider > 0){
            if(ctx->debug_mode){
         //     printf(" %f %f %f",sumExp, currentSumExp,ScoreTranslation);
              printf(" [%f-",currentSumExp/sumExp);
            }
            currentSumExp += exp((vcode_array[i].score - ScoreTranslation)/ProbaDivider);
            if(ctx->debug_mode)
              printf("%f[", currentSumExp/sumExp);
          }
          if(ctx->debug_mode)
            printf("\t<----");
        }else if(ctx->debug_mode){
          printf("\t<%d,%d,%d>",b1,b2,b3);
        }
        // printf("\t%d", respect_standard_constraint(c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code)));
        //printf("AAAAAAA\n");
        if(ctx->debug_mode)
          printf("\n");
      }
      free(vcode_array);
    }
      
    if(ctx->trace_mode){
      index = word_get_index(word_buffer_b0(config_get_buffer(c)));
      fprintf(stdout, "%d\t", index);
      
      stack_print(stdout, c->st);
      fprintf(stdout, "\t");
      
      movement_parser_print(stdout, mvt_code, ctx->dico_labels);        
      fprintf(stdout, "\t");
      feature_table_argmax_1_2(fv, ft, &argmax1, &max1, &argmax2, &max2);
      printf("%f\n", max1 - max2);
    }

    if(ctx->proba_mode){
      currentSumExp = 0.;
      randomFloat = (double) rand()/(double)RAND_MAX;
   
      if(ctx->debug_mode)
        printf("< %f > is our random number. \n",randomFloat);
    }
    if(ctx->partial_mode || ctx->proba_mode || ctx->single_root_mode){
      vcode *vcode_array = feature_table_get_vcode_array(fv, ft);
      mvt_code = 0;
      for(int i=0; i < ft->classes_nb; i++){
        int b1 = respect_standard_constraint(c, ctx, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
        int b2 = respect_stack_constraint(ctx->partial_mode, c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
        int b3 = respect_buffer_constraint(ctx->partial_mode, c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
	int b4 = (noRootYet || MVT_PARSER_ROOT != movement_parser_type(vcode_array[i].class_code)) ;	
	
        if(b1 && b2 && b3 && b4){
          if(ctx->proba_mode){
            currentSumExp += exp((vcode_array[i].score - ScoreTranslation)/ProbaDivider);
            if(currentSumExp/sumExp >= randomFloat){ 
              if(ctx->debug_mode)
                printf("The %d th move has been selected by the probabilistic parser.\n",i);
              mvt_code = vcode_array[i].class_code;
              break;
            }
          }else{
            mvt_code = vcode_array[i].class_code;
            break;
          }
        }
      }
      free(vcode_array);
    }

     
      mvt_type = movement_parser_type(mvt_code);
      mvt_label = movement_parser_label(mvt_code);

      if((mvt_type == MVT_PARSER_EOS) && (word_get_sent_seg(stack_top(config_get_stack(c))) == 0)){
        if(ctx->verbose)
          fprintf(stderr, "the classifier did predict EOS but this is not the case\n");
        feature_table_argmax_1_2(fv, ft, &argmax1, &max1, &argmax2, &max2);
        mvt_code = argmax2;
        mvt_type = movement_parser_type(mvt_code);
        mvt_label = movement_parser_label(mvt_code);
	
      }
      
      result = 0;
      switch(mvt_type){
      case MVT_PARSER_LEFT :
	word_scored = stack_top(config_get_stack(c));
	result = movement_parser_left_arc(c, mvt_label);
	break;
      case MVT_PARSER_RIGHT:
	word_scored = word_buffer_b0(config_get_buffer(c));
	result = movement_parser_right_arc(c, mvt_label);
	break;
      case MVT_PARSER_REDUCE:
	word_scored = stack_top(config_get_stack(c));
	result = movement_parser_reduce(c);
	break;
      case MVT_PARSER_ROOT:
	word_scored = stack_top(config_get_stack(c));
	result = movement_parser_root(c, root_label);
	if(result && ctx->single_root_mode)
	  noRootYet = 0;
	/*	while(!stack_is_empty(config_get_stack(c)))
		movement_parser_root(c, root_label);*/
	break;
      case MVT_PARSER_EOS:
	result = movement_parser_eos(c);
	break;
      case MVT_PARSER_SHIFT:
	word_scored = word_buffer_b0(config_get_buffer(c));
	result = movement_parser_shift(c);
      }
      
      if(result == 0){
        if(ctx->debug_mode) fprintf(stdout, "WARNING : movement cannot be executed doing a SHIFT instead !\n");
	result = movement_parser_shift(c);
	if(result == 0){ /* SHIFT failed no more words to read, let's get out of here ! */
	  if(ctx->debug_mode) fprintf(stdout, "WARNING : cannot exectue a SHIFT emptying stack !\n");
	  while(!stack_is_empty(config_get_stack(c)))
	    movement_parser_root(c, root_label);
	}
      }else{
	 if(ctx->score_method > 0){
	   score = confidence_score(mvt_code,feature_table_get_vcode_array(fv,ft),ft->classes_nb,ctx,c);
	   switch(mvt_type){
	   case MVT_PARSER_LEFT :
	   case MVT_PARSER_RIGHT :
	   case MVT_PARSER_ROOT :
//	     printf("dep score: %d %d!!\n", word_get_form(word_scored), (int)(score*1000));
	     word_set_S(word_scored,(int)(score*1000));
	     break;
	   case MVT_PARSER_REDUCE:
	   case MVT_PARSER_SHIFT:
//	     printf("pop/shift score: %d  %d!!\n", word_get_form(word_scored), (int)(score*1000));
	     word_set_T(word_scored,(int)(score*1000));
	     break;
	   default:
	     break;
	   }
	 }
      }
    }
  }
  
  if(!ctx->trace_mode)
    print_word_buffer(c, ctx->dico_labels, ctx->mcd_struct);
  
  config_free(c); 
  feat_vec_free(fv);
  feature_table_free(ft);
  if(ctx->input_filename)
    fclose(f);
}