Skip to content
Snippets Groups Projects
Commit 72c543b6 authored by Franck Dary's avatar Franck Dary
Browse files

During trainning, convert rare forms to unknownValue to train the corresponding embedding

parent c8ea7e12
No related branches found
No related tags found
No related merge requests found
......@@ -44,6 +44,8 @@ class Dict
void save(std::FILE * destination, Encoding encoding) const;
bool readEntry(std::FILE * file, int * index, char * entry, Encoding encoding);
void printEntry(std::FILE * file, int index, const std::string & entry, Encoding encoding) const;
std::size_t size() const;
int getNbOccs(int index) const;
};
#endif
......@@ -75,6 +75,8 @@ int Dict::getIndexOrInsert(const std::string & element)
if (state == State::Open)
{
insert(element);
if (isCountingOccs)
nbOccs[elementsToIndexes[element]]++;
return elementsToIndexes[element];
}
if (isCountingOccs)
......@@ -146,3 +148,15 @@ void Dict::countOcc(bool isCountingOccs)
this->isCountingOccs = isCountingOccs;
}
std::size_t Dict::size() const
{
return elementsToIndexes.size();
}
int Dict::getNbOccs(int index) const
{
if (index < 0 || index >= (int)nbOccs.size())
return 0;
return nbOccs[index];
}
......@@ -36,6 +36,7 @@ class ReadingMachine
TransitionSet & getTransitionSet();
Strategy & getStrategy();
Dict & getDict(const std::string & state);
std::map<std::string, Dict> & getDicts();
Classifier * getClassifier();
void save() const;
bool isPredicted(const std::string & columnName) const;
......
......@@ -131,3 +131,8 @@ void ReadingMachine::trainMode(bool isTrainMode)
it.second.setState(isTrainMode ? Dict::State::Open : Dict::State::Closed);
}
std::map<std::string, Dict> & ReadingMachine::getDicts()
{
return dicts;
}
......@@ -8,6 +8,9 @@ class CNNNetworkImpl : public NeuralNetworkImpl
{
private :
static constexpr int maxNbEmbeddings = 50000;
static constexpr int unknownValueThreshold = 0;
std::vector<int> focusedBufferIndexes;
std::vector<int> focusedStackIndexes;
std::vector<std::string> focusedColumns;
......
......@@ -19,7 +19,7 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i
rawInputCNN = register_module("rawInputCNN", CNN(std::vector<int>{2,3,4}, nbFiltersFocused, embeddingsSize));
int rawInputCNNOutputSize = rawInputSize == 0 ? 0 : rawInputCNN->getOutputSize();
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingsSize)));
embeddingsDropout = register_module("embeddings_dropout", torch::nn::Dropout(0.3));
cnnDropout = register_module("cnn_dropout", torch::nn::Dropout(0.3));
hiddenDropout = register_module("hidden_dropout", torch::nn::Dropout(0.3));
......@@ -76,6 +76,9 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input)
std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) const
{
if (dict.size() >= maxNbEmbeddings)
util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings));
std::vector<long> contextIndexes = extractContextIndexes(config);
std::vector<long> context;
......@@ -100,7 +103,14 @@ std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) c
if (index == -1)
context.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
else
context.push_back(dict.getIndexOrInsert(config.getAsFeature(col, index)));
{
int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index));
if (col == "FORM" || col == "LEMMA")
if (dict.getNbOccs(dictIndex) < unknownValueThreshold)
dictIndex = dict.getIndexOrInsert(Dict::unknownValueStr);
context.push_back(dictIndex);
}
for (unsigned int colIndex = 0; colIndex < focusedColumns.size(); colIndex++)
{
......
......@@ -61,6 +61,22 @@ po::variables_map checkOptions(po::options_description & od, int argc, char ** a
return vm;
}
void fillDicts(ReadingMachine & rm, const Config & config)
{
static std::vector<std::string> interestingColumns{"FORM", "LEMMA"};
for (auto & col : interestingColumns)
if (config.has(col,0,0))
for (auto & it : rm.getDicts())
{
it.second.countOcc(true);
for (unsigned int j = 0; j < config.getNbLines(); j++)
for (unsigned int k = 0; k < Config::nbHypothesesMax; k++)
it.second.getIndexOrInsert(config.getConst(col,j,k));
it.second.countOcc(false);
}
}
int main(int argc, char * argv[])
{
auto od = getOptionsDescription();
......@@ -89,6 +105,8 @@ int main(int argc, char * argv[])
BaseConfig devGoldConfig(mcdFile, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile);
SubConfig config(goldConfig);
fillDicts(machine, goldConfig);
Trainer trainer(machine);
trainer.createDataset(config, debug);
if (!computeDevScore)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment