Skip to content
Snippets Groups Projects
Commit 2b0d8d6c authored by Franck Dary's avatar Franck Dary
Browse files

Added more info to debug mode in maca_tm_decoder. We are now able to see...

Added more info to debug mode in maca_tm_decoder. We are now able to see movement choices and scores.
parent 9e78bd47
Branches
No related tags found
No related merge requests found
...@@ -25,6 +25,7 @@ classifier *classifier_new(char *name) ...@@ -25,6 +25,7 @@ classifier *classifier_new(char *name)
classif->fplm_filename = NULL; classif->fplm_filename = NULL;
classif->d_features_filename = NULL; classif->d_features_filename = NULL;
classif->d_tapes_filename = NULL; classif->d_tapes_filename = NULL;
classif->last_prediction = -1;
/* compute default filenames */ /* compute default filenames */
strcpy(string, name); strcpy(string, name);
...@@ -81,29 +82,35 @@ int classifier_argmax(classifier *classif, config *c, word_emb *emb, mcd *m) ...@@ -81,29 +82,35 @@ int classifier_argmax(classifier *classif, config *c, word_emb *emb, mcd *m)
{ {
if(classif->type == classifier::Type::Classifier){ if(classif->type == classifier::Type::Classifier){
if(classif->mlp) if(classif->mlp)
return classifier_argmax_nn(classif, c, emb, m); classif->last_prediction = classifier_argmax_nn(classif, c, emb, m);
else else
return classifier_argmax_perceptron(classif, c); classif->last_prediction = classifier_argmax_perceptron(classif, c);
} }
else if(classif->type == classifier::Type::Lookup){ else if(classif->type == classifier::Type::Lookup){
if(!strcmp("LEMMATIZER_LOOKUP", classifier_get_oracle_name(classif))){ if(!strcmp("LEMMATIZER_LOOKUP", classifier_get_oracle_name(classif))){
dico *d_form = dico_vec_get_dico(classif->d_tapes, (char*)"FORM"); dico *d_form = dico_vec_get_dico(classif->d_tapes, (char*)"FORM");
dico *d_pos = dico_vec_get_dico(classif->d_tapes, (char*)"POS"); dico *d_pos = dico_vec_get_dico(classif->d_tapes, (char*)"POS");
dico *d_lemma = dico_vec_get_dico(classif->d_tapes, (char*)"LEMMA"); dico *d_lemma = dico_vec_get_dico(classif->d_tapes, (char*)"LEMMA");
return oracle_lemmatizer_lookup(c, classifier_get_output_tagset(classif), d_form, d_lemma, d_pos, classif->fplm); classif->last_prediction = oracle_lemmatizer_lookup(c, classifier_get_output_tagset(classif), d_form, d_lemma, d_pos, classif->fplm);
} }
else{
fprintf(stderr, "do not know which oracle to use for classifier %s, oracle_name = %s\n", classif->name, classifier_get_oracle_name(classif)); fprintf(stderr, "do not know which oracle to use for classifier %s, oracle_name = %s\n", classif->name, classifier_get_oracle_name(classif));
exit(1); exit(1);
} }
}
else if(classif->type == classifier::Type::Forced){ else if(classif->type == classifier::Type::Forced){
return 0; classif->last_prediction = 0;
} }
else{
fprintf(stderr, "ERROR %s : wrong classifier type '%s' for classifier '%s'\n", fprintf(stderr, "ERROR %s : wrong classifier type '%s' for classifier '%s'\n",
__func__, type2string(classif->type), classif->name); __func__, type2string(classif->type), classif->name);
exit(1); exit(1);
} }
return classif->last_prediction;
}
int classifier_argmax_nn(classifier *classif, config *c, word_emb *emb, mcd *m) int classifier_argmax_nn(classifier *classif, config *c, word_emb *emb, mcd *m)
{ {
config2feat_vec_fann(classif->fm, c, classif->d_features, classif->fv, LOOKUP_MODE); config2feat_vec_fann(classif->fm, c, classif->d_features, classif->fv, LOOKUP_MODE);
...@@ -125,19 +132,51 @@ int classifier_argmax_perceptron(classifier *classif, config *c) ...@@ -125,19 +132,51 @@ int classifier_argmax_perceptron(classifier *classif, config *c)
vcode *classifier_vcode_array(classifier *classif, config *c) vcode *classifier_vcode_array(classifier *classif, config *c)
{ {
int vcode_size = 10; // we only store 1 value, but this size should be greater than the 'n' argument of the print_vcode_array function
if(classif->type == classifier::Type::Classifier){
if(classif->mlp){
vcode *res = (vcode*)memalloc(vcode_size * sizeof *res);
vcode *model = (vcode*)classif->mlp->get_vcode_array();
memcpy(res, model, vcode_size * sizeof *res);
return res;
}
config2feat_vec_cff(classif->fm, c, classif->d_features, classif->fv, LOOKUP_MODE); config2feat_vec_cff(classif->fm, c, classif->d_features, classif->fv, LOOKUP_MODE);
return feature_table_get_vcode_array(classifier_get_feat_vec(classif), classifier_get_feature_table(classif)); return feature_table_get_vcode_array(classifier_get_feat_vec(classif), classifier_get_feature_table(classif));
} }
else if(classif->type == classifier::Type::Lookup || classif->type == classifier::Type::Forced){
vcode *res = (vcode*)memalloc(vcode_size * sizeof *res);
res[0].class_code = classif->last_prediction;
res[0].score = 1.0;
for(int i = 1; i < vcode_size; i++){
res[i].class_code = -1;
res[i].score = -1.0;
}
return res;
}
return NULL;
}
void classifier_print_vcode_array(FILE *f, classifier *classif, config *c, int n) void classifier_print_vcode_array(FILE *f, classifier *classif, config *c, int n)
{ {
vcode *vcode_array = classifier_vcode_array(classif, c); vcode *vcode_array = classifier_vcode_array(classif, c);
if(!vcode_array)
return;
mvt_tagset *tagset = classifier_get_output_tagset(classif); mvt_tagset *tagset = classifier_get_output_tagset(classif);
for(int i=0; (i < n) && (i < tagset->nbelem); i++){ for(int i=0; (i < n) && (i < tagset->nbelem); i++){
if(vcode_array[i].class_code < 0)
continue;
fprintf(f, "%d\t", i); fprintf(f, "%d\t", i);
mvt_tagset_print_mvt(f, tagset, vcode_array[i].class_code); mvt_tagset_print_mvt(f, tagset, vcode_array[i].class_code);
printf("\t \t%.4f\n", vcode_array[i].score); fprintf(f, "\t \t%.4f\n", vcode_array[i].score);
} }
free(vcode_array); free(vcode_array);
} }
......
...@@ -39,6 +39,7 @@ struct classifier{ ...@@ -39,6 +39,7 @@ struct classifier{
char *mlp_model_filename; /* name of the file that stores the mlp weights*/ char *mlp_model_filename; /* name of the file that stores the mlp weights*/
char *mlp_struct_filename; /* name of the file that stores the mlp structure */ char *mlp_struct_filename; /* name of the file that stores the mlp structure */
char *oracle_name; /* what oracle to use when training (TAGGER,MORPHO,PARSER) */ char *oracle_name; /* what oracle to use when training (TAGGER,MORPHO,PARSER) */
int last_prediction; /* the last class we predicted, usefull for printing vcode_array of Lookup classifiers */
}; };
......
...@@ -257,17 +257,17 @@ void maca_tm_decoder(context *ctx) ...@@ -257,17 +257,17 @@ void maca_tm_decoder(context *ctx)
// fprintf(stderr, "mvt_code = %d\n", mvt_code); // fprintf(stderr, "mvt_code = %d\n", mvt_code);
if(ctx->debug_mode){ if(ctx->debug_mode){
fprintf(stdout, "***********************************\n"); fprintf(stderr, "***********************************\n");
fprintf(stdout, "%s ", ctx->machine->state_array[c->current_state_nb]->name); fprintf(stderr, "%s ", ctx->machine->state_array[c->current_state_nb]->name);
config_print(stdout, c); config_print(stderr, c);
//classifier_print_vcode_array(stdout, classif, c, 4); classifier_print_vcode_array(stderr, classif, c, 4);
} }
/* If EOS has been predicted (sentence has been treated), we empty the stack so that the failure /* If EOS has been predicted (sentence has been treated), we empty the stack so that the failure
to attach every word of the previous sentence does not affect the analysis of the new sentence */ to attach every word of the previous sentence does not affect the analysis of the new sentence */
if(word_is_eos(word_buffer_bm1(config_get_buffer(c)), ctx->mcd_struct) == 1){ if(word_is_eos(word_buffer_bm1(config_get_buffer(c)), ctx->mcd_struct) == 1){
if((state_is_parser[c->current_state_nb]) && !stack_is_empty(config_get_stack(c))){ if((state_is_parser[c->current_state_nb]) && !stack_is_empty(config_get_stack(c))){
if(ctx->debug_mode) fprintf(stdout, "WARNING : emptying stack, because it's a new sentence !\n"); if(ctx->debug_mode) fprintf(stderr, "WARNING : emptying stack, because it's a new sentence !\n");
while(!stack_is_empty(config_get_stack(c))){ while(!stack_is_empty(config_get_stack(c))){
mvt_code = mvt_tagset_get_code(classifier_get_output_tagset(classif), MVT_ROOT, 0); mvt_code = mvt_tagset_get_code(classifier_get_output_tagset(classif), MVT_ROOT, 0);
......
...@@ -60,6 +60,9 @@ class Mlp{ ...@@ -60,6 +60,9 @@ class Mlp{
static dynet::ComputationGraph computation_graph; static dynet::ComputationGraph computation_graph;
using vcode = struct{int class_code; float score;};
std::vector<vcode> vcode_array;
public : public :
Mlp(std::vector<Layer> layers, unsigned int batch_size = 0); Mlp(std::vector<Layer> layers, unsigned int batch_size = 0);
...@@ -71,6 +74,7 @@ class Mlp{ ...@@ -71,6 +74,7 @@ class Mlp{
std::vector< std::vector<float> > & x_dev, std::vector<int> & y_dev); std::vector< std::vector<float> > & x_dev, std::vector<int> & y_dev);
void save(); void save();
void set_filenames(char * model_filename, char * struct_filename); void set_filenames(char * model_filename, char * struct_filename);
void *get_vcode_array();
private : private :
...@@ -93,7 +97,6 @@ class Mlp{ ...@@ -93,7 +97,6 @@ class Mlp{
std::vector<int> & y_dev); std::vector<int> & y_dev);
double get_score_on_set(std::vector< std::vector<float> > & x_dev, double get_score_on_set(std::vector< std::vector<float> > & x_dev,
std::vector<int> & y_dev); std::vector<int> & y_dev);
}; };
#endif #endif
...@@ -145,12 +145,20 @@ unsigned int Mlp::predict(dynet::Expression x){ ...@@ -145,12 +145,20 @@ unsigned int Mlp::predict(dynet::Expression x){
std::vector<float> probas = as_vector(computation_graph.forward(y)); std::vector<float> probas = as_vector(computation_graph.forward(y));
vcode_array.resize(probas.size());
unsigned int argmax = 0; unsigned int argmax = 0;
for(unsigned int i = 0; i < probas.size(); i++){ for(unsigned int i = 0; i < probas.size(); i++){
vcode_array[i] = {(int)i, probas[i]};
if(probas[i] > probas[argmax]) if(probas[i] > probas[argmax])
argmax = i; argmax = i;
} }
std::sort(vcode_array.begin(), vcode_array.end(), [](vcode a, vcode b)
{
return a.score > b.score;
});
return argmax; return argmax;
} }
...@@ -535,3 +543,7 @@ dynet::DynetParams & Mlp::get_default_params(){ ...@@ -535,3 +543,7 @@ dynet::DynetParams & Mlp::get_default_params(){
return params; return params;
} }
void *Mlp::get_vcode_array(){
return &vcode_array[0];
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment