I am using JaX would like to use a progressbar for a class method. The method starts as following:
@partial(jit, static_argnums=(0,))
def _run_step(self, runner, i_step):
...
I tried adding the scan_tqdm decorator, but it didn't work. This happenned because the decorator works when the decorated function takes 2 inputs; an input used by the function (in my case runner) and a step counter for the bar (in my case i_step). However, in my case there is also self, which appears to create problems for scan_tqdm.
Any ideas to fix this?
Thank you in advance.
EDIT:
I ended up solving the problem by using scan_tqdm within lax.scan like this:
runner, metrics = lax.scan(
scan_tqdm(self.config["TOTAL_STEPS"])(self._run_step),
runner,
jnp.arange(self.config["TOTAL_STEPS"]),
self.config["TOTAL_STEPS"]
)
Using scan_tqdm within lax.scan is indeed a good approach to integrate progress tracking with your method.
Another approach:
You can add @partial(scan_tqdm, 0), you’re specifying that the scan_tqdm decorator should consider the first argument (which is self) as a placeholder, allowing it to properly handle the function signature.