Skip to content
Snippets Groups Projects
Commit b8e657ce authored by Franck Dary's avatar Franck Dary
Browse files

Started decoder

parent 760f9ec8
No related branches found
No related tags found
No related merge requests found
......@@ -17,11 +17,13 @@ set(CMAKE_CXX_FLAGS_RELEASE "-O3")
include_directories(maca_common/include)
include_directories(tape_machine/include)
include_directories(trainer/include)
include_directories(decoder/include)
include_directories(tests/include)
include_directories(MLP/include)
add_subdirectory(maca_common)
add_subdirectory(tape_machine)
add_subdirectory(trainer)
add_subdirectory(decoder)
add_subdirectory(MLP)
add_subdirectory(tests)
......@@ -150,11 +150,10 @@ dynet::Parameter & MLP::featValue2parameter(const FeatureModel::FeatureValue & f
if(it != ptr2parameter.end())
return it->second;
//ptr2parameter[fv.vec] = model.add_parameters({fv.vec->size(),1}, dynet::ParameterInitFromVector(*fv.vec));
ptr2parameter[fv.vec] = model.add_parameters({(unsigned)fv.vec->size(),1});
it = ptr2parameter.find(fv.vec);
// it->second.values()->v = fv.vec->data();
it->second.values()->v = fv.vec->data();
return it->second;
}
......
FILE(GLOB SOURCES src/*.cpp)
#compiling library
add_library(decoder STATIC ${SOURCES})
#ifndef DECODER__H
#define DECODER__H
#include "TapeMachine.hpp"
#include "MCD.hpp"
#include "Config.hpp"
class Decoder
{
private :
TapeMachine & tm;
MCD & mcd;
Config & config;
public :
Decoder(TapeMachine & tm, MCD & mcd, Config & config);
void decode();
};
#endif
#include "Decoder.hpp"
Decoder::Decoder(TapeMachine & tm, MCD & mcd, Config & config)
: tm(tm), mcd(mcd), config(config)
{
}
void Decoder::decode()
{
}
......@@ -33,6 +33,8 @@ class Classifier
public :
static void printWeightedActions(FILE * output, WeightedActions & wa);
static Type str2type(const std::string & filename);
Classifier(const std::string & filename);
WeightedActions weightActions(Config & config, const std::string & goldAction);
......
......@@ -123,3 +123,18 @@ std::string Classifier::getActionName(int actionIndex)
return as->getActionName(actionIndex);
}
void Classifier::printWeightedActions(FILE * output, WeightedActions & wa)
{
int nbCols = 80;
char symbol = '-';
for(int i = 0; i < nbCols; i++)
fprintf(output, "%c%s", symbol, i == nbCols-1 ? "\n" : "");
for (auto it : wa)
fprintf(output, "%.2f\t%s\n", it.first, it.second.c_str());
for(int i = 0; i < nbCols; i++)
fprintf(output, "%c%s", symbol, i == nbCols-1 ? "\n" : "");
}
......@@ -15,7 +15,6 @@ class Trainer
private :
void printWeightedActions(FILE * output, Classifier::WeightedActions & wa);
void trainUnbatched();
void trainBatched();
......
......@@ -71,8 +71,8 @@ void Trainer::trainBatched()
config.moveHead(transition->headMvt);
}
int nbIter = 20;
int batchSize = 50;
int nbIter = 5;
int batchSize = 256;
for (int i = 0; i < nbIter; i++)
{
......@@ -109,18 +109,3 @@ void Trainer::train()
trainBatched();
}
void Trainer::printWeightedActions(FILE * output, Classifier::WeightedActions & wa)
{
int nbCols = 80;
char symbol = '-';
for(int i = 0; i < nbCols; i++)
fprintf(output, "%c%s", symbol, i == nbCols-1 ? "\n" : "");
for (auto it : wa)
fprintf(output, "%.2f\t%s\n", it.first, it.second.c_str());
for(int i = 0; i < nbCols; i++)
fprintf(output, "%c%s", symbol, i == nbCols-1 ? "\n" : "");
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment