Commit 05062ca7 authored by Franck Dary's avatar Franck Dary
Browse files

Removed pretrainedEmbeddings as a global parameter, instead submodules can now...

Removed pretrainedEmbeddings as a global parameter, instead submodules can now have their own pretrained w2v
parent 74a4a5f6
......@@ -50,6 +50,7 @@ class Dict
std::size_t size() const;
int getNbOccs(int index) const;
void removeRareElements();
void loadWord2Vec(std::filesystem::path & path);
};
#endif
......@@ -200,3 +200,52 @@ void Dict::removeRareElements()
nbOccs = newNbOccs;
}
void Dict::loadWord2Vec(std::filesystem::path & path)
{
if (path.empty())
return;
if (!std::filesystem::exists(path))
util::myThrow(fmt::format("pretrained word2vec file '{}' do not exist", path.string()));
auto originalState = getState();
setState(Dict::State::Open);
std::FILE * file = std::fopen(path.c_str(), "r");
char buffer[100000];
bool firstLine = true;
try
{
while (!std::feof(file))
{
if (buffer != std::fgets(buffer, 100000, file))
break;
if (firstLine)
{
firstLine = false;
continue;
}
auto splited = util::split(util::strip(buffer), ' ');
if (splited.size() < 2)
util::myThrow(fmt::format("invalid w2v line '{}' less than 2 columns", buffer));
auto dictIndex = getIndexOrInsert(splited[0]);
if (dictIndex == getIndexOrInsert(Dict::unknownValueStr) or dictIndex == getIndexOrInsert(Dict::nullValueStr) or dictIndex == getIndexOrInsert(Dict::emptyValueStr))
util::myThrow(fmt::format("w2v line '{}' gave unexpected special dict index", buffer));
}
} catch (std::exception & e)
{
util::myThrow(fmt::format("caught '{}'", e.what()));
}
std::fclose(file);
setState(originalState);
}
......@@ -11,7 +11,7 @@ ReadingMachine::ReadingMachine(std::filesystem::path path, std::vector<std::file
readFromFile(path);
loadDicts();
classifier->getNN()->registerEmbeddings("");
classifier->getNN()->registerEmbeddings();
classifier->getNN()->to(NeuralNetworkImpl::device);
if (models.size() > 1)
......
......@@ -20,7 +20,7 @@ class AppliableTransModuleImpl : public Submodule
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings(std::filesystem::path pretrained) override;
void registerEmbeddings() override;
};
TORCH_MODULE(AppliableTransModule);
......
......@@ -19,6 +19,7 @@ class ContextModuleImpl : public Submodule
std::vector<int> bufferContext;
std::vector<int> stackContext;
int inSize;
std::filesystem::path w2vFile;
public :
......@@ -27,7 +28,7 @@ class ContextModuleImpl : public Submodule
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings(std::filesystem::path pretrained) override;
void registerEmbeddings() override;
};
TORCH_MODULE(ContextModule);
......
......@@ -27,7 +27,7 @@ class DepthLayerTreeEmbeddingModuleImpl : public Submodule
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings(std::filesystem::path pretrained) override;
void registerEmbeddings() override;
};
TORCH_MODULE(DepthLayerTreeEmbeddingModule);
......
......@@ -13,6 +13,7 @@ class DictHolder : public NameHolder
static constexpr char * filenameTemplate = "{}.dict";
std::unique_ptr<Dict> dict;
bool pretrained{false};
private :
......@@ -24,6 +25,8 @@ class DictHolder : public NameHolder
void saveDict(std::filesystem::path path);
void loadDict(std::filesystem::path path);
Dict & getDict();
bool dictIsPretrained();
void dictSetPretrained(bool pretrained);
};
#endif
......
......@@ -26,7 +26,7 @@ class DistanceModuleImpl : public Submodule
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings(std::filesystem::path pretrained) override;
void registerEmbeddings() override;
};
TORCH_MODULE(DistanceModule);
......
......@@ -27,7 +27,7 @@ class FocusedColumnModuleImpl : public Submodule
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings(std::filesystem::path pretrained) override;
void registerEmbeddings() override;
};
TORCH_MODULE(FocusedColumnModule);
......
......@@ -24,7 +24,7 @@ class HistoryModuleImpl : public Submodule
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings(std::filesystem::path pretrained) override;
void registerEmbeddings() override;
};
TORCH_MODULE(HistoryModule);
......
......@@ -30,7 +30,7 @@ class ModularNetworkImpl : public NeuralNetworkImpl
ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions);
torch::Tensor forward(torch::Tensor input) override;
std::vector<std::vector<long>> extractContext(Config & config) override;
void registerEmbeddings(std::filesystem::path pretrained) override;
void registerEmbeddings() override;
void saveDicts(std::filesystem::path path) override;
void loadDicts(std::filesystem::path path) override;
void setDictsState(Dict::State state) override;
......
......@@ -21,7 +21,7 @@ class NeuralNetworkImpl : public torch::nn::Module, public NameHolder, public St
virtual torch::Tensor forward(torch::Tensor input) = 0;
virtual std::vector<std::vector<long>> extractContext(Config & config) = 0;
virtual void registerEmbeddings(std::filesystem::path pretrained) = 0;
virtual void registerEmbeddings() = 0;
virtual void saveDicts(std::filesystem::path path) = 0;
virtual void loadDicts(std::filesystem::path path) = 0;
virtual void setDictsState(Dict::State state) = 0;
......
......@@ -24,7 +24,7 @@ class NumericColumnModuleImpl : public Submodule
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings(std::filesystem::path pretrained) override;
void registerEmbeddings() override;
};
TORCH_MODULE(NumericColumnModule);
......
......@@ -14,7 +14,7 @@ class RandomNetworkImpl : public NeuralNetworkImpl
RandomNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState);
torch::Tensor forward(torch::Tensor input) override;
std::vector<std::vector<long>> extractContext(Config &) override;
void registerEmbeddings(std::filesystem::path) override;
void registerEmbeddings() override;
void saveDicts(std::filesystem::path path) override;
void loadDicts(std::filesystem::path path) override;
void setDictsState(Dict::State state) override;
......
......@@ -24,7 +24,7 @@ class RawInputModuleImpl : public Submodule
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings(std::filesystem::path pretrained) override;
void registerEmbeddings() override;
};
TORCH_MODULE(RawInputModule);
......
......@@ -24,7 +24,7 @@ class SplitTransModuleImpl : public Submodule
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings(std::filesystem::path pretrained) override;
void registerEmbeddings() override;
};
TORCH_MODULE(SplitTransModule);
......
......@@ -21,7 +21,7 @@ class StateNameModuleImpl : public Submodule
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings(std::filesystem::path pretrained) override;
void registerEmbeddings() override;
};
TORCH_MODULE(StateNameModule);
......
......@@ -21,7 +21,7 @@ class Submodule : public torch::nn::Module, public DictHolder, public StateHolde
virtual std::size_t getInputSize() = 0;
virtual void addToContext(std::vector<std::vector<long>> & context, const Config & config) = 0;
virtual torch::Tensor forward(torch::Tensor input) = 0;
virtual void registerEmbeddings(std::filesystem::path pretrained) = 0;
virtual void registerEmbeddings() = 0;
std::function<std::string(const std::string &)> getFunction(const std::string functionNames);
};
......
......@@ -23,7 +23,7 @@ class UppercaseRateModuleImpl : public Submodule
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(std::vector<std::vector<long>> & context, const Config & config) override;
void registerEmbeddings(std::filesystem::path pretrained) override;
void registerEmbeddings() override;
};
TORCH_MODULE(UppercaseRateModule);
......
......@@ -31,7 +31,7 @@ void AppliableTransModuleImpl::addToContext(std::vector<std::vector<long>> & con
contextElement.emplace_back(0);
}
void AppliableTransModuleImpl::registerEmbeddings(std::filesystem::path)
void AppliableTransModuleImpl::registerEmbeddings()
{
}
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment