Skip to content
Snippets Groups Projects
Commit 1efb791a authored by Franck Dary's avatar Franck Dary
Browse files

Neural network now sees multiwords, also ID can now be a focused column

parent ba5742bd
No related branches found
No related tags found
No related merge requests found
......@@ -3,7 +3,7 @@
CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) : focusedBufferIndexes(focusedBufferIndexes), focusedStackIndexes(focusedStackIndexes), focusedColumns(focusedColumns), maxNbElements(maxNbElements), leftWindowRawInput(leftWindowRawInput), rightWindowRawInput(rightWindowRawInput)
{
constexpr int embeddingsSize = 64;
constexpr int hiddenSize = 512;
constexpr int hiddenSize = 1024;
constexpr int nbFiltersContext = 512;
constexpr int nbFiltersFocused = 64;
......@@ -152,6 +152,15 @@ std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) c
else
elements.emplace_back(Dict::nullValueStr);
}
else if (col == "ID")
{
if (config.isTokenPredicted(index))
elements.emplace_back("ID(TOKEN)");
else if (config.isMultiwordPredicted(index))
elements.emplace_back("ID(MULTIWORD)");
else if (config.isEmptyNodePredicted(index))
elements.emplace_back("ID(EMPTYNODE)");
}
else
{
elements.emplace_back(config.getAsFeature(col, index));
......
......@@ -6,7 +6,7 @@ std::vector<long> NeuralNetworkImpl::extractContextIndexes(const Config & config
{
std::stack<long> leftContext;
for (int index = config.getWordIndex()-1; config.has(0,index,0) && leftContext.size() < leftBorder; --index)
if (config.isToken(index))
if (!config.isComment(index))
leftContext.push(index);
std::vector<long> context;
......@@ -20,7 +20,7 @@ std::vector<long> NeuralNetworkImpl::extractContextIndexes(const Config & config
}
for (int index = config.getWordIndex(); config.has(0,index,0) && context.size() < leftBorder+rightBorder+1; ++index)
if (config.isToken(index))
if (!config.isComment(index))
context.emplace_back(index);
while (context.size() < leftBorder+rightBorder+1)
......
......@@ -19,7 +19,7 @@ class Trainer
DataLoader devDataLoader{nullptr};
std::unique_ptr<torch::optim::Adam> optimizer;
std::size_t epochNumber{0};
int batchSize{50};
int batchSize{64};
int nbExamples{0};
private :
......
......@@ -16,7 +16,7 @@ void Trainer::createDataset(SubConfig & config, bool debug)
dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(0.001).amsgrad(true).beta1(0.9).beta2(0.999)));
optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(0.0005).amsgrad(true).beta1(0.9).beta2(0.999)));
}
void Trainer::createDevDataset(SubConfig & config, bool debug)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment