conditionally call vs. don't call a function using flax.linen

211 Views Asked by At

Based on a boolean flag, I want to either 1) call or 2) not call the following function (which operates on a flax linen module).

def true_fn(module, carry, inputs):

    carry, output = flax.linen.scan(function_to_scan_over, variable_broadcast = 'params', split_rngs = {"params": False})(module, carry, inputs)

    return carry

I have tried to use flax.linen.cond as follows

carry = flax.linen.cond(pred, true_fn, false_fn, module, carry, inputs) 

where false_fn is the identity function with respect to carry:

def false_fn(module, carry, inputs):

    return carry

But when I do this I get an error message saying

jax._src.traceback_util.UnfilteredStackTrace: TypeError: true_fun and false_fun output must have same type structure

The output type structure is the same. I assume, based on the flax.linen.cond documentation, that I am getting the error message because true_fn creates variables that are not created in false_fn (this is a problem for flax.linen.cond but not jax.lax.cond). The module I am passing to true_fn always gets called later in my code, if that matters.

Any advice on what I should do here?

edit: MWE added:

from flax import linen as nn
import jax

class MLP(nn.Module):
    dim: int
    def setup(self):
        self.dense = nn.Dense(self.dim)
    def __call__(self, x):
        return self.dense(x)

class Dummy(nn.Module):
    dim: int
    def setup(self):
        self.mlp = MLP(self.dim)
    def __call__(self, x):
        def true_fn(module, x):
            return module(x)
        def false_fn(module, x):
            return x
        y = nn.cond(True, true_fn, false_fn, self.mlp, x)
        return y + self.mlp(x)

dim_in = 3
dim_out = dim_in

dummy = Dummy(dim_in)
init_vars = dummy.init(x = jax.numpy.ones((dim_in,)), rngs = {'params': jax.random.PRNGKey(0)})
dummy.apply(init_vars, x = jax.numpy.ones((dim_in,)))

I am using flax 0.7.2 and jax/jaxlib 0.4.13.

0

There are 0 best solutions below