Skip to content
Snippets Groups Projects
Commit 38f7f9ad authored by Franck Dary's avatar Franck Dary
Browse files

Added first version of error evaluation

parent 2ef613e3
No related branches found
No related tags found
No related merge requests found
FILE(GLOB SOURCES src/*.cpp)
add_executable(macaon_error_correction src/macaon_error_correction.cpp)
target_link_libraries(macaon_error_correction errors)
target_link_libraries(macaon_error_correction transition_machine)
target_link_libraries(macaon_error_correction ${Boost_PROGRAM_OPTIONS_LIBRARY})
install(TARGETS macaon_error_correction DESTINATION bin)
......@@ -14,3 +15,6 @@ add_executable(macaon_decode_error_detector src/macaon_decode_error_detector.cpp
target_link_libraries(macaon_decode_error_detector transition_machine)
target_link_libraries(macaon_decode_error_detector ${Boost_PROGRAM_OPTIONS_LIBRARY})
install(TARGETS macaon_decode_error_detector DESTINATION bin)
#compiling library
add_library(errors STATIC ${SOURCES})
/// @file Error.hpp
/// @author Franck Dary
/// @version 1.0
/// @date 2018-12-12
#ifndef ERROR__H
#define ERROR__H
#include "Classifier.hpp"
#include <vector>
#include <string>
class Error
{
private :
std::string prediction;
std::string gold;
Classifier::WeightedActions weightedActions;
std::string type;
public :
Error(std::string &, std::string &, Classifier::WeightedActions &);
bool isError() const;
const std::string & getType() const;
};
class ErrorSequence
{
private :
std::vector<Error> sequence;
std::map<std::string, bool> types;
public :
void add(const Error & error);
const std::map<std::string, bool> & getTypes() const;
const std::vector<Error> & getSequence() const;
};
class Errors
{
private :
std::vector<ErrorSequence> sequences;
public :
void newSequence();
void add(const Error & error);
void printStats();
};
#endif
#include "Error.hpp"
Error::Error(std::string & prediction, std::string & gold, Classifier::WeightedActions & weightedActions) :
prediction(prediction), gold(gold), weightedActions(weightedActions)
{
type = prediction + "->" + gold;
}
const std::string & Error::getType() const
{
return type;
}
void ErrorSequence::add(const Error & error)
{
sequence.emplace_back(error);
types[error.getType()] = true;
}
const std::map<std::string, bool> & ErrorSequence::getTypes() const
{
return types;
}
const std::vector<Error> & ErrorSequence::getSequence() const
{
return sequence;
}
void Errors::newSequence()
{
sequences.emplace_back();
}
void Errors::add(const Error & error)
{
sequences.back().add(error);
}
bool Error::isError() const
{
return prediction != gold;
}
void Errors::printStats()
{
std::map<std::string, int> nbOccurencesByType;
int nbErrorsTotal = 0;
for (auto & sequence : sequences)
for (auto & error : sequence.getSequence())
{
if (!error.isError())
{
}
else
{
nbOccurencesByType[error.getType()]++;
nbErrorsTotal++;
}
}
std::vector< std::pair<std::string,int> > typesOccurences;
for (auto & it : nbOccurencesByType)
typesOccurences.emplace_back(std::pair<std::string,int>(it.first,it.second));
std::sort(typesOccurences.begin(), typesOccurences.end(),
[](const std::pair<std::string,int> & a, const std::pair<std::string,int> & b)
{
return a.second > b.second;
});
std::vector< std::vector<std::string> > columns;
columns.clear();
columns.resize(4);
for (auto & it : typesOccurences)
{
columns[0].emplace_back(it.first);
columns[1].emplace_back("= " + float2str(it.second*100.0/nbErrorsTotal,"%.2f%%"));
columns[2].emplace_back(" of errors (" + std::to_string(it.second));
columns[3].emplace_back(" / " + std::to_string(nbErrorsTotal) + ")");
}
printColumns(stderr, columns, 1);
}
......@@ -10,6 +10,7 @@
#include "Config.hpp"
#include "TransitionMachine.hpp"
#include "util.hpp"
#include "Error.hpp"
namespace po = boost::program_options;
......@@ -132,6 +133,8 @@ int main(int argc, char * argv[])
bool configIsError = false;
int actionIndex = 0;
int errorIndex = 0;
Errors errors;
errors.newSequence();
while (!config.isFinal())
{
TransitionMachine::State * currentState = tm.getCurrentState();
......@@ -147,7 +150,6 @@ int main(int argc, char * argv[])
auto weightedActions = classifier->weightActions(config);
if (ProgramParameters::debug)
{
Classifier::printWeightedActions(stderr, weightedActions);
......@@ -184,8 +186,8 @@ int main(int argc, char * argv[])
if (classifier->name == ProgramParameters::classifierName)
{
fprintf(stderr, "%d\t%d\n", configIsError ? 1 : 0, errorIndex - actionIndex);
config.printAsExample(stderr);
//fprintf(stderr, "%d\t%d\n", configIsError ? 1 : 0, errorIndex - actionIndex);
//config.printAsExample(stderr);
actionIndex++;
auto zeroCostActions = classifier->getZeroCostActions(config);
......@@ -212,8 +214,14 @@ int main(int argc, char * argv[])
configIsError = false;
errorIndex = 0;
}
if (configIsError)
{
errors.add({action->name, zeroCostActions[0], weightedActions});
}
}
action->apply(config);
TransitionMachine::Transition * transition = tm.getTransition(predictedAction);
......@@ -236,11 +244,16 @@ int main(int argc, char * argv[])
{
justFlipped = true;
entropyAccumulator = 0.0;
errors.newSequence();
configIsError = false;
errorIndex = 0;
}
}
}
errors.printStats();
return 0;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment