Skip to content
Snippets Groups Projects
Commit 29907cb5 authored by Franck Dary's avatar Franck Dary
Browse files

Corrected a bug where dict was modified if training was resumed

parent 431fab99
No related branches found
No related tags found
No related merge requests found
......@@ -25,6 +25,7 @@ class ReadingMachine
std::unique_ptr<Strategy> strategy;
std::map<std::string, Dict> dicts;
std::set<std::string> predicted;
bool _dictsAreNew{false};
std::unique_ptr<TransitionSet> splitWordTransitionSet{nullptr};
......@@ -51,6 +52,7 @@ class ReadingMachine
void saveBest() const;
void saveLast() const;
void saveDicts() const;
bool dictsAreNew() const;
};
#endif
......@@ -14,7 +14,10 @@ ReadingMachine::ReadingMachine(std::filesystem::path path) : path(path)
this->dicts.emplace(path.stem().string(), Dict{path.c_str(), Dict::State::Open});
if (dicts.count(defaultDictName) == 0)
{
_dictsAreNew = true;
dicts.emplace(std::make_pair(defaultDictName, Dict::State::Open));
}
}
ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models, std::vector<std::filesystem::path> dicts)
......@@ -199,3 +202,8 @@ std::map<std::string, Dict> & ReadingMachine::getDicts()
return dicts;
}
bool ReadingMachine::dictsAreNew() const
{
return _dictsAreNew;
}
......@@ -114,24 +114,29 @@ int MacaonTrain::main()
Trainer trainer(machine, batchSize);
Decoder decoder(machine);
trainer.fillDicts(goldConfig);
std::size_t maxDictSize = 0;
for (auto & it : machine.getDicts())
if (machine.dictsAreNew())
{
std::size_t originalSize = it.second.size();
for (;;)
trainer.fillDicts(goldConfig);
for (auto & it : machine.getDicts())
{
std::size_t lastSize = it.second.size();
it.second.removeRareElements();
float decrease = 100.0*(originalSize-it.second.size())/originalSize;
if (decrease >= rarityThreshold or lastSize == it.second.size())
break;
std::size_t originalSize = it.second.size();
for (;;)
{
std::size_t lastSize = it.second.size();
it.second.removeRareElements();
float decrease = 100.0*(originalSize-it.second.size())/originalSize;
if (decrease >= rarityThreshold or lastSize == it.second.size())
break;
}
}
maxDictSize = std::max<std::size_t>(maxDictSize, it.second.size());
machine.saveDicts();
}
std::size_t maxDictSize = 0;
for (auto & it : machine.getDicts())
maxDictSize = std::max<std::size_t>(maxDictSize, it.second.size());
machine.getClassifier()->getNN()->registerEmbeddings(maxDictSize);
machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device);
machine.saveDicts();
float bestDevScore = computeDevScore ? std::numeric_limits<float>::min() : std::numeric_limits<float>::max();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment