Skip to content

Commit

Permalink
Google cloud operator strict type check (#11450)
Browse files Browse the repository at this point in the history
import optimisation
  • Loading branch information
mlgruby committed Oct 12, 2020
1 parent 358e61d commit 06141d6
Show file tree
Hide file tree
Showing 26 changed files with 237 additions and 232 deletions.
31 changes: 18 additions & 13 deletions airflow/providers/google/cloud/operators/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import re
import uuid
import warnings
from datetime import datetime
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, SupportsAbs, Union

import attr
Expand Down Expand Up @@ -81,7 +82,7 @@ class BigQueryConsoleIndexableLink(BaseOperatorLink):
def name(self) -> str:
return f'BigQuery Console #{self.index + 1}'

def get_link(self, operator, dttm):
def get_link(self, operator: BaseOperator, dttm: datetime):
ti = TaskInstance(task=operator, execution_date=dttm)
job_ids = ti.xcom_pull(task_ids=operator.task_id, key='job_id')
if not job_ids:
Expand Down Expand Up @@ -466,7 +467,7 @@ def __init__(
self.location = location
self.impersonation_chain = impersonation_chain

def execute(self, context):
def execute(self, context) -> list:
self.log.info(
'Fetching Data from %s.%s max results: %s', self.dataset_id, self.table_id, self.max_results
)
Expand Down Expand Up @@ -741,7 +742,7 @@ def execute(self, context):
)
context['task_instance'].xcom_push(key='job_id', value=job_id)

def on_kill(self):
def on_kill(self) -> None:
super().on_kill()
if self.hook is not None:
self.log.info('Cancelling running query')
Expand Down Expand Up @@ -931,7 +932,7 @@ def __init__(
self.table_resource = table_resource
self.impersonation_chain = impersonation_chain

def execute(self, context):
def execute(self, context) -> None:
bq_hook = BigQueryHook(
gcp_conn_id=self.bigquery_conn_id,
delegate_to=self.delegate_to,
Expand All @@ -946,7 +947,9 @@ def execute(self, context):
delegate_to=self.delegate_to,
impersonation_chain=self.impersonation_chain,
)
schema_fields = json.loads(gcs_hook.download(gcs_bucket, gcs_object).decode("utf-8"))
schema_fields = json.loads(
gcs_hook.download(gcs_bucket, gcs_object).decode("utf-8") # type: ignore[attr-defined]
) # type: ignore[attr-defined]
else:
schema_fields = self.schema_fields

Expand Down Expand Up @@ -1172,7 +1175,7 @@ def __init__(
self.location = location
self.impersonation_chain = impersonation_chain

def execute(self, context):
def execute(self, context) -> None:
bq_hook = BigQueryHook(
gcp_conn_id=self.bigquery_conn_id,
delegate_to=self.delegate_to,
Expand All @@ -1187,7 +1190,7 @@ def execute(self, context):
impersonation_chain=self.impersonation_chain,
)
schema_object = gcs_hook.download(self.bucket, self.schema_object)
schema_fields = json.loads(schema_object.decode("utf-8"))
schema_fields = json.loads(schema_object.decode("utf-8")) # type: ignore[attr-defined]
else:
schema_fields = self.schema_fields

Expand Down Expand Up @@ -1309,7 +1312,7 @@ def __init__(

super().__init__(**kwargs)

def execute(self, context):
def execute(self, context) -> None:
self.log.info('Dataset id: %s Project id: %s', self.dataset_id, self.project_id)

bq_hook = BigQueryHook(
Expand Down Expand Up @@ -1413,7 +1416,7 @@ def __init__(

super().__init__(**kwargs)

def execute(self, context):
def execute(self, context) -> None:
bq_hook = BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
Expand Down Expand Up @@ -1828,7 +1831,7 @@ def __init__(
self.location = location
self.impersonation_chain = impersonation_chain

def execute(self, context):
def execute(self, context) -> None:
self.log.info('Deleting: %s', self.deletion_dataset_table)
hook = BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
Expand Down Expand Up @@ -1919,7 +1922,7 @@ def __init__(
self.location = location
self.impersonation_chain = impersonation_chain

def execute(self, context):
def execute(self, context) -> None:
self.log.info('Upserting Dataset: %s with table_resource: %s', self.dataset_id, self.table_resource)
hook = BigQueryHook(
bigquery_conn_id=self.gcp_conn_id,
Expand Down Expand Up @@ -2107,6 +2110,8 @@ def execute(self, context: Any):
self.job_id = job.job_id
return job.job_id

def on_kill(self):
def on_kill(self) -> None:
if self.job_id and self.cancel_on_kill:
self.hook.cancel_job(job_id=self.job_id, project_id=self.project_id, location=self.location)
self.hook.cancel_job( # type: ignore[union-attr]
job_id=self.job_id, project_id=self.project_id, location=self.location
)
8 changes: 4 additions & 4 deletions airflow/providers/google/cloud/operators/bigquery_dts.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(
gcp_conn_id="google_cloud_default",
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
**kwargs,
):
) -> None:
super().__init__(**kwargs)
self.transfer_config = transfer_config
self.authorization_code = authorization_code
Expand Down Expand Up @@ -172,7 +172,7 @@ def __init__(
gcp_conn_id="google_cloud_default",
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
**kwargs,
):
) -> None:
super().__init__(**kwargs)
self.project_id = project_id
self.transfer_config_id = transfer_config_id
Expand All @@ -182,7 +182,7 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain

def execute(self, context):
def execute(self, context) -> None:
hook = BiqQueryDataTransferServiceHook(
gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain
)
Expand Down Expand Up @@ -265,7 +265,7 @@ def __init__(
gcp_conn_id="google_cloud_default",
impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
**kwargs,
):
) -> None:
super().__init__(**kwargs)
self.project_id = project_id
self.transfer_config_id = transfer_config_id
Expand Down
14 changes: 7 additions & 7 deletions airflow/providers/google/cloud/operators/bigtable.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __init__(
self.impersonation_chain = impersonation_chain
super().__init__(**kwargs)

def execute(self, context):
def execute(self, context) -> None:
hook = BigtableHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
Expand Down Expand Up @@ -259,7 +259,7 @@ def __init__(
self.impersonation_chain = impersonation_chain
super().__init__(**kwargs)

def execute(self, context):
def execute(self, context) -> None:
hook = BigtableHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
Expand Down Expand Up @@ -335,7 +335,7 @@ def __init__(
self.impersonation_chain = impersonation_chain
super().__init__(**kwargs)

def execute(self, context):
def execute(self, context) -> None:
hook = BigtableHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
Expand Down Expand Up @@ -423,7 +423,7 @@ def __init__(
self.impersonation_chain = impersonation_chain
super().__init__(**kwargs)

def _compare_column_families(self, hook, instance):
def _compare_column_families(self, hook, instance) -> bool:
table_column_families = hook.get_column_families_for_table(instance, self.table_id)
if set(table_column_families.keys()) != set(self.column_families.keys()):
self.log.error("Table '%s' has different set of Column Families", self.table_id)
Expand All @@ -444,7 +444,7 @@ def _compare_column_families(self, hook, instance):
return False
return True

def execute(self, context):
def execute(self, context) -> None:
hook = BigtableHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
Expand Down Expand Up @@ -533,7 +533,7 @@ def __init__(
self.impersonation_chain = impersonation_chain
super().__init__(**kwargs)

def execute(self, context):
def execute(self, context) -> None:
hook = BigtableHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
Expand Down Expand Up @@ -619,7 +619,7 @@ def __init__(
self.impersonation_chain = impersonation_chain
super().__init__(**kwargs)

def execute(self, context):
def execute(self, context) -> None:
hook = BigtableHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
Expand Down
14 changes: 7 additions & 7 deletions airflow/providers/google/cloud/operators/cloud_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ class BuildProcessor:
:type body: dict
"""

def __init__(self, body: Dict) -> None:
def __init__(self, body: dict) -> None:
self.body = deepcopy(body)

def _verify_source(self):
def _verify_source(self) -> None:
is_storage = "storageSource" in self.body["source"]
is_repo = "repoSource" in self.body["source"]

Expand All @@ -61,11 +61,11 @@ def _verify_source(self):
"storageSource and repoSource."
)

def _reformat_source(self):
def _reformat_source(self) -> None:
self._reformat_repo_source()
self._reformat_storage_source()

def _reformat_repo_source(self):
def _reformat_repo_source(self) -> None:
if "repoSource" not in self.body["source"]:
return

Expand All @@ -76,7 +76,7 @@ def _reformat_repo_source(self):

self.body["source"]["repoSource"] = self._convert_repo_url_to_dict(source)

def _reformat_storage_source(self):
def _reformat_storage_source(self) -> None:
if "storageSource" not in self.body["source"]:
return

Expand All @@ -87,7 +87,7 @@ def _reformat_storage_source(self):

self.body["source"]["storageSource"] = self._convert_storage_url_to_dict(source)

def process_body(self):
def process_body(self) -> dict:
"""
Processes the body passed in the constructor
Expand Down Expand Up @@ -228,7 +228,7 @@ def prepare_template(self) -> None:
if self.body_raw.endswith('.json'):
self.body = json.loads(file.read())

def _validate_inputs(self):
def _validate_inputs(self) -> None:
if not self.body:
raise AirflowException("The required parameter 'body' is missing")

Expand Down
22 changes: 11 additions & 11 deletions airflow/providers/google/cloud/operators/cloud_memorystore.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain

def execute(self, context: Dict):
def execute(self, context: dict):
hook = CloudMemorystoreHook(
gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain
)
Expand Down Expand Up @@ -204,7 +204,7 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain

def execute(self, context: Dict):
def execute(self, context: dict) -> None:
hook = CloudMemorystoreHook(
gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain
)
Expand Down Expand Up @@ -298,7 +298,7 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain

def execute(self, context: Dict):
def execute(self, context: dict) -> None:
hook = CloudMemorystoreHook(
gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain
)
Expand Down Expand Up @@ -391,7 +391,7 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain

def execute(self, context: Dict):
def execute(self, context: dict) -> None:
hook = CloudMemorystoreHook(
gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain
)
Expand Down Expand Up @@ -476,7 +476,7 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain

def execute(self, context: Dict):
def execute(self, context: dict):
hook = CloudMemorystoreHook(
gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain
)
Expand Down Expand Up @@ -572,7 +572,7 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain

def execute(self, context: Dict):
def execute(self, context: dict) -> None:
hook = CloudMemorystoreHook(
gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain
)
Expand Down Expand Up @@ -661,7 +661,7 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain

def execute(self, context: Dict):
def execute(self, context: dict):
hook = CloudMemorystoreHook(
gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain
)
Expand Down Expand Up @@ -771,7 +771,7 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain

def execute(self, context: Dict):
def execute(self, context: dict) -> None:
hook = CloudMemorystoreHook(
gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain
)
Expand Down Expand Up @@ -863,7 +863,7 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain

def execute(self, context: Dict):
def execute(self, context: dict) -> None:
hook = CloudMemorystoreHook(
gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain
)
Expand Down Expand Up @@ -978,7 +978,7 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain

def execute(self, context: Dict):
def execute(self, context: dict) -> None:
hook = CloudMemorystoreHook(
gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain
)
Expand Down Expand Up @@ -1085,7 +1085,7 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain

def execute(self, context: Dict):
def execute(self, context: dict) -> None:
hook = CloudMemorystoreHook(
gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain
)
Expand Down

0 comments on commit 06141d6

Please sign in to comment.