Skip to content
Snippets Groups Projects
Commit 28045459 authored by Franck Dary's avatar Franck Dary
Browse files

split unknown only when extracting train dataset

parent 183f0297
No related branches found
No related tags found
No related merge requests found
Showing
with 47 additions and 19 deletions
...@@ -9,6 +9,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool ...@@ -9,6 +9,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
{ {
torch::AutoGradMode useGrad(false); torch::AutoGradMode useGrad(false);
machine.trainMode(false); machine.trainMode(false);
machine.splitUnknown(false);
machine.setDictsState(Dict::State::Closed); machine.setDictsState(Dict::State::Closed);
machine.getStrategy().reset(); machine.getStrategy().reset();
config.addPredicted(machine.getPredicted()); config.addPredicted(machine.getPredicted());
......
...@@ -47,6 +47,7 @@ class ReadingMachine ...@@ -47,6 +47,7 @@ class ReadingMachine
bool isPredicted(const std::string & columnName) const; bool isPredicted(const std::string & columnName) const;
const std::set<std::string> & getPredicted() const; const std::set<std::string> & getPredicted() const;
void trainMode(bool isTrainMode); void trainMode(bool isTrainMode);
void splitUnknown(bool splitUnknown);
void setDictsState(Dict::State state); void setDictsState(Dict::State state);
void saveBest() const; void saveBest() const;
void saveLast() const; void saveLast() const;
......
...@@ -182,6 +182,11 @@ void ReadingMachine::trainMode(bool isTrainMode) ...@@ -182,6 +182,11 @@ void ReadingMachine::trainMode(bool isTrainMode)
classifier->getNN()->train(isTrainMode); classifier->getNN()->train(isTrainMode);
} }
void ReadingMachine::splitUnknown(bool splitUnknown)
{
classifier->getNN()->setSplitUnknown(splitUnknown);
}
void ReadingMachine::setDictsState(Dict::State state) void ReadingMachine::setDictsState(Dict::State state)
{ {
for (auto & it : dicts) for (auto & it : dicts)
......
...@@ -22,7 +22,7 @@ class ContextLSTMImpl : public torch::nn::Module, public Submodule ...@@ -22,7 +22,7 @@ class ContextLSTMImpl : public torch::nn::Module, public Submodule
torch::Tensor forward(torch::Tensor input); torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override; std::size_t getOutputSize() override;
std::size_t getInputSize() override; std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
}; };
TORCH_MODULE(ContextLSTM); TORCH_MODULE(ContextLSTM);
......
...@@ -21,7 +21,7 @@ class DepthLayerTreeEmbeddingImpl : public torch::nn::Module, public Submodule ...@@ -21,7 +21,7 @@ class DepthLayerTreeEmbeddingImpl : public torch::nn::Module, public Submodule
torch::Tensor forward(torch::Tensor input); torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override; std::size_t getOutputSize() override;
std::size_t getInputSize() override; std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
}; };
TORCH_MODULE(DepthLayerTreeEmbedding); TORCH_MODULE(DepthLayerTreeEmbedding);
......
...@@ -20,7 +20,7 @@ class FocusedColumnLSTMImpl : public torch::nn::Module, public Submodule ...@@ -20,7 +20,7 @@ class FocusedColumnLSTMImpl : public torch::nn::Module, public Submodule
torch::Tensor forward(torch::Tensor input); torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override; std::size_t getOutputSize() override;
std::size_t getInputSize() override; std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
}; };
TORCH_MODULE(FocusedColumnLSTM); TORCH_MODULE(FocusedColumnLSTM);
......
...@@ -11,6 +11,10 @@ class NeuralNetworkImpl : public torch::nn::Module ...@@ -11,6 +11,10 @@ class NeuralNetworkImpl : public torch::nn::Module
static torch::Device device; static torch::Device device;
private :
bool splitUnknown{false};
protected : protected :
static constexpr int maxNbEmbeddings = 150000; static constexpr int maxNbEmbeddings = 150000;
...@@ -19,6 +23,8 @@ class NeuralNetworkImpl : public torch::nn::Module ...@@ -19,6 +23,8 @@ class NeuralNetworkImpl : public torch::nn::Module
virtual torch::Tensor forward(torch::Tensor input) = 0; virtual torch::Tensor forward(torch::Tensor input) = 0;
virtual std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const = 0; virtual std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const = 0;
bool mustSplitUnknown() const;
void setSplitUnknown(bool splitUnknown);
}; };
TORCH_MODULE(NeuralNetwork); TORCH_MODULE(NeuralNetwork);
......
...@@ -18,7 +18,7 @@ class RawInputLSTMImpl : public torch::nn::Module, public Submodule ...@@ -18,7 +18,7 @@ class RawInputLSTMImpl : public torch::nn::Module, public Submodule
torch::Tensor forward(torch::Tensor input); torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override; std::size_t getOutputSize() override;
std::size_t getInputSize() override; std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
}; };
TORCH_MODULE(RawInputLSTM); TORCH_MODULE(RawInputLSTM);
......
...@@ -18,7 +18,7 @@ class SplitTransLSTMImpl : public torch::nn::Module, public Submodule ...@@ -18,7 +18,7 @@ class SplitTransLSTMImpl : public torch::nn::Module, public Submodule
torch::Tensor forward(torch::Tensor input); torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override; std::size_t getOutputSize() override;
std::size_t getInputSize() override; std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override; void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
}; };
TORCH_MODULE(SplitTransLSTM); TORCH_MODULE(SplitTransLSTM);
......
...@@ -15,7 +15,7 @@ class Submodule ...@@ -15,7 +15,7 @@ class Submodule
void setFirstInputIndex(std::size_t firstInputIndex); void setFirstInputIndex(std::size_t firstInputIndex);
virtual std::size_t getOutputSize() = 0; virtual std::size_t getOutputSize() = 0;
virtual std::size_t getInputSize() = 0; virtual std::size_t getInputSize() = 0;
virtual void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const = 0; virtual void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const = 0;
}; };
#endif #endif
......
...@@ -15,7 +15,7 @@ std::size_t ContextLSTMImpl::getInputSize() ...@@ -15,7 +15,7 @@ std::size_t ContextLSTMImpl::getInputSize()
return columns.size()*(bufferContext.size()+stackContext.size()); return columns.size()*(bufferContext.size()+stackContext.size());
} }
void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const
{ {
std::vector<long> contextIndexes; std::vector<long> contextIndexes;
...@@ -31,8 +31,10 @@ void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dic ...@@ -31,8 +31,10 @@ void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dic
for (auto index : contextIndexes) for (auto index : contextIndexes)
for (auto & col : columns) for (auto & col : columns)
if (index == -1) if (index == -1)
{
for (auto & contextElement : context) for (auto & contextElement : context)
contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
}
else else
{ {
int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index)); int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index));
...@@ -40,7 +42,8 @@ void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dic ...@@ -40,7 +42,8 @@ void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dic
for (auto & contextElement : context) for (auto & contextElement : context)
contextElement.push_back(dictIndex); contextElement.push_back(dictIndex);
if (is_training())
if (splitUnknown)
for (auto & targetCol : unknownValueColumns) for (auto & targetCol : unknownValueColumns)
if (col == targetCol) if (col == targetCol)
if (dict.getNbOccs(dictIndex) <= unknownValueThreshold) if (dict.getNbOccs(dictIndex) <= unknownValueThreshold)
......
...@@ -42,7 +42,7 @@ std::size_t DepthLayerTreeEmbeddingImpl::getInputSize() ...@@ -42,7 +42,7 @@ std::size_t DepthLayerTreeEmbeddingImpl::getInputSize()
return inputSize; return inputSize;
} }
void DepthLayerTreeEmbeddingImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const void DepthLayerTreeEmbeddingImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
{ {
std::vector<long> focusedIndexes; std::vector<long> focusedIndexes;
......
...@@ -24,7 +24,7 @@ std::size_t FocusedColumnLSTMImpl::getInputSize() ...@@ -24,7 +24,7 @@ std::size_t FocusedColumnLSTMImpl::getInputSize()
return (focusedBuffer.size()+focusedStack.size()) * maxNbElements; return (focusedBuffer.size()+focusedStack.size()) * maxNbElements;
} }
void FocusedColumnLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const void FocusedColumnLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
{ {
std::vector<long> focusedIndexes; std::vector<long> focusedIndexes;
......
...@@ -94,21 +94,21 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config, ...@@ -94,21 +94,21 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config,
context.back().emplace_back(dict.getIndexOrInsert(config.getState())); context.back().emplace_back(dict.getIndexOrInsert(config.getState()));
contextLSTM->addToContext(context, dict, config); contextLSTM->addToContext(context, dict, config, mustSplitUnknown());
if (!rawInputLSTM.is_empty()) if (!rawInputLSTM.is_empty())
rawInputLSTM->addToContext(context, dict, config); rawInputLSTM->addToContext(context, dict, config, mustSplitUnknown());
if (!treeEmbedding.is_empty()) if (!treeEmbedding.is_empty())
treeEmbedding->addToContext(context, dict, config); treeEmbedding->addToContext(context, dict, config, mustSplitUnknown());
splitTransLSTM->addToContext(context, dict, config); splitTransLSTM->addToContext(context, dict, config, mustSplitUnknown());
for (auto & lstm : focusedLstms) for (auto & lstm : focusedLstms)
lstm->addToContext(context, dict, config); lstm->addToContext(context, dict, config, mustSplitUnknown());
if (!is_training() && context.size() > 1) if (!mustSplitUnknown() && context.size() > 1)
util::myThrow(fmt::format("Not in training mode, yet context yields multiple variants (size={})", context.size())); util::myThrow(fmt::format("Not in splitUnknown mode, yet context yields multiple variants (size={})", context.size()));
return context; return context;
} }
......
...@@ -2,3 +2,13 @@ ...@@ -2,3 +2,13 @@
torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU); torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU);
bool NeuralNetworkImpl::mustSplitUnknown() const
{
return splitUnknown;
}
void NeuralNetworkImpl::setSplitUnknown(bool splitUnknown)
{
this->splitUnknown = splitUnknown;
}
...@@ -20,7 +20,7 @@ std::size_t RawInputLSTMImpl::getInputSize() ...@@ -20,7 +20,7 @@ std::size_t RawInputLSTMImpl::getInputSize()
return leftWindow + rightWindow + 1; return leftWindow + rightWindow + 1;
} }
void RawInputLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const void RawInputLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
{ {
if (leftWindow < 0 or rightWindow < 0) if (leftWindow < 0 or rightWindow < 0)
return; return;
......
...@@ -21,7 +21,7 @@ std::size_t SplitTransLSTMImpl::getInputSize() ...@@ -21,7 +21,7 @@ std::size_t SplitTransLSTMImpl::getInputSize()
return maxNbTrans; return maxNbTrans;
} }
void SplitTransLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const void SplitTransLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
{ {
auto & splitTransitions = config.getAppliableSplitTransitions(); auto & splitTransitions = config.getAppliableSplitTransitions();
for (auto & contextElement : context) for (auto & contextElement : context)
......
...@@ -10,6 +10,7 @@ void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem ...@@ -10,6 +10,7 @@ void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem
SubConfig config(goldConfig, goldConfig.getNbLines()); SubConfig config(goldConfig, goldConfig.getNbLines());
machine.trainMode(false); machine.trainMode(false);
machine.splitUnknown(true);
machine.setDictsState(Dict::State::Open); machine.setDictsState(Dict::State::Open);
extractExamples(config, debug, dir, epoch, dynamicOracleInterval); extractExamples(config, debug, dir, epoch, dynamicOracleInterval);
...@@ -23,6 +24,7 @@ void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesys ...@@ -23,6 +24,7 @@ void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesys
SubConfig config(goldConfig, goldConfig.getNbLines()); SubConfig config(goldConfig, goldConfig.getNbLines());
machine.trainMode(false); machine.trainMode(false);
machine.splitUnknown(false);
machine.setDictsState(Dict::State::Closed); machine.setDictsState(Dict::State::Closed);
extractExamples(config, debug, dir, epoch, dynamicOracleInterval); extractExamples(config, debug, dir, epoch, dynamicOracleInterval);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment