I have a subclassed flax.linen.Module that takes a boolean argument in its __call__ method. I want to use gradient checkpointing to reduce the GPU memory footprint of this layer, so I am using flax.linen.checkpoint which promises to be a "lifted" version of jax.checkpoint.
What I want is to use the standard pattern of passing a boolean flag into the __call__ method to indicate whether we are in training or inference mode. In training mode, operations like dropout are random; in inference mode they are deterministic. So the flag variable is often named deterministic. So far this is all standard stuff.
The deterministic variable cannot be traced as it will affect control flow, so instead (when applying function transformations) it is usually flagged using static_argnums, as in this example from the jax.checkpoint documentation. Again, standard jax. What I want is to use this same pattern, but with the flax module.
I take the use of "lifting" in the documentation of flax.linen.checkpoint to mean that it can be used on a Module the same way jax.checkpoint can be used on a function. And indeed the documentation says it has a static_argnums parameter which (I think) should intuitively be applied to the __call__ method of the (callable) Module, to allow non-traced values in the same way jax.checkpoint does.
However, I can't get this to work. See for example the following code, directly modeled on the example in the flax.linen.checkpoint docs:
import jax
import jax.numpy as jnp
import flax.linen as nn
class MLPWithDropout(nn.Module):
@nn.compact
def __call__(self, x, deterministic=False):
x = nn.Dense(128)(x)
x = nn.Dropout(rate=0.5, deterministic=deterministic)(x)
x = nn.relu(x)
x = nn.Dense(1)(x)
return x
# The following fails with jax.errors.ConcretizationTypeError, and suggests using static_argnums:
CheckpointedMLPWithDropout = nn.checkpoint (MLPWithDropout)
# Specifically the error is this:
# jax.errors.ConcretizationTypeError: Attempted boolean conversion of traced array with shape bool[]..
# The following fails with a ValueError, and says that static_argnums (which has been remapped to 3) is too large:
# CheckpointedMLPWithDropout = nn.checkpoint (MLPWithDropout, static_argnums=2)
# The following does not seem to do anything, i.e. it returns a ConcretizationTypeError again:
# CheckpointedMLPWithDropout = nn.checkpoint (MLPWithDropout, static_argnums=-1)
# Same for this (ConcretizationTypeError):
# CheckpointedMLPWithDropout = nn.checkpoint (MLPWithDropout, static_argnums=1)
model = CheckpointedMLPWithDropout()
x = jnp.ones((1, 16))
rng = jax.random.PRNGKey(0)
vars = model.init(rng, x, deterministic=True)
print(model.apply(vars, x, deterministic=False, rngs = {'dropout': rng}))
This turned out to be a very basic Python error: since I was passing
deterministicas a named argument, it did not get counted as a positional argument. The following works: