Skip to content
Snippets Groups Projects
Commit 0733a6af authored by Franck Dary's avatar Franck Dary
Browse files

Started to implement data loader

parent 0f5a864f
No related branches found
No related tags found
No related merge requests found
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "TransitionSet.hpp" #include "TransitionSet.hpp"
#include "ReadingMachine.hpp" #include "ReadingMachine.hpp"
#include "TestNetwork.hpp" #include "TestNetwork.hpp"
#include "ConfigDataset.hpp"
int main(int argc, char * argv[]) int main(int argc, char * argv[])
{ {
...@@ -27,7 +28,15 @@ int main(int argc, char * argv[]) ...@@ -27,7 +28,15 @@ int main(int argc, char * argv[])
SubConfig config(goldConfig); SubConfig config(goldConfig);
config.setState(machine.getStrategy().getInitialState()); config.setState(machine.getStrategy().getInitialState());
TestNetwork nn;
TestNetwork nn(machine.getTransitionSet().size());
torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
optimizer.zero_grad();
std::vector<torch::Tensor> predictionsBatch;
std::vector<torch::Tensor> referencesBatch;
std::vector<SubConfig> configs;
while (true) while (true)
{ {
...@@ -36,9 +45,22 @@ int main(int argc, char * argv[]) ...@@ -36,9 +45,22 @@ int main(int argc, char * argv[])
util::myThrow("No transition appliable !"); util::myThrow("No transition appliable !");
//here train //here train
auto testo = nn(config); int goldIndex = 3;
auto gold = torch::zeros(machine.getTransitionSet().size(), at::kLong);
gold[goldIndex] = 1;
// referencesBatch.emplace_back(gold);
// predictionsBatch.emplace_back(nn(config));
// auto loss = torch::nll_loss(prediction, gold);
// loss.backward();
// optimizer.step();
configs.emplace_back(config);
if (config.getWordIndex()%1 == 0)
fmt::print("{:.5f}%\n", config.getWordIndex()*100.0/goldConfig.getNbLines());
// std::cout << testo << std::endl; // if (config.getWordIndex() >= 500)
// exit(1);
transition->apply(config); transition->apply(config);
config.addToHistory(transition->getName()); config.addToHistory(transition->getName());
......
#ifndef FEATUREFUNCTION__H
#define FEATUREFUNCTION__H
#include <map>
#include <string>
#include "Config.hpp"
class FeatureFunction
{
using Representation = std::vector<std::size_t>;
using Feature = std::function<Config::String(const Config &)>;
private :
std::map<std::string, Feature> features;
std::map<Config::String, std::size_t> dictionary;
private :
const Feature & getOrCreateFeature(const std::string & name);
public :
FeatureFunction(const std::vector<std::string_view> & lines);
Representation getRepresentation(const Config & config) const;
};
#endif
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <memory> #include <memory>
#include "Classifier.hpp" #include "Classifier.hpp"
#include "Strategy.hpp" #include "Strategy.hpp"
#include "FeatureFunction.hpp"
class ReadingMachine class ReadingMachine
{ {
...@@ -12,6 +13,7 @@ class ReadingMachine ...@@ -12,6 +13,7 @@ class ReadingMachine
std::string name; std::string name;
std::unique_ptr<Classifier> classifier; std::unique_ptr<Classifier> classifier;
std::unique_ptr<Strategy> strategy; std::unique_ptr<Strategy> strategy;
std::unique_ptr<FeatureFunction> featureFunction;
public : public :
......
...@@ -18,6 +18,7 @@ class TransitionSet ...@@ -18,6 +18,7 @@ class TransitionSet
TransitionSet(const std::string & filename); TransitionSet(const std::string & filename);
std::vector<std::pair<Transition*, int>> getAppliableTransitionsCosts(const Config & c); std::vector<std::pair<Transition*, int>> getAppliableTransitionsCosts(const Config & c);
Transition * getBestAppliableTransition(const Config & c); Transition * getBestAppliableTransition(const Config & c);
std::size_t size() const;
}; };
#endif #endif
#include "FeatureFunction.hpp"
FeatureFunction::FeatureFunction(const std::vector<std::string_view> & lines)
{
if (!util::doIfNameMatch(std::regex("Features :(.*)"), lines[0], [](auto){}))
util::myThrow(fmt::format("Wrong line '{}', expected 'Features :'", lines[0]));
for (unsigned int i = 1; i < lines.size(); i++)
{
if (util::doIfNameMatch(std::regex("(?: |\\t)*buffer from ((?:-|\\+|)\\d+) to ((?:-|\\+|)\\d+)"), lines[i], [this](auto &sm)
{
getOrCreateFeature(fmt::format("b."));
}))
continue;
util::myThrow(fmt::format("Unknown feature directive '{}'", lines[i]));
}
for (auto & it : features)
fmt::print("{}\n", it.first);
}
FeatureFunction::Representation FeatureFunction::getRepresentation(const Config & config) const
{
Representation representation;
return representation;
}
const FeatureFunction::Feature & FeatureFunction::getOrCreateFeature(const std::string & name)
{
auto found = features.find(name);
if (found != features.end())
return found->second;
if (util::doIfNameMatch(std::regex(""), name, [this,name](auto){features[name] = Feature();}))
return features[name];
util::myThrow(fmt::format("Unknown feature '{}'", name));
return found->second;
}
...@@ -7,17 +7,22 @@ ReadingMachine::ReadingMachine(const std::string & filename) ...@@ -7,17 +7,22 @@ ReadingMachine::ReadingMachine(const std::string & filename)
char buffer[1024]; char buffer[1024];
std::string fileContent; std::string fileContent;
std::vector<std::string> lines;
while (!std::feof(file)) while (!std::feof(file))
{ {
if (buffer != std::fgets(buffer, 1024, file)) if (buffer != std::fgets(buffer, 1024, file))
break; break;
// If line is blank or commented (# or //), ignore it
if (util::doIfNameMatch(std::regex("((\\s|\\t)*)(((#|//).*)|)(\n|)"), buffer, [](auto){}))
continue;
fileContent += buffer; if (buffer[std::strlen(buffer)-1] == '\n')
buffer[std::strlen(buffer)-1] = '\0';
lines.emplace_back(buffer);
} }
std::fclose(file); std::fclose(file);
auto lines = util::split(fileContent, '\n');
try try
{ {
unsigned int curLine = 0; unsigned int curLine = 0;
...@@ -28,7 +33,15 @@ ReadingMachine::ReadingMachine(const std::string & filename) ...@@ -28,7 +33,15 @@ ReadingMachine::ReadingMachine(const std::string & filename)
if (!classifier.get()) if (!classifier.get())
util::myThrow("No Classifier specified"); util::myThrow("No Classifier specified");
std::vector<std::string_view> restOfFile(lines.begin()+curLine-1, lines.end()); --curLine;
//std::vector<std::string_view> restOfFile;
//while (curLine < lines.size() and !util::doIfNameMatch(std::regex("Strategy(.*)"),lines[curLine], [](auto){}))
// restOfFile.emplace_back(lines[curLine++]);
//featureFunction.reset(new FeatureFunction(restOfFile));
auto restOfFile = std::vector<std::string_view>(lines.begin()+curLine, lines.end());
strategy.reset(new Strategy(restOfFile)); strategy.reset(new Strategy(restOfFile));
} catch(std::exception & e) {util::myThrow(fmt::format("during reading of '{}' : {}", filename, e.what()));} } catch(std::exception & e) {util::myThrow(fmt::format("during reading of '{}' : {}", filename, e.what()));}
......
...@@ -67,3 +67,8 @@ Transition * TransitionSet::getBestAppliableTransition(const Config & c) ...@@ -67,3 +67,8 @@ Transition * TransitionSet::getBestAppliableTransition(const Config & c)
return result; return result;
} }
std::size_t TransitionSet::size() const
{
return transitions.size();
}
#ifndef CONFIGDATASET__H
#define CONFIGDATASET__H
#include <torch/torch.h>
#include "Config.hpp"
class ConfigDataset : public torch::data::Dataset<ConfigDataset>
{
private :
std::vector<Config> const & configs;
std::vector<std::size_t> const & classes;
public :
ConfigDataset(std::vector<Config> const & configs, std::vector<std::size_t> const & classes);
torch::optional<size_t> size() const override;
torch::data::Example<> get(size_t index) override;
};
#endif
...@@ -4,16 +4,17 @@ ...@@ -4,16 +4,17 @@
#include <torch/torch.h> #include <torch/torch.h>
#include "Config.hpp" #include "Config.hpp"
class TestNetworkImpl : torch::nn::Module class TestNetworkImpl : public torch::nn::Module
{ {
private : private :
std::map<Config::String, std::size_t> dict; std::map<Config::String, std::size_t> dict;
torch::nn::Embedding wordEmbeddings{nullptr}; torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Linear linear{nullptr};
public : public :
TestNetworkImpl(); TestNetworkImpl(int nbOutputs);
torch::Tensor forward(const Config & config); torch::Tensor forward(const Config & config);
std::size_t getOrAddDictValue(Config::String s); std::size_t getOrAddDictValue(Config::String s);
}; };
......
#include "ConfigDataset.hpp"
ConfigDataset::ConfigDataset(std::vector<Config> const & configs, std::vector<std::size_t> const & classes) : configs(configs), classes(classes)
{
}
torch::optional<size_t> ConfigDataset::size() const
{
}
torch::data::Example<> ConfigDataset::get(size_t index)
{
}
#include "TestNetwork.hpp" #include "TestNetwork.hpp"
TestNetworkImpl::TestNetworkImpl() TestNetworkImpl::TestNetworkImpl(int nbOutputs)
{ {
getOrAddDictValue(Config::String("_null_")); getOrAddDictValue(Config::String("_null_"));
getOrAddDictValue(Config::String("_unknown_")); getOrAddDictValue(Config::String("_unknown_"));
getOrAddDictValue(Config::String("_S_")); getOrAddDictValue(Config::String("_S_"));
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(200000, 100)); wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(200000, 100));
linear = register_module("linear", torch::nn::Linear(100, nbOutputs));
} }
torch::Tensor TestNetworkImpl::forward(const Config & config) torch::Tensor TestNetworkImpl::forward(const Config & config)
...@@ -15,11 +16,12 @@ torch::Tensor TestNetworkImpl::forward(const Config & config) ...@@ -15,11 +16,12 @@ torch::Tensor TestNetworkImpl::forward(const Config & config)
// torch::Tensor tens = torch::from_blob(test.data(), {1,2}); // torch::Tensor tens = torch::from_blob(test.data(), {1,2});
// return wordEmbeddings(tens); // return wordEmbeddings(tens);
constexpr int windowSize = 5; constexpr int windowSize = 5;
int startIndex = config.getWordIndex(); int wordIndex = config.getWordIndex();
while (config.has(0,startIndex-1,0) and config.getWordIndex()-startIndex < windowSize) int startIndex = wordIndex;
while (config.has(0,startIndex-1,0) and wordIndex-startIndex < windowSize)
startIndex--; startIndex--;
int endIndex = config.getWordIndex(); int endIndex = wordIndex;
while (config.has(0,endIndex+1,0) and -config.getWordIndex()+endIndex < windowSize) while (config.has(0,endIndex+1,0) and -wordIndex+endIndex < windowSize)
endIndex++; endIndex++;
std::vector<std::size_t> words; std::vector<std::size_t> words;
...@@ -32,9 +34,11 @@ torch::Tensor TestNetworkImpl::forward(const Config & config) ...@@ -32,9 +34,11 @@ torch::Tensor TestNetworkImpl::forward(const Config & config)
} }
if (words.empty()) if (words.empty())
util::myThrow(fmt::format("Empty context with nbLines={} head={} start={} end={}", config.getNbLines(), config.getWordIndex(), startIndex, endIndex)); util::myThrow(fmt::format("Empty context with nbLines={} head={} start={} end={}", config.getNbLines(), wordIndex, startIndex, endIndex));
return wordEmbeddings(torch::from_blob(words.data(), {1, (long int)words.size()}, at::kLong)); auto wordsAsEmb = wordEmbeddings(torch::from_blob(words.data(), {(long int)words.size()}, at::kLong));
return torch::softmax(linear(wordsAsEmb[wordIndex-startIndex]), 0);
} }
std::size_t TestNetworkImpl::getOrAddDictValue(Config::String s) std::size_t TestNetworkImpl::getOrAddDictValue(Config::String s)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment