Newer
Older
Transition::Transition(const std::string & name)
{
std::regex nameRegex("(<(.+)> )?(.+)");
std::regex writeRegex("WRITE ([bs])\\.(.+) (.+) (.+)");
std::regex addRegex("ADD ([bs])\\.(.+) (.+) (.+)");
std::regex shiftRegex("SHIFT");
std::regex reduceRegex("REDUCE");
std::regex leftRegex("LEFT (.+)");
std::regex rightRegex("RIGHT (.+)");
if (!util::doIfNameMatch(nameRegex, name, [this, name](auto sm)
{
this->state = sm[2];
this->name = sm[3];
}))
util::myThrow("doesn't match nameRegex");
if (util::doIfNameMatch(writeRegex, this->name, [this](auto sm){initWrite(sm[3], sm[1], sm[2], sm[4]);}))
if (util::doIfNameMatch(addRegex, this->name, [this](auto sm){initAdd(sm[3], sm[1], sm[2], sm[4]);}))
return;
if (util::doIfNameMatch(shiftRegex, this->name, [this](auto){initShift();}))
if (util::doIfNameMatch(reduceRegex, this->name, [this](auto){initReduce();}))
if (util::doIfNameMatch(leftRegex, this->name, [this](auto sm){initLeft(sm[1]);}))
if (util::doIfNameMatch(rightRegex, this->name, [this](auto sm){initRight(sm[1]);}))
if (util::doIfNameMatch(eosRegex, this->name, [this](auto){initEOS();}))
if (util::doIfNameMatch(nothingRegex, this->name, [this](auto){initNothing();}))
return;
Franck Dary
committed
throw std::invalid_argument("no match");
} catch (std::exception & e) {util::myThrow(fmt::format("Invalid name '{}' ({})", this->name, e.what()));}
Franck Dary
committed
}
void Transition::apply(Config & config)
{
for (Action & action : sequence)
action.apply(config, action);
}
bool Transition::appliable(const Config & config) const
{
if (!state.empty() && state != config.getState())
return false;
for (const Action & action : sequence)
if (!action.appliable(config, action))
return false;
return true;
}
int Transition::getCost(const Config & config) const
{
return cost(config);
}
const std::string & Transition::getName() const
{
return name;
}
Franck Dary
committed
void Transition::initWrite(std::string colName, std::string object, std::string index, std::string value)
{
auto objectValue = Action::str2object(object);
int indexValue = std::stoi(index);
sequence.emplace_back(Action::addHypothesisRelative(colName, objectValue, indexValue, value));
cost = [colName, objectValue, indexValue, value](const Config & config)
{
int lineIndex = 0;
if (objectValue == Action::Object::Buffer)
lineIndex = config.getWordIndex() + indexValue;
else
lineIndex = config.getStack(indexValue);
Franck Dary
committed
if (config.getConst(colName, lineIndex, 0) == value)
return 0;
Franck Dary
committed
return 1;
};
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
void Transition::initAdd(std::string colName, std::string object, std::string index, std::string value)
{
auto objectValue = Action::str2object(object);
int indexValue = std::stoi(index);
sequence.emplace_back(Action::addToHypothesisRelative(colName, objectValue, indexValue, value));
cost = [colName, objectValue, indexValue, value](const Config & config)
{
int lineIndex = 0;
if (objectValue == Action::Object::Buffer)
lineIndex = config.getWordIndex() + indexValue;
else
lineIndex = config.getStack(indexValue);
auto gold = util::split(config.getConst(colName, lineIndex, 0).get(), '|');
for (auto & part : gold)
if (part == value)
return 0;
return 1;
};
}
void Transition::initNothing()
{
cost = [](const Config &)
{
return 0;
};
}
sequence.emplace_back(Action::pushWordIndexOnStack());
cost = [](const Config & config)
{
if (config.hasStack(0) && config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1)
return std::numeric_limits<int>::max();
auto headGovIndex = config.getConst(Config::headColName, config.getWordIndex(), 0);
int cost = 0;
for (int i = 0; config.hasStack(i); ++i)
{
Franck Dary
committed
if (!config.has(0, config.getStack(i), 0))
continue;
auto stackIndex = config.getStack(i);
auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);
if (stackGovIndex == std::to_string(config.getWordIndex()) || headGovIndex == std::to_string(stackIndex))
++cost;
}
return cost;
};
}
void Transition::initLeft(std::string label)
{
sequence.emplace_back(Action::attach(Action::Object::Buffer, 0, Action::Object::Stack, 0));
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Action::Object::Stack, 0, label));
sequence.emplace_back(Action::popStack());
cost = [label](const Config & config)
{
auto stackIndex = config.getStack(0);
auto wordIndex = config.getWordIndex();
if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
return std::numeric_limits<int>::max();
int cost = 0;
auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);
for (int i = wordIndex+1; config.has(0, i, 0); ++i)
{
if (!config.isToken(i))
continue;
if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
break;
auto otherGovIndex = config.getConst(Config::headColName, i, 0);
if (stackGovIndex == std::to_string(i) || otherGovIndex == std::to_string(stackIndex))
++cost;
}
//TODO : Check if this is necessary
if (stackGovIndex != std::to_string(wordIndex))
++cost;
if (label != config.getConst(Config::deprelColName, stackIndex, 0))
++cost;
return cost;
};
}
void Transition::initRight(std::string label)
{
sequence.emplace_back(Action::attach(Action::Object::Stack, 0, Action::Object::Buffer, 0));
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Action::Object::Buffer, 0, label));
sequence.emplace_back(Action::pushWordIndexOnStack());
cost = [label](const Config & config)
{
auto stackIndex = config.getStack(0);
auto wordIndex = config.getWordIndex();
if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
return std::numeric_limits<int>::max();
int cost = 0;
auto bufferGovIndex = config.getConst(Config::headColName, wordIndex, 0);
for (int i = wordIndex; config.has(0, i, 0); ++i)
{
if (!config.isToken(i))
continue;
auto otherGovIndex = config.getConst(Config::headColName, i, 0);
if (bufferGovIndex == std::to_string(i) || otherGovIndex == std::to_string(wordIndex))
++cost;
if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
break;
}
for (int i = 1; config.hasStack(i); ++i)
{
Franck Dary
committed
if (!config.has(0, config.getStack(i), 0))
continue;
auto otherStackIndex = config.getStack(i);
auto otherStackGov = config.getConst(Config::headColName, otherStackIndex, 0);
if (otherStackGov == std::to_string(wordIndex) || bufferGovIndex == std::to_string(otherStackIndex))
++cost;
}
//TODO : Check if this is necessary
if (bufferGovIndex != std::to_string(stackIndex))
++cost;
if (label != config.getConst(Config::deprelColName, wordIndex, 0))
++cost;
return cost;
};
}
void Transition::initReduce()
{
sequence.emplace_back(Action::popStack());
cost = [](const Config & config)
{
Franck Dary
committed
if (!config.has(0, config.getStack(0), 0))
return 0;
if (!config.isToken(config.getStack(0)))
auto stackIndex = config.getStack(0);
auto stackGovIndex = config.getConst(Config::headColName, config.getStack(0), 0);
for (int i = config.getWordIndex(); config.has(0, i, 0); ++i)
{
if (!config.isToken(i))
continue;
auto otherGovIndex = config.getConst(Config::headColName, i, 0);
if (stackGovIndex == std::to_string(i) || otherGovIndex == std::to_string(stackIndex))
++cost;
if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
break;
}
if (config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1)
++cost;
return cost;
};
void Transition::initEOS()
{
sequence.emplace_back(Action::setRoot());
sequence.emplace_back(Action::updateIds());
sequence.emplace_back(Action::addHypothesisRelative(Config::EOSColName, Action::Object::Stack, 0, Config::EOSSymbol1));
sequence.emplace_back(Action::emptyStack());
cost = [](const Config & config)
{
if (!config.has(0, config.getStack(0), 0))
return std::numeric_limits<int>::max();
if (!config.isToken(config.getStack(0)))
return std::numeric_limits<int>::max();
if (config.getConst(Config::EOSColName, config.getStack(0), 0) != Config::EOSSymbol1)
return std::numeric_limits<int>::max();
--cost;
for (int i = 0; config.hasStack(i); ++i)
{
if (!config.has(0, config.getStack(i), 0))
continue;
auto otherStackIndex = config.getStack(i);
auto otherStackGovPred = config.getLastNotEmptyHypConst(Config::headColName, otherStackIndex);
if (util::isEmpty(otherStackGovPred))