Skip to content
Snippets Groups Projects
Commit 1efb791a authored by Franck Dary's avatar Franck Dary
Browse files

Neural network now sees multiwords, also ID can now be a focused column

parent ba5742bd
No related branches found
No related tags found
No related merge requests found
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) : focusedBufferIndexes(focusedBufferIndexes), focusedStackIndexes(focusedStackIndexes), focusedColumns(focusedColumns), maxNbElements(maxNbElements), leftWindowRawInput(leftWindowRawInput), rightWindowRawInput(rightWindowRawInput) CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) : focusedBufferIndexes(focusedBufferIndexes), focusedStackIndexes(focusedStackIndexes), focusedColumns(focusedColumns), maxNbElements(maxNbElements), leftWindowRawInput(leftWindowRawInput), rightWindowRawInput(rightWindowRawInput)
{ {
constexpr int embeddingsSize = 64; constexpr int embeddingsSize = 64;
constexpr int hiddenSize = 512; constexpr int hiddenSize = 1024;
constexpr int nbFiltersContext = 512; constexpr int nbFiltersContext = 512;
constexpr int nbFiltersFocused = 64; constexpr int nbFiltersFocused = 64;
...@@ -152,6 +152,15 @@ std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) c ...@@ -152,6 +152,15 @@ std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) c
else else
elements.emplace_back(Dict::nullValueStr); elements.emplace_back(Dict::nullValueStr);
} }
else if (col == "ID")
{
if (config.isTokenPredicted(index))
elements.emplace_back("ID(TOKEN)");
else if (config.isMultiwordPredicted(index))
elements.emplace_back("ID(MULTIWORD)");
else if (config.isEmptyNodePredicted(index))
elements.emplace_back("ID(EMPTYNODE)");
}
else else
{ {
elements.emplace_back(config.getAsFeature(col, index)); elements.emplace_back(config.getAsFeature(col, index));
......
...@@ -6,7 +6,7 @@ std::vector<long> NeuralNetworkImpl::extractContextIndexes(const Config & config ...@@ -6,7 +6,7 @@ std::vector<long> NeuralNetworkImpl::extractContextIndexes(const Config & config
{ {
std::stack<long> leftContext; std::stack<long> leftContext;
for (int index = config.getWordIndex()-1; config.has(0,index,0) && leftContext.size() < leftBorder; --index) for (int index = config.getWordIndex()-1; config.has(0,index,0) && leftContext.size() < leftBorder; --index)
if (config.isToken(index)) if (!config.isComment(index))
leftContext.push(index); leftContext.push(index);
std::vector<long> context; std::vector<long> context;
...@@ -20,7 +20,7 @@ std::vector<long> NeuralNetworkImpl::extractContextIndexes(const Config & config ...@@ -20,7 +20,7 @@ std::vector<long> NeuralNetworkImpl::extractContextIndexes(const Config & config
} }
for (int index = config.getWordIndex(); config.has(0,index,0) && context.size() < leftBorder+rightBorder+1; ++index) for (int index = config.getWordIndex(); config.has(0,index,0) && context.size() < leftBorder+rightBorder+1; ++index)
if (config.isToken(index)) if (!config.isComment(index))
context.emplace_back(index); context.emplace_back(index);
while (context.size() < leftBorder+rightBorder+1) while (context.size() < leftBorder+rightBorder+1)
......
...@@ -19,7 +19,7 @@ class Trainer ...@@ -19,7 +19,7 @@ class Trainer
DataLoader devDataLoader{nullptr}; DataLoader devDataLoader{nullptr};
std::unique_ptr<torch::optim::Adam> optimizer; std::unique_ptr<torch::optim::Adam> optimizer;
std::size_t epochNumber{0}; std::size_t epochNumber{0};
int batchSize{50}; int batchSize{64};
int nbExamples{0}; int nbExamples{0};
private : private :
......
...@@ -16,7 +16,7 @@ void Trainer::createDataset(SubConfig & config, bool debug) ...@@ -16,7 +16,7 @@ void Trainer::createDataset(SubConfig & config, bool debug)
dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0));
optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(0.001).amsgrad(true).beta1(0.9).beta2(0.999))); optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(0.0005).amsgrad(true).beta1(0.9).beta2(0.999)));
} }
void Trainer::createDevDataset(SubConfig & config, bool debug) void Trainer::createDevDataset(SubConfig & config, bool debug)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment