Skip to content

Commit

Permalink
Implement impersonation in google operators (#10052)
Browse files Browse the repository at this point in the history
Co-authored-by: Kamil Olszewski <[email protected]>
  • Loading branch information
olchas and Kamil Olszewski committed Aug 24, 2020
1 parent b0598b5 commit 3734876
Show file tree
Hide file tree
Showing 118 changed files with 6,845 additions and 1,258 deletions.
6 changes: 6 additions & 0 deletions UPDATING.md
Original file line number Diff line number Diff line change
Expand Up @@ -1038,6 +1038,12 @@ of this provider.
This section describes the changes that have been made, and what you need to do to update your if
you use operators or hooks which integrate with Google services (including Google Cloud Platform - GCP).

#### Direct impersonation added to operators communicating with Google services
[Directly impersonating a service account](https://cloud.google.com/iam/docs/understanding-service-accounts#directly_impersonating_a_service_account)
has been made possible for operators communicating with Google services via new argument called `impersonation_chain`
(`google_impersonation_chain` in case of operators that also communicate with services of other cloud providers).
As a result, GCSToS3Operator no longer derivatives from GCSListObjectsOperator.

#### Normalize gcp_conn_id for Google Cloud Platform

Previously not all hooks and operators related to Google Cloud Platform use
Expand Down
59 changes: 38 additions & 21 deletions airflow/providers/amazon/aws/transfers/gcs_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
This module contains Google Cloud Storage to S3 operator.
"""
import warnings
from typing import Iterable
from typing import Iterable, Optional, Sequence, Union

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.providers.google.cloud.operators.gcs import GCSListObjectsOperator
from airflow.utils.decorators import apply_defaults


class GCSToS3Operator(GCSListObjectsOperator):
class GCSToS3Operator(BaseOperator):
"""
Synchronizes a Google Cloud Storage bucket with an S3 bucket.
Expand All @@ -45,8 +45,8 @@ class GCSToS3Operator(GCSListObjectsOperator):
:param google_cloud_storage_conn_id: (Deprecated) The connection ID used to connect to Google Cloud
Platform. This parameter has been deprecated. You should pass the gcp_conn_id parameter instead.
:type google_cloud_storage_conn_id: str
:param delegate_to: The account to impersonate, if any.
For this to work, the service account making the request must have
:param delegate_to: Google account to impersonate using domain-wide delegation of authority,
if any. For this to work, the service account making the request must have
domain-wide delegation enabled.
:type delegate_to: str
:param dest_aws_conn_id: The destination S3 connection
Expand All @@ -73,8 +73,18 @@ class GCSToS3Operator(GCSListObjectsOperator):
If set to False, will upload only the files that are in the origin but not
in the destination bucket.
:type replace: bool
:param google_impersonation_chain: Optional Google service account to impersonate using
short-term credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:type google_impersonation_chain: Union[str, Sequence[str]]
"""
template_fields: Iterable[str] = ('bucket', 'prefix', 'delimiter', 'dest_s3_key')
template_fields: Iterable[str] = ('bucket', 'prefix', 'delimiter', 'dest_s3_key',
'google_impersonation_chain',)
ui_color = '#f0eee4'

@apply_defaults
Expand All @@ -89,31 +99,42 @@ def __init__(self, *, # pylint: disable=too-many-arguments
dest_s3_key=None,
dest_verify=None,
replace=False,
google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
**kwargs):
super().__init__(**kwargs)

if google_cloud_storage_conn_id:
warnings.warn(
"The google_cloud_storage_conn_id parameter has been deprecated. You should pass "
"the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3)
gcp_conn_id = google_cloud_storage_conn_id

super().__init__(
bucket=bucket,
prefix=prefix,
delimiter=delimiter,
gcp_conn_id=gcp_conn_id,
delegate_to=delegate_to,
**kwargs
)

self.bucket = bucket
self.prefix = prefix
self.delimiter = delimiter
self.gcp_conn_id = gcp_conn_id
self.delegate_to = delegate_to
self.dest_aws_conn_id = dest_aws_conn_id
self.dest_s3_key = dest_s3_key
self.dest_verify = dest_verify
self.replace = replace
self.google_impersonation_chain = google_impersonation_chain

def execute(self, context):
# use the super to list all files in an Google Cloud Storage bucket
files = super().execute(context)
# list all files in an Google Cloud Storage bucket
hook = GCSHook(
google_cloud_storage_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
impersonation_chain=self.google_impersonation_chain,
)

self.log.info('Getting list of the files. Bucket: %s; Delimiter: %s; Prefix: %s',
self.bucket, self.delimiter, self.prefix)

files = hook.list(bucket_name=self.bucket,
prefix=self.prefix,
delimiter=self.delimiter)

s3_hook = S3Hook(aws_conn_id=self.dest_aws_conn_id, verify=self.dest_verify)

if not self.replace:
Expand All @@ -131,10 +152,6 @@ def execute(self, context):
files = list(set(files) - set(existing_files))

if files:
hook = GCSHook(
google_cloud_storage_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to
)

for file in files:
file_bytes = hook.download(self.bucket, file)
Expand Down
20 changes: 17 additions & 3 deletions airflow/providers/amazon/aws/transfers/google_api_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"""
import json
import sys
from typing import Optional, Sequence, Union

from airflow.models import BaseOperator
from airflow.models.xcom import MAX_XCOM_SIZE
Expand Down Expand Up @@ -68,17 +69,27 @@ class GoogleApiToS3Operator(BaseOperator):
:type s3_overwrite: bool
:param gcp_conn_id: The connection ID to use when fetching connection info.
:type gcp_conn_id: str
:param delegate_to: The account to impersonate, if any.
For this to work, the service account making the request must have
:param delegate_to: Google account to impersonate using domain-wide delegation of authority,
if any. For this to work, the service account making the request must have
domain-wide delegation enabled.
:type delegate_to: str
:param aws_conn_id: The connection id specifying the authentication information for the S3 Bucket.
:type aws_conn_id: str
:param google_impersonation_chain: Optional Google service account to impersonate using
short-term credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:type google_impersonation_chain: Union[str, Sequence[str]]
"""

template_fields = (
'google_api_endpoint_params',
's3_destination_key',
'google_impersonation_chain',
)
template_ext = ()
ui_color = '#cc181e'
Expand All @@ -100,6 +111,7 @@ def __init__(
gcp_conn_id='google_cloud_default',
delegate_to=None,
aws_conn_id='aws_default',
google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
**kwargs
):
super().__init__(**kwargs)
Expand All @@ -117,6 +129,7 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.delegate_to = delegate_to
self.aws_conn_id = aws_conn_id
self.google_impersonation_chain = google_impersonation_chain

def execute(self, context):
"""
Expand All @@ -142,7 +155,8 @@ def _retrieve_data_from_google_api(self):
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
api_service_name=self.google_api_service_name,
api_version=self.google_api_service_version
api_version=self.google_api_service_version,
impersonation_chain=self.google_impersonation_chain,
)
google_api_response = google_discovery_api_hook.query(
endpoint=self.google_api_endpoint_path,
Expand Down
21 changes: 17 additions & 4 deletions airflow/providers/google/ads/operators/ads.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"""
import csv
from tempfile import NamedTemporaryFile
from typing import Dict
from typing import Dict, Optional, Sequence, Union

from airflow.models import BaseOperator
from airflow.providers.google.ads.hooks.ads import GoogleAdsHook
Expand Down Expand Up @@ -57,9 +57,18 @@ class GoogleAdsListAccountsOperator(BaseOperator):
:type page_size: int
:param gzip: Option to compress local file or file data for upload
:type gzip: bool
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:type impersonation_chain: Union[str, Sequence[str]]
"""

template_fields = ("bucket", "object_name")
template_fields = ("bucket", "object_name", "impersonation_chain",)

@apply_defaults
def __init__(
Expand All @@ -69,6 +78,7 @@ def __init__(
gcp_conn_id: str = "google_cloud_default",
google_ads_conn_id: str = "google_ads_default",
gzip: bool = False,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -77,6 +87,7 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.google_ads_conn_id = google_ads_conn_id
self.gzip = gzip
self.impersonation_chain = impersonation_chain

def execute(self, context: Dict):
uri = f"gs://{self.bucket}/{self.object_name}"
Expand All @@ -86,8 +97,10 @@ def execute(self, context: Dict):
google_ads_conn_id=self.google_ads_conn_id
)

gcs_hook = GCSHook(gcp_conn_id=self.gcp_conn_id)

gcs_hook = GCSHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain
)
with NamedTemporaryFile("w+") as temp_file:
# Download accounts
accounts = ads_hook.list_accessible_customers()
Expand Down
20 changes: 17 additions & 3 deletions airflow/providers/google/ads/transfers/ads_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import csv
from operator import attrgetter
from tempfile import NamedTemporaryFile
from typing import Dict, List
from typing import Dict, List, Optional, Sequence, Union

from airflow.models import BaseOperator
from airflow.providers.google.ads.hooks.ads import GoogleAdsHook
Expand Down Expand Up @@ -58,9 +58,18 @@ class GoogleAdsToGcsOperator(BaseOperator):
:type page_size: int
:param gzip: Option to compress local file or file data for upload
:type gzip: bool
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
:type impersonation_chain: Union[str, Sequence[str]]
"""

template_fields = ("client_ids", "query", "attributes", "bucket", "obj")
template_fields = ("client_ids", "query", "attributes", "bucket", "obj", "impersonation_chain",)

@apply_defaults
def __init__(
Expand All @@ -74,6 +83,7 @@ def __init__(
google_ads_conn_id: str = "google_ads_default",
page_size: int = 10000,
gzip: bool = False,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -86,6 +96,7 @@ def __init__(
self.google_ads_conn_id = google_ads_conn_id
self.page_size = page_size
self.gzip = gzip
self.impersonation_chain = impersonation_chain

def execute(self, context: Dict):
service = GoogleAdsHook(
Expand All @@ -108,7 +119,10 @@ def execute(self, context: Dict):
writer.writerows(converted_rows)
csvfile.flush()

hook = GCSHook(gcp_conn_id=self.gcp_conn_id)
hook = GCSHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain
)
hook.upload(
bucket_name=self.bucket,
object_name=self.obj,
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/google/cloud/hooks/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ class BigQueryHook(GoogleBaseHook, DbApiHook):
def __init__(self,
gcp_conn_id: str = 'google_cloud_default',
delegate_to: Optional[str] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
use_legacy_sql: bool = True,
location: Optional[str] = None,
bigquery_conn_id: Optional[str] = None,
api_resource_configs: Optional[Dict] = None) -> None:
api_resource_configs: Optional[Dict] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,) -> None:
# To preserve backward compatibility
# TODO: remove one day
if bigquery_conn_id:
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/hooks/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,8 +416,8 @@ def __init__(
self,
gcp_conn_id: str = "google_cloud_default",
delegate_to: Optional[str] = None,
poll_sleep: int = 10,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
poll_sleep: int = 10
) -> None:
self.poll_sleep = poll_sleep
super().__init__(
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/google/cloud/hooks/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def __init__(
self,
gcp_conn_id: str = "google_cloud_default",
delegate_to: Optional[str] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
api_version: str = 'v1',
datastore_conn_id: Optional[str] = None
datastore_conn_id: Optional[str] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
) -> None:
if datastore_conn_id:
warnings.warn(
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/hooks/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ def __init__(
self,
gcp_conn_id: str = "google_cloud_default",
delegate_to: Optional[str] = None,
google_cloud_storage_conn_id: Optional[str] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
google_cloud_storage_conn_id: Optional[str] = None
) -> None:
# To preserve backward compatibility
# TODO: remove one day
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/hooks/kubernetes_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def __init__(
self,
gcp_conn_id: str = "google_cloud_default",
delegate_to: Optional[str] = None,
location: Optional[str] = None,
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
location: Optional[str] = None
) -> None:
super().__init__(
gcp_conn_id=gcp_conn_id,
Expand Down

0 comments on commit 3734876

Please sign in to comment.