NaN values in all parameters after training a custom RNN with Flax/JAX

42 Views Asked by At

I've implemented a custom RNN cell using Flax and JAX, and after training with less number of epochs, all model parameters turn to NaN. I'm seeking advice on potential causes and solutions.

params = model.init(skey, s)

compute_loss = partial(compute_loss_, s=jnp.array(s), obs=jnp.array(obs))
loss_grad = jax.grad(compute_loss)
optimizer = optax.chain(optax.clip(0.2), optax.adamw(learning_rate=1e-3,))
opt_state = optimizer.init(params)

for epoch in range(10):
    grads = loss_grad(params)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)

xi = model.apply(params, s)

After this training loop the params values are tuning into Nan values.

0

There are 0 best solutions below