Skip to content

Commit

Permalink
Protect against accidental misuse of XCom.get_value() (#22244)
Browse files Browse the repository at this point in the history
The XCom.get_value has been added in 2.3.0 and while there are
cases it should be used in the providers when task are mapped,
in order to keep compatibility with earlier versions of Airlfow,
the XCom.get_value() should only be used when ti_key is not None.

We check for the construct used in community providers automatically
and also add a documentation for users who would like to use
dynamic task mapping featuers in their own providers.
  • Loading branch information
potiuk committed Mar 14, 2022
1 parent 4f6d24f commit c1ab8e2
Show file tree
Hide file tree
Showing 15 changed files with 132 additions and 23 deletions.
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,13 @@ repos:
pass_filenames: false
entry: ./scripts/ci/pre_commit/pre_commit_check_setup_extra_packages_ref.py
additional_dependencies: ['rich==9.2.0']
- id: forbidden-xcom-get-value
name: Check if XCom.get_value is used in backwards-compatible way
language: python
files: ^airflow/providers/.*\.py$
pass_filenames: true
entry: ./scripts/ci/pre_commit/pre_commit_check_xcom_get_value.py
additional_dependencies: ['rich']
- id: update-breeze-file
name: Update output of breeze command in BREEZE.rst
entry: ./scripts/ci/pre_commit/pre_commit_breeze_cmd_line.sh
Expand Down
26 changes: 13 additions & 13 deletions BREEZE.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2294,19 +2294,19 @@ This is the current syntax for `./breeze <./breeze>`_:
check-executables-have-shebangs check-extras-order check-hooks-apply
check-integrations check-merge-conflict check-xml daysago-import-check
debug-statements detect-private-key docstring-params doctoc dont-use-safe-filter
end-of-file-fixer fix-encoding-pragma flake8 flynt codespell forbid-tabs helm-lint
identity incorrect-use-of-LoggingMixin insert-license isort json-schema
language-matters lint-dockerfile lint-openapi markdownlint mermaid
migration-reference mixed-line-ending mypy mypy-helm no-providers-in-core-examples
no-relative-imports persist-credentials-disabled pre-commit-descriptions
pre-commit-hook-names pretty-format-json provide-create-sessions
providers-changelogs providers-init-file providers-subpackages-init-file
provider-yamls pydevd pydocstyle python-no-log-warn pyupgrade restrict-start_date
rst-backticks setup-order setup-extra-packages shellcheck sort-in-the-wild
sort-spelling-wordlist stylelint trailing-whitespace ui-lint update-breeze-file
update-extras update-local-yml-file update-setup-cfg-file update-supported-versions
update-versions vendor-k8s-json-schema verify-db-migrations-documented version-sync
www-lint yamllint yesqa
end-of-file-fixer fix-encoding-pragma flake8 flynt forbidden-xcom-get-value
codespell forbid-tabs helm-lint identity incorrect-use-of-LoggingMixin
insert-license isort json-schema language-matters lint-dockerfile lint-openapi
markdownlint mermaid migration-reference mixed-line-ending mypy mypy-helm
no-providers-in-core-examples no-relative-imports persist-credentials-disabled
pre-commit-descriptions pre-commit-hook-names pretty-format-json
provide-create-sessions providers-changelogs providers-init-file
providers-subpackages-init-file provider-yamls pydevd pydocstyle python-no-log-warn
pyupgrade restrict-start_date rst-backticks setup-order setup-extra-packages
shellcheck sort-in-the-wild sort-spelling-wordlist stylelint trailing-whitespace
ui-lint update-breeze-file update-extras update-local-yml-file update-setup-cfg-file
update-supported-versions update-versions vendor-k8s-json-schema
verify-db-migrations-documented version-sync www-lint yamllint yesqa
You can pass extra arguments including options to the pre-commit framework as
<EXTRA_ARGS> passed after --. For example:
Expand Down
2 changes: 2 additions & 0 deletions STATIC_CODE_CHECKS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,8 @@ require Breeze Docker images to be installed locally.
------------------------------------ ---------------------------------------------------------------- ------------
``fix-encoding-pragma`` Removes encoding header from python files
------------------------------------ ---------------------------------------------------------------- ------------
``forbidden-xcom-get-value`` Check if XCom.get_value is used in backwards-compatible way
------------------------------------ ---------------------------------------------------------------- ------------
``pyupgrade`` Runs pyupgrade
------------------------------------ ---------------------------------------------------------------- ------------
``flake8`` Runs flake8 *
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def get_link(
:param dttm: datetime
:return: url link
"""
if ti_key:
if ti_key is not None:
flow_id = XCom.get_value(key="return_value", ti_key=ti_key)
else:
assert dttm
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/dbt/cloud/operators/dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class DbtCloudRunJobOperatorLink(BaseOperatorLink):
name = "Monitor Job Run"

def get_link(self, operator, dttm=None, *, ti_key=None):
if ti_key:
if ti_key is not None:
job_run_url = XCom.get_value(key="job_run_url", ti_key=ti_key)
else:
assert dttm
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/links/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def get_link(
dttm: Optional[datetime] = None,
ti_key: Optional["TaskInstanceKey"] = None,
) -> str:
if ti_key:
if ti_key is not None:
conf = XCom.get_value(key=self.key, ti_key=ti_key)
else:
assert dttm
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/google/cloud/links/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def get_link(
dttm: Optional[datetime] = None,
ti_key: Optional["TaskInstanceKey"] = None,
) -> str:
if ti_key:
if ti_key is not None:
conf = XCom.get_value(key=self.key, ti_key=ti_key)
else:
assert dttm
Expand Down Expand Up @@ -112,7 +112,7 @@ def get_link(
dttm: Optional[datetime] = None,
ti_key: Optional["TaskInstanceKey"] = None,
) -> str:
if ti_key:
if ti_key is not None:
list_conf = XCom.get_value(key=self.key, ti_key=ti_key)
else:
assert dttm
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/operators/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def get_link(
dttm: Optional[datetime] = None,
ti_key: Optional["TaskInstanceKey"] = None,
):
if ti_key:
if ti_key is not None:
job_ids = XCom.get_value(key='job_id', ti_key=ti_key)
else:
assert dttm is not None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def get_link(
dttm: Optional[datetime] = None,
ti_key: Optional["TaskInstanceKey"] = None,
) -> str:
if ti_key:
if ti_key is not None:
conf = XCom.get_value(key=self.key, ti_key=ti_key)
else:
assert dttm
Expand Down Expand Up @@ -140,7 +140,7 @@ def get_link(
dttm: Optional[datetime] = None,
ti_key: Optional["TaskInstanceKey"] = None,
) -> str:
if ti_key:
if ti_key is not None:
conf = XCom.get_value(key=self.key, ti_key=ti_key)
else:
assert dttm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_link(
*,
ti_key: Optional["TaskInstanceKey"] = None,
) -> str:
if ti_key:
if ti_key is not None:
run_id = XCom.get_value(key="run_id", ti_key=ti_key)
else:
assert dttm
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/qubole/operators/qubole.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_link(
host = re.sub(r'api$', 'v2/analyze?command_id=', conn.host)
else:
host = 'https://api.qubole.com/v2/analyze?command_id='
if ti_key:
if ti_key is not None:
qds_command_id = XCom.get_value(key='qbol_cmd_id', ti_key=ti_key)
else:
assert dttm
Expand Down
1 change: 1 addition & 0 deletions breeze-complete
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ end-of-file-fixer
fix-encoding-pragma
flake8
flynt
forbidden-xcom-get-value
codespell
forbid-tabs
helm-lint
Expand Down
1 change: 1 addition & 0 deletions dev/breeze/src/airflow_breeze/pre_commit_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
'flake8',
'flynt',
'forbid-tabs',
'forbidden-xcom-get-value',
'helm-lint',
'identity',
'incorrect-use-of-LoggingMixin',
Expand Down
38 changes: 38 additions & 0 deletions docs/apache-airflow-providers/howto/create-update-providers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,44 @@ should be implemented to keep compatibility with Airflow 2.1 and 2.2
raise AirflowOptionalProviderFeatureException(e)
Using Providers with dynamic task mapping
-----------------------------------------

Airflow 2.3 added `Dynamic Task Mapping <https://cwiki.apache.org/confluence/display/AIRFLOW/AIP-42+Dynamic+Task+Mapping>`_
and it added the possibility of assigning a unique key to each task. Which means that when such dynamically
mapped task wants to retrieve a value from XCom (for example in case an extra link should calculated)
it should always check if the ti_key value passed is not None an only then retrieve the XCom value using
XCom.get_value. This allows to keep backwards compatibility with earlier versions of Airflow.

Typical code to access XCom Value in providers that want to keep backwards compatibility should look similar to
this (note the ``if ti_key is not None:`` condition).

.. code-block:: python
def get_link(
self,
operator,
dttm: Optional[datetime] = None,
ti_key: Optional["TaskInstanceKey"] = None,
):
if ti_key is not None:
job_ids = XCom.get_value(key="job_id", ti_key=ti_key)
else:
assert dttm is not None
job_ids = XCom.get_one(
key="job_id",
dag_id=operator.dag.dag_id,
task_id=operator.task_id,
execution_date=dttm,
)
if not job_ids:
return None
if len(job_ids) < self.index:
return None
job_id = job_ids[self.index]
return BIGQUERY_JOB_DETAILS_LINK_FMT.format(job_id=job_id)
How-to Update a community provider
----------------------------------

Expand Down
60 changes: 60 additions & 0 deletions scripts/ci/pre_commit/pre_commit_check_xcom_get_value.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import sys
from pathlib import Path
from typing import List

from rich.console import Console

if __name__ not in ("__main__", "__mp_main__"):
raise SystemExit(
"This file is intended to be executed as an executable program. You cannot use it as a module."
f"To run this script, run the ./{__file__} command [FILE] ..."
)


console = Console(color_system="standard", width=200)

errors: List[str] = []


def _check_file(_file: Path):
lines = _file.read_text().splitlines()
for index, line in enumerate(lines):
if "XCom.get_value(" in line:
if "if ti_key is not None:" not in lines[index - 1]:
errors.append(
f"[red]In {_file}:{index} there is a forbidden construct "
f"(Airflow 2.3.0 only):[/]\n\n"
f"{lines[index-1]}\n{lines[index]}\n\n"
f"[yellow]When you use XCom.get_value( in providers, it should be in the form:[/]\n\n"
f"if ti_key is not None:\n"
f" value = XCom.get_value(...., ti_key=ti_key)\n\n"
f"See: https://airflow.apache.org/docs/apache-airflow-providers/"
f"howto/create-update-providers.html#using-providers-with-dynamic-task-mapping\n"
)


if __name__ == '__main__':
for file in sys.argv[1:]:
_check_file(Path(file))
if errors:
console.print("[red]Found forbidden usage of XCom.get_value( in providers:[/]\n")
for error in errors:
console.print(f"{error}")
sys.exit(1)

0 comments on commit c1ab8e2

Please sign in to comment.