Skip to content
Snippets Groups Projects
Commit 8a9f3bf7 authored by Marjorie Armando's avatar Marjorie Armando
Browse files

maj

parent 6ed3881e
No related branches found
No related tags found
No related merge requests found
......@@ -133,6 +133,9 @@ int main(int argc, char** argv)
// Build model -----------------------------------------------------------------------------------
ParameterCollection model;
LookupParameter table_embedding = model.add_lookup_parameters(/*nb of words in cff (this is an example for testing)*/5027,
/*dim of the vector embedding*/{300});
// Use Adam optimizer
AdamTrainer trainer(model);
trainer.clip_threshold *= batch_size;
......@@ -196,11 +199,17 @@ int main(int argc, char** argv)
cur_labels = vector<unsigned int>(bsize);
for (unsigned int idx = 0; idx < bsize; ++idx)
{
cur_batch[idx] = input(cg, {5}, cff_train[id + idx]);
vector<Expression> vect_expr(NB_FEATS);
for(unsigned int i=0; i<NB_FEATS; ++i)
vect_expr[i] = const_lookup(cg, table_embedding, static_cast<unsigned>(cff_train[id + idx][i]) );
//cur_batch[idx] = input(cg, {5}, cff_train[id + idx]);
cur_batch[idx] = concatenate(vect_expr, 300);
cur_labels[idx] = cff_train_labels[id + idx];
}
// Reshape as batch (not very intuitive yet)
Expression x_batch = reshape(concatenate_cols(cur_batch), Dim({5}, bsize));
/**Trouver comment faire les batchs avec les embeddings !*/
/**Expression x_batch = reshape(concatenate_cols(cur_batch), Dim({5}, bsize));*/
// Get negative log likelihood on batch
Expression loss_expr = nn.get_nll(x_batch, cur_labels, cg);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment