diff --git a/reading_machine/src/Classifier.cpp b/reading_machine/src/Classifier.cpp
index 61a3d87c4ba4e9da636ae64ceb18372440576df0..096800b01a4db30e06f0a3c2df4dc0a3b2b4f853 100644
--- a/reading_machine/src/Classifier.cpp
+++ b/reading_machine/src/Classifier.cpp
@@ -3,6 +3,7 @@
 #include "OneWordNetwork.hpp"
 #include "ConcatWordsNetwork.hpp"
 #include "RTLSTMNetwork.hpp"
+#include "CNNNetwork.hpp"
 
 Classifier::Classifier(const std::string & name, const std::string & topology, const std::string & tsFile)
 {
@@ -46,6 +47,14 @@ void Classifier::initNeuralNetwork(const std::string & topology)
         this->nn.reset(new ConcatWordsNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3])));
       }
     },
+    {
+      std::regex("CNN\\(([+\\-]?\\d+),([+\\-]?\\d+),([+\\-]?\\d+)\\)"),
+      "CNN(leftBorder,rightBorder,nbStack) : CNN to capture context.",
+      [this,topology](auto sm)
+      {
+        this->nn.reset(new CNNNetworkImpl(this->transitionSet->size(), std::stoi(sm[1]), std::stoi(sm[2]), std::stoi(sm[3])));
+      }
+    },
     {
       std::regex("RTLSTM\\(([+\\-]?\\d+),([+\\-]?\\d+),([+\\-]?\\d+)\\)"),
       "RTLSTM(leftBorder,rightBorder,nbStack) : Recursive tree LSTM.",
diff --git a/torch_modules/include/CNNNetwork.hpp b/torch_modules/include/CNNNetwork.hpp
new file mode 100644
index 0000000000000000000000000000000000000000..d5ec5bf5a8f6d0f3633e74c110fffb8986454f79
--- /dev/null
+++ b/torch_modules/include/CNNNetwork.hpp
@@ -0,0 +1,27 @@
+#ifndef CNNNETWORK__H
+#define CNNNETWORK__H
+
+#include "NeuralNetwork.hpp"
+
+class CNNNetworkImpl : public NeuralNetworkImpl
+{
+  private :
+
+  static inline std::vector<long> focusedBufferIndexes{0,1};
+  static inline std::vector<long> windowSizes{2,3,4};
+  static constexpr unsigned int maxNbLetters = 10;
+
+  torch::nn::Embedding wordEmbeddings{nullptr};
+  torch::nn::Linear linear1{nullptr};
+  torch::nn::Linear linear2{nullptr};
+  std::vector<torch::nn::Conv2d> CNNs;
+  std::vector<torch::nn::Conv2d> lettersCNNs;
+
+  public :
+
+  CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements);
+  torch::Tensor forward(torch::Tensor input) override;
+  std::vector<long> extractContext(Config & config, Dict & dict) const override;
+};
+
+#endif
diff --git a/torch_modules/src/CNNNetwork.cpp b/torch_modules/src/CNNNetwork.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..7612f7c60bf568a4ae60df510e2d3f51b3cecea9
--- /dev/null
+++ b/torch_modules/src/CNNNetwork.cpp
@@ -0,0 +1,138 @@
+#include "CNNNetwork.hpp"
+
+CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements)
+{
+  constexpr int embeddingsSize = 64;
+  constexpr int hiddenSize = 512;
+  constexpr int nbFilters = 512;
+  constexpr int nbFiltersLetters = 64;
+
+  setLeftBorder(leftBorder);
+  setRightBorder(rightBorder);
+  setNbStackElements(nbStackElements);
+  setColumns({"FORM", "UPOS"});
+
+  wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(50000, embeddingsSize)));
+  linear1 = register_module("linear1", torch::nn::Linear(nbFilters*windowSizes.size()+nbFiltersLetters*windowSizes.size()*focusedBufferIndexes.size(), hiddenSize));
+  linear2 = register_module("linear2", torch::nn::Linear(hiddenSize, nbOutputs));
+  for (auto & windowSize : windowSizes)
+  {
+    CNNs.emplace_back(register_module(fmt::format("cnn_context_{}", windowSize), torch::nn::Conv2d(torch::nn::Conv2dOptions(1, nbFilters, torch::ExpandingArray<2>({windowSize,2*embeddingsSize})).padding({windowSize-1, 0}))));
+    lettersCNNs.emplace_back(register_module(fmt::format("cnn_letters_{}", windowSize), torch::nn::Conv2d(torch::nn::Conv2dOptions(1, nbFiltersLetters, torch::ExpandingArray<2>({windowSize,embeddingsSize})).padding({windowSize-1, 0}))));
+  }
+}
+
+torch::Tensor CNNNetworkImpl::forward(torch::Tensor input)
+{
+  if (input.dim() == 1)
+    input = input.unsqueeze(0);
+
+  auto wordIndexes = input.narrow(1, 0, columns.size()*(1+leftBorder+rightBorder));
+  auto wordLetters = input.narrow(1, columns.size()*(1+leftBorder+rightBorder), maxNbLetters*focusedBufferIndexes.size());
+
+  auto embeddings = wordEmbeddings(wordIndexes).view({wordIndexes.size(0), wordIndexes.size(1)/(int)columns.size(), (int)columns.size()*(int)wordEmbeddings->options.embedding_dim()}).unsqueeze(1);
+  auto lettersEmbeddings = wordEmbeddings(wordLetters).view({wordLetters.size(0), wordLetters.size(1)/maxNbLetters, maxNbLetters, wordEmbeddings->options.embedding_dim()}).unsqueeze(1);
+
+  auto permuted = lettersEmbeddings.permute({2,0,1,3,4});
+  std::vector<torch::Tensor> windows;
+  for (unsigned int word = 0; word < focusedBufferIndexes.size(); word++)
+    for (unsigned int i = 0; i < lettersCNNs.size(); i++)
+    {
+      auto input = permuted[word];
+      auto convOut = torch::relu(lettersCNNs[i](input).squeeze(-1));
+      auto pooled = torch::max_pool1d(convOut, convOut.size(2));
+      windows.emplace_back(pooled);
+    }
+  auto lettersCnnOut = torch::cat(windows, 2);
+  lettersCnnOut = lettersCnnOut.view({lettersCnnOut.size(0), -1});
+
+  windows.clear();
+  for (unsigned int i = 0; i < CNNs.size(); i++)
+  {
+    auto convOut = torch::relu(CNNs[i](embeddings).squeeze(-1));
+    auto pooled = torch::max_pool1d(convOut, convOut.size(2));
+    windows.emplace_back(pooled);
+  }
+
+  auto cnnOut = torch::cat(windows, 2);
+  cnnOut = cnnOut.view({cnnOut.size(0), -1});
+
+  auto totalInput = torch::cat({cnnOut, lettersCnnOut}, 1);
+
+  return linear2(torch::relu(linear1(totalInput)));
+}
+
+std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) const
+{
+  std::stack<int> leftContext;
+  std::stack<std::string> leftForms;
+  for (int index = config.getWordIndex()-1; config.has(0,index,0) && leftContext.size() < columns.size()*leftBorder; --index)
+    if (config.isToken(index))
+      for (auto & column : columns)
+      {
+        leftContext.push(dict.getIndexOrInsert(config.getAsFeature(column, index)));
+        if (column == "FORM")
+          leftForms.push(config.getAsFeature(column, index));
+      }
+
+  std::vector<long> context;
+  std::vector<std::string> forms;
+
+  while ((int)context.size() < (int)columns.size()*(leftBorder-(int)leftContext.size()))
+    context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
+  while (forms.size() < leftBorder-leftForms.size())
+    forms.emplace_back("");
+  while (!leftForms.empty())
+  {
+    forms.emplace_back(leftForms.top());
+    leftForms.pop();
+  }
+  while (!leftContext.empty())
+  {
+    context.emplace_back(leftContext.top());
+    leftContext.pop();
+  }
+
+  for (int index = config.getWordIndex(); config.has(0,index,0) && context.size() < columns.size()*(leftBorder+rightBorder+1); ++index)
+    if (config.isToken(index))
+      for (auto & column : columns)
+      {
+        context.emplace_back(dict.getIndexOrInsert(config.getAsFeature(column, index)));
+        if (column == "FORM")
+          forms.emplace_back(config.getAsFeature(column, index));
+      }
+
+  while (context.size() < columns.size()*(leftBorder+rightBorder+1))
+    context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
+  while ((int)forms.size() < leftBorder+rightBorder+1)
+    forms.emplace_back("");
+
+  for (int i = 0; i < nbStackElements; i++)
+    for (auto & column : columns)
+      if (config.hasStack(i))
+        context.emplace_back(dict.getIndexOrInsert(config.getAsFeature(column, config.getStack(i))));
+      else
+        context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
+
+  for (auto index : focusedBufferIndexes)
+  {
+    util::utf8string letters;
+    if (leftBorder+index >= 0 && leftBorder+index < (int)forms.size() && !forms[leftBorder+index].empty())
+      letters = util::splitAsUtf8(forms[leftBorder+index]);
+    for (unsigned int i = 0; i < maxNbLetters; i++)
+    {
+      if (i < letters.size())
+      {
+        std::string sLetter = fmt::format("Letter({})", letters[i]);
+        context.emplace_back(dict.getIndexOrInsert(sLetter));
+      }
+      else
+      {
+        context.emplace_back(dict.getIndexOrInsert(Dict::nullValueStr));
+      }
+    }
+  }
+
+  return context;
+}
+
diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp
index 5a8c30230b0f56b1aeb98b04879d40cf3d51ab20..6e889171c5f1fa461f333605446fb8545d287270 100644
--- a/trainer/include/Trainer.hpp
+++ b/trainer/include/Trainer.hpp
@@ -18,7 +18,7 @@ class Trainer
   DataLoader dataLoader{nullptr};
   std::unique_ptr<torch::optim::Adam> optimizer;
   std::size_t epochNumber{0};
-  int batchSize{1};
+  int batchSize{50};
   int nbExamples{0};
 
   public :