Skip to content

Commit

Permalink
Update google hooks to prefer non-prefixed extra fields (#27023)
Browse files Browse the repository at this point in the history
As of airflow 2.3 we no longer need to use prefixed name in extra purely for web UI compat.  So now we update the providers to take advantage of this, while still maintaining backcompat for conns defined the old way.
  • Loading branch information
dstandish committed Oct 22, 2022
1 parent 14a4587 commit de9633f
Show file tree
Hide file tree
Showing 20 changed files with 159 additions and 203 deletions.
7 changes: 5 additions & 2 deletions airflow/providers/google/ads/hooks/ads.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from airflow import AirflowException
from airflow.compat.functools import cached_property
from airflow.hooks.base import BaseHook
from airflow.providers.google.common.hooks.base_google import get_field


class GoogleAdsHook(BaseHook):
Expand Down Expand Up @@ -200,8 +201,10 @@ def _update_config_with_secret(self, secrets_temp: IO[str]) -> None:
Updates google ads config with file path of the temp file containing the secret
Note, the secret must be passed as a file path for Google Ads API
"""
secret_conn = self.get_connection(self.gcp_conn_id)
secret = secret_conn.extra_dejson["extra__google_cloud_platform__keyfile_dict"]
extras = self.get_connection(self.gcp_conn_id).extra_dejson
secret = get_field(extras, 'keyfile_dict')
if not secret:
raise KeyError("secret_conn.extra_dejson does not contain keyfile_dict")
secrets_temp.write(secret)
secrets_temp.flush()

Expand Down
21 changes: 9 additions & 12 deletions airflow/providers/google/cloud/hooks/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
from airflow.exceptions import AirflowException
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook
from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook, get_field
from airflow.utils.helpers import convert_camel_to_snake
from airflow.utils.log.logging_mixin import LoggingMixin

Expand Down Expand Up @@ -156,15 +156,14 @@ def get_sqlalchemy_engine(self, engine_kwargs=None):
"""
if engine_kwargs is None:
engine_kwargs = {}
connection = self.get_connection(self.gcp_conn_id)
if connection.extra_dejson.get("extra__google_cloud_platform__key_path"):
credentials_path = connection.extra_dejson['extra__google_cloud_platform__key_path']
extras = self.get_connection(self.gcp_conn_id).extra_dejson
credentials_path = get_field(extras, 'key_path')
if credentials_path:
return create_engine(self.get_uri(), credentials_path=credentials_path, **engine_kwargs)
elif connection.extra_dejson.get("extra__google_cloud_platform__keyfile_dict"):
credential_file_content = json.loads(
connection.extra_dejson["extra__google_cloud_platform__keyfile_dict"]
)
return create_engine(self.get_uri(), credentials_info=credential_file_content, **engine_kwargs)
keyfile_dict = get_field(extras, 'keyfile_dict')
if keyfile_dict:
keyfile_content = keyfile_dict if isinstance(keyfile_dict, dict) else json.loads(keyfile_dict)
return create_engine(self.get_uri(), credentials_info=keyfile_content, **engine_kwargs)
try:
# 1. If the environment variable GOOGLE_APPLICATION_CREDENTIALS is set
# ADC uses the service account key or configuration file that the variable points to.
Expand All @@ -175,9 +174,7 @@ def get_sqlalchemy_engine(self, engine_kwargs=None):
self.log.error(e)
raise AirflowException(
"For now, we only support instantiating SQLAlchemy engine by"
" using ADC"
", extra__google_cloud_platform__key_path"
"and extra__google_cloud_platform__keyfile_dict"
" using ADC or extra fields `key_path` and `keyfile_dict`."
)

def get_records(self, sql, parameters=None):
Expand Down
22 changes: 10 additions & 12 deletions airflow/providers/google/cloud/hooks/cloud_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
# For requests that are "retriable"
from airflow.hooks.base import BaseHook
from airflow.models import Connection
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook, get_field
from airflow.providers.mysql.hooks.mysql import MySqlHook
from airflow.providers.postgres.hooks.postgres import PostgresHook
from airflow.utils.log.logging_mixin import LoggingMixin
Expand Down Expand Up @@ -375,9 +375,6 @@ def _wait_for_operation_to_complete(self, project_id: str, operation_name: str)
"https://storage.googleapis.com/cloudsql-proxy/{}/cloud_sql_proxy.{}.{}"
)

GCP_CREDENTIALS_KEY_PATH = "extra__google_cloud_platform__key_path"
GCP_CREDENTIALS_KEYFILE_DICT = "extra__google_cloud_platform__keyfile_dict"


class CloudSqlProxyRunner(LoggingMixin):
"""
Expand Down Expand Up @@ -484,15 +481,16 @@ def _download_sql_proxy_if_needed(self) -> None:
self.sql_proxy_was_downloaded = True

def _get_credential_parameters(self) -> list[str]:
connection = GoogleBaseHook.get_connection(conn_id=self.gcp_conn_id)

if connection.extra_dejson.get(GCP_CREDENTIALS_KEY_PATH):
credential_params = ['-credential_file', connection.extra_dejson[GCP_CREDENTIALS_KEY_PATH]]
elif connection.extra_dejson.get(GCP_CREDENTIALS_KEYFILE_DICT):
credential_file_content = json.loads(connection.extra_dejson[GCP_CREDENTIALS_KEYFILE_DICT])
extras = GoogleBaseHook.get_connection(conn_id=self.gcp_conn_id).extra_dejson
key_path = get_field(extras, 'key_path')
keyfile_dict = get_field(extras, 'keyfile_dict')
if key_path:
credential_params = ['-credential_file', key_path]
elif keyfile_dict:
keyfile_content = keyfile_dict if isinstance(keyfile_dict, dict) else json.loads(keyfile_dict)
self.log.info("Saving credentials to %s", self.credentials_path)
with open(self.credentials_path, "w") as file:
json.dump(credential_file_content, file)
json.dump(keyfile_content, file)
credential_params = ['-credential_file', self.credentials_path]
else:
self.log.info(
Expand All @@ -504,7 +502,7 @@ def _get_credential_parameters(self) -> list[str]:
credential_params = []

if not self.instance_specification:
project_id = connection.extra_dejson.get('extra__google_cloud_platform__project')
project_id = get_field(extras, 'project')
if self.project_id:
project_id = self.project_id
if not project_id:
Expand Down
5 changes: 2 additions & 3 deletions airflow/providers/google/cloud/operators/cloud_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from airflow.providers.google.cloud.hooks.cloud_sql import CloudSQLDatabaseHook, CloudSQLHook
from airflow.providers.google.cloud.links.cloud_sql import CloudSQLInstanceDatabaseLink, CloudSQLInstanceLink
from airflow.providers.google.cloud.utils.field_validator import GcpBodyFieldValidator
from airflow.providers.google.common.hooks.base_google import get_field
from airflow.providers.google.common.links.storage import FileDetailsLink
from airflow.providers.mysql.hooks.mysql import MySqlHook
from airflow.providers.postgres.hooks.postgres import PostgresHook
Expand Down Expand Up @@ -1092,9 +1093,7 @@ def execute(self, context: Context):
hook = CloudSQLDatabaseHook(
gcp_cloudsql_conn_id=self.gcp_cloudsql_conn_id,
gcp_conn_id=self.gcp_conn_id,
default_gcp_project_id=self.gcp_connection.extra_dejson.get(
'extra__google_cloud_platform__project'
),
default_gcp_project_id=get_field(self.gcp_connection.extra_dejson, 'project'),
)
hook.validate_ssl_certs()
connection = hook.create_connection()
Expand Down
8 changes: 3 additions & 5 deletions airflow/providers/google/cloud/utils/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,14 @@ def build_gcp_conn(
:return: String representing Airflow connection.
"""
conn = "google-cloud-platform://?{}"
extras = "extra__google_cloud_platform"

query_params = {}
if key_file_path:
query_params[f"{extras}__key_path"] = key_file_path
query_params["key_path"] = key_file_path
if scopes:
scopes_string = ",".join(scopes)
query_params[f"{extras}__scope"] = scopes_string
query_params["scope"] = scopes_string
if project_id:
query_params[f"{extras}__projects"] = project_id
query_params["projects"] = project_id

query = urlencode(query_params)
return conn.format(query)
Expand Down
41 changes: 21 additions & 20 deletions airflow/providers/google/common/hooks/base_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,19 @@ def __init__(self):
RT = TypeVar('RT')


def get_field(extras: dict, field_name: str):
"""Get field from extra, first checking short name, then for backcompat we check for prefixed name."""
if field_name.startswith('extra__'):
raise ValueError(
f"Got prefixed name {field_name}; please remove the 'extra__google_cloud_platform__' prefix "
"when using this method."
)
if field_name in extras:
return extras[field_name] or None
prefixed_name = f"extra__google_cloud_platform__{field_name}"
return extras.get(prefixed_name) or None


class GoogleBaseHook(BaseHook):
"""
A base hook for Google cloud-related hooks. Google cloud has a shared REST
Expand Down Expand Up @@ -179,25 +192,17 @@ def get_connection_form_widgets() -> dict[str, Any]:
from wtforms.validators import NumberRange

return {
"extra__google_cloud_platform__project": StringField(
lazy_gettext('Project Id'), widget=BS3TextFieldWidget()
),
"extra__google_cloud_platform__key_path": StringField(
lazy_gettext('Keyfile Path'), widget=BS3TextFieldWidget()
),
"extra__google_cloud_platform__keyfile_dict": PasswordField(
lazy_gettext('Keyfile JSON'), widget=BS3PasswordFieldWidget()
),
"extra__google_cloud_platform__scope": StringField(
lazy_gettext('Scopes (comma separated)'), widget=BS3TextFieldWidget()
),
"extra__google_cloud_platform__key_secret_name": StringField(
"project": StringField(lazy_gettext('Project Id'), widget=BS3TextFieldWidget()),
"key_path": StringField(lazy_gettext('Keyfile Path'), widget=BS3TextFieldWidget()),
"keyfile_dict": PasswordField(lazy_gettext('Keyfile JSON'), widget=BS3PasswordFieldWidget()),
"scope": StringField(lazy_gettext('Scopes (comma separated)'), widget=BS3TextFieldWidget()),
"key_secret_name": StringField(
lazy_gettext('Keyfile Secret Name (in GCP Secret Manager)'), widget=BS3TextFieldWidget()
),
"extra__google_cloud_platform__key_secret_project_id": StringField(
"key_secret_project_id": StringField(
lazy_gettext('Keyfile Secret Project Id (in GCP Secret Manager)'), widget=BS3TextFieldWidget()
),
"extra__google_cloud_platform__num_retries": IntegerField(
"num_retries": IntegerField(
lazy_gettext('Number of Retries'),
validators=[NumberRange(min=0)],
widget=BS3TextFieldWidget(),
Expand Down Expand Up @@ -325,11 +330,7 @@ def _get_field(self, f: str, default: Any = None) -> Any:
to the hook page, which allow admins to specify service_account,
key_path, etc. They get formatted as shown below.
"""
long_f = f'extra__google_cloud_platform__{f}'
if hasattr(self, 'extras') and long_f in self.extras:
return self.extras[long_f]
else:
return default
return hasattr(self, 'extras') and get_field(self.extras, f) or default

@property
def project_id(self) -> str | None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,5 +147,5 @@ For connecting to a google cloud conn, all the fields must be in the extra field

.. code-block:: ini
{'extra__google_cloud_platform__key_path': '/opt/airflow/service_account.json',
'extra__google_cloud_platform__scope': 'https://www.googleapis.com/auth/devstorage.read_only'}
{'key_path': '/opt/airflow/service_account.json',
'scope': 'https://www.googleapis.com/auth/devstorage.read_only'}
24 changes: 15 additions & 9 deletions docs/apache-airflow-providers-google/connections/gcp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -124,21 +124,27 @@ Number of Retries
* query parameters contains information specific to this type of
connection. The following keys are accepted:

* ``extra__google_cloud_platform__project`` - Project Id
* ``extra__google_cloud_platform__key_path`` - Keyfile Path
* ``extra__google_cloud_platform__keyfile_dict`` - Keyfile JSON
* ``extra__google_cloud_platform__key_secret_name`` - Secret name which holds Keyfile JSON
* ``extra__google_cloud_platform__key_secret_project_id`` - Project Id which holds Keyfile JSON
* ``extra__google_cloud_platform__scope`` - Scopes
* ``extra__google_cloud_platform__num_retries`` - Number of Retries
* ``project`` - Project Id
* ``key_path`` - Keyfile Path
* ``keyfile_dict`` - Keyfile JSON
* ``key_secret_name`` - Secret name which holds Keyfile JSON
* ``key_secret_project_id`` - Project Id which holds Keyfile JSON
* ``scope`` - Scopes
* ``num_retries`` - Number of Retries

Note that all components of the URI should be URL-encoded.

For example:
For example, with URI format:

.. code-block:: bash
export AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT='google-cloud-platform://?extra__google_cloud_platform__key_path=%2Fkeys%2Fkey.json&extra__google_cloud_platform__scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fcloud-platform&extra__google_cloud_platform__project=airflow&extra__google_cloud_platform__num_retries=5'
export AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT='google-cloud-platform://?key_path=%2Fkeys%2Fkey.json&scope=https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fcloud-platform&project=airflow&num_retries=5'
And using JSON format:

.. code-block:: bash
export AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT='{"conn_type": "google-cloud-platform", "key_path": "/keys/key.json", "scope": "https://www.googleapis.com/auth/cloud-platform", "project": "airflow", "num_retries": 5}'
.. _howto/connection:gcp:impersonation:

Expand Down
24 changes: 12 additions & 12 deletions docs/apache-airflow-providers-google/connections/gcp_ssh.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@ Extra (optional)
connection. The following parameters are supported in addition to those describing
the Google Cloud connection.

* ``extra__google_cloud_platform__instance_name`` - The name of the Compute Engine instance.
* ``extra__google_cloud_platform__zone`` - The zone of the Compute Engine instance.
* ``extra__google_cloud_platform__use_internal_ip`` - Whether to connect using internal IP.
* ``extra__google_cloud_platform__use_iap_tunnel`` - Whether to connect through IAP tunnel.
* ``extra__google_cloud_platform__use_oslogin`` - Whether to manage keys using OsLogin API. If false, keys are managed using instance metadata.
* ``extra__google_cloud_platform__expire_time`` - The maximum amount of time in seconds before the private key expires.
* ``instance_name`` - The name of the Compute Engine instance.
* ``zone`` - The zone of the Compute Engine instance.
* ``use_internal_ip`` - Whether to connect using internal IP.
* ``use_iap_tunnel`` - Whether to connect through IAP tunnel.
* ``use_oslogin`` - Whether to manage keys using OsLogin API. If false, keys are managed using instance metadata.
* ``expire_time`` - The maximum amount of time in seconds before the private key expires.


Environment variable
Expand All @@ -64,9 +64,9 @@ For example:
.. code-block:: bash
export AIRFLOW_CONN_GOOGLE_CLOUD_SQL_DEFAULT="gcpssh://conn-user@conn-host?\
extra__google_cloud_platform__instance_name=conn-instance-name&\
extra__google_cloud_platform__zone=zone&\
extra__google_cloud_platform__use_internal_ip=True&\
extra__google_cloud_platform__use_iap_tunnel=True&\
extra__google_cloud_platform__use_oslogin=False&\
extra__google_cloud_platform__expire_time=4242"
instance_name=conn-instance-name&\
zone=zone&\
use_internal_ip=True&\
use_iap_tunnel=True&\
use_oslogin=False&\
expire_time=4242"
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ raise an exception. The following is a sample file.
.. code-block:: text
mysql_conn_id=mysql://log:[email protected]:3306/mysqldbrd
google_custom_key=google-cloud-platform://?extra__google_cloud_platform__key_path=%2Fkeys%2Fkey.json
google_custom_key=google-cloud-platform://?key_path=%2Fkeys%2Fkey.json
Storing and Retrieving Variables
""""""""""""""""""""""""""""""""
Expand Down
12 changes: 5 additions & 7 deletions tests/always/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,10 +275,8 @@ def test_connection_extra_with_encryption_rotate_fernet_key(self):
description='no schema',
),
UriTestCaseConfig(
test_conn_uri='google-cloud-platform://?extra__google_cloud_platform__key_'
'path=%2Fkeys%2Fkey.json&extra__google_cloud_platform__scope='
'https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fcloud-platform&extra'
'__google_cloud_platform__project=airflow',
test_conn_uri='google-cloud-platform://?key_path=%2Fkeys%2Fkey.json&scope='
'https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fcloud-platform&project=airflow',
test_conn_attributes=dict(
conn_type='google_cloud_platform',
host='',
Expand All @@ -287,9 +285,9 @@ def test_connection_extra_with_encryption_rotate_fernet_key(self):
password=None,
port=None,
extra_dejson=dict(
extra__google_cloud_platform__key_path='/keys/key.json',
extra__google_cloud_platform__scope='https://www.googleapis.com/auth/cloud-platform',
extra__google_cloud_platform__project='airflow',
key_path='/keys/key.json',
scope='https://www.googleapis.com/auth/cloud-platform',
project='airflow',
),
),
description='with underscore',
Expand Down
20 changes: 10 additions & 10 deletions tests/always/test_secrets_local_filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,8 @@ def test_missing_file(self, mock_exists):
extra_dejson:
arbitrary_dict:
a: b
extra__google_cloud_platform__keyfile_dict: '{"a": "b"}'
extra__google_cloud_platform__keyfile_path: asaa""",
keyfile_dict: '{"a": "b"}'
keyfile_path: asaa""",
{
"conn_a": {'conn_type': 'mysql', 'host': 'hosta'},
"conn_b": {
Expand All @@ -270,8 +270,8 @@ def test_missing_file(self, mock_exists):
'port': 1234,
'extra_dejson': {
'arbitrary_dict': {"a": "b"},
'extra__google_cloud_platform__keyfile_dict': '{"a": "b"}',
'extra__google_cloud_platform__keyfile_path': 'asaa',
'keyfile_dict': '{"a": "b"}',
'keyfile_path': 'asaa',
},
},
},
Expand Down Expand Up @@ -314,14 +314,14 @@ def test_yaml_file_should_load_connection(self, file_content, expected_attrs_dic
password: None
port: 1234
extra_dejson:
extra__google_cloud_platform__keyfile_dict:
keyfile_dict:
a: b
extra__google_cloud_platform__key_path: xxx
key_path: xxx
""",
{
"conn_d": {
"extra__google_cloud_platform__keyfile_dict": {"a": "b"},
"extra__google_cloud_platform__key_path": "xxx",
"keyfile_dict": {"a": "b"},
"key_path": "xxx",
}
},
),
Expand All @@ -334,9 +334,9 @@ def test_yaml_file_should_load_connection(self, file_content, expected_attrs_dic
login: Login
password: None
port: 1234
extra: '{\"extra__google_cloud_platform__keyfile_dict\": {\"a\": \"b\"}}'
extra: '{\"keyfile_dict\": {\"a\": \"b\"}}'
""",
{"conn_d": {"extra__google_cloud_platform__keyfile_dict": {"a": "b"}}},
{"conn_d": {"keyfile_dict": {"a": "b"}}},
),
],
)
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/google/ads/hooks/test_ads.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
ADS_CLIENT = {"key": "value"}
SECRET = "secret"
EXTRAS = {
"extra__google_cloud_platform__keyfile_dict": SECRET,
"keyfile_dict": SECRET,
"google_ads_client": ADS_CLIENT,
}

Expand Down

0 comments on commit de9633f

Please sign in to comment.