Skip to content
Snippets Groups Projects
Commit 28045459 authored by Franck Dary's avatar Franck Dary
Browse files

split unknown only when extracting train dataset

parent 183f0297
No related branches found
No related tags found
No related merge requests found
Showing
with 47 additions and 19 deletions
......@@ -9,6 +9,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
{
torch::AutoGradMode useGrad(false);
machine.trainMode(false);
machine.splitUnknown(false);
machine.setDictsState(Dict::State::Closed);
machine.getStrategy().reset();
config.addPredicted(machine.getPredicted());
......
......@@ -47,6 +47,7 @@ class ReadingMachine
bool isPredicted(const std::string & columnName) const;
const std::set<std::string> & getPredicted() const;
void trainMode(bool isTrainMode);
void splitUnknown(bool splitUnknown);
void setDictsState(Dict::State state);
void saveBest() const;
void saveLast() const;
......
......@@ -182,6 +182,11 @@ void ReadingMachine::trainMode(bool isTrainMode)
classifier->getNN()->train(isTrainMode);
}
void ReadingMachine::splitUnknown(bool splitUnknown)
{
classifier->getNN()->setSplitUnknown(splitUnknown);
}
void ReadingMachine::setDictsState(Dict::State state)
{
for (auto & it : dicts)
......
......@@ -22,7 +22,7 @@ class ContextLSTMImpl : public torch::nn::Module, public Submodule
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
};
TORCH_MODULE(ContextLSTM);
......
......@@ -21,7 +21,7 @@ class DepthLayerTreeEmbeddingImpl : public torch::nn::Module, public Submodule
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
};
TORCH_MODULE(DepthLayerTreeEmbedding);
......
......@@ -20,7 +20,7 @@ class FocusedColumnLSTMImpl : public torch::nn::Module, public Submodule
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
};
TORCH_MODULE(FocusedColumnLSTM);
......
......@@ -11,6 +11,10 @@ class NeuralNetworkImpl : public torch::nn::Module
static torch::Device device;
private :
bool splitUnknown{false};
protected :
static constexpr int maxNbEmbeddings = 150000;
......@@ -19,6 +23,8 @@ class NeuralNetworkImpl : public torch::nn::Module
virtual torch::Tensor forward(torch::Tensor input) = 0;
virtual std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const = 0;
bool mustSplitUnknown() const;
void setSplitUnknown(bool splitUnknown);
};
TORCH_MODULE(NeuralNetwork);
......
......@@ -18,7 +18,7 @@ class RawInputLSTMImpl : public torch::nn::Module, public Submodule
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
};
TORCH_MODULE(RawInputLSTM);
......
......@@ -18,7 +18,7 @@ class SplitTransLSTMImpl : public torch::nn::Module, public Submodule
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
};
TORCH_MODULE(SplitTransLSTM);
......
......@@ -15,7 +15,7 @@ class Submodule
void setFirstInputIndex(std::size_t firstInputIndex);
virtual std::size_t getOutputSize() = 0;
virtual std::size_t getInputSize() = 0;
virtual void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const = 0;
virtual void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const = 0;
};
#endif
......
......@@ -15,7 +15,7 @@ std::size_t ContextLSTMImpl::getInputSize()
return columns.size()*(bufferContext.size()+stackContext.size());
}
void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const
{
std::vector<long> contextIndexes;
......@@ -31,8 +31,10 @@ void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dic
for (auto index : contextIndexes)
for (auto & col : columns)
if (index == -1)
{
for (auto & contextElement : context)
contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
}
else
{
int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index));
......@@ -40,7 +42,8 @@ void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dic
for (auto & contextElement : context)
contextElement.push_back(dictIndex);
if (is_training())
if (splitUnknown)
for (auto & targetCol : unknownValueColumns)
if (col == targetCol)
if (dict.getNbOccs(dictIndex) <= unknownValueThreshold)
......
......@@ -42,7 +42,7 @@ std::size_t DepthLayerTreeEmbeddingImpl::getInputSize()
return inputSize;
}
void DepthLayerTreeEmbeddingImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
void DepthLayerTreeEmbeddingImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
{
std::vector<long> focusedIndexes;
......
......@@ -24,7 +24,7 @@ std::size_t FocusedColumnLSTMImpl::getInputSize()
return (focusedBuffer.size()+focusedStack.size()) * maxNbElements;
}
void FocusedColumnLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
void FocusedColumnLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
{
std::vector<long> focusedIndexes;
......
......@@ -94,21 +94,21 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config,
context.back().emplace_back(dict.getIndexOrInsert(config.getState()));
contextLSTM->addToContext(context, dict, config);
contextLSTM->addToContext(context, dict, config, mustSplitUnknown());
if (!rawInputLSTM.is_empty())
rawInputLSTM->addToContext(context, dict, config);
rawInputLSTM->addToContext(context, dict, config, mustSplitUnknown());
if (!treeEmbedding.is_empty())
treeEmbedding->addToContext(context, dict, config);
treeEmbedding->addToContext(context, dict, config, mustSplitUnknown());
splitTransLSTM->addToContext(context, dict, config);
splitTransLSTM->addToContext(context, dict, config, mustSplitUnknown());
for (auto & lstm : focusedLstms)
lstm->addToContext(context, dict, config);
lstm->addToContext(context, dict, config, mustSplitUnknown());
if (!is_training() && context.size() > 1)
util::myThrow(fmt::format("Not in training mode, yet context yields multiple variants (size={})", context.size()));
if (!mustSplitUnknown() && context.size() > 1)
util::myThrow(fmt::format("Not in splitUnknown mode, yet context yields multiple variants (size={})", context.size()));
return context;
}
......
......@@ -2,3 +2,13 @@
torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU);
bool NeuralNetworkImpl::mustSplitUnknown() const
{
return splitUnknown;
}
void NeuralNetworkImpl::setSplitUnknown(bool splitUnknown)
{
this->splitUnknown = splitUnknown;
}
......@@ -20,7 +20,7 @@ std::size_t RawInputLSTMImpl::getInputSize()
return leftWindow + rightWindow + 1;
}
void RawInputLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
void RawInputLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
{
if (leftWindow < 0 or rightWindow < 0)
return;
......
......@@ -21,7 +21,7 @@ std::size_t SplitTransLSTMImpl::getInputSize()
return maxNbTrans;
}
void SplitTransLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
void SplitTransLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
{
auto & splitTransitions = config.getAppliableSplitTransitions();
for (auto & contextElement : context)
......
......@@ -10,6 +10,7 @@ void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem
SubConfig config(goldConfig, goldConfig.getNbLines());
machine.trainMode(false);
machine.splitUnknown(true);
machine.setDictsState(Dict::State::Open);
extractExamples(config, debug, dir, epoch, dynamicOracleInterval);
......@@ -23,6 +24,7 @@ void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesys
SubConfig config(goldConfig, goldConfig.getNbLines());
machine.trainMode(false);
machine.splitUnknown(false);
machine.setDictsState(Dict::State::Closed);
extractExamples(config, debug, dir, epoch, dynamicOracleInterval);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment