Source code for machine.models.DecoderRNN

import random

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from .attention import Attention
from .baseRNN import BaseRNN

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


[docs]class DecoderRNN(BaseRNN): """ Provides functionality for decoding in a seq2seq framework, with an option for attention. Args: vocab_size (int): size of the vocabulary max_len (int): a maximum allowed length for the sequence to be processed hidden_size (int): the number of features in the hidden state `h` sos_id (int): index of the start of sentence symbol eos_id (int): index of the end of sentence symbol n_layers (int, optional): number of recurrent layers (default: 1) rnn_cell (str, optional): type of RNN cell (default: gru) bidirectional (bool, optional): if the encoder is bidirectional (default False) input_dropout_p (float, optional): dropout probability for the input sequence (default: 0) dropout_p (float, optional): dropout probability for the output sequence (default: 0) use_attention(bool, optional): flag indication whether to use attention mechanism or not (default: false) full_focus(bool, optional): flag indication whether to use full attention mechanism or not (default: false) Attributes: KEY_ATTN_SCORE (str): key used to indicate attention weights in `ret_dict` KEY_LENGTH (str): key used to indicate a list representing lengths of output sequences in `ret_dict` KEY_SEQUENCE (str): key used to indicate a list of sequences in `ret_dict` Inputs: inputs, encoder_hidden, encoder_outputs, function, teacher_forcing_ratio - **inputs** (batch, seq_len, input_size): list of sequences, whose length is the batch size and within which each sequence is a list of token IDs. It is used for teacher forcing when provided. (default `None`) - **encoder_hidden** (num_layers * num_directions, batch_size, hidden_size): tensor containing the features in the hidden state `h` of encoder. Used as the initial hidden state of the decoder. (default `None`) - **encoder_outputs** (batch, seq_len, hidden_size): tensor with containing the outputs of the encoder. Used for attention mechanism (default is `None`). - **function** (torch.nn.Module): A function used to generate symbols from RNN hidden state (default is `torch.nn.functional.log_softmax`). - **teacher_forcing_ratio** (float): The probability that teacher forcing will be used. A random number is drawn uniformly from 0-1 for every decoding token, and if the sample is smaller than the given value, teacher forcing would be used (default is 0). Outputs: decoder_outputs, decoder_hidden, ret_dict - **decoder_outputs** (seq_len, batch, vocab_size): list of tensors with size (batch_size, vocab_size) containing the outputs of the decoding function. - **decoder_hidden** (num_layers * num_directions, batch, hidden_size): tensor containing the last hidden state of the decoder. - **ret_dict**: dictionary containing additional information as follows {*KEY_LENGTH* : list of integers representing lengths of output sequences, *KEY_SEQUENCE* : list of sequences, where each sequence is a list of predicted token IDs }. """ KEY_ATTN_SCORE = 'attention_score' KEY_LENGTH = 'length' KEY_SEQUENCE = 'sequence' def __init__(self, vocab_size, max_len, hidden_size, sos_id, eos_id, n_layers=1, rnn_cell='gru', bidirectional=False, input_dropout_p=0, dropout_p=0, use_attention=False, attention_method=None, full_focus=False): super(DecoderRNN, self).__init__(vocab_size, max_len, hidden_size, input_dropout_p, dropout_p, n_layers, rnn_cell) self.bidirectional_encoder = bidirectional input_size = hidden_size if use_attention and attention_method is None: raise ValueError( "Method for computing attention should be provided") self.attention_method = attention_method self.full_focus = full_focus # increase input size decoder if attention is applied before decoder # rnn if use_attention == 'pre-rnn' and not full_focus: input_size *= 2 self.rnn = self.rnn_cell(input_size, hidden_size, n_layers, batch_first=True, dropout=dropout_p) self.output_size = vocab_size self.max_length = max_len self.use_attention = use_attention self.eos_id = eos_id self.sos_id = sos_id self.init_input = None self.embedding = nn.Embedding(self.output_size, self.hidden_size) if use_attention: self.attention = Attention(self.hidden_size, self.attention_method) else: self.attention = None if use_attention == 'post-rnn': self.out = nn.Linear(2 * self.hidden_size, self.output_size) else: self.out = nn.Linear(self.hidden_size, self.output_size) if self.full_focus: self.ffocus_merge = nn.Linear( 2 * self.hidden_size, hidden_size)
[docs] def forward_step(self, input_var, hidden, encoder_outputs, function, **kwargs): """ Performs one or multiple forward decoder steps. Args: input_var (torch.tensor): Variable containing the input(s) to the decoder RNN hidden (torch.tensor): Variable containing the previous decoder hidden state. encoder_outputs (torch.tensor): Variable containing the target outputs of the decoder RNN function (torch.tensor): Activation function over the last output of the decoder RNN at every time step. Returns: predicted_softmax: The output softmax distribution at every time step of the decoder RNN hidden: The hidden state at every time step of the decoder RNN attn: The attention distribution at every time step of the decoder RNN """ batch_size = input_var.size(0) output_size = input_var.size(1) embedded = self.embedding(input_var) embedded = self.input_dropout(embedded) if self.use_attention == 'pre-rnn': h = hidden if isinstance(hidden, tuple): h, c = hidden # Apply the attention method to get the attention vector and weighted context vector. Provide decoder step for hard attention # transpose to get batch at the second index context, attn = self.attention( h[-1:].transpose(0, 1), encoder_outputs, **kwargs) combined_input = torch.cat((context, embedded), dim=2) if self.full_focus: merged_input = F.relu(self.ffocus_merge(combined_input)) combined_input = torch.mul(context, merged_input) output, hidden = self.rnn(combined_input, hidden) elif self.use_attention == 'post-rnn': output, hidden = self.rnn(embedded, hidden) # Apply the attention method to get the attention vector and # weighted context vector. Provide decoder step for hard attention context, attn = self.attention(output, encoder_outputs, **kwargs) output = torch.cat((context, output), dim=2) elif not self.use_attention: attn = None output, hidden = self.rnn(embedded, hidden) predicted_softmax = function(self.out( output.contiguous().view(-1, self.out.in_features)), dim=1).view(batch_size, output_size, -1) return predicted_softmax, hidden, attn
[docs] def forward(self, inputs=None, encoder_hidden=None, encoder_outputs=None, function=F.log_softmax, teacher_forcing_ratio=0): ret_dict = dict() if self.use_attention: ret_dict[DecoderRNN.KEY_ATTN_SCORE] = list() inputs, batch_size, max_length = self._validate_args(inputs, encoder_hidden, encoder_outputs, function, teacher_forcing_ratio) decoder_hidden = self._init_state(encoder_hidden) use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False decoder_outputs = [] sequence_symbols = [] lengths = np.array([max_length] * batch_size) def decode(step, step_output, step_attn): decoder_outputs.append(step_output) if self.use_attention: ret_dict[DecoderRNN.KEY_ATTN_SCORE].append(step_attn) symbols = decoder_outputs[-1].topk(1)[1] sequence_symbols.append(symbols) eos_batches = symbols.data.eq(self.eos_id) if eos_batches.dim() > 0: eos_batches = eos_batches.cpu().view(-1).numpy() update_idx = ((lengths > step) & eos_batches) != 0 lengths[update_idx] = len(sequence_symbols) return symbols # When we use pre-rnn attention we must unroll the decoder. We need to calculate the attention based on # the previous hidden state, before we can calculate the next hidden state. # We also need to unroll when we don't use teacher forcing. We need perform the decoder steps # one-by-one since the output needs to be copied to the input of the # next step. if self.use_attention == 'pre-rnn' or not use_teacher_forcing: unrolling = True else: unrolling = False if unrolling: symbols = None for di in range(max_length): # We always start with the SOS symbol as input. We need to add extra dimension of length 1 for the number of decoder steps (1 in this case) # When we use teacher forcing, we always use the target input. if di == 0 or use_teacher_forcing: decoder_input = inputs[:, di].unsqueeze(1) # If we don't use teacher forcing (and we are beyond the first # SOS step), we use the last output as new input else: decoder_input = symbols # Perform one forward step decoder_output, decoder_hidden, step_attn = self.forward_step(decoder_input, decoder_hidden, encoder_outputs, function=function) # Remove the unnecessary dimension. step_output = decoder_output.squeeze(1) # Get the actual symbol symbols = decode(di, step_output, step_attn) else: # Remove last token of the longest output target in the batch. We don't have to run the last decoder step where the teacher forcing input is EOS (or the last output) # It still is run for shorter output targets in the batch decoder_input = inputs[:, :-1] # Forward step without unrolling decoder_output, decoder_hidden, attn = self.forward_step( decoder_input, decoder_hidden, encoder_outputs, function=function) for di in range(decoder_output.size(1)): step_output = decoder_output[:, di, :] if attn is not None: step_attn = attn[:, di, :] else: step_attn = None decode(di, step_output, step_attn) ret_dict[DecoderRNN.KEY_SEQUENCE] = sequence_symbols ret_dict[DecoderRNN.KEY_LENGTH] = lengths.tolist() return decoder_outputs, decoder_hidden, ret_dict
def _init_state(self, encoder_hidden): """ Initialize the encoder hidden state. """ if encoder_hidden is None: return None if isinstance(encoder_hidden, tuple): encoder_hidden = tuple([self._cat_directions(h) for h in encoder_hidden]) else: encoder_hidden = self._cat_directions(encoder_hidden) return encoder_hidden def _cat_directions(self, h): """ If the encoder is bidirectional, do the following transformation. (#directions * #layers, #batch, hidden_size) -> (#layers, #batch, #directions * hidden_size) """ if self.bidirectional_encoder: h = torch.cat([h[0:h.size(0):2], h[1:h.size(0):2]], 2) return h def _validate_args(self, inputs, encoder_hidden, encoder_outputs, function, teacher_forcing_ratio): if self.use_attention: if encoder_outputs is None: raise ValueError( "Argument encoder_outputs cannot be None when attention is used.") # inference batch size if inputs is None and encoder_hidden is None: batch_size = 1 else: if inputs is not None: batch_size = inputs.size(0) else: if self.rnn_cell is nn.LSTM: batch_size = encoder_hidden[0].size(1) elif self.rnn_cell is nn.GRU: batch_size = encoder_hidden.size(1) # set default input and max decoding length if inputs is None: if teacher_forcing_ratio > 0: raise ValueError( "Teacher forcing has to be disabled (set 0) when no inputs is provided.") inputs = torch.tensor([self.sos_id] * batch_size, dtype=torch.long, device=device).view(batch_size, 1) max_length = self.max_length else: # minus the start of sequence symbol max_length = inputs.size(1) - 1 return inputs, batch_size, max_length