Source code for extensions.stochastic

# Import pytorch library
import torch
import torch.nn.functional as F
# Import standard arrayLSTM implementation
from arrayLSTM import ArrayLSTM

[docs]class StochasticArrayLSTM(ArrayLSTM): """Implementation of ArrayLSTM with Stochastic Output Pooling From `Recurrent Memory Array Structures`_ by Kamil Rocki .. _`Recurrent Memory Array Structures`: https://arxiv.org/abs/1607.03085 Note ---- This is a `batch_first=True` implementation, hence the `forward()` method expect inputs of `shape=(batch, seq_len, input_size)`. Attributes ---------- input_size : int Size of input dimension hidden_size : int Size of hidden dimension k : int Number of parallel memory structures, i.e. cell states to use i2h : nn.Linear Linear layer transforming input to hidden state h2h : nn.Linear Linear layer updating hidden state to hidden state """
[docs] def update_hidden(self, outputs, states): """Update hidden state based on most likely output Parameters ---------- outputs : torch.Tensor of shape=(k, batch_size, hidden_size) Tensor containing the result of output gates o states : torch.Tensor of shape=(k, batch_size, hidden_size) Tensor containing the cell states Returns ------- hidden : torch.Tensor of shape=(1, batch_size, hidden_size) Hidden tensor as computed from outputs and states """ # Compute softmax s probability = F.softmax(outputs, dim=0) # Get maximum probability probability = probability.max(dim=0).indices # Initialise output output = torch.zeros(probability.shape, device=states.device) state = torch.zeros(probability.shape, device=states.device) # Fill output for k in range(self.k): # Create mask mask = probability == k # Fill output and state output[mask] = outputs[k][mask] state [mask] = states [k][mask] # Compute hidden and return hidden = torch.mul(output, state.tanh()) # Reshape hidden hidden = hidden.unsqueeze_(0) # Return hidden return hidden