What are requirements for allocating a TPU Pod under VM architechture?

331 Views Asked by At

When allocating a TPU under TPU VM architechture, pod versions such as tpu-vm-tf-2.6.2-pod is available as TPU software version. When selecting pod as software version, and following instruction at Run JAX code on TPU Pod Slide jax.device_count() cannot find TPU.

Is selecting pod version sufficient to allocate a TPU Pod or are there additional steps/requirements? How can I select which TPU VM's to run under pod?

1

There are 1 best solutions below

3
Gagik On

If you are using Jax, please use Jax images tpu-vm-base and tpu-vm-v4-base instead of Tensorflow (e.g. tpu-vm-tf-2.12.0-pod).

gcloud compute tpus tpu-vm create tpu-name \
  --zone europe-west4-a \
  --accelerator-type v2-32 \
  --version tpu-vm-base

Note: For Jax use the same image --version tpu-vm-base for both TPU VM device (v2-8, v3-8) and TPU VM pod slices (e.g. v3-32, v3-64, v2-32, etc.).

For TPU version v2 and v3 please use --version tpu-vm-base then Install Jax on pod slice:

gcloud compute tpus tpu-vm ssh tpu-name \
  --zone europe-west4-a \
  --worker=all \
  --command="pip install 'jax[tpu]>=0.2.16' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"

details Run JAX code on TPU Pod slices. For TPU v4 please use tpu-vm-v4-base, details v4-users-guide.