Skip to content
Snippets Groups Projects
Commit 6d627fa1 authored by Franck Dary's avatar Franck Dary
Browse files

Added module HistoryMine

parent f0df0508
No related branches found
No related tags found
No related merge requests found
...@@ -54,6 +54,7 @@ class Config ...@@ -54,6 +54,7 @@ class Config
std::size_t currentSentenceStartRawInput{0}; std::size_t currentSentenceStartRawInput{0};
util::String state{"NONE"}; util::String state{"NONE"};
boost::circular_buffer<util::String> history{10}; boost::circular_buffer<util::String> history{10};
std::map<std::string, boost::circular_buffer<util::String>> stateHistory;
boost::circular_buffer<std::size_t> stack{50}; boost::circular_buffer<std::size_t> stack{50};
float chosenActionScore{0.0}; float chosenActionScore{0.0};
std::vector<std::string> extraColumns{commentsColName, rawRangeStartColName, rawRangeEndColName, isMultiColName, childsColName, sentIdColName, EOSColName}; std::vector<std::string> extraColumns{commentsColName, rawRangeStartColName, rawRangeEndColName, isMultiColName, childsColName, sentIdColName, EOSColName};
...@@ -142,9 +143,11 @@ class Config ...@@ -142,9 +143,11 @@ class Config
long getRelativeWordIndex(Object object, int relativeIndex) const; long getRelativeWordIndex(Object object, int relativeIndex) const;
bool hasRelativeWordIndex(Object object, int relativeIndex) const; bool hasRelativeWordIndex(Object object, int relativeIndex) const;
const util::String & getHistory(int relativeIndex) const; const util::String & getHistory(int relativeIndex) const;
const util::String & getHistoryState(int relativeIndex) const;
std::size_t getStack(int relativeIndex) const; std::size_t getStack(int relativeIndex) const;
std::size_t getStackSize() const; std::size_t getStackSize() const;
bool hasHistory(int relativeIndex) const; bool hasHistory(int relativeIndex) const;
bool hasHistoryState(int relativeIndex) const;
bool hasStack(int relativeIndex) const; bool hasStack(int relativeIndex) const;
util::String getState() const; util::String getState() const;
void setState(const std::string state); void setState(const std::string state);
......
...@@ -18,6 +18,7 @@ Config::Config(const Config & other) ...@@ -18,6 +18,7 @@ Config::Config(const Config & other)
this->characterIndex = other.characterIndex; this->characterIndex = other.characterIndex;
this->state = other.state; this->state = other.state;
this->history = other.history; this->history = other.history;
this->stateHistory = other.stateHistory;
this->stack = other.stack; this->stack = other.stack;
this->extraColumns = other.extraColumns; this->extraColumns = other.extraColumns;
this->mcd = other.mcd; this->mcd = other.mcd;
...@@ -361,7 +362,10 @@ Config::ConstValueIterator Config::getConstIterator(int colIndex, int lineIndex, ...@@ -361,7 +362,10 @@ Config::ConstValueIterator Config::getConstIterator(int colIndex, int lineIndex,
void Config::addToHistory(const std::string & transition) void Config::addToHistory(const std::string & transition)
{ {
if (stateHistory.count(state) == 0)
stateHistory.insert({state, boost::circular_buffer<util::String>(10)});
history.push_back(util::String(transition)); history.push_back(util::String(transition));
stateHistory.at(state).push_back(util::String(transition));
} }
void Config::addToStack(std::size_t index) void Config::addToStack(std::size_t index)
...@@ -491,6 +495,11 @@ const util::String & Config::getHistory(int relativeIndex) const ...@@ -491,6 +495,11 @@ const util::String & Config::getHistory(int relativeIndex) const
return history[history.size()-1-relativeIndex]; return history[history.size()-1-relativeIndex];
} }
const util::String & Config::getHistoryState(int relativeIndex) const
{
return stateHistory.at(state)[stateHistory.at(state).size()-1-relativeIndex];
}
std::size_t Config::getStack(int relativeIndex) const std::size_t Config::getStack(int relativeIndex) const
{ {
if (relativeIndex == -1) if (relativeIndex == -1)
...@@ -508,6 +517,13 @@ bool Config::hasHistory(int relativeIndex) const ...@@ -508,6 +517,13 @@ bool Config::hasHistory(int relativeIndex) const
return relativeIndex >= 0 && relativeIndex < (int)history.size(); return relativeIndex >= 0 && relativeIndex < (int)history.size();
} }
bool Config::hasHistoryState(int relativeIndex) const
{
if (stateHistory.count(state) == 0)
return false;
return relativeIndex >= 0 && relativeIndex < (int)stateHistory.at(state).size();
}
bool Config::hasStack(int relativeIndex) const bool Config::hasStack(int relativeIndex) const
{ {
if (relativeIndex == -1) if (relativeIndex == -1)
......
#ifndef HISTORYMINEMODULE__H
#define HISTORYMINEMODULE__H
#include <torch/torch.h>
#include "Submodule.hpp"
#include "MyModule.hpp"
#include "LSTM.hpp"
#include "GRU.hpp"
#include "CNN.hpp"
#include "Concat.hpp"
#include "WordEmbeddings.hpp"
class HistoryMineModuleImpl : public Submodule
{
private :
WordEmbeddings wordEmbeddings{nullptr};
std::shared_ptr<MyModule> myModule{nullptr};
int maxNbElements;
int inSize;
public :
HistoryMineModuleImpl(std::string name, const std::string & definition);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(torch::Tensor & context, const Config & config) override;
void registerEmbeddings() override;
};
TORCH_MODULE(HistoryMineModule);
#endif
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "UppercaseRateModule.hpp" #include "UppercaseRateModule.hpp"
#include "NumericColumnModule.hpp" #include "NumericColumnModule.hpp"
#include "HistoryModule.hpp" #include "HistoryModule.hpp"
#include "HistoryMineModule.hpp"
#include "DistanceModule.hpp" #include "DistanceModule.hpp"
#include "MLP.hpp" #include "MLP.hpp"
......
#include "HistoryMineModule.hpp"
HistoryMineModuleImpl::HistoryMineModuleImpl(std::string name, const std::string & definition)
{
setName(name);
std::regex regex("(?:(?:\\s|\\t)*)NbElem\\{(.*)\\}(?:(?:\\s|\\t)*)(\\S+)\\{(.*)\\}(?:(?:\\s|\\t)*)In\\{(.*)\\}(?:(?:\\s|\\t)*)Out\\{(.*)\\}(?:(?:\\s|\\t)*)");
if (!util::doIfNameMatch(regex, definition, [this,&definition](auto sm)
{
try
{
maxNbElements = std::stoi(sm.str(1));
auto subModuleType = sm.str(2);
auto subModuleArguments = util::split(sm.str(3), ' ');
auto options = MyModule::ModuleOptions(true)
.bidirectional(std::stoi(subModuleArguments[0]))
.num_layers(std::stoi(subModuleArguments[1]))
.dropout(std::stof(subModuleArguments[2]))
.complete(std::stoi(subModuleArguments[3]));
inSize = std::stoi(sm.str(4));
int outSize = std::stoi(sm.str(5));
if (subModuleType == "LSTM")
myModule = register_module("myModule", LSTM(inSize, outSize, options));
else if (subModuleType == "GRU")
myModule = register_module("myModule", GRU(inSize, outSize, options));
else if (subModuleType == "CNN")
myModule = register_module("myModule", CNN(inSize, outSize, options));
else if (subModuleType == "Concat")
myModule = register_module("myModule", Concat(inSize));
else
util::myThrow(fmt::format("unknown sumodule type '{}'", subModuleType));
} catch (std::exception & e) {util::myThrow(fmt::format("{} in '{}'",e.what(),definition));}
}))
util::myThrow(fmt::format("invalid definition '{}'", definition));
}
torch::Tensor HistoryMineModuleImpl::forward(torch::Tensor input)
{
return myModule->forward(wordEmbeddings(input.narrow(1, firstInputIndex, maxNbElements))).reshape({input.size(0), -1});
}
std::size_t HistoryMineModuleImpl::getOutputSize()
{
return myModule->getOutputSize(maxNbElements);
}
std::size_t HistoryMineModuleImpl::getInputSize()
{
return maxNbElements;
}
void HistoryMineModuleImpl::addToContext(torch::Tensor & context, const Config & config)
{
auto & dict = getDict();
std::string prefix = "HISTORYMINE";
for (int i = 0; i < maxNbElements; i++)
if (config.hasHistoryState(i))
context[firstInputIndex+i] = dict.getIndexOrInsert(config.getHistoryState(i), prefix);
else
context[firstInputIndex+i] = dict.getIndexOrInsert(Dict::nullValueStr, prefix);
}
void HistoryMineModuleImpl::registerEmbeddings()
{
if (!wordEmbeddings)
wordEmbeddings = register_module("embeddings", WordEmbeddings(getDict().size(), inSize));
}
...@@ -35,6 +35,8 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st ...@@ -35,6 +35,8 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st
modules.emplace_back(register_module(name, StateNameModule(nameH, splited.second))); modules.emplace_back(register_module(name, StateNameModule(nameH, splited.second)));
else if (splited.first == "History") else if (splited.first == "History")
modules.emplace_back(register_module(name, HistoryModule(nameH, splited.second))); modules.emplace_back(register_module(name, HistoryModule(nameH, splited.second)));
else if (splited.first == "HistoryMine")
modules.emplace_back(register_module(name, HistoryMineModule(nameH, splited.second)));
else if (splited.first == "NumericColumn") else if (splited.first == "NumericColumn")
modules.emplace_back(register_module(name, NumericColumnModule(nameH, splited.second))); modules.emplace_back(register_module(name, NumericColumnModule(nameH, splited.second)));
else if (splited.first == "UppercaseRate") else if (splited.first == "UppercaseRate")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment