Newer
Older
Transition::Transition(const std::string & name)
{
std::vector<std::pair<std::regex, std::function<void(const std::smatch &)>>> inits
{
{std::regex("WRITE ([bs])\\.(.+) (.+) (.+)"),
[this](auto sm){(initWrite(sm[3], sm[1], sm[2], sm[4]));}},
{std::regex("ADD ([bs])\\.(.+) (.+) (.+)"),
[this](auto sm){(initAdd(sm[3], sm[1], sm[2], sm[4]));}},
{std::regex("SHIFT"),
[this](auto){initShift();}},
{std::regex("REDUCE"),
[this](auto){initReduce();}},
{std::regex("LEFT (.+)"),
[this](auto sm){(initLeft(sm[1]));}},
{std::regex("RIGHT (.+)"),
[this](auto sm){(initRight(sm[1]));}},
{std::regex("EOS b\\.(.+)"),
[this](auto sm){initEOS(std::stoi(sm[1]));}},
{std::regex("NOTHING"),
[this](auto){initNothing();}},
{std::regex("IGNORECHAR"),
[this](auto){initIgnoreChar();}},
{std::regex("ENDWORD"),
[this](auto){initEndWord();}},
{std::regex("ADDCHARTOWORD"),
[this](auto){initAddCharToWord();}},
{std::regex("SPLIT (.+)"),
[this](auto sm){(initSplit(std::stoi(sm.str(1))));}},
{std::regex("SPLITWORD ([^@]+)(:?(:?@[^@]+)+)"),
[this](auto sm)
{
std::vector<std::string> splitRes{sm[1]};
auto splited = util::split(std::string(sm[2]), '@');
for (auto & s : splited)
splitRes.emplace_back(s);
initSplitWord(splitRes);
}},
};
if (!util::doIfNameMatch(std::regex("(<(.+)> )?(.+)"), name, [this, name](auto sm)
{
this->state = sm[2];
this->name = sm[3];
}))
util::myThrow("doesn't match nameRegex");
for (auto & it : inits)
if (util::doIfNameMatch(it.first, this->name, it.second))
return;
Franck Dary
committed
throw std::invalid_argument("no match");
} catch (std::exception & e) {util::myThrow(fmt::format("Invalid name '{}' ({})", this->name, e.what()));}
Franck Dary
committed
}
void Transition::apply(Config & config)
{
for (Action & action : sequence)
action.apply(config, action);
}
bool Transition::appliable(const Config & config) const
{
if (!state.empty() && state != config.getState())
return false;
for (const Action & action : sequence)
if (!action.appliable(config, action))
return false;
return true;
}
int Transition::getCost(const Config & config) const
{
return cost(config);
}
const std::string & Transition::getName() const
{
return name;
}
Franck Dary
committed
void Transition::initWrite(std::string colName, std::string object, std::string index, std::string value)
{
Franck Dary
committed
int indexValue = std::stoi(index);
sequence.emplace_back(Action::addHypothesisRelative(colName, objectValue, indexValue, value));
cost = [colName, objectValue, indexValue, value](const Config & config)
{
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
Franck Dary
committed
if (config.getConst(colName, lineIndex, 0) == value)
return 0;
Franck Dary
committed
return 1;
};
void Transition::initAdd(std::string colName, std::string object, std::string index, std::string value)
{
int indexValue = std::stoi(index);
sequence.emplace_back(Action::addToHypothesisRelative(colName, objectValue, indexValue, value));
cost = [colName, objectValue, indexValue, value](const Config & config)
{
int lineIndex = config.getRelativeWordIndex(objectValue, indexValue);
auto gold = util::split(config.getConst(colName, lineIndex, 0).get(), '|');
for (auto & part : gold)
if (part == value)
return 0;
return 1;
};
}
void Transition::initNothing()
{
cost = [](const Config &)
{
return 0;
};
}
void Transition::initIgnoreChar()
{
sequence.emplace_back(Action::ignoreCurrentCharacter());
cost = [](const Config &)
{
return 0;
};
}
void Transition::initEndWord()
{
sequence.emplace_back(Action::endWord());
cost = [](const Config & config)
{
if (config.getConst("FORM", config.getWordIndex(), 0) == config.getAsFeature("FORM", config.getWordIndex()))
return 0;
return 1;
};
}
void Transition::initAddCharToWord()
{
sequence.emplace_back(Action::assertIsEmpty(Config::idColName));
sequence.emplace_back(Action::addLinesIfNeeded(0));
sequence.emplace_back(Action::addCurCharToCurWord());
sequence.emplace_back(Action::moveCharacterIndex(1));
cost = [](const Config & config)
{
if (!config.hasCharacter(config.getCharacterIndex()))
return std::numeric_limits<int>::max();
auto letter = fmt::format("{}", config.getLetter(config.getCharacterIndex()));
auto & goldWord = config.getConst("FORM", config.getWordIndex(), 0).get();
auto & curWord = config.getAsFeature("FORM", config.getWordIndex()).get();
if (curWord.size() + letter.size() > goldWord.size())
return 1;
for (unsigned int i = 0; i < letter.size(); i++)
if (goldWord[curWord.size()+i] != letter[i])
return 1;
return 0;
};
}
void Transition::initSplitWord(std::vector<std::string> words)
{
auto consumedWord = util::splitAsUtf8(words[0]);
sequence.emplace_back(Action::assertIsEmpty(Config::idColName));
sequence.emplace_back(Action::assertIsEmpty("FORM"));
sequence.emplace_back(Action::addLinesIfNeeded(words.size()));
sequence.emplace_back(Action::consumeCharacterIndex(consumedWord));
for (unsigned int i = 0; i < words.size(); i++)
sequence.emplace_back(Action::addHypothesisRelative("FORM", Config::Object::Buffer, i, words[i]));
Franck Dary
committed
sequence.emplace_back(Action::setMultiwordIds(words.size()-1));
cost = [words](const Config & config)
{
Franck Dary
committed
if (!config.isMultiword(config.getWordIndex()))
return std::numeric_limits<int>::max();
if (config.getMultiwordSize(config.getWordIndex())+2 != (int)words.size())
return std::numeric_limits<int>::max();
int cost = 0;
for (unsigned int i = 0; i < words.size(); i++)
if (!config.has("FORM", config.getWordIndex()+i, 0) or config.getConst("FORM", config.getWordIndex()+i, 0) != words[i])
cost++;
return cost;
};
}
void Transition::initSplit(int index)
{
sequence.emplace_back(Action::split(index));
cost = [index](const Config & config)
{
auto & transitions = config.getAppliableSplitTransitions();
if (index < 0 or index >= (int)transitions.size())
return std::numeric_limits<int>::max();
return transitions[index]->getCost(config);
};
}
sequence.emplace_back(Action::pushWordIndexOnStack());
cost = [](const Config & config)
{
if (config.hasStack(0) && config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1)
return std::numeric_limits<int>::max();
auto headGovIndex = config.getConst(Config::headColName, config.getWordIndex(), 0);
int cost = 0;
for (int i = 0; config.hasStack(i); ++i)
{
Franck Dary
committed
if (!config.has(0, config.getStack(i), 0))
continue;
auto stackIndex = config.getStack(i);
auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);
if (stackGovIndex == std::to_string(config.getWordIndex()) || headGovIndex == std::to_string(stackIndex))
++cost;
}
return cost;
};
}
void Transition::initLeft(std::string label)
{
sequence.emplace_back(Action::attach(Config::Object::Buffer, 0, Config::Object::Stack, 0));
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Stack, 0, label));
sequence.emplace_back(Action::popStack());
cost = [label](const Config & config)
{
auto stackIndex = config.getStack(0);
auto wordIndex = config.getWordIndex();
if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
return std::numeric_limits<int>::max();
int cost = 0;
auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);
for (int i = wordIndex+1; config.has(0, i, 0); ++i)
{
if (!config.isToken(i))
continue;
if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
break;
auto otherGovIndex = config.getConst(Config::headColName, i, 0);
if (stackGovIndex == std::to_string(i) || otherGovIndex == std::to_string(stackIndex))
++cost;
}
//TODO : Check if this is necessary
if (stackGovIndex != std::to_string(wordIndex))
++cost;
if (label != config.getConst(Config::deprelColName, stackIndex, 0))
++cost;
return cost;
};
}
void Transition::initRight(std::string label)
{
sequence.emplace_back(Action::attach(Config::Object::Stack, 0, Config::Object::Buffer, 0));
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Config::Object::Buffer, 0, label));
sequence.emplace_back(Action::pushWordIndexOnStack());
cost = [label](const Config & config)
{
auto stackIndex = config.getStack(0);
auto wordIndex = config.getWordIndex();
if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
return std::numeric_limits<int>::max();
int cost = 0;
auto bufferGovIndex = config.getConst(Config::headColName, wordIndex, 0);
for (int i = wordIndex; config.has(0, i, 0); ++i)
{
if (!config.isToken(i))
continue;
auto otherGovIndex = config.getConst(Config::headColName, i, 0);
++cost;
if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
break;
}
for (int i = 1; config.hasStack(i); ++i)
{
Franck Dary
committed
if (!config.has(0, config.getStack(i), 0))
continue;
auto otherStackIndex = config.getStack(i);
auto otherStackGov = config.getConst(Config::headColName, otherStackIndex, 0);
if (otherStackGov == std::to_string(wordIndex) || bufferGovIndex == std::to_string(otherStackIndex))
++cost;
}
//TODO : Check if this is necessary
if (bufferGovIndex != std::to_string(stackIndex))
++cost;
if (label != config.getConst(Config::deprelColName, wordIndex, 0))
++cost;
return cost;
};
}
void Transition::initReduce()
{
sequence.emplace_back(Action::popStack());
cost = [](const Config & config)
{
Franck Dary
committed
if (!config.isToken(config.getStack(0)))
auto stackIndex = config.getStack(0);
auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);
for (int i = config.getWordIndex(); config.has(0, i, 0); ++i)
{
if (!config.isToken(i))
continue;
auto otherGovIndex = config.getConst(Config::headColName, i, 0);
if (stackGovIndex == std::to_string(i) || otherGovIndex == std::to_string(stackIndex))
++cost;
if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
break;
}
if (config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1)
++cost;
return cost;
};
sequence.emplace_back(Action::setRoot(bufferIndex));
sequence.emplace_back(Action::updateIds(bufferIndex));
sequence.emplace_back(Action::addHypothesisRelative(Config::EOSColName, Config::Object::Buffer, bufferIndex, Config::EOSSymbol1));
int lineIndex = config.getRelativeWordIndex(Config::Buffer, bufferIndex);
if (config.getConst(Config::EOSColName, lineIndex, 0) != Config::EOSSymbol1)
return std::numeric_limits<int>::max();