Commit 9df2da0f authored by Franck Dary's avatar Franck Dary
Browse files

Added case for regression in extractExamples

parent bb7307d4
......@@ -92,8 +92,9 @@ void Trainer::extractExamples(std::vector<SubConfig> & configs, bool debug, std:
if (dynamicOracle and util::choiceWithProbability(1.0) and config.getState() != "tokenizer" and config.getState() != "segmenter")
{
auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
auto prediction = torch::softmax(machine.getClassifier(config.getState())->getNN()(neuralInput), -1).squeeze(0);
auto & classifier = *machine.getClassifier(config.getState());
auto neuralInput = torch::from_blob(context.data(), {(long)context.size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
auto prediction = classifier.isRegression() ? classifier.getNN()(neuralInput).squeeze(0) : torch::softmax(classifier.getNN()(neuralInput).squeeze(0), 0);
entropy = NeuralNetworkImpl::entropy(prediction);
std::vector<int> candidates;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment