Skip to content
Snippets Groups Projects
Commit 1ac56cf3 authored by Franck Dary's avatar Franck Dary
Browse files

Made training much more faster by the use of sparse update over lookup parameters

parent b8e657ce
Branches
No related tags found
No related merge requests found
......@@ -41,9 +41,11 @@ class MLP
private :
static const unsigned int MAXLOOKUPSIZE = 200000;
std::vector<Layer> layers;
std::vector< std::vector<dynet::Parameter> > parameters;
std::map<void*,dynet::Parameter> ptr2parameter;
std::map< Dict*, std::pair<dynet::LookupParameter, std::map<void*, unsigned int> > > lookupParameters;
dynet::ParameterCollection model;
dynet::AmsgradTrainer trainer;
......@@ -54,7 +56,7 @@ class MLP
void addLayerToModel(Layer & layer);
void checkLayersCompatibility();
dynet::DynetParams & getDefaultParams();
dynet::Parameter & featValue2parameter(const FeatureModel::FeatureValue & fv);
dynet::Expression featValue2Expression(dynet::ComputationGraph & cg, const FeatureModel::FeatureValue & fv);
dynet::Expression run(dynet::ComputationGraph & cg, dynet::Expression x);
inline dynet::Expression activate(dynet::Expression h, Activation f);
void printParameters(FILE * output);
......
......@@ -114,10 +114,7 @@ std::vector<float> MLP::predict(FeatureModel::FeatureDescription & fd, int goldC
for (auto & featValue : fd.values)
{
if(featValue.policy == FeatureModel::Policy::Final)
expressions.emplace_back(dynet::const_parameter(cg, featValue2parameter(featValue)));
else
expressions.emplace_back(dynet::parameter(cg, featValue2parameter(featValue)));
expressions.emplace_back(featValue2Expression(cg, featValue));
}
dynet::Expression input = dynet::concatenate(expressions);
......@@ -143,19 +140,50 @@ dynet::DynetParams & MLP::getDefaultParams()
return params;
}
dynet::Parameter & MLP::featValue2parameter(const FeatureModel::FeatureValue & fv)
dynet::Expression MLP::featValue2Expression(dynet::ComputationGraph & cg, const FeatureModel::FeatureValue & fv)
{
Dict * dict = fv.dict;
auto entry = lookupParameters.find(dict);
if(entry == lookupParameters.end())
{
auto it = ptr2parameter.find(fv.vec);
lookupParameters[dict].first = model.add_lookup_parameters(MAXLOOKUPSIZE, {(unsigned)dict->getDimension(),1});
}
if(it != ptr2parameter.end())
return it->second;
auto & ptr2index = lookupParameters[dict].second;
auto & lu = lookupParameters[dict].first;
ptr2parameter[fv.vec] = model.add_parameters({(unsigned)fv.vec->size(),1});
it = ptr2parameter.find(fv.vec);
bool isConst = fv.policy == FeatureModel::Policy::Final;
it->second.values()->v = fv.vec->data();
auto it = ptr2index.find(fv.vec);
return it->second;
if(it != ptr2index.end())
{
if(isConst)
return dynet::const_lookup(cg, lu, it->second);
else
return dynet::lookup(cg, lu, it->second);
}
ptr2index[fv.vec] = ptr2index.size();
it = ptr2index.find(fv.vec);
unsigned int lookupSize = (int)(*lu.values()).size();
if(it->second >= lookupSize)
{
fprintf(stderr, "ERROR (%s) : MAXLOOKUPSIZE (%d) is too small. Aborting.\n", ERRINFO, MAXLOOKUPSIZE);
exit(1);
}
// Horrible trick : directly set Dict data as Tensor values
// Works only on CPU
(*lu.values())[it->second].v = fv.vec->data();
if(isConst)
return dynet::const_lookup(cg, lu, it->second);
else
return dynet::lookup(cg, lu, it->second);
}
dynet::Expression MLP::run(dynet::ComputationGraph & cg, dynet::Expression x)
......@@ -225,17 +253,7 @@ inline dynet::Expression MLP::activate(dynet::Expression h, Activation f)
void MLP::printParameters(FILE * output)
{
for(auto & it : ptr2parameter)
{
auto & param = it.second;
dynet::Tensor * tensor = param.values();
float * value = tensor->v;
int dim = tensor->d.size();
fprintf(output, "Param : ");
for(int i = 0; i < dim; i++)
fprintf(output, "%.2f ", value[i]);
fprintf(output, "\n");
}
fprintf(output, "Parameters : NOT IMPLEMENTED\n");
}
int MLP::trainOnBatch(std::vector<std::pair<int, FeatureModel::FeatureDescription> >::iterator & start, std::vector<std::pair<int, FeatureModel::FeatureDescription> >::iterator & end)
......@@ -252,12 +270,7 @@ int MLP::trainOnBatch(std::vector<std::pair<int, FeatureModel::FeatureDescriptio
expressions.clear();
for (auto & featValue : it->second.values)
{
if(featValue.policy == FeatureModel::Policy::Final)
expressions.emplace_back(dynet::const_parameter(cg, featValue2parameter(featValue)));
else
expressions.emplace_back(dynet::parameter(cg, featValue2parameter(featValue)));
}
expressions.emplace_back(featValue2Expression(cg, featValue));
inputs.emplace_back(dynet::concatenate(expressions));
inputDim = inputs.back().dim().rows();
......
......@@ -58,6 +58,7 @@ class Dict
void save();
std::vector<float> * getValue(const std::string & s);
std::vector<float> * getNullValue();
int getDimension();
};
#endif
......@@ -190,3 +190,8 @@ Dict * Dict::getDict(Policy policy, const std::string & filename)
return str2dict[filename].get();
}
int Dict::getDimension()
{
return dimension;
}
......@@ -17,6 +17,7 @@ class FeatureModel
struct FeatureValue
{
Dict * dict;
std::string name;
std::string * value;
std::vector<float> * vec;
......
......@@ -36,9 +36,9 @@ FeatureModel::FeatureValue FeatureBank::simpleBufferAccess(Config & config, int
int index = config.head + relativeIndex;
if(index < 0 || index >= (int)tape.size())
return {featName+"(null)", &Dict::nullValueStr, dict->getNullValue(), policy};
return {dict, featName+"(null)", &Dict::nullValueStr, dict->getNullValue(), policy};
return {featName, &tape[index], dict->getValue(tape[index]), policy};
return {dict, featName, &tape[index], dict->getValue(tape[index]), policy};
}
FeatureModel::FeatureValue FeatureBank::simpleStackAccess(Config & config, int relativeIndex, const std::string & tapeName, const std::string & featName)
......@@ -48,14 +48,14 @@ FeatureModel::FeatureValue FeatureBank::simpleStackAccess(Config & config, int r
auto policy = dictPolicy2FeaturePolicy(dict->policy);
if(relativeIndex < 0 || relativeIndex >= (int)config.stack.size())
return {featName+"(null)", &Dict::nullValueStr, dict->getNullValue(), policy};
return {dict, featName+"(null)", &Dict::nullValueStr, dict->getNullValue(), policy};
int index = config.stack[config.stack.size()-1-relativeIndex];
if(index < 0 || index >= (int)tape.size())
return {featName+"(null)", &Dict::nullValueStr, dict->getNullValue(), policy};
return {dict, featName+"(null)", &Dict::nullValueStr, dict->getNullValue(), policy};
return {featName, &tape[index], dict->getValue(tape[index]), policy};
return {dict, featName, &tape[index], dict->getValue(tape[index]), policy};
}
FeatureModel::Policy FeatureBank::dictPolicy2FeaturePolicy(Dict::Policy policy)
......
......@@ -55,6 +55,8 @@ void Trainer::trainBatched()
std::map<Classifier*, std::vector<Example> > examples;
fprintf(stderr, "Training of \'%s\' :\n", tm.name.c_str());
while (!config.isFinal())
{
TapeMachine::State * currentState = tm.getCurrentState();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment