Skip to content

Commit

Permalink
Fix MyPy errors in Google Cloud (again) (#20469)
Browse files Browse the repository at this point in the history
Part of #19891

The .py additions are to handle "default_args" passed in
examples. Currently some of the obligatory parameters are
(correctly) passed as default_args. We have no good
mechanism yet to handle it properly for MyPy (it would
require to add a custom MyPy plugin to handle it)

We have no better way to handle it for now.
  • Loading branch information
potiuk committed Dec 29, 2021
1 parent afd84f6 commit da88ed1
Show file tree
Hide file tree
Showing 15 changed files with 94 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from datetime import datetime

from google.cloud.datacatalog_v1beta1 import FieldType, TagField, TagTemplateField
from google.protobuf.field_mask_pb2 import FieldMask

from airflow import models
from airflow.models.baseoperator import chain
Expand Down Expand Up @@ -242,7 +243,7 @@
task_id="get_entry_group",
location=LOCATION,
entry_group=ENTRY_GROUP_ID,
read_mask={"paths": ["name", "display_name"]},
read_mask=FieldMask(paths=["name", "display_name"]),
)
# [END howto_operator_gcp_datacatalog_get_entry_group]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
import datetime
import os

from google.cloud.metastore_v1 import MetadataImport
from google.protobuf.field_mask_pb2 import FieldMask

from airflow import models
from airflow.models.baseoperator import chain
from airflow.providers.google.cloud.operators.dataproc_metastore import (
Expand Down Expand Up @@ -66,7 +69,7 @@
"systemtest": "systemtest",
}
}
UPDATE_MASK = {"paths": ["labels"]}
UPDATE_MASK = FieldMask(paths=["labels"])
# [END how_to_cloud_dataproc_metastore_update_service]

# Backup definition
Expand All @@ -78,13 +81,15 @@

# Metadata import definition
# [START how_to_cloud_dataproc_metastore_create_metadata_import]
METADATA_IMPORT = {
"name": "test-metadata-import",
"database_dump": {
"gcs_uri": GCS_URI,
"database_type": DB_TYPE,
},
}
METADATA_IMPORT = MetadataImport(
{
"name": "test-metadata-import",
"database_dump": {
"gcs_uri": GCS_URI,
"database_type": DB_TYPE,
},
}
)
# [END how_to_cloud_dataproc_metastore_create_metadata_import]


Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/hooks/cloud_memorystore.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,7 +939,7 @@ def update_parameters(
parameters: Union[Dict, cloud_memcache.MemcacheParameters],
project_id: str,
location: str,
instance_id: Optional[str] = None,
instance_id: str,
retry: Optional[Retry] = None,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/google/cloud/hooks/cloud_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,15 +505,15 @@ def _download_sql_proxy_if_needed(self) -> None:
if "follow_redirects" in signature(httpx.get).parameters.keys():
response = httpx.get(download_url, follow_redirects=True)
else:
response = httpx.get(download_url, allow_redirects=True)
response = httpx.get(download_url, allow_redirects=True) # type: ignore[call-arg]
# Downloading to .tmp file first to avoid case where partially downloaded
# binary is used by parallel operator which uses the same fixed binary path
with open(proxy_path_tmp, 'wb') as file:
file.write(response.content)
if response.status_code != 200:
raise AirflowException(
"The cloud-sql-proxy could not be downloaded. "
f"Status code = {response.status_code}. Reason = {response.reason}"
f"Status code = {response.status_code}. Reason = {response.reason_phrase}"
)

self.log.info("Moving sql_proxy binary from %s to %s", proxy_path_tmp, self.sql_proxy_path)
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/hooks/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def get_instance_address(
if use_internal_ip:
return instance_info["networkInterfaces"][0].get("networkIP")

access_config = instance_info.get("networkInterfaces")[0].get("accessConfigs")
access_config = instance_info["networkInterfaces"][0].get("accessConfigs")
if access_config:
return access_config[0].get("natIP")
raise AirflowException("The target instance does not have external IP")
Expand Down
3 changes: 2 additions & 1 deletion airflow/providers/google/cloud/hooks/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,8 @@ def cancel(self) -> None:
timeout_error_message = (
f"Canceling jobs failed due to timeout ({self._cancel_timeout}s): {', '.join(job_ids)}"
)
with timeout(seconds=self._cancel_timeout, error_message=timeout_error_message):
tm = timeout(seconds=self._cancel_timeout, error_message=timeout_error_message)
with tm:
self._wait_for_states({DataflowJobStatus.JOB_STATE_CANCELLED})
else:
self.log.info("No jobs to cancel")
Expand Down
9 changes: 6 additions & 3 deletions airflow/providers/google/cloud/operators/datacatalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,9 +375,12 @@ def execute(self, context: dict):
)
except AlreadyExists:
self.log.info("Tag already exists. Skipping create operation.")
project_id = self.project_id or hook.project_id
if project_id is None:
raise RuntimeError("The project id must be set here")
if self.template_id:
template_name = DataCatalogClient.tag_template_path(
self.project_id or hook.project_id, self.location, self.template_id
project_id, self.location, self.template_id
)
else:
if isinstance(self.tag, Tag):
Expand All @@ -390,7 +393,7 @@ def execute(self, context: dict):
entry_group=self.entry_group,
template_name=template_name,
entry=self.entry,
project_id=self.project_id,
project_id=project_id,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
Expand Down Expand Up @@ -1265,7 +1268,7 @@ def __init__(
*,
location: str,
entry_group: str,
read_mask: Union[Dict, FieldMask],
read_mask: FieldMask,
project_id: Optional[str] = None,
retry: Optional[Retry] = None,
timeout: Optional[float] = None,
Expand Down
6 changes: 5 additions & 1 deletion airflow/providers/google/cloud/operators/dataprep.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@
# specific language governing permissions and limitations
# under the License.
"""This module contains a Google Dataprep operator."""
from typing import TYPE_CHECKING

from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.dataprep import GoogleDataprepHook

if TYPE_CHECKING:
from airflow.utils.context import Context


class DataprepGetJobsForJobGroupOperator(BaseOperator):
"""
Expand Down Expand Up @@ -121,7 +125,7 @@ def __init__(self, *, dataprep_conn_id: str = "dataprep_default", body_request:
self.body_request = body_request
self.dataprep_conn_id = dataprep_conn_id

def execute(self, context: None) -> dict:
def execute(self, context: "Context") -> dict:
self.log.info("Creating a job...")
hook = GoogleDataprepHook(dataprep_conn_id=self.dataprep_conn_id)
response = hook.run_job_group(body_request=self.body_request)
Expand Down
43 changes: 23 additions & 20 deletions airflow/providers/google/cloud/operators/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
import warnings
from pathlib import Path
from tempfile import NamedTemporaryFile, TemporaryDirectory
from typing import Dict, Iterable, List, Optional, Sequence, Union
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Union

if TYPE_CHECKING:
from airflow.utils.context import Context

from google.api_core.exceptions import Conflict
from google.cloud.exceptions import GoogleCloudError
Expand Down Expand Up @@ -152,7 +155,7 @@ def __init__(
self.delegate_to = delegate_to
self.impersonation_chain = impersonation_chain

def execute(self, context) -> None:
def execute(self, context: "Context") -> None:
hook = GCSHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
Expand Down Expand Up @@ -258,7 +261,7 @@ def __init__(
self.delegate_to = delegate_to
self.impersonation_chain = impersonation_chain

def execute(self, context) -> list:
def execute(self, context: "Context") -> list:

hook = GCSHook(
gcp_conn_id=self.gcp_conn_id,
Expand Down Expand Up @@ -320,7 +323,7 @@ def __init__(
self,
*,
bucket_name: str,
objects: Optional[Iterable[str]] = None,
objects: Optional[List[str]] = None,
prefix: Optional[str] = None,
gcp_conn_id: str = 'google_cloud_default',
google_cloud_storage_conn_id: Optional[str] = None,
Expand Down Expand Up @@ -350,7 +353,7 @@ def __init__(

super().__init__(**kwargs)

def execute(self, context):
def execute(self, context: "Context") -> None:
hook = GCSHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
Expand Down Expand Up @@ -443,7 +446,7 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain

def execute(self, context) -> None:
def execute(self, context: "Context") -> None:
hook = GCSHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
Expand Down Expand Up @@ -541,7 +544,7 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain

def execute(self, context) -> None:
def execute(self, context: "Context") -> None:
hook = GCSHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
Expand Down Expand Up @@ -620,7 +623,7 @@ def __init__(
self.output_encoding = sys.getdefaultencoding()
self.impersonation_chain = impersonation_chain

def execute(self, context: dict) -> None:
def execute(self, context: "Context") -> None:
hook = GCSHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)

with NamedTemporaryFile() as source_file, NamedTemporaryFile() as destination_file:
Expand Down Expand Up @@ -742,7 +745,7 @@ class GCSTimeSpanFileTransformOperator(BaseOperator):
)

@staticmethod
def interpolate_prefix(prefix: str, dt: datetime.datetime) -> Optional[datetime.datetime]:
def interpolate_prefix(prefix: str, dt: datetime.datetime) -> Optional[str]:
"""Interpolate prefix with datetime.
:param prefix: The prefix to interpolate
Expand Down Expand Up @@ -792,7 +795,7 @@ def __init__(
self.upload_continue_on_fail = upload_continue_on_fail
self.upload_num_attempts = upload_num_attempts

def execute(self, context: dict) -> None:
def execute(self, context: "Context") -> List[str]:
# Define intervals and prefixes.
try:
timespan_start = context["data_interval_start"]
Expand Down Expand Up @@ -838,12 +841,12 @@ def execute(self, context: dict) -> None:
)

with TemporaryDirectory() as temp_input_dir, TemporaryDirectory() as temp_output_dir:
temp_input_dir = Path(temp_input_dir)
temp_output_dir = Path(temp_output_dir)
temp_input_dir_path = Path(temp_input_dir)
temp_output_dir_path = Path(temp_output_dir)

# TODO: download in parallel.
for blob_to_transform in blobs_to_transform:
destination_file = temp_input_dir / blob_to_transform
destination_file = temp_input_dir_path / blob_to_transform
destination_file.parent.mkdir(parents=True, exist_ok=True)
try:
source_hook.download(
Expand All @@ -861,8 +864,8 @@ def execute(self, context: dict) -> None:
self.log.info("Starting the transformation")
cmd = [self.transform_script] if isinstance(self.transform_script, str) else self.transform_script
cmd += [
str(temp_input_dir),
str(temp_output_dir),
str(temp_input_dir_path),
str(temp_output_dir_path),
timespan_start.replace(microsecond=0).isoformat(),
timespan_end.replace(microsecond=0).isoformat(),
]
Expand All @@ -878,16 +881,16 @@ def execute(self, context: dict) -> None:
if process.returncode:
raise AirflowException(f"Transform script failed: {process.returncode}")

self.log.info("Transformation succeeded. Output temporarily located at %s", temp_output_dir)
self.log.info("Transformation succeeded. Output temporarily located at %s", temp_output_dir_path)

files_uploaded = []

# TODO: upload in parallel.
for upload_file in temp_output_dir.glob("**/*"):
for upload_file in temp_output_dir_path.glob("**/*"):
if upload_file.is_dir():
continue

upload_file_name = str(upload_file.relative_to(temp_output_dir))
upload_file_name = str(upload_file.relative_to(temp_output_dir_path))

if self.destination_prefix is not None:
upload_file_name = f"{destination_prefix_interp}/{upload_file_name}"
Expand Down Expand Up @@ -959,7 +962,7 @@ def __init__(
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain

def execute(self, context) -> None:
def execute(self, context: "Context") -> None:
hook = GCSHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain)
hook.delete_bucket(bucket_name=self.bucket_name, force=self.force)

Expand Down Expand Up @@ -1056,7 +1059,7 @@ def __init__(
self.delegate_to = delegate_to
self.impersonation_chain = impersonation_chain

def execute(self, context) -> None:
def execute(self, context: "Context") -> None:
hook = GCSHook(
gcp_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/google/cloud/operators/mlengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,9 +1164,9 @@ def __init__(
*,
job_id: str,
region: str,
package_uris: List[str] = None,
training_python_module: str = None,
training_args: List[str] = None,
package_uris: Optional[List[str]] = None,
training_python_module: Optional[str] = None,
training_args: Optional[List[str]] = None,
scale_tier: Optional[str] = None,
master_type: Optional[str] = None,
master_config: Optional[Dict] = None,
Expand Down
12 changes: 6 additions & 6 deletions airflow/providers/google/cloud/sensors/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,12 @@ def poke(self, context: dict) -> bool:
pipeline_status = pipeline_workflow["status"]
except AirflowException:
pass # Because the pipeline may not be visible in system yet

if self.failure_statuses and pipeline_status in self.failure_statuses:
raise AirflowException(
f"Pipeline with id '{self.pipeline_id}' state is: {pipeline_status}. "
f"Terminating sensor..."
)
if pipeline_status is not None:
if self.failure_statuses and pipeline_status in self.failure_statuses:
raise AirflowException(
f"Pipeline with id '{self.pipeline_id}' state is: {pipeline_status}. "
f"Terminating sensor..."
)

self.log.debug(
"Current status of the pipeline workflow for %s: %s.", self.pipeline_id, pipeline_status
Expand Down
9 changes: 6 additions & 3 deletions airflow/providers/google/cloud/sensors/dataproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# pylint: disable=C0302
import time
import warnings
from typing import Dict, Optional
from typing import TYPE_CHECKING, Optional

from google.api_core.exceptions import ServerError
from google.cloud.dataproc_v1.types import JobStatus
Expand All @@ -28,6 +28,9 @@
from airflow.providers.google.cloud.hooks.dataproc import DataprocHook
from airflow.sensors.base import BaseSensorOperator

if TYPE_CHECKING:
from airflow.utils.context import Context


class DataprocJobSensor(BaseSensorOperator):
"""
Expand Down Expand Up @@ -81,14 +84,14 @@ def __init__(
self.wait_timeout = wait_timeout
self.start_sensor_time: Optional[float] = None

def execute(self, context: Dict):
def execute(self, context: "Context") -> None:
self.start_sensor_time = time.monotonic()
super().execute(context)

def _duration(self):
return time.monotonic() - self.start_sensor_time

def poke(self, context: Dict) -> bool:
def poke(self, context: "Context") -> bool:
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id)
if self.wait_timeout:
try:
Expand Down

0 comments on commit da88ed1

Please sign in to comment.