Source code for extensions.attention

# Import pytorch library
import torch
import torch.nn as nn
import torch.nn.functional as F
# Import standard arrayLSTM implementation
from arrayLSTM import ArrayLSTM

[docs]class AttentionArrayLSTM(ArrayLSTM): """Implementation of ArrayLSTM with Lane selection: Soft attention From `Recurrent Memory Array Structures`_ by Kamil Rocki .. _`Recurrent Memory Array Structures`: https://arxiv.org/abs/1607.03085 Note ---- This is a `batch_first=True` implementation, hence the `forward()` method expect inputs of `shape=(batch, seq_len, input_size)`. Attributes ---------- input_size : int Size of input dimension hidden_size : int Size of hidden dimension k : int Number of parallel memory structures, i.e. cell states to use max_pooling : boolean, default=False If True, uses max pooling for attention instead i2h : nn.Linear Linear layer transforming input to hidden state h2h : nn.Linear Linear layer updating hidden state to hidden state """
[docs] def __init__(self, input_size, hidden_size, k, max_pooling=False): """Implementation of ArrayLSTM with Lane selection: Soft attention Note ---- This is a `batch_first=True` implementation, hence the `forward()` method expect inputs of `shape=(batch, seq_len, input_size)`. Parameters ---------- input_size : int Size of input dimension hidden_size : int Size of hidden dimension k : int Number of parallel memory structures, i.e. cell states to use max_pooling : boolean, default=False If True, uses max pooling for attention instead """ # Call super super().__init__(input_size, hidden_size, k) # Set max_pooling self.max_pooling = max_pooling # Set attention layer self.i2a = nn.Linear(input_size, hidden_size*k)
######################################################################## # Pass through network # ########################################################################
[docs] def forward_cell(self, x, hidden, state): """Perform a single forward pass through the network. Parameters ---------- x : torch.Tensor of shape=(batch, input_size) Tensor to pass through network hidden : torch.Tensor of shape (batch, input_size) Tensor containing the hidden state state : torch.Tensor of shape (batch, input_size) Tensor containing the cell state Returns ------- hidden : torch.Tensor of shape (batch, input_size) Tensor containing the next hidden state state : torch.Tensor of shape (batch, input_size) Tensor containing the next cell state """ # Reshape hidden state to work for single cell hidden = hidden.view(hidden.size(1), -1) # Initialise outputs outputs = torch.zeros(self.k, x.shape[0], self.hidden_size, device=x.device) # Compute attention signal attention = self.i2a(x).sigmoid() # View attention in terms of k attention = attention.view(x.shape[0], self.k, -1) # Compute softmax s softmax = F.softmax(attention, dim=1) # Use max_pooling if necessary if self.max_pooling: # Get maximum softmax = softmax.max(dim=1).values softmax = softmax.unsqueeze_(1) softmax = softmax.expand(-1, self.k, -1) # Apply linear mapping linear = self.i2h(x) + self.h2h(hidden) # View linear in terms of k linear = linear.view(x.shape[0], self.k, -1) # Loop over all k for k, (linear_, softmax_) in enumerate(zip( torch.unbind(linear , dim=1), torch.unbind(softmax, dim=1))): # Perform activation functions gates = linear_[:, :3*self.hidden_size ].sigmoid() c_t = linear_[:, 3*self.hidden_size:].tanh() # Extract gates f_t = torch.mul(softmax_, gates[:, :self.hidden_size ]) i_t = torch.mul(softmax_, gates[:, self.hidden_size:2*self.hidden_size]) o_t = torch.mul(softmax_, gates[:, -self.hidden_size: ]) # Update state state[k] = torch.mul(state[k].clone(), (1-f_t)) + torch.mul(i_t, c_t) # Update outputs outputs[k] = o_t # Update hidden state hidden = self.update_hidden(outputs, state) # Return result return hidden, state