Skip to content

Commit

Permalink
Resolve upstream tasks when template field is XComArg (#8805)
Browse files Browse the repository at this point in the history
* Resolve upstream tasks when template field is XComArg

closes: #8054

* fixup! Resolve upstream tasks when template field is XComArg

* Resolve task relations in DagRun and DagBag

* Add tests for serialized DAG

* Set dependencies only in bag_dag, refactor tests

* Traverse template_fields attribute

* Use provide_test_dag_bag in all tests

* fixup! Use provide_test_dag_bag in all tests

* Use metaclass + setattr

* Add prepare_for_execution method

* Check signature of __init__ not class

* Apply suggestions from code review

Co-authored-by: Ash Berlin-Taylor <[email protected]>

* Update airflow/models/baseoperator.py

Co-authored-by: Ash Berlin-Taylor <[email protected]>
  • Loading branch information
turbaszek and ashb committed Jun 15, 2020
1 parent aee6ab9 commit 431ea32
Show file tree
Hide file tree
Showing 7 changed files with 211 additions and 11 deletions.
51 changes: 51 additions & 0 deletions airflow/example_dags/example_xcomargs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Example DAG demonstrating the usage of the XComArgs."""

from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.utils.dates import days_ago

args = {
'owner': 'airflow',
'start_date': days_ago(2),
}


def dummy(*args, **kwargs):
"""Dummy function"""
return "pass"


with DAG(
dag_id='example_xcom_args',
default_args=args,
schedule_interval=None,
tags=['example']
) as dag:
task1 = PythonOperator(
task_id='task1',
python_callable=dummy,
)

task2 = PythonOperator(
task_id='task2',
python_callable=dummy,
op_kwargs={"dummy": task1.output},
)
96 changes: 93 additions & 3 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""
Base operator for all operators.
"""
import abc
import copy
import functools
import logging
Expand Down Expand Up @@ -60,9 +61,29 @@
ScheduleInterval = Union[str, timedelta, relativedelta]


class BaseOperatorMeta(abc.ABCMeta):
"""
Base metaclass of BaseOperator.
"""

def __call__(cls, *args, **kwargs):
"""
Called when you call BaseOperator(). In this way we are able to perform an action
after initializing an operator no matter where the ``super().__init__`` is called
(before or after assign of new attributes in a custom operator).
"""
obj: BaseOperator = type.__call__(cls, *args, **kwargs)
# Here we set upstream task defined by XComArgs passed to template fields of the operator
obj.set_xcomargs_dependencies()

# Mark instance as instantiated https://docs.python.org/3/tutorial/classes.html#private-variables
obj._BaseOperator__instantiated = True
return obj


# pylint: disable=too-many-instance-attributes,too-many-public-methods
@functools.total_ordering
class BaseOperator(Operator, LoggingMixin):
class BaseOperator(Operator, LoggingMixin, metaclass=BaseOperatorMeta):
"""
Abstract base class for all operators. Since operators create objects that
become nodes in the dag, BaseOperator contains many recursive methods for
Expand Down Expand Up @@ -292,6 +313,12 @@ class derived from this one results in the creation of a task object,
# Defines if the operator supports lineage without manual definitions
supports_lineage = False

# If True then the class constructor was called
__instantiated = False

# Set to True before calling execute method
_lock_for_execution = False

# noinspection PyUnusedLocal
# pylint: disable=too-many-arguments,too-many-locals, too-many-statements
@apply_defaults
Expand Down Expand Up @@ -547,6 +574,18 @@ def __lt__(self, other):

return self

def __setattr__(self, key, value):
super().__setattr__(key, value)
if self._lock_for_execution:
# Skip any custom behaviour during execute
return
if self.__instantiated and key in self.template_fields:
# Resolve upstreams set by assigning an XComArg after initializing
# an operator, example:
# op = BashOperator()
# op.bash_command = "sleep 1"
self.set_xcomargs_dependencies()

def add_inlets(self, inlets: Iterable[Any]):
"""
Sets inlets to this operator
Expand Down Expand Up @@ -633,6 +672,56 @@ def deps(self) -> Set[BaseTIDep]:
NotPreviouslySkippedDep(),
}

def prepare_for_execution(self) -> "BaseOperator":
"""
Lock task for execution to disable custom action in __setattr__ and
returns a copy of the task
"""
other = copy.copy(self)
other._lock_for_execution = True # pylint: disable=protected-access
return other

def set_xcomargs_dependencies(self) -> None:
"""
Resolves upstream dependencies of a task. In this way passing an ``XComArg``
as value for a template field will result in creating upstream relation between
two tasks.
**Example**: ::
with DAG(...):
generate_content = GenerateContentOperator(task_id="generate_content")
send_email = EmailOperator(..., html_content=generate_content.output)
# This is equivalent to
with DAG(...):
generate_content = GenerateContentOperator(task_id="generate_content")
send_email = EmailOperator(
..., html_content="{{ task_instance.xcom_pull('generate_content') }}"
)
generate_content >> send_email
"""
from airflow.models.xcom_arg import XComArg

def apply_set_upstream(arg: Any):
if isinstance(arg, XComArg):
self.set_upstream(arg.operator)
elif isinstance(arg, (tuple, set, list)):
for elem in arg:
apply_set_upstream(elem)
elif isinstance(arg, dict):
for elem in arg.values():
apply_set_upstream(elem)
elif hasattr(arg, "template_fields"):
for elem in arg.template_fields:
apply_set_upstream(elem)

for field in self.template_fields:
if hasattr(self, field):
arg = getattr(self, field)
apply_set_upstream(arg)

@property
def priority_weight_total(self) -> int:
"""
Expand Down Expand Up @@ -1140,7 +1229,7 @@ def set_upstream(self, task_or_task_list: Union['BaseOperator', List['BaseOperat

@property
def output(self):
"""Returns default XComArg for the operator"""
"""Returns reference to XCom pushed by current operator"""
from airflow.models.xcom_arg import XComArg
return XComArg(operator=self)

Expand Down Expand Up @@ -1205,7 +1294,8 @@ def get_serialized_fields(cls):
if not cls.__serialized_fields:
cls.__serialized_fields = frozenset(
vars(BaseOperator(task_id='test')).keys() - {
'inlets', 'outlets', '_upstream_task_ids', 'default_args', 'dag', '_dag'
'inlets', 'outlets', '_upstream_task_ids', 'default_args', 'dag', '_dag',
'_BaseOperator__instantiated',
} | {'_task_type', 'subdag', 'ui_color', 'ui_fgcolor', 'template_fields'})

return cls.__serialized_fields
Expand Down
5 changes: 2 additions & 3 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# specific language governing permissions and limitations
# under the License.

import copy
import getpass
import hashlib
import logging
Expand Down Expand Up @@ -970,7 +969,7 @@ def _run_raw_task(
if not mark_success:
context = self.get_template_context()

task_copy = copy.copy(task)
task_copy = task.prepare_for_execution()

# Sensors in `poke` mode can block execution of DAGs when running
# with single process executor, thus we change the mode to`reschedule`
Expand Down Expand Up @@ -1154,7 +1153,7 @@ def run(

def dry_run(self):
task = self.task
task_copy = copy.copy(task)
task_copy = task.prepare_for_execution()
self.task = task_copy

self.render_templates()
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/operators/sql_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from airflow.utils.decorators import apply_defaults


class BaseSQLToGCSOperator(BaseOperator, metaclass=abc.ABCMeta):
class BaseSQLToGCSOperator(BaseOperator):
"""
:param sql: The SQL to execute.
:type sql: str
Expand Down
4 changes: 2 additions & 2 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
_decorated_fields = {'executor_config'}

_CONSTRUCTOR_PARAMS = {
k: v.default for k, v in signature(BaseOperator).parameters.items()
k: v.default for k, v in signature(BaseOperator.__init__).parameters.items()
if v.default is not v.empty
}

Expand Down Expand Up @@ -537,7 +537,7 @@ def __get_constructor_defaults(): # pylint: disable=no-method-argument
'access_control': '_access_control',
}
return {
param_to_attr.get(k, k): v.default for k, v in signature(DAG).parameters.items()
param_to_attr.get(k, k): v.default for k, v in signature(DAG.__init__).parameters.items()
if v.default is not v.empty
}

Expand Down
61 changes: 60 additions & 1 deletion tests/models/test_baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,21 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import unittest
import uuid
from datetime import date, datetime
from unittest import mock

import jinja2
import pytest
from parameterized import parameterized

from airflow.exceptions import AirflowException
from airflow.lineage.entities import File
from airflow.models import DAG
from airflow.models.baseoperator import chain, cross_downstream
from airflow.operators.dummy_operator import DummyOperator
from airflow.utils.decorators import apply_defaults
from tests.models import DEFAULT_DATE
from tests.test_utils.mock_operators import MockNamedTuple, MockOperator

Expand Down Expand Up @@ -347,3 +348,61 @@ def test_lineage_composition(self):
task4 = DummyOperator(task_id="op4", dag=dag)
task4 > [inlet, outlet, extra]
self.assertEqual(task4.get_outlet_defs(), [inlet, outlet, extra])


class CustomOp(DummyOperator):
template_fields = ("field", "field2")

@apply_defaults
def __init__(self, field=None, field2=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.field = field
self.field2 = field2

def execute(self, context):
self.field = None


class TestXComArgsRelationsAreResolved:
def test_setattr_performs_no_custom_action_at_execute_time(self):
op = CustomOp(task_id="test_task")
op_copy = op.prepare_for_execution()

with mock.patch(
"airflow.models.baseoperator.BaseOperator.set_xcomargs_dependencies"
) as method_mock:
op_copy.execute({})
assert method_mock.call_count == 0

def test_upstream_is_set_when_template_field_is_xcomarg(self):
with DAG("xcomargs_test", default_args={"start_date": datetime.today()}):
op1 = DummyOperator(task_id="op1")
op2 = CustomOp(task_id="op2", field=op1.output)

assert op1 in op2.upstream_list
assert op2 in op1.downstream_list

def test_set_xcomargs_dependencies_works_recursively(self):
with DAG("xcomargs_test", default_args={"start_date": datetime.today()}):
op1 = DummyOperator(task_id="op1")
op2 = DummyOperator(task_id="op2")
op3 = CustomOp(task_id="op3", field=[op1.output, op2.output])
op4 = CustomOp(task_id="op4", field={"op1": op1.output, "op2": op2.output})

assert op1 in op3.upstream_list
assert op2 in op3.upstream_list
assert op1 in op4.upstream_list
assert op2 in op4.upstream_list

def test_set_xcomargs_dependencies_works_when_set_after_init(self):
with DAG(dag_id='xcomargs_test', default_args={"start_date": datetime.today()}):
op1 = DummyOperator(task_id="op1")
op2 = CustomOp(task_id="op2")
op2.field = op1.output # value is set after init

assert op1 in op2.upstream_list

def test_set_xcomargs_dependencies_error_when_outside_dag(self):
with pytest.raises(AirflowException):
op1 = DummyOperator(task_id="op1")
CustomOp(task_id="op2", field=op1.output)
3 changes: 2 additions & 1 deletion tests/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,8 @@ def test_no_new_fields_added_to_base_operator(self):
"""
base_operator = BaseOperator(task_id="10")
fields = base_operator.__dict__
self.assertEqual({'_dag': None,
self.assertEqual({'_BaseOperator__instantiated': True,
'_dag': None,
'_downstream_task_ids': set(),
'_inlets': [],
'_log': base_operator.log,
Expand Down

0 comments on commit 431ea32

Please sign in to comment.