LSTM
As a basis, we provide a pure pytorch implementation of the LSTM module. This extends the regular torch.nn.Module interface.
- class lstm.LSTM(*args: Any, **kwargs: Any)[source]
LSTM implementation in pytorch
Note
This is a batch_first=True implementation, hence the forward() method expect inputs of shape=(batch, seq_len, input_size).
- input_size
Size of input dimension
- Type
int
Size of hidden dimension
- Type
int
- i2h
Linear layer transforming input to hidden state
- Type
nn.Linear
- h2h
Linear layer updating hidden state to hidden state
- Type
nn.Linear
Initialization
- LSTM.__init__(input_size, hidden_size)[source]
LSTM implementation in pytorch
Note
This is a batch_first=True implementation, hence the forward() method expect inputs of shape=(batch, seq_len, input_size).
- Parameters
input_size (int) – Size of input dimension
hidden_size (int) – Size of hidden dimension
Forward
As all nn.Module objects, the LSTM implements a forward()
method.
This method forwards all sequences in x through the forward_cell()
method.
- LSTM.forward(x, hidden=None)[source]
Forward all sequences through the network.
- Parameters
x (torch.Tensor of shape=(batch, seq_len, input_size)) – Tensor to pass through network
hidden (tuple) –
Tuple consisting of (hidden, state) to use as initial vector. If None is given, both hidden and state vectors will be initialised as the 0 vector.
- hidden torch.Tensor of shape (batch, input_size), default=0 vector
Tensor containing the hidden state
- state torch.Tensor of shape (batch, input_size), default=0 vector
Tensor containing the cell state
- Returns
outputs (torch.Tensor of shape=(batch, seq_len, hidden_size)) – Outputs for each input of sequence
hidden (tuple) – Tuple consisting of (hidden, state) of final output.
- hidden torch.Tensor of shape (batch, output)
Tensor containing the hidden state
- state torch.Tensor of shape (batch, output)
Tensor containing the cell state
A single LSTM cell is implemented by the forward_cell()
method.
Note that this method is also overwritten by subclasses to implement their custom forward methods.
- LSTM.forward_cell(x, hidden, state)[source]
Perform a single forward pass through the network.
- Parameters
x (torch.Tensor of shape=(batch, input_size)) – Tensor to pass through network
hidden (torch.Tensor of shape (batch, input_size)) – Tensor containing the hidden state
state (torch.Tensor of shape (batch, input_size)) – Tensor containing the cell state
- Returns
hidden (torch.Tensor of shape (batch, input_size)) – Tensor containing the next hidden state
state (torch.Tensor of shape (batch, input_size)) – Tensor containing the next cell state