ArrayLSTM
The ArrayLSTM implements the basic ArrayLSTM of Rocki’s Recurrent Memory Array Structures. It module is build as an extension of the normal LSTM implementation.
- class arraylstm.ArrayLSTM(*args: Any, **kwargs: Any)[source]
Implementation of ArrayLSTM
From Recurrent Memory Array Structures by Kamil Rocki
Note
This is a batch_first=True implementation, hence the forward() method expect inputs of shape=(batch, seq_len, input_size).
- input_size
Size of input dimension
- Type
int
Size of hidden dimension
- Type
int
- k
Number of parallel memory structures, i.e. cell states to use
- Type
int
- i2h
Linear layer transforming input to hidden state
- Type
nn.Linear
- h2h
Linear layer updating hidden state to hidden state
- Type
nn.Linear
Initialization
- ArrayLSTM.__init__(input_size, hidden_size, k)[source]
- Parameters
input_size (int) – Size of input dimension
hidden_size (int) – Size of hidden dimension
k (int) – Number of parallel memory structures, i.e. cell states to use
Forward
A single ArrayLSTM cell is implemented by the forward_cell()
method.
This method overwrites its LSTM superclass.
- ArrayLSTM.forward_cell(x, hidden, state)[source]
Perform a single forward pass through the network.
- Parameters
x (torch.Tensor of shape=(batch, input_size)) – Tensor to pass through network
hidden (torch.Tensor of shape (batch, input_size)) – Tensor containing the hidden state
state (torch.Tensor of shape (batch, input_size)) – Tensor containing the cell state
- Returns
hidden (torch.Tensor of shape (batch, input_size)) – Tensor containing the next hidden state
state (torch.Tensor of shape (batch, input_size)) – Tensor containing the next cell state
As variations of the ArrayLSTM update their hidden state differently, we also add a method forward_cell()
.
This method can be overwritten by subclasses to update the hidden state in different ways.
Default hidden state as sum of outputs and cells
- Parameters
outputs (torch.Tensor of shape=(k, batch_size, hidden_size)) – Tensor containing the result of output gates o
states (torch.Tensor of shape=(k, batch_size, hidden_size)) – Tensor containing the cell states
- Returns
hidden – Hidden tensor as computed from outputs and states
- Return type
torch.Tensor of shape=(1, batch_size, hidden_size)