Skip to content

Commit

Permalink
Add scopes into a GCP token (#36974)
Browse files Browse the repository at this point in the history
* Add scopes into a GCP token

* Update airflow/providers/google/common/hooks/base_google.py

Co-authored-by: Gopal Dirisala <[email protected]>

---------

Co-authored-by: Hussein Awala <[email protected]>
Co-authored-by: Gopal Dirisala <[email protected]>
  • Loading branch information
3 people committed Jan 24, 2024
1 parent c401d58 commit 241b50a
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 3 deletions.
5 changes: 4 additions & 1 deletion airflow/providers/google/common/hooks/base_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,8 +642,10 @@ def __init__(
*,
project: str | None = None,
session: ClientSession | None = None,
scopes: Sequence[str] | None = None,
) -> None:
super().__init__(session=cast(Session, session))
_scopes: list[str] | None = list(scopes) if scopes else None
super().__init__(session=cast(Session, session), scopes=_scopes)
self.credentials = credentials
self.project = project

Expand All @@ -659,6 +661,7 @@ async def from_hook(
credentials=credentials,
project=project,
session=session,
scopes=hook.scopes,
)

async def get_project(self) -> str | None:
Expand Down
5 changes: 3 additions & 2 deletions tests/providers/google/common/hooks/test_base_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
MODULE_NAME = "airflow.providers.google.common.hooks.base_google"
PROJECT_ID = "PROJECT_ID"
ENV_VALUE = "/tmp/a"
SCOPES = ["https://www.googleapis.com/auth/cloud-platform"]


class NoForbiddenAfterCount:
Expand Down Expand Up @@ -881,14 +882,14 @@ class TestCredentialsToken:
@pytest.mark.asyncio
async def test_get_project(self):
mock_credentials = mock.MagicMock(spec=google.auth.compute_engine.Credentials)
token = hook._CredentialsToken(mock_credentials, project=PROJECT_ID)
token = hook._CredentialsToken(mock_credentials, project=PROJECT_ID, scopes=SCOPES)
assert await token.get_project() == PROJECT_ID

@pytest.mark.asyncio
async def test_get(self):
mock_credentials = mock.MagicMock(spec=google.auth.compute_engine.Credentials)
mock_credentials.token = "ACCESS_TOKEN"
token = hook._CredentialsToken(mock_credentials, project=PROJECT_ID)
token = hook._CredentialsToken(mock_credentials, project=PROJECT_ID, scopes=SCOPES)
assert await token.get() == "ACCESS_TOKEN"
mock_credentials.refresh.assert_called_once()

Expand Down

0 comments on commit 241b50a

Please sign in to comment.