Skip to content
Snippets Groups Projects
Commit 06b877e7 authored by Franck Dary's avatar Franck Dary
Browse files

Worked with pytorch

parent 6e1d31e3
No related branches found
No related tags found
No related merge requests found
......@@ -11,7 +11,7 @@ class Dict
enum State {Open, Closed};
enum Encoding {Binary, Ascii};
private :
public :
static constexpr char const * unknownValueStr = "__unknownValue__";
static constexpr char const * nullValueStr = "__nullValue__";
......
......@@ -29,17 +29,10 @@ int main(int argc, char * argv[])
config.setState(machine.getStrategy().getInitialState());
std::vector<torch::Tensor> contexts;
std::vector<torch::Tensor> classes;
TestNetwork nn(machine.getTransitionSet().size());
torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
optimizer.zero_grad();
std::vector<torch::Tensor> predictionsBatch;
std::vector<torch::Tensor> referencesBatch;
std::vector<std::unique_ptr<Config>> configs;
std::vector<std::size_t> classes;
fmt::print("Generating dataset...");
fmt::print("Generating dataset...\n");
Dict dict(Dict::State::Open);
......@@ -49,21 +42,13 @@ int main(int argc, char * argv[])
if (!transition)
util::myThrow("No transition appliable !");
//here train
int goldIndex = 3;
auto gold = torch::zeros(machine.getTransitionSet().size(), at::kLong);
gold[goldIndex] = 1;
// referencesBatch.emplace_back(gold);
// predictionsBatch.emplace_back(nn(config));
auto context = config.extractContext(5,5,dict);
contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
// auto loss = torch::nll_loss(prediction, gold);
// loss.backward();
// optimizer.step();
configs.emplace_back(std::unique_ptr<Config>(new SubConfig(config)));
classes.emplace_back(goldIndex);
int goldIndex = 3;
auto gold = torch::from_blob(&goldIndex, {1}, at::kLong).clone();
// if (config.getWordIndex() >= 500)
// exit(1);
classes.emplace_back(gold);
transition->apply(config);
config.addToHistory(transition->getName());
......@@ -80,16 +65,39 @@ int main(int argc, char * argv[])
config.update();
}
auto dataset = ConfigDataset(configs, classes, machine.getTransitionSet().size(), dict).map(torch::data::transforms::Stack<>());
auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>());
fmt::print("Done! size={}\n", *dataset.size());
fmt::print("Done!\n");
int batchSize = 100;
auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize));
auto dataLoader = torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(std::move(dataset), 50);
TestNetwork nn(machine.getTransitionSet().size(), 5);
torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
for (auto & batch : *dataLoader)
for (int epoch = 1; epoch <= 5; ++epoch)
{
auto data = batch.data;
auto labels = batch.target.squeeze();
float totalLoss = 0.0;
torch::Tensor example;
for (auto & batch : *dataLoader)
{
optimizer.zero_grad();
auto data = batch.data;
auto labels = batch.target.squeeze();
auto prediction = nn(data);
example = prediction[0];
auto loss = torch::nll_loss(torch::log(prediction), labels);
totalLoss += loss.item<float>();
loss.backward();
optimizer.step();
}
fmt::print("Epoch {} : loss={:.2f}\n", epoch, totalLoss);
std::cout << example << std::endl;
}
return 0;
......
......@@ -99,7 +99,7 @@ class Config
String getState() const;
void setState(const std::string state);
bool stateIsDone() const;
std::vector<int> extractContext(int leftBorder, int rightBorder, Dict & dict) const;
std::vector<long> extractContext(int leftBorder, int rightBorder, Dict & dict) const;
};
......
......@@ -366,30 +366,29 @@ bool Config::stateIsDone() const
return !has(0, wordIndex+1, 0);
}
std::vector<int> Config::extractContext(int leftBorder, int rightBorder, Dict & dict) const
std::vector<long> Config::extractContext(int leftBorder, int rightBorder, Dict & dict) const
{
std::vector<int> context;
std::stack<int> leftContext;
for (int index = wordIndex-1; has(0,index,0) && (int)leftContext.size() < leftBorder; --index)
if (isToken(index))
leftContext.push(dict.getIndexOrInsert(getLastNotEmptyConst("FORM", index)));
int startIndex = wordIndex;
std::vector<long> context;
for (int i = 0; i < leftBorder and has(0,startIndex-1,0); i++)
do
--startIndex;
while (!isToken(startIndex) and has(0,startIndex-1,0));
int endIndex = wordIndex;
for (int i = 0; i < rightBorder and has(0,endIndex+1,0); i++)
do
++endIndex;
while (!isToken(endIndex) and has(0,endIndex+1,0));
while ((int)context.size() < leftBorder-(int)leftContext.size())
context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
while (!leftContext.empty())
{
context.emplace_back(leftContext.top());
leftContext.pop();
}
for (int i = startIndex; i <= endIndex; ++i)
if (isToken(i))
context.emplace_back(dict.getIndexOrInsert(getLastNotEmptyConst("FORM", i)));
for (int index = wordIndex; has(0,index,0) && (int)context.size() < leftBorder+rightBorder+1; ++index)
if (isToken(index))
context.emplace_back(dict.getIndexOrInsert(getLastNotEmptyConst("FORM", index)));
//TODO gérer les cas où la taille est differente...
return {0};
while ((int)context.size() < leftBorder+rightBorder+1)
context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
return context;
}
......
......@@ -8,14 +8,12 @@ class ConfigDataset : public torch::data::Dataset<ConfigDataset>
{
private :
std::vector<std::unique_ptr<Config>> const & configs;
std::vector<std::size_t> const & classes;
std::size_t nbClasses;
Dict & dict;
std::vector<torch::Tensor> contexts;
std::vector<torch::Tensor> classes;
public :
explicit ConfigDataset(std::vector<std::unique_ptr<Config>> const & configs, std::vector<std::size_t> const & classes, std::size_t nbClasses, Dict & dict);
explicit ConfigDataset(std::vector<torch::Tensor> contexts, std::vector<torch::Tensor> classes);
torch::optional<size_t> size() const override;
torch::data::Example<> get(size_t index) override;
};
......
......@@ -8,15 +8,14 @@ class TestNetworkImpl : public torch::nn::Module
{
private :
std::map<Config::String, std::size_t> dict;
torch::nn::Embedding wordEmbeddings{nullptr};
torch::nn::Linear linear{nullptr};
int focusedIndex;
public :
TestNetworkImpl(int nbOutputs);
torch::Tensor forward(const Config & config);
std::size_t getOrAddDictValue(Config::String s);
TestNetworkImpl(int nbOutputs, int focusedIndex);
torch::Tensor forward(torch::Tensor input);
};
TORCH_MODULE(TestNetwork);
......
#include "ConfigDataset.hpp"
ConfigDataset::ConfigDataset(std::vector<std::unique_ptr<Config>> const & configs, std::vector<std::size_t> const & classes, std::size_t nbClasses, Dict & dict) : configs(configs), classes(classes), nbClasses(nbClasses), dict(dict)
ConfigDataset::ConfigDataset(std::vector<torch::Tensor> contexts, std::vector<torch::Tensor> classes) : contexts(contexts), classes(classes)
{
}
torch::optional<size_t> ConfigDataset::size() const
{
return configs.size();
return contexts.size();
}
torch::data::Example<> ConfigDataset::get(size_t index)
{
auto context = configs[index]->extractContext(5,5,dict);
auto tensorClass = torch::zeros(nbClasses);
tensorClass[classes[index]] = 1;
return {torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone(), tensorClass};
return {contexts[index], classes[index]};
}
#include "TestNetwork.hpp"
TestNetworkImpl::TestNetworkImpl(int nbOutputs)
TestNetworkImpl::TestNetworkImpl(int nbOutputs, int focusedIndex)
{
getOrAddDictValue(Config::String("_null_"));
getOrAddDictValue(Config::String("_unknown_"));
getOrAddDictValue(Config::String("_S_"));
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(200000, 100));
linear = register_module("linear", torch::nn::Linear(100, nbOutputs));
}
torch::Tensor TestNetworkImpl::forward(const Config & config)
{
// std::vector<std::size_t> test{0,1};
// torch::Tensor tens = torch::from_blob(test.data(), {1,2});
// return wordEmbeddings(tens);
constexpr int windowSize = 5;
int wordIndex = config.getWordIndex();
int startIndex = wordIndex;
while (config.has(0,startIndex-1,0) and wordIndex-startIndex < windowSize)
startIndex--;
int endIndex = wordIndex;
while (config.has(0,endIndex+1,0) and -wordIndex+endIndex < windowSize)
endIndex++;
std::vector<std::size_t> words;
for (int i = startIndex; i <= endIndex; ++i)
{
if (!config.has(0, i, 0))
util::myThrow(fmt::format("Config do not have line %d", i));
words.emplace_back(getOrAddDictValue(config.getLastNotEmptyConst("FORM", i)));
}
if (words.empty())
util::myThrow(fmt::format("Empty context with nbLines={} head={} start={} end={}", config.getNbLines(), wordIndex, startIndex, endIndex));
auto wordsAsEmb = wordEmbeddings(torch::from_blob(words.data(), {(long int)words.size()}, at::kLong));
return torch::softmax(linear(wordsAsEmb[wordIndex-startIndex]), 0);
constexpr int embeddingsSize = 100;
wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(200000, embeddingsSize));
linear = register_module("linear", torch::nn::Linear(embeddingsSize, nbOutputs));
this->focusedIndex = focusedIndex;
}
std::size_t TestNetworkImpl::getOrAddDictValue(Config::String s)
torch::Tensor TestNetworkImpl::forward(torch::Tensor input)
{
if (s.get().empty())
return dict[Config::String("_null_")];
const auto & found = dict.find(s);
// input dim = {batch, sequence, embeddings}
auto wordsAsEmb = wordEmbeddings(input);
// reshaped dim = {sequence, batch, embeddings}
auto reshaped = wordsAsEmb.permute({1,0,2});
if (found == dict.end())
return dict[s] = dict.size();
auto res = torch::softmax(linear(reshaped[focusedIndex]), 1);
return found->second;
return res;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment