from __future__ import print_function, division
import copy
import torch
import torchtext
from machine.loss import NLLLoss
from machine.metrics import WordAccuracy, SequenceAccuracy
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
[docs]class Evaluator(object):
""" Class to evaluate models with given datasets.
Args:
loss (machine.loss, optional): loss for evaluator (default: machine.loss.NLLLoss)
metrics (machine.metrics, optional): metrics for evaluator (default
machine.metrics.WordAccuracy and SequenceAccuracy )
"""
def __init__(self, loss=[NLLLoss()], metrics=[
WordAccuracy(), SequenceAccuracy()]):
self.losses = loss
self.metrics = metrics
[docs] @staticmethod
def update_batch_metrics(metrics, other, target_variable):
"""
Update a list with metrics for current batch.
Args:
metrics (list): list with of machine.metric.Metric objects
other (dict): dict generated by forward pass of model to be evaluated
target_variable (dict): map of keys to different targets of model
Returns:
metrics (list): list with updated metrics
"""
# evaluate output symbols
outputs = other['sequence']
for metric in metrics:
metric.eval_batch(outputs, target_variable)
return metrics
[docs] def compute_batch_loss(self, decoder_outputs,
decoder_hidden, other, target_variable):
"""
Compute the loss for the current batch.
Args:
decoder_outputs (torch.Tensor): decoder outputs of a batch
decoder_hidden (torch.Tensor): decoder hidden states for a batch
other (dict): maps extra outputs to torch.Tensors
target_variable (dict): map of keys to different targets
Returns:
losses (list): a list with machine.loss.Loss objects
"""
losses = self.losses
for loss in losses:
loss.reset()
losses = self.update_loss(
losses, decoder_outputs, decoder_hidden, other, target_variable)
return losses
[docs] @staticmethod
def update_loss(losses, decoder_outputs,
decoder_hidden, other, target_variable):
"""
Update a list with losses for current batch
Args:
losses (list): a list with machine.loss.Loss objects
decoder_outputs (torch.Tensor): decoder outputs of a batch
decoder_hidden (torch.Tensor): decoder hidden states for a batch
other (dict): maps extra outputs to torch.Tensors
target_variable (dict): map of keys to different targets
Returns:
losses (list): a list with machine.loss.Loss objects
"""
for loss in losses:
loss.eval_batch(decoder_outputs, other, target_variable)
return losses
[docs] def evaluate(self, model, data_iterator, get_batch_data):
""" Evaluate a model on given dataset and return performance.
Args:
model (machine.models): model to evaluate
data_iterator (torchtext.data.Iterator): data iterator to evaluate against
Returns:
loss (float): loss of the given model on the given dataset
accuracy (float): accuracy of the given model on the given dataset
"""
# If the model was in train mode before this method was called, we make sure it still is
# after this method.
# Since we are passing data_iterator
# We evaluate on whole batches - so exhaust all batches first
# and store the initial point
# data_iterator_reset = False
initial_iteration = data_iterator.iterations
if initial_iteration > 1 and initial_iteration != len(data_iterator):
raise Warning("Passed in data_iterator in middle of iterations")
previous_train_mode = model.training
model.eval()
for loss in self.losses:
loss.reset()
losses = copy.deepcopy(self.losses)
for metric in self.metrics:
metric.reset()
metrics = copy.deepcopy(self.metrics)
# loop over batches
with torch.no_grad():
for batch in data_iterator:
input_variable, input_lengths, target_variable = get_batch_data(
batch)
decoder_outputs, decoder_hidden, other = model(
input_variable, input_lengths.tolist(), target_variable)
# Compute metric(s) over one batch
metrics = self.update_batch_metrics(
metrics, other, target_variable)
# Compute loss(es) over one batch
losses = self.update_loss(losses, decoder_outputs,
decoder_hidden, other, target_variable)
model.train(previous_train_mode)
return losses, metrics