Skip to content
Snippets Groups Projects
Select Git revision
  • ee81e2518dda85baa840fd2ad1848e2e5d44bb87
  • master default protected
2 results

topic-classifier

Topic classifier for biomedical articles

Multilabel topic classifier for medical articles.

This system learns a topic classifier based for articles labelelled with multiple topics. The included model uses a variant of BERT pre-trained on medical texts, and finetunes it on task instances.

Data

Input data is expected to be a json-formatted file containing a list of articles. Each article should have a title, an abstract and a topics field containing a list of topics. Other fields are ignored.

[
  {
    "title": "this is a title",
    "abstract": "this is an abstract",
    "topics": ["topic1", "topic2", "topic3"]
    ...
  },
  ...
]

Installing

virtualenv -p python3 env
source env/bin/activate
pip install -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html

If updates have broken dependencies, you may use requirements-freeze.txt instead of requirements.txt. Note that we use pytorch with cuda 10.1; you may change requirements.txt to use a different version depending on your setup (see https://pytorch.org/get-started/locally/).

Training

usage: python trainer.py --name <name> --train_filename <path> [options]

optional arguments:
  -h, --help            show this help message and exit
  --gpus GPUS           ids of GPUs to use (use -1 for all available GPUs, defaults to CPU)
  --nodes NODES         number of computation nodes for distributed training (see lightning docs, defaults to 1)
  --name NAME           name of experiment (required)
  --fast_dev_run        run a single batch through the whole training loop for catching bugs
  --seed SEED           set global random seed (defaults to 123)
  --stem STEM           stem name of json files containing training/validation/test instances (<stem>.{train,valid,test})
  --learning_rate LEARNING_RATE learning rate (default=2e-5)
  --batch_size BATCH_SIZE size of batch (default=32)
  --epochs EPOCHS       number of epochs (default=20)
  --valid_size_percent VALID_SIZE_PERCENT validation set size in % (default=10)
  --max_len MAX_LEN     max sequence length (default=256)
  --selected_features SELECTED_FEATURES [SELECTED_FEATURES ...] list of features to load from input (default=title abstract)
  --dropout DROPOUT     dropout after bert
  --loss LOSS           choose loss function [f1, bce] (default=bce)
  --augment_data        simulate missing abstract through augmentation (default=do not augment data)
  --transfer TRANSFER   transfer weights from checkpoint (default=do not transfer)
  --model MODEL         model type [rnn, cnn, bert] (default=bert)
  --bert_flavor BERT_FLAVOR pretrained bert model (default=monologg/biobert_v1.1_pubmed)
  --rnn_embed_size RNN_EMBED_SIZE rnn embedding size (default=128)
  --rnn_hidden_size RNN_HIDDEN_SIZE rnn hidden size (default=128)
  --rnn_layers RNN_LAYERS rnn number of layers (default=1)
  --cnn_embed_size CNN_EMBED_SIZE cnn embedding size (default=128)
  --cnn_hidden_size CNN_HIDDEN_SIZE cnn hidden size (default=128)
  --cnn_layers CNN_LAYERS cnn number of layers (default=1)
  --cnn_kernel_size CNN_KERNEL_SIZE cnn kernel size (default=3)
  --scheduler SCHEDULER learning rate schedule [warmup_linear] (default=fixed learning rate)
  --scheduler_warmup SCHEDULER_WARMUP learning rate schedule warmup epochs (default=1)

Example training command line:

python trainer.py --gpus=-1 --name test1 --stem ../scrappers/data/20200615/folds/litcovid-0

Logs are saved in logs/, for each experiment there is a file run.json with hyperparameters and metrics at each epoch, and two checkpoints, the best and last checkpoints. The best checkpoint is used for testing.

The logger provides a simplified tensorboard-like facility. Run it with

python logger.py

Then point your browser to http://localhost:6006/.

Generating predictions

Give as input a json file containing articles without a "topics" field. It will be added with predictions.

python predict.py --checkpoint checkpoints/epoch\=0-val_loss\=0.2044.ckpt --test_filename ../scrappers/data/20200529/cord19-metadata.json > predicted.json

Note that at this time, this only works with original training data available in the same relative path as was used for training.