From 764c2fc7b192d0d7a15ba6778947b38eb990bc9c Mon Sep 17 00:00:00 2001
From: Franck Dary <franck.dary@lis-lab.fr>
Date: Thu, 5 Mar 2020 15:04:40 +0100
Subject: [PATCH] Improved error messages in case of trying to use an unknown
 column

---
 reading_machine/src/BaseConfig.cpp |  5 ++++-
 reading_machine/src/Config.cpp     |  7 ++++++-
 trainer/src/Trainer.cpp            | 10 ++++++++--
 3 files changed, 18 insertions(+), 4 deletions(-)

diff --git a/reading_machine/src/BaseConfig.cpp b/reading_machine/src/BaseConfig.cpp
index ce61016..b197250 100644
--- a/reading_machine/src/BaseConfig.cpp
+++ b/reading_machine/src/BaseConfig.cpp
@@ -166,7 +166,10 @@ std::size_t BaseConfig::getNbColumns() const
 
 std::size_t BaseConfig::getColIndex(const std::string & colName) const
 {
-  return colName2Index.at(colName);
+  auto it = colName2Index.find(colName);
+  if (it == colName2Index.end())
+    util::myThrow(fmt::format("unknown column name '{}'", colName));
+  return it->second;
 }
 
 bool BaseConfig::hasColIndex(const std::string & colName) const
diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp
index fd8d1bc..5cc4a88 100644
--- a/reading_machine/src/Config.cpp
+++ b/reading_machine/src/Config.cpp
@@ -476,7 +476,12 @@ bool Config::stateIsDone() const
 
 void Config::addPredicted(const std::set<std::string> & predicted)
 {
-  this->predicted.insert(predicted.begin(), predicted.end());
+  for (auto & col : predicted)
+  {
+    if (!hasColIndex(col))
+      util::myThrow(fmt::format("unknown column '{}'", col));
+    this->predicted.insert(col);
+  }
 }
 
 bool Config::isPredicted(const std::string & colName) const
diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp
index 191ac47..106df3f 100644
--- a/trainer/src/Trainer.cpp
+++ b/trainer/src/Trainer.cpp
@@ -25,8 +25,14 @@ void Trainer::createDataset(SubConfig & config, bool debug)
       util::myThrow("No transition appliable !");
     }
 
-    auto context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState()));
-    contexts.emplace_back(torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(at::kLong)).clone().to(NeuralNetworkImpl::device));
+    try
+    {
+      auto context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState()));
+      contexts.emplace_back(torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(at::kLong)).clone().to(NeuralNetworkImpl::device));
+    } catch(std::exception & e)
+    {
+      util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
+    }
 
     int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
     auto gold = torch::zeros(1, torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::device));
-- 
GitLab