Skip to content
Snippets Groups Projects
Commit 75b4d5ff authored by Franck Dary's avatar Franck Dary
Browse files

Removed memory bugs. Using mean probability over the last X transitions

parent bd1a31a8
No related branches found
No related tags found
No related merge requests found
......@@ -9,6 +9,7 @@
#include <filesystem>
#include <experimental/source_location>
#include <boost/flyweight.hpp>
#include <boost/circular_buffer.hpp>
#include "fmt/core.h"
#include "utf8.hpp"
#include "utf8string.hpp"
......@@ -90,6 +91,17 @@ std::string join(const std::string & delim, const std::vector<T> elems)
return result;
}
template <typename T>
std::string join(const std::string & delim, const boost::circular_buffer<T> elems)
{
std::string result;
for (unsigned int i = 0; i < elems.size(); i++)
result = fmt::format("{}{}{}", result, elems[i], i == elems.size()-1 ? "" : delim);
return result;
}
};
template <>
......
......@@ -15,14 +15,16 @@ class Beam
public :
BaseConfig config;
int nextTransition;
float totalProbability;
std::string name;
int nextTransition{-1};
boost::circular_buffer<double> probabilities{500};
boost::circular_buffer<std::string> name{20};
float meanProbability{0.0};
bool ended{false};
public :
Element(BaseConfig & model, int nextTransition, float totalProbability, std::string name);
Element(const BaseConfig & model, int nextTransition, const boost::circular_buffer<double> & probabilities, const boost::circular_buffer<std::string> & name);
Element(const BaseConfig & model);
};
private :
......
......@@ -25,7 +25,7 @@ class Decoder
public :
Decoder(ReadingMachine & machine);
void decode(BaseConfig & config, std::size_t beamSize, bool debug, bool printAdvancement);
void decode(BaseConfig & config, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement);
void evaluate(const Config & config, std::filesystem::path modelPath, const std::string goldTSV);
std::vector<std::pair<float,std::string>> getF1Scores(const std::set<std::string> & colNames) const;
std::vector<std::pair<float,std::string>> getAlignedAccs(const std::set<std::string> & colNames) const;
......
......@@ -5,10 +5,14 @@ Beam::Beam(std::size_t width, float threshold, BaseConfig & model, const Reading
model.setStrategy(machine.getStrategyDefinition());
model.addPredicted(machine.getPredicted());
model.setState(model.getStrategy().getInitialState());
elements.emplace_back(model, -1, 0.0, "0");
elements.emplace_back(model);
}
Beam::Element::Element(BaseConfig & model, int nextTransition, float totalProbability, std::string name) : config(model), nextTransition(nextTransition), totalProbability(totalProbability), name(name)
Beam::Element::Element(const BaseConfig & model, int nextTransition, const boost::circular_buffer<double> & probabilities, const boost::circular_buffer<std::string> & name) : config(model), nextTransition(nextTransition), probabilities(probabilities), name(name)
{
}
Beam::Element::Element(const BaseConfig & model) : config(model)
{
}
......@@ -23,7 +27,7 @@ void Beam::update(ReadingMachine & machine, bool debug)
auto currentNbElements = elements.size();
if (debug)
fmt::print(stderr, "{:-<{}}\nBEAM SEARCH CONTENT :\n", "", 80);
fmt::print(stderr, "{:*<{}}BEAM START{:*<{}}\n", "", 37, "", 36);
for (unsigned int index = 0; index < currentNbElements; index++)
{
......@@ -49,7 +53,7 @@ void Beam::update(ReadingMachine & machine, bool debug)
for (unsigned int i = 0; i < prediction.size(0); i++)
{
float score = prediction[i].item<float>();
if (appliableTransitions[i] and score >= threshold)
if (appliableTransitions[i])
scoresOfTransitions.emplace_back(std::make_pair(score, i));
}
......@@ -61,18 +65,32 @@ void Beam::update(ReadingMachine & machine, bool debug)
std::sort(scoresOfTransitions.rbegin(), scoresOfTransitions.rend());
while (!scoresOfTransitions.empty() and scoresOfTransitions.back().first < threshold)
scoresOfTransitions.pop_back();
if (width > 1)
for (unsigned int i = 1; i < scoresOfTransitions.size(); i++)
{
elements.emplace_back(elements[index].config, scoresOfTransitions[i].second, elements[index].totalProbability + scoresOfTransitions[i].first, elements[index].name + ":" + std::to_string(scoresOfTransitions[i].second));
elements.emplace_back(elements[index].config, scoresOfTransitions[i].second, elements[index].probabilities, elements[index].name);
elements.back().name.push_back(std::to_string(i));
elements.back().probabilities.push_back(scoresOfTransitions[i].first);
elements.back().meanProbability = 0.0;
for (auto & p : elements.back().probabilities)
elements.back().meanProbability += p;
elements.back().meanProbability /= elements.back().probabilities.size();
}
elements[index].nextTransition = scoresOfTransitions[0].second;
elements[index].totalProbability += scoresOfTransitions[0].first;
elements[index].name += ":" + std::to_string(elements[index].nextTransition);
elements[index].probabilities.push_back(scoresOfTransitions[0].first);
elements[index].name.push_back("0");
elements[index].meanProbability = 0.0;
for (auto & p : elements[index].probabilities)
elements[index].meanProbability += p;
elements[index].meanProbability /= elements[index].probabilities.size();
if (debug)
{
fmt::print(stderr, "Element {:<3} Probability={} Name={}\n", index, elements[index].meanProbability, util::join(":", elements[index].name));
elements[index].config.printForDebug(stderr);
std::vector<std::pair<float,std::string>> toPrint;
for (unsigned int i = 0; i < prediction.size(0); i++)
......@@ -89,7 +107,7 @@ void Beam::update(ReadingMachine & machine, bool debug)
std::sort(elements.begin(), elements.end(), [](const Element & a, const Element & b)
{
return a.totalProbability > b.totalProbability;
return a.meanProbability > b.meanProbability;
});
while (elements.size() > width)
......@@ -122,7 +140,7 @@ void Beam::update(ReadingMachine & machine, bool debug)
}
if (debug)
fmt::print(stderr, "END OF BEAM SEARCH CONTENT\n{:-<{}}\n", "", 80);
fmt::print(stderr, "{:*<{}}BEAM END{:*<{}}\n", "", 37, "", 38);
}
bool Beam::isEnded() const
......
......@@ -6,7 +6,7 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine)
{
}
void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug, bool printAdvancement)
void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, float beamThreshold, bool debug, bool printAdvancement)
{
constexpr int printInterval = 50;
......@@ -17,7 +17,7 @@ void Decoder::decode(BaseConfig & baseConfig, std::size_t beamSize, bool debug,
int nbExamplesProcessed = 0;
auto pastTime = std::chrono::high_resolution_clock::now();
Beam beam(beamSize, 0.1, baseConfig, machine);
Beam beam(beamSize, beamThreshold, baseConfig, machine);
try
{
......
......@@ -24,6 +24,8 @@ po::options_description MacaonDecode::getOptionsDescription()
("silent", "Don't print speed and progress")
("beamSize", po::value<int>()->default_value(1),
"Size of the beam during beam search")
("beamThreshold", po::value<float>()->default_value(0.1),
"Minimal probability an action must have to be considered in the beam search")
("help,h", "Produce this help message");
desc.add(req).add(opt);
......@@ -74,6 +76,7 @@ int MacaonDecode::main()
bool debug = variables.count("debug") == 0 ? false : true;
bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
auto beamSize = variables["beamSize"].as<int>();
auto beamThreshold = variables["beamThreshold"].as<float>();
torch::globalContext().setBenchmarkCuDNN(true);
......@@ -89,7 +92,7 @@ int MacaonDecode::main()
BaseConfig config(mcdFile, inputTSV, inputTXT);
decoder.decode(config, beamSize, debug, printAdvancement);
decoder.decode(config, beamSize, beamThreshold, debug, printAdvancement);
config.print(stdout);
} catch(std::exception & e) {util::error(e);}
......
......@@ -17,8 +17,6 @@ class BaseConfig : public Config
std::vector<std::string> colIndex2Name;
std::unordered_map<std::string, int> colName2Index;
Utf8String rawInputUtf8;
private :
void readMCD(std::string_view mcdFilename);
......
......@@ -44,29 +44,29 @@ class Config
private :
std::vector<String> lines;
std::set<std::string> predicted;
int lastPoppedStack{-1};
int lastAttached{-1};
int currentWordId{0};
std::vector<Transition *> appliableSplitTransitions;
std::vector<int> appliableTransitions;
std::shared_ptr<Strategy> strategy;
protected :
const Utf8String * rawInput;
Utf8String rawInput;
std::size_t wordIndex{0};
std::size_t characterIndex{0};
String state{"NONE"};
boost::circular_buffer<String> history{10};
boost::circular_buffer<std::size_t> stack{50};
std::vector<std::string> extraColumns{isMultiColName, childsColName, sentIdColName, EOSColName};
std::set<std::string> predicted;
int lastPoppedStack{-1};
int lastAttached{-1};
int currentWordId{0};
std::vector<Transition *> appliableSplitTransitions;
std::vector<int> appliableTransitions;
std::shared_ptr<Strategy> strategy;
protected :
Config(const Utf8String & rawInput);
Config(const Utf8String & rawInput, const Config & other);
Config(const Config & other) = delete;
Config() = default;
Config & operator=(const Config & other) = default;
Config(const Config & other);
virtual ~Config() = default;
public :
......
......@@ -43,9 +43,9 @@ void BaseConfig::readRawInput(std::string_view rawFilename)
std::fclose(file);
rawInputUtf8 = util::splitAsUtf8(rawInputTemp);
rawInputUtf8.replace(util::utf8char("\n"), util::utf8char(" "));
rawInputUtf8.replace(util::utf8char("\t"), util::utf8char(" "));
rawInput = util::splitAsUtf8(rawInputTemp);
rawInput.replace(util::utf8char("\n"), util::utf8char(" "));
rawInput.replace(util::utf8char("\t"), util::utf8char(" "));
}
void BaseConfig::readTSVInput(std::string_view tsvFilename)
......@@ -154,11 +154,11 @@ void BaseConfig::readTSVInput(std::string_view tsvFilename)
std::fclose(file);
}
BaseConfig::BaseConfig(const BaseConfig & other) : Config(rawInputUtf8, other), colIndex2Name(other.colIndex2Name), colName2Index(other.colName2Index), rawInputUtf8(other.rawInputUtf8)
BaseConfig::BaseConfig(const BaseConfig & other) : Config(other), colIndex2Name(other.colIndex2Name), colName2Index(other.colName2Index)
{
}
BaseConfig::BaseConfig(std::string_view mcdFilename, std::string_view tsvFilename, std::string_view rawFilename) : Config(rawInputUtf8)
BaseConfig::BaseConfig(std::string_view mcdFilename, std::string_view tsvFilename, std::string_view rawFilename)
{
if (tsvFilename.empty() and rawFilename.empty())
util::myThrow("tsvFilename and rawFilenames can't be both empty");
......
#include "Config.hpp"
#include "util.hpp"
Config::Config(const Utf8String & rawInput) : rawInput(&rawInput)
{
}
Config::Config(const Utf8String & rawInput, const Config & other) : rawInput(&rawInput)
Config::Config(const Config & other)
{
this->lines = other.lines;
this->predicted = other.predicted;
......@@ -17,6 +13,7 @@ Config::Config(const Utf8String & rawInput, const Config & other) : rawInput(&ra
this->strategy.reset(new Strategy(*other.strategy));
this->rawInput = other.rawInput;
this->wordIndex = other.wordIndex;
this->characterIndex = other.characterIndex;
this->state = other.state;
......@@ -151,8 +148,6 @@ void Config::printForDebug(FILE * dest) const
static constexpr int lettersWindowSize = 40;
static constexpr int maxWordLength = 7;
fmt::print(dest, "\n");
int firstLineToPrint = wordIndex;
int lastLineToPrint = wordIndex;
while (wordIndex-firstLineToPrint < windowSize and has(0, firstLineToPrint-1, 0))
......@@ -228,9 +223,9 @@ void Config::printForDebug(FILE * dest) const
if (!stackStr.empty())
stackStr.pop_back();
fmt::print(dest, "{}\n", longLine);
for (std::size_t index = characterIndex; index < util::getSize(*rawInput) and index - characterIndex < lettersWindowSize; index++)
for (std::size_t index = characterIndex; index < util::getSize(rawInput) and index - characterIndex < lettersWindowSize; index++)
fmt::print(dest, "{}", getLetter(index));
if (rawInput->size())
if (!util::isEmpty(rawInput))
fmt::print(dest, "\n{}\n", longLine);
fmt::print(dest, "State={}\nwordIndex={} characterIndex={}\nhistory=({})\nstack=({})\n", state, wordIndex, characterIndex, historyStr, stackStr);
fmt::print(dest, "{}\n", longLine);
......@@ -386,12 +381,12 @@ void Config::swapStack(int relIndex1, int relIndex2)
bool Config::hasCharacter(int letterIndex) const
{
return letterIndex >= 0 and letterIndex < (int)util::getSize(*rawInput);
return letterIndex >= 0 and letterIndex < (int)util::getSize(rawInput);
}
util::utf8char Config::getLetter(int letterIndex) const
{
return (*rawInput)[letterIndex];
return rawInput[letterIndex];
}
bool Config::isComment(std::size_t lineIndex) const
......@@ -515,20 +510,20 @@ bool Config::canMoveWordIndex(int relativeMovement) const
bool Config::moveCharacterIndex(int relativeMovement)
{
int oldVal = characterIndex;
characterIndex = std::max(0, (int)std::min(characterIndex+relativeMovement, util::getSize(*rawInput)));
characterIndex = std::max(0, (int)std::min(characterIndex+relativeMovement, util::getSize(rawInput)));
return (int)characterIndex == oldVal + relativeMovement;
}
bool Config::canMoveCharacterIndex(int relativeMovement) const
{
int target = std::max(0, (int)std::min(characterIndex+relativeMovement, util::getSize(*rawInput)));
int target = std::max(0, (int)std::min(characterIndex+relativeMovement, util::getSize(rawInput)));
return target == (int)characterIndex + relativeMovement;
}
bool Config::rawInputOnlySeparatorsLeft() const
{
for (unsigned int i = characterIndex; i < rawInput->size(); i++)
if (!util::isSeparator((*rawInput)[i]))
for (unsigned int i = characterIndex; i < util::getSize(rawInput); i++)
if (!util::isSeparator(rawInput[i]))
return false;
return true;
......@@ -585,7 +580,7 @@ void Config::setState(const std::string state)
bool Config::stateIsDone() const
{
if (!rawInput->empty())
if (!util::isEmpty(rawInput))
return rawInputOnlySeparatorsLeft() and !has(0, wordIndex+1, 0) and !hasStack(0);
return !has(0, wordIndex+1, 0) and !hasStack(0);
......
#include "SubConfig.hpp"
SubConfig::SubConfig(BaseConfig & model, std::size_t spanSize) : Config(*model.rawInput), model(model), spanSize(spanSize)
SubConfig::SubConfig(BaseConfig & model, std::size_t spanSize) : model(model), spanSize(spanSize)
{
rawInput = model.rawInput;
wordIndex = model.wordIndex;
characterIndex = model.characterIndex;
state = model.state;
history = model.history;
stack = model.stack;
extraColumns = model.extraColumns;
predicted = model.predicted;
lastPoppedStack = model.lastPoppedStack;
lastAttached = model.lastAttached;
currentWordId = model.currentWordId;
appliableSplitTransitions = model.appliableSplitTransitions;
appliableTransitions = model.appliableTransitions;
if (model.strategy.get() != nullptr)
strategy.reset(new Strategy(model.getStrategy()));
update();
}
......
......@@ -178,7 +178,7 @@ int MacaonTrain::main()
if (computeDevScore)
{
auto devConfig = devGoldConfig;
decoder.decode(devConfig, 1, debug, printAdvancement);
decoder.decode(devConfig, 1, 0.0, debug, printAdvancement);
decoder.evaluate(devConfig, modelPath, devTsvFile);
devScores = decoder.getF1Scores(machine.getPredicted());
}
......@@ -217,7 +217,7 @@ int MacaonTrain::main()
machine.getClassifier()->saveOptimizer(optimizerCheckpoint);
if (printAdvancement)
fmt::print(stderr, "\r{:80}\r", "");
std::string iterStr = fmt::format("[{}] Epoch {:^5} loss = {:6.4f} dev = {} {:5}", util::getTime(), fmt::format("{}/{}", currentEpoch+1, nbEpoch), loss, devScoresStr, saved ? "SAVED" : "");
std::string iterStr = fmt::format("[{}] Epoch {:^5} loss = {:6.4f} dev = {} {:5}", util::getTime(), fmt::format("{}/{}", currentEpoch+1, nbEpoch), 100.0*loss, devScoresStr, saved ? "SAVED" : "");
fmt::print(stderr, "{}\n", iterStr);
std::FILE * f = std::fopen(trainInfos.c_str(), "a");
fmt::print(f, "{}\t{}\n", iterStr, devScoreMean);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment