diff --git a/dev/src/dev.cpp b/dev/src/dev.cpp index b5bd81f47d2f1b37427c699e3eb1d3eb48b15c27..2aac98cf5171f86e9bbd93d953390d246e21a0e7 100644 --- a/dev/src/dev.cpp +++ b/dev/src/dev.cpp @@ -67,18 +67,20 @@ int main(int argc, char * argv[]) auto dataset = ConfigDataset(contexts, classes).map(torch::data::transforms::Stack<>()); - fmt::print("Done! size={}\n", *dataset.size()); + int nbExamples = *dataset.size(); + fmt::print("Done! size={}\n", nbExamples); int batchSize = 100; - auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize)); + auto dataLoader = torch::data::make_data_loader(std::move(dataset), torch::data::DataLoaderOptions(batchSize).workers(0).max_jobs(0)); TestNetwork nn(machine.getTransitionSet().size(), 5); torch::optim::Adam optimizer(nn->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5)); - for (int epoch = 1; epoch <= 5; ++epoch) + for (int epoch = 1; epoch <= 1; ++epoch) { float totalLoss = 0.0; torch::Tensor example; + int currentBatchNumber = 0; for (auto & batch : *dataLoader) { @@ -94,10 +96,15 @@ int main(int argc, char * argv[]) totalLoss += loss.item<float>(); loss.backward(); optimizer.step(); + + if (++currentBatchNumber*batchSize % 1000 == 0) + { + fmt::print("\rcurrent epoch : {:6.2f}%", 100.0*currentBatchNumber*batchSize/nbExamples); + std::fflush(stdout); + } } fmt::print("Epoch {} : loss={:.2f}\n", epoch, totalLoss); - std::cout << example << std::endl; } return 0; diff --git a/torch_modules/src/TestNetwork.cpp b/torch_modules/src/TestNetwork.cpp index 3e3c010a7634824181e205b69161afbaefd96585..f379c735a92aa17832bb24774f296e5f45b6aa7a 100644 --- a/torch_modules/src/TestNetwork.cpp +++ b/torch_modules/src/TestNetwork.cpp @@ -2,7 +2,7 @@ TestNetworkImpl::TestNetworkImpl(int nbOutputs, int focusedIndex) { - constexpr int embeddingsSize = 100; + constexpr int embeddingsSize = 30; wordEmbeddings = register_module("word_embeddings", torch::nn::Embedding(200000, embeddingsSize)); linear = register_module("linear", torch::nn::Linear(embeddingsSize, nbOutputs)); this->focusedIndex = focusedIndex;