我有一个名为final
的任务,该任务具有多个上游连接。当上游之一被ShortCircuitOperator
跳过时,该任务也会被跳过。我不希望跳过final
任务,因为它必须报告DAG成功。
为避免被跳过,我使用了trigger_rule='all_done'
,但仍然被跳过。
如果我使用BranchPythonOperator
而不是ShortCircuitOperator
final
,则不会跳过任务。即使不是最佳选择,分支工作流似乎也是一个解决方案,但是现在final
将不再考虑上游任务的失败。
如何使它仅在成功运行或跳过上游时运行?
示例短路DAG:
from airflow import DAG
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python_operator import ShortCircuitOperator
from datetime import datetime
from random import randint
default_args = {
'owner': 'airflow',
'start_date': datetime(2018, 8, 1)}
dag = DAG(
'shortcircuit_test',
default_args=default_args,
schedule_interval='* * * * *',
catchup=False)
def shortcircuit_fn():
return randint(0, 1) == 1
task_1 = DummyOperator(dag=dag, task_id='task_1')
task_2 = DummyOperator(dag=dag, task_id='task_2')
work = DummyOperator(dag=dag, task_id='work')
short = ShortCircuitOperator(dag=dag, task_id='short_circuit', python_callable=shortcircuit_fn)
final = DummyOperator(dag=dag, task_id="final", trigger_rule="all_done")
task_1 >> short >> work >> final
task_1 >> task_2 >> final
示例分支DAG:
from airflow import DAG
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.python_operator import BranchPythonOperator
from datetime import datetime
from random import randint
default_args = {
'owner': 'airflow',
'start_date': datetime(2018, 8, 1)}
dag = DAG(
'branch_test',
default_args=default_args,
schedule_interval='* * * * *',
catchup=False)
# these two are only here to protect tasks from getting skipped as direct dependencies of branch operator
to_do_work = DummyOperator(dag=dag, task_id='to_do_work')
to_skip_work = DummyOperator(dag=dag, task_id='to_skip_work')
def branch_fn():
return to_do_work.task_id if randint(0, 1) == 1 else to_skip_work.task_id
task_1 = DummyOperator(dag=dag, task_id='task_1')
task_2 = DummyOperator(dag=dag, task_id='task_2')
work = DummyOperator(dag=dag, task_id='work')
branch = BranchPythonOperator(dag=dag, task_id='branch', python_callable=branch_fn)
final = DummyOperator(dag=dag, task_id="final", trigger_rule="all_done")
task_1 >> branch >> to_do_work >> work >> final
branch >> to_skip_work >> final
task_1 >> task_2 >> final
答案 0 :(得分:2)
这个问题在气流1.10仍然是合法的
ShortCircuitOperator将跳过所有下游任务,无论是否设置了trigger_rule
@ michael-spector的解决方案仅适用于简单情况,不适用于这种情况:
使用@ michael-spector不会跳过任务 L (仅会跳过 E,F,G,H 任务)
这是一个解决方案(基于@ michael-spector命题):
class ShortCircuitOperatorOnlyDirectDownStream(PythonOperator, SkipMixin):
"""
Work like a ShortCircuitOperator but it will only skip the task that have in their upstream this task
So if a task have this task in his upstream AND another task if will not be skipped
-> B -> C -> D ------\
/ \
A -> K -> Y
\ /
-> F -> G - P -----------/
If K is a normal ShortCircuitOperator and condition is False then B , C , D and Y will be skip
if K is ShortCircuitOperatorOnlyDirectDownStream and condition is False then B , C , D will be skip , but not Y
found_tasks_name contains the names of the previous skipped task
found_tasks contains the airflow_task_id of the previous skipped task
:return found_tasks
"""
def find_tasks_to_skip(self, task, found_tasks_to_skip=None, found_tasks_to_skip_names=None):
if not found_tasks_to_skip: # list of task_id to skip
found_tasks_to_skip = []
# necessary because found_tasks do not keep a copy of names but airflow task_id
if not found_tasks_to_skip_names:
found_tasks_to_skip_names = set()
direct_relatives = task.get_direct_relatives(upstream=False)
for t in direct_relatives:
self.log.info("UPSTREAM : " + str(t.upstream_task_ids))
self.log.info(
" Does all skipped task " +
str(found_tasks_to_skip_names) +
" contain the upstream tasks" +
str(t.upstream_task_ids)
)
# if len == 1 then the task is only precede by a skipped task
# otherwise check if ALL upstream task are skipped
if len(t.upstream_task_ids) == 1 or all(elem in found_tasks_to_skip_names for elem in t.upstream_task_ids):
found_tasks_to_skip.append(t)
found_tasks_to_skip_names.add(t.task_id)
self.find_tasks_to_skip(t, found_tasks_to_skip, found_tasks_to_skip_names)
return found_tasks_to_skip
def execute(self, context):
condition = super(ShortCircuitOperatorOnlyDirectDownStream, self).execute(context)
self.log.info("Condition result is %s", condition)
if condition:
self.log.info('Proceeding with downstream tasks...')
return
self.log.info(
'Skipping downstream tasks that only rely on this path...')
tasks_to_skip = self.find_tasks_to_skip(context['task'])
self.log.debug("Tasks to skip: %s", tasks_to_skip)
if tasks_to_skip:
self.skip(context['dag_run'], context['ti'].execution_date,
tasks_to_skip)
self.log.info("Done.")
答案 1 :(得分:2)
我为此发布了另一种可能的解决方法,因为这是一种不需要自定义操作符实现的方法。
我受此博客中解决方案的影响,使用PythonOperator引发了AirflowSkipException,该异常会跳过任务本身,然后分别跳过下游任务。
https://godatadriven.com/blog/the-zen-of-python-and-apache-airflow/
这将遵循最终下游任务的trigger_rule,在我的情况下,我将其设置为def fn_short_circuit(**context):
if <<<some condition>>>:
raise AirflowSkipException("Skip this task and individual downstream tasks while respecting trigger rules.")
check_date = PythonOperator(
task_id="check_if_min_date",
python_callable=_check_date,
provide_context=True,
dag=dag,
)
task1 = DummyOperator(task_id="task1", dag=dag)
task2 = DummyOperator(task_id="task2", dag=dag)
work = DummyOperator(dag=dag, task_id='work')
short = ShortCircuitOperator(dag=dag, task_id='short_circuit', python_callable=fn_short_circuit
final_task = DummyOperator(task_id="final_task",
trigger_rule='none_failed',
dag=dag)
task_1 >> short >> work >> final_task
task_1 >> task_2 >> final_task
。
根据博客修改的示例,其中包含最终任务:
.container
答案 2 :(得分:1)
我通过执行final
任务来检查上游实例的状态来使其工作。通过访问Airflow DB来访问我的状态的唯一方法不是很漂亮。
# # additional imports to ones in question code
# from airflow import AirflowException
# from airflow.models import TaskInstance
# from airflow.operators.python_operator import PythonOperator
# from airflow.settings import Session
# from airflow.utils.state import State
# from airflow.utils.trigger_rule import TriggerRule
def all_upstreams_either_succeeded_or_skipped(dag, task, task_instance, **context):
"""
find directly upstream task instances and count how many are not in prefered statuses.
return True if we got no instances with non-preferred statuses.
"""
upstream_task_ids = [t.task_id for t in task.get_direct_relatives(upstream=True)]
session = Session()
query = (session
.query(TaskInstance)
.filter(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.execution_date.in_([task_instance.execution_date]),
TaskInstance.task_id.in_(upstream_task_ids)
)
)
upstream_task_instances = query.all()
unhappy_task_instances = [ti for ti in upstream_task_instances if ti.state not in [State.SUCCESS, State.SKIPPED]]
print(unhappy_task_instances)
return len(unhappy_task_instances) == 0
def final_fn(**context):
"""
fail if upstream task instances have unwanted statuses
"""
if not all_upstreams_either_succeeded_or_skipped(**context):
raise AirflowException("Not all upstream tasks succeeded.")
# Do things
# will run when upstream task instances are done, including failed
final = PythonOperator(
dag=dag,
task_id="final",
trigger_rule=TriggerRule.ALL_DONE,
python_callable=final_fn,
provide_context=True)
答案 3 :(得分:1)
我最终基于原始模型开发了定制的ShortCircuitOperator:
class ShortCircuitOperator(PythonOperator, SkipMixin):
"""
Allows a workflow to continue only if a condition is met. Otherwise, the
workflow "short-circuits" and downstream tasks that only rely on this operator
are skipped.
The ShortCircuitOperator is derived from the PythonOperator. It evaluates a
condition and short-circuits the workflow if the condition is False. Any
downstream tasks that only rely on this operator are marked with a state of "skipped".
If the condition is True, downstream tasks proceed as normal.
The condition is determined by the result of `python_callable`.
"""
def find_tasks_to_skip(self, task, found_tasks=None):
if not found_tasks:
found_tasks = []
direct_relatives = task.get_direct_relatives(upstream=False)
for t in direct_relatives:
if len(t.upstream_task_ids) == 1:
found_tasks.append(t)
self.find_tasks_to_skip(t, found_tasks)
return found_tasks
def execute(self, context):
condition = super(ShortCircuitOperator, self).execute(context)
self.log.info("Condition result is %s", condition)
if condition:
self.log.info('Proceeding with downstream tasks...')
return
self.log.info(
'Skipping downstream tasks that only rely on this path...')
tasks_to_skip = self.find_tasks_to_skip(context['task'])
self.log.debug("Tasks to skip: %s", tasks_to_skip)
if tasks_to_skip:
self.skip(context['dag_run'], context['ti'].execution_date,
tasks_to_skip)
self.log.info("Done.")
该操作员确保不会因一个跳过的任务而跳过依赖多个路径的下游任务。
答案 4 :(得分:0)
这可能是在您问了最初的问题之后 添加的,但是现在Airflow的trigger_rule值为none_failed
很方便。如果您在最终任务上设置了此选项,则应该跳过还是成功完成上游任务,而不是在失败时完成。