Skip to content
Snippets Groups Projects
Commit cff43421 authored by robin.perrotin's avatar robin.perrotin
Browse files

Added option -Z for probabilistic mode with a ratio parameter

parent 4ea335aa
Branches
No related tags found
No related merge requests found
...@@ -103,6 +103,10 @@ context *context_new(void) ...@@ -103,6 +103,10 @@ context *context_new(void)
ctx->dnn_model_filename = NULL; ctx->dnn_model_filename = NULL;
ctx->l_rules_filename = NULL; ctx->l_rules_filename = NULL;
ctx->proba_mode = 0;
ctx->proba_factor = 1;
return ctx; return ctx;
} }
...@@ -189,6 +193,9 @@ void context_json_help_message(context *ctx){ ...@@ -189,6 +193,9 @@ void context_json_help_message(context *ctx){
void context_dnn_model_help_message(context *ctx){ void context_dnn_model_help_message(context *ctx){
fprintf(stderr, "\t-N --dnn_model : weight file for dnn\n"); fprintf(stderr, "\t-N --dnn_model : weight file for dnn\n");
} }
void context_proba_mode_help_message(context *ctx){
fprintf(stderr, "\t-Z --probabilistic : activate probabilistic mode with factor >0 (use 1 for default).\n");
}
context *context_read_options(int argc, char *argv[]) context *context_read_options(int argc, char *argv[])
{ {
...@@ -198,7 +205,7 @@ context *context_read_options(int argc, char *argv[]) ...@@ -198,7 +205,7 @@ context *context_read_options(int argc, char *argv[])
ctx->program_name = strdup(argv[0]); ctx->program_name = strdup(argv[0]);
static struct option long_options[28] = static struct option long_options[29] =
{ {
{"help", no_argument, 0, 'h'}, {"help", no_argument, 0, 'h'},
{"verbose", no_argument, 0, 'v'}, {"verbose", no_argument, 0, 'v'},
...@@ -227,12 +234,13 @@ context *context_read_options(int argc, char *argv[]) ...@@ -227,12 +234,13 @@ context *context_read_options(int argc, char *argv[])
{"json", required_argument, 0, 'J'}, {"json", required_argument, 0, 'J'},
{"dnn_model", required_argument, 0, 'N'}, {"dnn_model", required_argument, 0, 'N'},
{"l_rules", required_argument, 0, 'l'}, {"l_rules", required_argument, 0, 'l'},
{"probabilistic", required_argument, 0, 'Z'},
{"fplm", required_argument, 0, 'w'} {"fplm", required_argument, 0, 'w'}
}; };
optind = 0; optind = 0;
opterr = 0; opterr = 0;
while ((c = getopt_long (argc, argv, "hvdcTpm:i:n:x:u:r:M:b:f:s:C:F:V:L:D:R:P:J:N:w:l:S:T:", long_options, &option_index)) != -1){ while ((c = getopt_long (argc, argv, "hvdcTpm:i:n:x:u:r:M:b:f:s:C:F:V:L:D:R:P:J:N:w:l:S:T:Z:", long_options, &option_index)) != -1){
switch (c) switch (c)
{ {
...@@ -326,6 +334,10 @@ context *context_read_options(int argc, char *argv[]) ...@@ -326,6 +334,10 @@ context *context_read_options(int argc, char *argv[])
case 'S': case 'S':
ctx->score_method = atoi(optarg); ctx->score_method = atoi(optarg);
break; break;
case 'Z':
ctx->proba_mode = 1;
ctx->proba_factor = atoi(optarg);
break;
} }
} }
......
...@@ -115,6 +115,9 @@ typedef struct { ...@@ -115,6 +115,9 @@ typedef struct {
char *json_filename; char *json_filename;
char *dnn_model_filename; char *dnn_model_filename;
char *l_rules_filename; char *l_rules_filename;
int proba_mode;
float proba_factor;
} context; } context;
context *context_new(void); context *context_new(void);
......
...@@ -170,11 +170,14 @@ void simple_decoder_parser_arc_eager(context *ctx) ...@@ -170,11 +170,14 @@ void simple_decoder_parser_arc_eager(context *ctx)
double currentSumExp; double currentSumExp;
double ScoreTranslation; double ScoreTranslation;
int FlagNotInitExp; int FlagNotInitExp;
double ProbaDivider = 15; // TO SETUP double ProbaDivider = ctx->proba_factor;
double randomFloat; double randomFloat;
srand(mix(clock(), time(NULL), getpid())); if(ProbaDivider <= 0)
ProbaDivider = 1;
if(ctx->proba_mode)
srand(mix(clock(), time(NULL), getpid()));
word* word_scored; word* word_scored;
...@@ -186,7 +189,7 @@ void simple_decoder_parser_arc_eager(context *ctx) ...@@ -186,7 +189,7 @@ void simple_decoder_parser_arc_eager(context *ctx)
sumExp = 0; sumExp = 0;
currentSumExp = 0; currentSumExp = 0;
ScoreTranslation = -5; // TO SETUP ScoreTranslation = -100; // TO SETUP?
FlagNotInitExp = 1; FlagNotInitExp = 1;
if(ctx->debug_mode){ if(ctx->debug_mode){
...@@ -210,9 +213,11 @@ void simple_decoder_parser_arc_eager(context *ctx) ...@@ -210,9 +213,11 @@ void simple_decoder_parser_arc_eager(context *ctx)
config2feat_vec_cff(ctx->features_model, c, ctx->d_perceptron_features, fv, LOOKUP_MODE); config2feat_vec_cff(ctx->features_model, c, ctx->d_perceptron_features, fv, LOOKUP_MODE);
mvt_code = feature_table_argmax(fv, ft, &max); mvt_code = feature_table_argmax(fv, ft, &max);
if(ctx->debug_mode){ if(ctx->proba_mode || ctx->debug_mode){
vcode *vcode_array = feature_table_get_vcode_array(fv, ft); vcode *vcode_array = feature_table_get_vcode_array(fv, ft);
if(ctx->proba_mode){
/* Get the probabilistic parameters */ /* Get the probabilistic parameters */
for(int i=0; i < ft->classes_nb; i++){ for(int i=0; i < ft->classes_nb; i++){
int b1 = respect_standard_constraint(c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code)); int b1 = respect_standard_constraint(c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
...@@ -221,36 +226,45 @@ void simple_decoder_parser_arc_eager(context *ctx) ...@@ -221,36 +226,45 @@ void simple_decoder_parser_arc_eager(context *ctx)
if(b1 && b2 && b3){ if(b1 && b2 && b3){
if(FlagNotInitExp){ if(FlagNotInitExp){
ScoreTranslation += vcode_array[i].score/ProbaDivider; ScoreTranslation += vcode_array[i].score;
FlagNotInitExp = 0; FlagNotInitExp = 0;
} }
if(vcode_array[i].score/ProbaDivider - ScoreTranslation > 0){ if((vcode_array[i].score - ScoreTranslation)/ProbaDivider > 0){
sumExp += exp(vcode_array[i].score/ProbaDivider - ScoreTranslation); sumExp += exp((vcode_array[i].score - ScoreTranslation)/ProbaDivider);
}
} }
} }
} }
currentSumExp = 0.; currentSumExp = 0.;
for(int i=0; i < ft->classes_nb && i < 10; i++){ for(int i=0; i < ft->classes_nb && i < 10; i++){
if(ctx->debug_mode){
printf("%d\t", i); printf("%d\t", i);
movement_parser_print(stdout, vcode_array[i].class_code, ctx->dico_labels); movement_parser_print(stdout, vcode_array[i].class_code, ctx->dico_labels);
printf("\t%.4f", vcode_array[i].score); printf("\t%.4f", vcode_array[i].score);
fflush(stdout); fflush(stdout);
}
int b1 = respect_standard_constraint(c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code)); int b1 = respect_standard_constraint(c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
int b2 = respect_stack_constraint(ctx->partial_mode, c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code)); int b2 = respect_stack_constraint(ctx->partial_mode, c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
int b3 = respect_buffer_constraint(ctx->partial_mode, c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code)); int b3 = respect_buffer_constraint(ctx->partial_mode, c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
if(b1 && b2 && b3){ if(b1 && b2 && b3){
if(vcode_array[i].score/ProbaDivider - ScoreTranslation > 0){ if(ctx->proba_mode && (vcode_array[i].score - ScoreTranslation)/ProbaDivider > 0){
printf(" %f %f %f",sumExp, currentSumExp,ScoreTranslation); if(ctx->debug_mode){
// printf(" %f %f %f",sumExp, currentSumExp,ScoreTranslation);
printf(" [%f-",currentSumExp/sumExp); printf(" [%f-",currentSumExp/sumExp);
currentSumExp += exp(vcode_array[i].score/ProbaDivider - ScoreTranslation); }
currentSumExp += exp((vcode_array[i].score - ScoreTranslation)/ProbaDivider);
if(ctx->debug_mode)
printf("%f[", currentSumExp/sumExp); printf("%f[", currentSumExp/sumExp);
} }
if(ctx->debug_mode)
printf("\t<----"); printf("\t<----");
}else }else if(ctx->debug_mode){
printf("\t<%d,%d,%d>",b1,b2,b3); printf("\t<%d,%d,%d>",b1,b2,b3);
}
// printf("\t%d", respect_standard_constraint(c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code))); // printf("\t%d", respect_standard_constraint(c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code)));
//printf("AAAAAAA\n"); //printf("AAAAAAA\n");
if(ctx->debug_mode)
printf("\n"); printf("\n");
} }
free(vcode_array); free(vcode_array);
...@@ -269,12 +283,14 @@ void simple_decoder_parser_arc_eager(context *ctx) ...@@ -269,12 +283,14 @@ void simple_decoder_parser_arc_eager(context *ctx)
printf("%f\n", max1 - max2); printf("%f\n", max1 - max2);
} }
if(ctx->proba_mode){
currentSumExp = 0.; currentSumExp = 0.;
randomFloat = (double) rand()/(double)RAND_MAX; randomFloat = (double) rand()/(double)RAND_MAX;
if(ctx->debug_mode)
printf("< %f > is our random number. \n",randomFloat); printf("< %f > is our random number. \n",randomFloat);
}
// if(ctx->partial_mode){ // NOT YET COMPATIBLE if(ctx->partial_mode || ctx->proba_mode){
vcode *vcode_array = feature_table_get_vcode_array(fv, ft); vcode *vcode_array = feature_table_get_vcode_array(fv, ft);
mvt_code = 0; mvt_code = 0;
for(int i=0; i < ft->classes_nb; i++){ for(int i=0; i < ft->classes_nb; i++){
...@@ -283,16 +299,22 @@ void simple_decoder_parser_arc_eager(context *ctx) ...@@ -283,16 +299,22 @@ void simple_decoder_parser_arc_eager(context *ctx)
int b3 = respect_buffer_constraint(ctx->partial_mode, c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code)); int b3 = respect_buffer_constraint(ctx->partial_mode, c, movement_parser_type(vcode_array[i].class_code), movement_parser_label(vcode_array[i].class_code));
if(b1 && b2 && b3){ if(b1 && b2 && b3){
currentSumExp += exp(vcode_array[i].score/ProbaDivider - ScoreTranslation); if(ctx->proba_mode){
currentSumExp += exp((vcode_array[i].score - ScoreTranslation)/ProbaDivider);
if(currentSumExp/sumExp >= randomFloat){ if(currentSumExp/sumExp >= randomFloat){
if(ctx->debug_mode)
printf("The %d th move has been selected by the probabilistic parser.\n",i); printf("The %d th move has been selected by the probabilistic parser.\n",i);
mvt_code = vcode_array[i].class_code; mvt_code = vcode_array[i].class_code;
break; break;
} }
}else{
mvt_code = vcode_array[i].class_code;
break;
}
} }
} }
free(vcode_array); free(vcode_array);
// } }
mvt_type = movement_parser_type(mvt_code); mvt_type = movement_parser_type(mvt_code);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment