Skip to content

Commit

Permalink
Change BaseOperatorLink interface to take a ti_key, not a datetime (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ashb committed Mar 1, 2022
1 parent 5befc7f commit 08575dd
Show file tree
Hide file tree
Showing 36 changed files with 508 additions and 513 deletions.
10 changes: 10 additions & 0 deletions UPDATING.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,16 @@ This setting is also used for the deprecated experimental API, which only uses t

To allow the Airflow UI to use the API, the previous default authorization backend `airflow.api.auth.backend.deny_all` is changed to `airflow.api.auth.backend.session`, and this is automatically added to the list of API authorization backends if a non-default value is set.

### BaseOperatorLink's `get_link` method changed to take a `ti_key` keyword argument

In v2.2 we "deprecated" passing an execution date to XCom.get methods, but there was no other option for operator links as they were only passed an execution_date.

Now in 2.3 as part of Dynamic Task Mapping (AIP-42) we will need to add map_index to the XCom row to support the "reduce" part of the API.

In order to support that cleanly we have changed the interface for BaseOperatorLink to take an TaskInstanceKey as the `ti_key` keyword argument (as execution_date + task is no longer unique for mapped operators).

The existing signature will be detected (by the absence of the `ti_key` argument) and continue to work.

## Airflow 2.2.4

### Smart sensors deprecated
Expand Down
18 changes: 13 additions & 5 deletions airflow/api_connexion/endpoints/extra_link_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from airflow.api_connexion.types import APIResponse
from airflow.exceptions import TaskNotFound
from airflow.models.dagbag import DagBag
from airflow.models.dagrun import DagRun as DR
from airflow.security import permissions
from airflow.utils.session import NEW_SESSION, provide_session

Expand All @@ -45,6 +44,8 @@ def get_extra_links(
session: Session = NEW_SESSION,
) -> APIResponse:
"""Get extra links for task instance"""
from airflow.models.taskinstance import TaskInstance

dagbag: DagBag = current_app.dag_bag
dag: DAG = dagbag.get_dag(dag_id)
if not dag:
Expand All @@ -55,14 +56,21 @@ def get_extra_links(
except TaskNotFound:
raise NotFound("Task not found", detail=f'Task with ID = "{task_id}" not found')

execution_date = (
session.query(DR.execution_date).filter(DR.dag_id == dag_id).filter(DR.run_id == dag_run_id).scalar()
ti = (
session.query(TaskInstance)
.filter(
TaskInstance.dag_id == dag_id,
TaskInstance.run_id == dag_run_id,
TaskInstance.task_id == task_id,
)
.one_or_none()
)
if not execution_date:

if not ti:
raise NotFound("DAG Run not found", detail=f'DAG Run with ID = "{dag_run_id}" not found')

all_extra_link_pairs = (
(link_name, task.get_extra_links(execution_date, link_name)) for link_name in task.extra_links
(link_name, task.get_extra_links(ti, link_name)) for link_name in task.extra_links
)
all_extra_links = {
link_name: link_url if link_url else None for link_name, link_url in all_extra_link_pairs
Expand Down
46 changes: 30 additions & 16 deletions airflow/models/abstractoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.

import datetime
import inspect
from typing import TYPE_CHECKING, Any, Callable, Collection, Dict, Iterable, List, Optional, Set, Type, Union

from sqlalchemy.orm import Session
Expand All @@ -32,23 +33,26 @@
from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.weight_rule import WeightRule

TaskStateChangeCallback = Callable[[Context], None]

if TYPE_CHECKING:
import jinja2 # Slow import.

from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
from airflow.models.dag import DAG
from airflow.models.operator import Operator
from airflow.models.taskinstance import TaskInstance

DEFAULT_OWNER = conf.get("operators", "default_owner")
DEFAULT_POOL_SLOTS = 1
DEFAULT_PRIORITY_WEIGHT = 1
DEFAULT_QUEUE = conf.get("operators", "default_queue")
DEFAULT_RETRIES = conf.getint("core", "default_task_retries", fallback=0)
DEFAULT_RETRY_DELAY = datetime.timedelta(seconds=300)
DEFAULT_WEIGHT_RULE = conf.get("core", "default_task_weight_rule", fallback=WeightRule.DOWNSTREAM)
DEFAULT_TRIGGER_RULE = TriggerRule.ALL_SUCCESS

TaskStateChangeCallback = Callable[[Context], None]
DEFAULT_OWNER: str = conf.get("operators", "default_owner")
DEFAULT_POOL_SLOTS: int = 1
DEFAULT_PRIORITY_WEIGHT: int = 1
DEFAULT_QUEUE: str = conf.get("operators", "default_queue")
DEFAULT_RETRIES: int = conf.getint("core", "default_task_retries", fallback=0)
DEFAULT_RETRY_DELAY: datetime.timedelta = datetime.timedelta(seconds=300)
DEFAULT_WEIGHT_RULE: WeightRule = WeightRule(
conf.get("core", "default_task_weight_rule", fallback=WeightRule.DOWNSTREAM)
)
DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS


class AbstractOperator(LoggingMixin, DAGNode):
Expand Down Expand Up @@ -239,19 +243,29 @@ def global_operator_extra_link_dict(self) -> Dict[str, Any]:
def extra_links(self) -> List[str]:
return list(set(self.operator_extra_link_dict).union(self.global_operator_extra_link_dict))

def get_extra_links(self, dttm: datetime.datetime, link_name: str) -> Optional[Dict[str, Any]]:
def get_extra_links(self, ti: "TaskInstance", link_name: str) -> Optional[str]:
"""For an operator, gets the URLs that the ``extra_links`` entry points to.
:meta private:
:raise ValueError: The error message of a ValueError will be passed on through to
the fronted to show up as a tooltip on the disabled link.
:param dttm: The datetime parsed execution date for the URL being searched for.
:param ti: The TaskInstance for the URL being searched for.
:param link_name: The name of the link we're looking for the URL for. Should be
one of the options specified in ``extra_links``.
"""
if link_name in self.operator_extra_link_dict:
return self.operator_extra_link_dict[link_name].get_link(self, dttm)
elif link_name in self.global_operator_extra_link_dict:
return self.global_operator_extra_link_dict[link_name].get_link(self, dttm)
link: Optional["BaseOperatorLink"] = self.operator_extra_link_dict.get(link_name)
if not link:
link = self.global_operator_extra_link_dict.get(link_name)
if not link:
return None
# Check for old function signature
parameters = inspect.signature(link.get_link).parameters
args = [name for name, p in parameters.items() if p.kind != p.VAR_KEYWORD]
if "ti_key" in args:
return link.get_link(self, ti_key=ti.key)
else:
return link.get_link(self, ti.dag_run.logical_date) # type: ignore[misc]
return None

def render_template_fields(
Expand Down
8 changes: 6 additions & 2 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
import jinja2 # Slow import.

from airflow.models.dag import DAG
from airflow.models.taskinstance import TaskInstanceKey
from airflow.utils.task_group import TaskGroup

ScheduleInterval = Union[str, timedelta, relativedelta]
Expand Down Expand Up @@ -1730,11 +1731,14 @@ def name(self) -> str:
"""

@abstractmethod
def get_link(self, operator: BaseOperator, dttm: datetime) -> str:
def get_link(self, operator: AbstractOperator, *, ti_key: "TaskInstanceKey") -> str:
"""
Link to external system.
Note: The old signature of this function was ``(self, operator, dttm: datetime)``. That is still
supported at runtime but is deprecated.
:param operator: airflow operator
:param dttm: datetime
:param ti_key: TaskInstance ID to return link for
:return: link to external system
"""
51 changes: 36 additions & 15 deletions airflow/models/xcom.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@
# run without storing it in the database.
IN_MEMORY_DAGRUN_ID = "__airflow_in_memory_dagrun__"

if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstanceKey


class BaseXCom(Base, LoggingMixin):
"""Base class for XCom objects."""
Expand Down Expand Up @@ -205,11 +208,8 @@ def set(
def get_one(
cls,
*,
run_id: str,
key: Optional[str] = None,
task_id: Optional[str] = None,
dag_id: Optional[str] = None,
include_prior_dates: bool = False,
ti_key: "TaskInstanceKey",
session: Session = NEW_SESSION,
) -> Optional[Any]:
"""Retrieve an XCom value, optionally meeting certain criteria.
Expand All @@ -223,20 +223,29 @@ def get_one(
A deprecated form of this function accepts ``execution_date`` instead of
``run_id``. The two arguments are mutually exclusive.
:param run_id: DAG run ID for the task.
:param ti_key: The TaskInstanceKey to look up the XCom for
:param key: A key for the XCom. If provided, only XCom with matching
keys will be returned. Pass *None* (default) to remove the filter.
:param task_id: Only XCom from task with matching ID will be pulled.
Pass *None* (default) to remove the filter.
:param dag_id: Only pull XCom from this DAG. If *None* (default), the
DAG of the calling task is used.
:param include_prior_dates: If *False* (default), only XCom from the
specified DAG run is returned. If *True*, the latest matching XCom is
returned regardless of the run it belongs to.
:param session: Database session. If not given, a new session will be
created for this function.
"""

@overload
@classmethod
def get_one(
cls,
*,
key: Optional[str] = None,
task_id: str,
dag_id: str,
run_id: str,
session: Session = NEW_SESSION,
) -> Optional[Any]:
...

@overload
@classmethod
def get_one(
Expand All @@ -256,24 +265,35 @@ def get_one(
cls,
execution_date: Optional[datetime.datetime] = None,
key: Optional[str] = None,
task_id: Optional[Union[str, Iterable[str]]] = None,
dag_id: Optional[Union[str, Iterable[str]]] = None,
task_id: Optional[str] = None,
dag_id: Optional[str] = None,
include_prior_dates: bool = False,
session: Session = NEW_SESSION,
*,
run_id: Optional[str] = None,
ti_key: Optional["TaskInstanceKey"] = None,
) -> Optional[Any]:
""":sphinx-autoapi-skip:"""
if not exactly_one(execution_date is not None, run_id is not None):
raise ValueError("Exactly one of run_id or execution_date must be passed")

if run_id is not None:
if not exactly_one(execution_date is not None, ti_key is not None, run_id is not None):
raise ValueError("Exactly one of ti_key, run_id, or execution_date must be passed")

if ti_key is not None:
query = session.query(cls).filter_by(
dag_id=ti_key.dag_id,
run_id=ti_key.run_id,
task_id=ti_key.task_id,
)
if key:
query = query.filter_by(key=key)
query = query.limit(1)
elif run_id:
query = cls.get_many(
run_id=run_id,
key=key,
task_ids=task_id,
dag_ids=dag_id,
include_prior_dates=include_prior_dates,
limit=1,
session=session,
)
elif execution_date is not None:
Expand All @@ -288,6 +308,7 @@ def get_one(
task_ids=task_id,
dag_ids=dag_id,
include_prior_dates=include_prior_dates,
limit=1,
session=session,
)
else:
Expand Down
21 changes: 14 additions & 7 deletions airflow/operators/trigger_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import datetime
import json
import time
from typing import Dict, List, Optional, Sequence, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union, cast

from airflow.api.common.trigger_dag import trigger_dag
from airflow.exceptions import AirflowException, DagNotFound, DagRunAlreadyExists
Expand All @@ -35,6 +35,11 @@
XCOM_RUN_ID = "trigger_run_id"


if TYPE_CHECKING:
from airflow.models.abstractoperator import AbstractOperator
from airflow.models.taskinstance import TaskInstanceKey


class TriggerDagRunLink(BaseOperatorLink):
"""
Operator link for TriggerDagRunOperator. It allows users to access
Expand All @@ -43,14 +48,16 @@ class TriggerDagRunLink(BaseOperatorLink):

name = 'Triggered DAG'

def get_link(self, operator, dttm):
def get_link(
self,
operator: "AbstractOperator",
*,
ti_key: "TaskInstanceKey",
) -> str:
# Fetch the correct execution date for the triggerED dag which is
# stored in xcom during execution of the triggerING task.
trigger_execution_date_iso = XCom.get_one(
execution_date=dttm, key=XCOM_EXECUTION_DATE_ISO, task_id=operator.task_id, dag_id=operator.dag_id
)

query = {"dag_id": operator.trigger_dag_id, "base_date": trigger_execution_date_iso}
when = XCom.get_one(ti_key=ti_key, key=XCOM_EXECUTION_DATE_ISO)
query = {"dag_id": cast(TriggerDagRunOperator, operator).trigger_dag_id, "base_date": when}
return build_airflow_url_with_query(query)


Expand Down
18 changes: 14 additions & 4 deletions airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from airflow.providers.amazon.aws.hooks.emr import EmrHook

if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstanceKey
from airflow.utils.context import Context


Expand Down Expand Up @@ -230,17 +231,26 @@ class EmrClusterLink(BaseOperatorLink):

name = 'EMR Cluster'

def get_link(self, operator: BaseOperator, dttm: datetime) -> str:
def get_link(
self,
operator,
dttm: Optional[datetime] = None,
ti_key: Optional["TaskInstanceKey"] = None,
) -> str:
"""
Get link to EMR cluster.
:param operator: operator
:param dttm: datetime
:return: url link
"""
flow_id = XCom.get_one(
key="return_value", dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm
)
if ti_key:
flow_id = XCom.get_one(key="return_value", ti_key=ti_key)
else:
assert dttm
flow_id = XCom.get_one(
key="return_value", dag_id=operator.dag_id, task_id=operator.task_id, execution_date=dttm
)
return (
f'https://console.aws.amazon.com/elasticmapreduce/home#cluster-details:{flow_id}'
if flow_id
Expand Down
12 changes: 8 additions & 4 deletions airflow/providers/dbt/cloud/operators/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,14 @@ class DbtCloudRunJobOperatorLink(BaseOperatorLink):

name = "Monitor Job Run"

def get_link(self, operator, dttm):
job_run_url = XCom.get_one(
dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm, key="job_run_url"
)
def get_link(self, operator, dttm=None, *, ti_key=None):
if ti_key:
job_run_url = XCom.get_one(key="job_run_url", ti_key=ti_key)
else:
assert dttm
job_run_url = XCom.get_one(
dag_id=operator.dag.dag_id, task_id=operator.task_id, execution_date=dttm, key="job_run_url"
)

return job_run_url

Expand Down

0 comments on commit 08575dd

Please sign in to comment.