How to compute truncated backpropagation through time (BPTT) for RNN cell in PyTorch

102 Views Asked by At

For simplicity I have a sequence of N input data like words and i have an RNN cell. I want to compute trunkated backpropagation thorugh time (BPTT) over sliding window of K words within the loop:

optimizer.zero_grad()
h = torch.zeros(hidden_size)
for i in range(N):
    out, h = rnn_cell.forward(data[i], h)
    if i > K:
        loss += compute_loss(out, target)

loss.backward()
optimizer.step()

but obviously it will compute gradient over all previous steps. I tried also this approach:

h = torch.zeros(hidden_size)
for i in range(N):
    optimizer.zero_grad()
    out, h = rnn_cell.forward(data[i], h.detach())
    loss += compute_loss(out, target)
    loss.backward(retain_graph=True)
    optimizer.step()

but it will compute the gradient only for the last step. I tried also to maintain previous hidden states only for K steps in deque(maxlen=K) because I thought that when the reference to h state is discarded from the list it will be also removed from the graph:

optimizer.zero_grad()
h = torch.zeros(hidden_size)
last_h = deque(maxlen=10)

for i in range(N):
    last_h.append(h)
    out, h = rnn_cell.forward(data[i], h)
    if i > K:
        optimizer.zero_grad()
        loss += compute_loss(out, target)
        loss.backward(retain_graph=True)
        optimizer.step()

but I doubt if any approach here works as I intended. As a very naive workaround I can do that:

h = torch.zeros(hidden_size)
optimizer.zero_grad()

for i in range(0, N, K):
    h = h.detach()

    optimizer.zero_grad()
    for j in range(i, min(i + K, N)):
        out, h = rnn_cell.forward(data[j], h)

    loss += compute_loss(out, target)
    loss.backward()

but it requires computation of each step K times. Eventually I can also detach h every K steps but this way gradient will be inaccurate:

h = torch.zeros(hidden_size)
optimizer.zero_grad()

for i in range(0, N, K):
    out, h = rnn_cell.forward(data[j], h)
    if i % K == 0 and i > 0:
        optimizer.zero_grad()
        h = h.detach()
        loss += compute_loss(out, target)
        loss.backward()
        optimizer.step()

If you have any idea how to do such sliding gradient window better I would be very glad for your help.

1

There are 1 best solutions below

0
Karl On

Is there a specific reason you're using RNNCell over RNN? Also you should use rnn_cell(data[i], h) instead of rnn_cell.forward(data[i], h). Unless you specifically need to add custom stuff for every time step, RNN will make your life easier for batch processing and using multiple layers.

Regardless:

Typically setting BPTT values is done at the data processing level. RNNs take in a tensor of size (bs, sl, d_in) (I'm using batch first format, but the same applies for sequence length first format). "BPTT" is just a fancy way of specifying the maximum value of sl in your input.

Say you have a total sequence length of N and want to use a BPTT value of K. You would choose an overlap value O between chunks. For example O=1 means chunk n+1 is one token shifted from chunk n. If O=K, there is no overlap. You would preprocess your entire dataset into chunks of size K with the desired overlap O.

Then when training, you would process a full sequence of length K, compute your loss, then backprop. If you're wondering about tracking the hidden state between chunks, the answer is you don't. That's a tradeoff when using BPTT that you make for the sake of compute efficiency. Each chunk starts with a fresh hidden state - each chunk is blind to whatever state existed before it.

If the hidden state thing concerns you, you can look into Truncated BPTT. With Truncated BPTT, you first run a sequence of K1 without grad tracking to build up a hidden state, then run a sequence of K2 with grad tracking and the hidden state from K1. You then update and backprop through K2.