Skip to content
Snippets Groups Projects
Commit 59751292 authored by Franck Dary's avatar Franck Dary
Browse files

Fixed dynamic oracle

parent bfd53a17
No related branches found
No related tags found
No related merge requests found
...@@ -88,6 +88,8 @@ bool isEmpty(const boost::flyweight<T> & s) ...@@ -88,6 +88,8 @@ bool isEmpty(const boost::flyweight<T> & s)
bool doIfNameMatch(const std::regex & reg, std::string_view name, const std::function<void(const std::smatch &)> & f); bool doIfNameMatch(const std::regex & reg, std::string_view name, const std::function<void(const std::smatch &)> & f);
bool choiceWithProbability(float probability);
}; };
template <> template <>
......
...@@ -197,3 +197,11 @@ std::string util::getTime() ...@@ -197,3 +197,11 @@ std::string util::getTime()
return std::string(buffer); return std::string(buffer);
} }
bool util::choiceWithProbability(float probability)
{
int maxVal = 100000;
int threshold = maxVal * probability;
return (std::rand() % maxVal) < threshold;
}
...@@ -84,8 +84,11 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p ...@@ -84,8 +84,11 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
} }
Transition * transition = nullptr; Transition * transition = nullptr;
Transition * goldTransition = nullptr;
goldTransition = machine.getTransitionSet().getBestAppliableTransition(config);
if (dynamicOracle and config.getState() != "tokenizer" and config.getState() != "parser") if (dynamicOracle and util::choiceWithProbability(0.8) and config.getState() != "tokenizer" and config.getState() != "parser")
{ {
auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device); auto neuralInput = torch::from_blob(context[0].data(), {(long)context[0].size()}, torch::kLong).clone().to(NeuralNetworkImpl::device);
auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze(); auto prediction = machine.getClassifier()->getNN()(neuralInput).squeeze();
...@@ -107,16 +110,16 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p ...@@ -107,16 +110,16 @@ void Trainer::extractExamples(SubConfig & config, bool debug, std::filesystem::p
} }
else else
{ {
transition = machine.getTransitionSet().getBestAppliableTransition(config); transition = goldTransition;
} }
if (!transition) if (!transition or !goldTransition)
{ {
config.printForDebug(stderr); config.printForDebug(stderr);
util::myThrow("No transition appliable !"); util::myThrow("No transition appliable !");
} }
int goldIndex = machine.getTransitionSet().getTransitionIndex(transition); int goldIndex = machine.getTransitionSet().getTransitionIndex(goldTransition);
totalNbExamples += context.size(); totalNbExamples += context.size();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment