Skip to content
Snippets Groups Projects
Commit 6e9186d7 authored by Benoit Favre's avatar Benoit Favre
Browse files

update doc

parent d2bf6dfb
No related branches found
No related tags found
No related merge requests found
...@@ -27,22 +27,25 @@ Training ...@@ -27,22 +27,25 @@ Training
-------- --------
``` ```
python trainier.py [options] usage: python trainer.py --name <name> --train_filename <path> [options]
optional arguments: optional arguments:
-h, --help show this help message and exit -h, --help show this help message and exit
--gpus GPUS --gpus <int> list of gpus to use (-1 = all in CUDA_VISIBLE_DEVICES)
--nodes NODES --nodes <int> number of nodes for distributed training (see pytorch_lightning doc)
--name NAME --name <str> experiment name
--fast_dev_run --fast_dev_run run one batch to check that training works
--train_filename TRAIN_FILENAME --train_filename <path> name of json file containing training/validation instances
--learning_rate LEARNING_RATE --learning_rate <float> learning rate (default=2e-5)
--batch_size BATCH_SIZE --batch_size <int> size of batch (default=32)
--epochs EPOCHS --epochs <int> number of epochs (default=20)
--valid_size VALID_SIZE --valid_size_percent <int> validation set size in % (default=10)
--max_len MAX_LEN --max_len <int> max sequence length (default=256)
--bert_flavor BERT_FLAVOR --bert_flavor <path> pretrained bert model (default=monologg/biobert_v1.1_pubmed
--selected_features SELECTED_FEATURES --selected_features <list> list of features to load from input (default=title abstract)
--dropout <float> dropout after bert
--loss <bce|f1> choose loss function [f1, bce] (default=f1)
--augment_data simulate missing abstract through augmentation (default=do not augment data)
``` ```
Example training command line: Example training command line:
...@@ -51,6 +54,8 @@ Example training command line: ...@@ -51,6 +54,8 @@ Example training command line:
python trainer.py --gpus=-1 --name test1 --train_filename ../scrappers/data/20200529/litcovid.json python trainer.py --gpus=-1 --name test1 --train_filename ../scrappers/data/20200529/litcovid.json
``` ```
Logs are saved in `lightning_logs/`, best `val_loss` checkpoints in `checkpoints/`.
pytorch-lightning provides a tensorboard logger. You can check it with pytorch-lightning provides a tensorboard logger. You can check it with
``` ```
tensorboard --logdir lightning_logs tensorboard --logdir lightning_logs
...@@ -60,6 +65,9 @@ Then point your browser to http://localhost:6006/. ...@@ -60,6 +65,9 @@ Then point your browser to http://localhost:6006/.
Generating predictions Generating predictions
---------------------- ----------------------
Give as input a json file containing articles without a "topics" field. It will be added with predictions.
``` ```
predict.py --checkpoint checkpoints/epoch\=0-val_loss\=0.2044.ckpt --test_filename ../scrappers/data/20200529/cord19-metadata.json > predicted.json python predict.py --checkpoint checkpoints/epoch\=0-val_loss\=0.2044.ckpt --test_filename ../scrappers/data/20200529/cord19-metadata.json > predicted.json
``` ```
Note that at this time, this only works with original training data available in the same relative path as was used for training.
...@@ -119,7 +119,7 @@ class Model(LightningModule): ...@@ -119,7 +119,7 @@ class Model(LightningModule):
parser.add_argument('--learning_rate', default=2e-5, type=float, help='learning rate (default=2e-5)') parser.add_argument('--learning_rate', default=2e-5, type=float, help='learning rate (default=2e-5)')
parser.add_argument('--batch_size', default=32, type=int, help='size of batch (default=32)') parser.add_argument('--batch_size', default=32, type=int, help='size of batch (default=32)')
parser.add_argument('--epochs', default=20, type=int, help='number of epochs (default=20)') parser.add_argument('--epochs', default=20, type=int, help='number of epochs (default=20)')
parser.add_argument('--valid_size_percent', default=10, type=int, help='validation set size in % (default=10)') parser.add_argument('--valid_size_percent', default=10, type=int, help='validation set size in %% (default=10)')
parser.add_argument('--max_len', default=256, type=int, help='max sequence length (default=256)') parser.add_argument('--max_len', default=256, type=int, help='max sequence length (default=256)')
parser.add_argument('--bert_flavor', default='monologg/biobert_v1.1_pubmed', type=str, help='pretrained bert model (default=monologg/biobert_v1.1_pubmed') parser.add_argument('--bert_flavor', default='monologg/biobert_v1.1_pubmed', type=str, help='pretrained bert model (default=monologg/biobert_v1.1_pubmed')
parser.add_argument('--selected_features', default=['title', 'abstract'], nargs='+', type=str, help='list of features to load from input (default=title abstract)') parser.add_argument('--selected_features', default=['title', 'abstract'], nargs='+', type=str, help='list of features to load from input (default=title abstract)')
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment