Skip to content
Snippets Groups Projects
Commit 66560e71 authored by Franck Dary's avatar Franck Dary
Browse files

Changed the way context indexes are handled by NeuralNetwok

parent 1ac29b7f
No related branches found
No related tags found
No related merge requests found
Showing
with 179 additions and 171 deletions
......@@ -38,7 +38,7 @@ std::vector<std::filesystem::path> findFilesByExtension(std::filesystem::path di
std::string_view getFilenameFromPath(std::string_view s);
std::vector<std::string_view> split(std::string_view s, char delimiter);
std::vector<std::string> split(std::string_view s, char delimiter);
utf8string splitAsUtf8(std::string_view s);
......
......@@ -32,9 +32,9 @@ bool util::isIllegal(utf8char c)
return c == '\n' || c == '\t';
}
std::vector<std::string_view> util::split(std::string_view remaining, char delimiter)
std::vector<std::string> util::split(std::string_view remaining, char delimiter)
{
std::vector<std::string_view> result;
std::vector<std::string> result;
for (auto firstDelimiterIndex = remaining.find_first_of(delimiter); firstDelimiterIndex != std::string_view::npos; firstDelimiterIndex = remaining.find_first_of(delimiter))
{
......
......@@ -109,6 +109,7 @@ class Config
bool rawInputOnlySeparatorsLeft() const;
std::size_t getWordIndex() const;
std::size_t getCharacterIndex() const;
long getRelativeWordIndex(int relativeIndex) const;
const String & getHistory(int relativeIndex) const;
std::size_t getStack(int relativeIndex) const;
bool hasHistory(int relativeIndex) const;
......
......@@ -37,59 +37,72 @@ void Classifier::initNeuralNetwork(const std::string & topology)
"OneWord(focusedIndex) : Only use the word embedding of the focused word.",
[this,topology](auto sm)
{
this->nn.reset(new OneWordNetworkImpl(this->transitionSet->size(), std::stoi(sm[1])));
this->nn.reset(new OneWordNetworkImpl(this->transitionSet->size(), std::stoi(sm.str(1))));
}
},
{
std::regex("ConcatWords\\(([+\\-]?\\d+),([+\\-]?\\d+),([+\\-]?\\d+)\\)"),
"ConcatWords(leftBorder,rightBorder,nbStack) : Concatenate embeddings of words in context.",
std::regex("ConcatWords\\(\\{(.*)\\},\\{(.*)\\}\\)"),
"ConcatWords({bufferContext},{stackContext}) : Concatenate embeddings of words in context.",
[this,topology](auto sm)
{
this->nn.reset(new ConcatWordsNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3])));
std::vector<int> bufferContext, stackContext;
for (auto s : util::split(sm.str(1), ','))
bufferContext.emplace_back(std::stoi(s));
for (auto s : util::split(sm.str(2), ','))
stackContext.emplace_back(std::stoi(s));
this->nn.reset(new ConcatWordsNetworkImpl(this->transitionSet->size(), bufferContext, stackContext));
}
},
{
std::regex("CNN\\(([+\\-]?\\d+),(\\d+),(\\d+),(\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"),
"CNN(unknownValueThreshold,leftBorder,rightBorder,nbStack,{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.",
std::regex("CNN\\(([+\\-]?\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"),
"CNN(unknownValueThreshold,{bufferContext},{stackContext},{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.",
[this,topology](auto sm)
{
std::vector<int> focusedBuffer, focusedStack, maxNbElements;
std::vector<int> focusedBuffer, focusedStack, maxNbElements, bufferContext, stackContext;
std::vector<std::string> focusedColumns, columns;
for (auto s : util::split(std::string(sm[5]), ','))
for (auto s : util::split(sm.str(2), ','))
bufferContext.emplace_back(std::stoi(s));
for (auto s : util::split(sm.str(3), ','))
stackContext.emplace_back(std::stoi(s));
for (auto s : util::split(sm.str(4), ','))
columns.emplace_back(s);
for (auto s : util::split(std::string(sm[6]), ','))
focusedBuffer.push_back(std::stoi(std::string(s)));
for (auto s : util::split(std::string(sm[7]), ','))
focusedStack.push_back(std::stoi(std::string(s)));
for (auto s : util::split(std::string(sm[8]), ','))
for (auto s : util::split(sm.str(5), ','))
focusedBuffer.push_back(std::stoi(s));
for (auto s : util::split(sm.str(6), ','))
focusedStack.push_back(std::stoi(s));
for (auto s : util::split(sm.str(7), ','))
focusedColumns.emplace_back(s);
for (auto s : util::split(std::string(sm[9]), ','))
maxNbElements.push_back(std::stoi(std::string(s)));
for (auto s : util::split(sm.str(8), ','))
maxNbElements.push_back(std::stoi(s));
if (focusedColumns.size() != maxNbElements.size())
util::myThrow("focusedColumns.size() != maxNbElements.size()");
this->nn.reset(new CNNNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]), std::stoi(sm[4]), columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm[10]), std::stoi(sm[11])));
this->nn.reset(new CNNNetworkImpl(this->transitionSet->size(), std::stoi(sm.str(1)), bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm.str(9)), std::stoi(sm.str(10))));
}
},
{
std::regex("LSTM\\(([+\\-]?\\d+),(\\d+),(\\d+),(\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"),
"LSTM(unknownValueThreshold,leftBorder,rightBorder,nbStack,{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.",
std::regex("LSTM\\(([+\\-]?\\d+),\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\},\\{(.*)\\}\\,([+\\-]?\\d+)\\,([+\\-]?\\d+)\\)"),
"LSTM(unknownValueThreshold,{bufferContext},{stackContext},{columns},{focusedBuffer},{focusedStack},{focusedColumns},{maxNbElements},leftBorderRawInput, rightBorderRawInput) : CNN to capture context.",
[this,topology](auto sm)
{
std::vector<int> focusedBuffer, focusedStack, maxNbElements;
std::vector<int> focusedBuffer, focusedStack, maxNbElements, bufferContext, stackContext;
std::vector<std::string> focusedColumns, columns;
for (auto s : util::split(std::string(sm[5]), ','))
for (auto s : util::split(sm.str(2), ','))
bufferContext.emplace_back(std::stoi(s));
for (auto s : util::split(sm.str(3), ','))
stackContext.emplace_back(std::stoi(s));
for (auto s : util::split(sm.str(4), ','))
columns.emplace_back(s);
for (auto s : util::split(std::string(sm[6]), ','))
focusedBuffer.push_back(std::stoi(std::string(s)));
for (auto s : util::split(std::string(sm[7]), ','))
focusedStack.push_back(std::stoi(std::string(s)));
for (auto s : util::split(std::string(sm[8]), ','))
for (auto s : util::split(sm.str(5), ','))
focusedBuffer.push_back(std::stoi(s));
for (auto s : util::split(sm.str(6), ','))
focusedStack.push_back(std::stoi(s));
for (auto s : util::split(sm.str(7), ','))
focusedColumns.emplace_back(s);
for (auto s : util::split(std::string(sm[9]), ','))
maxNbElements.push_back(std::stoi(std::string(s)));
for (auto s : util::split(sm.str(8), ','))
maxNbElements.push_back(std::stoi(s));
if (focusedColumns.size() != maxNbElements.size())
util::myThrow("focusedColumns.size() != maxNbElements.size()");
this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3]), std::stoi(sm[4]), columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm[10]), std::stoi(sm[11])));
this->nn.reset(new LSTMNetworkImpl(this->transitionSet->size(), std::stoi(sm.str(1)), bufferContext, stackContext, columns, focusedBuffer, focusedStack, focusedColumns, maxNbElements, std::stoi(sm.str(9)), std::stoi(sm.str(10))));
}
},
{
......@@ -97,19 +110,28 @@ void Classifier::initNeuralNetwork(const std::string & topology)
"RLT(leftBorder,rightBorder,nbStack) : Recursive tree LSTM.",
[this,topology](auto sm)
{
this->nn.reset(new RLTNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3])));
this->nn.reset(new RLTNetworkImpl(this->transitionSet->size(), std::stoi(sm.str(1)), std::stoi(sm.str(2)), std::stoi(sm.str(3))));
}
},
};
std::string errorMessage = fmt::format("Unknown neural network '{}', available networks :\n", topology);
for (auto & initializer : initializers)
if (util::doIfNameMatch(std::get<0>(initializer),topology,std::get<2>(initializer)))
try
{
this->nn->to(NeuralNetworkImpl::device);
return;
if (util::doIfNameMatch(std::get<0>(initializer),topology,std::get<2>(initializer)))
{
this->nn->to(NeuralNetworkImpl::device);
return;
}
}
catch (std::exception & e)
{
errorMessage = fmt::format("Caught({}) {}", e.what(), errorMessage);
break;
}
std::string errorMessage = fmt::format("Unknown neural network '{}', available networks :\n", topology);
for (auto & initializer : initializers)
errorMessage += std::get<1>(initializer) + "\n";
......
......@@ -573,3 +573,29 @@ void Config::addMissingColumns()
}
}
long Config::getRelativeWordIndex(int relativeIndex) const
{
if (relativeIndex < 0)
{
for (int index = getWordIndex()-1, counter = 0; has(0,index,0); --index)
if (!isCommentPredicted(index))
{
--counter;
if (counter == relativeIndex)
return index;
}
}
else
{
for (int index = getWordIndex(), counter = 0; has(0,index,0); ++index)
if (!isCommentPredicted(index))
{
if (counter == relativeIndex)
return index;
++counter;
}
}
return -1;
}
......@@ -11,8 +11,6 @@ class CNNNetworkImpl : public NeuralNetworkImpl
static constexpr int maxNbEmbeddings = 50000;
int unknownValueThreshold;
std::vector<int> focusedBufferIndexes;
std::vector<int> focusedStackIndexes;
std::vector<std::string> focusedColumns;
std::vector<int> maxNbElements;
int leftWindowRawInput;
......@@ -31,7 +29,7 @@ class CNNNetworkImpl : public NeuralNetworkImpl
public :
CNNNetworkImpl(int nbOutputs, int unknownValueThreshold, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput);
CNNNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput);
torch::Tensor forward(torch::Tensor input) override;
std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
};
......
......@@ -14,7 +14,7 @@ class ConcatWordsNetworkImpl : public NeuralNetworkImpl
public :
ConcatWordsNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements);
ConcatWordsNetworkImpl(int nbOutputs, const std::vector<int> & bufferContext, const std::vector<int> & stackContext);
torch::Tensor forward(torch::Tensor input) override;
};
......
......@@ -10,8 +10,6 @@ class LSTMNetworkImpl : public NeuralNetworkImpl
static constexpr int maxNbEmbeddings = 50000;
int unknownValueThreshold;
std::vector<int> focusedBufferIndexes;
std::vector<int> focusedStackIndexes;
std::vector<std::string> focusedColumns;
std::vector<int> maxNbElements;
int leftWindowRawInput;
......@@ -30,7 +28,7 @@ class LSTMNetworkImpl : public NeuralNetworkImpl
public :
LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput);
LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput);
torch::Tensor forward(torch::Tensor input) override;
std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const override;
};
......
......@@ -13,22 +13,25 @@ class NeuralNetworkImpl : public torch::nn::Module
protected :
unsigned leftBorder{5};
unsigned rightBorder{5};
unsigned nbStackElements{2};
std::vector<std::string> columns{"FORM"};
std::vector<int> bufferContext{-3,-2,-1,0,1};
std::vector<int> stackContext{};
std::vector<int> bufferFocused{};
std::vector<int> stackFocused{};
protected :
void setRightBorder(int rightBorder);
void setLeftBorder(int leftBorder);
void setNbStackElements(int nbStackElements);
void setBufferContext(const std::vector<int> & bufferContext);
void setStackContext(const std::vector<int> & stackContext);
void setBufferFocused(const std::vector<int> & bufferFocused);
void setStackFocused(const std::vector<int> & stackFocused);
public :
virtual torch::Tensor forward(torch::Tensor input) = 0;
virtual std::vector<std::vector<long>> extractContext(Config & config, Dict & dict) const;
std::vector<long> extractContextIndexes(const Config & config) const;
std::vector<long> extractFocusedIndexes(const Config & config) const;
int getContextSize() const;
void setColumns(const std::vector<std::string> & columns);
};
......
......@@ -9,7 +9,6 @@ class OneWordNetworkImpl : public NeuralNetworkImpl
torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Linear linear{nullptr};
int focusedIndex;
public :
......
......@@ -11,6 +11,8 @@ class RLTNetworkImpl : public NeuralNetworkImpl
static inline std::vector<long> focusedBufferIndexes{0,1,2};
static inline std::vector<long> focusedStackIndexes{0,1};
int leftBorder, rightBorder;
torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Linear linear1{nullptr};
torch::nn::Linear linear2{nullptr};
......
#include "CNNNetwork.hpp"
CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int unknownValueThreshold, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) : unknownValueThreshold(unknownValueThreshold), focusedBufferIndexes(focusedBufferIndexes), focusedStackIndexes(focusedStackIndexes), focusedColumns(focusedColumns), maxNbElements(maxNbElements), leftWindowRawInput(leftWindowRawInput), rightWindowRawInput(rightWindowRawInput)
CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) : unknownValueThreshold(unknownValueThreshold), focusedColumns(focusedColumns), maxNbElements(maxNbElements), leftWindowRawInput(leftWindowRawInput), rightWindowRawInput(rightWindowRawInput)
{
constexpr int embeddingsSize = 64;
constexpr int hiddenSize = 1024;
constexpr int nbFiltersContext = 512;
constexpr int nbFiltersFocused = 64;
setLeftBorder(leftBorder);
setRightBorder(rightBorder);
setNbStackElements(nbStackElements);
setBufferContext(bufferContext);
setStackContext(stackContext);
setColumns(columns);
setBufferFocused(focusedBufferIndexes);
setStackFocused(focusedStackIndexes);
rawInputSize = leftWindowRawInput + rightWindowRawInput + 1;
if (leftWindowRawInput < 0 or rightWindowRawInput < 0)
......@@ -42,7 +43,7 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input)
auto embeddings = embeddingsDropout(wordEmbeddings(input));
auto context = embeddings.narrow(1, rawInputSize, columns.size()*(1+leftBorder+rightBorder));
auto context = embeddings.narrow(1, rawInputSize, getContextSize());
context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()});
auto elementsEmbeddings = embeddings.narrow(1, rawInputSize+context.size(1), input.size(1)-(rawInputSize+context.size(1)));
......@@ -59,7 +60,7 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input)
for (unsigned int i = 0; i < focusedColumns.size(); i++)
{
long nbElements = maxNbElements[i];
for (unsigned int focused = 0; focused < focusedBufferIndexes.size()+focusedStackIndexes.size(); focused++)
for (unsigned int focused = 0; focused < bufferFocused.size()+stackFocused.size(); focused++)
{
auto cnnInput = elementsEmbeddings.narrow(1, curIndex, nbElements).unsqueeze(1);
curIndex += nbElements;
......@@ -119,30 +120,13 @@ std::vector<std::vector<long>> CNNNetworkImpl::extractContext(Config & config, D
}
}
std::vector<long> focusedIndexes = extractFocusedIndexes(config);
for (auto & contextElement : context)
for (unsigned int colIndex = 0; colIndex < focusedColumns.size(); colIndex++)
{
auto & col = focusedColumns[colIndex];
std::vector<int> focusedIndexes;
for (auto relIndex : focusedBufferIndexes)
{
int index = relIndex + leftBorder;
if (index < 0 || index >= (int)contextIndexes.size())
focusedIndexes.push_back(-1);
else
focusedIndexes.push_back(contextIndexes[index]);
}
for (auto index : focusedStackIndexes)
{
if (!config.hasStack(index))
focusedIndexes.push_back(-1);
else if (!config.has(col, config.getStack(index), 0))
focusedIndexes.push_back(-1);
else
focusedIndexes.push_back(config.getStack(index));
}
for (auto index : focusedIndexes)
{
if (index == -1)
......
#include "ConcatWordsNetwork.hpp"
ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements)
ConcatWordsNetworkImpl::ConcatWordsNetworkImpl(int nbOutputs, const std::vector<int> & bufferContext, const std::vector<int> & stackContext)
{
constexpr int embeddingsSize = 100;
setLeftBorder(leftBorder);
setRightBorder(rightBorder);
setNbStackElements(nbStackElements);
constexpr int embeddingsSize = 64;
constexpr int hiddenSize = 500;
setBufferContext(bufferContext);
setStackContext(stackContext);
setColumns({"FORM", "UPOS"});
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, 500));
linear2 = register_module("linear2", torch::nn::Linear(500, nbOutputs));
linear1 = register_module("linear1", torch::nn::Linear(getContextSize()*embeddingsSize, hiddenSize));
linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
dropout = register_module("dropout", torch::nn::Dropout(0.3));
}
torch::Tensor ConcatWordsNetworkImpl::forward(torch::Tensor input)
{
// input dim = {batch, sequence, embeddings}
auto wordsAsEmb = dropout(wordEmbeddings(input));
// reshaped dim = {batch, sequence of embeddings}
auto reshaped = wordsAsEmb.dim() == 3 ? torch::reshape(wordsAsEmb, {wordsAsEmb.size(0), wordsAsEmb.size(1)*wordsAsEmb.size(2)}) : torch::reshape(wordsAsEmb, {wordsAsEmb.size(0)*wordsAsEmb.size(1)});
return linear2(torch::relu(linear1(reshaped)));
if (input.dim() == 1)
input = input.unsqueeze(0);
auto wordsAsEmb = dropout(wordEmbeddings(input).view({input.size(0), -1}));
return linear2(torch::relu(linear1(wordsAsEmb)));
}
#include "LSTMNetwork.hpp"
LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) : unknownValueThreshold(unknownValueThreshold), focusedBufferIndexes(focusedBufferIndexes), focusedStackIndexes(focusedStackIndexes), focusedColumns(focusedColumns), maxNbElements(maxNbElements), leftWindowRawInput(leftWindowRawInput), rightWindowRawInput(rightWindowRawInput)
LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, std::vector<int> bufferContext, std::vector<int> stackContext, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) : unknownValueThreshold(unknownValueThreshold), focusedColumns(focusedColumns), maxNbElements(maxNbElements), leftWindowRawInput(leftWindowRawInput), rightWindowRawInput(rightWindowRawInput)
{
constexpr int embeddingsSize = 64;
constexpr int hiddenSize = 1024;
......@@ -8,10 +8,11 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, int l
constexpr int focusedLSTMSize = 64;
constexpr int rawInputLSTMSize = 16;
setLeftBorder(leftBorder);
setRightBorder(rightBorder);
setNbStackElements(nbStackElements);
setBufferContext(bufferContext);
setStackContext(stackContext);
setColumns(columns);
setBufferFocused(focusedBufferIndexes);
setStackFocused(focusedStackIndexes);
rawInputSize = leftWindowRawInput + rightWindowRawInput + 1;
if (leftWindowRawInput < 0 or rightWindowRawInput < 0)
......@@ -34,7 +35,7 @@ LSTMNetworkImpl::LSTMNetworkImpl(int nbOutputs, int unknownValueThreshold, int l
for (auto & col : focusedColumns)
{
lstms.emplace_back(register_module(fmt::format("LSTM_{}", col), torch::nn::LSTM(torch::nn::LSTMOptions(embeddingsSize, focusedLSTMSize).batch_first(true).bidirectional(true))));
totalLSTMOutputSize += lstms.back()->options.hidden_size() * (lstms.back()->options.bidirectional() ? 4 : 1) * (focusedBufferIndexes.size()+focusedStackIndexes.size());
totalLSTMOutputSize += lstms.back()->options.hidden_size() * (lstms.back()->options.bidirectional() ? 4 : 1) * (bufferFocused.size()+stackFocused.size());
}
linear1 = register_module("linear1", torch::nn::Linear(totalLSTMOutputSize, hiddenSize));
......@@ -48,7 +49,7 @@ torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input)
auto embeddings = embeddingsDropout(wordEmbeddings(input));
auto context = embeddings.narrow(1, rawInputSize, columns.size()*(1+leftBorder+rightBorder));
auto context = embeddings.narrow(1, rawInputSize, getContextSize());
context = context.view({context.size(0), context.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()});
......@@ -67,7 +68,7 @@ torch::Tensor LSTMNetworkImpl::forward(torch::Tensor input)
for (unsigned int i = 0; i < focusedColumns.size(); i++)
{
long nbElements = maxNbElements[i];
for (unsigned int focused = 0; focused < focusedBufferIndexes.size()+focusedStackIndexes.size(); focused++)
for (unsigned int focused = 0; focused < bufferFocused.size()+stackFocused.size(); focused++)
{
auto lstmInput = elementsEmbeddings.narrow(1, curIndex, nbElements);
curIndex += nbElements;
......@@ -136,30 +137,13 @@ std::vector<std::vector<long>> LSTMNetworkImpl::extractContext(Config & config,
}
}
std::vector<long> focusedIndexes = extractFocusedIndexes(config);
for (auto & contextElement : context)
for (unsigned int colIndex = 0; colIndex < focusedColumns.size(); colIndex++)
{
auto & col = focusedColumns[colIndex];
std::vector<int> focusedIndexes;
for (auto relIndex : focusedBufferIndexes)
{
int index = relIndex + leftBorder;
if (index < 0 || index >= (int)contextIndexes.size())
focusedIndexes.push_back(-1);
else
focusedIndexes.push_back(contextIndexes[index]);
}
for (auto index : focusedStackIndexes)
{
if (!config.hasStack(index))
focusedIndexes.push_back(-1);
else if (!config.has(col, config.getStack(index), 0))
focusedIndexes.push_back(-1);
else
focusedIndexes.push_back(config.getStack(index));
}
for (auto index : focusedIndexes)
{
if (index == -1)
......
......@@ -4,31 +4,30 @@ torch::Device NeuralNetworkImpl::device(torch::cuda::is_available() ? torch::kCU
std::vector<long> NeuralNetworkImpl::extractContextIndexes(const Config & config) const
{
std::stack<long> leftContext;
for (int index = config.getWordIndex()-1; config.has(0,index,0) && leftContext.size() < leftBorder; --index)
if (!config.isCommentPredicted(index))
leftContext.push(index);
std::vector<long> context;
while (context.size() < leftBorder-leftContext.size())
context.emplace_back(-1);
while (!leftContext.empty())
{
context.emplace_back(leftContext.top());
leftContext.pop();
}
for (int index : bufferContext)
context.emplace_back(config.getRelativeWordIndex(index));
for (int index : stackContext)
if (config.hasStack(index))
context.emplace_back(config.getStack(index));
else
context.emplace_back(-1);
return context;
}
for (int index = config.getWordIndex(); config.has(0,index,0) && context.size() < leftBorder+rightBorder+1; ++index)
if (!config.isCommentPredicted(index))
context.emplace_back(index);
std::vector<long> NeuralNetworkImpl::extractFocusedIndexes(const Config & config) const
{
std::vector<long> context;
while (context.size() < leftBorder+rightBorder+1)
context.emplace_back(-1);
for (int index : bufferFocused)
context.emplace_back(config.getRelativeWordIndex(index));
for (unsigned int i = 0; i < nbStackElements; i++)
if (config.hasStack(i))
context.emplace_back(config.getStack(i));
for (int index : stackFocused)
if (config.hasStack(index))
context.emplace_back(config.getStack(index));
else
context.emplace_back(-1);
......@@ -52,22 +51,27 @@ std::vector<std::vector<long>> NeuralNetworkImpl::extractContext(Config & config
int NeuralNetworkImpl::getContextSize() const
{
return columns.size()*(1 + leftBorder + rightBorder + nbStackElements);
return columns.size()*(bufferContext.size()+stackContext.size());
}
void NeuralNetworkImpl::setBufferContext(const std::vector<int> & bufferContext)
{
this->bufferContext = bufferContext;
}
void NeuralNetworkImpl::setRightBorder(int rightBorder)
void NeuralNetworkImpl::setStackContext(const std::vector<int> & stackContext)
{
this->rightBorder = rightBorder;
this->stackContext = stackContext;
}
void NeuralNetworkImpl::setLeftBorder(int leftBorder)
void NeuralNetworkImpl::setBufferFocused(const std::vector<int> & bufferFocused)
{
this->leftBorder = leftBorder;
this->bufferFocused = bufferFocused;
}
void NeuralNetworkImpl::setNbStackElements(int nbStackElements)
void NeuralNetworkImpl::setStackFocused(const std::vector<int> & stackFocused)
{
this->nbStackElements = nbStackElements;
this->stackFocused = stackFocused;
}
void NeuralNetworkImpl::setColumns(const std::vector<std::string> & columns)
......
......@@ -2,36 +2,22 @@
OneWordNetworkImpl::OneWordNetworkImpl(int nbOutputs, int focusedIndex)
{
constexpr int embeddingsSize = 30;
constexpr int embeddingsSize = 64;
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(200000, embeddingsSize)));
linear = register_module("linear", torch::nn::Linear(embeddingsSize, nbOutputs));
int leftBorder = 0;
int rightBorder = 0;
if (focusedIndex < 0)
leftBorder = -focusedIndex;
if (focusedIndex > 0)
rightBorder = focusedIndex;
this->focusedIndex = focusedIndex <= 0 ? 0 : focusedIndex;
setLeftBorder(leftBorder);
setRightBorder(rightBorder);
setNbStackElements(0);
setBufferContext({focusedIndex});
setStackContext({});
setColumns({"FORM", "UPOS"});
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
linear = register_module("linear", torch::nn::Linear(getContextSize()*embeddingsSize, nbOutputs));
}
torch::Tensor OneWordNetworkImpl::forward(torch::Tensor input)
{
// input dim = {batch, sequence, embeddings}
auto wordsAsEmb = wordEmbeddings(input);
auto reshaped = wordsAsEmb;
// reshaped dim = {sequence, batch, embeddings}
if (reshaped.dim() == 3)
reshaped = wordsAsEmb.permute({1,0,2});
auto res = linear(reshaped[focusedIndex]);
if (input.dim() == 1)
input = input.unsqueeze(0);
auto wordAsEmb = wordEmbeddings(input).view({input.size(0),-1});
auto res = linear(wordAsEmb);
return res;
}
......
......@@ -7,9 +7,11 @@ RLTNetworkImpl::RLTNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i
constexpr int treeEmbeddingsSize = 256;
constexpr int hiddenSize = 500;
setLeftBorder(leftBorder);
setRightBorder(rightBorder);
setNbStackElements(nbStackElements);
//TODO gerer ces context
this->leftBorder = leftBorder;
this->rightBorder = rightBorder;
setBufferContext({});
setStackContext({});
setColumns({"FORM", "UPOS"});
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
......@@ -27,10 +29,10 @@ torch::Tensor RLTNetworkImpl::forward(torch::Tensor input)
input = input.unsqueeze(0);
auto focusedIndexes = input.narrow(1, 0, focusedBufferIndexes.size()+focusedStackIndexes.size());
auto computeOrder = input.narrow(1, focusedIndexes.size(1), leftBorder+rightBorder+1);
auto childsFlat = input.narrow(1, focusedIndexes.size(1)+computeOrder.size(1), maxNbChilds*(leftBorder+rightBorder+1));
auto computeOrder = input.narrow(1, focusedIndexes.size(1), getContextSize()/columns.size());
auto childsFlat = input.narrow(1, focusedIndexes.size(1)+computeOrder.size(1), maxNbChilds*(getContextSize()/columns.size()));
auto childs = torch::reshape(childsFlat, {childsFlat.size(0), computeOrder.size(1), maxNbChilds});
auto wordIndexes = input.narrow(1, focusedIndexes.size(1)+computeOrder.size(1)+childsFlat.size(1), columns.size()*(leftBorder+rightBorder+1));
auto wordIndexes = input.narrow(1, focusedIndexes.size(1)+computeOrder.size(1)+childsFlat.size(1), getContextSize());
auto baseEmbeddings = wordEmbeddings(wordIndexes);
auto concatBaseEmbeddings = torch::reshape(baseEmbeddings, {baseEmbeddings.size(0), (int)baseEmbeddings.size(1)/(int)columns.size(), (int)baseEmbeddings.size(2)*(int)columns.size()});
auto vectorRepresentations = vectorBiLSTM(concatBaseEmbeddings).output;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment