ArrayLSTM

The ArrayLSTM implements the basic ArrayLSTM of Rocki’s Recurrent Memory Array Structures. It module is build as an extension of the normal LSTM implementation.

class arraylstm.ArrayLSTM(*args: Any, **kwargs: Any)[source]

Implementation of ArrayLSTM

From Recurrent Memory Array Structures by Kamil Rocki

Note

This is a batch_first=True implementation, hence the forward() method expect inputs of shape=(batch, seq_len, input_size).

input_size

Size of input dimension

Type

int

hidden_size

Size of hidden dimension

Type

int

k

Number of parallel memory structures, i.e. cell states to use

Type

int

i2h

Linear layer transforming input to hidden state

Type

nn.Linear

h2h

Linear layer updating hidden state to hidden state

Type

nn.Linear

Initialization

ArrayLSTM.__init__(input_size, hidden_size, k)[source]

Implementation of ArrayLSTM

Note

This is a batch_first=True implementation, hence the forward() method expect inputs of shape=(batch, seq_len, input_size).

Parameters
  • input_size (int) – Size of input dimension

  • hidden_size (int) – Size of hidden dimension

  • k (int) – Number of parallel memory structures, i.e. cell states to use

Forward

A single ArrayLSTM cell is implemented by the forward_cell() method. This method overwrites its LSTM superclass.

ArrayLSTM.forward_cell(x, hidden, state)[source]

Perform a single forward pass through the network.

Parameters
  • x (torch.Tensor of shape=(batch, input_size)) – Tensor to pass through network

  • hidden (torch.Tensor of shape (batch, input_size)) – Tensor containing the hidden state

  • state (torch.Tensor of shape (batch, input_size)) – Tensor containing the cell state

Returns

  • hidden (torch.Tensor of shape (batch, input_size)) – Tensor containing the next hidden state

  • state (torch.Tensor of shape (batch, input_size)) – Tensor containing the next cell state

As variations of the ArrayLSTM update their hidden state differently, we also add a method forward_cell(). This method can be overwritten by subclasses to update the hidden state in different ways.

ArrayLSTM.update_hidden(outputs, states)[source]

Default hidden state as sum of outputs and cells

Parameters
  • outputs (torch.Tensor of shape=(k, batch_size, hidden_size)) – Tensor containing the result of output gates o

  • states (torch.Tensor of shape=(k, batch_size, hidden_size)) – Tensor containing the cell states

Returns

hidden – Hidden tensor as computed from outputs and states

Return type

torch.Tensor of shape=(1, batch_size, hidden_size)

Hidden state

The ArrayLSTM requires multiple cell states instead of a single one, therefore it overwrites it super method from LSTM.

ArrayLSTM.initHidden(x)[source]

Initialise hidden layer