memory increase while using jax.value_and_grad

55 Views Asked by At

I'm using jax and flax. My problem is RAM increase while using jax.value_and_grad(loss_fn, has_aux=True). about 8MB increase So process go down before end of learning. Is there any solution?

0

There are 0 best solutions below