diff --git a/torch_modules/src/CNNNetwork.cpp b/torch_modules/src/CNNNetwork.cpp index 2f0cc9f8a13dfd00fd3d69e8d0cc8170d5b1afdf..285ad8c40fcf7ba164a7a32c2dc5501f02ab0af7 100644 --- a/torch_modules/src/CNNNetwork.cpp +++ b/torch_modules/src/CNNNetwork.cpp @@ -3,7 +3,7 @@ CNNNetworkImpl::CNNNetworkImpl(int nbOutputs, int leftBorder, int rightBorder, int nbStackElements, std::vector<std::string> columns, std::vector<int> focusedBufferIndexes, std::vector<int> focusedStackIndexes, std::vector<std::string> focusedColumns, std::vector<int> maxNbElements, int leftWindowRawInput, int rightWindowRawInput) : focusedBufferIndexes(focusedBufferIndexes), focusedStackIndexes(focusedStackIndexes), focusedColumns(focusedColumns), maxNbElements(maxNbElements), leftWindowRawInput(leftWindowRawInput), rightWindowRawInput(rightWindowRawInput) { constexpr int embeddingsSize = 64; - constexpr int hiddenSize = 512; + constexpr int hiddenSize = 1024; constexpr int nbFiltersContext = 512; constexpr int nbFiltersFocused = 64; @@ -152,6 +152,15 @@ std::vector<long> CNNNetworkImpl::extractContext(Config & config, Dict & dict) c else elements.emplace_back(Dict::nullValueStr); } + else if (col == "ID") + { + if (config.isTokenPredicted(index)) + elements.emplace_back("ID(TOKEN)"); + else if (config.isMultiwordPredicted(index)) + elements.emplace_back("ID(MULTIWORD)"); + else if (config.isEmptyNodePredicted(index)) + elements.emplace_back("ID(EMPTYNODE)"); + } else { elements.emplace_back(config.getAsFeature(col, index)); diff --git a/torch_modules/src/NeuralNetwork.cpp b/torch_modules/src/NeuralNetwork.cpp index 3f69b4a3e63753738ac7b7b0921877da03ce6317..0ef9f8dd03b2be4e90afedea75fe5cedfb66ce41 100644 --- a/torch_modules/src/NeuralNetwork.cpp +++ b/torch_modules/src/NeuralNetwork.cpp @@ -6,7 +6,7 @@ std::vector<long> NeuralNetworkImpl::extractContextIndexes(const Config & config { std::stack<long> leftContext; for (int index = config.getWordIndex()-1; config.has(0,index,0) && leftContext.size() < leftBorder; --index) - if (config.isToken(index)) + if (!config.isComment(index)) leftContext.push(index); std::vector<long> context; @@ -20,7 +20,7 @@ std::vector<long> NeuralNetworkImpl::extractContextIndexes(const Config & config } for (int index = config.getWordIndex(); config.has(0,index,0) && context.size() < leftBorder+rightBorder+1; ++index) - if (config.isToken(index)) + if (!config.isComment(index)) context.emplace_back(index); while (context.size() < leftBorder+rightBorder+1) diff --git a/trainer/include/Trainer.hpp b/trainer/include/Trainer.hpp index e04f3e37dcee7bd29ea47dd226d144d9574d1e40..259a150e2fa60cd619689fc6fbc2a8be89b81e39 100644 --- a/trainer/include/Trainer.hpp +++ b/trainer/include/Trainer.hpp @@ -19,7 +19,7 @@ class Trainer DataLoader devDataLoader{nullptr}; std::unique_ptr<torch::optim::Adam> optimizer; std::size_t epochNumber{0}; - int batchSize{50}; + int batchSize{64}; int nbExamples{0}; private : diff --git a/trainer/src/Trainer.cpp b/trainer/src/Trainer.cpp index 450c41660246845d180b250d97d479c3d3fd8eb8..2b68072465a599038ba0420003514bf9581f2db3 100644 --- a/trainer/src/Trainer.cpp +++ b/trainer/src/Trainer.cpp @@ -16,7 +16,7 @@ void Trainer::createDataset(SubConfig & config, bool debug) dataLoader = torch::data::make_data_loader(Dataset(contexts, classes).map(torch::data::transforms::Stack<>()), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); - optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(0.001).amsgrad(true).beta1(0.9).beta2(0.999))); + optimizer.reset(new torch::optim::Adam(machine.getClassifier()->getNN()->parameters(), torch::optim::AdamOptions(0.0005).amsgrad(true).beta1(0.9).beta2(0.999))); } void Trainer::createDevDataset(SubConfig & config, bool debug)