Skip to content
Snippets Groups Projects
Commit 6e9186d7 authored by Benoit Favre's avatar Benoit Favre
Browse files

update doc

parent d2bf6dfb
No related branches found
No related tags found
No related merge requests found
......@@ -27,22 +27,25 @@ Training
--------
```
python trainier.py [options]
usage: python trainer.py --name <name> --train_filename <path> [options]
optional arguments:
-h, --help show this help message and exit
--gpus GPUS
--nodes NODES
--name NAME
--fast_dev_run
--train_filename TRAIN_FILENAME
--learning_rate LEARNING_RATE
--batch_size BATCH_SIZE
--epochs EPOCHS
--valid_size VALID_SIZE
--max_len MAX_LEN
--bert_flavor BERT_FLAVOR
--selected_features SELECTED_FEATURES
--gpus <int> list of gpus to use (-1 = all in CUDA_VISIBLE_DEVICES)
--nodes <int> number of nodes for distributed training (see pytorch_lightning doc)
--name <str> experiment name
--fast_dev_run run one batch to check that training works
--train_filename <path> name of json file containing training/validation instances
--learning_rate <float> learning rate (default=2e-5)
--batch_size <int> size of batch (default=32)
--epochs <int> number of epochs (default=20)
--valid_size_percent <int> validation set size in % (default=10)
--max_len <int> max sequence length (default=256)
--bert_flavor <path> pretrained bert model (default=monologg/biobert_v1.1_pubmed
--selected_features <list> list of features to load from input (default=title abstract)
--dropout <float> dropout after bert
--loss <bce|f1> choose loss function [f1, bce] (default=f1)
--augment_data simulate missing abstract through augmentation (default=do not augment data)
```
Example training command line:
......@@ -51,6 +54,8 @@ Example training command line:
python trainer.py --gpus=-1 --name test1 --train_filename ../scrappers/data/20200529/litcovid.json
```
Logs are saved in `lightning_logs/`, best `val_loss` checkpoints in `checkpoints/`.
pytorch-lightning provides a tensorboard logger. You can check it with
```
tensorboard --logdir lightning_logs
......@@ -60,6 +65,9 @@ Then point your browser to http://localhost:6006/.
Generating predictions
----------------------
Give as input a json file containing articles without a "topics" field. It will be added with predictions.
```
predict.py --checkpoint checkpoints/epoch\=0-val_loss\=0.2044.ckpt --test_filename ../scrappers/data/20200529/cord19-metadata.json > predicted.json
python predict.py --checkpoint checkpoints/epoch\=0-val_loss\=0.2044.ckpt --test_filename ../scrappers/data/20200529/cord19-metadata.json > predicted.json
```
Note that at this time, this only works with original training data available in the same relative path as was used for training.
......@@ -119,7 +119,7 @@ class Model(LightningModule):
parser.add_argument('--learning_rate', default=2e-5, type=float, help='learning rate (default=2e-5)')
parser.add_argument('--batch_size', default=32, type=int, help='size of batch (default=32)')
parser.add_argument('--epochs', default=20, type=int, help='number of epochs (default=20)')
parser.add_argument('--valid_size_percent', default=10, type=int, help='validation set size in % (default=10)')
parser.add_argument('--valid_size_percent', default=10, type=int, help='validation set size in %% (default=10)')
parser.add_argument('--max_len', default=256, type=int, help='max sequence length (default=256)')
parser.add_argument('--bert_flavor', default='monologg/biobert_v1.1_pubmed', type=str, help='pretrained bert model (default=monologg/biobert_v1.1_pubmed')
parser.add_argument('--selected_features', default=['title', 'abstract'], nargs='+', type=str, help='list of features to load from input (default=title abstract)')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment