Using a progressbar with a method in Jax

69 Views Asked by At

I am using JaX would like to use a progressbar for a class method. The method starts as following:

@partial(jit, static_argnums=(0,))
def _run_step(self, runner, i_step):
           ...

I tried adding the scan_tqdm decorator, but it didn't work. This happenned because the decorator works when the decorated function takes 2 inputs; an input used by the function (in my case runner) and a step counter for the bar (in my case i_step). However, in my case there is also self, which appears to create problems for scan_tqdm.

Any ideas to fix this?

Thank you in advance.

EDIT: I ended up solving the problem by using scan_tqdm within lax.scan like this:

runner, metrics = lax.scan(
                scan_tqdm(self.config["TOTAL_STEPS"])(self._run_step),
                runner,
                jnp.arange(self.config["TOTAL_STEPS"]),
                self.config["TOTAL_STEPS"]
            )
1

There are 1 best solutions below

0
TheHungryCub On

Using scan_tqdm within lax.scan is indeed a good approach to integrate progress tracking with your method.

Another approach:

You can add @partial(scan_tqdm, 0), you’re specifying that the scan_tqdm decorator should consider the first argument (which is self) as a placeholder, allowing it to properly handle the function signature.

from jax.experimental.optimizers import scan_tqdm
from functools import partial

@partial(jit, static_argnums=(0,))
@partial(scan_tqdm, 0)  # Adding a placeholder for self
def _run_step(self, runner, i_step):
    # Your method implementation here