Skip to content
Snippets Groups Projects
Commit 29907cb5 authored by Franck Dary's avatar Franck Dary
Browse files

Corrected a bug where dict was modified if training was resumed

parent 431fab99
No related branches found
No related tags found
No related merge requests found
...@@ -25,6 +25,7 @@ class ReadingMachine ...@@ -25,6 +25,7 @@ class ReadingMachine
std::unique_ptr<Strategy> strategy; std::unique_ptr<Strategy> strategy;
std::map<std::string, Dict> dicts; std::map<std::string, Dict> dicts;
std::set<std::string> predicted; std::set<std::string> predicted;
bool _dictsAreNew{false};
std::unique_ptr<TransitionSet> splitWordTransitionSet{nullptr}; std::unique_ptr<TransitionSet> splitWordTransitionSet{nullptr};
...@@ -51,6 +52,7 @@ class ReadingMachine ...@@ -51,6 +52,7 @@ class ReadingMachine
void saveBest() const; void saveBest() const;
void saveLast() const; void saveLast() const;
void saveDicts() const; void saveDicts() const;
bool dictsAreNew() const;
}; };
#endif #endif
...@@ -14,7 +14,10 @@ ReadingMachine::ReadingMachine(std::filesystem::path path) : path(path) ...@@ -14,7 +14,10 @@ ReadingMachine::ReadingMachine(std::filesystem::path path) : path(path)
this->dicts.emplace(path.stem().string(), Dict{path.c_str(), Dict::State::Open}); this->dicts.emplace(path.stem().string(), Dict{path.c_str(), Dict::State::Open});
if (dicts.count(defaultDictName) == 0) if (dicts.count(defaultDictName) == 0)
{
_dictsAreNew = true;
dicts.emplace(std::make_pair(defaultDictName, Dict::State::Open)); dicts.emplace(std::make_pair(defaultDictName, Dict::State::Open));
}
} }
ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models, std::vector<std::filesystem::path> dicts) ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::filesystem::path> models, std::vector<std::filesystem::path> dicts)
...@@ -199,3 +202,8 @@ std::map<std::string, Dict> & ReadingMachine::getDicts() ...@@ -199,3 +202,8 @@ std::map<std::string, Dict> & ReadingMachine::getDicts()
return dicts; return dicts;
} }
bool ReadingMachine::dictsAreNew() const
{
return _dictsAreNew;
}
...@@ -114,24 +114,29 @@ int MacaonTrain::main() ...@@ -114,24 +114,29 @@ int MacaonTrain::main()
Trainer trainer(machine, batchSize); Trainer trainer(machine, batchSize);
Decoder decoder(machine); Decoder decoder(machine);
trainer.fillDicts(goldConfig); if (machine.dictsAreNew())
std::size_t maxDictSize = 0;
for (auto & it : machine.getDicts())
{ {
std::size_t originalSize = it.second.size(); trainer.fillDicts(goldConfig);
for (;;) for (auto & it : machine.getDicts())
{ {
std::size_t lastSize = it.second.size(); std::size_t originalSize = it.second.size();
it.second.removeRareElements(); for (;;)
float decrease = 100.0*(originalSize-it.second.size())/originalSize; {
if (decrease >= rarityThreshold or lastSize == it.second.size()) std::size_t lastSize = it.second.size();
break; it.second.removeRareElements();
float decrease = 100.0*(originalSize-it.second.size())/originalSize;
if (decrease >= rarityThreshold or lastSize == it.second.size())
break;
}
} }
maxDictSize = std::max<std::size_t>(maxDictSize, it.second.size()); machine.saveDicts();
} }
std::size_t maxDictSize = 0;
for (auto & it : machine.getDicts())
maxDictSize = std::max<std::size_t>(maxDictSize, it.second.size());
machine.getClassifier()->getNN()->registerEmbeddings(maxDictSize); machine.getClassifier()->getNN()->registerEmbeddings(maxDictSize);
machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device); machine.getClassifier()->getNN()->to(NeuralNetworkImpl::device);
machine.saveDicts();
float bestDevScore = computeDevScore ? std::numeric_limits<float>::min() : std::numeric_limits<float>::max(); float bestDevScore = computeDevScore ? std::numeric_limits<float>::min() : std::numeric_limits<float>::max();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment