diff --git a/torch_modules/include/ConfigDataset.hpp b/torch_modules/include/ConfigDataset.hpp
index 7aa878e62ce6d8beeabca4cd763353bb9023ff22..1c60de78ffc4cced19ed0138f45b6003bd419a35 100644
--- a/torch_modules/include/ConfigDataset.hpp
+++ b/torch_modules/include/ConfigDataset.hpp
@@ -8,12 +8,13 @@ class ConfigDataset : public torch::data::Dataset<ConfigDataset>
 {
   private :
 
-  std::vector<torch::Tensor> contexts;
-  std::vector<torch::Tensor> classes;
+  torch::Tensor data;
+  std::size_t size_{0};
+  std::size_t contextSize{0};
 
   public :
 
-  explicit ConfigDataset(std::vector<torch::Tensor> contexts, std::vector<torch::Tensor> classes);
+  explicit ConfigDataset(const std::vector<torch::Tensor> & contexts, const std::vector<torch::Tensor> & classes);
   torch::optional<size_t> size() const override;
   torch::data::Example<> get(size_t index) override;
 };
diff --git a/torch_modules/src/ConfigDataset.cpp b/torch_modules/src/ConfigDataset.cpp
index e2d3853312fc657c4dfd31f08197238adbab7e47..439cbcc6d32abfeeff17fd090f15bf34a3b3f438 100644
--- a/torch_modules/src/ConfigDataset.cpp
+++ b/torch_modules/src/ConfigDataset.cpp
@@ -1,16 +1,29 @@
 #include "ConfigDataset.hpp"
 
-ConfigDataset::ConfigDataset(std::vector<torch::Tensor> contexts, std::vector<torch::Tensor> classes) : contexts(contexts), classes(classes)
+ConfigDataset::ConfigDataset(const std::vector<torch::Tensor> & contexts, const std::vector<torch::Tensor> & classes)
 {
+  if (contexts.size() != classes.size())
+    util::myThrow(fmt::format("contexts.size()={} classes.size()={}", contexts.size(), classes.size()));
+
+  size_ = contexts.size();
+  contextSize = contexts.back().size(0);
+  std::vector<torch::Tensor> total;
+  for (unsigned int i = 0; i < contexts.size(); i++)
+  {
+    total.emplace_back(contexts[i]);
+    total.emplace_back(classes[i]);
+  }
+
+  data = torch::cat(total);
 }
 
 torch::optional<size_t> ConfigDataset::size() const
 {
-  return contexts.size();
+  return size_;
 }
 
 torch::data::Example<> ConfigDataset::get(size_t index)
 {
-  return {contexts[index], classes[index]};
+  return {data.narrow(0, index*(contextSize+1), contextSize), data.narrow(0, index*(contextSize+1)+contextSize, 1)};
 }