AttentionArrayLSTM
The AttentionArrayLSTM implements an ArrayLSTM with Deterministic Array-LSTM extension “Lane selection: Soft Attention” of Rocki’s Recurrent Memory Array Structures. It module is build as an extension of the basic ArrayLSTM implementation.
- class extensions.AttentionArrayLSTM(*args: Any, **kwargs: Any)[source]
Implementation of ArrayLSTM with Lane selection: Soft attention
From Recurrent Memory Array Structures by Kamil Rocki
Note
This is a batch_first=True implementation, hence the forward() method expect inputs of shape=(batch, seq_len, input_size).
- input_size
Size of input dimension
- Type
int
Size of hidden dimension
- Type
int
- k
Number of parallel memory structures, i.e. cell states to use
- Type
int
- max_pooling
If True, uses max pooling for attention instead
- Type
boolean, default=False
- i2h
Linear layer transforming input to hidden state
- Type
nn.Linear
- h2h
Linear layer updating hidden state to hidden state
- Type
nn.Linear
Initialization
- AttentionArrayLSTM.__init__(input_size, hidden_size, k, max_pooling=False)[source]
Implementation of ArrayLSTM with Lane selection: Soft attention
Note
This is a batch_first=True implementation, hence the forward() method expect inputs of shape=(batch, seq_len, input_size).
- Parameters
input_size (int) – Size of input dimension
hidden_size (int) – Size of hidden dimension
k (int) – Number of parallel memory structures, i.e. cell states to use
max_pooling (boolean, default=False) – If True, uses max pooling for attention instead
Forward
The AttentionArrayLSTM overwrites ArrayLSTM’s forward_cell()
method to include an attention mechanism.
The API is equivalent to that of ArrayLSTM, but the implementations differ.
- AttentionArrayLSTM.forward_cell(x, hidden, state)[source]
Perform a single forward pass through the network.
- Parameters
x (torch.Tensor of shape=(batch, input_size)) – Tensor to pass through network
hidden (torch.Tensor of shape (batch, input_size)) – Tensor containing the hidden state
state (torch.Tensor of shape (batch, input_size)) – Tensor containing the cell state
- Returns
hidden (torch.Tensor of shape (batch, input_size)) – Tensor containing the next hidden state
state (torch.Tensor of shape (batch, input_size)) – Tensor containing the next cell state