Skip to content

Commit cdf1809

Browse files
authored
[AIRFLOW-7104] Add Secret backend for GCP Secrets Manager (#7795)
1 parent d372f23 commit cdf1809

File tree

14 files changed

+583
-125
lines changed

14 files changed

+583
-125
lines changed

CONTRIBUTING.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,7 @@ apache.livy http
413413
dingding http
414414
discord http
415415
google amazon,apache.cassandra,cncf.kubernetes,microsoft.azure,microsoft.mssql,mysql,postgres,presto,sftp
416+
hashicorp google
416417
microsoft.azure oracle
417418
microsoft.mssql odbc
418419
mysql amazon,presto,vertica

airflow/providers/dependencies.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@
3838
"presto",
3939
"sftp"
4040
],
41+
"hashicorp": [
42+
"google"
43+
],
4144
"microsoft.azure": [
4245
"oracle"
4346
],

airflow/providers/google/cloud/hooks/base.py

Lines changed: 19 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,14 @@
4444
from airflow import version
4545
from airflow.exceptions import AirflowException
4646
from airflow.hooks.base_hook import BaseHook
47+
from airflow.providers.google.cloud.utils.credentials_provider import (
48+
_get_scopes, get_credentials_and_project_id,
49+
)
4750
from airflow.utils.process_utils import patch_environ
4851

4952
log = logging.getLogger(__name__)
5053

5154

52-
_DEFAULT_SCOPES = ('https://www.googleapis.com/auth/cloud-platform',) # type: Sequence[str]
53-
5455
# Constants used by the mechanism of repeating requests in reaction to exceeding the temporary quota.
5556
INVALID_KEYS = [
5657
'DefaultRequestsPerMinutePerProject',
@@ -167,55 +168,21 @@ def _get_credentials_and_project_id(self) -> Tuple[google.auth.credentials.Crede
167168
if self._cached_credentials is not None:
168169
return self._cached_credentials, self._cached_project_id
169170

170-
key_path = self._get_field('key_path', None) # type: Optional[str]
171-
keyfile_dict = self._get_field('keyfile_dict', None) # type: Optional[str]
172-
if key_path and keyfile_dict:
173-
raise AirflowException(
174-
"The `keyfile_dict` and `key_path` fields are mutually exclusive. "
175-
"Please provide only one value."
176-
)
177-
if not key_path and not keyfile_dict:
178-
self.log.info(
179-
'Getting connection using `google.auth.default()` since no key file is defined for hook.'
180-
)
181-
credentials, project_id = google.auth.default(scopes=self.scopes)
182-
elif key_path:
183-
# Get credentials from a JSON file.
184-
if key_path.endswith('.json'):
185-
self.log.debug('Getting connection using JSON key file %s', key_path)
186-
credentials = (
187-
google.oauth2.service_account.Credentials.from_service_account_file(
188-
key_path, scopes=self.scopes)
189-
)
190-
project_id = credentials.project_id
191-
elif key_path.endswith('.p12'):
192-
raise AirflowException(
193-
'Legacy P12 key file are not supported, use a JSON key file.'
194-
)
195-
else:
196-
raise AirflowException('Unrecognised extension for key file.')
197-
else:
198-
# Get credentials from JSON data provided in the UI.
199-
try:
200-
if not keyfile_dict:
201-
raise ValueError("The keyfile_dict should be set")
202-
keyfile_dict_json: Dict[str, str] = json.loads(keyfile_dict)
203-
204-
# Depending on how the JSON was formatted, it may contain
205-
# escaped newlines. Convert those to actual newlines.
206-
keyfile_dict_json['private_key'] = keyfile_dict_json['private_key'].replace(
207-
'\\n', '\n')
208-
209-
credentials = (
210-
google.oauth2.service_account.Credentials.from_service_account_info(
211-
keyfile_dict_json, scopes=self.scopes)
212-
)
213-
project_id = credentials.project_id
214-
except json.decoder.JSONDecodeError:
215-
raise AirflowException('Invalid key JSON.')
216-
217-
if self.delegate_to:
218-
credentials = credentials.with_subject(self.delegate_to)
171+
key_path: Optional[str] = self._get_field('key_path', None)
172+
try:
173+
keyfile_dict: Optional[str] = self._get_field('keyfile_dict', None)
174+
keyfile_dict_json: Optional[Dict[str, str]] = None
175+
if keyfile_dict:
176+
keyfile_dict_json = json.loads(keyfile_dict)
177+
except json.decoder.JSONDecodeError:
178+
raise AirflowException('Invalid key JSON.')
179+
180+
credentials, project_id = get_credentials_and_project_id(
181+
key_path=key_path,
182+
keyfile_dict=keyfile_dict_json,
183+
scopes=self.scopes,
184+
delegate_to=self.delegate_to
185+
)
219186

220187
overridden_project_id = self._get_field('project')
221188
if overridden_project_id:
@@ -308,8 +275,7 @@ def scopes(self) -> Sequence[str]:
308275
"""
309276
scope_value = self._get_field('scope', None) # type: Optional[str]
310277

311-
return [s.strip() for s in scope_value.split(',')] \
312-
if scope_value else _DEFAULT_SCOPES
278+
return _get_scopes(scope_value)
313279

314280
@staticmethod
315281
def quota_retry(*args, **kwargs) -> Callable:
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""
19+
Objects relating to sourcing connections from GCP Secrets Manager
20+
"""
21+
from typing import List, Optional
22+
23+
from cached_property import cached_property
24+
from google.api_core.exceptions import NotFound
25+
from google.api_core.gapic_v1.client_info import ClientInfo
26+
from google.cloud.secretmanager_v1 import SecretManagerServiceClient
27+
28+
from airflow import version
29+
from airflow.models import Connection
30+
from airflow.providers.google.cloud.utils.credentials_provider import (
31+
_get_scopes, get_credentials_and_project_id,
32+
)
33+
from airflow.secrets import BaseSecretsBackend
34+
from airflow.utils.log.logging_mixin import LoggingMixin
35+
36+
37+
class CloudSecretsManagerSecretsBackend(BaseSecretsBackend, LoggingMixin):
38+
"""
39+
Retrieves Connection object from GCP Secrets Manager
40+
41+
Configurable via ``airflow.cfg`` as follows:
42+
43+
.. code-block:: ini
44+
45+
[secrets]
46+
backend = airflow.providers.google.cloud.secrets.secrets_manager.CloudSecretsManagerSecretsBackend
47+
backend_kwargs = {"connections_prefix": "airflow/connections"}
48+
49+
For example, if secret id is ``airflow/connections/smtp_default``, this would be accessible
50+
if you provide ``{"connections_prefix": "airflow/connections"}`` and request conn_id ``smtp_default``.
51+
52+
:param connections_prefix: Specifies the prefix of the secret to read to get Connections.
53+
:type connections_prefix: str
54+
:param gcp_key_path: Path to GCP Credential JSON file;
55+
use default credentials in the current environment if not provided.
56+
:type gcp_key_path: str
57+
:param gcp_scopes: Comma-separated string containing GCP scopes
58+
:type gcp_scopes: str
59+
"""
60+
def __init__(
61+
self,
62+
connections_prefix: str = "airflow/connections",
63+
gcp_key_path: Optional[str] = None,
64+
gcp_scopes: Optional[str] = None,
65+
**kwargs
66+
):
67+
self.connections_prefix = connections_prefix.rstrip("/")
68+
self.gcp_key_path = gcp_key_path
69+
self.gcp_scopes = gcp_scopes
70+
self.credentials: Optional[str] = None
71+
self.project_id: Optional[str] = None
72+
super().__init__(**kwargs)
73+
74+
@cached_property
75+
def client(self) -> SecretManagerServiceClient:
76+
"""
77+
Create an authenticated KMS client
78+
"""
79+
scopes = _get_scopes(self.gcp_scopes)
80+
self.credentials, self.project_id = get_credentials_and_project_id(
81+
key_path=self.gcp_key_path,
82+
scopes=scopes
83+
)
84+
_client = SecretManagerServiceClient(
85+
credentials=self.credentials,
86+
client_info=ClientInfo(client_library_version='airflow_v' + version.version)
87+
)
88+
return _client
89+
90+
def build_secret_id(self, conn_id: str) -> str:
91+
"""
92+
Given conn_id, build path for Secrets Manager
93+
94+
:param conn_id: connection id
95+
:type conn_id: str
96+
"""
97+
secret_id = f"{self.connections_prefix}/{conn_id}"
98+
return secret_id
99+
100+
def get_conn_uri(self, conn_id: str) -> Optional[str]:
101+
"""
102+
Get secret value from Secrets Manager.
103+
104+
:param conn_id: connection id
105+
:type conn_id: str
106+
"""
107+
secret_id = self.build_secret_id(conn_id=conn_id)
108+
# always return the latest version of the secret
109+
secret_version = "latest"
110+
name = self.client.secret_version_path(self.project_id, secret_id, secret_version)
111+
try:
112+
response = self.client.access_secret_version(name)
113+
value = response.payload.data.decode('UTF-8')
114+
return value
115+
except NotFound:
116+
self.log.error(
117+
"GCP API Call Error (NotFound): Secret ID %s not found.", secret_id
118+
)
119+
return None
120+
121+
def get_connections(self, conn_id: str) -> List[Connection]:
122+
"""
123+
Create connection object from GCP Secrets Manager
124+
125+
:param conn_id: connection id
126+
:type conn_id: str
127+
"""
128+
conn_uri = self.get_conn_uri(conn_id=conn_id)
129+
if not conn_uri:
130+
return []
131+
conn = Connection(conn_id=conn_id, uri=conn_uri)
132+
return [conn]

airflow/providers/google/cloud/utils/credentials_provider.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,23 @@
2020
Google Cloud Platform authentication.
2121
"""
2222
import json
23+
import logging
2324
import tempfile
2425
from contextlib import contextmanager
25-
from typing import Dict, Optional, Sequence
26+
from typing import Dict, Optional, Sequence, Tuple
2627
from urllib.parse import urlencode
2728

29+
import google.auth
30+
import google.oauth2.service_account
2831
from google.auth.environment_vars import CREDENTIALS
2932

3033
from airflow.exceptions import AirflowException
3134
from airflow.utils.process_utils import patch_environ
3235

36+
log = logging.getLogger(__name__)
37+
3338
AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT = "AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT"
39+
_DEFAULT_SCOPES: Sequence[str] = ('https://www.googleapis.com/auth/cloud-platform',)
3440

3541

3642
def build_gcp_conn(
@@ -158,3 +164,87 @@ def provide_gcp_conn_and_credentials(
158164
key_file_path, scopes, project_id
159165
):
160166
yield
167+
168+
169+
def get_credentials_and_project_id(
170+
key_path: Optional[str] = None,
171+
keyfile_dict: Optional[Dict[str, str]] = None,
172+
scopes: Optional[Sequence[str]] = None,
173+
delegate_to: Optional[str] = None
174+
) -> Tuple[google.auth.credentials.Credentials, str]:
175+
"""
176+
Returns the Credentials object for Google API and the associated project_id
177+
178+
Only either `key_path` or `keyfile_dict` should be provided, or an exception will
179+
occur. If neither of them are provided, return default credentials for the current environment
180+
181+
:param key_path: Path to GCP Credential JSON file
182+
:type key_path: str
183+
:param key_dict: A dict representing GCP Credential as in the Credential JSON file
184+
:type key_dict: Dict[str, str]
185+
:param scopes: OAuth scopes for the connection
186+
:type scopes: Sequence[str]
187+
:param delegate_to: The account to impersonate, if any.
188+
For this to work, the service account making the request must have
189+
domain-wide delegation enabled.
190+
:type delegate_to: str
191+
:return: Google Auth Credentials
192+
:type: google.auth.credentials.Credentials
193+
"""
194+
if key_path and keyfile_dict:
195+
raise AirflowException(
196+
"The `keyfile_dict` and `key_path` fields are mutually exclusive. "
197+
"Please provide only one value."
198+
)
199+
if not key_path and not keyfile_dict:
200+
log.info(
201+
'Getting connection using `google.auth.default()` since no key file is defined for hook.'
202+
)
203+
credentials, project_id = google.auth.default(scopes=scopes)
204+
elif key_path:
205+
# Get credentials from a JSON file.
206+
if key_path.endswith('.json'):
207+
log.debug('Getting connection using JSON key file %s', key_path)
208+
credentials = (
209+
google.oauth2.service_account.Credentials.from_service_account_file(
210+
key_path, scopes=scopes)
211+
)
212+
project_id = credentials.project_id
213+
elif key_path.endswith('.p12'):
214+
raise AirflowException(
215+
'Legacy P12 key file are not supported, use a JSON key file.'
216+
)
217+
else:
218+
raise AirflowException('Unrecognised extension for key file.')
219+
else:
220+
if not keyfile_dict:
221+
raise ValueError("The keyfile_dict should be set")
222+
# Depending on how the JSON was formatted, it may contain
223+
# escaped newlines. Convert those to actual newlines.
224+
keyfile_dict['private_key'] = keyfile_dict['private_key'].replace(
225+
'\\n', '\n')
226+
227+
credentials = (
228+
google.oauth2.service_account.Credentials.from_service_account_info(
229+
keyfile_dict, scopes=scopes)
230+
)
231+
project_id = credentials.project_id
232+
233+
if delegate_to:
234+
credentials = credentials.with_subject(delegate_to)
235+
236+
return credentials, project_id
237+
238+
239+
def _get_scopes(scopes: Optional[str] = None) -> Sequence[str]:
240+
"""
241+
Parse a comma-separated string containing GCP scopes if `scopes` is provided.
242+
Otherwise, default scope will be returned.
243+
244+
:param scopes: A comma-separated string containing GCP scopes
245+
:type scopes: Optional[str]
246+
:return: Returns the scope defined in the connection configuration, or the default scope
247+
:rtype: Sequence[str]
248+
"""
249+
return [s.strip() for s in scopes.split(',')] \
250+
if scopes else _DEFAULT_SCOPES

0 commit comments

Comments
 (0)