I noticed that when I use flax's tensorboard from flax.metrics import tensorboard to log the loss, the GPU allocation explodes.
To compute the loss metrics I use the has_aux as explained in here.
These are the functions I am using. In the model class:
@partial(jit, static_argnums=(0,))
def compute_losses(self, params, x_train, y_train) -> dict:
# data fit loss
y_pred = self.forward(params, x_train)
rec_loss = jnp.mean(jnp.abs(y_pred - y_train)**2)
# reguralizations
# ...
loss_dict = {"rec":rec_loss}
return loss, loss_dict
@partial(jit, static_argnums=(0,))
def grad_loss(self, params, x_rec, y_rec):
return jax.grad(self.compute_losses, has_aux=True)(params, x_rec, y_rec)
In the main:
grads, loss_dict = model.grad_loss(params, input, target)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
# log loss
if step % config.logging.log_every_steps == 0 and config.logging.log_loss:
# compute loss
loss_dict = model.losses_no_grad(params, input, target)
for term in loss_dict.keys():
writer.scalar(f'loss/{term}', loss_dict[term], step)
Without the writer, the GPU consumption is around 25%. If activated, it goes up to 100%. Why is it happening?
EDIT:
This does not happen with from torch.utils import tensorboard.
You can see the GPU consumption here