Gradient clipping with multiple losses

343 Views Asked by At

I need to train a neural network using multiple losses. The most basic way is to sum the losses and then do a gradient step

optimizer.zero_grad()
total_loss = loss_1 + loss_2
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
optimizer.step()

However, sometimes one loss may take over, and I want both to contribute equally. I though about clipping losses after single passes like the following

optimizer.zero_grad()
loss_1.backward(retain_graph=True)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
loss_2.backward(retain_graph=True)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
optimizer.step()

But it's still a problem, because if the first loss it's very big the gradient will be already close to max_norm.

I thought about something like this

optimizer.zero_grad()
loss_1.backward(retain_graph=True)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
grad_1 = ... # save gradient

optimizer.zero_grad()
loss_2.backward(retain_graph=True)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
grad_2 = ... # save gradient

optimizer.zero_grad()
grad = grad_1 + grad_2
# manually apply grad (how?)

optimizer.step()

Is there a nicer way to do it? How would I manually apply the gradient to the model in a nice way?

0

There are 0 best solutions below