Commit 28045459 authored by Franck Dary's avatar Franck Dary
Browse files

split unknown only when extracting train dataset

parent 183f0297
......@@ -9,6 +9,7 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
{
torch::AutoGradMode useGrad(false);
machine.trainMode(false);
machine.splitUnknown(false);
machine.setDictsState(Dict::State::Closed);
machine.getStrategy().reset();
config.addPredicted(machine.getPredicted());
......
......@@ -47,6 +47,7 @@ class ReadingMachine
bool isPredicted(const std::string & columnName) const;
const std::set<std::string> & getPredicted() const;
void trainMode(bool isTrainMode);
void splitUnknown(bool splitUnknown);
void setDictsState(Dict::State state);
void saveBest() const;
void saveLast() const;
......
......@@ -182,6 +182,11 @@ void ReadingMachine::trainMode(bool isTrainMode)
classifier->getNN()->train(isTrainMode);
}
void ReadingMachine::splitUnknown(bool splitUnknown)
{
classifier->getNN()->setSplitUnknown(splitUnknown);
}
void ReadingMachine::setDictsState(Dict::State state)
{
for (auto & it : dicts)
......
......@@ -22,7 +22,7 @@ class ContextLSTMImpl : public torch::nn::Module, public Submodule
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
};
TORCH_MODULE(ContextLSTM);
......
......@@ -21,7 +21,7 @@ class DepthLayerTreeEmbeddingImpl : public torch::nn::Module, public Submodule
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
};
TORCH_MODULE(DepthLayerTreeEmbedding);
......
......@@ -20,7 +20,7 @@ class FocusedColumnLSTMImpl : public torch::nn::Module, public Submodule
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
};
TORCH_MODULE(FocusedColumnLSTM);
......
......@@ -11,6 +11,10 @@ class NeuralNetworkImpl : public torch::nn::Module
static torch::Device device;
private :
bool splitUnknown{false};
protected :
static constexpr int maxNbEmbeddings = 150000;
......@@ -19,6 +23,8 @@ class NeuralNetworkImpl : public torch::nn::Module
virtual torch::Tensor forward(torch::Tensor input) = 0;
virtual std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const = 0;
bool mustSplitUnknown() const;
void setSplitUnknown(bool splitUnknown);
};
TORCH_MODULE(NeuralNetwork);
......
......@@ -18,7 +18,7 @@ class RawInputLSTMImpl : public torch::nn::Module, public Submodule
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
};
TORCH_MODULE(RawInputLSTM);
......
......@@ -18,7 +18,7 @@ class SplitTransLSTMImpl : public torch::nn::Module, public Submodule
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const override;
void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const override;
};
TORCH_MODULE(SplitTransLSTM);
......
......@@ -15,7 +15,7 @@ class Submodule
void setFirstInputIndex(std::size_t firstInputIndex);
virtual std::size_t getOutputSize() = 0;
virtual std::size_t getInputSize() = 0;
virtual void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const = 0;
virtual void addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const = 0;
};
#endif
......
......@@ -15,7 +15,7 @@ std::size_t ContextLSTMImpl::getInputSize()
return columns.size()*(bufferContext.size()+stackContext.size());
}
void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool splitUnknown) const
{
std::vector<long> contextIndexes;
......@@ -31,8 +31,10 @@ void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dic
for (auto index : contextIndexes)
for (auto & col : columns)
if (index == -1)
{
for (auto & contextElement : context)
contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
}
else
{
int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index));
......@@ -40,7 +42,8 @@ void ContextLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dic
for (auto & contextElement : context)
contextElement.push_back(dictIndex);
if (is_training())
if (splitUnknown)
for (auto & targetCol : unknownValueColumns)
if (col == targetCol)
if (dict.getNbOccs(dictIndex) <= unknownValueThreshold)
......
......@@ -42,7 +42,7 @@ std::size_t DepthLayerTreeEmbeddingImpl::getInputSize()
return inputSize;
}
void DepthLayerTreeEmbeddingImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
void DepthLayerTreeEmbeddingImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
{
std::vector<long> focusedIndexes;
......
......@@ -24,7 +24,7 @@ std::size_t FocusedColumnLSTMImpl::getInputSize()
return (focusedBuffer.size()+focusedStack.size()) * maxNbElements;
}
void FocusedColumnLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
void FocusedColumnLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
{
std::vector<long> focusedIndexes;
......
......@@ -94,21 +94,21 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config,
context.back().emplace_back(dict.getIndexOrInsert(config.getState()));
contextLSTM->addToContext(context, dict, config);
contextLSTM->addToContext(context, dict, config, mustSplitUnknown());
if (!rawInputLSTM.is_empty())
rawInputLSTM->addToContext(context, dict, config);
rawInputLSTM->addToContext(context, dict, config, mustSplitUnknown());
if (!treeEmbedding.is_empty())
treeEmbedding->addToContext(context, dict, config);
treeEmbedding->addToContext(context, dict, config, mustSplitUnknown());
splitTransLSTM->addToContext(context, dict, config);
splitTransLSTM->addToContext(context, dict, config, mustSplitUnknown());
for (auto & lstm : focusedLstms)
lstm->addToContext(context, dict, config);
lstm->addToContext(context, dict, config, mustSplitUnknown());
if (!is_training() && context.size() > 1)
util::myThrow(fmt::format("Not in training mode, yet context yields multiple variants (size={})", context.size()));
if (!mustSplitUnknown() && context.size() > 1)
util::myThrow(fmt::format("Not in splitUnknown mode, yet context yields multiple variants (size={})", context.size()));
return context;
}
......
......@@ -2,3 +2,13 @@
torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCUDA : torch::kCPU);
bool NeuralNetworkImpl::mustSplitUnknown() const
{
return splitUnknown;
}
void NeuralNetworkImpl::setSplitUnknown(bool splitUnknown)
{
this->splitUnknown = splitUnknown;
}
......@@ -20,7 +20,7 @@ std::size_t RawInputLSTMImpl::getInputSize()
return leftWindow + rightWindow + 1;
}
void RawInputLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
void RawInputLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
{
if (leftWindow < 0 or rightWindow < 0)
return;
......
......@@ -21,7 +21,7 @@ std::size_t SplitTransLSTMImpl::getInputSize()
return maxNbTrans;
}
void SplitTransLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config) const
void SplitTransLSTMImpl::addToContext(std::vector<std::vector<long>> & context, Dict & dict, const Config & config, bool) const
{
auto & splitTransitions = config.getAppliableSplitTransitions();
for (auto & contextElement : context)
......
......@@ -10,6 +10,7 @@ void Trainer::createDataset(BaseConfig & goldConfig, bool debug, std::filesystem
SubConfig config(goldConfig, goldConfig.getNbLines());
machine.trainMode(false);
machine.splitUnknown(true);
machine.setDictsState(Dict::State::Open);
extractExamples(config, debug, dir, epoch, dynamicOracleInterval);
......@@ -23,6 +24,7 @@ void Trainer::createDevDataset(BaseConfig & goldConfig, bool debug, std::filesys
SubConfig config(goldConfig, goldConfig.getNbLines());
machine.trainMode(false);
machine.splitUnknown(false);
machine.setDictsState(Dict::State::Closed);
extractExamples(config, debug, dir, epoch, dynamicOracleInterval);
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment