Skip to content
Snippets Groups Projects
Commit 26cc83e4 authored by Franck Dary's avatar Franck Dary
Browse files

NeuralNetwork::extractContext can now generate multiple variants of context

parent 25ac06cb
Branches
No related tags found
No related merge requests found
...@@ -25,10 +25,10 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool ...@@ -25,10 +25,10 @@ void Decoder::decode(BaseConfig & config, std::size_t beamSize, bool debug, bool
config.printForDebug(stderr); config.printForDebug(stderr);
auto dictState = machine.getDict(config.getState()).getState(); auto dictState = machine.getDict(config.getState()).getState();
auto context = machine.getClassifier()->getNN()->extractContext(config, machine.getDict(config.getState())); auto context = machine.getClassifier()->getNN()->extractContext(config, machine.getDict(config.getState())).back();
machine.getDict(config.getState()).setState(dictState); machine.getDict(config.getState()).setState(dictState);
auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone().to(NeuralNetworkImpl::device); auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze(); auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze();
int chosenTransition = -1; int chosenTransition = -1;
......
...@@ -33,7 +33,7 @@ class CNNNetworkImpl : public NeuralNetworkImpl ...@@ -33,7 +33,7 @@ class CNNNetworkImpl : public NeuralNetworkImpl
CNNNetworkImpl(int nbOutputs, int unknownValueThreshold, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput); CNNNetworkImpl(int nbOutputs, int unknownValueThreshold, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput);
torch::Tensor forward(torch::Tensor input) override; torch::Tensor forward(torch::Tensor input) override;
std::vector<long> extractContext(Config & config, Dict & dict) const override; std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
}; };
#endif #endif
...@@ -27,7 +27,7 @@ class NeuralNetworkImpl : public torch::nn::Module ...@@ -27,7 +27,7 @@ class NeuralNetworkImpl : public torch::nn::Module
public : public :
virtual torch::Tensor forward(torch::Tensor input) = 0; virtual torch::Tensor forward(torch::Tensor input) = 0;
virtual std::vector<long> extractContext(Config & config, Dict & dict) const; virtual std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const;
std::vector<long> extractContextIndexes(const Config & config) const; std::vector<long> extractContextIndexes(const Config & config) const;
int getContextSize() const; int getContextSize() const;
void setColumns(const std::vector<std::string> & columns); void setColumns(const std::vector<std::string> & columns);
......
...@@ -23,7 +23,7 @@ class RLTNetworkImpl : public NeuralNetworkImpl ...@@ -23,7 +23,7 @@ class RLTNetworkImpl : public NeuralNetworkImpl
RLTNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements); RLTNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements);
torch::Tensor forward(torch::Tensor input) override; torch::Tensor forward(torch::Tensor input) override;
std::vector<long> extractContext(Config & config, Dict & dict) const override; std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
}; };
#endif #endif
...@@ -74,44 +74,52 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input) ...@@ -74,44 +74,52 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input)
return linear2(hiddenDropout(torch::relu(linear1(totalInput)))); return linear2(hiddenDropout(torch::relu(linear1(totalInput))));
} }
std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) const std::vector<std::vector<long>> CNNNetworkImpl::extractContext(Config & config, Dict & dict) const
{ {
if (dict.size() >= maxNbEmbeddings) if (dict.size() >= maxNbEmbeddings)
util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings)); util::warning(fmt::format("dict.size()={} > maxNbEmbeddings={}", dict.size(), maxNbEmbeddings));
std::vector<long> contextIndexes = extractContextIndexes(config); std::vector<long> contextIndexes = extractContextIndexes(config);
std::vector<long> context; std::vector<std::vector<long>> context;
context.emplace_back();
if (rawInputSize > 0) if (rawInputSize > 0)
{ {
for (int i = 0; i < leftWindowRawInput; i++) for (int i = 0; i < leftWindowRawInput; i++)
if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i)) if (config.hasCharacter(config.getCharacterIndex()-leftWindowRawInput+i))
context.push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i)))); context.back().push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()-leftWindowRawInput+i))));
else else
context.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
for (int i = 0; i <= rightWindowRawInput; i++) for (int i = 0; i <= rightWindowRawInput; i++)
if (config.hasCharacter(config.getCharacterIndex()+i)) if (config.hasCharacter(config.getCharacterIndex()+i))
context.back().push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()+i))));
context.push_back(dict.getIndexOrInsert(fmt::format("Letter({})", config.getLetter(config.getCharacterIndex()+i))));
else else
context.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); context.back().push_back(dict.getIndexOrInsert(Dict::nullValueStr));
} }
for (auto index : contextIndexes) for (auto index : contextIndexes)
for (auto & col : columns) for (auto & col : columns)
if (index == -1) if (index == -1)
context.push_back(dict.getIndexOrInsert(Dict::nullValueStr)); for (auto & contextElement : context)
contextElement.push_back(dict.getIndexOrInsert(Dict::nullValueStr));
else else
{ {
int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index)); int dictIndex = dict.getIndexOrInsert(config.getAsFeature(col, index));
for (auto & contextElement : context)
contextElement.push_back(dictIndex);
if (is_training())
if (col == "FORM" || col == "LEMMA") if (col == "FORM" || col == "LEMMA")
if (dict.getNbOccs(dictIndex) < unknownValueThreshold) if (dict.getNbOccs(dictIndex) < unknownValueThreshold)
dictIndex = dict.getIndexOrInsert(Dict::unknownValueStr); {
context.emplace_back(context.back());
context.push_back(dictIndex); context.back().back() = dict.getIndexOrInsert(Dict::unknownValueStr);
}
} }
for (auto & contextElement : context)
for (unsigned int colIndex = 0; colIndex < focusedColumns.size(); colIndex++) for (unsigned int colIndex = 0; colIndex < focusedColumns.size(); colIndex++)
{ {
auto & col = focusedColumns[colIndex]; auto & col = focusedColumns[colIndex];
...@@ -140,7 +148,7 @@ std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) c ...@@ -140,7 +148,7 @@ std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) c
if (index == -1) if (index == -1)
{ {
for (int i = 0; i < maxNbElements[colIndex]; i++) for (int i = 0; i < maxNbElements[colIndex]; i++)
context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); contextElement.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
continue; continue;
} }
...@@ -183,10 +191,13 @@ std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) c ...@@ -183,10 +191,13 @@ std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) c
util::myThrow(fmt::format("elements.size ({}) != maxNbElements[colIndex ({},{})]", elements.size(), maxNbElements[colIndex], col)); util::myThrow(fmt::format("elements.size ({}) != maxNbElements[colIndex ({},{})]", elements.size(), maxNbElements[colIndex], col));
for (auto & element : elements) for (auto & element : elements)
context.emplace_back(dict.getIndexOrInsert(element)); contextElement.emplace_back(dict.getIndexOrInsert(element));
} }
} }
if (!is_training() && context.size() > 1)
util::myThrow(fmt::format("Not in training mode, yet context yields multiple variants (size={})", context.size()));
return context; return context;
} }
...@@ -35,7 +35,7 @@ std::vector<long> NeuralNetworkImpl::extractContextIndexes(const Config & config ...@@ -35,7 +35,7 @@ std::vector<long> NeuralNetworkImpl::extractContextIndexes(const Config & config
return context; return context;
} }
std::vector<long> NeuralNetworkImpl::extractContext(Config & config, Dict & dict) const std::vector<std::vector<long>> NeuralNetworkImpl::extractContext(Config & config, Dict & dict) const
{ {
std::vector<long> indexes = extractContextIndexes(config); std::vector<long> indexes = extractContextIndexes(config);
std::vector<long> context; std::vector<long> context;
...@@ -47,7 +47,7 @@ std::vector<long> NeuralNetworkImpl::extractContext(Config & config, Dict & dict ...@@ -47,7 +47,7 @@ std::vector<long> NeuralNetworkImpl::extractContext(Config & config, Dict & dict
else else
context.push_back(dict.getIndexOrInsert(config.getAsFeature(col, index))); context.push_back(dict.getIndexOrInsert(config.getAsFeature(col, index)));
return context; return {context};
} }
int NeuralNetworkImpl::getContextSize() const int NeuralNetworkImpl::getContextSize() const
......
...@@ -79,7 +79,7 @@ torch::Tensor RLTNetworkImpl::forward(torch::Tensor input) ...@@ -79,7 +79,7 @@ torch::Tensor RLTNetworkImpl::forward(torch::Tensor input)
return linear2(torch::relu(linear1(representation))); return linear2(torch::relu(linear1(representation)));
} }
std::vector<long> RLTNetworkImpl::extractContext(Config & config, Dict & dict) const std::vector<std::vector<long>> RLTNetworkImpl::extractContext(Config & config, Dict & dict) const
{ {
std::vector<long> contextIndexes; std::vector<long> contextIndexes;
std::stack<int> leftContext; std::stack<int> leftContext;
...@@ -183,6 +183,6 @@ std::vector<long> RLTNetworkImpl::extractContext(Config & config, Dict & dict) c ...@@ -183,6 +183,6 @@ std::vector<long> RLTNetworkImpl::extractContext(Config & config, Dict & dict) c
else else
context.push_back(dict.getIndexOrInsert(config.getAsFeature(col, l))); context.push_back(dict.getIndexOrInsert(config.getAsFeature(col, l)));
return context; return {context};
} }
...@@ -48,19 +48,23 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch: ...@@ -48,19 +48,23 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::vector<torch:
util::myThrow("No transition appliable !"); util::myThrow("No transition appliable !");
} }
std::vector<std::vector<long>> context;
try try
{ {
auto context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState())); context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState()));
contexts.emplace_back(torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(at::kLong)).clone().to(NeuralNetworkImpl::device)); for (auto & element : context)
contexts.emplace_back(torch::from_blob(element.data(), {(long)element.size()}, torch::TensorOptions(torch::kLong)).clone().to(NeuralNetworkImpl::device));
} catch(std::exception & e) } catch(std::exception & e)
{ {
util::myThrow(fmt::format("Failed to extract context : {}", e.what())); util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
} }
int goldIndex = machine.getTransitionSet().getTransitionIndex(transition); int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
auto gold = torch::zeros(1, torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::device)); auto gold = torch::zeros(1, torch::TensorOptions(torch::kLong).device(NeuralNetworkImpl::device));
gold[0] = goldIndex; gold[0] = goldIndex;
for (auto & element : context)
classes.emplace_back(gold); classes.emplace_back(gold);
transition->apply(config); transition->apply(config);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment