import logging import psutil import shutil from airflow import configuration from airflow.models import DAG, DagRun, TaskInstance from airflow.operators.python_operator import PythonOperator from airflow.utils.dates import days_ago from airflow.utils.db import provide_session from airflow.utils.state import State from cwl_airflow.utils.notifier import dag_on_success, dag_on_failure logger = logging.getLogger(__name__) TIMEOUT = configuration.conf.getint('core', 'KILLED_TASK_CLEANUP_TIME') @provide_session def clean_db(dr, session=None): logger.debug(f"""Clean DB for {dr.dag_id} - {dr.run_id}""") for ti in dr.get_task_instances(): logger.debug(f"""process {ti.dag_id} - {ti.task_id} - {ti.execution_date}""") ti.clear_xcom_data() logger.debug(" - clean Xcom table") session.query(TaskInstance).filter( TaskInstance.task_id == ti.task_id, TaskInstance.dag_id == ti.dag_id, TaskInstance.execution_date == dr.execution_date).delete(synchronize_session='fetch') session.commit() logger.debug(" - clean TaskInstance table") session.query(DagRun).filter( DagRun.dag_id == dr.dag_id, DagRun.run_id == dr.run_id, ).delete(synchronize_session='fetch') session.commit() logger.debug(" - clean dag_run table") def stop_tasks(dr): logger.debug(f"""Stop tasks for {dr.dag_id} - {dr.run_id}""") for ti in dr.get_task_instances(): logger.debug(f"""process {ti.dag_id} - {ti.task_id} - {ti.execution_date} - {ti.pid}""") if ti.state == State.RUNNING: try: process = psutil.Process(ti.pid) if ti.pid else None except Exception: logger.debug(f" - cannot find process by PID {ti.pid}") process = None ti.set_state(State.FAILED) logger.debug(" - set state to FAILED") if process: logger.debug(f" - wait for process {ti.pid} to exit") try: process.wait(timeout=TIMEOUT * 2) # raises psutil.TimeoutExpired if timeout. Makes task fail -> DagRun fails except psutil.TimeoutExpired as e: logger.debug(f" - Done waiting for process {ti.pid} to die") def remove_tmp_data(dr): logger.debug(f"""Remove tmp data for {dr.dag_id} - {dr.run_id}""") tmp_folder_set = set() for ti in dr.get_task_instances(): ti_xcom_data = ti.xcom_pull(task_ids=ti.task_id) # can be None if ti_xcom_data and "outdir" in ti_xcom_data: tmp_folder_set.add(ti_xcom_data["outdir"]) for tmp_folder in tmp_folder_set: try: shutil.rmtree(tmp_folder) logger.debug(f"""Successfully removed {tmp_folder}""") except Exception as ex: logger.error(f"""Failed to delete temporary output directory {tmp_folder}\n {ex}""") def clean_dag_run(**context): dag_id = context['dag_run'].conf['remove_dag_id'] run_id = context['dag_run'].conf['remove_run_id'] dr_list = DagRun.find(dag_id=dag_id, run_id=run_id) for dr in dr_list: stop_tasks(dr) remove_tmp_data(dr) clean_db(dr) dag = DAG(dag_id="clean_dag_run", start_date=days_ago(1), on_failure_callback=dag_on_failure, on_success_callback=dag_on_success, schedule_interval=None) run_this = PythonOperator(task_id='clean_dag_run', python_callable=clean_dag_run, provide_context=True, dag=dag)