diff --git a/reading_machine/include/Action.hpp b/reading_machine/include/Action.hpp index a20f68a3d252fe3add08bffa6e80b4408d9ed549..994fe9cad4ecf07c810857beaeb7816788731d74 100644 --- a/reading_machine/include/Action.hpp +++ b/reading_machine/include/Action.hpp @@ -49,7 +49,9 @@ class Action static Action moveWordIndex(int movement); static Action moveCharacterIndex(int movement); static Action addHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & hypothesis); + static Action addToHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & addition); static Action addHypothesisRelative(const std::string & colName, Object object, int relativeIndex, const std::string & hypothesis); + static Action addToHypothesisRelative(const std::string & colName, Object object, int relativeIndex, const std::string & addition); static Action pushWordIndexOnStack(); static Action popStack(); static Action emptyStack(); diff --git a/reading_machine/include/Transition.hpp b/reading_machine/include/Transition.hpp index c3a4589170208ed897b1eebca731c74edf399dd1..c7309a6c6a9601867b48b4b3580ef3d275623890 100644 --- a/reading_machine/include/Transition.hpp +++ b/reading_machine/include/Transition.hpp @@ -17,11 +17,13 @@ class Transition private : void initWrite(std::string colName, std::string object, std::string index, std::string value); + void initAdd(std::string colName, std::string object, std::string index, std::string value); void initShift(); void initLeft(std::string label); void initRight(std::string label); void initReduce(); void initEOS(); + void initNothing(); public : diff --git a/reading_machine/src/Action.cpp b/reading_machine/src/Action.cpp index 6cf16785636a7d8e1a2b820ae8217292e0331511..ca954735fd571324a78a3d91c38a22a7b9c781c4 100644 --- a/reading_machine/src/Action.cpp +++ b/reading_machine/src/Action.cpp @@ -87,6 +87,79 @@ Action Action::addHypothesis(const std::string & colName, std::size_t lineIndex, return {Type::Write, apply, undo, appliable}; } +Action Action::addToHypothesis(const std::string & colName, std::size_t lineIndex, const std::string & addition) +{ + auto apply = [colName, lineIndex, addition](Config & config, Action &) + { + auto & current = config.getLastNotEmptyHyp(colName, lineIndex); + current = util::isEmpty(current) ? addition : '|' + addition; + }; + + auto undo = [colName, lineIndex](Config & config, Action &) + { + std::string newValue = config.getLastNotEmpty(colName, lineIndex); + while (!newValue.empty() and newValue.back() == '|') + newValue.pop_back(); + if (!newValue.empty()) + newValue.pop_back(); + config.getLastNotEmpty(colName, lineIndex) = newValue; + }; + + auto appliable = [colName, lineIndex, addition](const Config & config, const Action &) + { + if (!config.has(colName, lineIndex, 0)) + return false; + auto & current = config.getLastNotEmptyHypConst(colName, lineIndex); + auto splited = util::split(current.get(), '|'); + for (auto & part : splited) + if (part == addition) + return false; + return true; + }; + + return {Type::Write, apply, undo, appliable}; +} + +Action Action::addToHypothesisRelative(const std::string & colName, Object object, int relativeIndex, const std::string & addition) +{ + auto apply = [colName, object, relativeIndex, addition](Config & config, Action & a) + { + int lineIndex = 0; + if (object == Object::Buffer) + lineIndex = config.getWordIndex() + relativeIndex; + else + lineIndex = config.getStack(relativeIndex); + + return addToHypothesis(colName, lineIndex, addition).apply(config, a); + }; + + auto undo = [colName, object, relativeIndex](Config & config, Action & a) + { + int lineIndex = 0; + if (object == Object::Buffer) + lineIndex = config.getWordIndex() + relativeIndex; + else + lineIndex = config.getStack(relativeIndex); + + return addToHypothesis(colName, lineIndex, "").undo(config, a); + }; + + auto appliable = [colName, object, relativeIndex, addition](const Config & config, const Action & a) + { + int lineIndex = 0; + if (object == Object::Buffer) + lineIndex = config.getWordIndex() + relativeIndex; + else if (config.hasStack(relativeIndex)) + lineIndex = config.getStack(relativeIndex); + else + return false; + + return addToHypothesis(colName, lineIndex, addition).appliable(config, a); + }; + + return {Type::Write, apply, undo, appliable}; +} + Action Action::addHypothesisRelative(const std::string & colName, Object object, int relativeIndex, const std::string & hypothesis) { auto apply = [colName, object, relativeIndex, hypothesis](Config & config, Action & a) diff --git a/reading_machine/src/Transition.cpp b/reading_machine/src/Transition.cpp index f18450c0a0d644fc07402461dcc04d9fd9d7f617..259e7f4db3cf01e9424ed77ba65ca4a191138d00 100644 --- a/reading_machine/src/Transition.cpp +++ b/reading_machine/src/Transition.cpp @@ -5,11 +5,13 @@ Transition::Transition(const std::string & name) { std::regex nameRegex("(<(.+)> )?(.+)"); std::regex writeRegex("WRITE ([bs])\\.(.+) (.+) (.+)"); + std::regex addRegex("ADD ([bs])\\.(.+) (.+) (.+)"); std::regex shiftRegex("SHIFT"); std::regex reduceRegex("REDUCE"); std::regex leftRegex("LEFT (.+)"); std::regex rightRegex("RIGHT (.+)"); std::regex eosRegex("EOS"); + std::regex nothingRegex("NOTHING"); try { @@ -22,6 +24,8 @@ Transition::Transition(const std::string & name) if (util::doIfNameMatch(writeRegex, this->name, [this](auto sm){initWrite(sm[3], sm[1], sm[2], sm[4]);})) return; + if (util::doIfNameMatch(addRegex, this->name, [this](auto sm){initAdd(sm[3], sm[1], sm[2], sm[4]);})) + return; if (util::doIfNameMatch(shiftRegex, this->name, [this](auto){initShift();})) return; if (util::doIfNameMatch(reduceRegex, this->name, [this](auto){initReduce();})) @@ -32,6 +36,8 @@ Transition::Transition(const std::string & name) return; if (util::doIfNameMatch(eosRegex, this->name, [this](auto){initEOS();})) return; + if (util::doIfNameMatch(nothingRegex, this->name, [this](auto){initNothing();})) + return; throw std::invalid_argument("no match"); @@ -89,6 +95,39 @@ void Transition::initWrite(std::string colName, std::string object, std::string }; } +void Transition::initAdd(std::string colName, std::string object, std::string index, std::string value) +{ + auto objectValue = Action::str2object(object); + int indexValue = std::stoi(index); + + sequence.emplace_back(Action::addToHypothesisRelative(colName, objectValue, indexValue, value)); + + cost = [colName, objectValue, indexValue, value](const Config & config) + { + int lineIndex = 0; + if (objectValue == Action::Object::Buffer) + lineIndex = config.getWordIndex() + indexValue; + else + lineIndex = config.getStack(indexValue); + + auto gold = util::split(config.getConst(colName, lineIndex, 0).get(), '|'); + + for (auto & part : gold) + if (part == value) + return 0; + + return 1; + }; +} + +void Transition::initNothing() +{ + cost = [](const Config &) + { + return 0; + }; +} + void Transition::initShift() { sequence.emplace_back(Action::pushWordIndexOnStack()); diff --git a/torch_modules/include/CNNNetwork.hpp b/torch_modules/include/CNNNetwork.hpp index d5ec5bf5a8f6d0f3633e74c110fffb8986454f79..b9a730c7000983231b7c75e76e251f435868a049 100644 --- a/torch_modules/include/CNNNetwork.hpp +++ b/torch_modules/include/CNNNetwork.hpp @@ -8,6 +8,7 @@ class CNNNetworkImpl : public NeuralNetworkImpl private : static inline std::vector<long> focusedBufferIndexes{0,1}; + static inline std::vector<long> focusedStackIndexes{0,1}; static inline std::vector<long> windowSizes{2,3,4}; static constexpr unsigned int maxNbLetters = 10; diff --git a/torch_modules/src/CNNNetwork.cpp b/torch_modules/src/CNNNetwork.cpp index 7612f7c60bf568a4ae60df510e2d3f51b3cecea9..9633cb4f185ee392fa4299204317224ca692ff84 100644 --- a/torch_modules/src/CNNNetwork.cpp +++ b/torch_modules/src/CNNNetwork.cpp @@ -13,7 +13,7 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, i setColumns({"FORM", "UPOS"}); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize))); - linear1 = register_module("linear1", torch::nn::Linear(nbFilters*windowSizes.size()+nbFiltersLetters*windowSizes.size()*focusedBufferIndexes.size(), hiddenSize)); + linear1 = register_module("linear1", torch::nn::Linear(nbFilters*windowSizes.size()+nbFiltersLetters*windowSizes.size()*(focusedBufferIndexes.size()+focusedStackIndexes.size()), hiddenSize)); linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs)); for (auto & windowSize : windowSizes) { @@ -28,7 +28,7 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input) input = input.unsqueeze(0); auto wordIndexes = input.narrow(1, 0, columns.size()*(1+leftBorder+rightBorder)); - auto wordLetters = input.narrow(1, columns.size()*(1+leftBorder+rightBorder), maxNbLetters*focusedBufferIndexes.size()); + auto wordLetters = input.narrow(1, columns.size()*(1+leftBorder+rightBorder), maxNbLetters*(focusedBufferIndexes.size()+focusedStackIndexes.size())); auto embeddings = wordEmbeddings(wordIndexes).view({wordIndexes.size(0), wordIndexes.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()}).unsqueeze(1); auto lettersEmbeddings = wordEmbeddings(wordLetters).view({wordLetters.size(0), wordLetters.size(1)/maxNbLetters, maxNbLetters, wordEmbeddings->options.embedding_dim()}).unsqueeze(1); @@ -43,6 +43,14 @@ torch::Tensor CNNNetworkImpl::forward(torch::Tensor input) auto pooled = torch::max_pool1d(convOut, convOut.size(2)); windows.emplace_back(pooled); } + for (unsigned int word = 0; word < focusedStackIndexes.size(); word++) + for (unsigned int i = 0; i < lettersCNNs.size(); i++) + { + auto input = permuted[focusedBufferIndexes.size()+word]; + auto convOut = torch::relu(lettersCNNs[i](input).squeeze(-1)); + auto pooled = torch::max_pool1d(convOut, convOut.size(2)); + windows.emplace_back(pooled); + } auto lettersCnnOut = torch::cat(windows, 2); lettersCnnOut = lettersCnnOut.view({lettersCnnOut.size(0), -1}); @@ -133,6 +141,25 @@ std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) c } } + for (auto index : focusedStackIndexes) + { + util::utf8string letters; + if (config.hasStack(index) and config.has("FORM", config.getStack(index),0)) + letters = util::splitAsUtf8(config.getAsFeature("FORM", config.getStack(index)).get()); + for (unsigned int i = 0; i < maxNbLetters; i++) + { + if (i < letters.size()) + { + std::string sLetter = fmt::format("Letter({})", letters[i]); + context.emplace_back(dict.getIndexOrInsert(sLetter)); + } + else + { + context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr)); + } + } + } + return context; }