Commit 6d627fa1 authored by Franck Dary's avatar Franck Dary
Browse files

Added module HistoryMine

parent f0df0508
......@@ -54,6 +54,7 @@ class Config
std::size_t currentSentenceStartRawInput{0};
util::String state{"NONE"};
boost::circular_buffer<util::String> history{10};
std::map<std::string, boost::circular_buffer<util::String>> stateHistory;
boost::circular_buffer<std::size_t> stack{50};
float chosenActionScore{0.0};
std::vector<std::string> extraColumns{commentsColName, rawRangeStartColName, rawRangeEndColName, isMultiColName, childsColName, sentIdColName, EOSColName};
......@@ -142,9 +143,11 @@ class Config
long getRelativeWordIndex(Object object, int relativeIndex) const;
bool hasRelativeWordIndex(Object object, int relativeIndex) const;
const util::String & getHistory(int relativeIndex) const;
const util::String & getHistoryState(int relativeIndex) const;
std::size_t getStack(int relativeIndex) const;
std::size_t getStackSize() const;
bool hasHistory(int relativeIndex) const;
bool hasHistoryState(int relativeIndex) const;
bool hasStack(int relativeIndex) const;
util::String getState() const;
void setState(const std::string state);
......
......@@ -18,6 +18,7 @@ Config::Config(const Config & other)
this->characterIndex = other.characterIndex;
this->state = other.state;
this->history = other.history;
this->stateHistory = other.stateHistory;
this->stack = other.stack;
this->extraColumns = other.extraColumns;
this->mcd = other.mcd;
......@@ -361,7 +362,10 @@ Config::ConstValueIterator Config::getConstIterator(int colIndex, int lineIndex,
void Config::addToHistory(const std::string & transition)
{
if (stateHistory.count(state) == 0)
stateHistory.insert({state, boost::circular_buffer<util::String>(10)});
history.push_back(util::String(transition));
stateHistory.at(state).push_back(util::String(transition));
}
void Config::addToStack(std::size_t index)
......@@ -491,6 +495,11 @@ const util::String & Config::getHistory(int relativeIndex) const
return history[history.size()-1-relativeIndex];
}
const util::String & Config::getHistoryState(int relativeIndex) const
{
return stateHistory.at(state)[stateHistory.at(state).size()-1-relativeIndex];
}
std::size_t Config::getStack(int relativeIndex) const
{
if (relativeIndex == -1)
......@@ -508,6 +517,13 @@ bool Config::hasHistory(int relativeIndex) const
return relativeIndex >= 0 && relativeIndex < (int)history.size();
}
bool Config::hasHistoryState(int relativeIndex) const
{
if (stateHistory.count(state) == 0)
return false;
return relativeIndex >= 0 && relativeIndex < (int)stateHistory.at(state).size();
}
bool Config::hasStack(int relativeIndex) const
{
if (relativeIndex == -1)
......
#ifndef HISTORYMINEMODULE__H
#define HISTORYMINEMODULE__H
#include <torch/torch.h>
#include "Submodule.hpp"
#include "MyModule.hpp"
#include "LSTM.hpp"
#include "GRU.hpp"
#include "CNN.hpp"
#include "Concat.hpp"
#include "WordEmbeddings.hpp"
class HistoryMineModuleImpl : public Submodule
{
private :
WordEmbeddings wordEmbeddings{nullptr};
std::shared_ptr<MyModule> myModule{nullptr};
int maxNbElements;
int inSize;
public :
HistoryMineModuleImpl(std::string name, const std::string & definition);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(torch::Tensor & context, const Config & config) override;
void registerEmbeddings() override;
};
TORCH_MODULE(HistoryMineModule);
#endif
......@@ -13,6 +13,7 @@
#include "UppercaseRateModule.hpp"
#include "NumericColumnModule.hpp"
#include "HistoryModule.hpp"
#include "HistoryMineModule.hpp"
#include "DistanceModule.hpp"
#include "MLP.hpp"
......
#include "HistoryMineModule.hpp"
HistoryMineModuleImpl::HistoryMineModuleImpl(std::string name, const std::string & definition)
{
setName(name);
std::regex regex("(?:(?:\\s|\\t)*)NbElem\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)");
if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
{
try
{
maxNbElements = std::stoi(sm.str(1));
auto subModuleType = sm.str(2);
auto subModuleArguments = util::split(sm.str(3), ' ');
auto options = MyModule::ModuleOptions(true)
.bidirectional(std::stoi(subModuleArguments[0]))
.num_layers(std::stoi(subModuleArguments[1]))
.dropout(std::stof(subModuleArguments[2]))
.complete(std::stoi(subModuleArguments[3]));
inSize = std::stoi(sm.str(4));
int outSize = std::stoi(sm.str(5));
if (subModuleType == "LSTM")
myModule = register_module("myModule", LSTM(inSize, outSize, options));
else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(inSize, outSize, options));
else if (subModuleType == "CNN")
myModule = register_module("myModule", CNN(inSize, outSize, options));
else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize));
else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
} catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
}))
util::myThrow(fmt::format("invalid definition '{}'", definition));
}
torch::Tensor HistoryMineModuleImpl::forward(torch::Tensor input)
{
return myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex, maxNbElements))).reshape({input.size(0), -1});
}
std::size_t HistoryMineModuleImpl::getOutputSize()
{
return myModule->getOutputSize(maxNbElements);
}
std::size_t HistoryMineModuleImpl::getInputSize()
{
return maxNbElements;
}
void HistoryMineModuleImpl::addToContext(torch::Tensor & context, const Config & config)
{
auto & dict = getDict();
std::string prefix = "HISTORYMINE";
for (int i = 0; i < maxNbElements; i++)
if (config.hasHistoryState(i))
context[firstInputIndex+i] = dict.getIndexOrInsert(config.getHistoryState(i), prefix);
else
context[firstInputIndex+i] = dict.getIndexOrInsert(Dict::nullValueStr, prefix);
}
void HistoryMineModuleImpl::registerEmbeddings()
{
if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
}
......@@ -35,6 +35,8 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st
modules.emplace_back(register_module(name, StateNameModule(nameH, splited.second)));
else if (splited.first == "History")
modules.emplace_back(register_module(name, HistoryModule(nameH, splited.second)));
else if (splited.first == "HistoryMine")
modules.emplace_back(register_module(name, HistoryMineModule(nameH, splited.second)));
else if (splited.first == "NumericColumn")
modules.emplace_back(register_module(name, NumericColumnModule(nameH, splited.second)));
else if (splited.first == "UppercaseRate")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment