Commit 9032ef49 authored by Franck Dary's avatar Franck Dary
Browse files

Tried to improve pretrained

parent 8072af59
......@@ -24,7 +24,7 @@ if(NOT CMAKE_BUILD_TYPE)
endif()
set(CMAKE_CXX_FLAGS "-Wall -Wextra")
set(CMAKE_CXX_FLAGS_DEBUG "-g3")
set(CMAKE_CXX_FLAGS_DEBUG "-g3 -rdynamic")
set(CMAKE_CXX_FLAGS_RELEASE "-Ofast")
include_directories(fmt/include)
......
......@@ -36,6 +36,7 @@ class Dict
State state;
bool isCountingOccs{false};
std::set<std::string> prefixes{""};
bool locked;
public :
......@@ -51,6 +52,7 @@ class Dict
public :
void lock();
void countOcc(bool isCountingOccs);
std::set<std::size_t> getSpecialIndexes();
int getIndexOrInsert(const std::string & element, const std::string & prefix);
......
......@@ -29,6 +29,8 @@ void myThrow(std::string_view message, const std::experimental::source_location
std::vector<std::filesystem::path> findFilesByExtension(std::filesystem::path directory, std::string extension);
std::string getStackTrace();
std::string_view getFilenameFromPath(std::string_view s);
std::vector<std::string> split(std::string_view s, char delimiter);
......
......@@ -3,6 +3,7 @@
Dict::Dict(State state)
{
locked = false;
setState(state);
insert(unknownValueStr);
insert(nullValueStr);
......@@ -18,6 +19,12 @@ Dict::Dict(const char * filename, State state)
{
readFromFile(filename);
setState(state);
locked = false;
}
void Dict::lock()
{
locked = true;
}
void Dict::readFromFile(const char * filename)
......@@ -161,7 +168,8 @@ int Dict::_getIndexOrInsert(const std::string & element, const std::string & pre
void Dict::setState(State state)
{
this->state = state;
if (!locked)
this->state = state;
}
Dict::State Dict::getState() const
......
......@@ -5,6 +5,8 @@
#include <iostream>
#include <fstream>
#include <unistd.h>
#include <execinfo.h>
#include <cxxabi.h>
#include "upper2lower"
float util::long2float(long l)
......@@ -445,3 +447,57 @@ std::vector<std::vector<std::string>> util::readTSV(std::string_view tsvFilename
return sentences;
}
std::string util::getStackTrace()
{
std::string res;
try
{
void * array[100];
size_t size;
size = backtrace(array, 100);
char ** messages = backtrace_symbols(array, size);
for (unsigned int i = 1; i < size && messages != NULL; ++i)
{
char *mangled_name = 0, *offset_begin = 0, *offset_end = 0;
for (char *p = messages[i]; *p; ++p)
{
if (*p == '(')
mangled_name = p;
else if (*p == '+')
offset_begin = p;
else if (*p == ')')
{
offset_end = p;
break;
}
}
if (mangled_name && offset_begin && offset_end &&
mangled_name < offset_begin)
{
*mangled_name++ = '\0';
*offset_begin++ = '\0';
*offset_end++ = '\0';
int status = 0;
char * real_name = abi::__cxa_demangle(mangled_name, 0, 0, &status);
res = fmt::format("{}{}[bt] : ({}) {} : {}+{}{}", res, res.size() == 0 ? "" : "\n", i, messages[i], status == 0 ? real_name : mangled_name, offset_begin, offset_end);
}
else
res = fmt::format("{}\n[bt] : ({}) {}", res, i, messages[i]);
}
}
catch (std::exception & e)
{
error(e);
}
return res;
}
......@@ -37,7 +37,7 @@ class Classifier
public :
Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition, bool train);
Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition, bool train, bool loadPretrained=false);
TransitionSet & getTransitionSet(const std::string & state);
NeuralNetwork & getNN();
const std::string & getName() const;
......
......@@ -3,7 +3,7 @@
#include "RandomNetwork.hpp"
#include "ModularNetwork.hpp"
Classifier::Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition, bool train) : path(path)
Classifier::Classifier(const std::string & name, std::filesystem::path path, std::vector<std::string> definition, bool train, bool loadPretrained) : path(path)
{
this->name = name;
std::size_t curIndex = 0;
......@@ -79,12 +79,12 @@ Classifier::Classifier(const std::string & name, std::filesystem::path path, std
getNN()->eval();
getNN()->loadDicts(path);
getNN()->registerEmbeddings();
getNN()->registerEmbeddings(loadPretrained);
if (!train)
{
torch::load(getNN(), getBestFilename(), NeuralNetworkImpl::getDevice());
getNN()->registerEmbeddings();
getNN()->registerEmbeddings(loadPretrained);
}
else if (std::filesystem::exists(getLastFilename()))
{
......
......@@ -175,7 +175,7 @@ void ReadingMachine::removeRareDictElements(float rarityThreshold)
void ReadingMachine::resetClassifiers()
{
for (unsigned int i = 0; i < classifiers.size(); i++)
classifiers[i].reset(new Classifier(classifierNames[i], path.parent_path(), classifierDefinitions[i], train));
classifiers[i].reset(new Classifier(classifierNames[i], path.parent_path(), classifierDefinitions[i], train, true));
}
int ReadingMachine::getNbParameters() const
......
......@@ -20,7 +20,7 @@ class AppliableTransModuleImpl : public Submodule
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(torch::Tensor & context, const Config & config) override;
void registerEmbeddings() override;
void registerEmbeddings(bool loadPretrained) override;
};
TORCH_MODULE(AppliableTransModule);
......
......@@ -31,7 +31,7 @@ class ContextModuleImpl : public Submodule
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(torch::Tensor & context, const Config & config) override;
void registerEmbeddings() override;
void registerEmbeddings(bool loadPretrained) override;
};
TORCH_MODULE(ContextModule);
......
......@@ -32,7 +32,7 @@ class ContextualModuleImpl : public Submodule
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(torch::Tensor & context, const Config & config) override;
void registerEmbeddings() override;
void registerEmbeddings(bool loadPretrained) override;
};
TORCH_MODULE(ContextualModule);
......
......@@ -28,7 +28,7 @@ class DepthLayerTreeEmbeddingModuleImpl : public Submodule
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(torch::Tensor & context, const Config & config) override;
void registerEmbeddings() override;
void registerEmbeddings(bool loadPretrained) override;
};
TORCH_MODULE(DepthLayerTreeEmbeddingModule);
......
......@@ -27,7 +27,7 @@ class DistanceModuleImpl : public Submodule
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(torch::Tensor & context, const Config & config) override;
void registerEmbeddings() override;
void registerEmbeddings(bool loadPretrained) override;
};
TORCH_MODULE(DistanceModule);
......
......@@ -30,7 +30,7 @@ class FocusedColumnModuleImpl : public Submodule
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(torch::Tensor & context, const Config & config) override;
void registerEmbeddings() override;
void registerEmbeddings(bool loadPretrained) override;
};
TORCH_MODULE(FocusedColumnModule);
......
......@@ -26,7 +26,7 @@ class HistoryMineModuleImpl : public Submodule
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(torch::Tensor & context, const Config & config) override;
void registerEmbeddings() override;
void registerEmbeddings(bool loadPretrained) override;
};
TORCH_MODULE(HistoryMineModule);
......
......@@ -26,7 +26,7 @@ class HistoryModuleImpl : public Submodule
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(torch::Tensor & context, const Config & config) override;
void registerEmbeddings() override;
void registerEmbeddings(bool loadPretrained) override;
};
TORCH_MODULE(HistoryModule);
......
......@@ -33,7 +33,7 @@ class ModularNetworkImpl : public NeuralNetworkImpl
ModularNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState, std::vector<std::string> definitions, std::filesystem::path path);
torch::Tensor forward(torch::Tensor input, const std::string & state) override;
torch::Tensor extractContext(Config & config) override;
void registerEmbeddings() override;
void registerEmbeddings(bool loadPretrained) override;
void saveDicts(std::filesystem::path path) override;
void loadDicts(std::filesystem::path path) override;
void setDictsState(Dict::State state) override;
......
......@@ -16,7 +16,7 @@ class NeuralNetworkImpl : public torch::nn::Module, public NameHolder
virtual torch::Tensor forward(torch::Tensor input, const std::string & state) = 0;
virtual torch::Tensor extractContext(Config & config) = 0;
virtual void registerEmbeddings() = 0;
virtual void registerEmbeddings(bool loadPretrained) = 0;
virtual void saveDicts(std::filesystem::path path) = 0;
virtual void loadDicts(std::filesystem::path path) = 0;
virtual void setDictsState(Dict::State state) = 0;
......
......@@ -25,7 +25,7 @@ class NumericColumnModuleImpl : public Submodule
std::size_t getOutputSize() override;
std::size_t getInputSize() override;
void addToContext(torch::Tensor & context, const Config & config) override;
void registerEmbeddings() override;
void registerEmbeddings(bool loadPretrained) override;
};
TORCH_MODULE(NumericColumnModule);
......
......@@ -14,7 +14,7 @@ class RandomNetworkImpl : public NeuralNetworkImpl
RandomNetworkImpl(std::string name, std::map<std::string,std::size_t> nbOutputsPerState);
torch::Tensor forward(torch::Tensor input, const std::string & state) override;
torch::Tensor extractContext(Config &) override;
void registerEmbeddings() override;
void registerEmbeddings(bool loadPretrained) override;
void saveDicts(std::filesystem::path path) override;
void loadDicts(std::filesystem::path path) override;
void setDictsState(Dict::State state) override;
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment