I'm able to run other code like optax on my Macbook Pro M3 GPU just fine, but I haven't been able to get jaxopt working. I'm trying to run this simple code:
from jax import numpy as jnp
from jaxopt import LBFGS
def loss(x):
return x**2
opt = LBFGS(loss)
print(opt.run(jnp.array(3.0)))
But it gives this error:
/AppleInternal/Library/BuildRoots/0032d1ee-80fd-11ee-8227-6aecfccc70fe/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Runtimes/MPSRuntime/MPSRuntime.mm:1422:
failed assertion `MPSGraphKernelDAG: all placeholder ndarrays should have been allocated'
Am I doing something wrong or is it just a jax-metal bug? They do loudly warn that JAX Apple GPU support is experimental and not all JAX functionality is correctly supported! after all.