Skip to content

Commit

Permalink
Fix google operators handling of impersonation chain (#36903)
Browse files Browse the repository at this point in the history
* fix(MetastoreHivePartitionSensor): pass impersonation chain to GCSHook

The operator already accepts `impersonation_chain`, but does not pass
it to the GCSHook.

* fix(BigQueryGetDataOperator): pass impersonation chain to BigQueryGetDataTrigger

The operator already accepts `impersonation_chain`, but does not pass
it to the BigQueryGetDataTrigger.

* fix(BigQueryInsertJobOperator): pass impersonation chain to BigQueryInsertJobTrigger

The operator already accepts `impersonation_chain`, but does not pass
it to the BigQueryInsertJobTrigger.

* fix(BigQueryToGCSOperator): pass impersonation chain to BigQueryInsertJobTrigger

The operator already accepts `impersonation_chain`, but does not pass
it to the BigQueryInsertJobTrigger.

* fix(GCSToBigQueryOperator): pass impersonation chain to BigQueryInsertJobTrigger

The operator already accepts `impersonation_chain`, but does not pass
it to the BigQueryInsertJobTrigger.
  • Loading branch information
m1racoli committed Jan 23, 2024
1 parent 3473739 commit 1c14767
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 4 deletions.
11 changes: 9 additions & 2 deletions airflow/providers/google/cloud/hooks/gcs.py
Expand Up @@ -1336,15 +1336,22 @@ def gcs_object_is_directory(bucket: str) -> bool:
return len(blob) == 0 or blob.endswith("/")


def parse_json_from_gcs(gcp_conn_id: str, file_uri: str) -> Any:
def parse_json_from_gcs(
gcp_conn_id: str,
file_uri: str,
impersonation_chain: str | Sequence[str] | None = None,
) -> Any:
"""
Downloads and parses json file from Google cloud Storage.
:param gcp_conn_id: Airflow Google Cloud connection ID.
:param file_uri: full path to json file
example: ``gs://test-bucket/dir1/dir2/file``
"""
gcs_hook = GCSHook(gcp_conn_id=gcp_conn_id)
gcs_hook = GCSHook(
gcp_conn_id=gcp_conn_id,
impersonation_chain=impersonation_chain,
)
bucket, blob = _parse_gcs_url(file_uri)
with NamedTemporaryFile(mode="w+b") as file:
try:
Expand Down
2 changes: 2 additions & 0 deletions airflow/providers/google/cloud/operators/bigquery.py
Expand Up @@ -1069,6 +1069,7 @@ def execute(self, context: Context):
project_id=self.job_project_id or hook.project_id,
poll_interval=self.poll_interval,
as_dict=self.as_dict,
impersonation_chain=self.impersonation_chain,
),
method_name="execute_complete",
)
Expand Down Expand Up @@ -2878,6 +2879,7 @@ def execute(self, context: Any):
job_id=self.job_id,
project_id=self.project_id,
poll_interval=self.poll_interval,
impersonation_chain=self.impersonation_chain,
),
method_name="execute_complete",
)
Expand Down
14 changes: 12 additions & 2 deletions airflow/providers/google/cloud/sensors/dataproc_metastore.py
Expand Up @@ -93,7 +93,11 @@ def poke(self, context: Context) -> bool:
self.log.info("Received result manifest URI: %s", result_manifest_uri)

self.log.info("Extracting result manifest")
manifest: dict = parse_json_from_gcs(gcp_conn_id=self.gcp_conn_id, file_uri=result_manifest_uri)
manifest: dict = parse_json_from_gcs(
gcp_conn_id=self.gcp_conn_id,
file_uri=result_manifest_uri,
impersonation_chain=self.impersonation_chain,
)
if not (manifest and isinstance(manifest, dict)):
# TODO: remove this if check when min_airflow_version is set to higher than 2.7.1
message = (
Expand All @@ -115,7 +119,13 @@ def poke(self, context: Context) -> bool:
result_base_uri = result_manifest_uri.rsplit("/", 1)[0]
results = (f"{result_base_uri}//{filename}" for filename in manifest.get("filenames", []))
found_partitions = sum(
len(parse_json_from_gcs(gcp_conn_id=self.gcp_conn_id, file_uri=uri).get("rows", []))
len(
parse_json_from_gcs(
gcp_conn_id=self.gcp_conn_id,
file_uri=uri,
impersonation_chain=self.impersonation_chain,
).get("rows", [])
)
for uri in results
)

Expand Down
Expand Up @@ -261,6 +261,7 @@ def execute(self, context: Context):
conn_id=self.gcp_conn_id,
job_id=self._job_id,
project_id=self.project_id or self.hook.project_id,
impersonation_chain=self.impersonation_chain,
),
method_name="execute_complete",
)
Expand Down
Expand Up @@ -435,6 +435,7 @@ def execute(self, context: Context):
conn_id=self.gcp_conn_id,
job_id=self.job_id,
project_id=self.project_id or self.hook.project_id,
impersonation_chain=self.impersonation_chain,
),
method_name="execute_complete",
)
Expand Down

0 comments on commit 1c14767

Please sign in to comment.