From 764c2fc7b192d0d7a15ba6778947b38eb990bc9c Mon Sep 17 00:00:00 2001 From: Franck Dary <franck.dary@lis-lab.fr> Date: Thu, 5 Mar 2020 15:04:40 +0100 Subject: [PATCH] Improved error messages in case of trying to use an unknown column --- reading_machine/src/BaseConfig.cpp | 5 ++++- reading_machine/src/Config.cpp | 7 ++++++- trainer/src/Trainer.cpp | 10 ++++++++-- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/reading_machine/src/BaseConfig.cpp b/reading_machine/src/BaseConfig.cpp index ce61016..b197250 100644 --- a/reading_machine/src/BaseConfig.cpp +++ b/reading_machine/src/BaseConfig.cpp @@ -166,7 +166,10 @@ std::size_t BaseConfig::getNbColumns() const std::size_t BaseConfig::getColIndex(const std::string & colName) const { - return colName2Index.at(colName); + auto it = colName2Index.find(colName); + if (it == colName2Index.end()) + util::myThrow(fmt::format("unknown column name '{}'", colName)); + return it->second; } bool BaseConfig::hasColIndex(const std::string & colName) const diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index fd8d1bc..5cc4a88 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -476,7 +476,12 @@ bool Config::stateIsDone() const void Config::addPredicted(const std::set<std::string> & predicted) { - this->predicted.insert(predicted.begin(), predicted.end()); + for (auto & col : predicted) + { + if (!hasColIndex(col)) + util::myThrow(fmt::format("unknown column '{}'", col)); + this->predicted.insert(col); + } } bool Config::isPredicted(const std::string & colName) const diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 191ac47..106df3f 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -25,8 +25,14 @@ void Trainer::createDataset(SubConfig & config, bool debug) util::myThrow("No transition appliable !"); } - auto context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState())); - contexts.emplace_back(torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(at::kLong)).clone().to(NeuralNetworkImpl::device)); + try + { + auto context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState())); + contexts.emplace_back(torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(at::kLong)).clone().to(NeuralNetworkImpl::device)); + } catch(std::exception & e) + { + util::myThrow(fmt::format("Failed to extract context : {}", e.what())); + } int goldIndex = machine.getTransitionSet().getTransitionIndex(transition); auto gold = torch::zeros(1, torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::device)); -- GitLab