diff --git a/common/include/Dict.hpp b/common/include/Dict.hpp
index fb4b04bae5a812d2710def5e83f3418a7c57e60b..d976573f3c3c383acfb1e769a8ec84277ed43861 100644
--- a/common/include/Dict.hpp
+++ b/common/include/Dict.hpp
@@ -11,7 +11,7 @@ class Dict
   enum State {Open, Closed};
   enum Encoding {Binary, Ascii};
 
-  private :
+  public :
 
   static constexpr char const * unknownValueStr = "__unknownValue__";
   static constexpr char const * nullValueStr = "__nullValue__";
diff --git a/dev/src/dev.cpp b/dev/src/dev.cpp
index 8d83e28f843445eafa04cb7ae068b23f4e91e547..b5bd81f47d2f1b37427c699e3eb1d3eb48b15c27 100644
--- a/dev/src/dev.cpp
+++ b/dev/src/dev.cpp
@@ -29,17 +29,10 @@ int main(int argc, char * argv[])
 
   config.setState(machine.getStrategy().getInitialState());
 
+  std::vector<torch::Tensor> contexts;
+  std::vector<torch::Tensor> classes;
 
-  TestNetwork nn(machine.getTransitionSet().size());
-  torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
-  optimizer.zero_grad();
-
-  std::vector<torch::Tensor> predictionsBatch;
-  std::vector<torch::Tensor> referencesBatch;
-  std::vector<std::unique_ptr<Config>> configs;
-  std::vector<std::size_t> classes;
-
-  fmt::print("Generating dataset...");
+  fmt::print("Generating dataset...\n");
 
   Dict dict(Dict::State::Open);
 
@@ -49,21 +42,13 @@ int main(int argc, char * argv[])
     if (!transition)
       util::myThrow("No transition appliable !");
 
-    //here train
-    int goldIndex = 3;
-    auto gold = torch::zeros(machine.getTransitionSet().size(), at::kLong);
-    gold[goldIndex] = 1;
-//    referencesBatch.emplace_back(gold);
-//    predictionsBatch.emplace_back(nn(config));
+    auto context = config.extractContext(5,5,dict);
+    contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone());
 
-//    auto loss = torch::nll_loss(prediction, gold);
-//    loss.backward();
-//    optimizer.step();
-    configs.emplace_back(std::unique_ptr<Config>(new SubConfig(config)));
-    classes.emplace_back(goldIndex);
+    int goldIndex = 3;
+    auto gold = torch::from_blob(&goldIndex, {1}, at::kLong).clone();
 
-//    if (config.getWordIndex() >= 500)
-//      exit(1);
+    classes.emplace_back(gold);
 
     transition->apply(config);
     config.addToHistory(transition->getName());
@@ -80,16 +65,39 @@ int main(int argc, char * argv[])
       config.update();
   }
 
-  auto dataset = ConfigDataset(configs, classes, machine.getTransitionSet().size(), dict).map(torch::data::transforms::Stack<>());
+  auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>());
+
+  fmt::print("Done! size={}\n", *dataset.size());
 
-  fmt::print("Done!\n");
+  int batchSize = 100;
+  auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize));
 
-  auto dataLoader = torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(std::move(dataset), 50);
+  TestNetwork nn(machine.getTransitionSet().size(), 5);
+  torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
 
-  for (auto & batch : *dataLoader)
+  for (int epoch = 1; epoch <= 5; ++epoch)
   {
-    auto data = batch.data;
-    auto labels = batch.target.squeeze();
+    float totalLoss = 0.0;
+    torch::Tensor example;
+
+    for (auto & batch : *dataLoader)
+    {
+      optimizer.zero_grad();
+
+      auto data = batch.data;
+      auto labels = batch.target.squeeze();
+
+      auto prediction = nn(data);
+      example = prediction[0];
+
+      auto loss = torch::nll_loss(torch::log(prediction), labels);
+      totalLoss += loss.item<float>();
+      loss.backward();
+      optimizer.step();
+    }
+
+    fmt::print("Epoch {} : loss={:.2f}\n", epoch, totalLoss);
+    std::cout << example << std::endl;
   }
 
   return 0;
diff --git a/reading_machine/include/Config.hpp b/reading_machine/include/Config.hpp
index 21edddd928d7a4fd7820ef5aea8778f9142f51ac..5f34b241d7e494a48b345b2d0713647f3e6000e3 100644
--- a/reading_machine/include/Config.hpp
+++ b/reading_machine/include/Config.hpp
@@ -99,7 +99,7 @@ class Config
   String getState() const;
   void setState(const std::string state);
   bool stateIsDone() const;
-  std::vector<int> extractContext(int leftBorder, int rightBorder, Dict & dict) const;
+  std::vector<long> extractContext(int leftBorder, int rightBorder, Dict & dict) const;
 
 };
 
diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp
index 0075bcb458f864cd05b884a4ba747e039a4d8be9..2f6e7cf138875e225877990894964700e78fbc35 100644
--- a/reading_machine/src/Config.cpp
+++ b/reading_machine/src/Config.cpp
@@ -366,30 +366,29 @@ bool Config::stateIsDone() const
   return !has(0, wordIndex+1, 0);
 }
 
-std::vector<int> Config::extractContext(int leftBorder, int rightBorder, Dict & dict) const
+std::vector<long> Config::extractContext(int leftBorder, int rightBorder, Dict & dict) const
 {
-  std::vector<int> context;
+  std::stack<int> leftContext;
+  for (int index = wordIndex-1; has(0,index,0) && (int)leftContext.size() < leftBorder; --index)
+    if (isToken(index))
+      leftContext.push(dict.getIndexOrInsert(getLastNotEmptyConst("FORM", index)));
 
-  int startIndex = wordIndex;
+  std::vector<long> context;
 
-  for (int i = 0; i < leftBorder and has(0,startIndex-1,0); i++)
-    do
-      --startIndex;
-    while (!isToken(startIndex) and has(0,startIndex-1,0));
-
-  int endIndex = wordIndex;
-
-  for (int i = 0; i < rightBorder and has(0,endIndex+1,0); i++)
-    do
-      ++endIndex;
-    while (!isToken(endIndex) and has(0,endIndex+1,0));
+  while ((int)context.size() < leftBorder-(int)leftContext.size())
+    context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
+  while (!leftContext.empty())
+  {
+    context.emplace_back(leftContext.top());
+    leftContext.pop();
+  }
 
-  for (int i = startIndex; i <= endIndex; ++i)
-    if (isToken(i))
-      context.emplace_back(dict.getIndexOrInsert(getLastNotEmptyConst("FORM", i)));
+  for (int index = wordIndex; has(0,index,0) && (int)context.size() < leftBorder+rightBorder+1; ++index)
+    if (isToken(index))
+      context.emplace_back(dict.getIndexOrInsert(getLastNotEmptyConst("FORM", index)));
 
-  //TODO gérer les cas où la taille est differente...
-  return {0};
+  while ((int)context.size() < leftBorder+rightBorder+1)
+    context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
 
   return context;
 }
diff --git a/torch_modules/include/ConfigDataset.hpp b/torch_modules/include/ConfigDataset.hpp
index ee4430ce3171cda9f9e8f49575d09a5f821121c8..7aa878e62ce6d8beeabca4cd763353bb9023ff22 100644
--- a/torch_modules/include/ConfigDataset.hpp
+++ b/torch_modules/include/ConfigDataset.hpp
@@ -8,14 +8,12 @@ class ConfigDataset : public torch::data::Dataset<ConfigDataset>
 {
   private :
 
-  std::vector<std::unique_ptr<Config>> const & configs;
-  std::vector<std::size_t> const & classes;
-  std::size_t nbClasses;
-  Dict & dict;
+  std::vector<torch::Tensor> contexts;
+  std::vector<torch::Tensor> classes;
 
   public :
 
-  explicit ConfigDataset(std::vector<std::unique_ptr<Config>> const & configs, std::vector<std::size_t> const & classes, std::size_t nbClasses, Dict & dict);
+  explicit ConfigDataset(std::vector<torch::Tensor> contexts, std::vector<torch::Tensor> classes);
   torch::optional<size_t> size() const override;
   torch::data::Example<> get(size_t index) override;
 };
diff --git a/torch_modules/include/TestNetwork.hpp b/torch_modules/include/TestNetwork.hpp
index eceb9c9b8911364ad01f4df16314c80f0c7af550..b1bb4e11200197f029556f7f3d93338ca0ce072e 100644
--- a/torch_modules/include/TestNetwork.hpp
+++ b/torch_modules/include/TestNetwork.hpp
@@ -8,15 +8,14 @@ class TestNetworkImpl : public torch::nn::Module
 {
   private :
 
-  std::map<Config::String, std::size_t> dict;
   torch::nn::Embedding wordEmbeddings{nullptr};
   torch::nn::Linear linear{nullptr};
+  int focusedIndex;
 
   public :
 
-  TestNetworkImpl(int nbOutputs);
-  torch::Tensor forward(const Config & config);
-  std::size_t getOrAddDictValue(Config::String s);
+  TestNetworkImpl(int nbOutputs, int focusedIndex);
+  torch::Tensor forward(torch::Tensor input);
 };
 TORCH_MODULE(TestNetwork);
 
diff --git a/torch_modules/src/ConfigDataset.cpp b/torch_modules/src/ConfigDataset.cpp
index f9b5b57c976c30bcbdedb6abe2bd4685af13a95d..e2d3853312fc657c4dfd31f08197238adbab7e47 100644
--- a/torch_modules/src/ConfigDataset.cpp
+++ b/torch_modules/src/ConfigDataset.cpp
@@ -1,20 +1,16 @@
 #include "ConfigDataset.hpp"
 
-ConfigDataset::ConfigDataset(std::vector<std::unique_ptr<Config>> const & configs, std::vector<std::size_t> const & classes, std::size_t nbClasses, Dict & dict) : configs(configs), classes(classes), nbClasses(nbClasses), dict(dict)
+ConfigDataset::ConfigDataset(std::vector<torch::Tensor> contexts, std::vector<torch::Tensor> classes) : contexts(contexts), classes(classes)
 {
 }
 
 torch::optional<size_t> ConfigDataset::size() const
 {
-  return configs.size();
+  return contexts.size();
 }
 
 torch::data::Example<> ConfigDataset::get(size_t index)
 {
-  auto context = configs[index]->extractContext(5,5,dict);
-  auto tensorClass = torch::zeros(nbClasses);
-  tensorClass[classes[index]] = 1;
-
-  return {torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone(), tensorClass};
+  return {contexts[index], classes[index]};
 }
 
diff --git a/torch_modules/src/TestNetwork.cpp b/torch_modules/src/TestNetwork.cpp
index 7ac71241fb9691b77829d179a092fd2232862478..3e3c010a7634824181e205b69161afbaefd96585 100644
--- a/torch_modules/src/TestNetwork.cpp
+++ b/torch_modules/src/TestNetwork.cpp
@@ -1,56 +1,22 @@
 #include "TestNetwork.hpp"
 
-TestNetworkImpl::TestNetworkImpl(int nbOutputs)
+TestNetworkImpl::TestNetworkImpl(int nbOutputs, int focusedIndex)
 {
-  getOrAddDictValue(Config::String("_null_"));
-  getOrAddDictValue(Config::String("_unknown_"));
-  getOrAddDictValue(Config::String("_S_"));
-
-  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(200000, 100));
-  linear = register_module("linear", torch::nn::Linear(100, nbOutputs));
-}
-
-torch::Tensor TestNetworkImpl::forward(const Config & config)
-{
-//  std::vector<std::size_t> test{0,1};
-//  torch::Tensor tens = torch::from_blob(test.data(), {1,2});
-//  return wordEmbeddings(tens);
-  constexpr int windowSize = 5;
-  int wordIndex = config.getWordIndex();
-  int startIndex = wordIndex;
-  while (config.has(0,startIndex-1,0) and wordIndex-startIndex < windowSize)
-    startIndex--;
-  int endIndex = wordIndex;
-  while (config.has(0,endIndex+1,0) and -wordIndex+endIndex < windowSize)
-    endIndex++;
-
-  std::vector<std::size_t> words;
-  for (int i = startIndex; i <= endIndex; ++i)
-  {
-    if (!config.has(0, i, 0))
-      util::myThrow(fmt::format("Config do not have line %d", i));
-
-    words.emplace_back(getOrAddDictValue(config.getLastNotEmptyConst("FORM", i)));
-  }
-
-  if (words.empty())
-    util::myThrow(fmt::format("Empty context with nbLines={} head={} start={} end={}", config.getNbLines(), wordIndex, startIndex, endIndex));
-
-  auto wordsAsEmb = wordEmbeddings(torch::from_blob(words.data(), {(long int)words.size()}, at::kLong));
-
-  return torch::softmax(linear(wordsAsEmb[wordIndex-startIndex]), 0);
+  constexpr int embeddingsSize = 100;
+  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(200000, embeddingsSize));
+  linear = register_module("linear", torch::nn::Linear(embeddingsSize, nbOutputs));
+  this->focusedIndex = focusedIndex;
 }
 
-std::size_t TestNetworkImpl::getOrAddDictValue(Config::String s)
+torch::Tensor TestNetworkImpl::forward(torch::Tensor input)
 {
-  if (s.get().empty())
-    return dict[Config::String("_null_")];
-
-  const auto & found = dict.find(s);
+  // input dim = {batch, sequence, embeddings}
+  auto wordsAsEmb = wordEmbeddings(input);
+  // reshaped dim = {sequence, batch, embeddings}
+  auto reshaped = wordsAsEmb.permute({1,0,2});
 
-  if (found == dict.end())
-    return dict[s] = dict.size();
+  auto res = torch::softmax(linear(reshaped[focusedIndex]), 1);
 
-  return found->second;
+  return res;
 }