Commit 8a9f3bf7 authored by Marjorie Armando's avatar Marjorie Armando
Browse files

maj

parent 6ed3881e
......@@ -133,6 +133,9 @@ int main(int argc, char** argv)
// Build model -----------------------------------------------------------------------------------
ParameterCollection model;
LookupParameter table_embedding = model.add_lookup_parameters(/*nb of words in cff (this is an example for testing)*/5027,
/*dim of the vector embedding*/{300});
// Use Adam optimizer
AdamTrainer trainer(model);
trainer.clip_threshold *= batch_size;
......@@ -196,11 +199,17 @@ int main(int argc, char** argv)
cur_labels = vector<unsigned int>(bsize);
for (unsigned int idx = 0; idx < bsize; ++idx)
{
cur_batch[idx] = input(cg, {5}, cff_train[id + idx]);
vector<Expression> vect_expr(NB_FEATS);
for(unsigned int i=0; i<NB_FEATS; ++i)
vect_expr[i] = const_lookup(cg, table_embedding, static_cast<unsigned>(cff_train[id + idx][i]) );
//cur_batch[idx] = input(cg, {5}, cff_train[id + idx]);
cur_batch[idx] = concatenate(vect_expr, 300);
cur_labels[idx] = cff_train_labels[id + idx];
}
// Reshape as batch (not very intuitive yet)
Expression x_batch = reshape(concatenate_cols(cur_batch), Dim({5}, bsize));
/**Trouver comment faire les batchs avec les embeddings !*/
/**Expression x_batch = reshape(concatenate_cols(cur_batch), Dim({5}, bsize));*/
// Get negative log likelihood on batch
Expression loss_expr = nn.get_nll(x_batch, cur_labels, cg);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment