Skip to content
Snippets Groups Projects
Commit bc2ede62 authored by Franck Dary's avatar Franck Dary
Browse files

macaon_decode is working

parent e5cd481f
Branches
No related tags found
No related merge requests found
......@@ -19,6 +19,7 @@
#include <array>
#include <unordered_map>
#include <regex>
#include <filesystem>
#include <experimental/source_location>
#include <boost/flyweight.hpp>
#include "fmt/core.h"
......@@ -33,6 +34,8 @@ void error(std::string_view message, const std::experimental::source_location &
void error(const std::exception & e, const std::experimental::source_location & location = std::experimental::source_location::current());
void myThrow(std::string_view message, const std::experimental::source_location & location = std::experimental::source_location::current());
std::vector<std::filesystem::path> findFilesByExtension(std::filesystem::path directory, std::string extension);
std::string_view getFilenameFromPath(std::string_view s);
std::vector<std::string_view> split(std::string_view s, char delimiter);
......
......@@ -19,11 +19,11 @@ void Dict::readFromFile(const char * filename)
std::FILE * file = std::fopen(filename, "r");
if (!file)
util::myThrow(fmt::format("could not open file \'%s\'", filename));
util::myThrow(fmt::format("could not open file '{}'", filename));
char buffer[1048];
if (std::fscanf(file, "Encoding : %1047s\n", buffer) != 1)
util::myThrow(fmt::format("file \'%s\' bad format", filename));
util::myThrow(fmt::format("file '{}' bad format", filename));
Encoding encoding{Encoding::Ascii};
if (std::string(buffer) == "Ascii")
......@@ -31,12 +31,12 @@ void Dict::readFromFile(const char * filename)
else if (std::string(buffer) == "Binary")
encoding = Encoding::Binary;
else
util::myThrow(fmt::format("file \'%s\' bad format", filename));
util::myThrow(fmt::format("file '{}' bad format", filename));
int nbEntries;
if (std::fscanf(file, "Nb entries : %d\n", &nbEntries) != 1)
util::myThrow(fmt::format("file \'%s\' bad format", filename));
util::myThrow(fmt::format("file '{}' bad format", filename));
elementsToIndexes.reserve(nbEntries);
......@@ -45,7 +45,7 @@ void Dict::readFromFile(const char * filename)
for (int i = 0; i < nbEntries; i++)
{
if (!readEntry(file, &entryIndex, entryString, encoding))
util::myThrow(fmt::format("file \'%s\' line {} bad format", filename, i));
util::myThrow(fmt::format("file '{}' line {} bad format", filename, i));
elementsToIndexes[entryString] = entryIndex;
}
......
......@@ -171,3 +171,18 @@ std::string util::strip(const std::string & s)
return std::string(s.begin()+first, s.begin()+last+1);
}
std::vector<std::filesystem::path> util::findFilesByExtension(std::filesystem::path directory, std::string extension)
{
std::vector<std::filesystem::path> files;
for (auto entry : std::filesystem::directory_iterator(directory))
if (entry.is_regular_file())
{
auto path = entry.path();
if (path.extension() == extension)
files.push_back(path);
}
return files;
}
......@@ -6,6 +6,8 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine)
}
void Decoder::decode(BaseConfig & config, std::size_t beamSize)
{
try
{
config.setState(machine.getStrategy().getInitialState());
......@@ -42,6 +44,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize)
if (!config.moveWordIndex(movement.second))
util::myThrow("Cannot move word index !");
}
} catch(std::exception & e) {util::myThrow(e.what());}
}
float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex)
......
......@@ -64,19 +64,31 @@ int main(int argc, char * argv[])
auto variables = checkOptions(od, argc, argv);
std::filesystem::path modelPath(variables["model"].as<std::string>());
auto machinePath = modelPath / ReadingMachine::defaultMachineName;
auto machinePath = modelPath / ReadingMachine::defaultMachineFilename;
auto dictPaths = util::findFilesByExtension(modelPath, fmt::format(ReadingMachine::defaultDictFilename, ""));
auto modelPaths = util::findFilesByExtension(modelPath, fmt::format(ReadingMachine::defaultModelFilename, ""));
auto inputTSV = variables.count("inputTSV") ? variables["inputTSV"].as<std::string>() : "";
auto inputTXT = variables.count("inputTXT") ? variables["inputTXT"].as<std::string>() : "";
auto mcdFile = variables["mcd"].as<std::string>();
ReadingMachine machine(machinePath.string());
if (dictPaths.empty())
util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultDictFilename, "")));
if (modelPaths.empty())
util::error(fmt::format("no '{}' files were found, and none were given. Has the model been trained yet ?", fmt::format(ReadingMachine::defaultModelFilename, "")));
try
{
ReadingMachine machine(machinePath, modelPaths, dictPaths);
Decoder decoder(machine);
BaseConfig config(mcdFile, inputTSV, inputTXT);
decoder.decode(config, 1);
fmt::print(stderr, "\n");
config.print(stdout);
} catch(std::exception & e) {util::error(e);}
return 0;
}
......
......@@ -18,6 +18,7 @@ class Classifier
Classifier(const std::string & name, const std::string & topology, const std::string & tsFile);
TransitionSet & getTransitionSet();
TestNetwork & getNN();
const std::string & getName() const;
};
#endif
......@@ -12,9 +12,10 @@ class ReadingMachine
{
public :
static inline const std::string defaultMachineName = "machine.rm";
static inline const std::string defaultModelName = "{}.pt";
static inline const std::string defaultDictName = "{}.dict";
static inline const std::string defaultMachineFilename = "machine.rm";
static inline const std::string defaultModelFilename = "{}.pt";
static inline const std::string defaultDictFilename = "{}.dict";
static inline const std::string defaultDictName = "_default_";
private :
......@@ -25,14 +26,19 @@ class ReadingMachine
std::unique_ptr<FeatureFunction> featureFunction;
std::map<std::string, Dict> dicts;
private :
void readFromFile(std::filesystem::path path);
public :
ReadingMachine(std::filesystem::path path);
ReadingMachine(const std::string & filename, const std::vector<std::string> & models, const std::vector<std::string> & dicts);
ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models, std::vector<std::filesystem::path> dicts);
TransitionSet & getTransitionSet();
Strategy & getStrategy();
Dict & getDict(const std::string & state);
Classifier * getClassifier();
void save();
};
#endif
......@@ -17,3 +17,8 @@ TestNetwork & Classifier::getNN()
return nn;
}
const std::string & Classifier::getName() const
{
return name;
}
......@@ -3,8 +3,23 @@
ReadingMachine::ReadingMachine(std::filesystem::path path) : path(path)
{
dicts.emplace(std::make_pair("", Dict::State::Open));
dicts.emplace(std::make_pair(defaultDictName, Dict::State::Open));
readFromFile(path);
}
ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models, std::vector<std::filesystem::path> dicts)
{
readFromFile(path);
for (auto path : dicts)
this->dicts.emplace(path.stem().string(), Dict{path.c_str(), Dict::State::Closed});
torch::load(classifier->getNN(), models[0]);
}
void ReadingMachine::readFromFile(std::filesystem::path path)
{
std::FILE * file = std::fopen(path.c_str(), "r");
char buffer[1024];
......@@ -49,11 +64,6 @@ ReadingMachine::ReadingMachine(std::filesystem::path path) : path(path)
} catch(std::exception & e) {util::myThrow(fmt::format("during reading of '{}' : {}", path.string(), e.what()));}
}
ReadingMachine::ReadingMachine(const std::string & filename, const std::vector<std::string> & models, const std::vector<std::string> & dicts)
{
}
TransitionSet & ReadingMachine::getTransitionSet()
{
return classifier->getTransitionSet();
......@@ -68,8 +78,11 @@ Dict & ReadingMachine::getDict(const std::string & state)
{
auto found = dicts.find(state);
try
{
if (found == dicts.end())
return dicts.at("");
return dicts.at(defaultDictName);
} catch (std::exception & e) {util::myThrow(fmt::format("can't find dict '{}'", defaultDictName));}
return found->second;
}
......@@ -79,3 +92,21 @@ Classifier * ReadingMachine::getClassifier()
return classifier.get();
}
void ReadingMachine::save()
{
for (auto & it : dicts)
{
auto pathToDict = path.parent_path() / fmt::format(defaultDictFilename, it.first);
std::FILE * file = std::fopen(pathToDict.c_str(), "w");
if (!file)
util::myThrow(fmt::format("couldn't create file '{}'", pathToDict.c_str()));
it.second.save(file, Dict::Encoding::Ascii);
std::fclose(file);
}
auto pathToClassifier = path.parent_path() / fmt::format(defaultModelFilename, classifier->getName());
torch::save(classifier->getNN(), pathToClassifier);
}
......@@ -82,13 +82,22 @@ int main(int argc, char * argv[])
Decoder decoder(machine);
BaseConfig devGoldConfig(mcdFile, devTsvFile, devRawFile);
float bestDevScore = 0;
for (int i = 0; i < nbEpoch; i++)
{
float loss = trainer.epoch();
auto devConfig = devGoldConfig;
decoder.decode(devConfig, 1);
decoder.evaluate(devConfig, modelPath, devTsvFile);
fmt::print(stderr, "\r{:80}\rEpoch {:^9} loss = {:7.2f} dev = {}%\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, decoder.getF1Score("UPOS"));
float devScore = decoder.getF1Score("UPOS");
bool saved = devScore > bestDevScore;
if (saved)
{
bestDevScore = devScore;
machine.save();
}
fmt::print(stderr, "\r{:80}\rEpoch {:^9} loss = {:7.2f} dev = {:6.2f}% {:5}\n", " ", fmt::format("{}/{}", i+1, nbEpoch), loss, devScore, saved ? "SAVED" : "");
}
return 0;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment