Source code for arraylstm

# Import pytorch library
import torch
import torch.nn as nn
import torch.nn.functional as F
# Import pytorch LSTM implementation
from arrayLSTM import LSTM

[docs]class ArrayLSTM(LSTM): """Implementation of ArrayLSTM From `Recurrent Memory Array Structures`_ by Kamil Rocki .. _`Recurrent Memory Array Structures`: https://arxiv.org/abs/1607.03085 Note ---- This is a `batch_first=True` implementation, hence the `forward()` method expect inputs of `shape=(batch, seq_len, input_size)`. Attributes ---------- input_size : int Size of input dimension hidden_size : int Size of hidden dimension k : int Number of parallel memory structures, i.e. cell states to use i2h : nn.Linear Linear layer transforming input to hidden state h2h : nn.Linear Linear layer updating hidden state to hidden state """
[docs] def __init__(self, input_size, hidden_size, k): """Implementation of ArrayLSTM Note ---- This is a `batch_first=True` implementation, hence the `forward()` method expect inputs of `shape=(batch, seq_len, input_size)`. Parameters ---------- input_size : int Size of input dimension hidden_size : int Size of hidden dimension k : int Number of parallel memory structures, i.e. cell states to use """ # Call super super().__init__(input_size, hidden_size) # Set dimensions self.input_size = input_size self.hidden_size = hidden_size self.k = k # Set layers self.i2h = nn.Linear(input_size , 4*hidden_size*k) self.h2h = nn.Linear(hidden_size, 4*hidden_size*k)
######################################################################## # Pass through network # ########################################################################
[docs] def forward_cell(self, x, hidden, state): """Perform a single forward pass through the network. Parameters ---------- x : torch.Tensor of shape=(batch, input_size) Tensor to pass through network hidden : torch.Tensor of shape (batch, input_size) Tensor containing the hidden state state : torch.Tensor of shape (batch, input_size) Tensor containing the cell state Returns ------- hidden : torch.Tensor of shape (batch, input_size) Tensor containing the next hidden state state : torch.Tensor of shape (batch, input_size) Tensor containing the next cell state """ # Reshape hidden state to work for single cell hidden = hidden.view(hidden.size(1), -1) # Initialise outputs outputs = torch.zeros(self.k, x.shape[0], self.hidden_size, device=x.device) # Apply linear mapping linear = self.i2h(x) + self.h2h(hidden) # View linear in terms of k linear = linear.view(x.shape[0], self.k, -1) # Loop over all k for k, linear_ in enumerate(torch.unbind(linear, dim=1)): # Perform activation functions gates = linear_[:, :3*self.hidden_size ].sigmoid() c_t = linear_[:, 3*self.hidden_size:].tanh() # Extract gates f_t = gates[:, :self.hidden_size ] i_t = gates[:, self.hidden_size:2*self.hidden_size] o_t = gates[:, -self.hidden_size: ] # Update state state[k] = torch.mul(state[k].clone(), f_t) + torch.mul(i_t, c_t) # Update outputs outputs[k] = o_t # Update hidden state hidden = self.update_hidden(outputs, state) # Return result return hidden, state
######################################################################## # Update hidden state # ########################################################################
[docs] def update_hidden(self, outputs, states): """Default hidden state as sum of outputs and cells Parameters ---------- outputs : torch.Tensor of shape=(k, batch_size, hidden_size) Tensor containing the result of output gates o states : torch.Tensor of shape=(k, batch_size, hidden_size) Tensor containing the cell states Returns ------- hidden : torch.Tensor of shape=(1, batch_size, hidden_size) Hidden tensor as computed from outputs and states """ # Initialise hidden state hidden = torch.zeros(1, outputs.shape[1], self.hidden_size, device=states.device) # Loop over all outputs for output, state in zip(torch.unbind(outputs, dim=0), torch.unbind(states , dim=0)): # Update hidden state hidden += torch.mul(output, state.tanh()) # Return hiddens tate return hidden
######################################################################## # Hidden state initialisation # ########################################################################
[docs] def initHidden(self, x): """Initialise hidden layer""" return torch.zeros( 1, x.shape[0], self.hidden_size).to(x.device),\ torch.zeros(self.k, x.shape[0], self.hidden_size).to(x.device)