Skip to content
Snippets Groups Projects
Commit 0f5a864f authored by Franck Dary's avatar Franck Dary
Browse files

Starting to build neural network

parent f6d4084b
No related branches found
No related tags found
No related merge requests found
......@@ -3,3 +3,4 @@ FILE(GLOB SOURCES src/*.cpp)
add_executable(dev src/dev.cpp)
target_link_libraries(dev common)
target_link_libraries(dev reading_machine)
target_link_libraries(dev torch_modules)
......@@ -5,6 +5,7 @@
#include "SubConfig.hpp"
#include "TransitionSet.hpp"
#include "ReadingMachine.hpp"
#include "TestNetwork.hpp"
int main(int argc, char * argv[])
{
......@@ -26,8 +27,7 @@ int main(int argc, char * argv[])
SubConfig config(goldConfig);
config.setState(machine.getStrategy().getInitialState());
std::vector<std::pair<SubConfig, Transition*>> trainingExamples;
TestNetwork nn;
while (true)
{
......@@ -35,7 +35,10 @@ int main(int argc, char * argv[])
if (!transition)
util::myThrow("No transition appliable !");
trainingExamples.emplace_back(config, transition);
//here train
auto testo = nn(config);
// std::cout << testo << std::endl;
transition->apply(config);
config.addToHistory(transition->getName());
......@@ -52,8 +55,6 @@ int main(int argc, char * argv[])
config.update();
}
trainingExamples[10000].first.printForDebug(stderr);
return 0;
}
#ifndef TESTNETWORK__H
#define TESTNETWORK__H
#include <torch/torch.h>
#include "Config.hpp"
class TestNetworkImpl : torch::nn::Module
{
private :
std::map<Config::String, std::size_t> dict;
torch::nn::Embedding wordEmbeddings{nullptr};
public :
TestNetworkImpl();
torch::Tensor forward(const Config & config);
std::size_t getOrAddDictValue(Config::String s);
};
TORCH_MODULE(TestNetwork);
#endif
#include "TestNetwork.hpp"
TestNetworkImpl::TestNetworkImpl()
{
getOrAddDictValue(Config::String("_null_"));
getOrAddDictValue(Config::String("_unknown_"));
getOrAddDictValue(Config::String("_S_"));
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(200000, 100));
}
torch::Tensor TestNetworkImpl::forward(const Config & config)
{
// std::vector<std::size_t> test{0,1};
// torch::Tensor tens = torch::from_blob(test.data(), {1,2});
// return wordEmbeddings(tens);
constexpr int windowSize = 5;
int startIndex = config.getWordIndex();
while (config.has(0,startIndex-1,0) and config.getWordIndex()-startIndex < windowSize)
startIndex--;
int endIndex = config.getWordIndex();
while (config.has(0,endIndex+1,0) and -config.getWordIndex()+endIndex < windowSize)
endIndex++;
std::vector<std::size_t> words;
for (int i = startIndex; i <= endIndex; ++i)
{
if (!config.has(0, i, 0))
util::myThrow(fmt::format("Config do not have line %d", i));
words.emplace_back(getOrAddDictValue(config.getLastNotEmptyConst("FORM", i)));
}
if (words.empty())
util::myThrow(fmt::format("Empty context with nbLines={} head={} start={} end={}", config.getNbLines(), config.getWordIndex(), startIndex, endIndex));
return wordEmbeddings(torch::from_blob(words.data(), {1, (long int)words.size()}, at::kLong));
}
std::size_t TestNetworkImpl::getOrAddDictValue(Config::String s)
{
if (s.get().empty())
return dict[Config::String("_null_")];
const auto & found = dict.find(s);
if (found == dict.end())
return dict[s] = dict.size();
return found->second;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment