Skip to content

Commit

Permalink
Add support for extra links coming from the providers (#12472)
Browse files Browse the repository at this point in the history
Closes: #11431
  • Loading branch information
potiuk committed Dec 5, 2020
1 parent 6150e26 commit 1dcd3e1
Show file tree
Hide file tree
Showing 9 changed files with 130 additions and 12 deletions.
6 changes: 6 additions & 0 deletions airflow/cli/cli_parser.py
Expand Up @@ -1169,6 +1169,12 @@ class GroupCommand(NamedTuple):
func=lazy_load_command('airflow.cli.commands.provider_command.provider_get'),
args=(ARG_OUTPUT, ARG_FULL, ARG_COLOR, ARG_PROVIDER_NAME),
),
ActionCommand(
name='links',
help='List extra links registered by the providers',
func=lazy_load_command('airflow.cli.commands.provider_command.extra_links_list'),
args=(ARG_OUTPUT,),
),
)

USERS_COMMANDS = (
Expand Down
11 changes: 11 additions & 0 deletions airflow/cli/commands/provider_command.py
Expand Up @@ -72,3 +72,14 @@ def hooks_list(args):
"conn_attribute_name": x[1][1],
},
)


def extra_links_list(args):
"""Lists all extra links at the command line"""
AirflowConsole().print_as(
data=ProvidersManager().extra_links_class_names,
output=args.output,
mapper=lambda x: {
"extra_link_class_name": x,
},
)
7 changes: 7 additions & 0 deletions airflow/provider.yaml.schema.json
Expand Up @@ -180,6 +180,13 @@
"items": {
"type": "string"
}
},
"extra-links": {
"type": "array",
"description": "Class name that provide extra link functionality",
"items": {
"type": "string"
}
}
},
"additionalProperties": false,
Expand Down
5 changes: 5 additions & 0 deletions airflow/providers/google/provider.yaml
Expand Up @@ -636,3 +636,8 @@ hook-class-names:
- airflow.providers.google.cloud.hooks.cloud_sql.CloudSQLDatabaseHook
- airflow.providers.google.cloud.hooks.compute_ssh.ComputeEngineSSHHook
- airflow.providers.google.cloud.hooks.bigquery.BigQueryHook

extra-links:
- airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleLink
- airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink
- airflow.providers.google.cloud.operators.mlengine.AIPlatformConsoleLink
3 changes: 3 additions & 0 deletions airflow/providers/qubole/provider.yaml
Expand Up @@ -45,3 +45,6 @@ hooks:
python-modules:
- airflow.providers.qubole.hooks.qubole
- airflow.providers.qubole.hooks.qubole_check

extra-links:
- airflow.providers.qubole.operators.qubole.QDSLink
35 changes: 34 additions & 1 deletion airflow/providers_manager.py
Expand Up @@ -22,7 +22,7 @@
import logging
import os
from collections import OrderedDict
from typing import Dict, Tuple
from typing import Dict, Set, Tuple

import jsonschema
import yaml
Expand Down Expand Up @@ -68,6 +68,7 @@ def __init__(self):
# Keeps dict of hooks keyed by connection type and value is
# Tuple: connection class, connection_id_attribute_name
self._hooks_dict: Dict[str, Tuple[str, str]] = {}
self._extra_link_class_name_set: Set[str] = set()
self._validator = _create_validator()
# Local source folders are loaded first. They should take precedence over the package ones for
# Development purpose. In production provider.yaml files are not present in the 'airflow" directory
Expand All @@ -78,6 +79,7 @@ def __init__(self):
self._discover_hooks()
self._provider_dict = OrderedDict(sorted(self.providers.items()))
self._hooks_dict = OrderedDict(sorted(self.hooks.items()))
self._discover_extra_links()

def _discover_all_providers_from_packages(self) -> None:
"""
Expand Down Expand Up @@ -224,6 +226,32 @@ def _add_hook(self, hook_class_name, provider_package) -> None:

self._hooks_dict[conn_type] = (hook_class_name, connection_id_attribute_name)

def _discover_extra_links(self) -> None:
"""Retrieves all extra links defined in the providers"""
for provider_package, (_, provider) in self._provider_dict.items():
if provider.get("extra-links"):
for extra_link in provider["extra-links"]:
self._add_extra_link(extra_link, provider_package)

def _add_extra_link(self, extra_link_class_name, provider_package) -> None:
"""
Adds extra link class name to the list of classes
:param extra_link_class_name: name of the class to add
:param provider_package: provider package adding the link
:return:
"""
if provider_package.startswith("apache-airflow"):
provider_path = provider_package[len("apache-") :].replace("-", ".")
if not extra_link_class_name.startswith(provider_path):
log.warning(
"Sanity check failed when importing '%s' from '%s' package. It should start with '%s'",
extra_link_class_name,
provider_package,
provider_path,
)
return
self._extra_link_class_name_set.add(extra_link_class_name)

@property
def providers(self):
"""Returns information about available providers."""
Expand All @@ -233,3 +261,8 @@ def providers(self):
def hooks(self):
"""Returns dictionary of connection_type-to-hook mapping"""
return self._hooks_dict

@property
def extra_links_class_names(self):
"""Returns set of extra link class names."""
return sorted(list(self._extra_link_class_name_set))
32 changes: 24 additions & 8 deletions airflow/serialization/serialized_objects.py
Expand Up @@ -20,17 +20,25 @@
import enum
import logging
from inspect import Parameter, signature
from typing import Any, Dict, Iterable, List, Optional, Set, Union
from typing import Any, Dict, Iterable, Optional, Set, Union

import cattr
import pendulum
from dateutil import relativedelta

try:
from functools import cache
except ImportError:
from functools import lru_cache

cache = lru_cache(maxsize=None)
from pendulum.tz.timezone import Timezone

from airflow.exceptions import AirflowException
from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
from airflow.models.connection import Connection
from airflow.models.dag import DAG
from airflow.providers_manager import ProvidersManager
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
from airflow.serialization.helpers import serialize_template_field
from airflow.serialization.json_schema import Validator, load_dag_schema
Expand All @@ -53,14 +61,22 @@
log = logging.getLogger(__name__)
FAILED = 'serialization_failed'

BUILTIN_OPERATOR_EXTRA_LINKS: List[str] = [
"airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleLink",
"airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink",
"airflow.providers.google.cloud.operators.mlengine.AIPlatformConsoleLink",
"airflow.providers.qubole.operators.qubole.QDSLink",
_OPERATOR_EXTRA_LINKS: Set[str] = {
"airflow.operators.dagrun_operator.TriggerDagRunLink",
"airflow.sensors.external_task_sensor.ExternalTaskSensorLink",
]
}


@cache
def get_operator_extra_links():
"""
Returns operator extra links - both the ones that are built in and the ones that come from
the providers.
:return: set of extra links
"""
_OPERATOR_EXTRA_LINKS.update(ProvidersManager().extra_links_class_names)
return _OPERATOR_EXTRA_LINKS


class BaseSerialization:
Expand Down Expand Up @@ -498,7 +514,7 @@ def _deserialize_operator_extra_links(cls, encoded_op_links: list) -> Dict[str,
# )

_operator_link_class_path, data = list(_operator_links_source.items())[0]
if _operator_link_class_path in BUILTIN_OPERATOR_EXTRA_LINKS:
if _operator_link_class_path in get_operator_extra_links():
single_op_link_class = import_string(_operator_link_class_path)
elif _operator_link_class_path in plugins_manager.registered_operator_link_classes:
single_op_link_class = plugins_manager.registered_operator_link_classes[
Expand Down
31 changes: 28 additions & 3 deletions scripts/in_container/run_install_and_test_provider_packages.sh
Expand Up @@ -81,16 +81,17 @@ function discover_all_provider_packages() {

airflow providers list

local expected_number_of_providers=60
local expected_number_of_providers=59
local actual_number_of_providers
actual_number_of_providers=$(airflow providers list --output table | grep -c apache-airflow-providers | xargs)
if [[ ${actual_number_of_providers} != "${expected_number_of_providers}" ]]; then
echo
echo "${COLOR_RED_ERROR} Number of providers installed is wrong ${COLOR_RESET}"
echo "Expected number was '${expected_number_of_providers}' and got '${actual_number_of_providers}'"
echo
echo "Either increase the number of providers if you added one or fix the problem with imports if you see one."
echo "Either increase the number of providers if you added one or diagnose and fix the problem."
echo
exit 1
fi
}

Expand All @@ -109,12 +110,36 @@ function discover_all_hooks() {
echo "${COLOR_RED_ERROR} Number of hooks registered is wrong ${COLOR_RESET}"
echo "Expected number was '${expected_number_of_hooks}' and got '${actual_number_of_hooks}'"
echo
echo "Either increase the number of hooks if you added one or fix problem with imports if you see one."
echo "Either increase the number of hooks if you added one or diagnose and fix the problem."
echo
exit 1
fi
}

function discover_all_extra_links() {
echo
echo Listing available extra links via 'airflow providers links'
echo

airflow providers links

local expected_number_of_extra_links=4
local actual_number_of_extra_links
actual_number_of_extra_links=$(airflow providers links --output table | grep -c ^airflow.providers | xargs)
if [[ ${actual_number_of_extra_links} != "${expected_number_of_extra_links}" ]]; then
echo
echo "${COLOR_RED_ERROR} Number of links registered is wrong ${COLOR_RESET}"
echo "Expected number was '${expected_number_of_extra_links}' and got '${actual_number_of_extra_links}'"
echo
echo "Either increase the number of links if you added one or diagnose and fix the problem."
echo
exit 1
fi
}


if [[ ${BACKPORT_PACKAGES} != "true" ]]; then
discover_all_provider_packages
discover_all_hooks
discover_all_extra_links
fi
12 changes: 12 additions & 0 deletions tests/core/test_providers_manager.py
Expand Up @@ -119,6 +119,13 @@
'wasb',
]

EXTRA_LINKS = [
'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleIndexableLink',
'airflow.providers.google.cloud.operators.bigquery.BigQueryConsoleLink',
'airflow.providers.google.cloud.operators.mlengine.AIPlatformConsoleLink',
'airflow.providers.qubole.operators.qubole.QDSLink',
]


class TestProviderManager(unittest.TestCase):
def test_providers_are_loaded(self):
Expand All @@ -137,3 +144,8 @@ def test_hooks(self):
provider_manager = ProvidersManager()
connections_list = list(provider_manager.hooks.keys())
self.assertEqual(CONNECTIONS_LIST, connections_list)

def test_extra_links(self):
provider_manager = ProvidersManager()
extra_link_class_names = list(provider_manager.extra_links_class_names)
self.assertEqual(EXTRA_LINKS, extra_link_class_names)

0 comments on commit 1dcd3e1

Please sign in to comment.