Skip to content
Snippets Groups Projects
Commit 764c2fc7 authored by Franck Dary's avatar Franck Dary
Browse files

Improved error messages in case of trying to use an unknown column

parent e6ca3164
No related branches found
No related tags found
No related merge requests found
...@@ -166,7 +166,10 @@ std::size_t BaseConfig::getNbColumns() const ...@@ -166,7 +166,10 @@ std::size_t BaseConfig::getNbColumns() const
std::size_t BaseConfig::getColIndex(const std::string & colName) const std::size_t BaseConfig::getColIndex(const std::string & colName) const
{ {
return colName2Index.at(colName); auto it = colName2Index.find(colName);
if (it == colName2Index.end())
util::myThrow(fmt::format("unknown column name '{}'", colName));
return it->second;
} }
bool BaseConfig::hasColIndex(const std::string & colName) const bool BaseConfig::hasColIndex(const std::string & colName) const
......
...@@ -476,7 +476,12 @@ bool Config::stateIsDone() const ...@@ -476,7 +476,12 @@ bool Config::stateIsDone() const
void Config::addPredicted(const std::set<std::string> & predicted) void Config::addPredicted(const std::set<std::string> & predicted)
{ {
this->predicted.insert(predicted.begin(), predicted.end()); for (auto & col : predicted)
{
if (!hasColIndex(col))
util::myThrow(fmt::format("unknown column '{}'", col));
this->predicted.insert(col);
}
} }
bool Config::isPredicted(const std::string & colName) const bool Config::isPredicted(const std::string & colName) const
......
...@@ -25,8 +25,14 @@ void Trainer::createDataset(SubConfig & config, bool debug) ...@@ -25,8 +25,14 @@ void Trainer::createDataset(SubConfig & config, bool debug)
util::myThrow("No transition appliable !"); util::myThrow("No transition appliable !");
} }
try
{
auto context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState())); auto context = machine.getClassifier()->getNN()->extractContext(config,machine.getDict(config.getState()));
contexts.emplace_back(torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(at::kLong)).clone().to(NeuralNetworkImpl::device)); contexts.emplace_back(torch::from_blob(context.data(), {(long)context.size()}, torch::TensorOptions(at::kLong)).clone().to(NeuralNetworkImpl::device));
} catch(std::exception & e)
{
util::myThrow(fmt::format("Failed to extract context : {}", e.what()));
}
int goldIndex = machine.getTransitionSet().getTransitionIndex(transition); int goldIndex = machine.getTransitionSet().getTransitionIndex(transition);
auto gold = torch::zeros(1, torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::device)); auto gold = torch::zeros(1, torch::TensorOptions(at::kLong).device(NeuralNetworkImpl::device));
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment