Select Git revision
MLP.hpp 2.15 KiB
#ifndef MLP__H
#define MLP__H
#include <dynet/nodes.h>
#include <dynet/dynet.h>
#include <dynet/training.h>
#include <dynet/timing.h>
#include <dynet/expr.h>
#include "FeatureModel.hpp"
class MLP
{
public :
using Examples = std::pair< std::vector<int>, std::vector<std::pair<int, FeatureModel::FeatureDescription> > >;
enum Activation
{
SIGMOID,
TANH,
RELU,
ELU,
LINEAR,
SPARSEMAX,
CUBE,
SOFTMAX
};
static std::string activation2str(Activation a);
static Activation str2activation(std::string s);
struct Layer
{
int input_dim;
int output_dim;
float dropout_rate;
Activation activation;
Layer(int input_dim, int output_dim,
float dropout_rate, Activation activation);
void print(FILE * file);
};
private :
static const unsigned int MAXLOOKUPSIZE = 200000;
std::vector<Layer> layers;
std::vector< std::vector<dynet::Parameter> > parameters;
std::map< Dict*, std::pair<dynet::LookupParameter, std::map<void*, unsigned int> > > lookupParameters;
dynet::ParameterCollection model;
dynet::AmsgradTrainer trainer;
bool trainMode;
bool dropoutActive;
private :
void addLayerToModel(Layer & layer);
void checkLayersCompatibility();
dynet::DynetParams & getDefaultParams();
dynet::Expression featValue2Expression(dynet::ComputationGraph & cg, const FeatureModel::FeatureValue & fv);
dynet::Expression run(dynet::ComputationGraph & cg, dynet::Expression x);
inline dynet::Expression activate(dynet::Expression h, Activation f);
void printParameters(FILE * output);
void saveStruct(const std::string & filename);
void saveParameters(const std::string & filename);
void loadStruct(const std::string & filename);
void loadParameters(const std::string & filename);
void load(const std::string & filename);
public :
MLP(std::vector<Layer> layers);
MLP(const std::string & filename);
std::vector<float> predict(FeatureModel::FeatureDescription & fd);
int trainOnBatch(Examples & examples, int start, int end);
int getScoreOnBatch(Examples & examples, int start, int end);
void save(const std::string & filename);
void printTopology(FILE * output);
};
#endif