Skip to content
Snippets Groups Projects
Commit 72c543b6 authored by Franck Dary's avatar Franck Dary
Browse files

During trainning, convert rare forms to unknownValue to train the corresponding embedding

parent c8ea7e12
Branches
No related tags found
No related merge requests found
...@@ -44,6 +44,8 @@ class Dict ...@@ -44,6 +44,8 @@ class Dict
void save(std::FILE * destination, Encoding encoding) const; void save(std::FILE * destination, Encoding encoding) const;
bool readEntry(std::FILE * file, int * index, char * entry, Encoding encoding); bool readEntry(std::FILE * file, int * index, char * entry, Encoding encoding);
void printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding) const; void printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding) const;
std::size_t size() const;
int getNbOccs(int index) const;
}; };
#endif #endif
...@@ -75,6 +75,8 @@ int Dict::getIndexOrInsert(const std::string & element) ...@@ -75,6 +75,8 @@ int Dict::getIndexOrInsert(const std::string & element)
if (state == State::Open) if (state == State::Open)
{ {
insert(element); insert(element);
if (isCountingOccs)
nbOccs[elementsToIndexes[element]]++;
return elementsToIndexes[element]; return elementsToIndexes[element];
} }
if (isCountingOccs) if (isCountingOccs)
...@@ -146,3 +148,15 @@ void Dict::countOcc(bool isCountingOccs) ...@@ -146,3 +148,15 @@ void Dict::countOcc(bool isCountingOccs)
this->isCountingOccs = isCountingOccs; this->isCountingOccs = isCountingOccs;
} }
std::size_t Dict::size() const
{
return elementsToIndexes.size();
}
int Dict::getNbOccs(int index) const
{
if (index < 0 || index >= (int)nbOccs.size())
return 0;
return nbOccs[index];
}
...@@ -36,6 +36,7 @@ class ReadingMachine ...@@ -36,6 +36,7 @@ class ReadingMachine
TransitionSet & getTransitionSet(); TransitionSet & getTransitionSet();
Strategy & getStrategy(); Strategy & getStrategy();
Dict & getDict(const std::string & state); Dict & getDict(const std::string & state);
std::map<std::string, Dict> & getDicts();
Classifier * getClassifier(); Classifier * getClassifier();
void save() const; void save() const;
bool isPredicted(const std::string & columnName) const; bool isPredicted(const std::string & columnName) const;
......
...@@ -131,3 +131,8 @@ void ReadingMachine::trainMode(bool isTrainMode) ...@@ -131,3 +131,8 @@ void ReadingMachine::trainMode(bool isTrainMode)
it.second.setState(isTrainMode ? Dict::State::Open : Dict::State::Closed); it.second.setState(isTrainMode ? Dict::State::Open : Dict::State::Closed);
} }
std::map<std::string, Dict> & ReadingMachine::getDicts()
{
return dicts;
}
...@@ -8,6 +8,9 @@ class CNNNetworkImpl : public NeuralNetworkImpl ...@@ -8,6 +8,9 @@ class CNNNetworkImpl : public NeuralNetworkImpl
{ {
private : private :
static constexpr int maxNbEmbeddings = 50000;
static constexpr int unknownValueThreshold = 0;
std::vector<int> focusedBufferIndexes; std::vector<int> focusedBufferIndexes;
std::vector<int> focusedStackIndexes; std::vector<int> focusedStackIndexes;
std::vector<std::string> focusedColumns; std::vector<std::string> focusedColumns;
......
...@@ -19,7 +19,7 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i ...@@ -19,7 +19,7 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i
rawInputCNN = register_module("rawInputCNN", CNN(std::vector<int>{2,3,4}, nbFiltersFocused, embeddingsSize)); rawInputCNN = register_module("rawInputCNN", CNN(std::vector<int>{2,3,4}, nbFiltersFocused, embeddingsSize));
int rawInputCNNOutputSize = rawInputSize == 0 ? 0 : rawInputCNN->getOutputSize(); int rawInputCNNOutputSize = rawInputSize == 0 ? 0 : rawInputCNN->getOutputSize();
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3)); embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3));
cnnDropout = register_module("cnn_dropout", torch::nn::Dropout(0.3)); cnnDropout = register_module("cnn_dropout", torch::nn::Dropout(0.3));
hiddenDropout = register_module("hidden_dropout", torch::nn::Dropout(0.3)); hiddenDropout = register_module("hidden_dropout", torch::nn::Dropout(0.3));
...@@ -76,6 +76,9 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input) ...@@ -76,6 +76,9 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input)
std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) const std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) const
{ {
if (dict.size() >= maxNbEmbeddings)
util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings));
std::vector<long> contextIndexes = extractContextIndexes(config); std::vector<long> contextIndexes = extractContextIndexes(config);
std::vector<long> context; std::vector<long> context;
...@@ -100,7 +103,14 @@ std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) c ...@@ -100,7 +103,14 @@ std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) c
if (index == -1) if (index == -1)
context.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); context.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
else else
context.push_back(dict.getIndexOrInsert(config.getAsFeature(col, index))); {
int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index));
if (col == "FORM" || col == "LEMMA")
if (dict.getNbOccs(dictIndex) < unknownValueThreshold)
dictIndex = dict.getIndexOrInsert(Dict::unknownValueStr);
context.push_back(dictIndex);
}
for (unsigned int colIndex = 0; colIndex < focusedColumns.size(); colIndex++) for (unsigned int colIndex = 0; colIndex < focusedColumns.size(); colIndex++)
{ {
......
...@@ -61,6 +61,22 @@ po::variables_map checkOptions(po::options_description & od, int argc, char ** a ...@@ -61,6 +61,22 @@ po::variables_map checkOptions(po::options_description & od, int argc, char ** a
return vm; return vm;
} }
void fillDicts(ReadingMachine & rm, const Config & config)
{
static std::vector<std::string> interestingColumns{"FORM", "LEMMA"};
for (auto & col : interestingColumns)
if (config.has(col,0,0))
for (auto & it : rm.getDicts())
{
it.second.countOcc(true);
for (unsigned int j = 0; j < config.getNbLines(); j++)
for (unsigned int k = 0; k < Config::nbHypothesesMax; k++)
it.second.getIndexOrInsert(config.getConst(col,j,k));
it.second.countOcc(false);
}
}
int main(int argc, char * argv[]) int main(int argc, char * argv[])
{ {
auto od = getOptionsDescription(); auto od = getOptionsDescription();
...@@ -89,6 +105,8 @@ int main(int argc, char * argv[]) ...@@ -89,6 +105,8 @@ int main(int argc, char * argv[])
BaseConfig devGoldConfig(mcdFile, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile); BaseConfig devGoldConfig(mcdFile, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile);
SubConfig config(goldConfig); SubConfig config(goldConfig);
fillDicts(machine, goldConfig);
Trainer trainer(machine); Trainer trainer(machine);
trainer.createDataset(config, debug); trainer.createDataset(config, debug);
if (!computeDevScore) if (!computeDevScore)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment