Skip to content
Snippets Groups Projects
Commit bd1a31a8 authored by Franck Dary's avatar Franck Dary
Browse files

BeamSearch works for tagger, but there is a bug with rawInput

parent b9bb67a5
Branches
No related tags found
No related merge requests found
#ifndef BEAM__H
#define BEAM__H
#include <vector>
#include <string>
#include "BaseConfig.hpp"
#include "ReadingMachine.hpp"
class Beam
{
public :
class Element
{
public :
BaseConfig config;
int nextTransition;
float totalProbability;
std::string name;
bool ended{false};
public :
Element(BaseConfig & model, int nextTransition, float totalProbability, std::string name);
};
private :
std::size_t width;
float threshold;
std::vector<Element> elements;
bool ended{false};
public :
Beam(std::size_t width, float threshold, BaseConfig & model, const ReadingMachine & machine);
Element & operator[](std::size_t index);
void update(ReadingMachine & machine, bool debug);
bool isEnded() const;
};
#endif
#include "Beam.hpp"
Beam::Beam(std::size_t width, float threshold, BaseConfig & model, const ReadingMachine & machine) : width(width), threshold(threshold)
{
model.setStrategy(machine.getStrategyDefinition());
model.addPredicted(machine.getPredicted());
model.setState(model.getStrategy().getInitialState());
elements.emplace_back(model, -1, 0.0, "0");
}
Beam::Element::Element(BaseConfig & model, int nextTransition, float totalProbability, std::string name) : config(model), nextTransition(nextTransition), totalProbability(totalProbability), name(name)
{
}
Beam::Element & Beam::operator[](std::size_t index)
{
return elements[index];
}
void Beam::update(ReadingMachine & machine, bool debug)
{
ended = true;
auto currentNbElements = elements.size();
if (debug)
fmt::print(stderr, "{:-<{}}\nBEAM SEARCH CONTENT :\n", "", 80);
for (unsigned int index = 0; index < currentNbElements; index++)
{
if (elements[index].ended)
continue;
ended = false;
auto & classifier = *machine.getClassifier();
classifier.setState(elements[index].config.getState());
auto appliableTransitions = machine.getTransitionSet().getAppliableTransitions(elements[index].config);
elements[index].config.setAppliableTransitions(appliableTransitions);
if (machine.hasSplitWordTransitionSet())
elements[index].config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(elements[index].config, Config::maxNbAppliableSplitTransitions));
auto context = classifier.getNN()->extractContext(elements[index].config).back();
auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
auto prediction = torch::softmax(classifier.getNN()(neuralInput).squeeze(), 0);
std::vector<std::pair<float, int>> scoresOfTransitions;
for (unsigned int i = 0; i < prediction.size(0); i++)
{
float score = prediction[i].item<float>();
if (appliableTransitions[i] and score >= threshold)
scoresOfTransitions.emplace_back(std::make_pair(score, i));
}
if (scoresOfTransitions.empty())
{
elements[index].config.printForDebug(stderr);
util::myThrow("No suitable transition found !");
}
std::sort(scoresOfTransitions.rbegin(), scoresOfTransitions.rend());
if (width > 1)
for (unsigned int i = 1; i < scoresOfTransitions.size(); i++)
{
elements.emplace_back(elements[index].config, scoresOfTransitions[i].second, elements[index].totalProbability + scoresOfTransitions[i].first, elements[index].name + ":" + std::to_string(scoresOfTransitions[i].second));
}
elements[index].nextTransition = scoresOfTransitions[0].second;
elements[index].totalProbability += scoresOfTransitions[0].first;
elements[index].name += ":" + std::to_string(elements[index].nextTransition);
if (debug)
{
elements[index].config.printForDebug(stderr);
std::vector<std::pair<float,std::string>> toPrint;
for (unsigned int i = 0; i < prediction.size(0); i++)
{
float score = prediction[i].item<float>();
std::string nicePrint = fmt::format("{} {:7.2f} {}", appliableTransitions[i] ? "*" : " ", score, machine.getTransitionSet().getTransition(i)->getName());
toPrint.emplace_back(std::make_pair(score,nicePrint));
}
std::sort(toPrint.rbegin(), toPrint.rend());
for (unsigned int i = 0; i < 5 and i < toPrint.size(); i++)
fmt::print(stderr, "{}\n", toPrint[i].second);
}
}
std::sort(elements.begin(), elements.end(), [](const Element & a, const Element & b)
{
return a.totalProbability > b.totalProbability;
});
while (elements.size() > width)
elements.pop_back();
for (auto & element : elements)
{
if (element.ended)
continue;
auto & config = element.config;
auto & classifier = *machine.getClassifier();
classifier.setState(config.getState());
auto * transition = machine.getTransitionSet().getTransition(element.nextTransition);
transition->apply(config);
config.addToHistory(transition->getName());
auto movement = config.getStrategy().getMovement(config, transition->getName());
if (movement == Strategy::endMovement)
{
element.ended = true;
continue;
}
config.setState(movement.first);
config.moveWordIndexRelaxed(movement.second);
}
if (debug)
fmt::print(stderr, "END OF BEAM SEARCH CONTENT\n{:-<{}}\n", "", 80);
}
bool Beam::isEnded() const
{
return ended;
}
#include "Decoder.hpp" #include "Decoder.hpp"
#include "SubConfig.hpp" #include "SubConfig.hpp"
#include "Beam.hpp"
Decoder::Decoder(ReadingMachine & machine) : machine(machine) Decoder::Decoder(ReadingMachine & machine) : machine(machine)
{ {
...@@ -7,7 +8,6 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine) ...@@ -7,7 +8,6 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine)
void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug, bool printAdvancement) void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug, bool printAdvancement)
{ {
constexpr float beamThreshold = 0.1;
constexpr int printInterval = 50; constexpr int printInterval = 50;
torch::AutoGradMode useGrad(false); torch::AutoGradMode useGrad(false);
...@@ -17,130 +17,13 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug, ...@@ -17,130 +17,13 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug,
int nbExamplesProcessed = 0; int nbExamplesProcessed = 0;
auto pastTime = std::chrono::high_resolution_clock::now(); auto pastTime = std::chrono::high_resolution_clock::now();
std::vector<BaseConfig> beam; Beam beam(beamSize, 0.1, baseConfig, machine);
std::vector<bool> endFlag;
try try
{ {
while (!beam.isEnded())
for (unsigned int i = 0; i < beamSize; i++)
{
beam.emplace_back(baseConfig);
beam.back().setStrategy(machine.getStrategyDefinition());
beam.back().addPredicted(machine.getPredicted());
beam.back().setState(beam.back().getStrategy().getInitialState());
endFlag.emplace_back(false);
}
while (true)
{
if (machine.hasSplitWordTransitionSet())
for (auto & c : beam)
c.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(c, Config::maxNbAppliableSplitTransitions));
std::vector<std::vector<int>> appliableTransitions;
for (auto & c : beam)
{
machine.getClassifier()->setState(c.getState());
appliableTransitions.emplace_back(machine.getTransitionSet().getAppliableTransitions(c));
c.setAppliableTransitions(appliableTransitions.back());
}
std::vector<torch::Tensor> predictions;
for (auto & c : beam)
{
machine.getClassifier()->setState(c.getState());
auto context = machine.getClassifier()->getNN()->extractContext(c).back();
auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
predictions.emplace_back(machine.getClassifier()->getNN()(neuralInput).squeeze());
}
if (debug)
{
fmt::print(stderr, "{:-<{}}\n", "", 80);
fmt::print(stderr, "BEAM SEARCH CONTENT :\n");
for (unsigned int beamIndex = 0; beamIndex < beam.size(); beamIndex++)
{
auto & c = beam[beamIndex];
machine.getClassifier()->setState(c.getState());
c.printForDebug(stderr);
auto softmaxed = torch::softmax(predictions[beamIndex],-1);
std::vector<std::pair<float,std::string>> toPrint;
for (unsigned int i = 0; i < softmaxed.size(0); i++)
{
float score = softmaxed[i].item<float>();
std::string nicePrint = fmt::format("{} {:7.2f} {}", appliableTransitions[beamIndex][i] ? "*" : " ", score, machine.getTransitionSet().getTransition(i)->getName());
toPrint.emplace_back(std::make_pair(score,nicePrint));
}
std::sort(toPrint.rbegin(), toPrint.rend());
for (unsigned int i = 0; i < 5 and i < toPrint.size(); i++)
fmt::print(stderr, "{}\n", toPrint[i].second);
}
fmt::print(stderr, "END OF BEAM SEARCH CONTENT\n");
fmt::print(stderr, "{:-<{}}\n", "", 80);
}
for (unsigned int beamIndex = 0; beamIndex < beam.size(); beamIndex++)
{
if (endFlag[beamIndex])
continue;
auto & c = beam[beamIndex];
machine.getClassifier()->setState(c.getState());
int chosenTransition = -1;
float bestScore = std::numeric_limits<float>::min();
auto softmaxed = torch::softmax(predictions[beamIndex], 0);
std::vector<int> consideredTransitions;
try
{
for (unsigned int i = 0; i < predictions[beamIndex].size(0); i++)
{
float score = predictions[beamIndex][i].item<float>();
if ((chosenTransition == -1 or score > bestScore) and appliableTransitions[beamIndex][i])
{ {
chosenTransition = i; beam.update(machine, debug);
bestScore = score;
}
if (softmaxed[i].item<float>() >= beamThreshold)
consideredTransitions.emplace_back(i);
}
} catch(std::exception & e) {util::myThrow(e.what());}
if (chosenTransition == -1)
{
c.printForDebug(stderr);
util::myThrow("No transition appliable !");
}
auto * transition = machine.getTransitionSet().getTransition(chosenTransition);
transition->apply(c);
c.addToHistory(transition->getName());
auto movement = c.getStrategy().getMovement(c, transition->getName());
if (debug)
{
//TODO improve this for beam search
fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", transition->getName(), movement.first, movement.second);
}
if (movement == Strategy::endMovement)
{
endFlag[beamIndex] = true;
continue;
}
c.setState(movement.first);
c.moveWordIndexRelaxed(movement.second);
}
bool allBeamAreEnded = true;
for (unsigned int i = 0; i < beam.size(); i++)
if (!endFlag[i])
allBeamAreEnded = false;
if (allBeamAreEnded)
break;
if (printAdvancement) if (printAdvancement)
if (++nbExamplesProcessed >= printInterval) if (++nbExamplesProcessed >= printInterval)
...@@ -153,30 +36,26 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug, ...@@ -153,30 +36,26 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug,
} }
} }
} catch(std::exception & e) {util::myThrow(e.what());} } catch(std::exception & e) {util::myThrow(e.what());}
for (auto & c : beam) baseConfig = beam[0].config;
{ machine.getClassifier()->setState(baseConfig.getState());
// Force EOS when needed
if (machine.getTransitionSet().getTransition("EOS b.0") and c.getLastNotEmptyHypConst(Config::EOSColName, c.getWordIndex()) != Config::EOSSymbol1) if (machine.getTransitionSet().getTransition("EOS b.0") and baseConfig.getLastNotEmptyHypConst(Config::EOSColName, baseConfig.getWordIndex()) != Config::EOSSymbol1)
{ {
machine.getTransitionSet().getTransition("EOS b.0")->apply(c); machine.getTransitionSet().getTransition("EOS b.0")->apply(baseConfig);
if (debug) if (debug)
{ {
fmt::print(stderr, "Forcing EOS transition\n"); fmt::print(stderr, "Forcing EOS transition\n");
c.printForDebug(stderr); baseConfig.printForDebug(stderr);
} }
} }
// Fill holes in important columns like "ID" and "HEAD" to be compatible with eval script // Fill holes in important columns like "ID" and "HEAD" to be compatible with eval script
try {c.addMissingColumns();} try {baseConfig.addMissingColumns();}
catch (std::exception & e) {util::myThrow(e.what());} catch (std::exception & e) {util::myThrow(e.what());}
} }
baseConfig = beam[0];
}
float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex) const float Decoder::getMetricScore(const std::string & metric, std::size_t scoreIndex) const
{ {
auto found = evaluation.find(metric); auto found = evaluation.find(metric);
......
...@@ -28,6 +28,8 @@ class BaseConfig : public Config ...@@ -28,6 +28,8 @@ class BaseConfig : public Config
public : public :
BaseConfig(std::string_view mcdFilename, std::string_view tsvFilename, std::string_view rawFilename); BaseConfig(std::string_view mcdFilename, std::string_view tsvFilename, std::string_view rawFilename);
BaseConfig(const BaseConfig & other);
BaseConfig & operator=(const BaseConfig & other) = default;
std::size_t getNbColumns() const override; std::size_t getNbColumns() const override;
std::size_t getFirstLineIndex() const override; std::size_t getFirstLineIndex() const override;
......
...@@ -65,6 +65,9 @@ class Config ...@@ -65,6 +65,9 @@ class Config
protected : protected :
Config(const Utf8String & rawInput); Config(const Utf8String & rawInput);
Config(const Utf8String & rawInput, const Config & other);
Config(const Config & other) = delete;
virtual ~Config() = default;
public : public :
...@@ -97,7 +100,6 @@ class Config ...@@ -97,7 +100,6 @@ class Config
public : public :
virtual ~Config() {}
void print(FILE * dest) const; void print(FILE * dest) const;
void printForDebug(FILE * dest) const; void printForDebug(FILE * dest) const;
bool has(const std::string & colName, int lineIndex, int hypothesisIndex) const; bool has(const std::string & colName, int lineIndex, int hypothesisIndex) const;
...@@ -158,7 +160,7 @@ class Config ...@@ -158,7 +160,7 @@ class Config
const std::vector<Transition *> & getAppliableSplitTransitions() const; const std::vector<Transition *> & getAppliableSplitTransitions() const;
const std::vector<int> & getAppliableTransitions() const; const std::vector<int> & getAppliableTransitions() const;
bool isExtraColumn(const std::string & colName) const; bool isExtraColumn(const std::string & colName) const;
void setStrategy(std::vector<std::string> & strategyDefinition); void setStrategy(const std::vector<std::string> & strategyDefinition);
Strategy & getStrategy(); Strategy & getStrategy();
}; };
......
...@@ -35,7 +35,7 @@ class ReadingMachine ...@@ -35,7 +35,7 @@ class ReadingMachine
TransitionSet & getTransitionSet(); TransitionSet & getTransitionSet();
TransitionSet & getSplitWordTransitionSet(); TransitionSet & getSplitWordTransitionSet();
bool hasSplitWordTransitionSet() const; bool hasSplitWordTransitionSet() const;
std::vector<std::string> & getStrategyDefinition(); const std::vector<std::string> & getStrategyDefinition() const;
Classifier * getClassifier(); Classifier * getClassifier();
bool isPredicted(const std::string & columnName) const; bool isPredicted(const std::string & columnName) const;
const std::set<std::string> & getPredicted() const; const std::set<std::string> & getPredicted() const;
......
...@@ -154,6 +154,10 @@ void BaseConfig::readTSVInput(std::string_view tsvFilename) ...@@ -154,6 +154,10 @@ void BaseConfig::readTSVInput(std::string_view tsvFilename)
std::fclose(file); std::fclose(file);
} }
BaseConfig::BaseConfig(const BaseConfig & other) : Config(rawInputUtf8, other), colIndex2Name(other.colIndex2Name), colName2Index(other.colName2Index), rawInputUtf8(other.rawInputUtf8)
{
}
BaseConfig::BaseConfig(std::string_view mcdFilename, std::string_view tsvFilename, std::string_view rawFilename) : Config(rawInputUtf8) BaseConfig::BaseConfig(std::string_view mcdFilename, std::string_view tsvFilename, std::string_view rawFilename) : Config(rawInputUtf8)
{ {
if (tsvFilename.empty() and rawFilename.empty()) if (tsvFilename.empty() and rawFilename.empty())
......
...@@ -5,6 +5,26 @@ Config::Config(const Utf8String & rawInput) : rawInput(&rawInput) ...@@ -5,6 +5,26 @@ Config::Config(const Utf8String & rawInput) : rawInput(&rawInput)
{ {
} }
Config::Config(const Utf8String & rawInput, const Config & other) : rawInput(&rawInput)
{
this->lines = other.lines;
this->predicted = other.predicted;
this->lastPoppedStack = other.lastPoppedStack;
this->lastAttached = other.lastAttached;
this->currentWordId = other.currentWordId;
this->appliableSplitTransitions = other.appliableSplitTransitions;
this->appliableTransitions = other.appliableTransitions;
this->strategy.reset(new Strategy(*other.strategy));
this->wordIndex = other.wordIndex;
this->characterIndex = other.characterIndex;
this->state = other.state;
this->history = other.history;
this->stack = other.stack;
this->extraColumns = this->extraColumns;
}
std::size_t Config::getIndexOfLine(int lineIndex) const std::size_t Config::getIndexOfLine(int lineIndex) const
{ {
return lineIndex * getNbColumns() * (nbHypothesesMax+1); return lineIndex * getNbColumns() * (nbHypothesesMax+1);
...@@ -727,7 +747,7 @@ std::size_t Config::getStackSize() const ...@@ -727,7 +747,7 @@ std::size_t Config::getStackSize() const
return stack.size(); return stack.size();
} }
void Config::setStrategy(std::vector<std::string> & strategyDefinition) void Config::setStrategy(const std::vector<std::string> & strategyDefinition)
{ {
strategy.reset(new Strategy(strategyDefinition)); strategy.reset(new Strategy(strategyDefinition));
} }
......
...@@ -121,7 +121,7 @@ TransitionSet & ReadingMachine::getSplitWordTransitionSet() ...@@ -121,7 +121,7 @@ TransitionSet & ReadingMachine::getSplitWordTransitionSet()
return *splitWordTransitionSet; return *splitWordTransitionSet;
} }
std::vector<std::string> & ReadingMachine::getStrategyDefinition() const std::vector<std::string> & ReadingMachine::getStrategyDefinition() const
{ {
return strategyDefinition; return strategyDefinition;
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment