Newer
Older
Transition::Transition(const std::string & name)
{
std::vector<std::pair<std::regex, std::function<void(const std::smatch &)>>> inits
{
{std::regex("WRITE ([bs])\\.(.+) (.+) (.+)"),
[this](auto sm){(initWrite(sm[3], sm[1], sm[2], sm[4]));}},
{std::regex("ADD ([bs])\\.(.+) (.+) (.+)"),
[this](auto sm){(initAdd(sm[3], sm[1], sm[2], sm[4]));}},
{std::regex("SHIFT"),
[this](auto){initShift();}},
{std::regex("REDUCE"),
[this](auto){initReduce();}},
{std::regex("LEFT (.+)"),
[this](auto sm){(initLeft(sm[1]));}},
{std::regex("RIGHT (.+)"),
[this](auto sm){(initRight(sm[1]));}},
{std::regex("EOS"),
[this](auto){initEOS();}},
{std::regex("NOTHING"),
[this](auto){initNothing();}},
{std::regex("IGNORECHAR"),
[this](auto){initIgnoreChar();}},
{std::regex("ENDWORD"),
[this](auto){initEndWord();}},
{std::regex("ADDCHARTOWORD"),
[this](auto){initAddCharToWord();}},
{std::regex("SPLITWORD ([^@]+)(:?(:?@[^@]+)+)"),
[this](auto sm)
{
std::vector<std::string> splitRes{sm[1]};
auto splited = util::split(std::string(sm[2]), '@');
for (auto & s : splited)
splitRes.emplace_back(s);
initSplitWord(splitRes);
}},
};
if (!util::doIfNameMatch(std::regex("(<(.+)> )?(.+)"), name, [this, name](auto sm)
{
this->state = sm[2];
this->name = sm[3];
}))
util::myThrow("doesn't match nameRegex");
for (auto & it : inits)
if (util::doIfNameMatch(it.first, this->name, it.second))
return;
Franck Dary
committed
throw std::invalid_argument("no match");
} catch (std::exception & e) {util::myThrow(fmt::format("Invalid name '{}' ({})", this->name, e.what()));}
Franck Dary
committed
}
void Transition::apply(Config & config)
{
for (Action & action : sequence)
action.apply(config, action);
}
bool Transition::appliable(const Config & config) const
{
if (!state.empty() && state != config.getState())
return false;
for (const Action & action : sequence)
if (!action.appliable(config, action))
return false;
return true;
}
int Transition::getCost(const Config & config) const
{
return cost(config);
}
const std::string & Transition::getName() const
{
return name;
}
Franck Dary
committed
void Transition::initWrite(std::string colName, std::string object, std::string index, std::string value)
{
auto objectValue = Action::str2object(object);
int indexValue = std::stoi(index);
sequence.emplace_back(Action::addHypothesisRelative(colName, objectValue, indexValue, value));
cost = [colName, objectValue, indexValue, value](const Config & config)
{
int lineIndex = 0;
if (objectValue == Action::Object::Buffer)
lineIndex = config.getWordIndex() + indexValue;
else
lineIndex = config.getStack(indexValue);
Franck Dary
committed
if (config.getConst(colName, lineIndex, 0) == value)
return 0;
Franck Dary
committed
return 1;
};
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
void Transition::initAdd(std::string colName, std::string object, std::string index, std::string value)
{
auto objectValue = Action::str2object(object);
int indexValue = std::stoi(index);
sequence.emplace_back(Action::addToHypothesisRelative(colName, objectValue, indexValue, value));
cost = [colName, objectValue, indexValue, value](const Config & config)
{
int lineIndex = 0;
if (objectValue == Action::Object::Buffer)
lineIndex = config.getWordIndex() + indexValue;
else
lineIndex = config.getStack(indexValue);
auto gold = util::split(config.getConst(colName, lineIndex, 0).get(), '|');
for (auto & part : gold)
if (part == value)
return 0;
return 1;
};
}
void Transition::initNothing()
{
cost = [](const Config &)
{
return 0;
};
}
void Transition::initIgnoreChar()
{
sequence.emplace_back(Action::ignoreCurrentCharacter());
cost = [](const Config &)
{
return 0;
};
}
void Transition::initEndWord()
{
sequence.emplace_back(Action::endWord());
cost = [](const Config & config)
{
if (config.getConst("FORM", config.getWordIndex(), 0) == config.getAsFeature("FORM", config.getWordIndex()))
return 0;
return 1;
};
}
void Transition::initAddCharToWord()
{
sequence.emplace_back(Action::assertIsEmpty(Config::idColName));
sequence.emplace_back(Action::addLinesIfNeeded(0));
sequence.emplace_back(Action::addCurCharToCurWord());
sequence.emplace_back(Action::moveCharacterIndex(1));
cost = [](const Config & config)
{
if (!config.hasCharacter(config.getCharacterIndex()))
return std::numeric_limits<int>::max();
auto letter = fmt::format("{}", config.getLetter(config.getCharacterIndex()));
auto & goldWord = config.getConst("FORM", config.getWordIndex(), 0).get();
auto & curWord = config.getAsFeature("FORM", config.getWordIndex()).get();
if (curWord.size() + letter.size() > goldWord.size())
return 1;
for (unsigned int i = 0; i < letter.size(); i++)
if (goldWord[curWord.size()+i] != letter[i])
return 1;
return 0;
};
}
void Transition::initSplitWord(std::vector<std::string> words)
{
auto & consumedWord = words[0];
sequence.emplace_back(Action::assertIsEmpty(Config::idColName));
sequence.emplace_back(Action::assertIsEmpty("FORM"));
sequence.emplace_back(Action::addLinesIfNeeded(words.size()));
sequence.emplace_back(Action::consumeCharacterIndex(consumedWord));
for (unsigned int i = 0; i < words.size(); i++)
sequence.emplace_back(Action::addHypothesisRelative("FORM", Action::Object::Buffer, i, words[i]));
Franck Dary
committed
sequence.emplace_back(Action::setMultiwordIds(words.size()-1));
cost = [words](const Config & config)
{
int cost = 0;
for (unsigned int i = 0; i < words.size(); i++)
if (!config.has("FORM", config.getWordIndex()+i, 0) or config.getConst("FORM", config.getWordIndex()+i, 0) != words[i])
cost++;
return cost;
};
}
sequence.emplace_back(Action::pushWordIndexOnStack());
cost = [](const Config & config)
{
if (config.hasStack(0) && config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1)
return std::numeric_limits<int>::max();
auto headGovIndex = config.getConst(Config::headColName, config.getWordIndex(), 0);
int cost = 0;
for (int i = 0; config.hasStack(i); ++i)
{
Franck Dary
committed
if (!config.has(0, config.getStack(i), 0))
continue;
auto stackIndex = config.getStack(i);
auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);
if (stackGovIndex == std::to_string(config.getWordIndex()) || headGovIndex == std::to_string(stackIndex))
++cost;
}
return cost;
};
}
void Transition::initLeft(std::string label)
{
sequence.emplace_back(Action::attach(Action::Object::Buffer, 0, Action::Object::Stack, 0));
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Action::Object::Stack, 0, label));
sequence.emplace_back(Action::popStack());
cost = [label](const Config & config)
{
auto stackIndex = config.getStack(0);
auto wordIndex = config.getWordIndex();
if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
return std::numeric_limits<int>::max();
int cost = 0;
auto stackGovIndex = config.getConst(Config::headColName, stackIndex, 0);
for (int i = wordIndex+1; config.has(0, i, 0); ++i)
{
if (!config.isToken(i))
continue;
if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
break;
auto otherGovIndex = config.getConst(Config::headColName, i, 0);
if (stackGovIndex == std::to_string(i) || otherGovIndex == std::to_string(stackIndex))
++cost;
}
//TODO : Check if this is necessary
if (stackGovIndex != std::to_string(wordIndex))
++cost;
if (label != config.getConst(Config::deprelColName, stackIndex, 0))
++cost;
return cost;
};
}
void Transition::initRight(std::string label)
{
sequence.emplace_back(Action::attach(Action::Object::Stack, 0, Action::Object::Buffer, 0));
sequence.emplace_back(Action::addHypothesisRelative(Config::deprelColName, Action::Object::Buffer, 0, label));
sequence.emplace_back(Action::pushWordIndexOnStack());
cost = [label](const Config & config)
{
auto stackIndex = config.getStack(0);
auto wordIndex = config.getWordIndex();
if (!(config.isToken(stackIndex) && config.isToken(wordIndex)))
return std::numeric_limits<int>::max();
int cost = 0;
auto bufferGovIndex = config.getConst(Config::headColName, wordIndex, 0);
for (int i = wordIndex; config.has(0, i, 0); ++i)
{
if (!config.isToken(i))
continue;
auto otherGovIndex = config.getConst(Config::headColName, i, 0);
if (bufferGovIndex == std::to_string(i) || otherGovIndex == std::to_string(wordIndex))
++cost;
if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
break;
}
for (int i = 1; config.hasStack(i); ++i)
{
Franck Dary
committed
if (!config.has(0, config.getStack(i), 0))
continue;
auto otherStackIndex = config.getStack(i);
auto otherStackGov = config.getConst(Config::headColName, otherStackIndex, 0);
if (otherStackGov == std::to_string(wordIndex) || bufferGovIndex == std::to_string(otherStackIndex))
++cost;
}
//TODO : Check if this is necessary
if (bufferGovIndex != std::to_string(stackIndex))
++cost;
if (label != config.getConst(Config::deprelColName, wordIndex, 0))
++cost;
return cost;
};
}
void Transition::initReduce()
{
sequence.emplace_back(Action::popStack());
cost = [](const Config & config)
{
Franck Dary
committed
if (!config.has(0, config.getStack(0), 0))
return 0;
if (!config.isToken(config.getStack(0)))
auto stackIndex = config.getStack(0);
auto stackGovIndex = config.getConst(Config::headColName, config.getStack(0), 0);
for (int i = config.getWordIndex(); config.has(0, i, 0); ++i)
{
if (!config.isToken(i))
continue;
auto otherGovIndex = config.getConst(Config::headColName, i, 0);
if (stackGovIndex == std::to_string(i) || otherGovIndex == std::to_string(stackIndex))
++cost;
if (config.getConst(Config::EOSColName, i, 0) == Config::EOSSymbol1)
break;
}
if (config.getConst(Config::EOSColName, config.getStack(0), 0) == Config::EOSSymbol1)
++cost;
return cost;
};
void Transition::initEOS()
{
sequence.emplace_back(Action::setRoot());
sequence.emplace_back(Action::updateIds());
sequence.emplace_back(Action::addHypothesisRelative(Config::EOSColName, Action::Object::Stack, 0, Config::EOSSymbol1));
sequence.emplace_back(Action::emptyStack());
cost = [](const Config & config)
{
if (!config.has(0, config.getStack(0), 0))
return std::numeric_limits<int>::max();
if (!config.isToken(config.getStack(0)))
return std::numeric_limits<int>::max();
if (config.getConst(Config::EOSColName, config.getStack(0), 0) != Config::EOSSymbol1)
return std::numeric_limits<int>::max();
--cost;
for (int i = 0; config.hasStack(i); ++i)
{
if (!config.has(0, config.getStack(i), 0))
continue;
auto otherStackIndex = config.getStack(i);
auto otherStackGovPred = config.getAsFeature(Config::headColName, otherStackIndex);
if (util::isEmpty(otherStackGovPred))