Skip to content

Commit

Permalink
Fix template_fields type to have MyPy friendly Sequence type (#20571)
Browse files Browse the repository at this point in the history
Part of #19891
  • Loading branch information
potiuk committed Dec 30, 2021
1 parent bd9e8ce commit d56e7b5
Show file tree
Hide file tree
Showing 225 changed files with 818 additions and 766 deletions.
4 changes: 2 additions & 2 deletions airflow/providers/airbyte/operators/airbyte.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Sequence

from airflow.models import BaseOperator
from airflow.providers.airbyte.hooks.airbyte import AirbyteHook
Expand Down Expand Up @@ -51,7 +51,7 @@ class AirbyteTriggerSyncOperator(BaseOperator):
:type timeout: float
"""

template_fields = ('connection_id',)
template_fields: Sequence[str] = ('connection_id',)

def __init__(
self,
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/airbyte/sensors/airbyte.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
"""This module contains a Airbyte Job sensor."""
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Sequence

from airflow.exceptions import AirflowException
from airflow.providers.airbyte.hooks.airbyte import AirbyteHook
Expand All @@ -39,7 +39,7 @@ class AirbyteJobSensor(BaseSensorOperator):
:type api_version: str
"""

template_fields = ('airbyte_job_id',)
template_fields: Sequence[str] = ('airbyte_job_id',)
ui_color = '#6C51FD'

def __init__(
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/alibaba/cloud/sensors/oss_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
else:
from cached_property import cached_property

from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Sequence
from urllib.parse import urlparse

from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -51,7 +51,7 @@ class OSSKeySensor(BaseSensorOperator):
:type oss_conn_id: Optional[str]
"""

template_fields = ('bucket_key', 'bucket_name')
template_fields: Sequence[str] = ('bucket_key', 'bucket_name')

def __init__(
self,
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/operators/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#
import sys
import warnings
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence
from uuid import uuid4

if sys.version_info >= (3, 8):
Expand Down Expand Up @@ -64,7 +64,7 @@ class AthenaOperator(BaseOperator):
"""

ui_color = '#44b5e2'
template_fields = ('query', 'database', 'output_location')
template_fields: Sequence[str] = ('query', 'database', 'output_location')
template_ext = ('.sql',)
template_fields_renderers = {"query": "sql"}

Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/operators/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
- http://boto3.readthedocs.io/en/latest/reference/services/batch.html
- https://docs.aws.amazon.com/batch/latest/APIReference/Welcome.html
"""
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Optional, Sequence

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
Expand Down Expand Up @@ -97,7 +97,7 @@ class AwsBatchOperator(BaseOperator):

ui_color = "#c3dae0"
arn = None # type: Optional[str]
template_fields = (
template_fields: Sequence[str] = (
"job_name",
"overrides",
"parameters",
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/amazon/aws/operators/cloud_formation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
"""This module contains CloudFormation create/delete stack operators."""
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, Optional, Sequence

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.cloud_formation import CloudFormationHook
Expand All @@ -40,7 +40,7 @@ class CloudFormationCreateStackOperator(BaseOperator):
:type aws_conn_id: str
"""

template_fields: List[str] = ['stack_name']
template_fields: Sequence[str] = ('stack_name',)
template_ext = ()
ui_color = '#6b9659'

Expand Down Expand Up @@ -72,7 +72,7 @@ class CloudFormationDeleteStackOperator(BaseOperator):
:type aws_conn_id: str
"""

template_fields: List[str] = ['stack_name']
template_fields: Sequence[str] = ('stack_name',)
template_ext = ()
ui_color = '#1d472b'
ui_fgcolor = '#FFF'
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/operators/datasync.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import logging
import random
import warnings
from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, List, Optional, Sequence

from airflow.exceptions import AirflowException, AirflowTaskTimeout
from airflow.models import BaseOperator
Expand Down Expand Up @@ -111,7 +111,7 @@ class DataSyncOperator(BaseOperator):
:raises AirflowException: If Task creation, update, execution or delete fails.
"""

template_fields = (
template_fields: Sequence[str] = (
"task_arn",
"source_location_uri",
"destination_location_uri",
Expand Down
12 changes: 6 additions & 6 deletions airflow/providers/amazon/aws/operators/dms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# under the License.


from typing import TYPE_CHECKING, Dict, Optional
from typing import TYPE_CHECKING, Dict, Optional, Sequence

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.dms import DmsHook
Expand Down Expand Up @@ -56,7 +56,7 @@ class DmsCreateTaskOperator(BaseOperator):
:type aws_conn_id: Optional[str]
"""

template_fields = (
template_fields: Sequence[str] = (
'replication_task_id',
'source_endpoint_arn',
'target_endpoint_arn',
Expand Down Expand Up @@ -134,7 +134,7 @@ class DmsDeleteTaskOperator(BaseOperator):
:type aws_conn_id: Optional[str]
"""

template_fields = ('replication_task_arn',)
template_fields: Sequence[str] = ('replication_task_arn',)
template_ext = ()
template_fields_renderers: Dict[str, str] = {}

Expand Down Expand Up @@ -174,7 +174,7 @@ class DmsDescribeTasksOperator(BaseOperator):
:type aws_conn_id: Optional[str]
"""

template_fields = ('describe_tasks_kwargs',)
template_fields: Sequence[str] = ('describe_tasks_kwargs',)
template_ext = ()
template_fields_renderers: Dict[str, str] = {'describe_tasks_kwargs': 'json'}

Expand Down Expand Up @@ -223,7 +223,7 @@ class DmsStartTaskOperator(BaseOperator):
:type aws_conn_id: Optional[str]
"""

template_fields = (
template_fields: Sequence[str] = (
'replication_task_arn',
'start_replication_task_type',
'start_task_kwargs',
Expand Down Expand Up @@ -276,7 +276,7 @@ class DmsStopTaskOperator(BaseOperator):
:type aws_conn_id: Optional[str]
"""

template_fields = ('replication_task_arn',)
template_fields: Sequence[str] = ('replication_task_arn',)
template_ext = ()
template_fields_renderers: Dict[str, str] = {}

Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/amazon/aws/operators/ec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# under the License.
#

from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Sequence

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.ec2 import EC2Hook
Expand All @@ -41,7 +41,7 @@ class EC2StartInstanceOperator(BaseOperator):
:type check_interval: float
"""

template_fields = ("instance_id", "region_name")
template_fields: Sequence[str] = ("instance_id", "region_name")
ui_color = "#eeaa11"
ui_fgcolor = "#ffffff"

Expand Down Expand Up @@ -87,7 +87,7 @@ class EC2StopInstanceOperator(BaseOperator):
:type check_interval: float
"""

template_fields = ("instance_id", "region_name")
template_fields: Sequence[str] = ("instance_id", "region_name")
ui_color = "#eeaa11"
ui_fgcolor = "#ffffff"

Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/operators/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from datetime import datetime, timedelta
from logging import Logger
from threading import Event, Thread
from typing import Dict, Generator, Optional
from typing import Dict, Generator, Optional, Sequence

from botocore.exceptions import ClientError
from botocore.waiter import Waiter
Expand Down Expand Up @@ -225,7 +225,7 @@ class ECSOperator(BaseOperator):
"""

ui_color = '#f0ede4'
template_fields = ('overrides',)
template_fields: Sequence[str] = ('overrides',)
template_fields_renderers = {
"overrides": "json",
"network_configuration": "json",
Expand Down
33 changes: 18 additions & 15 deletions airflow/providers/amazon/aws/operators/eks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""This module contains Amazon EKS operators."""
import warnings
from time import sleep
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence

from airflow import AirflowException
from airflow.models import BaseOperator
Expand Down Expand Up @@ -109,7 +109,7 @@ class EksCreateClusterOperator(BaseOperator):
"""

template_fields: Iterable[str] = (
template_fields: Sequence[str] = (
"cluster_name",
"cluster_role_arn",
"resources_vpc_config",
Expand Down Expand Up @@ -246,7 +246,7 @@ class EksCreateNodegroupOperator(BaseOperator):
"""

template_fields: Iterable[str] = (
template_fields: Sequence[str] = (
"cluster_name",
"nodegroup_subnets",
"nodegroup_role_arn",
Expand Down Expand Up @@ -316,7 +316,7 @@ class EksCreateFargateProfileOperator(BaseOperator):
:type region: str
"""

template_fields: Iterable[str] = (
template_fields: Sequence[str] = (
"cluster_name",
"pod_execution_role_arn",
"selectors",
Expand Down Expand Up @@ -382,7 +382,7 @@ class EksDeleteClusterOperator(BaseOperator):
"""

template_fields: Iterable[str] = (
template_fields: Sequence[str] = (
"cluster_name",
"force_delete_compute",
"aws_conn_id",
Expand Down Expand Up @@ -506,7 +506,7 @@ class EksDeleteNodegroupOperator(BaseOperator):
"""

template_fields: Iterable[str] = (
template_fields: Sequence[str] = (
"cluster_name",
"nodegroup_name",
"aws_conn_id",
Expand Down Expand Up @@ -559,7 +559,7 @@ class EksDeleteFargateProfileOperator(BaseOperator):
:type region: str
"""

template_fields: Iterable[str] = (
template_fields: Sequence[str] = (
"cluster_name",
"fargate_profile_name",
"aws_conn_id",
Expand Down Expand Up @@ -623,14 +623,17 @@ class EksPodOperator(KubernetesPodOperator):
:type aws_conn_id: str
"""

template_fields: Iterable[str] = {
"cluster_name",
"in_cluster",
"namespace",
"pod_name",
"aws_conn_id",
"region",
} | set(KubernetesPodOperator.template_fields)
template_fields: Sequence[str] = tuple(
{
"cluster_name",
"in_cluster",
"namespace",
"pod_name",
"aws_conn_id",
"region",
}
| set(KubernetesPodOperator.template_fields)
)

def __init__(
self,
Expand Down
18 changes: 12 additions & 6 deletions airflow/providers/amazon/aws/operators/emr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import ast
import sys
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
from uuid import uuid4

from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -59,7 +59,7 @@ class EmrAddStepsOperator(BaseOperator):
:type do_xcom_push: bool
"""

template_fields = ['job_flow_id', 'job_flow_name', 'cluster_states', 'steps']
template_fields: Sequence[str] = ('job_flow_id', 'job_flow_name', 'cluster_states', 'steps')
template_ext = ('.json',)
template_fields_renderers = {"steps": "json"}
ui_color = '#f9c915'
Expand Down Expand Up @@ -149,7 +149,13 @@ class EmrContainerOperator(BaseOperator):
:type max_tries: int
"""

template_fields = ["name", "virtual_cluster_id", "execution_role_arn", "release_label", "job_driver"]
template_fields: Sequence[str] = (
"name",
"virtual_cluster_id",
"execution_role_arn",
"release_label",
"job_driver",
)
ui_color = "#f9c915"

def __init__(
Expand Down Expand Up @@ -274,7 +280,7 @@ class EmrCreateJobFlowOperator(BaseOperator):
:type region_name: Optional[str]
"""

template_fields = ['job_flow_overrides']
template_fields: Sequence[str] = ('job_flow_overrides',)
template_ext = ('.json',)
template_fields_renderers = {"job_flow_overrides": "json"}
ui_color = '#f9c915'
Expand Down Expand Up @@ -333,7 +339,7 @@ class EmrModifyClusterOperator(BaseOperator):
:type do_xcom_push: bool
"""

template_fields = ['cluster_id', 'step_concurrency_level']
template_fields: Sequence[str] = ('cluster_id', 'step_concurrency_level')
template_ext = ()
ui_color = '#f9c915'

Expand Down Expand Up @@ -377,7 +383,7 @@ class EmrTerminateJobFlowOperator(BaseOperator):
:type aws_conn_id: str
"""

template_fields = ['job_flow_id']
template_fields: Sequence[str] = ('job_flow_id',)
template_ext = ()
ui_color = '#f9c915'

Expand Down

0 comments on commit d56e7b5

Please sign in to comment.