Airflow: How to retry a task if its corresponding sensor fails?

52 Views Asked by At

I have a simple DAG with a task (start_job) that starts a job via REST API.

A sensor task (wait_for_job) waits for the job to complete.

If the job doesn't complete within the configured sensor timeout, the job failed and I want both the start_job and the wait_for_job task to be re-tried.

I managed it to retry the start_job task using the on_failure_callback of the wait_for_job sensor task. But after its execution, the wait_for_job task is not triggered.

The last message of the start_job task is "INFO - 0 downstream tasks scheduled from follow-on schedule check". I expected that 1 downstream task is found (as in the first run of the start_job task).

Here is a minimal DAG without the REST API stuff:

import time
import logging
from datetime import timedelta
from typing import Any, Dict, List, Optional

import pendulum
from sqlalchemy.orm.session import Session

from airflow.decorators import dag, task
from airflow.sensors.base import PokeReturnValue
from airflow.models import taskinstance
from airflow.utils.state import State
from airflow.utils.db import provide_session
from airflow.utils.session import NEW_SESSION, provide_session

logger = logging.getLogger("airflow.task")


@provide_session
def on_failure_callback(context: Dict[str, Any], session: Session = NEW_SESSION) -> None:
    logger.info(f"on_failure_callback()")

    start_job_task = _get_task(context, "start_job")
    wait_for_job_task = _get_task(context, "wait_for_job")

    # _clear_task(wait_for_job_task, session, context)
    # _clear_task(start_job_task, session, context)

    # start_job_task.set_state(State.FAILED)
    # wait_for_job_task.set_state(State.UP_FOR_RETRY)

    logging.info("set state of start_job_task to UP_FOR_RETRY ...")
    start_job_task.set_state(State.UP_FOR_RETRY)



def _clear_task(task, session, context):
    logger.info(f"run clear_task_instances() for task: {task.task_id}")
    taskinstance.clear_task_instances(
        tis=[task, ],
        session=session,
        dag=context["dag"])


def _get_task(
    context: Dict[str, Any],
    task_id: str,
    ) -> taskinstance.TaskInstance:
    
    task_instances: List[taskinstance.TaskInstance] = context["dag_run"].get_task_instances()
    logger.info(f"task_instances: {task_instances}")
    for ti in task_instances:
        logger.info(f"    ti.task_id: {ti.task_id}")
        if ti.task_id == task_id:
            return ti


@provide_session
def on_retry_callback(context: Dict[str, Any], session: Session = NEW_SESSION) -> None:
    print("on_retry_callback()")


@dag(
    schedule=None,
    # schedule="@once",
    # schedule="*/5 * * * *",     # At every 5th minute.
    start_date=pendulum.datetime(2023, 1, 1, tz="UTC"),
    catchup=False,
    tags=["job"],
)
def start_jobs_minimal_dag():

    @task(
        execution_timeout=timedelta(seconds=30),
        retries=3,
        retry_delay=timedelta(seconds=10),
    )
    def start_job():
        job_id = 1
        time.sleep(1)        
        return job_id
        
        
    @task.sensor(
        execution_timeout=timedelta(seconds=5),
        timeout=60,
        retries=3,
        retry_delay=timedelta(seconds=2),
        # leads to endless loop
        # mode='reschedule',
        on_failure_callback=on_failure_callback,
        on_retry_callback=on_retry_callback,
    )
    def wait_for_job(job_id: int) -> PokeReturnValue:
        logger.info(f"wait_for_job(): job_id: {job_id}")        
        time.sleep(2)        
        # make the sensor fail
        return PokeReturnValue(is_done=False, xcom_value=None)

    job_id = start_job()
    wait_for_job(job_id)
    
start_jobs_minimal_dag()
0

There are 0 best solutions below