Skip to content
Snippets Groups Projects
Commit 9032ef49 authored by Franck Dary's avatar Franck Dary
Browse files

Tried to improve pretrained

parent 8072af59
No related branches found
No related tags found
No related merge requests found
Showing
with 87 additions and 19 deletions
...@@ -24,7 +24,7 @@ if(NOT CMAKE_BUILD_TYPE) ...@@ -24,7 +24,7 @@ if(NOT CMAKE_BUILD_TYPE)
endif() endif()
set(CMAKE_CXX_FLAGS "-Wall -Wextra") set(CMAKE_CXX_FLAGS "-Wall -Wextra")
set(CMAKE_CXX_FLAGS_DEBUG "-g3") set(CMAKE_CXX_FLAGS_DEBUG "-g3 -rdynamic")
set(CMAKE_CXX_FLAGS_RELEASE "-Ofast") set(CMAKE_CXX_FLAGS_RELEASE "-Ofast")
include_directories(fmt/include) include_directories(fmt/include)
......
...@@ -36,6 +36,7 @@ class Dict ...@@ -36,6 +36,7 @@ class Dict
State state; State state;
bool isCountingOccs{false}; bool isCountingOccs{false};
std::set<std::string> prefixes{""}; std::set<std::string> prefixes{""};
bool locked;
public : public :
...@@ -51,6 +52,7 @@ class Dict ...@@ -51,6 +52,7 @@ class Dict
public : public :
void lock();
void countOcc(bool isCountingOccs); void countOcc(bool isCountingOccs);
std::set<std::size_t> getSpecialIndexes(); std::set<std::size_t> getSpecialIndexes();
int getIndexOrInsert(const std::string & element, const std::string & prefix); int getIndexOrInsert(const std::string & element, const std::string & prefix);
......
...@@ -29,6 +29,8 @@ void myThrow(std::string_view message, const std::experimental::source_location ...@@ -29,6 +29,8 @@ void myThrow(std::string_view message, const std::experimental::source_location
std::vector<std::filesystem::path> findFilesByExtension(std::filesystem::path directory, std::string extension); std::vector<std::filesystem::path> findFilesByExtension(std::filesystem::path directory, std::string extension);
std::string getStackTrace();
std::string_view getFilenameFromPath(std::string_view s); std::string_view getFilenameFromPath(std::string_view s);
std::vector<std::string> split(std::string_view s, char delimiter); std::vector<std::string> split(std::string_view s, char delimiter);
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
Dict::Dict(State state) Dict::Dict(State state)
{ {
locked = false;
setState(state); setState(state);
insert(unknownValueStr); insert(unknownValueStr);
insert(nullValueStr); insert(nullValueStr);
...@@ -18,6 +19,12 @@ Dict::Dict(const char * filename, State state) ...@@ -18,6 +19,12 @@ Dict::Dict(const char * filename, State state)
{ {
readFromFile(filename); readFromFile(filename);
setState(state); setState(state);
locked = false;
}
void Dict::lock()
{
locked = true;
} }
void Dict::readFromFile(const char * filename) void Dict::readFromFile(const char * filename)
...@@ -161,6 +168,7 @@ int Dict::_getIndexOrInsert(const std::string & element, const std::string & pre ...@@ -161,6 +168,7 @@ int Dict::_getIndexOrInsert(const std::string & element, const std::string & pre
void Dict::setState(State state) void Dict::setState(State state)
{ {
if (!locked)
this->state = state; this->state = state;
} }
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
#include <iostream> #include <iostream>
#include <fstream> #include <fstream>
#include <unistd.h> #include <unistd.h>
#include <execinfo.h>
#include <cxxabi.h>
#include "upper2lower" #include "upper2lower"
float util::long2float(long l) float util::long2float(long l)
...@@ -445,3 +447,57 @@ std::vector<std::vector<std::string>> util::readTSV(std::string_view tsvFilename ...@@ -445,3 +447,57 @@ std::vector<std::vector<std::string>> util::readTSV(std::string_view tsvFilename
return sentences; return sentences;
} }
std::string util::getStackTrace()
{
std::string res;
try
{
void * array[100];
size_t size;
size = backtrace(array, 100);
char ** messages = backtrace_symbols(array, size);
for (unsigned int i = 1; i < size && messages != NULL; ++i)
{
char *mangled_name = 0, *offset_begin = 0, *offset_end = 0;
for (char *p = messages[i]; *p; ++p)
{
if (*p == '(')
mangled_name = p;
else if (*p == '+')
offset_begin = p;
else if (*p == ')')
{
offset_end = p;
break;
}
}
if (mangled_name && offset_begin && offset_end &&
mangled_name < offset_begin)
{
*mangled_name++ = '\0';
*offset_begin++ = '\0';
*offset_end++ = '\0';
int status = 0;
char * real_name = abi::__cxa_demangle(mangled_name, 0, 0, &status);
res = fmt::format("{}{}[bt] : ({}) {} : {}+{}{}", res, res.size() == 0 ? "" : "\n", i, messages[i], status == 0 ? real_name : mangled_name, offset_begin, offset_end);
}
else
res = fmt::format("{}\n[bt] : ({}) {}", res, i, messages[i]);
}
}
catch (std::exception & e)
{
error(e);
}
return res;
}
...@@ -37,7 +37,7 @@ class Classifier ...@@ -37,7 +37,7 @@ class Classifier
public : public :
Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition, bool train); Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition, bool train, bool loadPretrained=false);
TransitionSet & getTransitionSet(const std::string & state); TransitionSet & getTransitionSet(const std::string & state);
NeuralNetwork & getNN(); NeuralNetwork & getNN();
const std::string & getName() const; const std::string & getName() const;
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "RandomNetwork.hpp" #include "RandomNetwork.hpp"
#include "ModularNetwork.hpp" #include "ModularNetwork.hpp"
Classifier::Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition, bool train) : path(path) Classifier::Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition, bool train, bool loadPretrained) : path(path)
{ {
this->name = name; this->name = name;
std::size_t curIndex = 0; std::size_t curIndex = 0;
...@@ -79,12 +79,12 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std ...@@ -79,12 +79,12 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std
getNN()->eval(); getNN()->eval();
getNN()->loadDicts(path); getNN()->loadDicts(path);
getNN()->registerEmbeddings(); getNN()->registerEmbeddings(loadPretrained);
if (!train) if (!train)
{ {
torch::load(getNN(), getBestFilename(), NeuralNetworkImpl::getDevice()); torch::load(getNN(), getBestFilename(), NeuralNetworkImpl::getDevice());
getNN()->registerEmbeddings(); getNN()->registerEmbeddings(loadPretrained);
} }
else if (std::filesystem::exists(getLastFilename())) else if (std::filesystem::exists(getLastFilename()))
{ {
......
...@@ -175,7 +175,7 @@ void ReadingMachine::removeRareDictElements(float rarityThreshold) ...@@ -175,7 +175,7 @@ void ReadingMachine::removeRareDictElements(float rarityThreshold)
void ReadingMachine::resetClassifiers() void ReadingMachine::resetClassifiers()
{ {
for (unsigned int i = 0; i < classifiers.size(); i++) for (unsigned int i = 0; i < classifiers.size(); i++)
classifiers[i].reset(new Classifier(classifierNames[i], path.parent_path(), classifierDefinitions[i], train)); classifiers[i].reset(new Classifier(classifierNames[i], path.parent_path(), classifierDefinitions[i], train, true));
} }
int ReadingMachine::getNbParameters() const int ReadingMachine::getNbParameters() const
......
...@@ -20,7 +20,7 @@ class AppliableTransModuleImpl : public Submodule ...@@ -20,7 +20,7 @@ class AppliableTransModuleImpl : public Submodule
std::size_t getOutputSize() override; std::size_t getOutputSize() override;
std::size_t getInputSize() override; std::size_t getInputSize() override;
void addToContext(torch::Tensor & context, const Config & config) override; void addToContext(torch::Tensor & context, const Config & config) override;
void registerEmbeddings() override; void registerEmbeddings(bool loadPretrained) override;
}; };
TORCH_MODULE(AppliableTransModule); TORCH_MODULE(AppliableTransModule);
......
...@@ -31,7 +31,7 @@ class ContextModuleImpl : public Submodule ...@@ -31,7 +31,7 @@ class ContextModuleImpl : public Submodule
std::size_t getOutputSize() override; std::size_t getOutputSize() override;
std::size_t getInputSize() override; std::size_t getInputSize() override;
void addToContext(torch::Tensor & context, const Config & config) override; void addToContext(torch::Tensor & context, const Config & config) override;
void registerEmbeddings() override; void registerEmbeddings(bool loadPretrained) override;
}; };
TORCH_MODULE(ContextModule); TORCH_MODULE(ContextModule);
......
...@@ -32,7 +32,7 @@ class ContextualModuleImpl : public Submodule ...@@ -32,7 +32,7 @@ class ContextualModuleImpl : public Submodule
std::size_t getOutputSize() override; std::size_t getOutputSize() override;
std::size_t getInputSize() override; std::size_t getInputSize() override;
void addToContext(torch::Tensor & context, const Config & config) override; void addToContext(torch::Tensor & context, const Config & config) override;
void registerEmbeddings() override; void registerEmbeddings(bool loadPretrained) override;
}; };
TORCH_MODULE(ContextualModule); TORCH_MODULE(ContextualModule);
......
...@@ -28,7 +28,7 @@ class DepthLayerTreeEmbeddingModuleImpl : public Submodule ...@@ -28,7 +28,7 @@ class DepthLayerTreeEmbeddingModuleImpl : public Submodule
std::size_t getOutputSize() override; std::size_t getOutputSize() override;
std::size_t getInputSize() override; std::size_t getInputSize() override;
void addToContext(torch::Tensor & context, const Config & config) override; void addToContext(torch::Tensor & context, const Config & config) override;
void registerEmbeddings() override; void registerEmbeddings(bool loadPretrained) override;
}; };
TORCH_MODULE(DepthLayerTreeEmbeddingModule); TORCH_MODULE(DepthLayerTreeEmbeddingModule);
......
...@@ -27,7 +27,7 @@ class DistanceModuleImpl : public Submodule ...@@ -27,7 +27,7 @@ class DistanceModuleImpl : public Submodule
std::size_t getOutputSize() override; std::size_t getOutputSize() override;
std::size_t getInputSize() override; std::size_t getInputSize() override;
void addToContext(torch::Tensor & context, const Config & config) override; void addToContext(torch::Tensor & context, const Config & config) override;
void registerEmbeddings() override; void registerEmbeddings(bool loadPretrained) override;
}; };
TORCH_MODULE(DistanceModule); TORCH_MODULE(DistanceModule);
......
...@@ -30,7 +30,7 @@ class FocusedColumnModuleImpl : public Submodule ...@@ -30,7 +30,7 @@ class FocusedColumnModuleImpl : public Submodule
std::size_t getOutputSize() override; std::size_t getOutputSize() override;
std::size_t getInputSize() override; std::size_t getInputSize() override;
void addToContext(torch::Tensor & context, const Config & config) override; void addToContext(torch::Tensor & context, const Config & config) override;
void registerEmbeddings() override; void registerEmbeddings(bool loadPretrained) override;
}; };
TORCH_MODULE(FocusedColumnModule); TORCH_MODULE(FocusedColumnModule);
......
...@@ -26,7 +26,7 @@ class HistoryMineModuleImpl : public Submodule ...@@ -26,7 +26,7 @@ class HistoryMineModuleImpl : public Submodule
std::size_t getOutputSize() override; std::size_t getOutputSize() override;
std::size_t getInputSize() override; std::size_t getInputSize() override;
void addToContext(torch::Tensor & context, const Config & config) override; void addToContext(torch::Tensor & context, const Config & config) override;
void registerEmbeddings() override; void registerEmbeddings(bool loadPretrained) override;
}; };
TORCH_MODULE(HistoryMineModule); TORCH_MODULE(HistoryMineModule);
......
...@@ -26,7 +26,7 @@ class HistoryModuleImpl : public Submodule ...@@ -26,7 +26,7 @@ class HistoryModuleImpl : public Submodule
std::size_t getOutputSize() override; std::size_t getOutputSize() override;
std::size_t getInputSize() override; std::size_t getInputSize() override;
void addToContext(torch::Tensor & context, const Config & config) override; void addToContext(torch::Tensor & context, const Config & config) override;
void registerEmbeddings() override; void registerEmbeddings(bool loadPretrained) override;
}; };
TORCH_MODULE(HistoryModule); TORCH_MODULE(HistoryModule);
......
...@@ -33,7 +33,7 @@ class ModularNetworkImpl : public NeuralNetworkImpl ...@@ -33,7 +33,7 @@ class ModularNetworkImpl : public NeuralNetworkImpl
ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions, std::filesystem::path path); ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions, std::filesystem::path path);
torch::Tensor forward(torch::Tensor input, const std::string & state) override; torch::Tensor forward(torch::Tensor input, const std::string & state) override;
torch::Tensor extractContext(Config & config) override; torch::Tensor extractContext(Config & config) override;
void registerEmbeddings() override; void registerEmbeddings(bool loadPretrained) override;
void saveDicts(std::filesystem::path path) override; void saveDicts(std::filesystem::path path) override;
void loadDicts(std::filesystem::path path) override; void loadDicts(std::filesystem::path path) override;
void setDictsState(Dict::State state) override; void setDictsState(Dict::State state) override;
......
...@@ -16,7 +16,7 @@ class NeuralNetworkImpl : public torch::nn::Module, public NameHolder ...@@ -16,7 +16,7 @@ class NeuralNetworkImpl : public torch::nn::Module, public NameHolder
virtual torch::Tensor forward(torch::Tensor input, const std::string & state) = 0; virtual torch::Tensor forward(torch::Tensor input, const std::string & state) = 0;
virtual torch::Tensor extractContext(Config & config) = 0; virtual torch::Tensor extractContext(Config & config) = 0;
virtual void registerEmbeddings() = 0; virtual void registerEmbeddings(bool loadPretrained) = 0;
virtual void saveDicts(std::filesystem::path path) = 0; virtual void saveDicts(std::filesystem::path path) = 0;
virtual void loadDicts(std::filesystem::path path) = 0; virtual void loadDicts(std::filesystem::path path) = 0;
virtual void setDictsState(Dict::State state) = 0; virtual void setDictsState(Dict::State state) = 0;
......
...@@ -25,7 +25,7 @@ class NumericColumnModuleImpl : public Submodule ...@@ -25,7 +25,7 @@ class NumericColumnModuleImpl : public Submodule
std::size_t getOutputSize() override; std::size_t getOutputSize() override;
std::size_t getInputSize() override; std::size_t getInputSize() override;
void addToContext(torch::Tensor & context, const Config & config) override; void addToContext(torch::Tensor & context, const Config & config) override;
void registerEmbeddings() override; void registerEmbeddings(bool loadPretrained) override;
}; };
TORCH_MODULE(NumericColumnModule); TORCH_MODULE(NumericColumnModule);
......
...@@ -14,7 +14,7 @@ class RandomNetworkImpl : public NeuralNetworkImpl ...@@ -14,7 +14,7 @@ class RandomNetworkImpl : public NeuralNetworkImpl
RandomNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState); RandomNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState);
torch::Tensor forward(torch::Tensor input, const std::string & state) override; torch::Tensor forward(torch::Tensor input, const std::string & state) override;
torch::Tensor extractContext(Config &) override; torch::Tensor extractContext(Config &) override;
void registerEmbeddings() override; void registerEmbeddings(bool loadPretrained) override;
void saveDicts(std::filesystem::path path) override; void saveDicts(std::filesystem::path path) override;
void loadDicts(std::filesystem::path path) override; void loadDicts(std::filesystem::path path) override;
void setDictsState(Dict::State state) override; void setDictsState(Dict::State state) override;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment