Skip to content
Snippets Groups Projects
Commit 0733a6af authored by Franck Dary's avatar Franck Dary
Browse files

Started to implement data loader

parent 0f5a864f
No related branches found
No related tags found
No related merge requests found
......@@ -6,6 +6,7 @@
#include "TransitionSet.hpp"
#include "ReadingMachine.hpp"
#include "TestNetwork.hpp"
#include "ConfigDataset.hpp"
int main(int argc, char * argv[])
{
......@@ -27,7 +28,15 @@ int main(int argc, char * argv[])
SubConfig config(goldConfig);
config.setState(machine.getStrategy().getInitialState());
TestNetwork nn;
TestNetwork nn(machine.getTransitionSet().size());
torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
optimizer.zero_grad();
std::vector<torch::Tensor> predictionsBatch;
std::vector<torch::Tensor> referencesBatch;
std::vector<SubConfig> configs;
while (true)
{
......@@ -36,9 +45,22 @@ int main(int argc, char * argv[])
util::myThrow("No transition appliable !");
//here train
auto testo = nn(config);
int goldIndex = 3;
auto gold = torch::zeros(machine.getTransitionSet().size(), at::kLong);
gold[goldIndex] = 1;
// referencesBatch.emplace_back(gold);
// predictionsBatch.emplace_back(nn(config));
// auto loss = torch::nll_loss(prediction, gold);
// loss.backward();
// optimizer.step();
configs.emplace_back(config);
if (config.getWordIndex()%1 == 0)
fmt::print("{:.5f}%\n", config.getWordIndex()*100.0/goldConfig.getNbLines());
// std::cout << testo << std::endl;
// if (config.getWordIndex() >= 500)
// exit(1);
transition->apply(config);
config.addToHistory(transition->getName());
......
#ifndef FEATUREFUNCTION__H
#define FEATUREFUNCTION__H
#include <map>
#include <string>
#include "Config.hpp"
class FeatureFunction
{
using Representation = std::vector<std::size_t>;
using Feature = std::function<Config::String(const Config &)>;
private :
std::map<std::string, Feature> features;
std::map<Config::String, std::size_t> dictionary;
private :
const Feature & getOrCreateFeature(const std::string & name);
public :
FeatureFunction(const std::vector<std::string_view> & lines);
Representation getRepresentation(const Config & config) const;
};
#endif
......@@ -4,6 +4,7 @@
#include <memory>
#include "Classifier.hpp"
#include "Strategy.hpp"
#include "FeatureFunction.hpp"
class ReadingMachine
{
......@@ -12,6 +13,7 @@ class ReadingMachine
std::string name;
std::unique_ptr<Classifier> classifier;
std::unique_ptr<Strategy> strategy;
std::unique_ptr<FeatureFunction> featureFunction;
public :
......
......@@ -18,6 +18,7 @@ class TransitionSet
TransitionSet(const std::string & filename);
std::vector<std::pair<Transition*, int>> getAppliableTransitionsCosts(const Config & c);
Transition * getBestAppliableTransition(const Config & c);
std::size_t size() const;
};
#endif
#include "FeatureFunction.hpp"
FeatureFunction::FeatureFunction(const std::vector<std::string_view> & lines)
{
if (!util::doIfNameMatch(std::regex("Features :(.*)"), lines[0], [](auto){}))
util::myThrow(fmt::format("Wrong line '{}', expected 'Features :'", lines[0]));
for (unsigned int i = 1; i < lines.size(); i++)
{
if (util::doIfNameMatch(std::regex("(?: |\\t)*buffer from ((?:-|\\+|)\\d+) to ((?:-|\\+|)\\d+)"), lines[i], [this](auto &sm)
{
getOrCreateFeature(fmt::format("b."));
}))
continue;
util::myThrow(fmt::format("Unknown feature directive '{}'", lines[i]));
}
for (auto & it : features)
fmt::print("{}\n", it.first);
}
FeatureFunction::Representation FeatureFunction::getRepresentation(const Config & config) const
{
Representation representation;
return representation;
}
const FeatureFunction::Feature & FeatureFunction::getOrCreateFeature(const std::string & name)
{
auto found = features.find(name);
if (found != features.end())
return found->second;
if (util::doIfNameMatch(std::regex(""), name, [this,name](auto){features[name] = Feature();}))
return features[name];
util::myThrow(fmt::format("Unknown feature '{}'", name));
return found->second;
}
......@@ -7,17 +7,22 @@ ReadingMachine::ReadingMachine(const std::string & filename)
char buffer[1024];
std::string fileContent;
std::vector<std::string> lines;
while (!std::feof(file))
{
if (buffer != std::fgets(buffer, 1024, file))
break;
// If line is blank or commented (# or //), ignore it
if (util::doIfNameMatch(std::regex("((\\s|\\t)*)(((#|//).*)|)(\n|)"), buffer, [](auto){}))
continue;
fileContent += buffer;
if (buffer[std::strlen(buffer)-1] == '\n')
buffer[std::strlen(buffer)-1] = '\0';
lines.emplace_back(buffer);
}
std::fclose(file);
auto lines = util::split(fileContent, '\n');
try
{
unsigned int curLine = 0;
......@@ -28,7 +33,15 @@ ReadingMachine::ReadingMachine(const std::string & filename)
if (!classifier.get())
util::myThrow("No Classifier specified");
std::vector<std::string_view> restOfFile(lines.begin()+curLine-1, lines.end());
--curLine;
//std::vector<std::string_view> restOfFile;
//while (curLine < lines.size() and !util::doIfNameMatch(std::regex("Strategy(.*)"),lines[curLine], [](auto){}))
// restOfFile.emplace_back(lines[curLine++]);
//featureFunction.reset(new FeatureFunction(restOfFile));
auto restOfFile = std::vector<std::string_view>(lines.begin()+curLine, lines.end());
strategy.reset(new Strategy(restOfFile));
} catch(std::exception & e) {util::myThrow(fmt::format("during reading of '{}' : {}", filename, e.what()));}
......
......@@ -67,3 +67,8 @@ Transition * TransitionSet::getBestAppliableTransition(const Config & c)
return result;
}
std::size_t TransitionSet::size() const
{
return transitions.size();
}
#ifndef CONFIGDATASET__H
#define CONFIGDATASET__H
#include <torch/torch.h>
#include "Config.hpp"
class ConfigDataset : public torch::data::Dataset<ConfigDataset>
{
private :
std::vector<Config> const & configs;
std::vector<std::size_t> const & classes;
public :
ConfigDataset(std::vector<Config> const & configs, std::vector<std::size_t> const & classes);
torch::optional<size_t> size() const override;
torch::data::Example<> get(size_t index) override;
};
#endif
......@@ -4,16 +4,17 @@
#include <torch/torch.h>
#include "Config.hpp"
class TestNetworkImpl : torch::nn::Module
class TestNetworkImpl : public torch::nn::Module
{
private :
std::map<Config::String, std::size_t> dict;
torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Linear linear{nullptr};
public :
TestNetworkImpl();
TestNetworkImpl(int nbOutputs);
torch::Tensor forward(const Config & config);
std::size_t getOrAddDictValue(Config::String s);
};
......
#include "ConfigDataset.hpp"
ConfigDataset::ConfigDataset(std::vector<Config> const & configs, std::vector<std::size_t> const & classes) : configs(configs), classes(classes)
{
}
torch::optional<size_t> ConfigDataset::size() const
{
}
torch::data::Example<> ConfigDataset::get(size_t index)
{
}
#include "TestNetwork.hpp"
TestNetworkImpl::TestNetworkImpl()
TestNetworkImpl::TestNetworkImpl(int nbOutputs)
{
getOrAddDictValue(Config::String("_null_"));
getOrAddDictValue(Config::String("_unknown_"));
getOrAddDictValue(Config::String("_S_"));
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(200000, 100));
linear = register_module("linear", torch::nn::Linear(100, nbOutputs));
}
torch::Tensor TestNetworkImpl::forward(const Config & config)
......@@ -15,11 +16,12 @@ torch::Tensor TestNetworkImpl::forward(const Config & config)
// torch::Tensor tens = torch::from_blob(test.data(), {1,2});
// return wordEmbeddings(tens);
constexpr int windowSize = 5;
int startIndex = config.getWordIndex();
while (config.has(0,startIndex-1,0) and config.getWordIndex()-startIndex < windowSize)
int wordIndex = config.getWordIndex();
int startIndex = wordIndex;
while (config.has(0,startIndex-1,0) and wordIndex-startIndex < windowSize)
startIndex--;
int endIndex = config.getWordIndex();
while (config.has(0,endIndex+1,0) and -config.getWordIndex()+endIndex < windowSize)
int endIndex = wordIndex;
while (config.has(0,endIndex+1,0) and -wordIndex+endIndex < windowSize)
endIndex++;
std::vector<std::size_t> words;
......@@ -32,9 +34,11 @@ torch::Tensor TestNetworkImpl::forward(const Config & config)
}
if (words.empty())
util::myThrow(fmt::format("Empty context with nbLines={} head={} start={} end={}", config.getNbLines(), config.getWordIndex(), startIndex, endIndex));
util::myThrow(fmt::format("Empty context with nbLines={} head={} start={} end={}", config.getNbLines(), wordIndex, startIndex, endIndex));
return wordEmbeddings(torch::from_blob(words.data(), {1, (long int)words.size()}, at::kLong));
auto wordsAsEmb = wordEmbeddings(torch::from_blob(words.data(), {(long int)words.size()}, at::kLong));
return torch::softmax(linear(wordsAsEmb[wordIndex-startIndex]), 0);
}
std::size_t TestNetworkImpl::getOrAddDictValue(Config::String s)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment