Skip to content
Snippets Groups Projects
Commit 085a30f2 authored by Franck Dary's avatar Franck Dary
Browse files

w2v in ContextModule is now relative path

parent 9f7e5b50
Branches
No related tags found
No related merge requests found
......@@ -52,7 +52,7 @@ class Dict
std::size_t size() const;
int getNbOccs(int index) const;
void removeRareElements();
void loadWord2Vec(std::filesystem::path & path);
void loadWord2Vec(std::filesystem::path path);
};
#endif
......@@ -217,7 +217,7 @@ void Dict::removeRareElements()
nbOccs = newNbOccs;
}
void Dict::loadWord2Vec(std::filesystem::path & path)
void Dict::loadWord2Vec(std::filesystem::path path)
{
if (path.empty())
return;
......
......@@ -24,8 +24,8 @@ class Classifier
private :
void initNeuralNetwork(const std::vector<std::string> & definition);
void initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState);
void initNeuralNetwork(const std::vector<std::string> & definition, std::filesystem::path path);
void initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState, std::filesystem::path path);
public :
......
......@@ -58,7 +58,7 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std
}))
util::myThrow(fmt::format("Invalid line '{}', expected '{}'\n", definition[1], "(LossMultiplier :) {state1,multiplier1 state2,multiplier2...}"));
initNeuralNetwork(definition);
initNeuralNetwork(definition, path.parent_path());
}
int Classifier::getNbParameters() const
......@@ -89,7 +89,7 @@ const std::string & Classifier::getName() const
return name;
}
void Classifier::initNeuralNetwork(const std::vector<std::string> & definition)
void Classifier::initNeuralNetwork(const std::vector<std::string> & definition, std::filesystem::path path)
{
std::map<std::string,std::size_t> nbOutputsPerState;
for (auto & it : this->transitionSets)
......@@ -108,7 +108,7 @@ void Classifier::initNeuralNetwork(const std::vector<std::string> & definition)
if (networkType == "Random")
this->nn.reset(new RandomNetworkImpl(this->name, nbOutputsPerState));
else if (networkType == "Modular")
initModular(definition, curIndex, nbOutputsPerState);
initModular(definition, curIndex, nbOutputsPerState, path);
else
util::myThrow(fmt::format("Unknown network type '{}', available types are 'Random, Modular'", networkType));
......@@ -141,7 +141,7 @@ void Classifier::setState(const std::string & state)
nn->setState(state);
}
void Classifier::initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState)
void Classifier::initModular(const std::vector<std::string> & definition, std::size_t & curIndex, const std::map<std::string,std::size_t> & nbOutputsPerState, std::filesystem::path path)
{
std::string anyBlanks = "(?:(?:\\s|\\t)*)";
std::regex endRegex(fmt::format("{}End{}",anyBlanks,anyBlanks));
......@@ -157,7 +157,7 @@ void Classifier::initModular(const std::vector<std::string> & definition, std::s
modulesDefinitions.emplace_back(definition[curIndex]);
}
this->nn.reset(new ModularNetworkImpl(this->name, nbOutputsPerState, modulesDefinitions));
this->nn.reset(new ModularNetworkImpl(this->name, nbOutputsPerState, modulesDefinitions, path));
}
void Classifier::resetOptimizer()
......
......@@ -19,11 +19,12 @@ class ContextModuleImpl : public Submodule
std::vector<std::function<std::string(const std::string &)>> functions;
std::vector<std::tuple<Config::Object, int, std::optional<int>>> targets;
int inSize;
std::filesystem::path path;
std::filesystem::path w2vFile;
public :
ContextModuleImpl(std::string name, const std::string & definition);
ContextModuleImpl(std::string name, const std::string & definition, std::filesystem::path path);
torch::Tensor forward(torch::Tensor input);
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
......
......@@ -27,7 +27,7 @@ class ModularNetworkImpl : public NeuralNetworkImpl
public :
ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions);
ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions, std::filesystem::path path);
torch::Tensor forward(torch::Tensor input) override;
std::vector<std::vector<long>> extractContext(Config & config) override;
void registerEmbeddings() override;
......
#include "ContextModule.hpp"
ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & definition)
ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & definition, std::filesystem::path path) : path(path)
{
setName(name);
......@@ -50,7 +50,7 @@ ContextModuleImpl::ContextModuleImpl(std::string name, const std::string & defin
if (!w2vFile.empty())
{
getDict().loadWord2Vec(w2vFile);
getDict().loadWord2Vec(this->path / w2vFile);
getDict().setState(Dict::State::Closed);
dictSetPretrained(true);
}
......@@ -144,6 +144,6 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input)
void ContextModuleImpl::registerEmbeddings()
{
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
loadPretrainedW2vEmbeddings(wordEmbeddings, w2vFile);
loadPretrainedW2vEmbeddings(wordEmbeddings, w2vFile.empty() ? "" : path / w2vFile);
}
#include "ModularNetwork.hpp"
ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions)
ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions, std::filesystem::path path)
{
setName(name);
std::string anyBlanks = "(?:(?:\\s|\\t)*)";
......@@ -28,7 +28,7 @@ ModularNetworkImpl::ModularNetworkImpl(std::string name, std::map<std::string,st
std::string name = fmt::format("{}_{}", modules.size(), splited.first);
std::string nameH = fmt::format("{}_{}", getName(), name);
if (splited.first == "Context")
modules.emplace_back(register_module(name, ContextModule(nameH, splited.second)));
modules.emplace_back(register_module(name, ContextModule(nameH, splited.second, path)));
else if (splited.first == "StateName")
modules.emplace_back(register_module(name, StateNameModule(nameH, splited.second)));
else if (splited.first == "History")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment