Skip to content
Snippets Groups Projects
Commit 05062ca7 authored by Franck Dary's avatar Franck Dary
Browse files

Removed pretrainedEmbeddings as a global parameter, instead submodules can now...

Removed pretrainedEmbeddings as a global parameter, instead submodules can now have their own pretrained w2v
parent 74a4a5f6
No related branches found
No related tags found
No related merge requests found
Showing
with 71 additions and 17 deletions
...@@ -50,6 +50,7 @@ class Dict ...@@ -50,6 +50,7 @@ class Dict
std::size_t size() const; std::size_t size() const;
int getNbOccs(int index) const; int getNbOccs(int index) const;
void removeRareElements(); void removeRareElements();
void loadWord2Vec(std::filesystem::path & path);
}; };
#endif #endif
...@@ -200,3 +200,52 @@ void Dict::removeRareElements() ...@@ -200,3 +200,52 @@ void Dict::removeRareElements()
nbOccs = newNbOccs; nbOccs = newNbOccs;
} }
void Dict::loadWord2Vec(std::filesystem::path & path)
{
if (path.empty())
return;
if (!std::filesystem::exists(path))
util::myThrow(fmt::format("pretrained word2vec file '{}' do not exist", path.string()));
auto originalState = getState();
setState(Dict::State::Open);
std::FILE * file = std::fopen(path.c_str(), "r");
char buffer[100000];
bool firstLine = true;
try
{
while (!std::feof(file))
{
if (buffer != std::fgets(buffer, 100000, file))
break;
if (firstLine)
{
firstLine = false;
continue;
}
auto splited = util::split(util::strip(buffer), ' ');
if (splited.size() < 2)
util::myThrow(fmt::format("invalid w2v line '{}' less than 2 columns", buffer));
auto dictIndex = getIndexOrInsert(splited[0]);
if (dictIndex == getIndexOrInsert(Dict::unknownValueStr) or dictIndex == getIndexOrInsert(Dict::nullValueStr) or dictIndex == getIndexOrInsert(Dict::emptyValueStr))
util::myThrow(fmt::format("w2v line '{}' gave unexpected special dict index", buffer));
}
} catch (std::exception & e)
{
util::myThrow(fmt::format("caught '{}'", e.what()));
}
std::fclose(file);
setState(originalState);
}
...@@ -11,7 +11,7 @@ ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::file ...@@ -11,7 +11,7 @@ ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::file
readFromFile(path); readFromFile(path);
loadDicts(); loadDicts();
classifier->getNN()->registerEmbeddings(""); classifier->getNN()->registerEmbeddings();
classifier->getNN()->to(NeuralNetworkImpl::device); classifier->getNN()->to(NeuralNetworkImpl::device);
if (models.size() > 1) if (models.size() > 1)
......
...@@ -20,7 +20,7 @@ class AppliableTransModuleImpl : public Submodule ...@@ -20,7 +20,7 @@ class AppliableTransModuleImpl : public Submodule
std::size_t getOutputSize() override; std::size_t getOutputSize() override;
std::size_t getInputSize() override; std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings(std::filesystem::path pretrained) override; void registerEmbeddings() override;
}; };
TORCH_MODULE(AppliableTransModule); TORCH_MODULE(AppliableTransModule);
......
...@@ -19,6 +19,7 @@ class ContextModuleImpl : public Submodule ...@@ -19,6 +19,7 @@ class ContextModuleImpl : public Submodule
std::vector<int> bufferContext; std::vector<int> bufferContext;
std::vector<int> stackContext; std::vector<int> stackContext;
int inSize; int inSize;
std::filesystem::path w2vFile;
public : public :
...@@ -27,7 +28,7 @@ class ContextModuleImpl : public Submodule ...@@ -27,7 +28,7 @@ class ContextModuleImpl : public Submodule
std::size_t getOutputSize() override; std::size_t getOutputSize() override;
std::size_t getInputSize() override; std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings(std::filesystem::path pretrained) override; void registerEmbeddings() override;
}; };
TORCH_MODULE(ContextModule); TORCH_MODULE(ContextModule);
......
...@@ -27,7 +27,7 @@ class DepthLayerTreeEmbeddingModuleImpl : public Submodule ...@@ -27,7 +27,7 @@ class DepthLayerTreeEmbeddingModuleImpl : public Submodule
std::size_t getOutputSize() override; std::size_t getOutputSize() override;
std::size_t getInputSize() override; std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings(std::filesystem::path pretrained) override; void registerEmbeddings() override;
}; };
TORCH_MODULE(DepthLayerTreeEmbeddingModule); TORCH_MODULE(DepthLayerTreeEmbeddingModule);
......
...@@ -13,6 +13,7 @@ class DictHolder : public NameHolder ...@@ -13,6 +13,7 @@ class DictHolder : public NameHolder
static constexpr char * filenameTemplate = "{}.dict"; static constexpr char * filenameTemplate = "{}.dict";
std::unique_ptr<Dict> dict; std::unique_ptr<Dict> dict;
bool pretrained{false};
private : private :
...@@ -24,6 +25,8 @@ class DictHolder : public NameHolder ...@@ -24,6 +25,8 @@ class DictHolder : public NameHolder
void saveDict(std::filesystem::path path); void saveDict(std::filesystem::path path);
void loadDict(std::filesystem::path path); void loadDict(std::filesystem::path path);
Dict & getDict(); Dict & getDict();
bool dictIsPretrained();
void dictSetPretrained(bool pretrained);
}; };
#endif #endif
......
...@@ -26,7 +26,7 @@ class DistanceModuleImpl : public Submodule ...@@ -26,7 +26,7 @@ class DistanceModuleImpl : public Submodule
std::size_t getOutputSize() override; std::size_t getOutputSize() override;
std::size_t getInputSize() override; std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings(std::filesystem::path pretrained) override; void registerEmbeddings() override;
}; };
TORCH_MODULE(DistanceModule); TORCH_MODULE(DistanceModule);
......
...@@ -27,7 +27,7 @@ class FocusedColumnModuleImpl : public Submodule ...@@ -27,7 +27,7 @@ class FocusedColumnModuleImpl : public Submodule
std::size_t getOutputSize() override; std::size_t getOutputSize() override;
std::size_t getInputSize() override; std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings(std::filesystem::path pretrained) override; void registerEmbeddings() override;
}; };
TORCH_MODULE(FocusedColumnModule); TORCH_MODULE(FocusedColumnModule);
......
...@@ -24,7 +24,7 @@ class HistoryModuleImpl : public Submodule ...@@ -24,7 +24,7 @@ class HistoryModuleImpl : public Submodule
std::size_t getOutputSize() override; std::size_t getOutputSize() override;
std::size_t getInputSize() override; std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings(std::filesystem::path pretrained) override; void registerEmbeddings() override;
}; };
TORCH_MODULE(HistoryModule); TORCH_MODULE(HistoryModule);
......
...@@ -30,7 +30,7 @@ class ModularNetworkImpl : public NeuralNetworkImpl ...@@ -30,7 +30,7 @@ class ModularNetworkImpl : public NeuralNetworkImpl
ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions); ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions);
torch::Tensor forward(torch::Tensor input) override; torch::Tensor forward(torch::Tensor input) override;
std::vector<std::vector<long>> extractContext(Config & config) override; std::vector<std::vector<long>> extractContext(Config & config) override;
void registerEmbeddings(std::filesystem::path pretrained) override; void registerEmbeddings() override;
void saveDicts(std::filesystem::path path) override; void saveDicts(std::filesystem::path path) override;
void loadDicts(std::filesystem::path path) override; void loadDicts(std::filesystem::path path) override;
void setDictsState(Dict::State state) override; void setDictsState(Dict::State state) override;
......
...@@ -21,7 +21,7 @@ class NeuralNetworkImpl : public torch::nn::Module, public NameHolder, public St ...@@ -21,7 +21,7 @@ class NeuralNetworkImpl : public torch::nn::Module, public NameHolder, public St
virtual torch::Tensor forward(torch::Tensor input) = 0; virtual torch::Tensor forward(torch::Tensor input) = 0;
virtual std::vector<std::vector<long>> extractContext(Config & config) = 0; virtual std::vector<std::vector<long>> extractContext(Config & config) = 0;
virtual void registerEmbeddings(std::filesystem::path pretrained) = 0; virtual void registerEmbeddings() = 0;
virtual void saveDicts(std::filesystem::path path) = 0; virtual void saveDicts(std::filesystem::path path) = 0;
virtual void loadDicts(std::filesystem::path path) = 0; virtual void loadDicts(std::filesystem::path path) = 0;
virtual void setDictsState(Dict::State state) = 0; virtual void setDictsState(Dict::State state) = 0;
......
...@@ -24,7 +24,7 @@ class NumericColumnModuleImpl : public Submodule ...@@ -24,7 +24,7 @@ class NumericColumnModuleImpl : public Submodule
std::size_t getOutputSize() override; std::size_t getOutputSize() override;
std::size_t getInputSize() override; std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings(std::filesystem::path pretrained) override; void registerEmbeddings() override;
}; };
TORCH_MODULE(NumericColumnModule); TORCH_MODULE(NumericColumnModule);
......
...@@ -14,7 +14,7 @@ class RandomNetworkImpl : public NeuralNetworkImpl ...@@ -14,7 +14,7 @@ class RandomNetworkImpl : public NeuralNetworkImpl
RandomNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState); RandomNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState);
torch::Tensor forward(torch::Tensor input) override; torch::Tensor forward(torch::Tensor input) override;
std::vector<std::vector<long>> extractContext(Config &) override; std::vector<std::vector<long>> extractContext(Config &) override;
void registerEmbeddings(std::filesystem::path) override; void registerEmbeddings() override;
void saveDicts(std::filesystem::path path) override; void saveDicts(std::filesystem::path path) override;
void loadDicts(std::filesystem::path path) override; void loadDicts(std::filesystem::path path) override;
void setDictsState(Dict::State state) override; void setDictsState(Dict::State state) override;
......
...@@ -24,7 +24,7 @@ class RawInputModuleImpl : public Submodule ...@@ -24,7 +24,7 @@ class RawInputModuleImpl : public Submodule
std::size_t getOutputSize() override; std::size_t getOutputSize() override;
std::size_t getInputSize() override; std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings(std::filesystem::path pretrained) override; void registerEmbeddings() override;
}; };
TORCH_MODULE(RawInputModule); TORCH_MODULE(RawInputModule);
......
...@@ -24,7 +24,7 @@ class SplitTransModuleImpl : public Submodule ...@@ -24,7 +24,7 @@ class SplitTransModuleImpl : public Submodule
std::size_t getOutputSize() override; std::size_t getOutputSize() override;
std::size_t getInputSize() override; std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings(std::filesystem::path pretrained) override; void registerEmbeddings() override;
}; };
TORCH_MODULE(SplitTransModule); TORCH_MODULE(SplitTransModule);
......
...@@ -21,7 +21,7 @@ class StateNameModuleImpl : public Submodule ...@@ -21,7 +21,7 @@ class StateNameModuleImpl : public Submodule
std::size_t getOutputSize() override; std::size_t getOutputSize() override;
std::size_t getInputSize() override; std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings(std::filesystem::path pretrained) override; void registerEmbeddings() override;
}; };
TORCH_MODULE(StateNameModule); TORCH_MODULE(StateNameModule);
......
...@@ -21,7 +21,7 @@ class Submodule : public torch::nn::Module, public DictHolder, public StateHolde ...@@ -21,7 +21,7 @@ class Submodule : public torch::nn::Module, public DictHolder, public StateHolde
virtual std::size_t getInputSize() = 0; virtual std::size_t getInputSize() = 0;
virtual void addToContext(std::vector<std::vector<long>> & context, const Config & config) = 0; virtual void addToContext(std::vector<std::vector<long>> & context, const Config & config) = 0;
virtual torch::Tensor forward(torch::Tensor input) = 0; virtual torch::Tensor forward(torch::Tensor input) = 0;
virtual void registerEmbeddings(std::filesystem::path pretrained) = 0; virtual void registerEmbeddings() = 0;
std::function<std::string(const std::string &)> getFunction(const std::string functionNames); std::function<std::string(const std::string &)> getFunction(const std::string functionNames);
}; };
......
...@@ -23,7 +23,7 @@ class UppercaseRateModuleImpl : public Submodule ...@@ -23,7 +23,7 @@ class UppercaseRateModuleImpl : public Submodule
std::size_t getOutputSize() override; std::size_t getOutputSize() override;
std::size_t getInputSize() override; std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override; void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings(std::filesystem::path pretrained) override; void registerEmbeddings() override;
}; };
TORCH_MODULE(UppercaseRateModule); TORCH_MODULE(UppercaseRateModule);
......
...@@ -31,7 +31,7 @@ void AppliableTransModuleImpl::addToContext(std::vector<std::vector<long>> & con ...@@ -31,7 +31,7 @@ void AppliableTransModuleImpl::addToContext(std::vector<std::vector<long>> & con
contextElement.emplace_back(0); contextElement.emplace_back(0);
} }
void AppliableTransModuleImpl::registerEmbeddings(std::filesystem::path) void AppliableTransModuleImpl::registerEmbeddings()
{ {
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment