LSTM

As a basis, we provide a pure pytorch implementation of the LSTM module. This extends the regular torch.nn.Module interface.

class lstm.LSTM(*args: Any, **kwargs: Any)[source]

LSTM implementation in pytorch

Note

This is a batch_first=True implementation, hence the forward() method expect inputs of shape=(batch, seq_len, input_size).

input_size

Size of input dimension

Type

int

hidden_size

Size of hidden dimension

Type

int

i2h

Linear layer transforming input to hidden state

Type

nn.Linear

h2h

Linear layer updating hidden state to hidden state

Type

nn.Linear

Initialization

LSTM.__init__(input_size, hidden_size)[source]

LSTM implementation in pytorch

Note

This is a batch_first=True implementation, hence the forward() method expect inputs of shape=(batch, seq_len, input_size).

Parameters
  • input_size (int) – Size of input dimension

  • hidden_size (int) – Size of hidden dimension

Forward

As all nn.Module objects, the LSTM implements a forward() method. This method forwards all sequences in x through the forward_cell() method.

LSTM.forward(x, hidden=None)[source]

Forward all sequences through the network.

Parameters
  • x (torch.Tensor of shape=(batch, seq_len, input_size)) – Tensor to pass through network

  • hidden (tuple) –

    Tuple consisting of (hidden, state) to use as initial vector. If None is given, both hidden and state vectors will be initialised as the 0 vector.

    hidden torch.Tensor of shape (batch, input_size), default=0 vector

    Tensor containing the hidden state

    state torch.Tensor of shape (batch, input_size), default=0 vector

    Tensor containing the cell state

Returns

  • outputs (torch.Tensor of shape=(batch, seq_len, hidden_size)) – Outputs for each input of sequence

  • hidden (tuple) – Tuple consisting of (hidden, state) of final output.

    hidden torch.Tensor of shape (batch, output)

    Tensor containing the hidden state

    state torch.Tensor of shape (batch, output)

    Tensor containing the cell state

A single LSTM cell is implemented by the forward_cell() method. Note that this method is also overwritten by subclasses to implement their custom forward methods.

LSTM.forward_cell(x, hidden, state)[source]

Perform a single forward pass through the network.

Parameters
  • x (torch.Tensor of shape=(batch, input_size)) – Tensor to pass through network

  • hidden (torch.Tensor of shape (batch, input_size)) – Tensor containing the hidden state

  • state (torch.Tensor of shape (batch, input_size)) – Tensor containing the cell state

Returns

  • hidden (torch.Tensor of shape (batch, input_size)) – Tensor containing the next hidden state

  • state (torch.Tensor of shape (batch, input_size)) – Tensor containing the next cell state

Hidden state

The LSTM provides a method for initializing the hidden state and cell state. Note that this method is also overwritten by subclasses to implement their custom cell initializations.

LSTM.initHidden(x)[source]

Initialise hidden layer