StochasticArrayLSTM
The StochasticArrayLSTM implements an ArrayLSTM with Non-deterministic Array-LSTM extension “Stochastic Output Pooling” of Rocki’s Recurrent Memory Array Structures. It module is build as an extension of the basic ArrayLSTM implementation.
- class extensions.StochasticArrayLSTM(*args: Any, **kwargs: Any)[source]
Implementation of ArrayLSTM with Stochastic Output Pooling
From Recurrent Memory Array Structures by Kamil Rocki
Note
This is a batch_first=True implementation, hence the forward() method expect inputs of shape=(batch, seq_len, input_size).
- input_size
Size of input dimension
- Type
int
Size of hidden dimension
- Type
int
- k
Number of parallel memory structures, i.e. cell states to use
- Type
int
- i2h
Linear layer transforming input to hidden state
- Type
nn.Linear
- h2h
Linear layer updating hidden state to hidden state
- Type
nn.Linear
Initialization
- StochasticArrayLSTM.__init__(*args: Any, **kwargs: Any) None
Forward
The StochasticArrayLSTM overwrites ArrayLSTM’s update_hidden()
method to update the hidden state using stochastic output pooling.
The API is equivalent to that of ArrayLSTM, but the implementations differ.
Update hidden state based on most likely output
- Parameters
outputs (torch.Tensor of shape=(k, batch_size, hidden_size)) – Tensor containing the result of output gates o
states (torch.Tensor of shape=(k, batch_size, hidden_size)) – Tensor containing the cell states
- Returns
hidden – Hidden tensor as computed from outputs and states
- Return type
torch.Tensor of shape=(1, batch_size, hidden_size)