Evaluator¶
evaluator¶
-
class
machine.evaluator.evaluator.
Evaluator
(loss=[<machine.loss.loss.NLLLoss object>], metrics=[<machine.metrics.metrics.WordAccuracy object>, <machine.metrics.metrics.SequenceAccuracy object>])[source]¶ Class to evaluate models with given datasets.
Parameters: - loss (machine.loss, optional) – loss for evaluator (default: machine.loss.NLLLoss)
- metrics (machine.metrics, optional) – metrics for evaluator (default machine.metrics.WordAccuracy and SequenceAccuracy )
-
compute_batch_loss
(decoder_outputs, decoder_hidden, other, target_variable)[source]¶ Compute the loss for the current batch.
Parameters: - decoder_outputs (torch.Tensor) – decoder outputs of a batch
- decoder_hidden (torch.Tensor) – decoder hidden states for a batch
- other (dict) – maps extra outputs to torch.Tensors
- target_variable (dict) – map of keys to different targets
Returns: a list with machine.loss.Loss objects
Return type: losses (list)
-
evaluate
(model, data_iterator, get_batch_data)[source]¶ Evaluate a model on given dataset and return performance.
Parameters: - model (machine.models) – model to evaluate
- data_iterator (torchtext.data.Iterator) – data iterator to evaluate against
Returns: loss of the given model on the given dataset accuracy (float): accuracy of the given model on the given dataset
Return type: loss (float)
-
static
update_batch_metrics
(metrics, other, target_variable)[source]¶ Update a list with metrics for current batch.
Parameters: Returns: list with updated metrics
Return type: metrics (list)
-
static
update_loss
(losses, decoder_outputs, decoder_hidden, other, target_variable)[source]¶ Update a list with losses for current batch
Parameters: - losses (list) – a list with machine.loss.Loss objects
- decoder_outputs (torch.Tensor) – decoder outputs of a batch
- decoder_hidden (torch.Tensor) – decoder hidden states for a batch
- other (dict) – maps extra outputs to torch.Tensors
- target_variable (dict) – map of keys to different targets
Returns: a list with machine.loss.Loss objects
Return type: losses (list)