Skip to content

Commit

Permalink
Refresh GKE OAuth2 tokens (#32673)
Browse files Browse the repository at this point in the history
* Refresh token for sync mode
  • Loading branch information
fdemiane committed Jul 20, 2023
1 parent 27b5f69 commit 848c69a
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
4 changes: 4 additions & 0 deletions airflow/providers/google/cloud/hooks/kubernetes_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,8 +387,12 @@ def get_xcom_sidecar_container_resources(self):

def get_conn(self) -> client.ApiClient:
configuration = self._get_config()
configuration.refresh_api_key_hook = self._refresh_api_key_hook
return client.ApiClient(configuration)

def _refresh_api_key_hook(self, configuration: client.configuration.Configuration):
configuration.api_key = {"authorization": self._get_token(self.get_credentials())}

def _get_config(self) -> client.configuration.Configuration:
configuration = client.Configuration(
host=self._cluster_url,
Expand Down
46 changes: 46 additions & 0 deletions tests/providers/google/cloud/hooks/test_kubernetes_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
GKEAsyncHook,
GKEHook,
GKEPodAsyncHook,
GKEPodHook,
)
from airflow.providers.google.common.consts import CLIENT_INFO
from tests.providers.google.cloud.utils.base_gcp_mock import mock_base_gcp_hook_default_project_id
Expand Down Expand Up @@ -397,3 +398,48 @@ async def test_get_operation(self, mock_get_client, async_gke_hook, mock_async_g
mock_async_gke_cluster_client.get_operation.assert_called_once_with(
name=operation_path,
)


class TestGKEPodHook:
def setup_method(self):
with mock.patch(
BASE_STRING.format("GoogleBaseHook.__init__"), new=mock_base_gcp_hook_default_project_id
):
self.gke_hook = GKEPodHook(gcp_conn_id="test", ssl_ca_cert=None, cluster_url=None)
self.gke_hook._client = mock.Mock()

def refresh_token(request):
self.credentials.token = "New"

self.credentials = mock.MagicMock()
self.credentials.token = "Old"
self.credentials.expired = False
self.credentials.refresh = refresh_token

@mock.patch(GKE_STRING.format("google_requests.Request"))
def test_get_connection_update_hook_with_invalid_token(self, mock_request):
self.gke_hook._get_config = self._get_config
self.gke_hook.get_credentials = self._get_credentials
self.gke_hook.get_credentials().expired = True
the_client: kubernetes.client.ApiClient = self.gke_hook.get_conn()

the_client.configuration.refresh_api_key_hook(the_client.configuration)

assert self.gke_hook.get_credentials().token == "New"

@mock.patch(GKE_STRING.format("google_requests.Request"))
def test_get_connection_update_hook_with_valid_token(self, mock_request):
self.gke_hook._get_config = self._get_config
self.gke_hook.get_credentials = self._get_credentials
self.gke_hook.get_credentials().expired = False
the_client: kubernetes.client.ApiClient = self.gke_hook.get_conn()

the_client.configuration.refresh_api_key_hook(the_client.configuration)

assert self.gke_hook.get_credentials().token == "Old"

def _get_config(self):
return kubernetes.client.configuration.Configuration()

def _get_credentials(self):
return self.credentials

0 comments on commit 848c69a

Please sign in to comment.