Skip to content
Snippets Groups Projects
Commit d377af89 authored by Franck Dary's avatar Franck Dary
Browse files

modernized EOS

parent 5ed6be65
Branches
No related tags found
No related merge requests found
......@@ -21,12 +21,6 @@ class Action
Check
};
enum Object
{
Buffer,
Stack
};
private :
Type type;
......@@ -44,22 +38,21 @@ class Action
public :
static Object str2object(const std::string & s);
static Action addLinesIfNeeded(int nbLines);
static Action moveWordIndex(int movement);
static Action moveCharacterIndex(int movement);
static Action addHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & hypothesis);
static Action addToHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & addition);
static Action addHypothesisRelative(const std::string & colName, Object object, int relativeIndex, const std::string & hypothesis);
static Action addToHypothesisRelative(const std::string & colName, Object object, int relativeIndex, const std::string & addition);
static Action addHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & hypothesis);
static Action addToHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & addition);
static Action pushWordIndexOnStack();
static Action popStack();
static Action emptyStack();
static Action setRoot();
static Action updateIds();
static Action setRoot(int bufferIndex);
static Action updateIds(int bufferIndex);
static Action endWord();
static Action assertIsEmpty(const std::string & colName);
static Action attach(Object governorObject, int governorIndex, Object dependentObject, int dependentIndex);
static Action attach(Config::Object governorObject, int governorIndex, Config::Object dependentObject, int dependentIndex);
static Action addCurCharToCurWord();
static Action ignoreCurrentCharacter();
static Action consumeCharacterIndex(util::utf8string consumed);
......
......@@ -21,11 +21,18 @@ class Config
static constexpr const char * headColName = "HEAD";
static constexpr const char * deprelColName = "DEPREL";
static constexpr const char * idColName = "ID";
static constexpr const char * sentIdColName = "SENTID";
static constexpr const char * isMultiColName = "MULTI";
static constexpr const char * childsColName = "CHILDS";
static constexpr int nbHypothesesMax = 1;
static constexpr int maxNbAppliableSplitTransitions = 8;
enum Object
{
Buffer,
Stack
};
public :
using String = boost::flyweight<std::string>;
......@@ -56,6 +63,8 @@ class Config
public :
static Object str2object(const std::string & s);
virtual std::size_t getNbColumns() const = 0;
virtual std::size_t getColIndex(const std::string & colName) const = 0;
virtual bool hasColIndex(const std::string & colName) const = 0;
......@@ -78,6 +87,8 @@ class Config
ValueIterator getIterator(int colIndex, int lineIndex, int hypothesisIndex);
ConstValueIterator getConstIterator(int colIndex, int lineIndex, int hypothesisIndex) const;
long getRelativeWordIndex(int relativeIndex) const;
public :
virtual ~Config() {}
......@@ -116,7 +127,8 @@ class Config
bool rawInputOnlySeparatorsLeft() const;
std::size_t getWordIndex() const;
std::size_t getCharacterIndex() const;
long getRelativeWordIndex(int relativeIndex) const;
long getRelativeWordIndex(Object object, int relativeIndex) const;
bool hasRelativeWordIndex(Object object, int relativeIndex) const;
const String & getHistory(int relativeIndex) const;
std::size_t getStack(int relativeIndex) const;
bool hasHistory(int relativeIndex) const;
......
......@@ -31,7 +31,7 @@ class Strategy
public :
Strategy(const std::vector<std::string_view> & lines);
Strategy(std::vector<std::string> lines);
std::pair<std::string, int> getMovement(const Config & c, const std::string & transition);
const std::string getInitialState() const;
void reset();
......
......@@ -22,7 +22,7 @@ class Transition
void initLeft(std::string label);
void initRight(std::string label);
void initReduce();
void initEOS();
void initEOS(int bufferIndex);
void initNothing();
void initIgnoreChar();
void initEndWord();
......
......@@ -53,11 +53,11 @@ Action Action::setMultiwordIds(int multiwordSize)
{
auto apply = [multiwordSize](Config & config, Action & a)
{
addHypothesisRelative(Config::idColName, Object::Buffer, 0, fmt::format("{}-{}", config.getCurrentWordId()+1, config.getCurrentWordId()+multiwordSize)).apply(config, a);
addHypothesisRelative(Config::idColName, Config::Object::Buffer, 0, fmt::format("{}-{}", config.getCurrentWordId()+1, config.getCurrentWordId()+multiwordSize)).apply(config, a);
for (int i = 0; i < multiwordSize; i++)
{
addHypothesisRelative(Config::idColName, Object::Buffer, i+1, fmt::format("{}", config.getCurrentWordId()+1+i)).apply(config, a);
addHypothesisRelative(Config::isMultiColName, Object::Buffer, i+1, Config::EOSSymbol1).apply(config, a);
addHypothesisRelative(Config::idColName, Config::Object::Buffer, i+1, fmt::format("{}", config.getCurrentWordId()+1+i)).apply(config, a);
addHypothesisRelative(Config::isMultiColName, Config::Object::Buffer, i+1, Config::EOSSymbol1).apply(config, a);
}
};
......@@ -176,80 +176,58 @@ Action Action::addToHypothesis(const std::string & colName, std::size_t lineInde
return {Type::Write, apply, undo, appliable};
}
Action Action::addToHypothesisRelative(const std::string & colName, Object object, int relativeIndex, const std::string & addition)
Action Action::addToHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & addition)
{
auto apply = [colName, object, relativeIndex, addition](Config & config, Action & a)
{
int lineIndex = 0;
if (object == Object::Buffer)
lineIndex = config.getWordIndex() + relativeIndex;
else
lineIndex = config.getStack(relativeIndex);
int lineIndex = config.getRelativeWordIndex(object, relativeIndex);
return addToHypothesis(colName, lineIndex, addition).apply(config, a);
};
auto undo = [colName, object, relativeIndex](Config & config, Action & a)
{
int lineIndex = 0;
if (object == Object::Buffer)
lineIndex = config.getWordIndex() + relativeIndex;
else
lineIndex = config.getStack(relativeIndex);
int lineIndex = config.getRelativeWordIndex(object, relativeIndex);
return addToHypothesis(colName, lineIndex, "").undo(config, a);
};
auto appliable = [colName, object, relativeIndex, addition](const Config & config, const Action & a)
{
int lineIndex = 0;
if (object == Object::Buffer)
lineIndex = config.getWordIndex() + relativeIndex;
else if (config.hasStack(relativeIndex))
lineIndex = config.getStack(relativeIndex);
else
if (!config.hasRelativeWordIndex(object, relativeIndex))
return false;
int lineIndex = config.getRelativeWordIndex(object, relativeIndex);
return addToHypothesis(colName, lineIndex, addition).appliable(config, a);
};
return {Type::Write, apply, undo, appliable};
}
Action Action::addHypothesisRelative(const std::string & colName, Object object, int relativeIndex, const std::string & hypothesis)
Action Action::addHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & hypothesis)
{
auto apply = [colName, object, relativeIndex, hypothesis](Config & config, Action & a)
{
int lineIndex = 0;
if (object == Object::Buffer)
lineIndex = config.getWordIndex() + relativeIndex;
else
lineIndex = config.getStack(relativeIndex);
int lineIndex = config.getRelativeWordIndex(object, relativeIndex);
return addHypothesis(colName, lineIndex, hypothesis).apply(config, a);
};
auto undo = [colName, object, relativeIndex](Config & config, Action & a)
{
int lineIndex = 0;
if (object == Object::Buffer)
lineIndex = config.getWordIndex() + relativeIndex;
else
lineIndex = config.getStack(relativeIndex);
int lineIndex = config.getRelativeWordIndex(object, relativeIndex);
return addHypothesis(colName, lineIndex, "").undo(config, a);
};
auto appliable = [colName, object, relativeIndex](const Config & config, const Action & a)
{
int lineIndex = 0;
if (object == Object::Buffer)
lineIndex = config.getWordIndex() + relativeIndex;
else if (config.hasStack(relativeIndex))
lineIndex = config.getStack(relativeIndex);
else
if (!config.hasRelativeWordIndex(object, relativeIndex))
return false;
int lineIndex = config.getRelativeWordIndex(object, relativeIndex);
return addHypothesis(colName, lineIndex, "").appliable(config, a);
};
......@@ -317,7 +295,7 @@ Action Action::endWord()
auto apply = [](Config & config, Action & a)
{
config.setCurrentWordId(config.getCurrentWordId()+1);
addHypothesisRelative(Config::idColName, Object::Buffer, 0, std::to_string(config.getCurrentWordId())).apply(config, a);
addHypothesisRelative(Config::idColName, Config::Object::Buffer, 0, std::to_string(config.getCurrentWordId())).apply(config, a);
if (!config.rawInputOnlySeparatorsLeft() and !config.has(0,config.getWordIndex()+1,0))
config.addLines(1);
......@@ -442,14 +420,14 @@ Action Action::addCurCharToCurWord()
return {Type::Write, apply, undo, appliable};
}
Action Action::setRoot()
Action Action::setRoot(int bufferIndex)
{
auto apply = [](Config & config, Action & a)
auto apply = [bufferIndex](Config & config, Action & a)
{
int lineIndex = config.getRelativeWordIndex(Config::Object::Buffer, bufferIndex);
int rootIndex = -1;
int firstSentenceIndex = -1;
for (int i = config.getStack(0); true; --i)
for (int i = lineIndex; true; --i)
{
if (!config.has(0, i, 0))
{
......@@ -460,19 +438,17 @@ Action Action::setRoot()
if (!config.isTokenPredicted(i))
continue;
if (config.getLastNotEmptyHypConst(Config::EOSColName, i) == Config::EOSSymbol1)
if (config.getAsFeature(Config::EOSColName, i) == Config::EOSSymbol1)
break;
firstSentenceIndex = i;
if (util::isEmpty(config.getLastNotEmptyHypConst(Config::headColName, i)))
if (util::isEmpty(config.getAsFeature(Config::headColName, i)))
{
rootIndex = i;
a.data.push_back(std::to_string(i));
}
}
for (int i = config.getStack(0); true; --i)
for (int i = lineIndex; true; --i)
{
if (!config.has(0, i, 0))
{
......@@ -483,10 +459,10 @@ Action Action::setRoot()
if (!config.isTokenPredicted(i))
continue;
if (config.getLastNotEmptyHypConst(Config::EOSColName, i) == Config::EOSSymbol1)
if (config.getAsFeature(Config::EOSColName, i) == Config::EOSSymbol1)
break;
if (util::isEmpty(config.getLastNotEmptyHypConst(Config::headColName, i)))
if (util::isEmpty(config.getAsFeature(Config::headColName, i)))
{
if (i == rootIndex)
{
......@@ -498,11 +474,6 @@ Action Action::setRoot()
config.getFirstEmpty(Config::headColName, i) = std::to_string(rootIndex);
}
}
else
{
if (std::stoi(config.getLastNotEmptyHypConst(Config::headColName, i)) < firstSentenceIndex)
config.getFirstEmpty(Config::headColName, i) = std::to_string(rootIndex);
}
}
};
......@@ -516,20 +487,23 @@ Action Action::setRoot()
}
};
auto appliable = [](const Config & config, const Action &)
auto appliable = [bufferIndex](const Config & config, const Action &)
{
return config.hasStack(0) and config.isTokenPredicted(config.getStack(0)) and config.getLastNotEmptyConst(Config::isMultiColName, config.getStack(0)) != Config::EOSSymbol1;
int lineIndex = config.getRelativeWordIndex(Config::Object::Buffer, bufferIndex);
return config.has(0,lineIndex,0) and config.isTokenPredicted(lineIndex) and config.getAsFeature(Config::isMultiColName, lineIndex) != Config::EOSSymbol1;
};
return {Type::Write, apply, undo, appliable};
}
Action Action::updateIds()
Action Action::updateIds(int bufferIndex)
{
auto apply = [](Config & config, Action & a)
auto apply = [bufferIndex](Config & config, Action & a)
{
int lineIndex = config.getRelativeWordIndex(Config::Object::Buffer, bufferIndex);
int firstIndexOfSentence = -1;
for (int i = config.getStack(0); true; --i)
int lastSentId = -1;
for (int i = lineIndex; true; --i)
{
if (!config.has(0, i, 0))
{
......@@ -541,7 +515,10 @@ Action Action::updateIds()
continue;
if (config.getLastNotEmptyHypConst(Config::EOSColName, i) == Config::EOSSymbol1)
{
lastSentId = std::stoi(config.getAsFeature(Config::sentIdColName, i));
break;
}
firstIndexOfSentence = i;
}
......@@ -549,18 +526,17 @@ Action Action::updateIds()
if (firstIndexOfSentence < 0)
util::myThrow("could not find any token in current sentence");
for (unsigned int i = firstIndexOfSentence, currentId = 1; i <= config.getStack(0); ++i)
for (int i = firstIndexOfSentence, currentId = 1; i <= lineIndex; ++i)
{
if (config.isComment(i) || config.isEmptyNode(i))
continue;
if (config.getLastNotEmptyHypConst(Config::EOSColName, i) == Config::EOSSymbol1)
break;
if (config.isMultiwordPredicted(i))
config.getFirstEmpty(Config::idColName, i) = fmt::format("{}-{}", currentId, currentId+config.getMultiwordSizePredicted(i));
else
config.getFirstEmpty(Config::idColName, i) = fmt::format("{}", currentId++);
config.getFirstEmpty(Config::sentIdColName, i) = fmt::format("{}", lastSentId+1);
}
};
......@@ -577,20 +553,12 @@ Action Action::updateIds()
return {Type::Write, apply, undo, appliable};
}
Action Action::attach(Object governorObject, int governorIndex, Object dependentObject, int dependentIndex)
Action Action::attach(Config::Object governorObject, int governorIndex, Config::Object dependentObject, int dependentIndex)
{
auto apply = [governorObject, governorIndex, dependentObject, dependentIndex](Config & config, Action & a)
{
int lineIndex = 0;
if (governorObject == Object::Buffer)
lineIndex = config.getWordIndex() + governorIndex;
else
lineIndex = config.getStack(governorIndex);
int depIndex = 0;
if (dependentObject == Object::Buffer)
depIndex = config.getWordIndex() + dependentIndex;
else
depIndex = config.getStack(dependentIndex);
long lineIndex = config.getRelativeWordIndex(governorObject, governorIndex);
long depIndex = config.getRelativeWordIndex(dependentObject, dependentIndex);
addHypothesisRelative(Config::headColName, dependentObject, dependentIndex, std::to_string(lineIndex)).apply(config, a);
addToHypothesisRelative(Config::childsColName, governorObject, governorIndex, std::to_string(depIndex)).apply(config, a);
......@@ -604,35 +572,16 @@ Action Action::attach(Object governorObject, int governorIndex, Object dependent
auto appliable = [governorObject, governorIndex, dependentObject, dependentIndex](const Config & config, const Action & action)
{
int govLineIndex = 0;
if (governorObject == Object::Buffer)
{
govLineIndex = config.getWordIndex() + governorIndex;
if (!config.has(0, govLineIndex, 0))
if (!config.hasRelativeWordIndex(governorObject, governorIndex) or !config.hasRelativeWordIndex(dependentObject, dependentIndex))
return false;
}
else
{
if (!config.hasStack(governorIndex))
return false;
govLineIndex = config.getStack(governorIndex);
}
long govLineIndex = config.getRelativeWordIndex(governorObject, governorIndex);
long depLineIndex = config.getRelativeWordIndex(dependentObject, dependentIndex);
int depLineIndex = 0;
if (dependentObject == Object::Buffer)
{
depLineIndex = config.getWordIndex() + dependentIndex;
if (!config.has(0, depLineIndex, 0))
return false;
}
else
{
if (!config.hasStack(dependentIndex))
if (!config.isTokenPredicted(govLineIndex) or !config.isTokenPredicted(depLineIndex))
return false;
depLineIndex = config.getStack(dependentIndex);
}
if (!config.isTokenPredicted(govLineIndex) or !config.isTokenPredicted(depLineIndex))
// Check if dep and head belongs to the same sentence
if (config.getAsFeature(Config::sentIdColName, govLineIndex) != config.getAsFeature(Config::sentIdColName, depLineIndex))
return false;
// Check for cycles
......@@ -640,7 +589,7 @@ Action Action::attach(Object governorObject, int governorIndex, Object dependent
{
try
{
govLineIndex = std::stoi(config.getLastNotEmptyHypConst(Config::headColName, govLineIndex));
govLineIndex = std::stoi(config.getAsFeature(Config::headColName, govLineIndex));
} catch(std::exception &) {return true;}
}
......@@ -677,14 +626,3 @@ Action Action::split(int index)
return {Type::Write, apply, undo, appliable};
}
Action::Object Action::str2object(const std::string & s)
{
if (s == "b")
return Object::Buffer;
if (s == "s")
return Object::Stack;
util::myThrow(fmt::format("Invalid object '{}'", s));
return Object::Buffer;
}
......@@ -38,6 +38,11 @@ void BaseConfig::readMCD(std::string_view mcdFilename)
colIndex2Name.emplace_back(childsColName);
colName2Index.emplace(childsColName, colIndex2Name.size()-1);
if (colName2Index.count(sentIdColName))
util::myThrow(fmt::format("mcd '{}' must not contain column '{}'", mcdFilename, sentIdColName));
colIndex2Name.emplace_back(sentIdColName);
colName2Index.emplace(sentIdColName, colIndex2Name.size()-1);
if (colName2Index.count(EOSColName))
util::myThrow(fmt::format("mcd '{}' must not contain column '{}'", mcdFilename, EOSColName));
colIndex2Name.emplace_back(EOSColName);
......
......@@ -637,6 +637,22 @@ long Config::getRelativeWordIndex(int relativeIndex) const
return -1;
}
long Config::getRelativeWordIndex(Object object, int relativeIndex) const
{
if (object == Object::Buffer)
return getRelativeWordIndex(relativeIndex);
return getStack(relativeIndex);
}
bool Config::hasRelativeWordIndex(Object object, int relativeIndex) const
{
if (object == Object::Buffer)
return has(0,getRelativeWordIndex(relativeIndex),0);
return hasStack(relativeIndex);
}
void Config::setAppliableSplitTransitions(const std::vector<Transition *> & appliableSplitTransitions)
{
this->appliableSplitTransitions = appliableSplitTransitions;
......@@ -647,3 +663,14 @@ const std::vector<Transition *> & Config::getAppliableSplitTransitions() const
return appliableSplitTransitions;
}
Config::Object Config::str2object(const std::string & s)
{
if (s == "b")
return Object::Buffer;
if (s == "s")
return Object::Stack;
util::myThrow(fmt::format("Invalid object '{}'", s));
return Object::Buffer;
}
......@@ -104,7 +104,7 @@ void ReadingMachine::readFromFile(std::filesystem::path path)
}))
util::myThrow("No predictions specified");
auto restOfFile = std::vector<std::string_view>(lines.begin()+curLine, lines.end());
auto restOfFile = std::vector<std::string>(lines.begin()+curLine, lines.end());
strategy.reset(new Strategy(restOfFile));
......
#include "Strategy.hpp"
Strategy::Strategy(const std::vector<std::string_view> & lines)
Strategy::Strategy(std::vector<std::string> lines)
{
if (!util::doIfNameMatch(std::regex("Strategy : ((incremental)|(sequential))"), lines[0], [this](auto sm)
{type = sm[1] == "sequential" ? Type::Sequential : Type::Incremental;}))
......@@ -8,6 +8,7 @@ Strategy::Strategy(const std::vector<std::string_view> & lines)
for (unsigned int i = 1; i < lines.size(); i++)
{
std::replace(lines[i].begin(), lines[i].end(), '\t', ' ');
auto splited = util::split(lines[i], ' ');
std::pair<std::string, std::string> key;
std::string value;
......
......@@ -17,8 +17,8 @@ Transition::Transition(const std::string & name)
[this](auto sm){(initLeft(sm[1]));}},
{std::regex("RIGHT (.+)"),
[this](auto sm){(initRight(sm[1]));}},
{std::regex("EOS"),
[this](auto){initEOS();}},
{std::regex("EOS b\\.(.+)"),
[this](auto sm){initEOS(std::stoi(sm[1]));}},
{std::regex("NOTHING"),
[this](auto){initNothing();}},
{std::regex("IGNORECHAR"),
......@@ -89,18 +89,14 @@ const std::string & Transition::getName() const
void Transition::initWrite(std::string colName, std::string object, std::string index, std::string value)
{
auto objectValue = Action::str2object(object);
auto objectValue = Config::str2object(object);
int indexValue = std::stoi(index);
sequence.emplace_back(Action::addHypothesisRelative(colName, objectValue, indexValue, value));
cost = [colName, objectValue, indexValue, value](const Config & config)
{
int lineIndex = 0;
if (objectValue == Action::Object::Buffer)
lineIndex = config.getWordIndex() + indexValue;
else
lineIndex = config.getStack(indexValue);
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
if (config.getConst(colName, lineIndex, 0) == value)
return 0;
......@@ -111,18 +107,14 @@ void Transition::initWrite(std::string colName, std::string object, std::string
void Transition::initAdd(std::string colName, std::string object, std::string index, std::string value)
{
auto objectValue = Action::str2object(object);
auto objectValue = Config::str2object(object);
int indexValue = std::stoi(index);
sequence.emplace_back(Action::addToHypothesisRelative(colName, objectValue, indexValue, value));
cost = [colName, objectValue, indexValue, value](const Config & config)
{
int lineIndex = 0;
if (objectValue == Action::Object::Buffer)
lineIndex = config.getWordIndex() + indexValue;
else
lineIndex = config.getStack(indexValue);
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
auto gold = util::split(config.getConst(colName, lineIndex, 0).get(), '|');
......@@ -198,7 +190,7 @@ void Transition::initSplitWord(std::vector<std::string> words)
sequence.emplace_back(Action::addLinesIfNeeded(words.size()));
sequence.emplace_back(Action::consumeCharacterIndex(consumedWord));
for (unsigned int i = 0; i < words.size(); i++)
sequence.emplace_back(Action::addHypothesisRelative("FORM", Action::Object::Buffer, i, words[i]));
sequence.emplace_back(Action::addHypothesisRelative("FORM", Config::Object::Buffer, i, words[i]));
sequence.emplace_back(Action::setMultiwordIds(words.size()-1));
cost = [words](const Config & config)
......@@ -266,8 +258,8 @@ void Transition::initShift()
void Transition::initLeft(std::string label)
{
sequence.emplace_back(Action::attach(Action::Object::Buffer, 0, Action::Object::Stack, 0));
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Action::Object::Stack, 0, label));
sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0));
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label));
sequence.emplace_back(Action::popStack());
cost = [label](const Config & config)
......@@ -308,8 +300,8 @@ void Transition::initLeft(std::string label)
void Transition::initRight(std::string label)
{
sequence.emplace_back(Action::attach(Action::Object::Stack, 0, Action::Object::Buffer, 0));
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Action::Object::Buffer, 0, label));
sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Buffer, 0));
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Buffer, 0, label));
sequence.emplace_back(Action::pushWordIndexOnStack());
cost = [label](const Config & config)
......@@ -395,40 +387,19 @@ void Transition::initReduce()
};
}
void Transition::initEOS()
void Transition::initEOS(int bufferIndex)
{
sequence.emplace_back(Action::setRoot());
sequence.emplace_back(Action::updateIds());
sequence.emplace_back(Action::addHypothesisRelative(Config::EOSColName, Action::Object::Stack, 0, Config::EOSSymbol1));
sequence.emplace_back(Action::emptyStack());
sequence.emplace_back(Action::setRoot(bufferIndex));
sequence.emplace_back(Action::updateIds(bufferIndex));
sequence.emplace_back(Action::addHypothesisRelative(Config::EOSColName, Config::Object::Buffer, bufferIndex, Config::EOSSymbol1));
cost = [](const Config & config)
cost = [bufferIndex](const Config & config)
{
if (!config.has(0, config.getStack(0), 0))
return std::numeric_limits<int>::max();
if (!config.isToken(config.getStack(0)))
return std::numeric_limits<int>::max();
if (config.getConst(Config::EOSColName, config.getStack(0), 0) != Config::EOSSymbol1)
int lineIndex = config.getRelativeWordIndex(Config::Buffer, bufferIndex);
if (config.getConst(Config::EOSColName, lineIndex, 0) != Config::EOSSymbol1)
return std::numeric_limits<int>::max();
int cost = 0;
--cost;
for (int i = 0; config.hasStack(i); ++i)
{
if (!config.has(0, config.getStack(i), 0))
continue;
auto otherStackIndex = config.getStack(i);
auto otherStackGovPred = config.getAsFeature(Config::headColName, otherStackIndex);
if (util::isEmpty(otherStackGovPred))
++cost;
}
return cost;
return 0;
};
}
......@@ -43,14 +43,14 @@ class Trainer
void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval);
float processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples);
void fillDicts(SubConfig & config);
void fillDicts(SubConfig & config, bool debug);
public :
Trainer(ReadingMachine & machine, int batchSize);
void createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval);
void createDevDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval);
void fillDicts(BaseConfig & goldConfig);
void fillDicts(BaseConfig & goldConfig, bool debug);
float epoch(bool printAdvancement);
float evalOnDev(bool printAdvancement);
};
......
......@@ -114,7 +114,7 @@ int MacaonTrain::main()
if (machine.dictsAreNew())
{
trainer.fillDicts(goldConfig);
trainer.fillDicts(goldConfig, debug);
for (auto & it : machine.getDicts())
{
std::size_t originalSize = it.second.size();
......
......@@ -270,7 +270,7 @@ void Trainer::Examples::addClass(int goldIndex)
classes.emplace_back(gold);
}
void Trainer::fillDicts(BaseConfig & goldConfig)
void Trainer::fillDicts(BaseConfig & goldConfig, bool debug)
{
SubConfig config(goldConfig, goldConfig.getNbLines());
......@@ -280,13 +280,13 @@ void Trainer::fillDicts(BaseConfig & goldConfig)
machine.trainMode(false);
machine.setDictsState(Dict::State::Open);
fillDicts(config);
fillDicts(config, debug);
for (auto & it : machine.getDicts())
it.second.countOcc(false);
}
void Trainer::fillDicts(SubConfig & config)
void Trainer::fillDicts(SubConfig & config, bool debug)
{
torch::AutoGradMode useGrad(false);
......@@ -297,6 +297,9 @@ void Trainer::fillDicts(SubConfig & config)
while (true)
{
if (debug)
config.printForDebug(stderr);
if (machine.hasSplitWordTransitionSet())
config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
......@@ -321,6 +324,8 @@ void Trainer::fillDicts(SubConfig & config)
config.addToHistory(goldTransition->getName());
auto movement = machine.getStrategy().getMovement(config, goldTransition->getName());
if (debug)
fmt::print(stderr, "(Transition,Newstate,Movement) = ({},{},{})\n", goldTransition->getName(), movement.first, movement.second);
if (movement == Strategy::endMovement)
break;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment