Skip to content
Snippets Groups Projects
Commit d2c96a06 authored by Alexis Nasr's avatar Alexis Nasr
Browse files

new movement called IGNORE to skip words in the buffer, used for punctuation

parent 256e0277
No related branches found
No related tags found
No related merge requests found
......@@ -22,6 +22,7 @@ void context_free(context *ctx)
if(ctx->maca_data_path) free(ctx->maca_data_path);
if(ctx->language) free(ctx->language);
if(ctx->root_label) free(ctx->root_label);
if(ctx->punct_label) free(ctx->punct_label);
if(ctx->vocabs_filename) free(ctx->vocabs_filename);
if(ctx->fplm_filename) free(ctx->fplm_filename);
if(ctx->json_filename) free(ctx->json_filename);
......@@ -79,6 +80,7 @@ context *context_new(void)
ctx->language = strdup("fr");
ctx->root_label = strdup("root");
ctx->root_label = strdup("ponct");
ctx->d_perceptron_features = NULL;
ctx->d_perceptron_features_error = NULL;
ctx->mcd_struct = NULL;
......@@ -175,6 +177,9 @@ void context_maca_data_path_help_message(context *ctx){
void context_root_label_help_message(context *ctx){
fprintf(stderr, "\t-R --root_label <str> : name of the root label (default is \"root\")\n");
}
void context_punct_label_help_message(context *ctx){
fprintf(stderr, "\t-U --punct_label <str> : name of the punct label (default is \"ponct\")\n");
}
void context_f2p_filename_help_message(context *ctx){
fprintf(stderr, "\t-P --f2p <file> : form to pos (f2p) filename\n");
}
......@@ -201,7 +206,7 @@ context *context_read_options(int argc, char *argv[])
ctx->program_name = strdup(argv[0]);
static struct option long_options[30] =
static struct option long_options[31] =
{
{"help", no_argument, 0, 'h'},
{"force", no_argument, 0, 'K'},
......@@ -227,6 +232,7 @@ context *context_read_options(int argc, char *argv[])
{"language", required_argument, 0, 'L'},
{"maca_data_path", required_argument, 0, 'D'},
{"root_label", required_argument, 0, 'R'},
{"punct_label", required_argument, 0, 'U'},
{"f2p", required_argument, 0, 'P'},
{"traces", required_argument, 0, 'T'},
{"json", required_argument, 0, 'J'},
......@@ -238,7 +244,7 @@ context *context_read_options(int argc, char *argv[])
opterr = 0;
while ((c = getopt_long (argc, argv, "hKvdcSTm:i:A:B:n:x:q:u:r:M:b:f:s:C:F:V:L:D:R:P:J:N:w:l:", long_options, &option_index)) != -1){
while ((c = getopt_long (argc, argv, "hKvdcSTm:i:A:B:n:x:q:u:r:M:b:f:s:C:F:V:L:D:R:U:P:J:N:w:l:", long_options, &option_index)) != -1){
switch (c)
{
case 'A':
......@@ -324,6 +330,10 @@ context *context_read_options(int argc, char *argv[])
if (ctx->root_label) free(ctx->root_label); // libérer le default (strdup("root") )
ctx->root_label = strdup(optarg);
break;
case 'U':
if (ctx->punct_label) free(ctx->punct_label); // libérer le default (strdup("ponct") )
ctx->punct_label = strdup(optarg);
break;
case 'P':
ctx->f2p_filename = strdup(optarg);
if(!strcmp(ctx->f2p_filename, "_") || !strcmp(ctx->f2p_filename, "NULL"))
......
......@@ -106,6 +106,7 @@ typedef struct {
char *maca_data_path;
char *language;
char *root_label;
char *punct_label;
form2pos *f2p;
int conll;
int ifpls;
......@@ -149,6 +150,7 @@ void context_conll_help_message(context *ctx);
void context_ifpls_help_message(context *ctx);
void context_input_help_message(context *ctx);
void context_root_label_help_message(context *ctx);
void context_punct_label_help_message(context *ctx);
void context_debug_help_message(context *ctx);
void context_json_help_message(context *ctx);
......
......@@ -75,7 +75,7 @@ void oracle_movement(int *mvt_code_oracle, char *mvt_type_oracle, int *mvt_label
{
if (!word_buffer_end(ref_oracle) && (*sentence_nb < ctx->sent_nb)) {
*mvt_code_oracle = oracle_parser_arc_eager(config_oracle, ref_oracle, root_label_oracle);
*mvt_code_oracle = oracle_parser_arc_eager(config_oracle, ref_oracle, root_label_oracle, -1);
*mvt_type_oracle = movement_parser_type(*mvt_code_oracle);
*mvt_label_oracle = movement_parser_label(*mvt_code_oracle);
......
......@@ -28,6 +28,10 @@ void maca_trans_parser_mcf2cff_help_message(context *ctx)
fprintf(stderr, "IN TRAIN MODE\n");
context_vocabs_help_message(ctx);
context_root_label_help_message(ctx);
context_punct_label_help_message(ctx);
}
void maca_trans_parser_mcf2cff_check_options(context *ctx)
......@@ -51,6 +55,7 @@ void generate_training_file(FILE *output_file, context *ctx)
feat_vec *fv = feat_vec_new(feature_types_nb);
int sentence_nb = 0;
int root_label = dico_string2int(ctx->dico_labels, (char *) ctx->root_label);
int punct_label = dico_string2int(ctx->dico_labels, (char *) ctx->punct_label);
word_buffer *ref = word_buffer_load_mcf(ctx->input_filename, ctx->mcd_struct);
FILE *mcf_file = myfopen(ctx->input_filename, "r");
......@@ -66,7 +71,7 @@ void generate_training_file(FILE *output_file, context *ctx)
c = config_new(mcf_file, mcd_struct_hyp, 5);
while(!word_buffer_end(ref) && (sentence_nb < ctx->sent_nb)){
mvt_code = oracle_parser_arc_eager(c, ref, root_label);
mvt_code = oracle_parser_arc_eager(c, ref, root_label, punct_label);
mvt_type = movement_parser_type(mvt_code);
mvt_label = movement_parser_label(mvt_code);
......@@ -110,6 +115,10 @@ void generate_training_file(FILE *output_file, context *ctx)
case MVT_PARSER_ROOT :
movement_parser_root(c, root_label);
break;
case MVT_PARSER_IGNORE :
movement_parser_ignore(c);
word_buffer_move_right(ref);
break;
case MVT_PARSER_SHIFT :
movement_parser_shift(c);
word_buffer_move_right(ref);
......
......@@ -13,6 +13,7 @@ void movement_parser_print(FILE *f, int mvt_code, dico *dico_labels){
if(mvt_type == MVT_PARSER_REDUCE) {fprintf(f, "REDUCE"); return;}
if(mvt_type == MVT_PARSER_ROOT) {fprintf(f, "ROOT"); return;}
if(mvt_type == MVT_PARSER_EOS) {fprintf(f, "EOS"); return;}
if(mvt_type == MVT_PARSER_IGNORE) {fprintf(f, "IGNORE"); return;}
if(mvt_type == MVT_PARSER_RIGHT) fprintf(f, "RIGHT");
else fprintf(f, "LEFT");
label = dico_int2string(dico_labels, mvt_label);
......@@ -27,6 +28,7 @@ void movement_parser_sprint(char *f, int mvt_code, dico *dico_labels){
if(mvt_type == MVT_PARSER_REDUCE) {sprintf(f, "REDUCE"); return;}
if(mvt_type == MVT_PARSER_ROOT) {sprintf(f, "ROOT"); return;}
if(mvt_type == MVT_PARSER_EOS) {sprintf(f, "EOS"); return;}
if(mvt_type == MVT_PARSER_IGNORE) {sprintf(f, "IGNORE"); return;}
if(mvt_type == MVT_PARSER_RIGHT) sprintf(f, "RIGHT");
else sprintf(f, "LEFT");
label = dico_int2string(dico_labels, mvt_label);
......@@ -40,8 +42,9 @@ int movement_parser_type(int mvt)
if(mvt == MVT_PARSER_REDUCE) return MVT_PARSER_REDUCE;
if(mvt == MVT_PARSER_ROOT) return MVT_PARSER_ROOT;
if(mvt == MVT_PARSER_EOS) return MVT_PARSER_EOS;
if(mvt % 2 == 0) return MVT_PARSER_LEFT; /* even movements are left movements */
return MVT_PARSER_RIGHT; /* odd movements are right movements */
if(mvt == MVT_PARSER_IGNORE) return MVT_PARSER_IGNORE;
if(mvt % 2 == 0) return MVT_PARSER_RIGHT; /* even movements are right movements */
return MVT_PARSER_LEFT; /* odd movements are left movements */
}
int movement_parser_label(int mvt)
......@@ -50,9 +53,10 @@ int movement_parser_label(int mvt)
if(mvt == MVT_PARSER_REDUCE) return -1;
if(mvt == MVT_PARSER_ROOT) return -1;
if(mvt == MVT_PARSER_EOS) return -1;
if(mvt % 2 == 0) /* even codes correspond to left movements */
return mvt / 2 - 2;
return (mvt - 1) / 2 - 2; /* odd codes correspond to right movements */
if(mvt == MVT_PARSER_IGNORE) return -1;
if(mvt % 2 == 0) /* even codes correspond to right movements */
return mvt / 2 - 3;
return (mvt + 1) / 2 - 3; /* odd codes correspond to left movements */
}
int movement_parser_eos(config *c)
......@@ -90,6 +94,11 @@ int movement_parser_shift(config *c)
return movement_shift(c, MVT_PARSER_SHIFT);
}
int movement_parser_ignore(config *c)
{
return movement_ignore(c, MVT_PARSER_IGNORE);
}
int movement_parser_shift_undo(config *c)
{
return movement_shift_undo(c);
......
......@@ -8,14 +8,15 @@
#define MVT_PARSER_REDUCE 1
#define MVT_PARSER_ROOT 2
#define MVT_PARSER_EOS 3
#define MVT_PARSER_LEFT 4
#define MVT_PARSER_RIGHT 5
#define MVT_PARSER_IGNORE 4
#define MVT_PARSER_LEFT 5
#define MVT_PARSER_RIGHT 6
/* even movements are left movements (except 0, which is shift and 2 which is root) */
#define movement_parser_left_code(label) (2 * (label) + 4)
/* odd movements are left movements (except 0, which is shift and 2 which is root) */
#define movement_parser_left_code(label) (2 * (label) + 5)
/* odd movements are right movements (except 1, which is reduce and 3 which is end_of_sentence) */
#define movement_parser_right_code(label) (2 * (label) + 5)
/* even movements are right movements (except 1, which is reduce and 3 which is end_of_sentence) */
#define movement_parser_right_code(label) (2 * (label) + 6)
int movement_parser_type(int mvt);
int movement_parser_label(int mvt);
......@@ -24,6 +25,7 @@ int movement_parser_left_arc(config *c, int label);
int movement_parser_left_arc_undo(config *c);
int movement_parser_right_arc(config *c, int label);
int movement_parser_right_arc_undo(config *c);
int movement_parser_ignore(config *c);
int movement_parser_shift(config *c);
int movement_parser_shift_undo(config *c);
int movement_parser_reduce(config *c);
......
......@@ -84,8 +84,12 @@ int movement_right_arc_undo(config *c)
int movement_ignore(config *c, int movement_code)
{
if(word_buffer_end(config_get_buffer(c))) return 0;
word *b0 = word_buffer_b0(config_get_buffer(c));
word_set_gov(b0, WORD_INVALID_GOV);
word_set_label(b0, -1);
config_push_mvt(c, movement_code, b0, NULL);
word_buffer_move_right(config_get_buffer(c));
// fprintf(stderr, "IGNORE\n");
return 1;
}
......@@ -162,6 +166,16 @@ int movement_root_undo(config *c)
}
int movement_eos(config *c, int movement_code)
{
word *b0 = word_buffer_b0(config_get_buffer(c));
/* set b0 to eos */
word_set_sent_seg(b0, 1);
config_push_mvt(c, movement_code, b0, NULL);
word_buffer_move_right(config_get_buffer(c));
return 1;
}
int movement_eos_old(config *c, int movement_code)
{
if(stack_is_empty(config_get_stack(c))) return 0;
word *s0 = stack_top(config_get_stack(c));
......
......@@ -4,16 +4,20 @@
#include"word_buffer.h"
#include"movement_parser_arc_eager.h"
int check_all_dependents_of_word_in_ref_are_in_hyp(config *c, word_buffer *ref, int word_index)
int check_all_dependents_of_word_in_ref_are_in_hyp(config *c, word_buffer *ref, int word_index, int punct_label)
{
int dep;
int gov_ref;
int gov_hyp;
int sentence_change;
for(dep = word_index - 1; (dep >= 0) && (word_get_sent_seg(word_buffer_get_word_n(ref, dep)) == 0); dep--){
for(dep = word_index - 1;
(dep >= 0) && (word_get_sent_seg(word_buffer_get_word_n(ref, dep)) == 0);
dep--){
gov_ref = word_get_gov_index(word_buffer_get_word_n(ref, dep));
if(gov_ref == word_index){ /* dep is a dependent of word in ref */
if((gov_ref == word_index)
&& (word_get_label(word_buffer_get_word_n(ref, dep)) != punct_label))
{ /* dep is a dependent of word in ref */
/* check that dep has the same governor in hyp */
gov_hyp = word_get_gov_index(word_buffer_get_word_n(config_get_buffer(c), dep));
if(gov_hyp != gov_ref) return 0;
......@@ -25,7 +29,9 @@ int check_all_dependents_of_word_in_ref_are_in_hyp(config *c, word_buffer *ref,
if(word_get_sent_seg(word_buffer_get_word_n(ref, dep)) == 1)
sentence_change = 1;
gov_ref = word_get_gov_index(word_buffer_get_word_n(ref, dep));
if(gov_ref == word_index){ /* dep is a dependent of word in ref */
if((gov_ref == word_index)
&& (word_get_label(word_buffer_get_word_n(ref, dep)) != punct_label))
{ /* dep is a dependent of word in ref */
/* look for a dependency in hyp such that its dependent is dep */
gov_hyp = word_get_gov_index(word_buffer_get_word_n(config_get_buffer(c), dep));
if(gov_hyp != gov_ref) return 0;
......@@ -34,21 +40,31 @@ int check_all_dependents_of_word_in_ref_are_in_hyp(config *c, word_buffer *ref,
return 1;
}
int oracle_parser_arc_eager(config *c, word_buffer *ref, int root_label)
int oracle_parser_arc_eager(config *c, word_buffer *ref, int root_label, int punct_label)
{
word *s0; /* word on top of stack */
word *b0; /* next word in the bufer */
int s0_index, b0_index;
int s0_gov_index, b0_gov_index;
int s0_label;
int s0_label, b0_label;
/* int s0_label_in_hyp; */
/* if(!stack_is_empty(config_get_stack(c)) && !word_buffer_is_empty(config_get_buffer(c))){ */
if(!stack_is_empty(config_get_stack(c))){
b0 = word_buffer_b0(config_get_buffer(c));
b0_index = word_get_index(b0);
b0_gov_index = word_get_gov_index(word_buffer_get_word_n(ref, b0_index));
b0_label = word_get_label(word_buffer_get_word_n(ref, b0_index));
/* word in front of the buffer is an end of sentence marker */
if(word_get_sent_seg(word_buffer_get_word_n(ref, b0_index)) == 1) return MVT_PARSER_EOS;
/* ignore punctuation */
if(b0_label == punct_label) return MVT_PARSER_IGNORE;
/* SHIFT */
if(stack_is_empty(config_get_stack(c))) return MVT_PARSER_SHIFT;
s0 = stack_top(config_get_stack(c));
s0_index = word_get_index(s0);
......@@ -63,18 +79,11 @@ int oracle_parser_arc_eager(config *c, word_buffer *ref, int root_label)
/* s0 is the root of the sentence */
if((s0_label == root_label)
// && (word_get_label(word_buffer_get_word_n(config_get_buffer(c), s0_index)) != root_label)
&& check_all_dependents_of_word_in_ref_are_in_hyp(c, ref, s0_index)
&& check_all_dependents_of_word_in_ref_are_in_hyp(c, ref, s0_index, punct_label)
){
return MVT_PARSER_ROOT;
}
/* word on the top of the stack is an end of sentence marker */
if((word_get_sent_seg(word_buffer_get_word_n(ref, s0_index)) == 1)
// && (word_get_sent_seg(word_buffer_get_word_n(config_get_buffer(c), s0_index)) != 1)
&& check_all_dependents_of_word_in_ref_are_in_hyp(c, ref, s0_index)
){
return MVT_PARSER_EOS;
}
/* LEFT ARC b0 is the governor and s0 the dependent */
if(s0_gov_index == b0_index){
......@@ -87,13 +96,8 @@ int oracle_parser_arc_eager(config *c, word_buffer *ref, int root_label)
}
/* REDUCE */
if((stack_nbelem(config_get_stack(c)) > 1)
&& check_all_dependents_of_word_in_ref_are_in_hyp(c, ref, s0_index) /* word on top must have all its dependents */
&& check_all_dependents_of_word_in_ref_are_in_hyp(c, ref, s0_index, punct_label) /* word on top must have all its dependents */
&& (word_get_gov(stack_top(config_get_stack(c))) != WORD_INVALID_GOV)){ /* word on top of the stack has a governor */
return MVT_PARSER_REDUCE;
}
}
/* SHIFT */
return MVT_PARSER_SHIFT;
}
......@@ -6,6 +6,6 @@
#include"word_buffer.h"
int oracle_parser_arc_eager(config *c, word_buffer *ref, int root_label);
int oracle_parser_arc_eager(config *c, word_buffer *ref, int root_label, int punct_label);
#endif
......@@ -216,6 +216,9 @@ void simple_decoder_parser_arc_eager(context *ctx)
case MVT_PARSER_EOS:
result = movement_parser_eos(c);
break;
case MVT_PARSER_IGNORE:
result = movement_parser_ignore(c);
break;
case MVT_PARSER_SHIFT:
result = movement_parser_shift(c);
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment