AttentionArrayLSTM

The AttentionArrayLSTM implements an ArrayLSTM with Deterministic Array-LSTM extension “Lane selection: Soft Attention” of Rocki’s Recurrent Memory Array Structures. It module is build as an extension of the basic ArrayLSTM implementation.

class extensions.AttentionArrayLSTM(*args: Any, **kwargs: Any)[source]

Implementation of ArrayLSTM with Lane selection: Soft attention

From Recurrent Memory Array Structures by Kamil Rocki

Note

This is a batch_first=True implementation, hence the forward() method expect inputs of shape=(batch, seq_len, input_size).

input_size

Size of input dimension

Type

int

hidden_size

Size of hidden dimension

Type

int

k

Number of parallel memory structures, i.e. cell states to use

Type

int

max_pooling

If True, uses max pooling for attention instead

Type

boolean, default=False

i2h

Linear layer transforming input to hidden state

Type

nn.Linear

h2h

Linear layer updating hidden state to hidden state

Type

nn.Linear

Initialization

AttentionArrayLSTM.__init__(input_size, hidden_size, k, max_pooling=False)[source]

Implementation of ArrayLSTM with Lane selection: Soft attention

Note

This is a batch_first=True implementation, hence the forward() method expect inputs of shape=(batch, seq_len, input_size).

Parameters
  • input_size (int) – Size of input dimension

  • hidden_size (int) – Size of hidden dimension

  • k (int) – Number of parallel memory structures, i.e. cell states to use

  • max_pooling (boolean, default=False) – If True, uses max pooling for attention instead

Forward

The AttentionArrayLSTM overwrites ArrayLSTM’s forward_cell() method to include an attention mechanism. The API is equivalent to that of ArrayLSTM, but the implementations differ.

AttentionArrayLSTM.forward_cell(x, hidden, state)[source]

Perform a single forward pass through the network.

Parameters
  • x (torch.Tensor of shape=(batch, input_size)) – Tensor to pass through network

  • hidden (torch.Tensor of shape (batch, input_size)) – Tensor containing the hidden state

  • state (torch.Tensor of shape (batch, input_size)) – Tensor containing the cell state

Returns

  • hidden (torch.Tensor of shape (batch, input_size)) – Tensor containing the next hidden state

  • state (torch.Tensor of shape (batch, input_size)) – Tensor containing the next cell state