Skip to content
Snippets Groups Projects
Commit 38571427 authored by Franck Dary's avatar Franck Dary
Browse files
parents f93305a5 14bcdc4e
No related branches found
No related tags found
No related merge requests found
...@@ -49,7 +49,9 @@ class Action ...@@ -49,7 +49,9 @@ class Action
static Action moveWordIndex(int movement); static Action moveWordIndex(int movement);
static Action moveCharacterIndex(int movement); static Action moveCharacterIndex(int movement);
static Action addHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & hypothesis); static Action addHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & hypothesis);
static Action addToHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & addition);
static Action addHypothesisRelative(const std::string & colName, Object object, int relativeIndex, const std::string & hypothesis); static Action addHypothesisRelative(const std::string & colName, Object object, int relativeIndex, const std::string & hypothesis);
static Action addToHypothesisRelative(const std::string & colName, Object object, int relativeIndex, const std::string & addition);
static Action pushWordIndexOnStack(); static Action pushWordIndexOnStack();
static Action popStack(); static Action popStack();
static Action emptyStack(); static Action emptyStack();
......
...@@ -17,11 +17,13 @@ class Transition ...@@ -17,11 +17,13 @@ class Transition
private : private :
void initWrite(std::string colName, std::string object, std::string index, std::string value); void initWrite(std::string colName, std::string object, std::string index, std::string value);
void initAdd(std::string colName, std::string object, std::string index, std::string value);
void initShift(); void initShift();
void initLeft(std::string label); void initLeft(std::string label);
void initRight(std::string label); void initRight(std::string label);
void initReduce(); void initReduce();
void initEOS(); void initEOS();
void initNothing();
public : public :
......
...@@ -87,6 +87,79 @@ Action Action::addHypothesis(const std::string & colName, std::size_t lineIndex, ...@@ -87,6 +87,79 @@ Action Action::addHypothesis(const std::string & colName, std::size_t lineIndex,
return {Type::Write, apply, undo, appliable}; return {Type::Write, apply, undo, appliable};
} }
Action Action::addToHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & addition)
{
auto apply = [colName, lineIndex, addition](Config & config, Action &)
{
auto & current = config.getLastNotEmptyHyp(colName, lineIndex);
current = util::isEmpty(current) ? addition : '|' + addition;
};
auto undo = [colName, lineIndex](Config & config, Action &)
{
std::string newValue = config.getLastNotEmpty(colName, lineIndex);
while (!newValue.empty() and newValue.back() == '|')
newValue.pop_back();
if (!newValue.empty())
newValue.pop_back();
config.getLastNotEmpty(colName, lineIndex) = newValue;
};
auto appliable = [colName, lineIndex, addition](const Config & config, const Action &)
{
if (!config.has(colName, lineIndex, 0))
return false;
auto & current = config.getLastNotEmptyHypConst(colName, lineIndex);
auto splited = util::split(current.get(), '|');
for (auto & part : splited)
if (part == addition)
return false;
return true;
};
return {Type::Write, apply, undo, appliable};
}
Action Action::addToHypothesisRelative(const std::string & colName, Object object, int relativeIndex, const std::string & addition)
{
auto apply = [colName, object, relativeIndex, addition](Config & config, Action & a)
{
int lineIndex = 0;
if (object == Object::Buffer)
lineIndex = config.getWordIndex() + relativeIndex;
else
lineIndex = config.getStack(relativeIndex);
return addToHypothesis(colName, lineIndex, addition).apply(config, a);
};
auto undo = [colName, object, relativeIndex](Config & config, Action & a)
{
int lineIndex = 0;
if (object == Object::Buffer)
lineIndex = config.getWordIndex() + relativeIndex;
else
lineIndex = config.getStack(relativeIndex);
return addToHypothesis(colName, lineIndex, "").undo(config, a);
};
auto appliable = [colName, object, relativeIndex, addition](const Config & config, const Action & a)
{
int lineIndex = 0;
if (object == Object::Buffer)
lineIndex = config.getWordIndex() + relativeIndex;
else if (config.hasStack(relativeIndex))
lineIndex = config.getStack(relativeIndex);
else
return false;
return addToHypothesis(colName, lineIndex, addition).appliable(config, a);
};
return {Type::Write, apply, undo, appliable};
}
Action Action::addHypothesisRelative(const std::string & colName, Object object, int relativeIndex, const std::string & hypothesis) Action Action::addHypothesisRelative(const std::string & colName, Object object, int relativeIndex, const std::string & hypothesis)
{ {
auto apply = [colName, object, relativeIndex, hypothesis](Config & config, Action & a) auto apply = [colName, object, relativeIndex, hypothesis](Config & config, Action & a)
......
...@@ -5,11 +5,13 @@ Transition::Transition(const std::string & name) ...@@ -5,11 +5,13 @@ Transition::Transition(const std::string & name)
{ {
std::regex nameRegex("(<(.+)> )?(.+)"); std::regex nameRegex("(<(.+)> )?(.+)");
std::regex writeRegex("WRITE ([bs])\\.(.+) (.+) (.+)"); std::regex writeRegex("WRITE ([bs])\\.(.+) (.+) (.+)");
std::regex addRegex("ADD ([bs])\\.(.+) (.+) (.+)");
std::regex shiftRegex("SHIFT"); std::regex shiftRegex("SHIFT");
std::regex reduceRegex("REDUCE"); std::regex reduceRegex("REDUCE");
std::regex leftRegex("LEFT (.+)"); std::regex leftRegex("LEFT (.+)");
std::regex rightRegex("RIGHT (.+)"); std::regex rightRegex("RIGHT (.+)");
std::regex eosRegex("EOS"); std::regex eosRegex("EOS");
std::regex nothingRegex("NOTHING");
try try
{ {
...@@ -22,6 +24,8 @@ Transition::Transition(const std::string & name) ...@@ -22,6 +24,8 @@ Transition::Transition(const std::string & name)
if (util::doIfNameMatch(writeRegex, this->name, [this](auto sm){initWrite(sm[3], sm[1], sm[2], sm[4]);})) if (util::doIfNameMatch(writeRegex, this->name, [this](auto sm){initWrite(sm[3], sm[1], sm[2], sm[4]);}))
return; return;
if (util::doIfNameMatch(addRegex, this->name, [this](auto sm){initAdd(sm[3], sm[1], sm[2], sm[4]);}))
return;
if (util::doIfNameMatch(shiftRegex, this->name, [this](auto){initShift();})) if (util::doIfNameMatch(shiftRegex, this->name, [this](auto){initShift();}))
return; return;
if (util::doIfNameMatch(reduceRegex, this->name, [this](auto){initReduce();})) if (util::doIfNameMatch(reduceRegex, this->name, [this](auto){initReduce();}))
...@@ -32,6 +36,8 @@ Transition::Transition(const std::string & name) ...@@ -32,6 +36,8 @@ Transition::Transition(const std::string & name)
return; return;
if (util::doIfNameMatch(eosRegex, this->name, [this](auto){initEOS();})) if (util::doIfNameMatch(eosRegex, this->name, [this](auto){initEOS();}))
return; return;
if (util::doIfNameMatch(nothingRegex, this->name, [this](auto){initNothing();}))
return;
throw std::invalid_argument("no match"); throw std::invalid_argument("no match");
...@@ -89,6 +95,39 @@ void Transition::initWrite(std::string colName, std::string object, std::string ...@@ -89,6 +95,39 @@ void Transition::initWrite(std::string colName, std::string object, std::string
}; };
} }
void Transition::initAdd(std::string colName, std::string object, std::string index, std::string value)
{
auto objectValue = Action::str2object(object);
int indexValue = std::stoi(index);
sequence.emplace_back(Action::addToHypothesisRelative(colName, objectValue, indexValue, value));
cost = [colName, objectValue, indexValue, value](const Config & config)
{
int lineIndex = 0;
if (objectValue == Action::Object::Buffer)
lineIndex = config.getWordIndex() + indexValue;
else
lineIndex = config.getStack(indexValue);
auto gold = util::split(config.getConst(colName, lineIndex, 0).get(), '|');
for (auto & part : gold)
if (part == value)
return 0;
return 1;
};
}
void Transition::initNothing()
{
cost = [](const Config &)
{
return 0;
};
}
void Transition::initShift() void Transition::initShift()
{ {
sequence.emplace_back(Action::pushWordIndexOnStack()); sequence.emplace_back(Action::pushWordIndexOnStack());
......
...@@ -8,6 +8,7 @@ class CNNNetworkImpl : public NeuralNetworkImpl ...@@ -8,6 +8,7 @@ class CNNNetworkImpl : public NeuralNetworkImpl
private : private :
static inline std::vector<long> focusedBufferIndexes{0,1}; static inline std::vector<long> focusedBufferIndexes{0,1};
static inline std::vector<long> focusedStackIndexes{0,1};
static inline std::vector<long> windowSizes{2,3,4}; static inline std::vector<long> windowSizes{2,3,4};
static constexpr unsigned int maxNbLetters = 10; static constexpr unsigned int maxNbLetters = 10;
......
...@@ -13,7 +13,7 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i ...@@ -13,7 +13,7 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i
setColumns({"FORM", "UPOS"}); setColumns({"FORM", "UPOS"});
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
linear1 = register_module("linear1", torch::nn::Linear(nbFilters*windowSizes.size()+nbFiltersLetters*windowSizes.size()*focusedBufferIndexes.size(), hiddenSize)); linear1 = register_module("linear1", torch::nn::Linear(nbFilters*windowSizes.size()+nbFiltersLetters*windowSizes.size()*(focusedBufferIndexes.size()+focusedStackIndexes.size()), hiddenSize));
linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs)); linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
for (auto & windowSize : windowSizes) for (auto & windowSize : windowSizes)
{ {
...@@ -28,7 +28,7 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input) ...@@ -28,7 +28,7 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input)
input = input.unsqueeze(0); input = input.unsqueeze(0);
auto wordIndexes = input.narrow(1, 0, columns.size()*(1+leftBorder+rightBorder)); auto wordIndexes = input.narrow(1, 0, columns.size()*(1+leftBorder+rightBorder));
auto wordLetters = input.narrow(1, columns.size()*(1+leftBorder+rightBorder), maxNbLetters*focusedBufferIndexes.size()); auto wordLetters = input.narrow(1, columns.size()*(1+leftBorder+rightBorder), maxNbLetters*(focusedBufferIndexes.size()+focusedStackIndexes.size()));
auto embeddings = wordEmbeddings(wordIndexes).view({wordIndexes.size(0), wordIndexes.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()}).unsqueeze(1); auto embeddings = wordEmbeddings(wordIndexes).view({wordIndexes.size(0), wordIndexes.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()}).unsqueeze(1);
auto lettersEmbeddings = wordEmbeddings(wordLetters).view({wordLetters.size(0), wordLetters.size(1)/maxNbLetters, maxNbLetters, wordEmbeddings->options.embedding_dim()}).unsqueeze(1); auto lettersEmbeddings = wordEmbeddings(wordLetters).view({wordLetters.size(0), wordLetters.size(1)/maxNbLetters, maxNbLetters, wordEmbeddings->options.embedding_dim()}).unsqueeze(1);
...@@ -43,6 +43,14 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input) ...@@ -43,6 +43,14 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input)
auto pooled = torch::max_pool1d(convOut, convOut.size(2)); auto pooled = torch::max_pool1d(convOut, convOut.size(2));
windows.emplace_back(pooled); windows.emplace_back(pooled);
} }
for (unsigned int word = 0; word < focusedStackIndexes.size(); word++)
for (unsigned int i = 0; i < lettersCNNs.size(); i++)
{
auto input = permuted[focusedBufferIndexes.size()+word];
auto convOut = torch::relu(lettersCNNs[i](input).squeeze(-1));
auto pooled = torch::max_pool1d(convOut, convOut.size(2));
windows.emplace_back(pooled);
}
auto lettersCnnOut = torch::cat(windows, 2); auto lettersCnnOut = torch::cat(windows, 2);
lettersCnnOut = lettersCnnOut.view({lettersCnnOut.size(0), -1}); lettersCnnOut = lettersCnnOut.view({lettersCnnOut.size(0), -1});
...@@ -133,6 +141,25 @@ std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) c ...@@ -133,6 +141,25 @@ std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) c
} }
} }
for (auto index : focusedStackIndexes)
{
util::utf8string letters;
if (config.hasStack(index) and config.has("FORM", config.getStack(index),0))
letters = util::splitAsUtf8(config.getAsFeature("FORM", config.getStack(index)).get());
for (unsigned int i = 0; i < maxNbLetters; i++)
{
if (i < letters.size())
{
std::string sLetter = fmt::format("Letter({})", letters[i]);
context.emplace_back(dict.getIndexOrInsert(sLetter));
}
else
{
context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
}
}
}
return context; return context;
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment