I've implemented a custom RNN cell using Flax and JAX, and after training with less number of epochs, all model parameters turn to NaN. I'm seeking advice on potential causes and solutions.
params = model.init(skey, s)
compute_loss = partial(compute_loss_, s=jnp.array(s), obs=jnp.array(obs))
loss_grad = jax.grad(compute_loss)
optimizer = optax.chain(optax.clip(0.2), optax.adamw(learning_rate=1e-3,))
opt_state = optimizer.init(params)
for epoch in range(10):
grads = loss_grad(params)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
xi = model.apply(params, s)
After this training loop the params values are tuning into Nan values.