Skip to content

Commit 66d0222

Browse files
authored
Add CloudRunHook and operators (#33067)
1 parent 4bd0b56 commit 66d0222

File tree

11 files changed

+1829
-0
lines changed

11 files changed

+1829
-0
lines changed
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
from __future__ import annotations
19+
20+
import itertools
21+
from typing import Iterable, Sequence
22+
23+
from google.api_core import operation
24+
from google.cloud.run_v2 import (
25+
CreateJobRequest,
26+
DeleteJobRequest,
27+
GetJobRequest,
28+
Job,
29+
JobsAsyncClient,
30+
JobsClient,
31+
ListJobsRequest,
32+
RunJobRequest,
33+
UpdateJobRequest,
34+
)
35+
from google.cloud.run_v2.services.jobs import pagers
36+
from google.longrunning import operations_pb2
37+
38+
from airflow.exceptions import AirflowException
39+
from airflow.providers.google.common.consts import CLIENT_INFO
40+
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook
41+
42+
43+
class CloudRunHook(GoogleBaseHook):
44+
"""
45+
Hook for the Google Cloud Run service.
46+
47+
:param gcp_conn_id: The connection ID to use when fetching connection info.
48+
:param impersonation_chain: Optional service account to impersonate using short-term
49+
credentials, or chained list of accounts required to get the access_token
50+
of the last account in the list, which will be impersonated in the request.
51+
If set as a string, the account must grant the originating account
52+
the Service Account Token Creator IAM role.
53+
If set as a sequence, the identities from the list must grant
54+
Service Account Token Creator IAM role to the directly preceding identity, with first
55+
account from the list granting this role to the originating account.
56+
"""
57+
58+
def __init__(
59+
self,
60+
gcp_conn_id: str = "google_cloud_default",
61+
impersonation_chain: str | Sequence[str] | None = None,
62+
) -> None:
63+
super().__init__(gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain)
64+
self._client: JobsClient | None = None
65+
66+
def get_conn(self):
67+
"""
68+
Retrieves connection to Cloud Run.
69+
70+
:return: Cloud Run Jobs client object.
71+
"""
72+
if self._client is None:
73+
self._client = JobsClient(credentials=self.get_credentials(), client_info=CLIENT_INFO)
74+
return self._client
75+
76+
@GoogleBaseHook.fallback_to_default_project_id
77+
def delete_job(self, job_name: str, region: str, project_id: str = PROVIDE_PROJECT_ID) -> Job:
78+
delete_request = DeleteJobRequest()
79+
delete_request.name = f"projects/{project_id}/locations/{region}/jobs/{job_name}"
80+
81+
operation = self.get_conn().delete_job(delete_request)
82+
return operation.result()
83+
84+
@GoogleBaseHook.fallback_to_default_project_id
85+
def create_job(
86+
self, job_name: str, job: Job | dict, region: str, project_id: str = PROVIDE_PROJECT_ID
87+
) -> Job:
88+
if isinstance(job, dict):
89+
job = Job(job)
90+
91+
create_request = CreateJobRequest()
92+
create_request.job = job
93+
create_request.job_id = job_name
94+
create_request.parent = f"projects/{project_id}/locations/{region}"
95+
96+
operation = self.get_conn().create_job(create_request)
97+
return operation.result()
98+
99+
@GoogleBaseHook.fallback_to_default_project_id
100+
def update_job(
101+
self, job_name: str, job: Job | dict, region: str, project_id: str = PROVIDE_PROJECT_ID
102+
) -> Job:
103+
if isinstance(job, dict):
104+
job = Job(job)
105+
106+
update_request = UpdateJobRequest()
107+
job.name = f"projects/{project_id}/locations/{region}/jobs/{job_name}"
108+
update_request.job = job
109+
operation = self.get_conn().update_job(update_request)
110+
return operation.result()
111+
112+
@GoogleBaseHook.fallback_to_default_project_id
113+
def execute_job(
114+
self, job_name: str, region: str, project_id: str = PROVIDE_PROJECT_ID
115+
) -> operation.Operation:
116+
run_job_request = RunJobRequest(name=f"projects/{project_id}/locations/{region}/jobs/{job_name}")
117+
operation = self.get_conn().run_job(request=run_job_request)
118+
return operation
119+
120+
@GoogleBaseHook.fallback_to_default_project_id
121+
def get_job(self, job_name: str, region: str, project_id: str = PROVIDE_PROJECT_ID):
122+
get_job_request = GetJobRequest(name=f"projects/{project_id}/locations/{region}/jobs/{job_name}")
123+
return self.get_conn().get_job(get_job_request)
124+
125+
@GoogleBaseHook.fallback_to_default_project_id
126+
def list_jobs(
127+
self,
128+
region: str,
129+
project_id: str = PROVIDE_PROJECT_ID,
130+
show_deleted: bool = False,
131+
limit: int | None = None,
132+
) -> Iterable[Job]:
133+
134+
if limit is not None and limit < 0:
135+
raise AirflowException("The limit for the list jobs request should be greater or equal to zero")
136+
137+
list_jobs_request: ListJobsRequest = ListJobsRequest(
138+
parent=f"projects/{project_id}/locations/{region}", show_deleted=show_deleted
139+
)
140+
141+
jobs: pagers.ListJobsPager = self.get_conn().list_jobs(request=list_jobs_request)
142+
143+
return list(itertools.islice(jobs, limit))
144+
145+
146+
class CloudRunAsyncHook(GoogleBaseHook):
147+
"""
148+
Async hook for the Google Cloud Run service.
149+
150+
:param gcp_conn_id: The connection ID to use when fetching connection info.
151+
:param impersonation_chain: Optional service account to impersonate using short-term
152+
credentials, or chained list of accounts required to get the access_token
153+
of the last account in the list, which will be impersonated in the request.
154+
If set as a string, the account must grant the originating account
155+
the Service Account Token Creator IAM role.
156+
If set as a sequence, the identities from the list must grant
157+
Service Account Token Creator IAM role to the directly preceding identity, with first
158+
account from the list granting this role to the originating account.
159+
"""
160+
161+
def __init__(
162+
self,
163+
gcp_conn_id: str = "google_cloud_default",
164+
impersonation_chain: str | Sequence[str] | None = None,
165+
):
166+
self._client: JobsAsyncClient = JobsAsyncClient()
167+
super().__init__(gcp_conn_id=gcp_conn_id, impersonation_chain=impersonation_chain)
168+
169+
def get_conn(self):
170+
if self._client is None:
171+
self._client = JobsAsyncClient(credentials=self.get_credentials(), client_info=CLIENT_INFO)
172+
173+
return self._client
174+
175+
async def get_operation(self, operation_name: str) -> operations_pb2.Operation:
176+
return await self.get_conn().get_operation(operations_pb2.GetOperationRequest(name=operation_name))

0 commit comments

Comments
 (0)