Skip to content
Snippets Groups Projects
Select Git revision
  • 92baf4b9f1ce3072deb3f974c070c9e08ad077d6
  • master default
  • object
  • develop protected
  • private_algos
  • cuisine
  • SMOTE
  • revert-76c4cca5
  • archive protected
  • no_graphviz
  • 0.0.1
11 results

ExecClassifMonoView.py

Blame
  • predict.py 1.75 KiB
    from argparse import ArgumentParser
    import json
    
    from pytorch_lightning import LightningModule
    from torch.utils.data import DataLoader
    import torch
    
    import data
    from model import Model
    
    def main(hparams):
      model = Model.load_from_checkpoint(hparams.checkpoint)
      model.freeze()
    
      label_vocab = model.label_vocab
      sorted_labels = list(sorted(label_vocab.keys(), key=label_vocab.get))
    
      with open(hparams.test_filename) as fp:
        dataset = json.loads(fp.read())
    
      dataset = dataset[:32]
    
      int_texts, int_labels = data.to_int(model.tokenizer, model.label_vocab, model.hparams, dataset)
    
      test_set = data.CustomDataset(int_texts, int_labels)
      test_loader = DataLoader(test_set, batch_size=hparams.batch_size, collate_fn=model.collate_fn, shuffle=False)
    
      def generate_predictions(model, loader):
        predictions = []
        model.eval()
        total_loss = num = correct = 0
        for x, y in loader:
          #x = x.to(device)
          #y = y.to(device)
          with torch.no_grad():
            y_scores = model(x)
            predictions.extend(y_scores.cpu().tolist())
        return predictions
    
      predictions = generate_predictions(model, test_loader)
      for i, article in enumerate(dataset):
        article['topic-scores'] = {label: score for label, score in zip(sorted_labels, predictions[i])}
        article['topic-pred'] = [label for label, score in zip(sorted_labels, predictions[i]) if score >= 0]
    
      print(json.dumps(dataset, indent=2))
    
    
    if __name__ == '__main__':
      parser = ArgumentParser(add_help=False)
      #parser.add_argument('--gpus', type=str, default=None)
      parser.add_argument('--checkpoint', type=str, required=True)
      parser.add_argument('--test_filename', type=str, required=True)
      parser.add_argument('--batch_size', type=int, default=32)
    
      hparams = parser.parse_args()
      main(hparams)