Skip to content
Snippets Groups Projects
Commit 5d34471e authored by Franck Dary's avatar Franck Dary
Browse files

Refactored transitions

parent e7ee9509
No related branches found
No related tags found
No related merge requests found
......@@ -46,7 +46,7 @@ class Action
static Action addHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & hypothesis);
static Action addToHypothesisRelative(const std::string & colName, Config::Object object, int relativeIndex, const std::string & addition);
static Action pushWordIndexOnStack();
static Action popStack();
static Action popStack(int relIndex);
static Action emptyStack();
static Action setRoot(int bufferIndex);
static Action updateIds(int bufferIndex);
......
......@@ -89,6 +89,7 @@ class Config
const String & getAsFeature(int colIndex, int lineIndex) const;
ValueIterator getIterator(int colIndex, int lineIndex, int hypothesisIndex);
ConstValueIterator getConstIterator(int colIndex, int lineIndex, int hypothesisIndex) const;
std::size_t & getStackRef(int relativeIndex);
long getRelativeWordIndex(int relativeIndex) const;
......@@ -112,6 +113,7 @@ class Config
void addToHistory(const std::string & transition);
void addToStack(std::size_t index);
void popStack();
void swapStack(int relIndex1, int relIndex2);
bool isComment(std::size_t lineIndex) const;
bool isCommentPredicted(std::size_t lineIndex) const;
bool isMultiword(std::size_t lineIndex) const;
......@@ -134,6 +136,7 @@ class Config
bool hasRelativeWordIndex(Object object, int relativeIndex) const;
const String & getHistory(int relativeIndex) const;
std::size_t getStack(int relativeIndex) const;
std::size_t getStackSize() const;
bool hasHistory(int relativeIndex) const;
bool hasStack(int relativeIndex) const;
String getState() const;
......
......@@ -16,6 +16,11 @@ class Transition
private :
static int getNbLinkedWith(int firstIndex, int lastIndex, Config::Object object, int withIndex, const Config & config);
static int getFirstIndexOfSentence(int baseIndex, const Config & config);
static int getLastIndexOfSentence(int baseIndex, const Config & config);
void initWrite(std::string colName, std::string object, std::string index, std::string value);
void initAdd(std::string colName, std::string object, std::string index, std::string value);
void initShift();
......
......@@ -268,23 +268,27 @@ Action Action::pushWordIndexOnStack()
return {Type::Push, apply, undo, appliable};
}
Action Action::popStack()
Action Action::popStack(int relIndex)
{
auto apply = [](Config & config, Action & a)
auto apply = [relIndex](Config & config, Action & a)
{
auto toSave = config.getStack(0);
auto toSave = config.getStack(relIndex);
a.data.push_back(std::to_string(toSave));
for (int i = 0; relIndex-1-i >= 0; i++)
config.swapStack(relIndex-i, relIndex-1-i);
config.popStack();
};
auto undo = [](Config & config, Action & a)
auto undo = [relIndex](Config & config, Action & a)
{
config.addToStack(std::stoi(a.data.back()));
for (int i = 0; i+1 <= relIndex; i++)
config.swapStack(i, i+1);
};
auto appliable = [](const Config & config, const Action &)
auto appliable = [relIndex](const Config & config, const Action &)
{
return config.hasStack(0) and config.getStack(0) != config.getWordIndex();
return config.hasStack(relIndex) and config.getStack(relIndex) != config.getWordIndex();
};
return {Type::Pop, apply, undo, appliable};
......
......@@ -357,6 +357,13 @@ void Config::popStack()
stack.pop_back();
}
void Config::swapStack(int relIndex1, int relIndex2)
{
int tmp = getStack(relIndex1);
getStackRef(relIndex1) = relIndex2;
getStackRef(relIndex2) = tmp;
}
bool Config::hasCharacter(int letterIndex) const
{
return letterIndex >= 0 and letterIndex < (int)util::getSize(rawInput);
......@@ -529,6 +536,11 @@ std::size_t Config::getStack(int relativeIndex) const
return stack[stack.size()-1-relativeIndex];
}
std::size_t & Config::getStackRef(int relativeIndex)
{
return stack[stack.size()-1-relativeIndex];
}
bool Config::hasHistory(int relativeIndex) const
{
return relativeIndex >= 0 && relativeIndex < (int)history.size();
......@@ -710,3 +722,8 @@ void Config::setLastAttached(int lastAttached)
this->lastAttached = lastAttached;
}
std::size_t Config::getStackSize() const
{
return stack.size();
}
......@@ -246,22 +246,7 @@ void Transition::initShift()
if (!config.isToken(config.getWordIndex()))
return 0;
auto headGovIndex = config.getConst(Config::headColName, config.getWordIndex(), 0);
int cost = 0;
for (int i = 0; config.hasStack(i); ++i)
{
if (!config.has(0, config.getStack(i), 0))
continue;
auto stackIndex = config.getStack(i);
auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);
if (stackGovIndex == std::to_string(config.getWordIndex()) || headGovIndex == std::to_string(stackIndex))
++cost;
}
return cost;
return getNbLinkedWith(0, config.getStackSize()-1, Config::Object::Stack, config.getWordIndex(), config);
};
}
......@@ -269,34 +254,18 @@ void Transition::initEagerLeft_rel(std::string label)
{
sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0));
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label));
sequence.emplace_back(Action::popStack());
sequence.emplace_back(Action::popStack(0));
cost = [label](const Config & config)
{
auto stackIndex = config.getStack(0);
auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);
auto wordIndex = config.getWordIndex();
if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
return std::numeric_limits<int>::max();
int cost = 0;
auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);
for (int i = wordIndex+1; config.has(0, i, 0); ++i)
{
if (!config.isToken(i))
continue;
int cost = getNbLinkedWith(wordIndex+1, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config);
if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
break;
auto otherGovIndex = config.getConst(Config::headColName, i, 0);
if (stackGovIndex == std::to_string(i) || otherGovIndex == std::to_string(stackIndex))
++cost;
}
//TODO : Check if this is necessary
if (stackGovIndex != std::to_string(wordIndex))
++cost;
......@@ -310,34 +279,18 @@ void Transition::initEagerLeft_rel(std::string label)
void Transition::initEagerLeft()
{
sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0));
sequence.emplace_back(Action::popStack());
sequence.emplace_back(Action::popStack(0));
cost = [](const Config & config)
{
auto stackIndex = config.getStack(0);
auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);
auto wordIndex = config.getWordIndex();
if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
return std::numeric_limits<int>::max();
int cost = 0;
auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);
for (int i = wordIndex+1; config.has(0, i, 0); ++i)
{
if (!config.isToken(i))
continue;
if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
break;
auto otherGovIndex = config.getConst(Config::headColName, i, 0);
if (stackGovIndex == std::to_string(i) || otherGovIndex == std::to_string(stackIndex))
++cost;
}
int cost = getNbLinkedWith(wordIndex+1, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config);
//TODO : Check if this is necessary
if (stackGovIndex != std::to_string(wordIndex))
++cost;
......@@ -359,37 +312,13 @@ void Transition::initEagerRight_rel(std::string label)
if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
return std::numeric_limits<int>::max();
int cost = 0;
auto bufferGovIndex = config.getConst(Config::headColName, wordIndex, 0);
for (int i = wordIndex; config.has(0, i, 0); ++i)
{
if (!config.isToken(i))
continue;
auto otherGovIndex = config.getConst(Config::headColName, i, 0);
if (bufferGovIndex == std::to_string(i))
++cost;
if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
break;
}
for (int i = 1; config.hasStack(i); ++i)
{
if (!config.has(0, config.getStack(i), 0))
continue;
auto otherStackIndex = config.getStack(i);
auto otherStackGov = config.getConst(Config::headColName, otherStackIndex, 0);
if (config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1)
return std::numeric_limits<int>::max();
if (otherStackGov == std::to_string(wordIndex) || bufferGovIndex == std::to_string(otherStackIndex))
++cost;
}
int cost = getNbLinkedWith(0, config.getStackSize()-1, Config::Object::Stack, config.getWordIndex(), config);
//TODO : Check if this is necessary
if (bufferGovIndex != std::to_string(stackIndex))
++cost;
......@@ -413,37 +342,13 @@ void Transition::initEagerRight()
if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
return std::numeric_limits<int>::max();
int cost = 0;
auto bufferGovIndex = config.getConst(Config::headColName, wordIndex, 0);
for (int i = wordIndex; config.has(0, i, 0); ++i)
{
if (!config.isToken(i))
continue;
auto otherGovIndex = config.getConst(Config::headColName, i, 0);
if (bufferGovIndex == std::to_string(i))
++cost;
if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
break;
}
for (int i = 1; config.hasStack(i); ++i)
{
if (!config.has(0, config.getStack(i), 0))
continue;
auto otherStackIndex = config.getStack(i);
auto otherStackGov = config.getConst(Config::headColName, otherStackIndex, 0);
if (config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1)
return std::numeric_limits<int>::max();
if (otherStackGov == std::to_string(wordIndex) || bufferGovIndex == std::to_string(otherStackIndex))
++cost;
}
int cost = getNbLinkedWith(0, config.getStackSize()-1, Config::Object::Stack, config.getWordIndex(), config);
//TODO : Check if this is necessary
if (bufferGovIndex != std::to_string(stackIndex))
++cost;
......@@ -454,33 +359,19 @@ void Transition::initEagerRight()
void Transition::initReduce_strict()
{
sequence.emplace_back(Action::assertIsNotEmpty(Config::headColName, Config::Object::Stack, 0));
sequence.emplace_back(Action::popStack());
sequence.emplace_back(Action::popStack(0));
cost = [](const Config & config)
{
if (!config.isToken(config.getStack(0)))
return 0;
int cost = 0;
auto stackIndex = config.getStack(0);
auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);
for (int i = config.getWordIndex(); config.has(0, i, 0); ++i)
{
if (!config.isToken(i))
continue;
auto otherGovIndex = config.getConst(Config::headColName, i, 0);
auto wordIndex = config.getWordIndex();
if (stackGovIndex == std::to_string(i) || otherGovIndex == std::to_string(stackIndex))
++cost;
if (!config.isToken(stackIndex))
return 0;
if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
break;
}
int cost = getNbLinkedWith(wordIndex, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config);
if (config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1)
if (config.getConst(Config::EOSColName, stackIndex, 0) == Config::EOSSymbol1)
++cost;
return cost;
......@@ -489,33 +380,19 @@ void Transition::initReduce_strict()
void Transition::initReduce_relaxed()
{
sequence.emplace_back(Action::popStack());
sequence.emplace_back(Action::popStack(0));
cost = [](const Config & config)
{
if (!config.isToken(config.getStack(0)))
return 0;
int cost = 0;
auto stackIndex = config.getStack(0);
auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);
for (int i = config.getWordIndex(); config.has(0, i, 0); ++i)
{
if (!config.isToken(i))
continue;
auto otherGovIndex = config.getConst(Config::headColName, i, 0);
auto wordIndex = config.getWordIndex();
if (stackGovIndex == std::to_string(i) || otherGovIndex == std::to_string(stackIndex))
++cost;
if (!config.isToken(stackIndex))
return 0;
if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
break;
}
int cost = getNbLinkedWith(wordIndex, getLastIndexOfSentence(wordIndex, config), Config::Object::Buffer, stackIndex, config);
if (config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1)
if (config.getConst(Config::EOSColName, stackIndex, 0) == Config::EOSSymbol1)
++cost;
return cost;
......@@ -549,3 +426,63 @@ void Transition::initDeprel(std::string label)
};
}
int Transition::getNbLinkedWith(int firstIndex, int lastIndex, Config::Object object, int withIndex, const Config & config)
{
auto govIndex = config.getConst(Config::headColName, withIndex, 0);
int nbLinkedWith = 0;
for (int i = firstIndex; i <= lastIndex; ++i)
{
int index = i;
if (object == Config::Object::Stack)
index = config.getStack(i);
if (!config.isToken(index))
continue;
auto otherGovIndex = config.getConst(Config::headColName, index, 0);
if (govIndex == std::to_string(index) || otherGovIndex == std::to_string(withIndex))
++nbLinkedWith;
}
return nbLinkedWith;
}
int Transition::getFirstIndexOfSentence(int baseIndex, const Config & config)
{
int firstIndex = baseIndex;
for (int i = baseIndex; config.has(0, i, 0); --i)
{
if (!config.isToken(i))
continue;
if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
break;
firstIndex = i;
}
return firstIndex;
}
int Transition::getLastIndexOfSentence(int baseIndex, const Config & config)
{
int lastIndex = baseIndex;
for (int i = baseIndex; config.has(0, i, 0); ++i)
{
if (!config.isToken(i))
continue;
lastIndex = i;
if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
break;
}
return lastIndex;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment