Skip to content
Snippets Groups Projects
Commit 586b6b0b authored by Franck Dary's avatar Franck Dary
Browse files

Added decoder

parent 1759bd0b
No related branches found
No related tags found
No related merge requests found
......@@ -27,6 +27,7 @@ include_directories(common/include)
include_directories(reading_machine/include)
include_directories(torch_modules/include)
include_directories(trainer/include)
include_directories(decoder/include)
include_directories(utf8)
add_subdirectory(fmt)
......@@ -35,4 +36,5 @@ add_subdirectory(dev)
add_subdirectory(reading_machine)
add_subdirectory(torch_modules)
add_subdirectory(trainer)
add_subdirectory(decoder)
......@@ -36,6 +36,7 @@ class Dict
void insert(const std::string & element);
int getIndexOrInsert(const std::string & element);
void setState(State state);
State getState();
void save(std::FILE * destination, Encoding encoding);
bool readEntry(std::FILE * file, int * index, char * entry, Encoding encoding);
void printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding);
......
......@@ -79,6 +79,11 @@ void Dict::setState(State state)
this->state = state;
}
Dict::State Dict::getState()
{
return state;
}
void Dict::save(std::FILE * destination, Encoding encoding)
{
fprintf(destination, "Encoding : %s\n", encoding == Encoding::Ascii ? "Ascii" : "Binary");
......
FILE(GLOB SOURCES src/*.cpp)
add_library(decoder STATIC ${SOURCES})
target_link_libraries(decoder reading_machine)
#ifndef DECODER__H
#define DECODER__H
#include <filesystem>
#include "ReadingMachine.hpp"
#include "SubConfig.hpp"
class Decoder
{
private :
ReadingMachine & machine;
std::map<std::string, std::array<float,4>> evaluation;
public :
Decoder(ReadingMachine & machine);
void decode(BaseConfig & config, std::size_t beamSize);
void evaluate(const Config & config, std::filesystem::path modelPath, const std::string goldTSV);
float getMetricScore(const std::string & metric, std::size_t scoreIndex);
float getPrecision(const std::string & metric);
float getF1Score(const std::string & metric);
float getRecall(const std::string & metric);
float getAlignedAcc(const std::string & metric);
};
#endif
#include "Decoder.hpp"
#include "SubConfig.hpp"
Decoder::Decoder(ReadingMachine & machine) : machine(machine)
{
}
void Decoder::decode(BaseConfig & config, std::size_t beamSize)
{
config.setState(machine.getStrategy().getInitialState());
while (true)
{
auto dictState = machine.getDict(config.getState()).getState();
auto context = config.extractContext(5,5,machine.getDict(config.getState()));
machine.getDict(config.getState()).setState(dictState);
//TODO : check if clone is mandatory
auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone();
//TODO : check if NoGradGuard does anything
torch::NoGradGuard guard;
auto prediction = machine.getClassifier()->getNN()(neuralInput);
int chosenTransition = -1;
for (unsigned int i = 0; i < prediction.size(0); i++)
if ((chosenTransition == -1 or prediction[i].item<float>() > prediction[chosenTransition].item<float>()) and machine.getTransitionSet().getTransition(i))
chosenTransition = i;
if (chosenTransition == -1)
util::myThrow("No transition appliable !");
auto * transition = machine.getTransitionSet().getTransition(chosenTransition);
transition->apply(config);
config.addToHistory(transition->getName());
auto movement = machine.getStrategy().getMovement(config, transition->getName());
if (movement == Strategy::endMovement)
break;
config.setState(movement.first);
if (!config.moveWordIndex(movement.second))
util::myThrow("Cannot move word index !");
}
}
float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex)
{
auto found = evaluation.find(metric);
if (found == evaluation.end())
util::myThrow(fmt::format("Cannot find metric '{}' {}\n", metric, evaluation.empty() ? "(call Decoder::evaluate() first)" : ""));
return found->second[scoreIndex];
}
float Decoder::getPrecision(const std::string & metric)
{
return getMetricScore(metric, 0);
}
float Decoder::getRecall(const std::string & metric)
{
return getMetricScore(metric, 1);
}
float Decoder::getF1Score(const std::string & metric)
{
return getMetricScore(metric, 2);
}
float Decoder::getAlignedAcc(const std::string & metric)
{
return getMetricScore(metric, 3);
}
void Decoder::evaluate(const Config & config, std::filesystem::path modelPath, const std::string goldTSV)
{
evaluation.clear();
auto predictedTSV = (modelPath/"predicted_dev.tsv").string();
std::FILE * predictedTSVFile = std::fopen(predictedTSV.c_str(), "w");
config.print(predictedTSVFile);
std::fclose(predictedTSVFile);
std::FILE * evalFromUD = popen(fmt::format("{} {} {} -v", "../scripts/conll18_ud_eval.py", goldTSV, predictedTSV).c_str(), "r");
char buffer[1024];
while (!std::feof(evalFromUD))
{
if (buffer != std::fgets(buffer, 1024, evalFromUD))
break;
if (util::doIfNameMatch(std::regex("(.*)\\|(.*)\\|(.*)\\|(.*)\\|(.*)"), buffer, [this](auto sm)
{
for (unsigned int i = 0; i < this->evaluation[sm[1]].size(); i++)
this->evaluation[sm[1]][i] = std::stof(sm[1+i]);
})){}
}
pclose(evalFromUD);
}
......@@ -5,3 +5,4 @@ target_link_libraries(dev common)
target_link_libraries(dev reading_machine)
target_link_libraries(dev torch_modules)
target_link_libraries(dev trainer)
target_link_libraries(dev decoder)
#include <cstdio>
#include <filesystem>
#include "fmt/core.h"
#include "util.hpp"
#include "BaseConfig.hpp"
......@@ -6,6 +6,7 @@
#include "TransitionSet.hpp"
#include "ReadingMachine.hpp"
#include "Trainer.hpp"
#include "Decoder.hpp"
int main(int argc, char * argv[])
{
......@@ -15,13 +16,16 @@ int main(int argc, char * argv[])
exit(1);
}
std::string machineFile = argv[1];
std::string model = argv[1];
std::string mcdFile = argv[2];
std::string tsvFile = argv[3];
//std::string rawFile = argv[4];
std::string rawFile = "";
ReadingMachine machine(machineFile);
std::filesystem::path modelPath(model);
auto machinePath = modelPath / "machine.rm";
ReadingMachine machine(machinePath.string());
BaseConfig goldConfig(mcdFile, tsvFile, rawFile);
SubConfig config(goldConfig);
......@@ -29,13 +33,17 @@ int main(int argc, char * argv[])
Trainer trainer(machine);
trainer.createDataset(config);
int nbEpoch = 5;
Decoder decoder(machine);
int nbEpoch = 1;
for (int i = 0; i < nbEpoch; i++)
{
float loss = trainer.epoch();
fmt::print("\r{:80}", " ");
fmt::print("\rEpoch {}/{} loss = {}\n", i+1, nbEpoch, loss);
auto devConfig = goldConfig;
decoder.decode(devConfig, 1);
decoder.evaluate(devConfig, modelPath, tsvFile);
fmt::print(stderr, "\r{:80}\rEpoch {}/{} loss = {} dev = {}\n", " ", i+1, nbEpoch, loss, decoder.getF1Score("UPOS"));
}
return 0;
......
......@@ -19,6 +19,7 @@ class TransitionSet
std::vector<std::pair<Transition*, int>> getAppliableTransitionsCosts(const Config & c);
Transition * getBestAppliableTransition(const Config & c);
std::size_t getTransitionIndex(const Transition * transition) const;
Transition * getTransition(std::size_t index);
std::size_t size() const;
};
......
......@@ -97,7 +97,7 @@ Action Action::addHypothesisRelative(const std::string & colName, Object object,
else
lineIndex = config.getStack(relativeIndex);
return addHypothesis(colName, lineIndex, "").apply(config, a);
return addHypothesis(colName, lineIndex, hypothesis).apply(config, a);
};
auto undo = [colName, object, relativeIndex](Config & config, Action & a)
......
......@@ -80,3 +80,8 @@ std::size_t TransitionSet::getTransitionIndex(const Transition * transition) con
return transition - &transitions[0];
}
Transition * TransitionSet::getTransition(std::size_t index)
{
return &transitions[index];
}
......@@ -29,10 +29,12 @@ torch::Tensor TestNetworkImpl::forward(torch::Tensor input)
{
// input dim = {batch, sequence, embeddings}
auto wordsAsEmb = wordEmbeddings(input);
auto reshaped = wordsAsEmb;
// reshaped dim = {sequence, batch, embeddings}
auto reshaped = wordsAsEmb.permute({1,0,2});
if (reshaped.dim() == 3)
reshaped = wordsAsEmb.permute({1,0,2});
auto res = torch::softmax(linear(reshaped[focusedIndex]), 1);
auto res = torch::softmax(linear(reshaped[focusedIndex]), reshaped.dim() == 3 ? 1 : 0);
return res;
}
......
......@@ -18,6 +18,7 @@ void Trainer::createDataset(SubConfig & config)
if (!transition)
util::myThrow("No transition appliable !");
//TODO : check if clone is mandatory
auto context = config.extractContext(5,5,machine.getDict(config.getState()));
contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
......@@ -82,8 +83,7 @@ float Trainer::epoch()
if (nbExamplesUntilPrint <= 0)
{
nbExamplesUntilPrint = printInterval;
fmt::print("\rcurrent epoch : {:6.2f}% loss={:<15}", 100.0*(currentBatchNumber*batchSize)/nbExamples, lossSoFar);
std::fflush(stdout);
fmt::print(stderr, "\rcurrent epoch : {:6.2f}% loss={:<15}", 100.0*(currentBatchNumber*batchSize)/nbExamples, lossSoFar);
lossSoFar = 0;
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment