Source code for machine.models.TopKDecoder

import torch
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def _inflate(tensor, times, dim):
    """
    Given a tensor, 'inflates' it along the given dimension by replicating each slice specified number of times (in-place)

    Args:
        tensor: A :class:`Tensor` to inflate
        times: number of repetitions
        dim: axis for inflation (default=0)

    Returns:
        A :class:`Tensor`

    Examples::
        >> a = torch.LongTensor([[1, 2], [3, 4]])
        >> a
        1   2
        3   4
        [torch.LongTensor of size 2x2]
        >> b = ._inflate(a, 2, dim=1)
        >> b
        1   2   1   2
        3   4   3   4
        [torch.LongTensor of size 2x4]
        >> c = _inflate(a, 2, dim=0)
        >> c
        1   2
        3   4
        1   2
        3   4
        [torch.LongTensor of size 4x2]

    """
    repeat_dims = [1] * tensor.dim()
    repeat_dims[dim] = times
    return tensor.repeat(*repeat_dims)


[docs]class TopKDecoder(torch.nn.Module): r""" Top-K decoding with beam search. Args: decoder_rnn (DecoderRNN): An object of DecoderRNN used for decoding. k (int): Size of the beam. Inputs: inputs, encoder_hidden, encoder_outputs, function, teacher_forcing_ratio - **inputs** (seq_len, batch, input_size): list of sequences, whose length is the batch size and within which each sequence is a list of token IDs. It is used for teacher forcing when provided. (default is `None`) - **encoder_hidden** (batch, seq_len, hidden_size): tensor containing the features in the hidden state `h` of encoder. Used as the initial hidden state of the decoder. - **encoder_outputs** (batch, seq_len, hidden_size): tensor with containing the outputs of the encoder. Used for attention mechanism (default is `None`). - **function** (torch.nn.Module): A function used to generate symbols from RNN hidden state (default is `torch.nn.functional.log_softmax`). - **teacher_forcing_ratio** (float): The probability that teacher forcing will be used. A random number is drawn uniformly from 0-1 for every decoding token, and if the sample is smaller than the given value, teacher forcing would be used (default is 0). Outputs: decoder_outputs, decoder_hidden, ret_dict - **decoder_outputs** (batch): batch-length list of tensors with size (max_length, hidden_size) containing the outputs of the decoder. - **decoder_hidden** (num_layers * num_directions, batch, hidden_size): tensor containing the last hidden state of the decoder. - **ret_dict**: dictionary containing additional information as follows {*length* : list of integers representing lengths of output sequences, *topk_length*: list of integers representing lengths of beam search sequences, *sequence* : list of sequences, where each sequence is a list of predicted token IDs, *topk_sequence* : list of beam search sequences, each beam is a list of token IDs, *inputs* : target outputs if provided for decoding}. """ def __init__(self, decoder_rnn, k): super(TopKDecoder, self).__init__() self.rnn = decoder_rnn self.k = k self.hidden_size = self.rnn.hidden_size self.V = self.rnn.output_size self.SOS = self.rnn.sos_id self.EOS = self.rnn.eos_id
[docs] def forward(self, inputs=None, encoder_hidden=None, encoder_outputs=None, function=F.log_softmax, teacher_forcing_ratio=0, retain_output_probs=True): """ Forward rnn for MAX_LENGTH steps. Look at :func:`machine.models.DecoderRNN.DecoderRNN.forward_rnn` for details. """ inputs, batch_size, max_length = self.rnn._validate_args(inputs, encoder_hidden, encoder_outputs, function, teacher_forcing_ratio) self.pos_index = (torch.tensor(range(batch_size), dtype=torch.long, device=device) * self.k).view(-1, 1) # Inflate the initial hidden states to be of size: b*k x h encoder_hidden = self.rnn._init_state(encoder_hidden) if encoder_hidden is None: hidden = None else: if isinstance(encoder_hidden, tuple): hidden = tuple([_inflate(h, self.k, 1) for h in encoder_hidden]) else: hidden = _inflate(encoder_hidden, self.k, 1) # ... same idea for encoder_outputs and decoder_outputs if self.rnn.use_attention: inflated_encoder_outputs = _inflate(encoder_outputs, self.k, 0) else: inflated_encoder_outputs = None # Initialize the scores; for the first step, # ignore the inflated copies to avoid duplicate entries in the top k sequence_scores = torch.full([batch_size * self.k, 1], fill_value=-float('inf'), device=device) sequence_scores.index_fill_(0, torch.tensor( [i * self.k for i in range(0, batch_size)], dtype=torch.long, device=device), 0.0) sequence_scores = sequence_scores # Initialize the input vector input_var = torch.transpose(torch.tensor( [[self.SOS] * batch_size * self.k], dtype=torch.long, device=device), 0, 1) # Store decisions for backtracking stored_outputs = list() stored_scores = list() stored_predecessors = list() stored_emitted_symbols = list() stored_hidden = list() for _ in range(0, max_length): # Run the RNN one step forward log_softmax_output, hidden, _ = self.rnn.forward_step(input_var, hidden, inflated_encoder_outputs, function=function) # If doing local backprop (e.g. supervised training), retain the # output layer if retain_output_probs: stored_outputs.append(log_softmax_output) # To get the full sequence scores for the new candidates, add the # local scores for t_i to the predecessor scores for t_(i-1) sequence_scores = _inflate(sequence_scores, self.V, 1) sequence_scores += log_softmax_output.squeeze(1) scores, candidates = sequence_scores.view( batch_size, -1).topk(self.k, dim=1) # Reshape input = (bk, 1) and sequence_scores = (bk, 1) input_var = (candidates % self.V).view(batch_size * self.k, 1) sequence_scores = scores.view(batch_size * self.k, 1) # Update fields for next timestep predecessors = (candidates / self.V + self.pos_index.expand_as(candidates) ).view(batch_size * self.k, 1) if isinstance(hidden, tuple): hidden = tuple( [h.index_select(1, predecessors.squeeze()) for h in hidden]) else: hidden = hidden.index_select(1, predecessors.squeeze()) # Update sequence scores and erase scores for end-of-sentence # symbol so that they aren't expanded stored_scores.append(sequence_scores.clone()) eos_indices = input_var.data.eq(self.EOS) if eos_indices.nonzero().dim() > 0: sequence_scores.data.masked_fill_(eos_indices, -float('inf')) # Cache results for backtracking stored_predecessors.append(predecessors) stored_emitted_symbols.append(input_var) stored_hidden.append(hidden) # Do backtracking to return the optimal values output, h_t, h_n, s, l, p = self._backtrack(stored_outputs, stored_hidden, stored_predecessors, stored_emitted_symbols, stored_scores, batch_size, self.hidden_size) # Build return objects decoder_outputs = [step[:, 0, :] for step in output] if isinstance(h_n, tuple): decoder_hidden = tuple([h[:, :, 0, :] for h in h_n]) else: decoder_hidden = h_n[:, :, 0, :] metadata = {} metadata['inputs'] = inputs metadata['output'] = output metadata['h_t'] = h_t metadata['score'] = s metadata['topk_length'] = l metadata['topk_sequence'] = p metadata['length'] = [seq_len[0] for seq_len in l] metadata['sequence'] = [seq[0] for seq in p] return decoder_outputs, decoder_hidden, metadata
def _backtrack(self, nw_output, nw_hidden, predecessors, symbols, scores, b, hidden_size): """Backtracks over batch to generate optimal k-sequences. Args: nw_output [(batch*k, vocab_size)] * sequence_length: A Tensor of outputs from network nw_hidden [(num_layers, batch*k, hidden_size)] * sequence_length: A Tensor of hidden states from network predecessors [(batch*k)] * sequence_length: A Tensor of predecessors symbols [(batch*k)] * sequence_length: A Tensor of predicted tokens scores [(batch*k)] * sequence_length: A Tensor containing sequence scores for every token t = [0, ... , seq_len - 1] b: Size of the batch hidden_size: Size of the hidden state Returns: output [(batch, k, vocab_size)] * sequence_length: A list of the output probabilities (p_n) from the last layer of the RNN, for every n = [0, ... , seq_len - 1] h_t [(batch, k, hidden_size)] * sequence_length: A list containing the output features (h_n) from the last layer of the RNN, for every n = [0, ... , seq_len - 1] h_n(batch, k, hidden_size): A Tensor containing the last hidden state for all top-k sequences. score [batch, k]: A list containing the final scores for all top-k sequences length [batch, k]: A list specifying the length of each sequence in the top-k candidates p (batch, k, sequence_len): A Tensor containing predicted sequence """ lstm = isinstance(nw_hidden[0], tuple) # initialize return variables given different types output = list() h_t = list() p = list() # Placeholder for last hidden state of top-k sequences. # If a (top-k) sequence ends early in decoding, `h_n` contains # its hidden state when it sees EOS. Otherwise, `h_n` contains # the last hidden state of decoding. if lstm: state_size = nw_hidden[0][0].size() h_n = tuple([torch.zeros(state_size, device=device), torch.zeros(state_size, device=device)]) else: h_n = torch.zeros(nw_hidden[0].size(), device=device) # Placeholder for lengths of top-k sequences l = [[self.rnn.max_length] * self.k for _ in range(b)] # Similar to `h_n` # the last step output of the beams are not sorted # thus they are sorted here sorted_score, sorted_idx = scores[-1].view(b, self.k).topk(self.k) # initialize the sequence scores with the sorted last step beam scores s = sorted_score.clone() batch_eos_found = [0] * b # the number of EOS found # in the backward loop below for each batch t = self.rnn.max_length - 1 # initialize the back pointer with the sorted order of the last step beams. # add self.pos_index for indexing variable with b*k as the first # dimension. t_predecessors = ( sorted_idx + self.pos_index.expand_as(sorted_idx)).view(b * self.k) while t >= 0: # Re-order the variables with the back pointer current_output = nw_output[t].index_select(0, t_predecessors) if lstm: current_hidden = tuple( [h.index_select(1, t_predecessors) for h in nw_hidden[t]]) else: current_hidden = nw_hidden[t].index_select(1, t_predecessors) current_symbol = symbols[t].index_select(0, t_predecessors) # Re-order the back pointer of the previous step with the back pointer of # the current step t_predecessors = predecessors[t].index_select( 0, t_predecessors).squeeze() # This tricky block handles dropped sequences that see EOS earlier. # The basic idea is summarized below: # # Terms: # Ended sequences = sequences that see EOS early and dropped # Survived sequences = sequences in the last step of the beams # # Although the ended sequences are dropped during decoding, # their generated symbols and complete backtracking information are still # in the backtracking variables. # For each batch, everytime we see an EOS in the backtracking process, # 1. If there is survived sequences in the return variables, replace # the one with the lowest survived sequence score with the new ended # sequences # 2. Otherwise, replace the ended sequence with the lowest sequence # score with the new ended sequence # eos_indices = symbols[t].data.squeeze(1).eq(self.EOS).nonzero() if eos_indices.dim() > 0: for i in range(eos_indices.size(0) - 1, -1, -1): # Indices of the EOS symbol for both variables # with b*k as the first dimension, and b, k for # the first two dimensions idx = eos_indices[i].item() b_idx = int(idx / self.k) # The indices of the replacing position # according to the replacement strategy noted above res_k_idx = self.k - (batch_eos_found[b_idx] % self.k) - 1 batch_eos_found[b_idx] += 1 res_idx = b_idx * self.k + res_k_idx # Replace the old information in return variables # with the new ended sequence information # TODO: Check this still works (this if was added for # torch 1.0 but might have unforseen consequences) if t_predecessors.dim() > 0: t_predecessors[res_idx] = predecessors[t][idx] else: t_predecessors = predecessors[t][idx] current_output[res_idx, :] = nw_output[t][idx, :] if lstm: current_hidden[0][:, res_idx, :] = nw_hidden[t][0][:, idx, :] current_hidden[1][:, res_idx, :] = nw_hidden[t][1][:, idx, :] h_n[0][:, res_idx, :] = nw_hidden[t][0][:, idx, :].data h_n[1][:, res_idx, :] = nw_hidden[t][1][:, idx, :].data else: current_hidden[:, res_idx, :] = nw_hidden[t][:, idx, :] h_n[:, res_idx, :] = nw_hidden[t][:, idx, :].data current_symbol[res_idx, :] = symbols[t][idx] s[b_idx, res_k_idx] = scores[t][idx].item() l[b_idx][res_k_idx] = t + 1 # record the back tracked results output.append(current_output) h_t.append(current_hidden) p.append(current_symbol) t -= 1 # Sort and re-order again as the added ended sequences may change # the order (very unlikely) s, re_sorted_idx = s.topk(self.k) for b_idx in range(b): l[b_idx] = [l[b_idx][k_idx.item()] for k_idx in re_sorted_idx[b_idx, :]] re_sorted_idx = ( re_sorted_idx + self.pos_index.expand_as(re_sorted_idx)).view(b * self.k) # Reverse the sequences and re-order at the same time # It is reversed because the backtracking happens in reverse time order output = [step.index_select(0, re_sorted_idx).view(b, self.k, -1) for step in reversed(output)] p = [step.index_select(0, re_sorted_idx).view( b, self.k, -1) for step in reversed(p)] if lstm: h_t = [tuple([h.index_select(1, re_sorted_idx).view(-1, b, self.k, hidden_size) for h in step]) for step in reversed(h_t)] h_n = tuple([h.index_select(1, re_sorted_idx.data).view(-1, b, self.k, hidden_size) for h in h_n]) else: h_t = [step.index_select(1, re_sorted_idx).view(-1, b, self.k, hidden_size) for step in reversed(h_t)] h_n = h_n.index_select( 1, re_sorted_idx.data).view(-1, b, self.k, hidden_size) s = s.data return output, h_t, h_n, s, l, p def _mask_symbol_scores(self, score, idx, masking_score=-float('inf')): score[idx] = masking_score def _mask(self, tensor, idx, dim=0, masking_score=-float('inf')): if len(idx.size()) > 0: indices = idx[:, 0] tensor.index_fill_(dim, indices, masking_score)