Skip to content
Snippets Groups Projects
Commit 56e83bcf authored by Alexis Nasr's avatar Alexis Nasr
Browse files

added new tagparse decoder, that tags and parses the sentence at the same time

parent 1cddee84
No related branches found
No related tags found
No related merge requests found
set(SOURCES src/context.c set(SOURCES src/context.c
src/feat_desc.c src/feat_desc.c
# src/movement_parser_arc_eager.c src/movement_parser_arc_eager.c
src/movement_tagparser_arc_eager.c src/movement_tagparser_arc_eager.c
src/movement_tagger.c src/movement_tagger.c
src/feat_fct.c src/feat_fct.c
src/global_feat_vec.c src/global_feat_vec.c
# src/oracle_parser.c # src/oracle_parser.c
# src/oracle_parser_arc_eager.c src/oracle_parser_arc_eager.c
src/oracle_tagparser_arc_eager.c src/oracle_tagparser_arc_eager.c
src/oracle_tagger.c src/oracle_tagger.c
# src/simple_decoder_parser.c src/simple_decoder_parser.c
src/simple_decoder_parser_arc_eager.c src/simple_decoder_parser_arc_eager.c
src/simple_decoder_tagparser_arc_eager.c src/simple_decoder_tagparser_arc_eager.c
src/simple_decoder_forrest.c src/simple_decoder_forrest.c
......
...@@ -78,12 +78,12 @@ void generate_training_file_stream(FILE *output_file, context *ctx) ...@@ -78,12 +78,12 @@ void generate_training_file_stream(FILE *output_file, context *ctx)
config2feat_vec_cff(ctx->features_model, c, ctx->d_perceptron_features, fv, ctx->mode); config2feat_vec_cff(ctx->features_model, c, ctx->d_perceptron_features, fv, ctx->mode);
mvt_code = oracle_tagparser_arc_eager(c, ref, root_label); mvt_code = oracle_tagparser_arc_eager(c, ref, root_label);
mvt_type = movement_type(mvt_code); mvt_type = movement_tagparse_type(mvt_code);
mvt_label = movement_label(mvt_code); mvt_label = movement_tagparse_label(mvt_code);
if(ctx->debug_mode){ if(ctx->debug_mode){
config_print(stdout,c); config_print(stdout,c);
movement_print(stdout, mvt_code, ctx->dico_labels, dico_postag); movement_tagparse_print(stdout, mvt_code, ctx->dico_labels, dico_postag);
fprintf(stdout, "\n"); fprintf(stdout, "\n");
} }
...@@ -92,7 +92,7 @@ void generate_training_file_stream(FILE *output_file, context *ctx) ...@@ -92,7 +92,7 @@ void generate_training_file_stream(FILE *output_file, context *ctx)
stack_print(output_file, c->st); stack_print(output_file, c->st);
fprintf(output_file, "\t"); fprintf(output_file, "\t");
movement_print(output_file, mvt_code, ctx->dico_labels, dico_postag); movement_tagparse_print(output_file, mvt_code, ctx->dico_labels, dico_postag);
fprintf(output_file, "\t1\n"); fprintf(output_file, "\t1\n");
} }
else{ else{
...@@ -101,40 +101,40 @@ void generate_training_file_stream(FILE *output_file, context *ctx) ...@@ -101,40 +101,40 @@ void generate_training_file_stream(FILE *output_file, context *ctx)
} }
if(mvt_type == MVT_EOS){ if(mvt_type == MVT_EOS){
movement_eos(c, 0); movement_tagparse_eos(c, 0);
sentence_nb++; sentence_nb++;
if(word_buffer_is_last(ref)) if(word_buffer_is_last(ref))
break; break;
} }
if(mvt_type == MVT_POSTAG){ if(mvt_type == MVT_POSTAG){
movement_add_pos(c, 0, mvt_label); movement_tagparse_add_pos(c, 0, mvt_label);
continue; continue;
} }
if(mvt_type == MVT_LEFT){ if(mvt_type == MVT_LEFT){
movement_left_arc(c, mvt_label, 0); movement_tagparse_left_arc(c, mvt_label, 0);
continue; continue;
} }
if(mvt_type == MVT_RIGHT){ if(mvt_type == MVT_RIGHT){
movement_right_arc(c, mvt_label, 0); movement_tagparse_right_arc(c, mvt_label, 0);
word_buffer_move_right(ref); word_buffer_move_right(ref);
continue; continue;
} }
if(mvt_type == MVT_REDUCE){ if(mvt_type == MVT_REDUCE){
movement_reduce(c, 0); movement_tagparse_reduce(c, 0);
continue; continue;
} }
if(mvt_type == MVT_ROOT){ if(mvt_type == MVT_ROOT){
movement_root(c, 0, root_label); movement_tagparse_root(c, 0, root_label);
continue; continue;
} }
if(mvt_type == MVT_SHIFT){ if(mvt_type == MVT_SHIFT){
movement_shift(c, 1, 0); movement_tagparse_shift(c, 1, 0);
word_buffer_move_right(ref); word_buffer_move_right(ref);
continue; continue;
} }
......
...@@ -4,11 +4,11 @@ ...@@ -4,11 +4,11 @@
#include"util.h" #include"util.h"
#include"movement_tagparser_arc_eager.h" #include"movement_tagparser_arc_eager.h"
void movement_print(FILE *f, int mvt_code, dico *dico_labels, dico *dico_postag) void movement_tagparse_print(FILE *f, int mvt_code, dico *dico_labels, dico *dico_postag)
{ {
int mvt_type = movement_type(mvt_code); int mvt_type = movement_tagparse_type(mvt_code);
int mvt_label = movement_label(mvt_code); int mvt_label = movement_tagparse_label(mvt_code);
char *label; char *label;
if(mvt_type == MVT_SHIFT) {fprintf(f, "SHIFT"); return;} if(mvt_type == MVT_SHIFT) {fprintf(f, "SHIFT"); return;}
...@@ -27,7 +27,7 @@ void movement_print(FILE *f, int mvt_code, dico *dico_labels, dico *dico_postag) ...@@ -27,7 +27,7 @@ void movement_print(FILE *f, int mvt_code, dico *dico_labels, dico *dico_postag)
fprintf(f, " %s", label); fprintf(f, " %s", label);
} }
int movement_type(int mvt) int movement_tagparse_type(int mvt)
{ {
if(mvt == MVT_SHIFT) return MVT_SHIFT; /* 0 */ if(mvt == MVT_SHIFT) return MVT_SHIFT; /* 0 */
if(mvt == MVT_REDUCE) return MVT_REDUCE; /* 1 */ if(mvt == MVT_REDUCE) return MVT_REDUCE; /* 1 */
...@@ -38,7 +38,7 @@ int movement_type(int mvt) ...@@ -38,7 +38,7 @@ int movement_type(int mvt)
/*if(mvt % 3 == 2)*/ return MVT_LEFT; /* 6, 9, 12 ... */ /*if(mvt % 3 == 2)*/ return MVT_LEFT; /* 6, 9, 12 ... */
} }
int movement_label(int mvt) int movement_tagparse_label(int mvt)
{ {
if(mvt == MVT_SHIFT) return -1; if(mvt == MVT_SHIFT) return -1;
if(mvt == MVT_REDUCE) return -1; if(mvt == MVT_REDUCE) return -1;
...@@ -52,7 +52,7 @@ int movement_label(int mvt) ...@@ -52,7 +52,7 @@ int movement_label(int mvt)
return (mvt - 6) / 3; return (mvt - 6) / 3;
} }
int movement_add_pos(config *c, float score, int pos) int movement_tagparse_add_pos(config *c, float score, int pos)
{ {
if(word_buffer_b0(config_get_buffer(c)) == NULL) return 0; if(word_buffer_b0(config_get_buffer(c)) == NULL) return 0;
if(word_get_pos(word_buffer_b0(config_get_buffer(c))) != -1) return 0; if(word_get_pos(word_buffer_b0(config_get_buffer(c))) != -1) return 0;
...@@ -60,12 +60,12 @@ int movement_add_pos(config *c, float score, int pos) ...@@ -60,12 +60,12 @@ int movement_add_pos(config *c, float score, int pos)
/* stack_push(config_get_stack(c), word_buffer_b0(config_get_buffer(c))); /* stack_push(config_get_stack(c), word_buffer_b0(config_get_buffer(c)));
word_buffer_move_right(config_get_buffer(c));*/ word_buffer_move_right(config_get_buffer(c));*/
config_add_mvt(c, movement_postag(pos)); config_add_mvt(c, movement_tagparse_postag(pos));
return 1; return 1;
} }
int movement_eos(config *c, float score) int movement_tagparse_eos(config *c, float score)
{ {
if(stack_is_empty(config_get_stack(c))) return 0; if(stack_is_empty(config_get_stack(c))) return 0;
if(word_get_sent_seg(stack_top(config_get_stack(c))) == 1) return 0; if(word_get_sent_seg(stack_top(config_get_stack(c))) == 1) return 0;
...@@ -80,7 +80,7 @@ int movement_eos(config *c, float score) ...@@ -80,7 +80,7 @@ int movement_eos(config *c, float score)
return 1; return 1;
} }
int movement_left_arc(config *c, int label, float score) int movement_tagparse_left_arc(config *c, int label, float score)
{ {
if(stack_is_empty(config_get_stack(c))) return 0; if(stack_is_empty(config_get_stack(c))) return 0;
/* if(word_buffer_is_empty(config_get_buffer(c))) return 0; */ /* if(word_buffer_is_empty(config_get_buffer(c))) return 0; */
...@@ -97,11 +97,11 @@ int movement_left_arc(config *c, int label, float score) ...@@ -97,11 +97,11 @@ int movement_left_arc(config *c, int label, float score)
word_set_label(dep, label); word_set_label(dep, label);
stack_pop(config_get_stack(c)); stack_pop(config_get_stack(c));
config_add_mvt(c, movement_left_code(label)); config_add_mvt(c, movement_tagparse_left_code(label));
return 1; return 1;
} }
int movement_right_arc(config *c, int label, float score) int movement_tagparse_right_arc(config *c, int label, float score)
{ {
if(stack_is_empty(config_get_stack(c))) return 0; if(stack_is_empty(config_get_stack(c))) return 0;
...@@ -116,11 +116,11 @@ int movement_right_arc(config *c, int label, float score) ...@@ -116,11 +116,11 @@ int movement_right_arc(config *c, int label, float score)
stack_push(config_get_stack(c), word_buffer_b0(config_get_buffer(c))); stack_push(config_get_stack(c), word_buffer_b0(config_get_buffer(c)));
word_buffer_move_right(config_get_buffer(c)); word_buffer_move_right(config_get_buffer(c));
config_add_mvt(c, movement_right_code(label)); config_add_mvt(c, movement_tagparse_right_code(label));
return 1; return 1;
} }
int movement_shift(config *c, int stream, float score) int movement_tagparse_shift(config *c, int stream, float score)
{ {
if(word_buffer_is_empty(config_get_buffer(c))) return 0; if(word_buffer_is_empty(config_get_buffer(c))) return 0;
stack_push(config_get_stack(c), word_buffer_b0(config_get_buffer(c))); stack_push(config_get_stack(c), word_buffer_b0(config_get_buffer(c)));
...@@ -129,7 +129,7 @@ int movement_shift(config *c, int stream, float score) ...@@ -129,7 +129,7 @@ int movement_shift(config *c, int stream, float score)
return 1; return 1;
} }
int movement_reduce(config *c, float score) int movement_tagparse_reduce(config *c, float score)
{ {
if(stack_nbelem(config_get_stack(c)) <= 1) return 0; if(stack_nbelem(config_get_stack(c)) <= 1) return 0;
...@@ -142,7 +142,7 @@ int movement_reduce(config *c, float score) ...@@ -142,7 +142,7 @@ int movement_reduce(config *c, float score)
return 1; return 1;
} }
int movement_root(config *c, float score, int root_code) int movement_tagparse_root(config *c, float score, int root_code)
{ {
word *s0 = stack_top(config_get_stack(c)); word *s0 = stack_top(config_get_stack(c));
if(s0 == NULL) return 0; if(s0 == NULL) return 0;
......
...@@ -13,23 +13,23 @@ ...@@ -13,23 +13,23 @@
#define MVT_RIGHT 5 #define MVT_RIGHT 5
#define MVT_POSTAG 6 #define MVT_POSTAG 6
#define movement_postag(postag) (3 * (postag) + 4) #define movement_tagparse_postag(postag) (3 * (postag) + 4)
/* even movements are left movements (except 0, which is shift and 2 which is root) */ /* even movements are left movements (except 0, which is shift and 2 which is root) */
#define movement_left_code(label) (3 * (label) + 5) #define movement_tagparse_left_code(label) (3 * (label) + 5)
/* odd movements are right movements (except 1, which is reduce and 3 which is end_of_sentence) */ /* odd movements are right movements (except 1, which is reduce and 3 which is end_of_sentence) */
#define movement_right_code(label) (3 * (label) + 6) #define movement_tagparse_right_code(label) (3 * (label) + 6)
int movement_type(int mvt); int movement_tagparse_type(int mvt);
int movement_label(int mvt); int movement_tagparse_label(int mvt);
int movement_left_arc(config *c, int label, float score); int movement_tagparse_left_arc(config *c, int label, float score);
int movement_right_arc(config *c, int label, float score); int movement_tagparse_right_arc(config *c, int label, float score);
int movement_shift(config *c, int stream, float score); int movement_tagparse_shift(config *c, int stream, float score);
int movement_reduce(config *c, float score); int movement_tagparse_reduce(config *c, float score);
int movement_root(config *c, float score, int root_code); int movement_tagparse_root(config *c, float score, int root_code);
int movement_eos(config *c, float score); int movement_tagparse_eos(config *c, float score);
int movement_add_pos(config *c, float score, int postag); int movement_tagparse_add_pos(config *c, float score, int postag);
void movement_print(FILE *f, int mvt_code, dico *dico_labels, dico *dico_postag); void movement_tagparse_print(FILE *f, int mvt_code, dico *dico_labels, dico *dico_postag);
#endif #endif
...@@ -75,9 +75,9 @@ int oracle_tagparser_arc_eager(config *c, word_buffer *ref, int root_label) ...@@ -75,9 +75,9 @@ int oracle_tagparser_arc_eager(config *c, word_buffer *ref, int root_label)
/* give a pos to b0 if it does not have one */ /* give a pos to b0 if it does not have one */
if(word_get_pos(b0) == -1){ if(word_get_pos(b0) == -1){
/* word_set_pos(b0, word_get_pos(word_buffer_get_word_n(ref, b0_index))); */ /* word_set_pos(b0, word_get_pos(word_buffer_get_word_n(ref, b0_index))); */
/* return movement_postag(word_get_pos(b0)); */ /* return movement_tagparse_postag(word_get_pos(b0)); */
return movement_postag(word_get_pos(word_buffer_get_word_n(ref, b0_index))); return movement_tagparse_postag(word_get_pos(word_buffer_get_word_n(ref, b0_index)));
} }
...@@ -113,12 +113,12 @@ int oracle_tagparser_arc_eager(config *c, word_buffer *ref, int root_label) ...@@ -113,12 +113,12 @@ int oracle_tagparser_arc_eager(config *c, word_buffer *ref, int root_label)
/* LEFT ARC b0 is the governor and s0 the dependent */ /* LEFT ARC b0 is the governor and s0 the dependent */
if(s0_gov_index == b0_index){ if(s0_gov_index == b0_index){
return movement_left_code(word_get_label(word_buffer_get_word_n(ref, s0_index))); return movement_tagparse_left_code(word_get_label(word_buffer_get_word_n(ref, s0_index)));
} }
/* RIGHT ARC s0 is the governor and b0 the dependent */ /* RIGHT ARC s0 is the governor and b0 the dependent */
if(b0_gov_index == s0_index){ if(b0_gov_index == s0_index){
return movement_right_code(word_get_label(word_buffer_get_word_n(ref, b0_index))); return movement_tagparse_right_code(word_get_label(word_buffer_get_word_n(ref, b0_index)));
} }
/* REDUCE */ /* REDUCE */
if((stack_nbelem(config_get_stack(c)) > 1) if((stack_nbelem(config_get_stack(c)) > 1)
......
...@@ -82,8 +82,8 @@ void simple_decoder_tagparser_arc_eager(context *ctx) ...@@ -82,8 +82,8 @@ void simple_decoder_tagparser_arc_eager(context *ctx)
config2feat_vec_cff(ctx->features_model, c, ctx->d_perceptron_features, fv, LOOKUP_MODE); config2feat_vec_cff(ctx->features_model, c, ctx->d_perceptron_features, fv, LOOKUP_MODE);
mvt_code = feature_table_argmax(fv, ft, &max); mvt_code = feature_table_argmax(fv, ft, &max);
mvt_type = movement_type(mvt_code); mvt_type = movement_tagparse_type(mvt_code);
mvt_label = movement_label(mvt_code); mvt_label = movement_tagparse_label(mvt_code);
if(ctx->trace_mode){ if(ctx->trace_mode){
index = word_get_index(word_buffer_b0(config_get_buffer(c))); index = word_get_index(word_buffer_b0(config_get_buffer(c)));
...@@ -92,7 +92,7 @@ void simple_decoder_tagparser_arc_eager(context *ctx) ...@@ -92,7 +92,7 @@ void simple_decoder_tagparser_arc_eager(context *ctx)
stack_print(stdout, c->st); stack_print(stdout, c->st);
fprintf(stdout, "\t"); fprintf(stdout, "\t");
movement_print(stdout, mvt_code, ctx->dico_labels, ctx->dico_postags); movement_tagparse_print(stdout, mvt_code, ctx->dico_labels, ctx->dico_postags);
fprintf(stdout, "\t"); fprintf(stdout, "\t");
feature_table_argmax_1_2(fv, ft, &argmax1, &max1, &argmax2, &max2); feature_table_argmax_1_2(fv, ft, &argmax1, &max1, &argmax2, &max2);
printf("%f\n", max1 - max2); printf("%f\n", max1 - max2);
...@@ -105,9 +105,9 @@ void simple_decoder_tagparser_arc_eager(context *ctx) ...@@ -105,9 +105,9 @@ void simple_decoder_tagparser_arc_eager(context *ctx)
entropy = feature_table_entropy(fv, ft); entropy = feature_table_entropy(fv, ft);
/* delta = feature_table_diff_scores(fv, ft); */ /* delta = feature_table_diff_scores(fv, ft); */
feature_table_argmax_1_2(fv, ft, &argmax1, &max1, &argmax2, &max2); feature_table_argmax_1_2(fv, ft, &argmax1, &max1, &argmax2, &max2);
movement_print(stdout, argmax1, ctx->dico_labels, ctx->dico_postags); movement_tagparse_print(stdout, argmax1, ctx->dico_labels, ctx->dico_postags);
printf(":\t%f\n", max1); printf(":\t%f\n", max1);
movement_print(stdout, argmax2, ctx->dico_labels, ctx->dico_postags); movement_tagparse_print(stdout, argmax2, ctx->dico_labels, ctx->dico_postags);
printf(":\t%f\n", max2); printf(":\t%f\n", max2);
printf("delta = %f\n", max1 - max2); printf("delta = %f\n", max1 - max2);
...@@ -115,37 +115,37 @@ void simple_decoder_tagparser_arc_eager(context *ctx) ...@@ -115,37 +115,37 @@ void simple_decoder_tagparser_arc_eager(context *ctx)
/* printf("entropy = %f delta = %f\n", entropy, delta); */ /* printf("entropy = %f delta = %f\n", entropy, delta); */
printf("entropy = %f\n",entropy); printf("entropy = %f\n",entropy);
/* movement_print(stdout, mvt_code, ctx->dico_labels); */ /* movement_tagparse_print(stdout, mvt_code, ctx->dico_labels); */
} }
result = 0; result = 0;
switch(mvt_type){ switch(mvt_type){
case MVT_POSTAG : case MVT_POSTAG :
result = movement_add_pos(c, max, mvt_label); result = movement_tagparse_add_pos(c, max, mvt_label);
break; break;
case MVT_LEFT : case MVT_LEFT :
result = movement_left_arc(c, mvt_label, max); result = movement_tagparse_left_arc(c, mvt_label, max);
break; break;
case MVT_RIGHT: case MVT_RIGHT:
result = movement_right_arc(c, mvt_label, max); result = movement_tagparse_right_arc(c, mvt_label, max);
break; break;
case MVT_REDUCE: case MVT_REDUCE:
result = movement_reduce(c, max); result = movement_tagparse_reduce(c, max);
break; break;
case MVT_ROOT: case MVT_ROOT:
result = movement_root(c, max, root_label); result = movement_tagparse_root(c, max, root_label);
break; break;
case MVT_EOS: case MVT_EOS:
result = movement_eos(c, max); result = movement_tagparse_eos(c, max);
break; break;
case MVT_SHIFT: case MVT_SHIFT:
result = movement_shift(c, 1, max); result = movement_tagparse_shift(c, 1, max);
} }
if(result == 0){ if(result == 0){
if(ctx->debug_mode){ if(ctx->debug_mode){
fprintf(stdout, "WARNING : movement cannot be executed doing a SHIFT instead !\n"); fprintf(stdout, "WARNING : movement cannot be executed doing a SHIFT instead !\n");
} }
movement_shift(c, 1, max); movement_tagparse_shift(c, 1, max);
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment