Skip to content
Snippets Groups Projects
Commit c8ea7e12 authored by Franck Dary's avatar Franck Dary
Browse files

Close dict for decode

parent f00fdc1f
No related branches found
No related tags found
No related merge requests found
......@@ -3,6 +3,7 @@
#include <string>
#include <unordered_map>
#include <vector>
class Dict
{
......@@ -20,7 +21,9 @@ class Dict
private :
std::unordered_map<std::string, int> elementsToIndexes;
std::vector<int> nbOccs;
State state;
bool isCountingOccs{false};
public :
......@@ -34,6 +37,7 @@ class Dict
public :
void countOcc(bool isCountingOccs);
int getIndexOrInsert(const std::string & element);
void setState(State state);
State getState() const;
......
......@@ -59,6 +59,8 @@ void Dict::insert(const std::string & element)
util::myThrow(fmt::format("inserting element of size={} > maxElementSize={}", element.size(), maxEntrySize));
elementsToIndexes.emplace(element, elementsToIndexes.size());
while (nbOccs.size() < elementsToIndexes.size())
nbOccs.emplace_back(0);
}
int Dict::getIndexOrInsert(const std::string & element)
......@@ -75,9 +77,13 @@ int Dict::getIndexOrInsert(const std::string & element)
insert(element);
return elementsToIndexes[element];
}
if (isCountingOccs)
nbOccs[elementsToIndexes[unknownValueStr]]++;
return elementsToIndexes[unknownValueStr];
}
if (isCountingOccs)
nbOccs[found->second]++;
return found->second;
}
......@@ -135,3 +141,8 @@ void Dict::printEntry(std::FILE * file, int index, const std::string & entry, En
}
}
void Dict::countOcc(bool isCountingOccs)
{
this->isCountingOccs = isCountingOccs;
}
......@@ -8,7 +8,7 @@ Decoder::Decoder(ReadingMachine & machine) : machine(machine)
void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool printAdvancement)
{
torch::AutoGradMode useGrad(false);
machine.getClassifier()->getNN()->train(false);
machine.trainMode(false);
config.addPredicted(machine.getPredicted());
constexpr int printInterval = 50;
......
......@@ -40,6 +40,7 @@ class ReadingMachine
void save() const;
bool isPredicted(const std::string & columnName) const;
const std::set<std::string> & getPredicted() const;
void trainMode(bool isTrainMode);
};
#endif
......@@ -124,3 +124,10 @@ const std::set<std::string> & ReadingMachine::getPredicted() const
return predicted;
}
void ReadingMachine::trainMode(bool isTrainMode)
{
classifier->getNN()->train(isTrainMode);
for (auto & it : dicts)
it.second.setState(isTrainMode ? Dict::State::Open : Dict::State::Closed);
}
......@@ -7,6 +7,7 @@ Trainer::Trainer(ReadingMachine & machine) : machine(machine)
void Trainer::createDataset(SubConfig & config, bool debug)
{
machine.trainMode(true);
std::vector<torch::Tensor> contexts;
std::vector<torch::Tensor> classes;
......@@ -21,6 +22,7 @@ void Trainer::createDataset(SubConfig & config, bool debug)
void Trainer::createDevDataset(SubConfig & config, bool debug)
{
machine.trainMode(false);
std::vector<torch::Tensor> contexts;
std::vector<torch::Tensor> classes;
......@@ -91,7 +93,7 @@ float Trainer::processDataset(DataLoader & loader, bool train, bool printAdvance
int currentBatchNumber = 0;
torch::AutoGradMode useGrad(train);
machine.getClassifier()->getNN()->train(train);
machine.trainMode(train);
auto lossFct = torch::nn::CrossEntropyLoss();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment