Skip to content

Commit

Permalink
[AIRFLOW-7073] GKEStartPodOperator always use connection credentials (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
mik-laj committed Mar 17, 2020
1 parent 2a54512 commit 91557c6
Show file tree
Hide file tree
Showing 8 changed files with 299 additions and 187 deletions.
107 changes: 80 additions & 27 deletions airflow/providers/google/cloud/hooks/base.py
Expand Up @@ -26,9 +26,11 @@
import os
import tempfile
from contextlib import contextmanager
from typing import Any, Callable, Dict, Optional, Sequence, TypeVar
from subprocess import check_output
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, TypeVar

import google.auth
import google.auth.credentials
import google.oauth2.service_account
import google_auth_httplib2
import httplib2
Expand All @@ -37,13 +39,15 @@
AlreadyExists, Forbidden, GoogleAPICallError, ResourceExhausted, RetryError, TooManyRequests,
)
from google.api_core.gapic_v1.client_info import ClientInfo
from google.auth import _cloud_sdk
from google.auth.environment_vars import CREDENTIALS
from googleapiclient.errors import HttpError
from googleapiclient.http import set_user_agent

from airflow import version
from airflow.exceptions import AirflowException
from airflow.hooks.base_hook import BaseHook
from airflow.utils.process_utils import patch_environ

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -138,7 +142,7 @@ def __init__(self, gcp_conn_id: str = 'google_cloud_default', delegate_to: Optio
self.delegate_to = delegate_to
self.extras = self.get_connection(self.gcp_conn_id).extra_dejson # type: Dict

def _get_credentials_and_project_id(self) -> google.auth.credentials.Credentials:
def _get_credentials_and_project_id(self) -> Tuple[google.auth.credentials.Credentials, Optional[str]]:
"""
Returns the Credentials object for Google API and the associated project_id
"""
Expand Down Expand Up @@ -387,28 +391,77 @@ def provide_gcp_credential_file_as_context(self):
It can be used to provide credentials for external programs (e.g. gcloud) that expect authorization
file in ``GOOGLE_APPLICATION_CREDENTIALS`` environment variable.
"""
with tempfile.NamedTemporaryFile(mode='w+t') as conf_file:
key_path = self._get_field('key_path', None) # type: Optional[str] # noqa: E501 # pylint: disable=protected-access
keyfile_dict = self._get_field('keyfile_dict', None) # type: Optional[Dict] # noqa: E501 # pylint: disable=protected-access
current_env_state = os.environ.get(CREDENTIALS)
try:
if key_path:
if key_path.endswith('.p12'):
raise AirflowException(
'Legacy P12 key file are not supported, use a JSON key file.'
)
os.environ[CREDENTIALS] = key_path
elif keyfile_dict:
conf_file.write(keyfile_dict)
conf_file.flush()
os.environ[CREDENTIALS] = conf_file.name
else:
# We will use the default service account credentials.
pass
yield conf_file
finally:
if current_env_state is None:
if CREDENTIALS in os.environ:
del os.environ[CREDENTIALS]
else:
os.environ[CREDENTIALS] = current_env_state
key_path = self._get_field('key_path', None) # type: Optional[str] # noqa: E501 # pylint: disable=protected-access
keyfile_dict = self._get_field('keyfile_dict', None) # type: Optional[Dict] # noqa: E501 # pylint: disable=protected-access
if key_path and keyfile_dict:
raise AirflowException(
"The `keyfile_dict` and `key_path` fields are mutually exclusive. "
"Please provide only one value."
)
elif key_path:
if key_path.endswith('.p12'):
raise AirflowException(
'Legacy P12 key file are not supported, use a JSON key file.'
)
with patch_environ({CREDENTIALS: key_path}):
yield key_path
elif keyfile_dict:
with tempfile.NamedTemporaryFile(mode='w+t') as conf_file:
conf_file.write(keyfile_dict)
conf_file.flush()
with patch_environ({CREDENTIALS: conf_file.name}):
yield conf_file.name
else:
# We will use the default service account credentials.
yield None

@contextmanager
def provide_authorized_gcloud(self):
"""
Provides a separate gcloud configuration with current credentials.
The gcloud allows you to login to GCP only - ``gcloud auth login`` and
for the needs of Application Default Credentials ``gcloud auth application-default login``.
In our case, we want all commands to use only the credentials from ADCm so
we need to configure the credentials in gcloud manually.
"""
credentials_path = _cloud_sdk.get_application_default_credentials_path()
project_id = self.project_id

with self.provide_gcp_credential_file_as_context(), \
tempfile.TemporaryDirectory() as gcloud_config_tmp, \
patch_environ({'CLOUDSDK_CONFIG': gcloud_config_tmp}):

if project_id:
# Don't display stdout/stderr for security reason
check_output([
"gcloud", "config", "set", "core/project", project_id
])
if CREDENTIALS in os.environ:
# This solves most cases when we are logged in using the service key in Airflow.
# Don't display stdout/stderr for security reason
check_output([
"gcloud", "auth", "activate-service-account", f"--key-file={os.environ[CREDENTIALS]}",
])
elif os.path.exists(credentials_path):
# If we are logged in by `gcloud auth application-default` then we need to log in manually.
# This will make the `gcloud auth application-default` and `gcloud auth` credentials equals.
with open(credentials_path) as creds_file:
creds_content = json.loads(creds_file.read())
# Don't display stdout/stderr for security reason
check_output([
"gcloud", "config", "set", "auth/client_id", creds_content["client_id"]
])
# Don't display stdout/stderr for security reason
check_output([
"gcloud", "config", "set", "auth/client_secret", creds_content["client_secret"]
])
# Don't display stdout/stderr for security reason
check_output([
"gcloud",
"auth",
"activate-refresh-token",
creds_content["client_id"],
creds_content["refresh_token"],
])
yield
39 changes: 19 additions & 20 deletions airflow/providers/google/cloud/operators/kubernetes_engine.py
Expand Up @@ -21,7 +21,6 @@
"""

import os
import subprocess
import tempfile
from typing import Dict, Optional, Union

Expand All @@ -33,6 +32,7 @@
from airflow.providers.google.cloud.hooks.base import CloudBaseHook
from airflow.providers.google.cloud.hooks.kubernetes_engine import GKEHook
from airflow.utils.decorators import apply_defaults
from airflow.utils.process_utils import execute_in_subprocess, patch_environ


class GKEDeleteClusterOperator(BaseOperator):
Expand Down Expand Up @@ -254,22 +254,21 @@ def execute(self, context):

# Write config to a temp file and set the environment variable to point to it.
# This is to avoid race conditions of reading/writing a single file
with tempfile.NamedTemporaryFile() as conf_file:
os.environ[KUBE_CONFIG_ENV_VAR] = conf_file.name

with hook.provide_gcp_credential_file_as_context():
# Attempt to get/update credentials
# We call gcloud directly instead of using google-cloud-python api
# because there is no way to write kubernetes config to a file, which is
# required by KubernetesPodOperator.
# The gcloud command looks at the env variable `KUBECONFIG` for where to save
# the kubernetes config file.
subprocess.check_call(
["gcloud", "container", "clusters", "get-credentials",
self.cluster_name,
"--zone", self.location,
"--project", self.project_id])

# Tell `KubernetesPodOperator` where the config file is located
self.config_file = os.environ[KUBE_CONFIG_ENV_VAR]
return super().execute(context)
with tempfile.NamedTemporaryFile() as conf_file,\
patch_environ({KUBE_CONFIG_ENV_VAR: conf_file.name}), \
hook.provide_authorized_gcloud():
# Attempt to get/update credentials
# We call gcloud directly instead of using google-cloud-python api
# because there is no way to write kubernetes config to a file, which is
# required by KubernetesPodOperator.
# The gcloud command looks at the env variable `KUBECONFIG` for where to save
# the kubernetes config file.
execute_in_subprocess(
["gcloud", "container", "clusters", "get-credentials",
self.cluster_name,
"--zone", self.location,
"--project", self.project_id])

# Tell `KubernetesPodOperator` where the config file is located
self.config_file = os.environ[KUBE_CONFIG_ENV_VAR]
return super().execute(context)
31 changes: 3 additions & 28 deletions airflow/providers/google/cloud/utils/credentials_provider.py
Expand Up @@ -20,7 +20,6 @@
Google Cloud Platform authentication.
"""
import json
import os
import tempfile
from contextlib import contextmanager
from typing import Dict, Optional, Sequence
Expand All @@ -29,6 +28,7 @@
from google.auth.environment_vars import CREDENTIALS

from airflow.exceptions import AirflowException
from airflow.utils.process_utils import patch_environ

AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT = "AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT"

Expand Down Expand Up @@ -66,31 +66,6 @@ def build_gcp_conn(
return conn.format(query)


@contextmanager
def temporary_environment_variable(variable_name: str, value: str):
"""
Context manager that set up temporary value for a given environment
variable and the restore initial state.
:param variable_name: Name of the environment variable
:type variable_name: str
:param value: The temporary value
:type value: str
"""
# Save initial value
init_value = os.environ.get(variable_name)
try:
# set temporary value
os.environ[variable_name] = value
yield
finally:
# Restore initial state (remove or restore)
if variable_name in os.environ:
del os.environ[variable_name]
if init_value:
os.environ[variable_name] = init_value


@contextmanager
def provide_gcp_credentials(
key_file_path: Optional[str] = None, key_file_dict: Optional[Dict] = None
Expand Down Expand Up @@ -121,7 +96,7 @@ def provide_gcp_credentials(
conf_file.flush()
key_file_path = conf_file.name
if key_file_path:
with temporary_environment_variable(CREDENTIALS, key_file_path):
with patch_environ({CREDENTIALS: key_file_path}):
yield
else:
# We will use the default service account credentials.
Expand Down Expand Up @@ -155,7 +130,7 @@ def provide_gcp_connection(
scopes=scopes, key_file_path=key_file_path, project_id=project_id
)

with temporary_environment_variable(AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT, conn):
with patch_environ({AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT: conn}):
yield


Expand Down
26 changes: 25 additions & 1 deletion airflow/utils/process_utils.py
Expand Up @@ -25,7 +25,8 @@
import shlex
import signal
import subprocess
from typing import List
from contextlib import contextmanager
from typing import Dict, List

import psutil

Expand Down Expand Up @@ -184,3 +185,26 @@ def kill_child_processes_by_pids(pids_to_kill: List[int], timeout: int = 5) -> N
log.info("Killing child PID: %s", child.pid)
child.kill()
child.wait()


@contextmanager
def patch_environ(new_env_variables: Dict[str, str]):
"""
Sets environment variables in context. After leaving the context, it restores its original state.
:param new_env_variables: Environment variables to set
"""
current_env_state = {
key: os.environ.get(key)
for key in new_env_variables.keys()
}
os.environ.update(new_env_variables)
try: # pylint: disable=too-many-nested-blocks
yield
finally:
for key, old_value in current_env_state.items():
if old_value is None:
if key in os.environ:
del os.environ[key]
else:
os.environ[key] = old_value

0 comments on commit 91557c6

Please sign in to comment.