I'm trying to run this simple introduction to score-based generative modeling. The code is using flax.optim, which seems to be moved to optax meanwhile (https://flax.readthedocs.io/en/latest/guides/converting_and_upgrading/optax_update_guide.html).

I've made a copy of the colab code with the changes I think needed to be made (I'm only unsure how I need to replace optimizer = flax.jax_utils.replicate(optimizer)).

Now, in the training section, I get the error

pmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())

at the line loss, params, opt_state = train_step_fn(step_rng, x, params, opt_state). This obviously comes from the return jax.pmap(step_fn, axis_name='device') in the "Define the loss function" section.

How can I fix this error? I've googled it, but have no idea what's going wrong here.

1

There are 1 best solutions below

0
jakevdp On BEST ANSWER

This happens because you are passing a scalar argument to a pmapped function. For example:

import jax
func = lambda x: x ** 2
pfunc = jax.pmap(func)

pfunc(1.0)
# ValueError: pmap was requested to map its argument along axis 0, which implies
# that its rank should be at least 1, but is only 0 (its shape is ())

If you want to operate on a scalar, you should use the function without wrapping it in pmap:

func(1.0)
# 1.0

Alternatively, if you want to use pmap, you should operate on an array whose leading dimension matches the number of devices:

num_devices = len(jax.devices())
x = jax.numpy.arange(num_devices)
pfunc(x)
# Array([ 0,  1,  4,  9, 16, 25, 36, 49], dtype=int32)