Select Git revision
ExecClassifMonoView.py
-
Baptiste Bauvin authoredBaptiste Bauvin authored
predict.py 1.75 KiB
from argparse import ArgumentParser
import json
from pytorch_lightning import LightningModule
from torch.utils.data import DataLoader
import torch
import data
from model import Model
def main(hparams):
model = Model.load_from_checkpoint(hparams.checkpoint)
model.freeze()
label_vocab = model.label_vocab
sorted_labels = list(sorted(label_vocab.keys(), key=label_vocab.get))
with open(hparams.test_filename) as fp:
dataset = json.loads(fp.read())
dataset = dataset[:32]
int_texts, int_labels = data.to_int(model.tokenizer, model.label_vocab, model.hparams, dataset)
test_set = data.CustomDataset(int_texts, int_labels)
test_loader = DataLoader(test_set, batch_size=hparams.batch_size, collate_fn=model.collate_fn, shuffle=False)
def generate_predictions(model, loader):
predictions = []
model.eval()
total_loss = num = correct = 0
for x, y in loader:
#x = x.to(device)
#y = y.to(device)
with torch.no_grad():
y_scores = model(x)
predictions.extend(y_scores.cpu().tolist())
return predictions
predictions = generate_predictions(model, test_loader)
for i, article in enumerate(dataset):
article['topic-scores'] = {label: score for label, score in zip(sorted_labels, predictions[i])}
article['topic-pred'] = [label for label, score in zip(sorted_labels, predictions[i]) if score >= 0]
print(json.dumps(dataset, indent=2))
if __name__ == '__main__':
parser = ArgumentParser(add_help=False)
#parser.add_argument('--gpus', type=str, default=None)
parser.add_argument('--checkpoint', type=str, required=True)
parser.add_argument('--test_filename', type=str, required=True)
parser.add_argument('--batch_size', type=int, default=32)
hparams = parser.parse_args()
main(hparams)