Skip to content
Snippets Groups Projects
Commit 03221c3f authored by Franck Dary's avatar Franck Dary
Browse files

Corrected a bug where splitword had 0 cost even when it didn't match the size of the gold multiword

parent f75f941f
No related branches found
No related tags found
No related merge requests found
......@@ -22,7 +22,7 @@ class Config
static constexpr const char * deprelColName = "DEPREL";
static constexpr const char * idColName = "ID";
static constexpr int nbHypothesesMax = 1;
static constexpr int maxNbAppliableSplitTransitions = 3;
static constexpr int maxNbAppliableSplitTransitions = 8;
public :
......
......@@ -203,6 +203,12 @@ void Transition::initSplitWord(std::vector<std::string> words)
cost = [words](const Config & config)
{
if (!config.isMultiword(config.getWordIndex()))
return std::numeric_limits<int>::max();
if (config.getMultiwordSize(config.getWordIndex())+2 != (int)words.size())
return std::numeric_limits<int>::max();
int cost = 0;
for (unsigned int i = 0; i < words.size(); i++)
if (!config.has("FORM", config.getWordIndex()+i, 0) or config.getConst("FORM", config.getWordIndex()+i, 0) != words[i])
......
......@@ -46,10 +46,13 @@ std::vector<Transition *> TransitionSet::getNAppliableTransitions(const Config &
{
std::vector<Transition *> result;
for (unsigned int i = 0; i < transitions.size() && result.size() < n; i++)
for (unsigned int i = 0; i < transitions.size(); i++)
if (transitions[i].appliable(c))
result.emplace_back(&transitions[i]);
if ((int)result.size() > n)
util::myThrow(fmt::format("there are {} appliable transitions n = {}\n", result.size(), n));
return result;
}
......
......@@ -52,6 +52,24 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch:
util::myThrow("No transition appliable !");
}
if (config.isMultiword(config.getWordIndex()))
if (transition->getName() == "ADDCHARTOWORD")
{
config.printForDebug(stderr);
auto & splitTrans = config.getAppliableSplitTransitions();
fmt::print(stderr, "splitTrans.size() = {}\n", splitTrans.size());
for (auto & trans : splitTrans)
fmt::print(stderr, "cost {} : '{}'\n", trans->getCost(config), trans->getName());
util::myThrow(fmt::format("Transition should have been a split"));
}
if (transition->getName() == "ENDWORD")
if (config.getAsFeature("FORM",config.getWordIndex()) != config.getConst("FORM",config.getWordIndex(),0))
{
config.printForDebug(stderr);
util::myThrow(fmt::format("Words don't match"));
}
std::vector<std::vector<long>> context;
try
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment