Skip to content

Commit

Permalink
feat: full support for google credentials in gcloud-aio clients (#36849)
Browse files Browse the repository at this point in the history
* feat: full support for google credentials in gcloud-aio clients

The class CredentialsToken implements the ability to generate access
tokens to be used in gcloud-aio clients from Google credentials objects
provided by instances of Google Cloud hooks.

With this change we provide all credentials based capabilities of Google
Cloud hooks (for exmaple impersonation) to gcloud-aio clients.

* test: add tests for CredentialsToken

* Update tests/providers/google/common/hooks/test_base_google.py

Co-authored-by: Wei Lee <[email protected]>

* refactor: make CredentialsToken private

This class is only intended to be used within the Google provider and
might need to change in the future. Making it private in order to avoid
a potential breaking change in the future.

* Revert removal of service_file_as_context

The method `service_file_as_context` not being used anymore in the
airflow, but it is public and removing would imply a breaking changes
for users for the Google provider. Therefore we keep it.

---------

Co-authored-by: Wei Lee <[email protected]>
  • Loading branch information
m1racoli and Lee-W committed Jan 23, 2024
1 parent b7f84c4 commit fbd21ed
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 14 deletions.
25 changes: 15 additions & 10 deletions airflow/providers/google/cloud/hooks/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -3287,8 +3287,13 @@ async def get_job_instance(
self, project_id: str | None, job_id: str | None, session: ClientSession
) -> Job:
"""Get the specified job resource by job ID and project ID."""
with await self.service_file_as_context() as f:
return Job(job_id=job_id, project=project_id, service_file=f, session=cast(Session, session))
token = await self.get_token(session=session)
return Job(
job_id=job_id,
project=project_id,
token=token,
session=cast(Session, session),
)

async def get_job_status(self, job_id: str | None, project_id: str | None = None) -> dict[str, str]:
async with ClientSession() as s:
Expand Down Expand Up @@ -3532,11 +3537,11 @@ async def get_table_client(
access to the specified project.
:param session: aiohttp ClientSession
"""
with await self.service_file_as_context() as file:
return Table_async(
dataset_name=dataset,
table_name=table_id,
project=project_id,
service_file=file,
session=cast(Session, session),
)
token = await self.get_token(session=session)
return Table_async(
dataset_name=dataset,
table_name=table_id,
project=project_id,
token=token,
session=cast(Session, session),
)
7 changes: 5 additions & 2 deletions airflow/providers/google/cloud/hooks/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1398,5 +1398,8 @@ class GCSAsyncHook(GoogleBaseAsyncHook):

async def get_storage_client(self, session: ClientSession) -> Storage:
"""Returns a Google Cloud Storage service object."""
with await self.service_file_as_context() as file:
return Storage(service_file=file, session=cast(Session, session))
token = await self.get_token(session=session)
return Storage(
token=token,
session=cast(Session, session),
)
56 changes: 56 additions & 0 deletions airflow/providers/google/common/hooks/base_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""This module contains a Google Cloud API base hook."""
from __future__ import annotations

import datetime
import functools
import json
import logging
Expand All @@ -35,6 +36,7 @@
import requests
import tenacity
from asgiref.sync import sync_to_async
from gcloud.aio.auth.token import Token
from google.api_core.exceptions import Forbidden, ResourceExhausted, TooManyRequests
from google.auth import _cloud_sdk, compute_engine # type: ignore[attr-defined]
from google.auth.environment_vars import CLOUD_SDK_CONFIG_DIR, CREDENTIALS
Expand All @@ -43,6 +45,7 @@
from googleapiclient import discovery
from googleapiclient.errors import HttpError
from googleapiclient.http import MediaIoBaseDownload, build_http, set_user_agent
from requests import Session

from airflow import version
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
Expand All @@ -56,7 +59,9 @@
from airflow.utils.process_utils import patch_environ

if TYPE_CHECKING:
from aiohttp import ClientSession
from google.api_core.gapic_v1.client_info import ClientInfo
from google.auth.credentials import Credentials

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -623,6 +628,51 @@ def test_connection(self):
return status, message


class _CredentialsToken(Token):
"""A token implementation which makes Google credentials objects accessible to [gcloud-aio](https://talkiq.github.io/gcloud-aio/) clients.
This class allows us to create token instances from credentials objects and thus supports a variety of use cases for Google
credentials in Airflow (i.e. impersonation chain). By relying on a existing credentials object we leverage functionality provided by the GoogleBaseHook
for generating credentials objects.
"""

def __init__(
self,
credentials: Credentials,
*,
project: str | None = None,
session: ClientSession | None = None,
) -> None:
super().__init__(session=cast(Session, session))
self.credentials = credentials
self.project = project

@classmethod
async def from_hook(
cls,
hook: GoogleBaseHook,
*,
session: ClientSession | None = None,
) -> _CredentialsToken:
credentials, project = hook.get_credentials_and_project_id()
return cls(
credentials=credentials,
project=project,
session=session,
)

async def get_project(self) -> str | None:
return self.project

async def acquire_access_token(self, timeout: int = 10) -> None:
await sync_to_async(self.credentials.refresh)(google.auth.transport.requests.Request())

self.access_token = cast(str, self.credentials.token)
self.access_token_duration = 3600
self.access_token_acquired_at = datetime.datetime.utcnow()
self.acquiring = None


class GoogleBaseAsyncHook(BaseHook):
"""GoogleBaseAsyncHook inherits from BaseHook class, run on the trigger worker."""

Expand All @@ -639,6 +689,12 @@ async def get_sync_hook(self) -> Any:
self._sync_hook = await sync_to_async(self.sync_hook_class)(**self._hook_kwargs)
return self._sync_hook

async def get_token(self, *, session: ClientSession | None = None) -> _CredentialsToken:
"""Returns a Token instance for use in [gcloud-aio](https://talkiq.github.io/gcloud-aio/) clients."""
sync_hook = await self.get_sync_hook()
return await _CredentialsToken.from_hook(sync_hook, session=session)

async def service_file_as_context(self) -> Any:
"""This is the async equivalent of the non-async GoogleBaseHook's `provide_gcp_credential_file_as_context` method."""
sync_hook = await self.get_sync_hook()
return await sync_to_async(sync_hook.provide_gcp_credential_file_as_context)()
12 changes: 10 additions & 2 deletions tests/providers/google/cloud/hooks/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from unittest import mock
from unittest.mock import AsyncMock

import google.auth
import pytest
from gcloud.aio.bigquery import Job, Table as Table_async
from google.api_core import page_iterator
Expand Down Expand Up @@ -2143,8 +2144,12 @@ def get_credentials_and_project_id(self):
class TestBigQueryAsyncHookMethods(_BigQueryBaseAsyncTestClass):
@pytest.mark.db_test
@pytest.mark.asyncio
@mock.patch("google.auth.default")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.ClientSession")
async def test_get_job_instance(self, mock_session):
async def test_get_job_instance(self, mock_session, mock_auth_default):
mock_credentials = mock.MagicMock(spec=google.auth.compute_engine.Credentials)
mock_credentials.token = "ACCESS_TOKEN"
mock_auth_default.return_value = (mock_credentials, PROJECT_ID)
hook = BigQueryAsyncHook()
result = await hook.get_job_instance(project_id=PROJECT_ID, job_id=JOB_ID, session=mock_session)
assert isinstance(result, Job)
Expand Down Expand Up @@ -2315,10 +2320,13 @@ def test_convert_to_float_if_possible(self, test_input, expected):

@pytest.mark.db_test
@pytest.mark.asyncio
@mock.patch("google.auth.default")
@mock.patch("aiohttp.client.ClientSession")
async def test_get_table_client(self, mock_session):
async def test_get_table_client(self, mock_session, mock_auth_default):
"""Test get_table_client async function and check whether the return value is a
Table instance object"""
mock_credentials = mock.MagicMock(spec=google.auth.compute_engine.Credentials)
mock_auth_default.return_value = (mock_credentials, PROJECT_ID)
hook = BigQueryTableAsyncHook()
result = await hook.get_table_client(
dataset=DATASET_ID, project_id=PROJECT_ID, table_id=TABLE_ID, session=mock_session
Expand Down
93 changes: 93 additions & 0 deletions tests/providers/google/common/hooks/test_base_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from unittest.mock import patch

import google.auth
import google.auth.compute_engine
import pytest
import tenacity
from google.auth.environment_vars import CREDENTIALS
Expand Down Expand Up @@ -874,3 +875,95 @@ def test_should_fallback_when_empty_string_in_env_var(self):
instance = hook.GoogleBaseHook(gcp_conn_id="google_cloud_default")
assert isinstance(instance.num_retries, int)
assert 5 == instance.num_retries


class TestCredentialsToken:
@pytest.mark.asyncio
async def test_get_project(self):
mock_credentials = mock.MagicMock(spec=google.auth.compute_engine.Credentials)
token = hook._CredentialsToken(mock_credentials, project=PROJECT_ID)
assert await token.get_project() == PROJECT_ID

@pytest.mark.asyncio
async def test_get(self):
mock_credentials = mock.MagicMock(spec=google.auth.compute_engine.Credentials)
mock_credentials.token = "ACCESS_TOKEN"
token = hook._CredentialsToken(mock_credentials, project=PROJECT_ID)
assert await token.get() == "ACCESS_TOKEN"
mock_credentials.refresh.assert_called_once()

@pytest.mark.asyncio
@mock.patch(f"{MODULE_NAME}.get_credentials_and_project_id", return_value=("CREDENTIALS", "PROJECT_ID"))
async def test_from_hook(self, get_creds_and_project, monkeypatch):
monkeypatch.setenv(
"AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT",
"google-cloud-platform://",
)
instance = hook.GoogleBaseHook(gcp_conn_id="google_cloud_default")
token = await hook._CredentialsToken.from_hook(instance)
assert token.credentials == "CREDENTIALS"
assert token.project == "PROJECT_ID"


class TestGoogleBaseAsyncHook:
@pytest.mark.asyncio
@mock.patch("google.auth.default")
async def test_get_token(self, mock_auth_default, monkeypatch) -> None:
mock_credentials = mock.MagicMock(spec=google.auth.compute_engine.Credentials)
mock_credentials.token = "ACCESS_TOKEN"
mock_auth_default.return_value = (mock_credentials, "PROJECT_ID")
monkeypatch.setenv(
"AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT",
"google-cloud-platform://?project=CONN_PROJECT_ID",
)

instance = hook.GoogleBaseAsyncHook(gcp_conn_id="google_cloud_default")
instance.sync_hook_class = hook.GoogleBaseHook
token = await instance.get_token()
assert await token.get_project() == "CONN_PROJECT_ID"
assert await token.get() == "ACCESS_TOKEN"
mock_credentials.refresh.assert_called_once()

@pytest.mark.asyncio
@mock.patch("google.auth.default")
async def test_get_token_impersonation(self, mock_auth_default, monkeypatch, requests_mock) -> None:
mock_credentials = mock.MagicMock(spec=google.auth.compute_engine.Credentials)
mock_credentials.token = "ACCESS_TOKEN"
mock_auth_default.return_value = (mock_credentials, "PROJECT_ID")
monkeypatch.setenv(
"AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT",
"google-cloud-platform://?project=CONN_PROJECT_ID",
)
requests_mock.post(
"https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/SERVICE_ACCOUNT@SA_PROJECT.iam.gserviceaccount.com:generateAccessToken",
text='{"accessToken": "IMPERSONATED_ACCESS_TOKEN", "expireTime": "2014-10-02T15:01:23Z"}',
)

instance = hook.GoogleBaseAsyncHook(
gcp_conn_id="google_cloud_default",
impersonation_chain="SERVICE_ACCOUNT@SA_PROJECT.iam.gserviceaccount.com",
)
instance.sync_hook_class = hook.GoogleBaseHook
token = await instance.get_token()
assert await token.get_project() == "CONN_PROJECT_ID"
assert await token.get() == "IMPERSONATED_ACCESS_TOKEN"

@pytest.mark.asyncio
@mock.patch("google.auth.default")
async def test_get_token_impersonation_conn(self, mock_auth_default, monkeypatch, requests_mock) -> None:
mock_credentials = mock.MagicMock(spec=google.auth.compute_engine.Credentials)
mock_auth_default.return_value = (mock_credentials, "PROJECT_ID")
monkeypatch.setenv(
"AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT",
"google-cloud-platform://?project=CONN_PROJECT_ID&impersonation_chain=SERVICE_ACCOUNT@SA_PROJECT.iam.gserviceaccount.com",
)
requests_mock.post(
"https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/SERVICE_ACCOUNT@SA_PROJECT.iam.gserviceaccount.com:generateAccessToken",
text='{"accessToken": "IMPERSONATED_ACCESS_TOKEN", "expireTime": "2014-10-02T15:01:23Z"}',
)

instance = hook.GoogleBaseAsyncHook(gcp_conn_id="google_cloud_default")
instance.sync_hook_class = hook.GoogleBaseHook
token = await instance.get_token()
assert await token.get_project() == "CONN_PROJECT_ID"
assert await token.get() == "IMPERSONATED_ACCESS_TOKEN"

0 comments on commit fbd21ed

Please sign in to comment.