Skip to content
Snippets Groups Projects
Commit 8118cf69 authored by Franck Dary's avatar Franck Dary
Browse files

Added featureExtraction option

parent 46ff4287
Branches
Tags
No related merge requests found
...@@ -70,6 +70,7 @@ struct ProgramParameters ...@@ -70,6 +70,7 @@ struct ProgramParameters
static bool printOutputEntropy; static bool printOutputEntropy;
static std::string tapeToMask; static std::string tapeToMask;
static float maskRate; static float maskRate;
static bool featureExtraction;
private : private :
......
...@@ -64,4 +64,5 @@ bool ProgramParameters::printOutputEntropy; ...@@ -64,4 +64,5 @@ bool ProgramParameters::printOutputEntropy;
int ProgramParameters::dictCapacity; int ProgramParameters::dictCapacity;
std::string ProgramParameters::tapeToMask; std::string ProgramParameters::tapeToMask;
float ProgramParameters::maskRate; float ProgramParameters::maskRate;
bool ProgramParameters::featureExtraction;
...@@ -52,7 +52,7 @@ void Trainer::computeScoreOnDev() ...@@ -52,7 +52,7 @@ void Trainer::computeScoreOnDev()
else else
{ {
// Print current iter advancement in percentage // Print current iter advancement in percentage
if (ProgramParameters::interactive) if (ProgramParameters::interactive && !ProgramParameters::featureExtraction)
{ {
int totalSize = ProgramParameters::devTapeSize; int totalSize = ProgramParameters::devTapeSize;
int steps = devConfig->getHead(); int steps = devConfig->getHead();
...@@ -200,7 +200,7 @@ void Trainer::train() ...@@ -200,7 +200,7 @@ void Trainer::train()
} }
// Print current iter advancement in percentage // Print current iter advancement in percentage
if (ProgramParameters::interactive) if (ProgramParameters::interactive && !ProgramParameters::featureExtraction)
{ {
int totalSize = ProgramParameters::iterationSize == -1 ? ProgramParameters::tapeSize : ProgramParameters::iterationSize; int totalSize = ProgramParameters::iterationSize == -1 ? ProgramParameters::tapeSize : ProgramParameters::iterationSize;
int steps = ProgramParameters::iterationSize == -1 ? trainConfig.getHead() : nbSteps; int steps = ProgramParameters::iterationSize == -1 ? trainConfig.getHead() : nbSteps;
...@@ -211,12 +211,15 @@ void Trainer::train() ...@@ -211,12 +211,15 @@ void Trainer::train()
} }
} }
auto weightedActions = tm.getCurrentClassifier()->weightActions(trainConfig);
std::string pAction = ""; std::string pAction = "";
std::string oAction = ""; std::string oAction = "";
bool pActionIsZeroCost = false; bool pActionIsZeroCost = false;
Classifier::WeightedActions weightedActions;
if (!ProgramParameters::featureExtraction)
{
weightedActions = tm.getCurrentClassifier()->weightActions(trainConfig);
for (auto & it : weightedActions) for (auto & it : weightedActions)
if (it.first) if (it.first)
{ {
...@@ -232,6 +235,11 @@ void Trainer::train() ...@@ -232,6 +235,11 @@ void Trainer::train()
if (pAction == oAction) if (pAction == oAction)
pActionIsZeroCost = true; pActionIsZeroCost = true;
}
else
{
oAction = tm.getCurrentClassifier()->getZeroCostActions(trainConfig)[0];
}
if (oAction.empty()) if (oAction.empty())
oAction = tm.getCurrentClassifier()->getDefaultAction(); oAction = tm.getCurrentClassifier()->getDefaultAction();
...@@ -252,6 +260,7 @@ void Trainer::train() ...@@ -252,6 +260,7 @@ void Trainer::train()
exit(1); exit(1);
} }
if (!ProgramParameters::featureExtraction)
tm.getCurrentClassifier()->trainOnExample(trainConfig, tm.getCurrentClassifier()->getActionIndex(oAction)); tm.getCurrentClassifier()->trainOnExample(trainConfig, tm.getCurrentClassifier()->getActionIndex(oAction));
TI.addTrainExample(tm.getCurrentClassifier()->name); TI.addTrainExample(tm.getCurrentClassifier()->name);
...@@ -262,6 +271,13 @@ void Trainer::train() ...@@ -262,6 +271,13 @@ void Trainer::train()
std::string actionName = ""; std::string actionName = "";
//ici
if (ProgramParameters::featureExtraction)
{
auto features = tm.getCurrentClassifier()->getFeatureModel()->getFeatureDescription(trainConfig).featureValues();
fprintf(stdout, "%s\t%s\n", oAction.c_str(), features.c_str());
}
if (TI.getEpoch() >= k && choiceWithProbability(ProgramParameters::dynamicProbability)) if (TI.getEpoch() >= k && choiceWithProbability(ProgramParameters::dynamicProbability))
{ {
actionName = pAction; actionName = pAction;
......
...@@ -85,7 +85,8 @@ po::options_description getOptionsDescription() ...@@ -85,7 +85,8 @@ po::options_description getOptionsDescription()
"The name of the Tape for which some of the elements will be masked.") "The name of the Tape for which some of the elements will be masked.")
("maskRate", po::value<float>()->default_value(0.0), ("maskRate", po::value<float>()->default_value(0.0),
"The rate of elements of the Tape that will be masked.") "The rate of elements of the Tape that will be masked.")
("printTime", "Print time on stderr") ("printTime", "Print time on stderr.")
("featureExtraction", "Use macaon only a feature extractor, print corpus to stdout.")
("shuffle", po::value<bool>()->default_value(true), ("shuffle", po::value<bool>()->default_value(true),
"Shuffle examples after each iteration"); "Shuffle examples after each iteration");
...@@ -268,6 +269,7 @@ int main(int argc, char * argv[]) ...@@ -268,6 +269,7 @@ int main(int argc, char * argv[])
ProgramParameters::debug = vm.count("debug") == 0 ? false : true; ProgramParameters::debug = vm.count("debug") == 0 ? false : true;
ProgramParameters::printEntropy = vm.count("printEntropy") == 0 ? false : true; ProgramParameters::printEntropy = vm.count("printEntropy") == 0 ? false : true;
ProgramParameters::printTime = vm.count("printTime") == 0 ? false : true; ProgramParameters::printTime = vm.count("printTime") == 0 ? false : true;
ProgramParameters::featureExtraction = vm.count("featureExtraction") == 0 ? false : true;
ProgramParameters::trainName = vm["train"].as<std::string>(); ProgramParameters::trainName = vm["train"].as<std::string>();
ProgramParameters::devName = vm["dev"].as<std::string>(); ProgramParameters::devName = vm["dev"].as<std::string>();
ProgramParameters::lang = vm["lang"].as<std::string>(); ProgramParameters::lang = vm["lang"].as<std::string>();
......
...@@ -197,6 +197,8 @@ class Classifier ...@@ -197,6 +197,8 @@ class Classifier
/// ///
/// @return The number of actions. /// @return The number of actions.
unsigned int getNbActions(); unsigned int getNbActions();
/// @brief Get a pointer to the FeatureModel.
FeatureModel * getFeatureModel();
}; };
#endif #endif
...@@ -64,6 +64,10 @@ class FeatureModel ...@@ -64,6 +64,10 @@ class FeatureModel
/// ///
/// @return The string representing this FeatureDescription /// @return The string representing this FeatureDescription
std::string toString(); std::string toString();
/// @brief Return a string representing the values of the features
///
/// @return The string representing the values of the features
std::string featureValues();
}; };
private : private :
......
...@@ -328,3 +328,8 @@ unsigned int Classifier::getNbActions() ...@@ -328,3 +328,8 @@ unsigned int Classifier::getNbActions()
return as->size(); return as->size();
} }
FeatureModel * Classifier::getFeatureModel()
{
return fm.get();
}
...@@ -133,3 +133,17 @@ FeatureModel::FeatureValue::FeatureValue() ...@@ -133,3 +133,17 @@ FeatureModel::FeatureValue::FeatureValue()
{ {
} }
std::string FeatureModel::FeatureDescription::featureValues()
{
std::string res;
for (auto & feature : values)
for (auto & value : feature.values)
res += value + "\t";
if (!res.empty())
res.pop_back();
return res;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment