Topic classifier for biomedical articles
Multilabel topic classifier for medical articles.
This system learns a topic classifier based for articles labelelled with multiple topics. The included model uses a variant of BERT pre-trained on medical texts, and finetunes it on task instances.
Data
Input data is expected to be a json-formatted file containing a list of articles. Each article should have a title, an abstract and a topics field containing a list of topics. Other fields are ignored.
[
{
"title": "this is a title",
"abstract": "this is an abstract",
"topics": ["topic1", "topic2", "topic3"]
...
},
...
]
Installing
virtualenv -p python3 env
source env/bin/activate
pip install -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html
If updates have broken dependencies, you may use requirements-freeze.txt instead of requirements.txt. Note that we use pytorch with cuda 10.1; you may change requirements.txt to use a different version depending on your setup (see https://pytorch.org/get-started/locally/).
Training
usage: python trainer.py --name <name> --train_filename <path> [options]
optional arguments:
-h, --help show this help message and exit
--gpus GPUS ids of GPUs to use (use -1 for all available GPUs, defaults to CPU)
--nodes NODES number of computation nodes for distributed training (see lightning docs, defaults to 1)
--name NAME name of experiment (required)
--fast_dev_run run a single batch through the whole training loop for catching bugs
--seed SEED set global random seed (defaults to 123)
--stem STEM stem name of json files containing training/validation/test instances (<stem>.{train,valid,test})
--learning_rate LEARNING_RATE learning rate (default=2e-5)
--batch_size BATCH_SIZE size of batch (default=32)
--epochs EPOCHS number of epochs (default=20)
--valid_size_percent VALID_SIZE_PERCENT validation set size in % (default=10)
--max_len MAX_LEN max sequence length (default=256)
--selected_features SELECTED_FEATURES [SELECTED_FEATURES ...] list of features to load from input (default=title abstract)
--dropout DROPOUT dropout after bert
--loss LOSS choose loss function [f1, bce] (default=bce)
--augment_data simulate missing abstract through augmentation (default=do not augment data)
--transfer TRANSFER transfer weights from checkpoint (default=do not transfer)
--model MODEL model type [rnn, cnn, bert] (default=bert)
--bert_flavor BERT_FLAVOR pretrained bert model (default=monologg/biobert_v1.1_pubmed)
--rnn_embed_size RNN_EMBED_SIZE rnn embedding size (default=128)
--rnn_hidden_size RNN_HIDDEN_SIZE rnn hidden size (default=128)
--rnn_layers RNN_LAYERS rnn number of layers (default=1)
--cnn_embed_size CNN_EMBED_SIZE cnn embedding size (default=128)
--cnn_hidden_size CNN_HIDDEN_SIZE cnn hidden size (default=128)
--cnn_layers CNN_LAYERS cnn number of layers (default=1)
--cnn_kernel_size CNN_KERNEL_SIZE cnn kernel size (default=3)
--scheduler SCHEDULER learning rate schedule [warmup_linear] (default=fixed learning rate)
--scheduler_warmup SCHEDULER_WARMUP learning rate schedule warmup epochs (default=1)
Example training command line:
python trainer.py --gpus=-1 --name test1 --stem ../scrappers/data/20200615/folds/litcovid-0
Logs are saved in logs/
, for each experiment there is a file run.json
with hyperparameters and metrics at each epoch, and two checkpoints, the best and last checkpoints.
The best checkpoint is used for testing.
The logger provides a simplified tensorboard-like facility. Run it with
python logger.py
Then point your browser to http://localhost:6006/.
Generating predictions
Give as input a json file containing articles without a "topics" field. It will be added with predictions.
python predict.py --checkpoint checkpoints/epoch\=0-val_loss\=0.2044.ckpt --test_filename ../scrappers/data/20200529/cord19-metadata.json > predicted.json
Note that at this time, this only works with original training data available in the same relative path as was used for training.