diff --git a/reading_machine/src/BaseConfig.cpp b/reading_machine/src/BaseConfig.cpp index ce61016ed8a6ca399e63a5cbd44d75268300de01..b197250a1f8146d8b7da3b58cc69a272b1b529ee 100644 --- a/reading_machine/src/BaseConfig.cpp +++ b/reading_machine/src/BaseConfig.cpp @@ -166,7 +166,10 @@ std::size_t BaseConfig::getNbColumns() const std::size_t BaseConfig::getColIndex(const std::string & colName) const { - return colName2Index.at(colName); + auto it = colName2Index.find(colName); + if (it == colName2Index.end()) + util::myThrow(fmt::format("unknown column name '{}'", colName)); + return it->second; } bool BaseConfig::hasColIndex(const std::string & colName) const diff --git a/reading_machine/src/Config.cpp b/reading_machine/src/Config.cpp index fd8d1bca540f96c15c9a7ead5306ea9e4e87ac50..5cc4a88aec4c604d1bad91a0946b5a30b17c625e 100644 --- a/reading_machine/src/Config.cpp +++ b/reading_machine/src/Config.cpp @@ -476,7 +476,12 @@ bool Config::stateIsDone() const void Config::addPredicted(const std::set<std::string> & predicted) { - this->predicted.insert(predicted.begin(), predicted.end()); + for (auto & col : predicted) + { + if (!hasColIndex(col)) + util::myThrow(fmt::format("unknown column '{}'", col)); + this->predicted.insert(col); + } } bool Config::isPredicted(const std::string & colName) const diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 191ac47278e1c6d17efd8fa9663271a6620c654f..106df3ffbdf3d90ab397d78a6fbff11619e77dc9 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -25,8 +25,14 @@ void Trainer::createDataset(SubConfig & config, bool debug) util::myThrow("No transition appliable !"); } - auto context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState())); - contexts.emplace_back(torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(at::kLong)).clone().to(NeuralNetworkImpl::device)); + try + { + auto context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState())); + contexts.emplace_back(torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(at::kLong)).clone().to(NeuralNetworkImpl::device)); + } catch(std::exception & e) + { + util::myThrow(fmt::format("Failed to extract context : {}", e.what())); + } int goldIndex = machine.getTransitionSet().getTransitionIndex(transition); auto gold = torch::zeros(1, torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::device));