Commit 77afafd7 authored by Franck Dary's avatar Franck Dary
Browse files

Added program parameter to give pretrained word embeddings in w2v format

parent 0089639f
......@@ -157,7 +157,7 @@ std::string util::strip(const std::string & s)
++first;
std::size_t last = s.size()-1;
while (last > first and (s[last] == ' ' or s[last] == '\t'))
while (last > first and (s[last] == ' ' or s[last] == '\t' or s[last] == '\n'))
--last;
return std::string(s.begin()+first, s.begin()+last+1);
......
......@@ -11,7 +11,7 @@ ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::file
readFromFile(path);
loadDicts();
classifier->getNN()->registerEmbeddings();
classifier->getNN()->registerEmbeddings("");
classifier->getNN()->to(NeuralNetworkImpl::device);
if (models.size() > 1)
......
......@@ -20,7 +20,7 @@ class AppliableTransModuleImpl : public Submodule
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings() override;
void registerEmbeddings(std::filesystem::path pretrained) override;
};
TORCH_MODULE(AppliableTransModule);
......
......@@ -25,7 +25,7 @@ class ContextModuleImpl : public Submodule
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings() override;
void registerEmbeddings(std::filesystem::path pretrained) override;
};
TORCH_MODULE(ContextModule);
......
......@@ -26,7 +26,7 @@ class DepthLayerTreeEmbeddingModuleImpl : public Submodule
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings() override;
void registerEmbeddings(std::filesystem::path pretrained) override;
};
TORCH_MODULE(DepthLayerTreeEmbeddingModule);
......
......@@ -25,7 +25,7 @@ class FocusedColumnModuleImpl : public Submodule
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings() override;
void registerEmbeddings(std::filesystem::path pretrained) override;
};
TORCH_MODULE(FocusedColumnModule);
......
......@@ -23,7 +23,7 @@ class HistoryModuleImpl : public Submodule
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings() override;
void registerEmbeddings(std::filesystem::path pretrained) override;
};
TORCH_MODULE(HistoryModule);
......
......@@ -29,7 +29,7 @@ class ModularNetworkImpl : public NeuralNetworkImpl
ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions);
torch::Tensor forward(torch::Tensor input) override;
std::vector<std::vector<long>> extractContext(Config & config) override;
void registerEmbeddings() override;
void registerEmbeddings(std::filesystem::path pretrained) override;
void saveDicts(std::filesystem::path path) override;
void loadDicts(std::filesystem::path path) override;
void setDictsState(Dict::State state) override;
......
......@@ -21,7 +21,7 @@ class NeuralNetworkImpl : public torch::nn::Module, public NameHolder, public St
virtual torch::Tensor forward(torch::Tensor input) = 0;
virtual std::vector<std::vector<long>> extractContext(Config & config) = 0;
virtual void registerEmbeddings() = 0;
virtual void registerEmbeddings(std::filesystem::path pretrained) = 0;
virtual void saveDicts(std::filesystem::path path) = 0;
virtual void loadDicts(std::filesystem::path path) = 0;
virtual void setDictsState(Dict::State state) = 0;
......
......@@ -23,7 +23,7 @@ class NumericColumnModuleImpl : public Submodule
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings() override;
void registerEmbeddings(std::filesystem::path pretrained) override;
};
TORCH_MODULE(NumericColumnModule);
......
......@@ -14,7 +14,7 @@ class RandomNetworkImpl : public NeuralNetworkImpl
RandomNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState);
torch::Tensor forward(torch::Tensor input) override;
std::vector<std::vector<long>> extractContext(Config &) override;
void registerEmbeddings() override;
void registerEmbeddings(std::filesystem::path) override;
void saveDicts(std::filesystem::path path) override;
void loadDicts(std::filesystem::path path) override;
void setDictsState(Dict::State state) override;
......
......@@ -23,7 +23,7 @@ class RawInputModuleImpl : public Submodule
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings() override;
void registerEmbeddings(std::filesystem::path pretrained) override;
};
TORCH_MODULE(RawInputModule);
......
......@@ -23,7 +23,7 @@ class SplitTransModuleImpl : public Submodule
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings() override;
void registerEmbeddings(std::filesystem::path pretrained) override;
};
TORCH_MODULE(SplitTransModule);
......
......@@ -21,7 +21,7 @@ class StateNameModuleImpl : public Submodule
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings() override;
void registerEmbeddings(std::filesystem::path pretrained) override;
};
TORCH_MODULE(StateNameModule);
......
......@@ -2,6 +2,7 @@
#define SUBMODULE__H
#include <torch/torch.h>
#include <filesystem>
#include "Config.hpp"
#include "DictHolder.hpp"
#include "StateHolder.hpp"
......@@ -15,11 +16,12 @@ class Submodule : public torch::nn::Module, public DictHolder, public StateHolde
public :
void setFirstInputIndex(std::size_t firstInputIndex);
void loadPretrainedW2vEmbeddings(torch::nn::Embedding & embeddings, std::filesystem::path path);
virtual std::size_t getOutputSize() = 0;
virtual std::size_t getInputSize() = 0;
virtual void addToContext(std::vector<std::vector<long>> & context, const Config & config) = 0;
virtual torch::Tensor forward(torch::Tensor input) = 0;
virtual void registerEmbeddings() = 0;
virtual void registerEmbeddings(std::filesystem::path pretrained) = 0;
};
#endif
......
......@@ -22,7 +22,7 @@ class UppercaseRateModuleImpl : public Submodule
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings() override;
void registerEmbeddings(std::filesystem::path pretrained) override;
};
TORCH_MODULE(UppercaseRateModule);
......
......@@ -31,7 +31,7 @@ void AppliableTransModuleImpl::addToContext(std::vector<std::vector<long>> & con
contextElement.emplace_back(0);
}
void AppliableTransModuleImpl::registerEmbeddings()
void AppliableTransModuleImpl::registerEmbeddings(std::filesystem::path)
{
}
......@@ -89,8 +89,9 @@ torch::Tensor ContextModuleImpl::forward(torch::Tensor input)
return myModule->forward(context);
}
void ContextModuleImpl::registerEmbeddings()
void ContextModuleImpl::registerEmbeddings(std::filesystem::path path)
{
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
loadPretrainedW2vEmbeddings(wordEmbeddings, path);
}
......@@ -122,8 +122,9 @@ void DepthLayerTreeEmbeddingModuleImpl::addToContext(std::vector<std::vector<lon
}
}
void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings()
void DepthLayerTreeEmbeddingModuleImpl::registerEmbeddings(std::filesystem::path path)
{
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
loadPretrainedW2vEmbeddings(wordEmbeddings, path);
}
......@@ -134,8 +134,9 @@ void FocusedColumnModuleImpl::addToContext(std::vector<std::vector<long>> & cont
}
}
void FocusedColumnModuleImpl::registerEmbeddings()
void FocusedColumnModuleImpl::registerEmbeddings(std::filesystem::path path)
{
wordEmbeddings = register_module("embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(getDict().size(), inSize)));
loadPretrainedW2vEmbeddings(wordEmbeddings, path);
}
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment