Source code for machine.evaluator.predictor

import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


[docs]class Predictor(object): def __init__(self, model, src_vocab, tgt_vocab): """ Predictor class to evaluate for a given model. Args: model (machine.models): trained model. This can be loaded from a checkpoint using `machine.util.checkpoint.load` src_vocab (machine.dataset.vocabulary.Vocabulary): source sequence vocabulary tgt_vocab (machine.dataset.vocabulary.Vocabulary): target sequence vocabulary """ self.model = model.to(device) self.model.eval() self.src_vocab = src_vocab self.tgt_vocab = tgt_vocab
[docs] def predict(self, src_seq): """ Make prediction given `src_seq` as input. Args: src_seq (list): list of tokens in source language Returns: tgt_seq (list): list of tokens in target language as predicted by the pre-trained model """ src_id_seq = torch.tensor([self.src_vocab.stoi[tok] for tok in src_seq], dtype=torch.long, device=device).view(1, -1) softmax_list, _, other = self.model(src_id_seq, [len(src_seq)]) length = other['length'][0] tgt_id_seq = [other['sequence'][di][0].data[0] for di in range(length)] tgt_seq = [self.tgt_vocab.itos[tok] for tok in tgt_id_seq] return tgt_seq