Flax JIT error for inherited nn.Module class methods

78 Views Asked by At

Based on this answer I am trying to make a class jit compatible by creating a pytree node, but I get:

TypeError: Cannot interpret value of type <class '__main__.TestModel'> as an abstract array; it does not have a dtype attribute

The error line is in the fit function when calling self.step. Is there anything wrong with my implementation?

import jax
import flax.linen as nn
import optax
from jax.tree_util import register_pytree_node_class
from dataclasses import dataclass
from typing import Callable

def data_loader(X, Y, batch_size):
    for i in range(0, len(X), batch_size):
        yield X[i : i + batch_size], Y[i : i + batch_size]

@register_pytree_node_class
@dataclass
class Parent(nn.Module):
    key: jax.random.PRNGKey
    params: dict = None

    @jax.jit
    def step(self, loss_fn, optimizer, opt_state, x, y):
        loss, grads = jax.value_and_grad(loss_fn)(y, self.predict(x))
        opt_grads, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(self.params, opt_grads)
        return params, opt_state, loss

    @jax.jit
    def predict(self, x):
        return self.apply(self.params, x)

    def fit(
        self,
        X,
        Y,
        optimizer: Callable,
        loss: Callable,
        batch_size=32,
        epochs=10,
        verbose=True,
    ):
        opt_state = optimizer.init(self.params)
        self.params = self.init(self.key, X)
        history = []
        for i in range(epochs):
            epoch_loss = 0
            for x, y in data_loader(X, Y, batch_size):
                self.params, opt_state, loss_value = self.step(
                    loss, optimizer, opt_state, x, y
                )
                epoch_loss += loss_value
            history.append(epoch_loss / (len(X) // batch_size))
            if verbose:
                print(f"Epoch {i+1}/{epochs} - loss: {history[-1]}")
        return history

    def tree_flatten(self):
        return (self.params,), None

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*children, aux_data)

class TestModel(Parent):
 d_hidden: int = 64
 d_out: int = 1

 @nn.compact
 def __call__(self, x):
     x = nn.Dense(self.d_hidden)(x)
     x = nn.relu(x)
     x = nn.Dense(self.d_out)(x)
     x = nn.sigmoid(x)
     return x


x_train = jax.random.normal(jax.random.PRNGKey(0), (209, 12288))
y_train = jax.random.randint(jax.random.PRNGKey(0), (209, 1), 0, 2)

model = TestModel(key=jax.random.PRNGKey(0))
model.fit(
 x_train,
 y_train,
 optimizer=optax.adam(1e-3),
 loss=optax.sigmoid_binary_cross_entropy,
)
0

There are 0 best solutions below