Skip to content
Snippets Groups Projects
Commit d217cf58 authored by Franck Dary's avatar Franck Dary
Browse files

Rare values are now treated as unknown values. Embeddings sizes now exactly match dict size

parent 7982090e
No related branches found
No related tags found
No related merge requests found
......@@ -19,11 +19,9 @@ RawInputModuleImpl::RawInputModuleImpl(const std::string & definition)
.dropout(std::stof(subModuleArguments[2]))
.complete(std::stoi(subModuleArguments[3]));
int inSize = std::stoi(sm.str(5));
inSize = std::stoi(sm.str(5));
int outSize = std::stoi(sm.str(6));
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(60000, inSize)));
if (subModuleType == "LSTM")
myModule = register_module("myModule", LSTM(inSize, outSize, options));
else if (subModuleType == "GRU")
......@@ -51,7 +49,7 @@ std::size_t RawInputModuleImpl::getInputSize()
return leftWindow + rightWindow + 1;
}
void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
{
if (leftWindow < 0 or rightWindow < 0)
return;
......@@ -72,3 +70,8 @@ void RawInputModuleImpl::addToContext(std::vector<std::vector<long>> & context,
}
}
void RawInputModuleImpl::registerEmbeddings(std::size_t nbElements)
{
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(nbElements, inSize)));
}
......@@ -17,11 +17,9 @@ SplitTransModuleImpl::SplitTransModuleImpl(int maxNbTrans, const std::string & d
.dropout(std::stof(subModuleArguments[2]))
.complete(std::stoi(subModuleArguments[3]));
int inSize = std::stoi(sm.str(3));
inSize = std::stoi(sm.str(3));
int outSize = std::stoi(sm.str(4));
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(60000, inSize)));
if (subModuleType == "LSTM")
myModule = register_module("myModule", LSTM(inSize, outSize, options));
else if (subModuleType == "GRU")
......@@ -49,7 +47,7 @@ std::size_t SplitTransModuleImpl::getInputSize()
return maxNbTrans;
}
void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
{
auto & splitTransitions = config.getAppliableSplitTransitions();
for (auto & contextElement : context)
......@@ -60,3 +58,8 @@ void SplitTransModuleImpl::addToContext(std::vector<std::vector<long>> & context
contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
}
void SplitTransModuleImpl::registerEmbeddings(std::size_t nbElements)
{
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(nbElements, inSize)));
}
......@@ -19,7 +19,6 @@ class MacaonTrain
po::options_description getOptionsDescription();
po::variables_map checkOptions(po::options_description & od);
void fillDicts(ReadingMachine & rm, const Config & config);
public :
......
......@@ -43,12 +43,14 @@ class Trainer
void extractExamples(SubConfig & config, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval);
float processDataset(DataLoader & loader, bool train, bool printAdvancement, int nbExamples);
void fillDicts(SubConfig & config);
public :
Trainer(ReadingMachine & machine, int batchSize);
void createDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval);
void createDevDataset(BaseConfig & goldConfig, bool debug, std::filesystem::path dir, int epoch, int dynamicOracleInterval);
void fillDicts(BaseConfig & goldConfig);
float epoch(bool printAdvancement);
float evalOnDev(bool printAdvancement);
};
......
......@@ -35,6 +35,8 @@ po::options_description MacaonTrain::getOptionsDescription()
"Number of examples per batch")
("dynamicOracleInterval", po::value<int>()->default_value(-1),
"Number of examples per batch")
("rarityThreshold", po::value<float>()->default_value(20.0),
"During train, the X% rarest elements will be treated as unknown values")
("machine", po::value<std::string>()->default_value(""),
"Reading machine file content")
("help,h", "Produce this help message");
......@@ -65,22 +67,6 @@ po::variables_map MacaonTrain::checkOptions(po::options_description & od)
return vm;
}
void MacaonTrain::fillDicts(ReadingMachine & rm, const Config & config)
{
static std::vector<std::string> interestingColumns{"FORM", "LEMMA"};
for (auto & col : interestingColumns)
if (config.has(col,0,0))
for (auto & it : rm.getDicts())
{
it.second.countOcc(true);
for (unsigned int j = 0; j < config.getNbLines(); j++)
for (unsigned int k = 0; k < Config::nbHypothesesMax; k++)
it.second.getIndexOrInsert(config.getConst(col,j,k));
it.second.countOcc(false);
}
}
int MacaonTrain::main()
{
auto od = getOptionsDescription();
......@@ -96,6 +82,7 @@ int MacaonTrain::main()
auto nbEpoch = variables["nbEpochs"].as<int>();
auto batchSize = variables["batchSize"].as<int>();
auto dynamicOracleInterval = variables["dynamicOracleInterval"].as<int>();
auto rarityThreshold = variables["rarityThreshold"].as<float>();
bool debug = variables.count("debug") == 0 ? false : true;
bool printAdvancement = !debug && variables.count("silent") == 0 ? true : false;
bool computeDevScore = variables.count("devScore") == 0 ? false : true;
......@@ -124,11 +111,27 @@ int MacaonTrain::main()
BaseConfig goldConfig(mcdFile, trainTsvFile, trainRawFile);
BaseConfig devGoldConfig(mcdFile, computeDevScore ? (devRawFile.empty() ? devTsvFile : "") : devTsvFile, devRawFile);
fillDicts(machine, goldConfig);
Trainer trainer(machine, batchSize);
Decoder decoder(machine);
trainer.fillDicts(goldConfig);
std::size_t maxDictSize = 0;
for (auto & it : machine.getDicts())
{
std::size_t originalSize = it.second.size();
for (;;)
{
std::size_t lastSize = it.second.size();
it.second.removeRareElements();
float decrease = 100.0*(originalSize-it.second.size())/originalSize;
if (decrease >= rarityThreshold or lastSize == it.second.size())
break;
}
maxDictSize = std::max<std::size_t>(maxDictSize, it.second.size());
}
machine.getClassifier()->getNN()->registerEmbeddings(maxDictSize);
machine.saveDicts();
float bestDevScore = computeDevScore ? std::numeric_limits<float>::min() : std::numeric_limits<float>::max();
auto trainInfos = machinePath.parent_path() / "train.info";
......
......@@ -10,8 +10,7 @@ void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem
SubConfig config(goldConfig, goldConfig.getNbLines());
machine.trainMode(false);
machine.splitUnknown(true);
machine.setDictsState(Dict::State::Open);
machine.setDictsState(Dict::State::Closed);
extractExamples(config, debug, dir, epoch, dynamicOracleInterval);
trainDataset.reset(new Dataset(dir));
......@@ -24,7 +23,6 @@ void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesys
SubConfig config(goldConfig, goldConfig.getNbLines());
machine.trainMode(false);
machine.splitUnknown(false);
machine.setDictsState(Dict::State::Closed);
extractExamples(config, debug, dir, epoch, dynamicOracleInterval);
......@@ -43,9 +41,9 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
std::filesystem::create_directories(dir);
config.addPredicted(machine.getPredicted());
machine.getStrategy().reset();
config.setState(machine.getStrategy().getInitialState());
machine.getClassifier()->setState(machine.getStrategy().getInitialState());
machine.getStrategy().reset();
auto currentEpochAllExtractedFile = dir / fmt::format("extracted.{}", epoch);
bool mustExtract = !std::filesystem::exists(currentEpochAllExtractedFile);
......@@ -154,8 +152,6 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
util::myThrow(fmt::format("could not create file '{}'", currentEpochAllExtractedFile.c_str()));
std::fclose(f);
machine.saveDicts();
fmt::print(stderr, "[{}] Extracted {} examples\n", util::getTime(), util::int2HumanStr(totalNbExamples));
}
......@@ -274,3 +270,66 @@ void Trainer::Examples::addClass(int goldIndex)
classes.emplace_back(gold);
}
void Trainer::fillDicts(BaseConfig & goldConfig)
{
SubConfig config(goldConfig, goldConfig.getNbLines());
for (auto & it : machine.getDicts())
it.second.countOcc(true);
machine.trainMode(false);
machine.setDictsState(Dict::State::Open);
fillDicts(config);
for (auto & it : machine.getDicts())
it.second.countOcc(false);
}
void Trainer::fillDicts(SubConfig & config)
{
torch::AutoGradMode useGrad(false);
config.addPredicted(machine.getPredicted());
machine.getStrategy().reset();
config.setState(machine.getStrategy().getInitialState());
machine.getClassifier()->setState(machine.getStrategy().getInitialState());
while (true)
{
if (machine.hasSplitWordTransitionSet())
config.setAppliableSplitTransitions(machine.getSplitWordTransitionSet().getNAppliableTransitions(config, Config::maxNbAppliableSplitTransitions));
try
{
machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState()));
} catch(std::exception & e)
{
util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
}
Transition * goldTransition = nullptr;
goldTransition = machine.getTransitionSet().getBestAppliableTransition(config);
if (!goldTransition)
{
config.printForDebug(stderr);
util::myThrow("No transition appliable !");
}
goldTransition->apply(config);
config.addToHistory(goldTransition->getName());
auto movement = machine.getStrategy().getMovement(config, goldTransition->getName());
if (movement == Strategy::endMovement)
break;
config.setState(movement.first);
machine.getClassifier()->setState(movement.first);
config.moveWordIndexRelaxed(movement.second);
if (config.needsUpdate())
config.update();
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment