Tracking progress of a celery chord task with tqdm? (Python)

912 Views Asked by At

Is there a way to track the progress of a chord, preferably in a tqdm bar?

For example if we take the documentation exemple, we would create this file:

#proj/tasks.py

@app.task
def add(x, y):
    return x + y

@app.task
def tsum(numbers):
    return sum(numbers)

and then run this script:

from celery import chord
from proj.tasks import add, tsum

chord(add.s(i, i)
      for i in range(100))(tsum.s()).get()

How could we track the progression on the chord?

  • We cannot use update_state since the chord() object is not a function.
  • We cannot use collect() since chord()(callback) blocks the script until the results are ready.

Ideally I would envision something like this custom tqdm subclass for Dask, however I've been unable to find a similar solution.

Any help or hint much appreciated!

1

There are 1 best solutions below

0
etincelaisse On

So I found a way around it.

First, chord()(callback) doesn't actually block the script, only the .get() part does. It just might take a long time to publish all tasks to the broker. Luckily, there's a simple way to track this publishing process through signals. We can create a progress bar before the publishing begins and modify the example handler from the documentation to update it:

from tqdm import tqdm
from celery.signals import after_task_publish

publish_pbar = tqdm(total=100, desc="Publishing tasks")

@after_task_publish.connect(sender='tasks.add')
def task_sent_handler(sender=None, headers=None, body=None, **kwargs):
    publish_pbar.update(1)

c = chord(add.s(i, i)
      for i in range(100))(tsum.s())

# The script will resume once all tasks are published so close the pbar
publish_pbar.close()

However this only works for publishing tasks since this signal is executed in the signal that sent the task. The task_success signal is executed in the worker process, so this trick can only be used in the worker log (to the best of my understanding).

So to track progress once all tasks have been published and the script resumes, I turned to worker stats from app.control.inspect().stats(). This returns a dict with various stats, among which are the completed tasks. Here's my implementation:

tasks_pbar = tqdm(total=100, desc="Executing tasks")

previous_total = 0
current_total = 0

while current_total<100:

    current_total = 0
    for key in app.control.inspect().stats():
        current_total += app.control.inspect().stats()[key]['total']['tasks.add']

    if current_total > previous_total:
        tasks_pbar.update(current_total-previous_total)

    previous_total = current_total

results = c.get()
tasks_pbar.close()

Finally, I think it might be necessary to give names to the tasks, both for filtering by the signal handler and for the stats() dict, so do not forget to add this to your tasks:

#proj/tasks.py

@app.task(name='tasks.add')
def add(x, y):
    return x + y

If someone can find a better solution, please do share!