Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#include "Decoder.hpp"
#include "SubConfig.hpp"
Decoder::Decoder(ReadingMachine & machine) : machine(machine)
{
}
void Decoder::decode(BaseConfig & config, std::size_t beamSize)
{
config.setState(machine.getStrategy().getInitialState());
while (true)
{
auto dictState = machine.getDict(config.getState()).getState();
auto context = config.extractContext(5,5,machine.getDict(config.getState()));
machine.getDict(config.getState()).setState(dictState);
//TODO : check if clone is mandatory
auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone();
//TODO : check if NoGradGuard does anything
torch::NoGradGuard guard;
auto prediction = machine.getClassifier()->getNN()(neuralInput);
int chosenTransition = -1;
for (unsigned int i = 0; i < prediction.size(0); i++)
if ((chosenTransition == -1 or prediction[i].item<float>() > prediction[chosenTransition].item<float>()) and machine.getTransitionSet().getTransition(i))
chosenTransition = i;
if (chosenTransition == -1)
util::myThrow("No transition appliable !");
auto * transition = machine.getTransitionSet().getTransition(chosenTransition);
transition->apply(config);
config.addToHistory(transition->getName());
auto movement = machine.getStrategy().getMovement(config, transition->getName());
if (movement == Strategy::endMovement)
break;
config.setState(movement.first);
if (!config.moveWordIndex(movement.second))
util::myThrow("Cannot move word index !");
}
}
float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex)
{
auto found = evaluation.find(metric);
if (found == evaluation.end())
util::myThrow(fmt::format("Cannot find metric '{}' {}\n", metric, evaluation.empty() ? "(call Decoder::evaluate() first)" : ""));
return found->second[scoreIndex];
}
float Decoder::getPrecision(const std::string & metric)
{
return getMetricScore(metric, 0);
}
float Decoder::getRecall(const std::string & metric)
{
return getMetricScore(metric, 1);
}
float Decoder::getF1Score(const std::string & metric)
{
return getMetricScore(metric, 2);
}
float Decoder::getAlignedAcc(const std::string & metric)
{
return getMetricScore(metric, 3);
}
void Decoder::evaluate(const Config & config, std::filesystem::path modelPath, const std::string goldTSV)
{
evaluation.clear();
auto predictedTSV = (modelPath/"predicted_dev.tsv").string();
std::FILE * predictedTSVFile = std::fopen(predictedTSV.c_str(), "w");
config.print(predictedTSVFile);
std::fclose(predictedTSVFile);
std::FILE * evalFromUD = popen(fmt::format("{} {} {} -v", "../scripts/conll18_ud_eval.py", goldTSV, predictedTSV).c_str(), "r");
char buffer[1024];
while (!std::feof(evalFromUD))
{
if (buffer != std::fgets(buffer, 1024, evalFromUD))
break;
if (buffer[std::strlen(buffer)-1] == '\n')
buffer[std::strlen(buffer)-1] = '\0';
if (util::doIfNameMatch(std::regex("(.*)Metric(.*)"), buffer, [this, buffer](auto sm){}))
continue;
if (util::doIfNameMatch(std::regex("(.*)\\|(.*)\\|(.*)\\|(.*)\\|(.*)"), buffer, [this, buffer](auto sm)
{
for (unsigned int i = 0; i < this->evaluation[sm[1]].size(); i++)
{
if (std::string(sm[i+2]).empty())
{
this->evaluation[sm[1]][i] = 0.0;
continue;
}
try {this->evaluation[sm[1]][i] = std::stof(sm[i+2]);}
catch (std::exception &)
{
util::myThrow(fmt::format("score '{}' is not a number in line '{}'", std::string(sm[i+2]), buffer));
}
}