Skip to content
Snippets Groups Projects
Commit 830bd4b3 authored by Alexis Nasr's avatar Alexis Nasr
Browse files

modified oracle functions

parent 841ad3e5
No related branches found
No related tags found
No related merge requests found
...@@ -96,6 +96,13 @@ typedef struct _word { ...@@ -96,6 +96,13 @@ typedef struct _word {
#define word_set_signature(w, val) ((w)->signature = (val)) #define word_set_signature(w, val) ((w)->signature = (val))
#define word_set_index(w, val) ((w)->index = (val)) #define word_set_index(w, val) ((w)->index = (val))
int word_sprint_lemma(char *s, word *w);
int word_sprint_form(char *s, word *w);
word *word_new(char *input); word *word_new(char *input);
word *word_create_dummy(mcd *mcd_struct); word *word_create_dummy(mcd *mcd_struct);
...@@ -110,6 +117,6 @@ word *word_parse_buffer(char *buffer, mcd *mcd_struct); ...@@ -110,6 +117,6 @@ word *word_parse_buffer(char *buffer, mcd *mcd_struct);
int word_is_eos(word *w, mcd *mcd_struct); int word_is_eos(word *w, mcd *mcd_struct);
int word_get_gov_index(word *w); int word_get_gov_index(word *w);
void word_print_col_n(FILE *f, word *w, int n); void word_print_col_n(FILE *f, word *w, int n);
void word_sprint_col_n(char *s, word *w, int n); int word_sprint_col_n(char *s, word *w, int n);
#endif #endif
...@@ -166,13 +166,15 @@ void word_print_col_n(FILE *f, word *w, int n) ...@@ -166,13 +166,15 @@ void word_print_col_n(FILE *f, word *w, int n)
} }
} }
void word_sprint_col_n(char *s, word *w, int n) int word_sprint_col_n(char *s, word *w, int n)
{ {
int i; int i;
int col = 0; int col = 0;
int j = 0; int j = 0;
s[0] = '\0';
if(n == -1) return 0;
char *buffer = w->input; char *buffer = w->input;
if(buffer == NULL) return; if(buffer == NULL) return 0;
int l= strlen(buffer); int l= strlen(buffer);
for(i=0; i < l; i++){ for(i=0; i < l; i++){
if(buffer[i] == '\t') { if(buffer[i] == '\t') {
...@@ -183,4 +185,5 @@ void word_sprint_col_n(char *s, word *w, int n) ...@@ -183,4 +185,5 @@ void word_sprint_col_n(char *s, word *w, int n)
s[j++] = buffer[i]; s[j++] = buffer[i];
} }
s[j] = '\0'; s[j] = '\0';
return 1;
} }
...@@ -3,6 +3,8 @@ set(SOURCES src/context.c ...@@ -3,6 +3,8 @@ set(SOURCES src/context.c
src/feat_fct.c src/feat_fct.c
src/feat_types.c src/feat_types.c
src/oracle_parser_arc_eager.c src/oracle_parser_arc_eager.c
src/oracle_tagparser_arc_eager.c
src/oracle_tagger.c
src/simple_decoder_parser_arc_eager.c src/simple_decoder_parser_arc_eager.c
src/simple_decoder_tagger.c src/simple_decoder_tagger.c
src/feat_lib.c src/feat_lib.c
...@@ -20,14 +22,13 @@ set(SOURCES src/context.c ...@@ -20,14 +22,13 @@ set(SOURCES src/context.c
src/classifier.c src/classifier.c
src/simple_decoder_tagparser_arc_eager.c src/simple_decoder_tagparser_arc_eager.c
# src/simple_decoder_parser.c
# src/oracle_parser.c
# src/global_feat_vec.c # src/global_feat_vec.c
# src/movement_parser_arc_eager.c # src/movement_parser_arc_eager.c
# src/movement_tagparser_arc_eager.c # src/movement_tagparser_arc_eager.c
# src/movement_tagger.c # src/movement_tagger.c
# src/oracle_parser.c
# src/oracle_tagparser_arc_eager.c
# src/oracle_tagger.c
# src/simple_decoder_parser.c
# src/simple_decoder_tagparser_arc_eager.c # src/simple_decoder_tagparser_arc_eager.c
# src/simple_decoder_forrest.c # src/simple_decoder_forrest.c
# src/simple_decoder_tagger_bt.c # src/simple_decoder_tagger_bt.c
......
...@@ -12,10 +12,10 @@ ...@@ -12,10 +12,10 @@
#include"word_emb.h" #include"word_emb.h"
#include"config2feat_vec.h" #include"config2feat_vec.h"
int oracle_tagger(config *c) /*int oracle_tagger(config *c)
{ {
return word_get_pos(word_buffer_b0(config_get_buffer(c))); return word_get_pos(word_buffer_b0(config_get_buffer(c)));
} }*/
#if 1 #if 1
void add_signature_to_words_in_word_buffer(word_buffer *bf, form2pos *f2p) void add_signature_to_words_in_word_buffer(word_buffer *bf, form2pos *f2p)
...@@ -96,7 +96,6 @@ void generate_training_file(FILE *output_file, context *ctx) ...@@ -96,7 +96,6 @@ void generate_training_file(FILE *output_file, context *ctx)
int mvt_code; int mvt_code;
feat_vec *fv = ctx->classif->fv; feat_vec *fv = ctx->classif->fv;
FILE *mcf_file = myfopen(ctx->input_filename, "r"); FILE *mcf_file = myfopen(ctx->input_filename, "r");
int postag;
int word_nb = 0; int word_nb = 0;
/* dico *dico_pos = dico_vec_get_dico(ctx->vocabs, (char *)"POS"); */ /* dico *dico_pos = dico_vec_get_dico(ctx->vocabs, (char *)"POS"); */
...@@ -109,11 +108,10 @@ void generate_training_file(FILE *output_file, context *ctx) ...@@ -109,11 +108,10 @@ void generate_training_file(FILE *output_file, context *ctx)
if((++word_nb % 1000) == 0) fprintf(stderr, "\rword %d", word_nb); if((++word_nb % 1000) == 0) fprintf(stderr, "\rword %d", word_nb);
postag = oracle_tagger(c);
config2feat_vec_cff(ctx->classif->fm, c, ctx->classif->d_features, fv, ctx->mode); config2feat_vec_cff(ctx->classif->fm, c, ctx->classif->d_features, fv, ctx->mode);
/* word_print(stdout, word_buffer_b0(config_get_buffer(c))); */ /* word_print(stdout, word_buffer_b0(config_get_buffer(c))); */
mvt_code = mvt_tagset_get_code(classifier_get_output_tagset(ctx->classif), MVT_POS, postag); mvt_code = oracle_tagger(c, classifier_get_output_tagset(ctx->classif));
fprintf(output_file, "%d", mvt_code); fprintf(output_file, "%d", mvt_code);
feat_vec_print(output_file, fv); feat_vec_print(output_file, fv);
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
#include<stdlib.h> #include<stdlib.h>
#include<string.h> #include<string.h>
#include"word_buffer.h" #include"word_buffer.h"
/* #include"movement_parser_arc_eager.h" */
#include"mvt_tagset.h" #include"mvt_tagset.h"
#include"config.h" #include"config.h"
#include"dico.h" #include"dico.h"
...@@ -69,7 +68,6 @@ int oracle_parser_arc_eager(config *c, word_buffer *ref, int root_label, mvt_tag ...@@ -69,7 +68,6 @@ int oracle_parser_arc_eager(config *c, word_buffer *ref, int root_label, mvt_tag
){ ){
return mvt_tagset_get_code(tagset, MVT_ROOT, 0); return mvt_tagset_get_code(tagset, MVT_ROOT, 0);
/* return MVT_PARSER_ROOT; */
} }
/* word on the top of the stack is an end of sentence marker */ /* word on the top of the stack is an end of sentence marker */
...@@ -77,31 +75,25 @@ int oracle_parser_arc_eager(config *c, word_buffer *ref, int root_label, mvt_tag ...@@ -77,31 +75,25 @@ int oracle_parser_arc_eager(config *c, word_buffer *ref, int root_label, mvt_tag
&& (word_get_sent_seg(word_buffer_get_word_n(config_get_buffer(c), s0_index)) != 1)){ && (word_get_sent_seg(word_buffer_get_word_n(config_get_buffer(c), s0_index)) != 1)){
return mvt_tagset_get_code(tagset, MVT_EOS, 0); return mvt_tagset_get_code(tagset, MVT_EOS, 0);
/* return MVT_PARSER_EOS; */
} }
/* LEFT ARC b0 is the governor and s0 the dependent */ /* LEFT ARC b0 is the governor and s0 the dependent */
if(s0_gov_index == b0_index){ if(s0_gov_index == b0_index){
return mvt_tagset_get_code(tagset, MVT_LEFT, word_get_label(word_buffer_get_word_n(ref, s0_index))); return mvt_tagset_get_code(tagset, MVT_LEFT, word_get_label(word_buffer_get_word_n(ref, s0_index)));
/* return movement_parser_left_code(word_get_label(word_buffer_get_word_n(ref, s0_index))); */
} }
/* RIGHT ARC s0 is the governor and b0 the dependent */ /* RIGHT ARC s0 is the governor and b0 the dependent */
if(b0_gov_index == s0_index){ if(b0_gov_index == s0_index){
return mvt_tagset_get_code(tagset, MVT_RIGHT, word_get_label(word_buffer_get_word_n(ref, b0_index))); return mvt_tagset_get_code(tagset, MVT_RIGHT, word_get_label(word_buffer_get_word_n(ref, b0_index)));
/* return movement_parser_right_code(word_get_label(word_buffer_get_word_n(ref, b0_index))); */
} }
/* REDUCE */ /* REDUCE */
if((stack_nbelem(config_get_stack(c)) > 1) if((stack_nbelem(config_get_stack(c)) > 1)
&& check_all_dependents_of_word_in_ref_are_in_hyp(c, ref, s0_index) /* word on top must have all its dependents */ && check_all_dependents_of_word_in_ref_are_in_hyp(c, ref, s0_index) /* word on top must have all its dependents */
&& (word_get_gov(stack_top(config_get_stack(c))) != WORD_INVALID_GOV)){ /* word on top of the stack has a governor */ && (word_get_gov(stack_top(config_get_stack(c))) != WORD_INVALID_GOV)){ /* word on top of the stack has a governor */
return mvt_tagset_get_code(tagset, MVT_REDUCE, 0); return mvt_tagset_get_code(tagset, MVT_REDUCE, 0);
/* return MVT_PARSER_REDUCE; */
} }
} }
/* SHIFT */ /* SHIFT */
return mvt_tagset_get_code(tagset, MVT_SHIFT, 0); return mvt_tagset_get_code(tagset, MVT_SHIFT, 0);
/* return MVT_PARSER_SHIFT; */
} }
#include"oracle_tagger.h" #include"oracle_tagger.h"
int oracle_tagger(config *c) int oracle_tagger(config *c, mvt_tagset *tagset)
{ {
return word_get_pos(word_buffer_b0(config_get_buffer(c))); int postag = word_get_pos(word_buffer_b0(config_get_buffer(c)));
int mvt_code = mvt_tagset_get_code(tagset, MVT_POS, postag);
return mvt_code;
} }
...@@ -5,6 +5,11 @@ ...@@ -5,6 +5,11 @@
#include<stdlib.h> #include<stdlib.h>
#include"config.h" #include"config.h"
int oracle_tagger(config *c);
#include"mvt_tagset.h"
int oracle_tagger(config *c, mvt_tagset *tagset);
#endif #endif
#include<stdio.h> #include<stdio.h>
#include<stdlib.h> #include<stdlib.h>
#include<string.h> #include<string.h>
#include"word_buffer.h" #include"config.h"
#include"movement_tagparser_arc_eager.h" #include"tm.h"
#include"oracle_tagger.h"
#include"oracle_parser_arc_eager.h"
int check_all_dependents_of_word_in_ref_are_in_hyp(config *c, word_buffer *ref, int word_index) int oracle_tagparser_arc_eager(config *c, word_buffer *ref, tm *machine, int root_label)
{ {
int dep; char *state_name = tm_get_state_name(machine, c->current_state_nb);
int gov_ref; tm_state *state = machine->state_array[c->current_state_nb];
int gov_hyp; classifier *classif = state->classif;
int sentence_change; mvt_tagset *output_tagset = classif->output_tagset;
#if 0 if(state_name == NULL) return -1;
for(dep = word_index - 1; (dep >= 0) && (word_get_sent_seg(word_buffer_get_word_n(ref, dep)) == 0); dep--){
gov_ref = word_get_gov_index(word_buffer_get_word_n(ref, dep)); if(!strcmp(state_name, "TAGGER"))
if(gov_ref == word_index){ /* dep is a dependent of word in ref */ return oracle_tagger(c, output_tagset);
/* check that dep has the same governor in hyp */ if(!strcmp(state_name, "PARSER"))
gov_hyp = word_get_gov_index(word_buffer_get_word_n(config_get_buffer(c), dep)); return oracle_parser_arc_eager(c, ref, root_label, output_tagset);
if(gov_hyp != gov_ref) return 0; return -1;
}
}
for(dep = word_index + 1; ((dep < word_buffer_get_nbelem(ref)) && (word_get_sent_seg(word_buffer_get_word_n(ref, dep)) == 0)); dep++){
gov_ref = word_get_gov_index(word_buffer_get_word_n(ref, dep));
if(gov_ref == word_index){ /* dep is a dependent of word in ref */
/* check that dep has the same governor in hyp */
gov_hyp = word_get_gov_index(word_buffer_get_word_n(config_get_buffer(c), dep));
if(gov_hyp != gov_ref) return 0;
}
}
#endif
#if 1
for(dep = word_index - 1; (dep >= 0) && (word_get_sent_seg(word_buffer_get_word_n(ref, dep)) == 0); dep--){
gov_ref = word_get_gov_index(word_buffer_get_word_n(ref, dep));
if(gov_ref == word_index){ /* dep is a dependent of word in ref */
/* check that dep has the same governor in hyp */
gov_hyp = word_get_gov_index(word_buffer_get_word_n(config_get_buffer(c), dep));
if(gov_hyp != gov_ref) return 0;
}
}
sentence_change = 0;
for(dep = word_index + 1; (dep < word_buffer_get_nbelem(ref)) && (sentence_change == 0); dep++){
if(word_get_sent_seg(word_buffer_get_word_n(ref, dep)) == 1)
sentence_change = 1;
gov_ref = word_get_gov_index(word_buffer_get_word_n(ref, dep));
if(gov_ref == word_index){ /* dep is a dependent of word in ref */
/* look for a dependency in hyp such that its dependent is dep */
gov_hyp = word_get_gov_index(word_buffer_get_word_n(config_get_buffer(c), dep));
if(gov_hyp != gov_ref) return 0;
}
}
#endif
return 1;
}
int oracle_tagparser_arc_eager(config *c, word_buffer *ref, int root_label)
{
word *s0; /* word on top of stack */
word *b0; /* next word in the bufer */
int s0_index, b0_index;
int s0_gov_index, b0_gov_index;
int s0_label;
/* int s0_label_in_hyp; */
b0 = word_buffer_b0(config_get_buffer(c));
b0_index = word_get_index(b0);
b0_gov_index = word_get_gov_index(word_buffer_get_word_n(ref, b0_index));
/* give a pos to b0 if it does not have one */
if(word_get_pos(b0) == -1){
/* word_set_pos(b0, word_get_pos(word_buffer_get_word_n(ref, b0_index))); */
/* return movement_tagparser_postag(word_get_pos(b0)); */
return movement_tagparser_postag_code(word_get_pos(word_buffer_get_word_n(ref, b0_index)));
}
/* if(!stack_is_empty(config_get_stack(c)) && !word_buffer_is_empty(config_get_buffer(c))){ */
if(!stack_is_empty(config_get_stack(c))){
s0 = stack_top(config_get_stack(c));
s0_index = word_get_index(s0);
s0_gov_index = word_get_gov_index(word_buffer_get_word_n(ref, s0_index));
s0_label = word_get_label(word_buffer_get_word_n(ref, s0_index));
/* s0_label_in_hyp = word_get_label(word_buffer_get_word_n(config_get_buffer(c), s0_index)); */
/* printf("s0_index = %d b0_index = %d\n", s0_index, b0_index);
printf("dans ref gov de s0 (%d) = %d\n", s0_index, s0_gov_index);
printf("dans ref gov de b0 (%d) = %d\n", b0_index, b0_gov_index);*/
/* s0 is the root of the sentence */
if((s0_label == root_label)
&& check_all_dependents_of_word_in_ref_are_in_hyp(c, ref, s0_index)
){
return MVT_TAGPARSER_ROOT;
}
/* word on the top of the stack is an end of sentence marker */
if((word_get_sent_seg(word_buffer_get_word_n(ref, s0_index)) == 1)
&& (word_get_sent_seg(word_buffer_get_word_n(config_get_buffer(c), s0_index)) != 1)){
return MVT_TAGPARSER_EOS;
}
/* LEFT ARC b0 is the governor and s0 the dependent */
if(s0_gov_index == b0_index){
return movement_tagparser_left_code(word_get_label(word_buffer_get_word_n(ref, s0_index)));
}
/* RIGHT ARC s0 is the governor and b0 the dependent */
if(b0_gov_index == s0_index){
return movement_tagparser_right_code(word_get_label(word_buffer_get_word_n(ref, b0_index)));
}
/* REDUCE */
if((stack_nbelem(config_get_stack(c)) > 1)
&& check_all_dependents_of_word_in_ref_are_in_hyp(c, ref, s0_index) /* word on top must have all its dependents */
&& (word_get_gov(stack_top(config_get_stack(c))) != WORD_INVALID_GOV)){ /* word on top of the stack has a governor */
return MVT_TAGPARSER_REDUCE;
}
}
/* SHIFT */
return MVT_TAGPARSER_SHIFT;
} }
...@@ -60,6 +60,11 @@ tm_state *tm_get_state(tm *machine, char *state_name) ...@@ -60,6 +60,11 @@ tm_state *tm_get_state(tm *machine, char *state_name)
return machine->state_array[state_number]; return machine->state_array[state_number];
} }
char *tm_get_state_name(tm *machine, int state_nb)
{
return dico_int2string(machine->d_states, state_nb);
}
void tm_add_transition(tm *machine, char *origin_state_name, char *destination_state_name, int mvt_code) void tm_add_transition(tm *machine, char *origin_state_name, char *destination_state_name, int mvt_code)
{ {
......
...@@ -41,6 +41,8 @@ tm *tm_load(char *filename, char *absolute_path, int verbose); ...@@ -41,6 +41,8 @@ tm *tm_load(char *filename, char *absolute_path, int verbose);
void tm_print(FILE *f, tm *machine); void tm_print(FILE *f, tm *machine);
void tm_link_to_classifier_vector(tm *machine, classifier_vec *classif_vec); void tm_link_to_classifier_vector(tm *machine, classifier_vec *classif_vec);
int tm_delta(tm *machine, int state_nb, int symbol); int tm_delta(tm *machine, int state_nb, int symbol);
char *tm_get_state_name(tm *machine, int state_nb);
/* tm_state *tm_delta(tm *machine, tm_state *state, int symbol); */ /* tm_state *tm_delta(tm *machine, tm_state *state, int symbol); */
#endif #endif
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment