AttributeError: module 'flax.linen' has no attribute 'transforms'

99 Views Asked by At

I received an error from flax 0.7.5, could u help me:

  File ~\AppData\Roaming\Python\Python311\site-packages\jVMC\nets\__init__.py:5
    from jVMC.nets.rnn import *

  File ~\AppData\Roaming\Python\Python311\site-packages\jVMC\nets\rnn.py:92
    class RNN(nn.Module):

  File ~\AppData\Roaming\Python\Python311\site-packages\jVMC\nets\rnn.py:138 in RNN
    @partial(nn.transforms.scan,

AttributeError: module 'flax.linen' has no attribute 'transforms'

Many thanks!!!

I have tried flax 0.7.4 after I solved an AttributeError: module 'flax' has no attribute 'nn'

1

There are 1 best solutions below

0
jakevdp On

From your traceback, it looks like you are using vmc_jax. Looking at this package, it appears to require flax v0.6.4-0.6.11; I suspect the error you're seeing is because you're using too new a flax version for this package.

I would suggest installing flax 0.6.11, and also use the required jax versions listed there:

$ pip install flax==0.6.11 jax==0.4.11 jaxlib==0.4.11