diff --git a/dev/src/dev.cpp b/dev/src/dev.cpp index b98111b00c26e711516ba72db2b10f82db6c2cf5..f056ce2a28d64fa4ec9a95cd60a47893b3a52904 100644 --- a/dev/src/dev.cpp +++ b/dev/src/dev.cpp @@ -8,155 +8,177 @@ #include "TestNetwork.hpp" #include "ConfigDataset.hpp" -constexpr int batchSize = 50; -constexpr int nbExamples = 350000; -constexpr int embeddingSize = 20; -constexpr int nbClasses = 15; -constexpr int nbWordsPerDatapoint = 5; -constexpr int maxNbEmbeddings = 1000000; - -//3m15s -struct NetworkImpl : torch::nn::Module +//constexpr int batchSize = 50; +//constexpr int nbExamples = 350000; +//constexpr int embeddingSize = 20; +//constexpr int nbClasses = 15; +//constexpr int nbWordsPerDatapoint = 5; +//constexpr int maxNbEmbeddings = 1000000; +// +//struct NetworkImpl : torch::nn::Module +//{ +// torch::nn::Linear linear{nullptr}; +// torch::nn::Embedding wordEmbeddings{nullptr}; +// +// std::vector<torch::Tensor> _sparseParameters; +// std::vector<torch::Tensor> _denseParameters; +// NetworkImpl() +// { +// linear = register_module("dense_linear", torch::nn::Linear(embeddingSize, nbClasses)); +// auto params = linear->parameters(); +// _denseParameters.insert(_denseParameters.end(), params.begin(), params.end()); +// +// wordEmbeddings = register_module("sparse_word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingSize).sparse(true))); +// params = wordEmbeddings->parameters(); +// _sparseParameters.insert(_sparseParameters.end(), params.begin(), params.end()); +// }; +// const std::vector<torch::Tensor> & denseParameters() +// { +// return _denseParameters; +// } +// const std::vector<torch::Tensor> & sparseParameters() +// { +// return _sparseParameters; +// } +// torch::Tensor forward(const torch::Tensor & input) +// { +// // I have a batch of sentences (list of word embeddings), so as the sentence embedding I take the mean of the embedding of its words +// auto embeddingsOfInput = wordEmbeddings(input).mean(1); +// return torch::softmax(linear(embeddingsOfInput),1); +// } +//}; +//TORCH_MODULE(Network); + +//int main(int argc, char * argv[]) +//{ +// auto nn = Network(); +// torch::optim::SparseAdam sparseOptimizer(nn->sparseParameters(), torch::optim::SparseAdamOptions(2e-4).beta1(0.5)); +// torch::optim::Adam denseOptimizer(nn->denseParameters(), torch::optim::AdamOptions(2e-4).beta1(0.5)); +// std::vector<std::pair<torch::Tensor,torch::Tensor>> batches; +// for (int nbBatch = 0; nbBatch < nbExamples / batchSize; ++nbBatch) +// batches.emplace_back(std::make_pair(torch::randint(maxNbEmbeddings,{batchSize,nbWordsPerDatapoint}, at::kLong), torch::randint(nbClasses, batchSize, at::kLong))); +// +// for (auto & batch : batches) +// { +// sparseOptimizer.zero_grad(); +// denseOptimizer.zero_grad(); +// auto prediction = nn(batch.first); +// auto loss = torch::nll_loss(torch::log(prediction), batch.second); +// loss.backward(); +// sparseOptimizer.step(); +// denseOptimizer.step(); +// } +// return 0; +//} + +int main(int argc, char * argv[]) { - torch::nn::Linear linear{nullptr}; - torch::nn::Embedding wordEmbeddings{nullptr}; - NetworkImpl() + if (argc != 5) { - linear = register_module("dense_linear", torch::nn::Linear(embeddingSize, nbClasses)); - wordEmbeddings = register_module("sparse_word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(maxNbEmbeddings, embeddingSize).sparse(true))); - }; - torch::Tensor forward(const torch::Tensor & input) + fmt::print(stderr, "needs 4 arguments.\n"); + exit(1); + } + + at::init_num_threads(); + + std::string machineFile = argv[1]; + std::string mcdFile = argv[2]; + std::string tsvFile = argv[3]; + //std::string rawFile = argv[4]; + std::string rawFile = ""; + + ReadingMachine machine(machineFile); + + BaseConfig goldConfig(mcdFile, tsvFile, rawFile); + SubConfig config(goldConfig); + + config.setState(machine.getStrategy().getInitialState()); + + std::vector<torch::Tensor> contexts; + std::vector<torch::Tensor> classes; + + fmt::print("Generating dataset...\n"); + + Dict dict(Dict::State::Open); + + while (true) { - // I have a batch of sentences (list of word embeddings), so as the sentence embedding I take the mean of the embedding of its words - auto embeddingsOfInput = wordEmbeddings(input).mean(1); - return torch::softmax(linear(embeddingsOfInput),1); + auto * transition = machine.getTransitionSet().getBestAppliableTransition(config); + if (!transition) + util::myThrow("No transition appliable !"); + + auto context = config.extractContext(5,5,dict); + contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone()); + + int goldIndex = 3; + auto gold = torch::zeros(1, at::kLong); + gold[0] = goldIndex; + + classes.emplace_back(gold); + + transition->apply(config); + config.addToHistory(transition->getName()); + + auto movement = machine.getStrategy().getMovement(config, transition->getName()); + if (movement == Strategy::endMovement) + break; + + config.setState(movement.first); + if (!config.moveWordIndex(movement.second)) + util::myThrow("Cannot move word index !"); + + if (config.needsUpdate()) + config.update(); } -}; -TORCH_MODULE(Network); -int main(int argc, char * argv[]) -{ - auto nn = Network(); - torch::optim::SparseAdam sparseOptimizer(nn->parameters(), torch::optim::SparseAdamOptions(2e-4).beta1(0.5)); - torch::optim::Adam denseOptimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5)); - std::vector<std::pair<torch::Tensor,torch::Tensor>> batches; - for (int nbBatch = 0; nbBatch < nbExamples / batchSize; ++nbBatch) - batches.emplace_back(std::make_pair(torch::randint(maxNbEmbeddings,{batchSize,nbWordsPerDatapoint}, at::kLong), torch::randint(nbClasses, batchSize, at::kLong))); - - for (auto & batch : batches) + auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>()); + + int nbExamples = *dataset.size(); + fmt::print("Done! size={}\n", nbExamples); + + int batchSize = 100; + auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); + + TestNetwork nn(machine.getTransitionSet().size(), 5); + torch::optim::Adam denseOptimizer(nn->denseParameters(), torch::optim::AdamOptions(2e-1).beta1(0.5)); + torch::optim::SparseAdam sparseOptimizer(nn->sparseParameters(), torch::optim::SparseAdamOptions(2e-1).beta1(0.5)); + + for (int epoch = 1; epoch <= 2; ++epoch) { - sparseOptimizer.zero_grad(); - denseOptimizer.zero_grad(); - auto prediction = nn(batch.first); - auto loss = torch::nll_loss(torch::log(prediction), batch.second); - loss.backward(); - sparseOptimizer.step(); - denseOptimizer.step(); + float totalLoss = 0.0; + float lossSoFar = 0.0; + torch::Tensor example; + int currentBatchNumber = 0; + + for (auto & batch : *dataLoader) + { + denseOptimizer.zero_grad(); + sparseOptimizer.zero_grad(); + + auto data = batch.data; + auto labels = batch.target.squeeze(); + + auto prediction = nn(data); + example = prediction[0]; + + auto loss = torch::nll_loss(torch::log(prediction), labels); + totalLoss += loss.item<float>(); + lossSoFar += loss.item<float>(); + loss.backward(); + denseOptimizer.step(); + sparseOptimizer.step(); + + if (++currentBatchNumber*batchSize % 1000 == 0) + { + fmt::print("\rcurrent epoch : {:6.2f}% loss={:<15}", 100.0*currentBatchNumber*batchSize/nbExamples, lossSoFar); + std::fflush(stdout); + lossSoFar = 0; + } + } + + fmt::print("\nEpoch {} : loss={:.2f}\n", epoch, totalLoss); } + return 0; } -//int main(int argc, char * argv[]) -//{ -// if (argc != 5) -// { -// fmt::print(stderr, "needs 4 arguments.\n"); -// exit(1); -// } -// -// at::init_num_threads(); -// -// std::string machineFile = argv[1]; -// std::string mcdFile = argv[2]; -// std::string tsvFile = argv[3]; -// //std::string rawFile = argv[4]; -// std::string rawFile = ""; -// -// ReadingMachine machine(machineFile); -// -// BaseConfig goldConfig(mcdFile, tsvFile, rawFile); -// SubConfig config(goldConfig); -// -// config.setState(machine.getStrategy().getInitialState()); -// -// std::vector<torch::Tensor> contexts; -// std::vector<torch::Tensor> classes; -// -// fmt::print("Generating dataset...\n"); -// -// Dict dict(Dict::State::Open); -// -// while (true) -// { -// auto * transition = machine.getTransitionSet().getBestAppliableTransition(config); -// if (!transition) -// util::myThrow("No transition appliable !"); -// -// auto context = config.extractContext(5,5,dict); -// contexts.push_back(torch::from_blob(context.data(), {(long)context.size()}, at::kLong).clone()); -// -// int goldIndex = 3; -// auto gold = torch::from_blob(&goldIndex, {1}, at::kLong).clone(); -// -// classes.emplace_back(gold); -// -// transition->apply(config); -// config.addToHistory(transition->getName()); -// -// auto movement = machine.getStrategy().getMovement(config, transition->getName()); -// if (movement == Strategy::endMovement) -// break; -// -// config.setState(movement.first); -// if (!config.moveWordIndex(movement.second)) -// util::myThrow("Cannot move word index !"); -// -// if (config.needsUpdate()) -// config.update(); -// } -// -// auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>()); -// -// int nbExamples = *dataset.size(); -// fmt::print("Done! size={}\n", nbExamples); -// -// int batchSize = 100; -// auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); -// -// TestNetwork nn(machine.getTransitionSet().size(), 5); -// torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5)); -// -// for (int epoch = 1; epoch <= 1; ++epoch) -// { -// float totalLoss = 0.0; -// torch::Tensor example; -// int currentBatchNumber = 0; -// -// for (auto & batch : *dataLoader) -// { -// optimizer.zero_grad(); -// -// auto data = batch.data; -// auto labels = batch.target.squeeze(); -// -// auto prediction = nn(data); -// example = prediction[0]; -// -// auto loss = torch::nll_loss(torch::log(prediction), labels); -// totalLoss += loss.item<float>(); -// loss.backward(); -// optimizer.step(); -// -// if (++currentBatchNumber*batchSize % 1000 == 0) -// { -// fmt::print("\rcurrent epoch : {:6.2f}%", 100.0*currentBatchNumber*batchSize/nbExamples); -// std::fflush(stdout); -// } -// } -// -// fmt::print("Epoch {} : loss={:.2f}\n", epoch, totalLoss); -// } -// -// return 0; -//} -// diff --git a/torch_modules/include/TestNetwork.hpp b/torch_modules/include/TestNetwork.hpp index b1bb4e11200197f029556f7f3d93338ca0ce072e..27b92e8a00ac3430567ba35c7846ada1aa076d4a 100644 --- a/torch_modules/include/TestNetwork.hpp +++ b/torch_modules/include/TestNetwork.hpp @@ -12,10 +12,15 @@ class TestNetworkImpl : public torch::nn::Module torch::nn::Linear linear{nullptr}; int focusedIndex; + std::vector<torch::Tensor> _denseParameters; + std::vector<torch::Tensor> _sparseParameters; + public : TestNetworkImpl(int nbOutputs, int focusedIndex); torch::Tensor forward(torch::Tensor input); + std::vector<torch::Tensor> & denseParameters(); + std::vector<torch::Tensor> & sparseParameters(); }; TORCH_MODULE(TestNetwork); diff --git a/torch_modules/src/TestNetwork.cpp b/torch_modules/src/TestNetwork.cpp index f379c735a92aa17832bb24774f296e5f45b6aa7a..63257082bdd3dd0d51d29063726325fd862c4e1c 100644 --- a/torch_modules/src/TestNetwork.cpp +++ b/torch_modules/src/TestNetwork.cpp @@ -3,11 +3,28 @@ TestNetworkImpl::TestNetworkImpl(int nbOutputs, int focusedIndex) { constexpr int embeddingsSize = 30; - wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(200000, embeddingsSize)); + + wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(torch::nn::EmbeddingOptions(200000, embeddingsSize).sparse(true))); + auto params = wordEmbeddings->parameters(); + _sparseParameters.insert(_sparseParameters.end(), params.begin(), params.end()); + linear = register_module("linear", torch::nn::Linear(embeddingsSize, nbOutputs)); + params = linear->parameters(); + _denseParameters.insert(_denseParameters.end(), params.begin(), params.end()); + this->focusedIndex = focusedIndex; } +std::vector<torch::Tensor> & TestNetworkImpl::denseParameters() +{ + return _denseParameters; +} + +std::vector<torch::Tensor> & TestNetworkImpl::sparseParameters() +{ + return _sparseParameters; +} + torch::Tensor TestNetworkImpl::forward(torch::Tensor input) { // input dim = {batch, sequence, embeddings}