Skip to content
Snippets Groups Projects
Commit bf428a88 authored by Franck Dary's avatar Franck Dary
Browse files

Splited arc eager transitions in 2 kinds, one that does deprel and one who...

Splited arc eager transitions in 2 kinds, one that does deprel and one who doesnt. Also added deprel transition
parent c73a43d3
No related branches found
No related tags found
No related merge requests found
......@@ -60,6 +60,7 @@ class Action
static Action setMultiwordIds(int multiwordSize);
static Action split(int index);
static Action setRootUpdateIdsEmptyStackIfSentChanged();
static Action deprel(std::string value);
};
#endif
......@@ -45,6 +45,7 @@ class Config
std::vector<String> lines;
std::set<std::string> predicted;
int lastPoppedStack{-1};
int lastAttached{-1};
int currentWordId{0};
std::vector<Transition *> appliableSplitTransitions;
std::vector<int> appliableTransitions;
......@@ -141,6 +142,8 @@ class Config
void addPredicted(const std::set<std::string> & predicted);
bool isPredicted(const std::string & colName) const;
int getLastPoppedStack() const;
int getLastAttached() const;
void setLastAttached(int lastAttached);
int getCurrentWordId() const;
void setCurrentWordId(int currentWordId);
void addMissingColumns();
......
......@@ -19,8 +19,11 @@ class Transition
void initWrite(std::string colName, std::string object, std::string index, std::string value);
void initAdd(std::string colName, std::string object, std::string index, std::string value);
void initShift();
void initLeft(std::string label);
void initRight(std::string label);
void initEagerLeft_rel(std::string label);
void initEagerRight_rel(std::string label);
void initDeprel(std::string label);
void initEagerLeft();
void initEagerRight();
void initReduce();
void initEOS(int bufferIndex);
void initNothing();
......
......@@ -582,12 +582,16 @@ Action Action::attach(Config::Object governorObject, int governorIndex, Config::
addHypothesisRelative(Config::headColName, dependentObject, dependentIndex, std::to_string(lineIndex)).apply(config, a);
addToHypothesisRelative(Config::childsColName, governorObject, governorIndex, std::to_string(depIndex)).apply(config, a);
a.data.emplace_back(std::to_string(config.getLastAttached()));
config.setLastAttached(depIndex);
};
auto undo = [governorObject, governorIndex, dependentObject, dependentIndex](Config & config, Action & a)
{
addHypothesisRelative(Config::headColName, dependentObject, dependentIndex, "").undo(config, a);
addToHypothesisRelative(Config::childsColName, governorObject, governorIndex, "").apply(config, a);
config.setLastAttached(std::stoi(a.data.back()));
a.data.pop_back();
};
auto appliable = [governorObject, governorIndex, dependentObject, dependentIndex](const Config & config, const Action & action)
......@@ -740,3 +744,23 @@ Action Action::setRootUpdateIdsEmptyStackIfSentChanged()
return {Type::Write, apply, undo, appliable};
}
Action Action::deprel(std::string value)
{
auto apply = [value](Config & config, Action & a)
{
addHypothesis(Config::deprelColName, config.getLastAttached(), value).apply(config, a);
};
auto undo = [](Config & config, Action & a)
{
addHypothesis(Config::deprelColName, config.getLastAttached(), "").undo(config, a);
};
auto appliable = [](const Config & config, const Action & action)
{
return config.has(0,config.getLastAttached(),0);
};
return {Type::Write, apply, undo, appliable};
}
......@@ -524,6 +524,8 @@ const Config::String & Config::getHistory(int relativeIndex) const
std::size_t Config::getStack(int relativeIndex) const
{
if (relativeIndex == -1)
return getLastPoppedStack();
return stack[stack.size()-1-relativeIndex];
}
......@@ -534,6 +536,8 @@ bool Config::hasHistory(int relativeIndex) const
bool Config::hasStack(int relativeIndex) const
{
if (relativeIndex == -1)
return has(0,getLastPoppedStack(),0);
return relativeIndex >= 0 && relativeIndex < (int)stack.size();
}
......@@ -696,3 +700,13 @@ bool Config::isExtraColumn(const std::string & colName) const
return false;
}
int Config::getLastAttached() const
{
return lastAttached;
}
void Config::setLastAttached(int lastAttached)
{
this->lastAttached = lastAttached;
}
......@@ -13,10 +13,16 @@ Transition::Transition(const std::string & name)
[this](auto){initShift();}},
{std::regex("REDUCE"),
[this](auto){initReduce();}},
{std::regex("LEFT (.+)"),
[this](auto sm){(initLeft(sm[1]));}},
{std::regex("RIGHT (.+)"),
[this](auto sm){(initRight(sm[1]));}},
{std::regex("eager_LEFT_rel (.+)"),
[this](auto sm){(initEagerLeft_rel(sm[1]));}},
{std::regex("eager_RIGHT_rel (.+)"),
[this](auto sm){(initEagerRight_rel(sm[1]));}},
{std::regex("eager_LEFT"),
[this](auto){(initEagerLeft());}},
{std::regex("eager_RIGHT"),
[this](auto){(initEagerRight());}},
{std::regex("deprel (.+)"),
[this](auto sm){(initDeprel(sm[1]));}},
{std::regex("EOS b\\.(.+)"),
[this](auto sm){initEOS(std::stoi(sm[1]));}},
{std::regex("NOTHING"),
......@@ -257,7 +263,7 @@ void Transition::initShift()
};
}
void Transition::initLeft(std::string label)
void Transition::initEagerLeft_rel(std::string label)
{
sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0));
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label));
......@@ -299,7 +305,45 @@ void Transition::initLeft(std::string label)
};
}
void Transition::initRight(std::string label)
void Transition::initEagerLeft()
{
sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0));
sequence.emplace_back(Action::popStack());
cost = [](const Config & config)
{
auto stackIndex = config.getStack(0);
auto wordIndex = config.getWordIndex();
if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
return std::numeric_limits<int>::max();
int cost = 0;
auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);
for (int i = wordIndex+1; config.has(0, i, 0); ++i)
{
if (!config.isToken(i))
continue;
if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
break;
auto otherGovIndex = config.getConst(Config::headColName, i, 0);
if (stackGovIndex == std::to_string(i) || otherGovIndex == std::to_string(stackIndex))
++cost;
}
//TODO : Check if this is necessary
if (stackGovIndex != std::to_string(wordIndex))
++cost;
return cost;
};
}
void Transition::initEagerRight_rel(std::string label)
{
sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Buffer, 0));
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Buffer, 0, label));
......@@ -354,6 +398,57 @@ void Transition::initRight(std::string label)
};
}
void Transition::initEagerRight()
{
sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Buffer, 0));
sequence.emplace_back(Action::pushWordIndexOnStack());
sequence.emplace_back(Action::setRootUpdateIdsEmptyStackIfSentChanged());
cost = [](const Config & config)
{
auto stackIndex = config.getStack(0);
auto wordIndex = config.getWordIndex();
if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
return std::numeric_limits<int>::max();
int cost = 0;
auto bufferGovIndex = config.getConst(Config::headColName, wordIndex, 0);
for (int i = wordIndex; config.has(0, i, 0); ++i)
{
if (!config.isToken(i))
continue;
auto otherGovIndex = config.getConst(Config::headColName, i, 0);
if (bufferGovIndex == std::to_string(i))
++cost;
if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
break;
}
for (int i = 1; config.hasStack(i); ++i)
{
if (!config.has(0, config.getStack(i), 0))
continue;
auto otherStackIndex = config.getStack(i);
auto otherStackGov = config.getConst(Config::headColName, otherStackIndex, 0);
if (otherStackGov == std::to_string(wordIndex) || bufferGovIndex == std::to_string(otherStackIndex))
++cost;
}
//TODO : Check if this is necessary
if (bufferGovIndex != std::to_string(stackIndex))
++cost;
return cost;
};
}
void Transition::initReduce()
{
sequence.emplace_back(Action::assertIsNotEmpty(Config::headColName, Config::Object::Stack, 0));
......@@ -407,3 +502,13 @@ void Transition::initEOS(int bufferIndex)
};
}
void Transition::initDeprel(std::string label)
{
sequence.emplace_back(Action::deprel(label));
cost = [label](const Config & config)
{
return config.getConst(Config::deprelColName, config.getLastAttached(), 0) == label ? 0 : 1;
};
}
......@@ -102,6 +102,25 @@ Transition * TransitionSet::getBestAppliableTransition(const Config & c)
}
}
if (!result)
{
for (unsigned int i = 0; i < transitions.size(); i++)
{
fmt::print(stderr, "{}\n", transitions[i].getName());
if (!transitions[i].appliable(c))
{
fmt::print(stderr, "not appliable\n");
continue;
}
fmt::print(stderr, "appliable\n");
int cost = transitions[i].getCost(c);
fmt::print(stderr, "cost {}\n", cost);
}
}
return result;
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment