I'm trying to run this simple introduction to score-based generative modeling. The code is using flax.optim, which seems to be moved to optax meanwhile (https://flax.readthedocs.io/en/latest/guides/converting_and_upgrading/optax_update_guide.html).
I've made a copy of the colab code with the changes I think needed to be made (I'm only unsure how I need to replace optimizer = flax.jax_utils.replicate(optimizer)).
Now, in the training section, I get the error
pmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())
at the line loss, params, opt_state = train_step_fn(step_rng, x, params, opt_state). This obviously comes from the return jax.pmap(step_fn, axis_name='device') in the "Define the loss function" section.
How can I fix this error? I've googled it, but have no idea what's going wrong here.
This happens because you are passing a scalar argument to a pmapped function. For example:
If you want to operate on a scalar, you should use the function without wrapping it in
pmap:Alternatively, if you want to use
pmap, you should operate on an array whose leading dimension matches the number of devices: