Skip to content

Commit

Permalink
Fix StackdriverTaskHandler + add system tests (#9761)
Browse files Browse the repository at this point in the history
Co-authored-by: Tomek Urbaszek <[email protected]>
Co-authored-by: Tomek Urbaszek <[email protected]>
  • Loading branch information
mik-laj and turbaszek committed Jul 11, 2020
1 parent bade1d3 commit 092d33f
Show file tree
Hide file tree
Showing 7 changed files with 250 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ class StackdriverTaskHandler(logging.Handler):
This handler supports both an asynchronous and synchronous transport.
:param gcp_key_path: Path to GCP Credential JSON file.
If ommited, authorization based on `the Application Default Credentials
<https://cloud.google.com/docs/authentication/production#finding_credentials_automatically>`__ will
Expand Down Expand Up @@ -104,12 +103,14 @@ def __init__(
@cached_property
def _client(self) -> gcp_logging.Client:
"""Google Cloud Library API client"""
credentials, _ = get_credentials_and_project_id(
credentials, project = get_credentials_and_project_id(
key_path=self.gcp_key_path,
scopes=self.scopes,
disable_logging=True
)
client = gcp_logging.Client(
credentials=credentials,
project=project,
client_info=ClientInfo(client_library_version='airflow_v' + version.version)
)
return client
Expand Down
141 changes: 93 additions & 48 deletions airflow/providers/google/cloud/utils/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@
from urllib.parse import urlencode

import google.auth
import google.auth.credentials
import google.oauth2.service_account
from google.auth.environment_vars import CREDENTIALS, LEGACY_PROJECT, PROJECT

from airflow.exceptions import AirflowException
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.process_utils import patch_environ

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -176,15 +178,9 @@ def provide_gcp_conn_and_credentials(
yield


def get_credentials_and_project_id(
key_path: Optional[str] = None,
keyfile_dict: Optional[Dict[str, str]] = None,
# See: https://github.com/PyCQA/pylint/issues/2377
scopes: Optional[Collection[str]] = None, # pylint: disable=unsubscriptable-object
delegate_to: Optional[str] = None
) -> Tuple[google.auth.credentials.Credentials, str]:
class _CredentialProvider(LoggingMixin):
"""
Returns the Credentials object for Google API and the associated project_id
Prepare the Credentials object for Google API and the associated project_id
Only either `key_path` or `keyfile_dict` should be provided, or an exception will
occur. If neither of them are provided, return default credentials for the current environment
Expand All @@ -194,64 +190,113 @@ def get_credentials_and_project_id(
:param keyfile_dict: A dict representing GCP Credential as in the Credential JSON file
:type keyfile_dict: Dict[str, str]
:param scopes: OAuth scopes for the connection
:type scopes: Sequence[str]
:type scopes: Collection[str]
:param delegate_to: The account to impersonate, if any.
For this to work, the service account making the request must have
domain-wide delegation enabled.
:type delegate_to: str
:return: Google Auth Credentials
:type: google.auth.credentials.Credentials
:param disable_logging: If true, disable all log messages, which allows you to use this
class to configure Logger.
"""
if key_path and keyfile_dict:
raise AirflowException(
"The `keyfile_dict` and `key_path` fields are mutually exclusive. "
"Please provide only one value."
)
if not key_path and not keyfile_dict:
log.info(
'Getting connection using `google.auth.default()` since no key file is defined for hook.'
)
credentials, project_id = google.auth.default(scopes=scopes)
elif key_path:
# Get credentials from a JSON file.
if key_path.endswith('.json'):
log.debug('Getting connection using JSON key file %s', key_path)
credentials = (
google.oauth2.service_account.Credentials.from_service_account_file(
key_path, scopes=scopes)
)
project_id = credentials.project_id
elif key_path.endswith('.p12'):
def __init__(
self,
key_path: Optional[str] = None,
keyfile_dict: Optional[Dict[str, str]] = None,
# See: https://github.com/PyCQA/pylint/issues/2377
scopes: Optional[Collection[str]] = None, # pylint: disable=unsubscriptable-object
delegate_to: Optional[str] = None,
disable_logging: bool = False
):
super().__init__()
if key_path and keyfile_dict:
raise AirflowException(
'Legacy P12 key file are not supported, use a JSON key file.'
"The `keyfile_dict` and `key_path` fields are mutually exclusive. "
"Please provide only one value."
)
self.key_path = key_path
self.keyfile_dict = keyfile_dict
self.scopes = scopes
self.delegate_to = delegate_to
self.disable_logging = disable_logging

def get_credentials_and_project(self):
"""
Get current credentials and project ID.
:return: Google Auth Credentials
:type: Tuple[google.auth.credentials.Credentials, str]
"""
if self.key_path:
credentials, project_id = self._get_credentials_using_key_path()
elif self.keyfile_dict:
credentials, project_id = self._get_credentials_using_keyfile_dict()
else:
raise AirflowException('Unrecognised extension for key file.')
else:
if not keyfile_dict:
raise ValueError("The keyfile_dict should be set")
credentials, project_id = self._get_credentials_using_adc()

if self.delegate_to:
if hasattr(credentials, 'with_subject'):
credentials = credentials.with_subject(self.delegate_to)
else:
raise AirflowException(
"The `delegate_to` parameter cannot be used here as the current "
"authentication method does not support account impersonate. "
"Please use service-account for authorization."
)

return credentials, project_id

def _get_credentials_using_keyfile_dict(self):
self._log_debug('Getting connection using JSON Dict')
# Depending on how the JSON was formatted, it may contain
# escaped newlines. Convert those to actual newlines.
keyfile_dict['private_key'] = keyfile_dict['private_key'].replace(
'\\n', '\n')

self.keyfile_dict['private_key'] = self.keyfile_dict['private_key'].replace('\\n', '\n')
credentials = (
google.oauth2.service_account.Credentials.from_service_account_info(
keyfile_dict, scopes=scopes)
self.keyfile_dict, scopes=self.scopes)
)
project_id = credentials.project_id
return credentials, project_id

if delegate_to:
if hasattr(credentials, 'with_subject'):
credentials = credentials.with_subject(delegate_to)
else:
def _get_credentials_using_key_path(self):
if self.key_path.endswith('.p12'):
raise AirflowException(
"The `delegate_to` parameter cannot be used here as the current "
"authentication method does not support account impersonate. "
"Please use service-account for authorization."
'Legacy P12 key file are not supported, use a JSON key file.'
)

return credentials, project_id
if not self.key_path.endswith('.json'):
raise AirflowException('Unrecognised extension for key file.')

self._log_debug('Getting connection using JSON key file %s', self.key_path)
credentials = (
google.oauth2.service_account.Credentials.from_service_account_file(
self.key_path, scopes=self.scopes)
)
project_id = credentials.project_id
return credentials, project_id

def _get_credentials_using_adc(self):
self._log_info(
'Getting connection using `google.auth.default()` since no key file is defined for hook.'
)
credentials, project_id = google.auth.default(scopes=self.scopes)
return credentials, project_id

def _log_info(self, *args, **kwargs):
if not self.disable_logging:
self.log.info(*args, **kwargs)

def _log_debug(self, *args, **kwargs):
if not self.disable_logging:
self.log.debug(*args, **kwargs)


def get_credentials_and_project_id(
*args, **kwargs
) -> Tuple[google.auth.credentials.Credentials, str]:
"""
Returns the Credentials object for Google API and the associated project_id.
"""
return _CredentialProvider(*args, **kwargs).get_credentials_and_project()


def _get_scopes(scopes: Optional[str] = None) -> Sequence[str]:
Expand Down
1 change: 0 additions & 1 deletion docs/howto/write-logs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,6 @@ example:
# configuration requirements.
remote_logging = True
remote_base_log_folder = stackdriver://logs-name
remote_log_conn_id = custom-conn-id
All configuration options are in the ``[logging]`` section.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def test_should_pass_message_to_client(self, mock_client, mock_get_creds_and_pro
)
mock_client.assert_called_once_with(
credentials='creds',
client_info=mock.ANY
client_info=mock.ANY,
project="project_id"
)


Expand Down Expand Up @@ -279,6 +280,7 @@ def test_should_use_credentials(self, mock_client, mock_get_creds_and_project_id
client = stackdriver_task_handler._client

mock_get_creds_and_project_id.assert_called_once_with(
disable_logging=True,
key_path='KEY_PATH',
scopes=frozenset(
{
Expand All @@ -289,6 +291,7 @@ def test_should_use_credentials(self, mock_client, mock_get_creds_and_project_id
)
mock_client.assert_called_once_with(
credentials='creds',
client_info=mock.ANY
client_info=mock.ANY,
project="project_id"
)
self.assertEqual(mock_client.return_value, client)
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import importlib
import random
import string
import subprocess
import unittest
from unittest import mock

import pytest

from airflow import settings
from airflow.example_dags import example_complex
from airflow.models import TaskInstance
from airflow.utils.log.log_reader import TaskLogReader
from airflow.utils.session import provide_session
from tests.providers.google.cloud.utils.gcp_authenticator import GCP_STACKDDRIVER
from tests.test_utils.config import conf_vars
from tests.test_utils.db import clear_db_runs
from tests.test_utils.gcp_system_helpers import provide_gcp_context, resolve_full_gcp_key_path


@pytest.mark.system("google")
@pytest.mark.credential_file(GCP_STACKDDRIVER)
class TestStackdriverLoggingHandlerSystemTest(unittest.TestCase):

def setUp(self) -> None:
clear_db_runs()
self.log_name = 'stackdriver-tests-'.join(random.sample(string.ascii_lowercase, 16))

def tearDown(self) -> None:
from airflow.config_templates import airflow_local_settings
importlib.reload(airflow_local_settings)
settings.configure_logging()
clear_db_runs()

@provide_session
def test_should_support_key_auth(self, session):
with mock.patch.dict(
'os.environ',
AIRFLOW__LOGGING__REMOTE_LOGGING="true",
AIRFLOW__LOGGING__REMOTE_BASE_LOG_FOLDER=f"stackdriver://{self.log_name}",
AIRFLOW__LOGGING__STACKDRIVER_KEY_PATH=resolve_full_gcp_key_path(GCP_STACKDDRIVER),
AIRFLOW__CORE__LOAD_EXAMPLES="false",
AIRFLOW__CORE__DAGS_FOLDER=example_complex.__file__
):
self.assertEqual(0, subprocess.Popen(
["airflow", "dags", "trigger", "example_complex"]
).wait())
self.assertEqual(0, subprocess.Popen(
["airflow", "scheduler", "--num-runs", "1"]
).wait())
ti = session.query(TaskInstance).filter(TaskInstance.task_id == "create_entry_group").first()

self.assert_remote_logs("INFO - Task exited with return code 0", ti)

@provide_session
def test_should_support_adc(self, session):
with mock.patch.dict(
'os.environ',
AIRFLOW__LOGGING__REMOTE_LOGGING="true",
AIRFLOW__LOGGING__REMOTE_BASE_LOG_FOLDER=f"stackdriver://{self.log_name}",
AIRFLOW__CORE__LOAD_EXAMPLES="false",
AIRFLOW__CORE__DAGS_FOLDER=example_complex.__file__,
GOOGLE_APPLICATION_CREDENTIALS=resolve_full_gcp_key_path(GCP_STACKDDRIVER)
):
self.assertEqual(0, subprocess.Popen(
["airflow", "dags", "trigger", "example_complex"]
).wait())
self.assertEqual(0, subprocess.Popen(
["airflow", "scheduler", "--num-runs", "1"]
).wait())
ti = session.query(TaskInstance).filter(TaskInstance.task_id == "create_entry_group").first()

self.assert_remote_logs("INFO - Task exited with return code 0", ti)

def assert_remote_logs(self, expected_message, ti):
with provide_gcp_context(GCP_STACKDDRIVER), conf_vars({
('logging', 'remote_logging'): 'True',
('logging', 'remote_base_log_folder'): f"stackdriver://{self.log_name}",
}):
from airflow.config_templates import airflow_local_settings
importlib.reload(airflow_local_settings)
settings.configure_logging()

task_log_reader = TaskLogReader()
logs = "\n".join(task_log_reader.read_log_stream(ti, try_number=None, metadata={}))
self.assertIn(expected_message, logs)
3 changes: 2 additions & 1 deletion tests/providers/google/cloud/utils/gcp_authenticator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@
GCP_LIFE_SCIENCES_KEY = 'gcp_life_sciences.json'
GCP_MEMORYSTORE = 'gcp_memorystore.json'
GCP_PUBSUB_KEY = "gcp_pubsub.json"
GCP_SPANNER_KEY = 'gcp_spanner.json'
GCP_SECRET_MANAGER_KEY = 'gcp_secret_manager.json'
GCP_SPANNER_KEY = 'gcp_spanner.json'
GCP_STACKDDRIVER = 'gcp_stackdriver.json'
GCP_TASKS_KEY = 'gcp_tasks.json'
GMP_KEY = 'gmp.json'
G_FIREBASE_KEY = 'g_firebase.json'
Expand Down

0 comments on commit 092d33f

Please sign in to comment.