Skip to content
Snippets Groups Projects
Commit 6376f588 authored by Franck Dary's avatar Franck Dary
Browse files

Decoder is now working but parameters are coded in hard

parent 1ac56cf3
No related branches found
No related tags found
No related merge requests found
......@@ -60,13 +60,20 @@ class MLP
dynet::Expression run(dynet::ComputationGraph & cg, dynet::Expression x);
inline dynet::Expression activate(dynet::Expression h, Activation f);
void printParameters(FILE * output);
void saveStruct(const std::string & filename);
void saveParameters(const std::string & filename);
void loadStruct(const std::string & filename);
void loadParameters(const std::string & filename);
void load(const std::string & filename);
public :
MLP(std::vector<Layer> layers);
MLP(const std::string & filename);
std::vector<float> predict(FeatureModel::FeatureDescription & fd, int goldClass);
int trainOnBatch(std::vector<std::pair<int, FeatureModel::FeatureDescription> >::iterator & start, std::vector<std::pair<int, FeatureModel::FeatureDescription> >::iterator & end);
void save(const std::string & filename);
};
#endif
#include "MLP.hpp"
#include "File.hpp"
#include "util.hpp"
#include <dynet/param-init.h>
......@@ -113,9 +114,7 @@ std::vector<float> MLP::predict(FeatureModel::FeatureDescription & fd, int goldC
std::vector<dynet::Expression> expressions;
for (auto & featValue : fd.values)
{
expressions.emplace_back(featValue2Expression(cg, featValue));
}
dynet::Expression input = dynet::concatenate(expressions);
......@@ -310,3 +309,79 @@ int MLP::trainOnBatch(std::vector<std::pair<int, FeatureModel::FeatureDescriptio
return nbCorrect;
}
void MLP::save(const std::string & filename)
{
saveStruct(filename);
saveParameters(filename);
}
void MLP::saveStruct(const std::string & filename)
{
File file(filename, "w");
FILE * fd = file.getDescriptor();
for (auto & layer : layers)
{
fprintf(fd, "Layer : %d %d %s %.2f\n", layer.input_dim, layer.output_dim, activation2str(layer.activation).c_str(), layer.dropout_rate);
}
}
void MLP::saveParameters(const std::string & filename)
{
dynet::TextFileSaver s(filename, true);
std::string prefix("Layer_");
for(unsigned int i = 0; i < parameters.size(); i++)
{
s.save(parameters[i][0], prefix + std::to_string(i) + "_W");
s.save(parameters[i][1], prefix + std::to_string(i) + "_b");
}
}
void MLP::load(const std::string & filename)
{
loadStruct(filename);
loadParameters(filename);
}
void MLP::loadStruct(const std::string & filename)
{
File file(filename, "r");
FILE * fd = file.getDescriptor();
char activation[1024];
int input;
int output;
float dropout;
while (fscanf(fd, "Layer : %d %d %s %f\n", &input, &output, activation, &dropout) == 4)
layers.emplace_back(input, output, dropout, str2activation(activation));
checkLayersCompatibility();
for (auto & layer : layers)
addLayerToModel(layer);
}
void MLP::loadParameters(const std::string & filename)
{
dynet::TextFileLoader loader(filename);
std::string prefix("Layer_");
for(unsigned int i = 0; i < parameters.size(); i++)
{
parameters[i][0] = loader.load_param(model, prefix + std::to_string(i) + "_W");
parameters[i][1] = loader.load_param(model, prefix + std::to_string(i) + "_b");
}
}
MLP::MLP(const std::string & filename)
: trainer(model, 0.001, 0.9, 0.999, 1e-8)
{
dynet::initialize(getDefaultParams());
trainMode = false;
load(filename);
}
......@@ -7,6 +7,44 @@ Decoder::Decoder(TapeMachine & tm, MCD & mcd, Config & config)
void Decoder::decode()
{
int nbIter = 1;
fprintf(stderr, "Training of \'%s\' :\n", tm.name.c_str());
for (int i = 0; i < nbIter; i++)
{
std::map< std::string, std::pair<int, int> > nbExamples;
while (!config.isFinal())
{
TapeMachine::State * currentState = tm.getCurrentState();
Classifier * classifier = currentState->classifier;
//config.printForDebug(stderr);
//fprintf(stderr, "State : \'%s\'\n", currentState->name.c_str());
std::string neededActionName = classifier->getOracleAction(config);
auto weightedActions = classifier->weightActions(config, neededActionName);
//Classifier::printWeightedActions(stderr, weightedActions);
std::string & predictedAction = weightedActions[0].second;
nbExamples[classifier->name].first++;
if(predictedAction == neededActionName)
nbExamples[classifier->name].second++;
//fprintf(stderr, "Action : \'%s\'\n", neededActionName.c_str());
TapeMachine::Transition * transition = tm.getTransition(neededActionName);
tm.takeTransition(transition);
config.moveHead(transition->headMvt);
}
fprintf(stderr, "Iteration %d/%d :\n", i+1, nbIter);
for(auto & it : nbExamples)
fprintf(stderr, "\t%s %.2f%% accuracy\n", it.first.c_str(), 100.0*it.second.second / it.second.first);
config.reset();
}
}
......@@ -59,6 +59,7 @@ class Dict
std::vector<float> * getValue(const std::string & s);
std::vector<float> * getNullValue();
int getDimension();
void printForDebug(FILE * output);
};
#endif
......@@ -71,10 +71,11 @@ Dict::Dict(Policy policy, const std::string & filename)
if(this->policy == Policy::FromZero)
return;
while(fscanf(fd, "%s", b1) != 1)
while(fscanf(fd, "%s", b1) == 1)
{
std::string entry = b1;
str2vec.emplace(entry, std::vector<float>());
//str2vec.emplace(entry, std::vector<float>());
str2vec[entry] = std::vector<float>();
auto & vec = str2vec[entry];
// For OneHot we only write the index
......@@ -195,3 +196,8 @@ int Dict::getDimension()
return dimension;
}
void Dict::printForDebug(FILE * output)
{
fprintf(output, "Dict name \'%s\' nbElems = %lu\n", name.c_str(), str2vec.size());
}
......@@ -25,6 +25,7 @@ class Classifier
private :
bool trainMode;
Type type;
std::unique_ptr<FeatureModel> fm;
std::unique_ptr<ActionSet> as;
......@@ -36,7 +37,7 @@ class Classifier
static void printWeightedActions(FILE * output, WeightedActions & wa);
static Type str2type(const std::string & filename);
Classifier(const std::string & filename);
Classifier(const std::string & filename, bool trainMode);
WeightedActions weightActions(Config & config, const std::string & goldAction);
FeatureModel::FeatureDescription getFeatureDescription(Config & config);
std::string getOracleAction(Config & config);
......@@ -44,6 +45,7 @@ class Classifier
int trainOnBatch(std::vector<std::pair<int, FeatureModel::FeatureDescription> >::iterator & start, std::vector<std::pair<int, FeatureModel::FeatureDescription> >::iterator & end);
std::string getActionName(int actionIndex);
void initClassifier(Config & config);
void save(const std::string & filename);
};
#endif
......@@ -29,9 +29,11 @@ class TapeMachine
private :
bool trainMode;
std::map< std::string, std::unique_ptr<Classifier> > str2classifier;
std::map< std::string, std::unique_ptr<State> > str2state;
State * currentState;
std::vector<Classifier*> classifiers;
public :
......@@ -39,10 +41,11 @@ class TapeMachine
public :
TapeMachine(const std::string & filename);
TapeMachine(const std::string & filename, bool trainMode);
State * getCurrentState();
Transition * getTransition(const std::string & action);
void takeTransition(Transition * transition);
std::vector<Classifier*> & getClassifiers();
};
#endif
......@@ -2,8 +2,10 @@
#include "File.hpp"
#include "util.hpp"
Classifier::Classifier(const std::string & filename)
Classifier::Classifier(const std::string & filename, bool trainMode)
{
this->trainMode = trainMode;
auto badFormatAndAbort = [&filename](const char * errInfo)
{
fprintf(stderr, "ERROR (%s) : file %s bad format. Aborting.\n", errInfo, filename.c_str());
......@@ -85,6 +87,12 @@ void Classifier::initClassifier(Config & config)
if(mlp.get())
return;
if(!trainMode)
{
mlp.reset(new MLP("toto.txt"));
return;
}
int nbInputs = 0;
int nbHidden = 200;
int nbOutputs = as->actions.size();
......@@ -138,3 +146,8 @@ void Classifier::printWeightedActions(FILE * output, WeightedActions & wa)
fprintf(output, "%c%s", symbol, i == nbCols-1 ? "\n" : "");
}
void Classifier::save(const std::string & filename)
{
mlp->save(filename);
}
......@@ -3,7 +3,7 @@
#include "util.hpp"
#include <cstring>
TapeMachine::TapeMachine(const std::string & filename)
TapeMachine::TapeMachine(const std::string & filename, bool trainMode)
{
auto badFormatAndAbort = [&filename](const std::string & errInfo)
{
......@@ -12,6 +12,8 @@ TapeMachine::TapeMachine(const std::string & filename)
exit(1);
};
this->trainMode = trainMode;
File file(filename, "r");
FILE * fd = file.getDescriptor();
......@@ -35,7 +37,9 @@ TapeMachine::TapeMachine(const std::string & filename)
if(fscanf(fd, "%s %s\n", buffer, buffer2) != 2)
badFormatAndAbort(ERRINFO);
str2classifier.emplace(buffer, std::unique_ptr<Classifier>(new Classifier(buffer2)));
str2classifier.emplace(buffer, std::unique_ptr<Classifier>(new Classifier(buffer2, trainMode)));
classifiers.emplace_back(str2classifier[buffer].get());
}
// Reading %STATES
......@@ -124,3 +128,8 @@ void TapeMachine::takeTransition(Transition * transition)
currentState = transition->dest;
}
std::vector<Classifier*> & TapeMachine::getClassifiers()
{
return classifiers;
}
......@@ -3,3 +3,7 @@ FILE(GLOB SOURCES src/*.cpp)
add_executable(test_train src/test_train.cpp)
target_link_libraries(test_train tape_machine)
target_link_libraries(test_train trainer)
add_executable(test_decode src/test_decode.cpp)
target_link_libraries(test_decode tape_machine)
target_link_libraries(test_decode decoder)
#include <cstdio>
#include <cstdlib>
#include "MCD.hpp"
#include "Config.hpp"
#include "TapeMachine.hpp"
#include "Decoder.hpp"
void printUsageAndExit(char * argv[])
{
fprintf(stderr, "USAGE : %s mcd inputFile tm\n", *argv);
exit(1);
}
int main(int argc, char * argv[])
{
if (argc != 4)
printUsageAndExit(argv);
MCD mcd(argv[1]);
Config config(mcd);
TapeMachine tapeMachine(argv[3], false);
config.readInput(argv[2]);
Decoder decoder(tapeMachine, mcd, config);
decoder.decode();
return 0;
}
......@@ -19,7 +19,7 @@ int main(int argc, char * argv[])
MCD mcd(argv[1]);
Config config(mcd);
TapeMachine tapeMachine(argv[3]);
TapeMachine tapeMachine(argv[3], true);
config.readInput(argv[2]);
......
......@@ -7,7 +7,7 @@ Trainer::Trainer(TapeMachine & tm, MCD & mcd, Config & config)
void Trainer::trainUnbatched()
{
int nbIter = 5;
int nbIter = 20;
fprintf(stderr, "Training of \'%s\' :\n", tm.name.c_str());
......@@ -26,7 +26,7 @@ void Trainer::trainUnbatched()
std::string neededActionName = classifier->getOracleAction(config);
auto weightedActions = classifier->weightActions(config, neededActionName);
//printWeightedActions(stderr, weightedActions);
//Classifier::printWeightedActions(stderr, weightedActions);
std::string & predictedAction = weightedActions[0].second;
nbExamples[classifier->name].first++;
......@@ -46,6 +46,10 @@ void Trainer::trainUnbatched()
config.reset();
}
auto & classifiers = tm.getClassifiers();
for(Classifier * cla : classifiers)
cla->save("toto.txt");
}
void Trainer::trainBatched()
......@@ -101,8 +105,11 @@ void Trainer::trainBatched()
fprintf(stderr, "Iteration %d/%d :\n", i+1, nbIter);
for(auto & it : nbExamples)
fprintf(stderr, "\t%s %.2f%% accuracy\n", it.first.c_str(), 100.0*it.second.second / it.second.first);
}
auto & classifiers = tm.getClassifiers();
for(Classifier * cla : classifiers)
cla->save("toto.txt");
}
void Trainer::train()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment