Skip to content
Snippets Groups Projects
Commit 29907cb5 authored by Franck Dary's avatar Franck Dary
Browse files

Corrected a bug where dict was modified if training was resumed

parent 431fab99
Branches
No related tags found
No related merge requests found
......@@ -25,6 +25,7 @@ class ReadingMachine
std::unique_ptr<Strategy> strategy;
std::map<std::string, Dict> dicts;
std::set<std::string> predicted;
bool _dictsAreNew{false};
std::unique_ptr<TransitionSet> splitWordTransitionSet{nullptr};
......@@ -51,6 +52,7 @@ class ReadingMachine
void saveBest() const;
void saveLast() const;
void saveDicts() const;
bool dictsAreNew() const;
};
#endif
......@@ -14,8 +14,11 @@ ReadingMachine::ReadingMachine(std::filesystem::path path) : path(path)
this->dicts.emplace(path.stem().string(), Dict{path.c_str(), Dict::State::Open});
if (dicts.count(defaultDictName) == 0)
{
_dictsAreNew = true;
dicts.emplace(std::make_pair(defaultDictName, Dict::State::Open));
}
}
ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models, std::vector<std::filesystem::path> dicts)
{
......@@ -199,3 +202,8 @@ std::map<std::string, Dict> & ReadingMachine::getDicts()
return dicts;
}
bool ReadingMachine::dictsAreNew() const
{
return _dictsAreNew;
}
......@@ -114,8 +114,9 @@ int MacaonTrain::main()
Trainer trainer(machine, batchSize);
Decoder decoder(machine);
if (machine.dictsAreNew())
{
trainer.fillDicts(goldConfig);
std::size_t maxDictSize = 0;
for (auto & it : machine.getDicts())
{
std::size_t originalSize = it.second.size();
......@@ -127,11 +128,15 @@ int MacaonTrain::main()
if (decrease >= rarityThreshold or lastSize == it.second.size())
break;
}
maxDictSize = std::max<std::size_t>(maxDictSize, it.second.size());
}
machine.saveDicts();
}
std::size_t maxDictSize = 0;
for (auto & it : machine.getDicts())
maxDictSize = std::max<std::size_t>(maxDictSize, it.second.size());
machine.getClassifier()->getNN()->registerEmbeddings(maxDictSize);
machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device);
machine.saveDicts();
float bestDevScore = computeDevScore ? std::numeric_limits<float>::min() : std::numeric_limits<float>::max();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment