Skip to content
Snippets Groups Projects
Commit 7616ceba authored by Franck Dary's avatar Franck Dary
Browse files

Added new option for training

parent 6afe9f26
No related branches found
No related tags found
No related merge requests found
......@@ -76,6 +76,7 @@ struct ProgramParameters
static bool randomDebug;
static float randomDebugProbability;
static bool alwaysSave;
static bool noNeuralNetwork;
private :
......
......@@ -70,3 +70,4 @@ bool ProgramParameters::devLoss;
bool ProgramParameters::randomDebug;
float ProgramParameters::randomDebugProbability;
bool ProgramParameters::alwaysSave;
bool ProgramParameters::noNeuralNetwork;
......@@ -42,6 +42,7 @@ po::options_description getTrainOptionsDescription()
("help,h", "Produce this help message")
("debug,d", "Print infos on stderr")
("alwaysSave", "Save the model at every iteration")
("noNeuralNetwork", "Don't use any neural network, useful to speed up debug")
("randomDebug", "Print infos on stderr with a probability of randomDebugProbability")
("randomDebugProbability", po::value<float>()->default_value(0.001),
"Probability that debug infos will be printed")
......@@ -270,6 +271,7 @@ void loadTrainProgramParameters(int argc, char * argv[])
ProgramParameters::mcdName = vm["mcd"].as<std::string>();
ProgramParameters::debug = vm.count("debug") == 0 ? false : true;
ProgramParameters::alwaysSave = vm.count("alwaysSave") == 0 ? false : true;
ProgramParameters::noNeuralNetwork = vm.count("noNeuralNetwork") == 0 ? false : true;
ProgramParameters::randomDebug = vm.count("randomDebug") == 0 ? false : true;
ProgramParameters::printEntropy = vm.count("printEntropy") == 0 ? false : true;
ProgramParameters::printTime = vm.count("printTime") == 0 ? false : true;
......
......@@ -94,6 +94,14 @@ Classifier::WeightedActions Classifier::weightActions(Config & config)
{
WeightedActions result;
if (ProgramParameters::noNeuralNetwork)
{
for (unsigned int i = 0; i < as->actions.size(); i++)
result.emplace_back(as->actions[i].appliable(config), std::pair<float, std::string>(1.0, as->actions[i].name));
return result;
}
if(type == Type::Prediction)
{
initClassifier(config);
......@@ -269,17 +277,26 @@ std::string Classifier::getDefaultAction() const
float Classifier::trainOnExample(Config & config, int gold)
{
if (ProgramParameters::noNeuralNetwork)
return 0.0;
auto & fd = fm->getFeatureDescription(config);
return nn->update(fd, gold);
}
float Classifier::trainOnExample(FeatureModel::FeatureDescription & fd, int gold)
{
if (ProgramParameters::noNeuralNetwork)
return 0.0;
return nn->update(fd, gold);
}
float Classifier::getLoss(Config & config, int gold)
{
if (ProgramParameters::noNeuralNetwork)
return 0.0;
auto & fd = fm->getFeatureDescription(config);
return nn->getLoss(fd, gold);
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment