"all placeholder ndarrays should have been allocated" with jax-metal using jaxopt

38 Views Asked by At

I'm able to run other code like optax on my Macbook Pro M3 GPU just fine, but I haven't been able to get jaxopt working. I'm trying to run this simple code:

from jax import numpy as jnp
from jaxopt import LBFGS

def loss(x):
  return x**2

opt = LBFGS(loss)
print(opt.run(jnp.array(3.0)))

But it gives this error:

/AppleInternal/Library/BuildRoots/0032d1ee-80fd-11ee-8227-6aecfccc70fe/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Runtimes/MPSRuntime/MPSRuntime.mm:1422:
failed assertion `MPSGraphKernelDAG: all placeholder ndarrays should have been allocated'

Am I doing something wrong or is it just a jax-metal bug? They do loudly warn that JAX Apple GPU support is experimental and not all JAX functionality is correctly supported! after all.

0

There are 0 best solutions below