Skip to content
Snippets Groups Projects
Commit c8ea7e12 authored by Franck Dary's avatar Franck Dary
Browse files

Close dict for decode

parent f00fdc1f
No related branches found
No related tags found
No related merge requests found
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector>
class Dict class Dict
{ {
...@@ -20,7 +21,9 @@ class Dict ...@@ -20,7 +21,9 @@ class Dict
private : private :
std::unordered_map<std::string, int> elementsToIndexes; std::unordered_map<std::string, int> elementsToIndexes;
std::vector<int> nbOccs;
State state; State state;
bool isCountingOccs{false};
public : public :
...@@ -34,6 +37,7 @@ class Dict ...@@ -34,6 +37,7 @@ class Dict
public : public :
void countOcc(bool isCountingOccs);
int getIndexOrInsert(const std::string & element); int getIndexOrInsert(const std::string & element);
void setState(State state); void setState(State state);
State getState() const; State getState() const;
......
...@@ -59,6 +59,8 @@ void Dict::insert(const std::string & element) ...@@ -59,6 +59,8 @@ void Dict::insert(const std::string & element)
util::myThrow(fmt::format("inserting element of size={} > maxElementSize={}", element.size(), maxEntrySize)); util::myThrow(fmt::format("inserting element of size={} > maxElementSize={}", element.size(), maxEntrySize));
elementsToIndexes.emplace(element, elementsToIndexes.size()); elementsToIndexes.emplace(element, elementsToIndexes.size());
while (nbOccs.size() < elementsToIndexes.size())
nbOccs.emplace_back(0);
} }
int Dict::getIndexOrInsert(const std::string & element) int Dict::getIndexOrInsert(const std::string & element)
...@@ -75,9 +77,13 @@ int Dict::getIndexOrInsert(const std::string & element) ...@@ -75,9 +77,13 @@ int Dict::getIndexOrInsert(const std::string & element)
insert(element); insert(element);
return elementsToIndexes[element]; return elementsToIndexes[element];
} }
if (isCountingOccs)
nbOccs[elementsToIndexes[unknownValueStr]]++;
return elementsToIndexes[unknownValueStr]; return elementsToIndexes[unknownValueStr];
} }
if (isCountingOccs)
nbOccs[found->second]++;
return found->second; return found->second;
} }
...@@ -135,3 +141,8 @@ void Dict::printEntry(std::FILE * file, int index, const std::string & entry, En ...@@ -135,3 +141,8 @@ void Dict::printEntry(std::FILE * file, int index, const std::string & entry, En
} }
} }
void Dict::countOcc(bool isCountingOccs)
{
this->isCountingOccs = isCountingOccs;
}
...@@ -8,7 +8,7 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine) ...@@ -8,7 +8,7 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine)
void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool printAdvancement) void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool printAdvancement)
{ {
torch::AutoGradMode useGrad(false); torch::AutoGradMode useGrad(false);
machine.getClassifier()->getNN()->train(false); machine.trainMode(false);
config.addPredicted(machine.getPredicted()); config.addPredicted(machine.getPredicted());
constexpr int printInterval = 50; constexpr int printInterval = 50;
......
...@@ -40,6 +40,7 @@ class ReadingMachine ...@@ -40,6 +40,7 @@ class ReadingMachine
void save() const; void save() const;
bool isPredicted(const std::string & columnName) const; bool isPredicted(const std::string & columnName) const;
const std::set<std::string> & getPredicted() const; const std::set<std::string> & getPredicted() const;
void trainMode(bool isTrainMode);
}; };
#endif #endif
...@@ -124,3 +124,10 @@ const std::set<std::string> & ReadingMachine::getPredicted() const ...@@ -124,3 +124,10 @@ const std::set<std::string> & ReadingMachine::getPredicted() const
return predicted; return predicted;
} }
void ReadingMachine::trainMode(bool isTrainMode)
{
classifier->getNN()->train(isTrainMode);
for (auto & it : dicts)
it.second.setState(isTrainMode ? Dict::State::Open : Dict::State::Closed);
}
...@@ -7,6 +7,7 @@ Trainer::Trainer(ReadingMachine & machine) : machine(machine) ...@@ -7,6 +7,7 @@ Trainer::Trainer(ReadingMachine & machine) : machine(machine)
void Trainer::createDataset(SubConfig & config, bool debug) void Trainer::createDataset(SubConfig & config, bool debug)
{ {
machine.trainMode(true);
std::vector<torch::Tensor> contexts; std::vector<torch::Tensor> contexts;
std::vector<torch::Tensor> classes; std::vector<torch::Tensor> classes;
...@@ -21,6 +22,7 @@ void Trainer::createDataset(SubConfig & config, bool debug) ...@@ -21,6 +22,7 @@ void Trainer::createDataset(SubConfig & config, bool debug)
void Trainer::createDevDataset(SubConfig & config, bool debug) void Trainer::createDevDataset(SubConfig & config, bool debug)
{ {
machine.trainMode(false);
std::vector<torch::Tensor> contexts; std::vector<torch::Tensor> contexts;
std::vector<torch::Tensor> classes; std::vector<torch::Tensor> classes;
...@@ -91,7 +93,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance ...@@ -91,7 +93,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance
int currentBatchNumber = 0; int currentBatchNumber = 0;
torch::AutoGradMode useGrad(train); torch::AutoGradMode useGrad(train);
machine.getClassifier()->getNN()->train(train); machine.trainMode(train);
auto lossFct = torch::nn::CrossEntropyLoss(); auto lossFct = torch::nn::CrossEntropyLoss();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment