Skip to content

Commit

Permalink
Fix cached_property MyPy declaration and related MyPy errors (#20226)
Browse files Browse the repository at this point in the history
Part of #19891
  • Loading branch information
potiuk committed Dec 15, 2021
1 parent 21b8661 commit 2fb5e1d
Show file tree
Hide file tree
Showing 31 changed files with 174 additions and 97 deletions.
7 changes: 5 additions & 2 deletions airflow/providers/alibaba/cloud/sensors/oss_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
try:
import sys

if sys.version_info >= (3, 8):
from functools import cached_property
except ImportError:
else:
from cached_property import cached_property

from typing import Optional
from urllib.parse import urlparse

Expand Down
24 changes: 19 additions & 5 deletions airflow/providers/amazon/aws/hooks/base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import configparser
import datetime
import logging
import sys
import warnings
from functools import wraps
from typing import Any, Callable, Dict, Optional, Tuple, Union
Expand All @@ -40,9 +41,9 @@
from botocore.credentials import ReadOnlyCredentials
from slugify import slugify

try:
if sys.version_info >= (3, 8):
from functools import cached_property
except ImportError:
else:
from cached_property import cached_property

from dateutil.tz import tzlocal
Expand All @@ -60,8 +61,8 @@ def __init__(self, conn: Connection, region_name: Optional[str], config: Config)
self.region_name = region_name
self.config = config
self.extra_config = self.conn.extra_dejson
self.basic_session = None
self.role_arn = None
self.basic_session: Optional[boto3.session.Session] = None
self.role_arn: Optional[str] = None

def create_session(self) -> boto3.session.Session:
"""Create AWS session."""
Expand Down Expand Up @@ -128,6 +129,8 @@ def _create_session_with_assume_role(self, session_kwargs: Dict[str, Any]) -> bo
)
session = botocore.session.get_session()
session._credentials = credentials
if self.basic_session is None:
raise RuntimeError("The basic session should be created here!")
region_name = self.basic_session.region_name
session.set_config_variable("region", region_name)
return boto3.session.Session(botocore_session=session, **session_kwargs)
Expand All @@ -137,16 +140,25 @@ def _refresh_credentials(self) -> Dict[str, Any]:
assume_role_method = self.extra_config.get('assume_role_method', 'assume_role')
sts_session = self.basic_session
if assume_role_method == 'assume_role':
if sts_session is None:
raise RuntimeError(
"Session should be initialized when refresh credentials with assume_role is used!"
)
sts_client = sts_session.client("sts", config=self.config)
sts_response = self._assume_role(sts_client=sts_client)
elif assume_role_method == 'assume_role_with_saml':
if sts_session is None:
raise RuntimeError(
"Session should be initialized when refresh "
"credentials with assume_role_with_saml is used!"
)
sts_client = sts_session.client("sts", config=self.config)
sts_response = self._assume_role_with_saml(sts_client=sts_client)
else:
raise NotImplementedError(f'assume_role_method={assume_role_method} not expected')
sts_response_http_status = sts_response['ResponseMetadata']['HTTPStatusCode']
if not sts_response_http_status == 200:
raise Exception(f'sts_response_http_status={sts_response_http_status}')
raise RuntimeError(f'sts_response_http_status={sts_response_http_status}')
credentials = sts_response['Credentials']
expiry_time = credentials.get('Expiration').isoformat()
self.log.info(f'New credentials expiry_time:{expiry_time}')
Expand Down Expand Up @@ -305,6 +317,8 @@ def _fetch_saml_assertion_using_http_spegno_auth(self, saml_config: Dict[str, An
def _get_web_identity_credential_fetcher(
self,
) -> botocore.credentials.AssumeRoleWithWebIdentityCredentialFetcher:
if self.basic_session is None:
raise Exception("Session should be set where identity is fetched!")
base_session = self.basic_session._session or botocore.session.get_session()
client_creator = base_session.create_client
federation = self.extra_config.get('assume_role_with_web_identity_federation')
Expand Down
8 changes: 4 additions & 4 deletions airflow/providers/amazon/aws/hooks/glue_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import sys
from time import sleep

try:
if sys.version_info >= (3, 8):
from functools import cached_property
except ImportError:
else:
from cached_property import cached_property

from airflow.exceptions import AirflowException
Expand Down Expand Up @@ -73,7 +73,7 @@ def get_crawler(self, crawler_name: str) -> dict:
"""
return self.glue_client.get_crawler(Name=crawler_name)['Crawler']

def update_crawler(self, **crawler_kwargs) -> str:
def update_crawler(self, **crawler_kwargs) -> bool:
"""
Updates crawler configurations
Expand Down
12 changes: 6 additions & 6 deletions airflow/providers/amazon/aws/hooks/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
# specific language governing permissions and limitations
# under the License.
"""Interact with AWS Redshift clusters."""

import sys
from typing import Dict, List, Optional, Union

try:
if sys.version_info >= (3, 8):
from functools import cached_property
except ImportError:
else:
from cached_property import cached_property

import redshift_connector
Expand Down Expand Up @@ -212,7 +212,7 @@ def get_sqlalchemy_engine(self, engine_kwargs=None):

return create_engine(self.get_uri(), **engine_kwargs)

def get_table_primary_key(self, table: str, schema: Optional[str] = "public") -> List[str]:
def get_table_primary_key(self, table: str, schema: Optional[str] = "public") -> Optional[List[str]]:
"""
Helper method that returns the table primary key
:param table: Name of the target table
Expand All @@ -239,8 +239,8 @@ def get_table_primary_key(self, table: str, schema: Optional[str] = "public") ->
def get_conn(self) -> RedshiftConnection:
"""Returns a redshift_connector.Connection object"""
conn_params = self._get_conn_params()
conn_kwargs = self.conn.extra_dejson
conn_kwargs: Dict = {**conn_params, **conn_kwargs}
conn_kwargs_dejson = self.conn.extra_dejson
conn_kwargs: Dict = {**conn_params, **conn_kwargs_dejson}
conn: RedshiftConnection = redshift_connector.connect(**conn_kwargs)

return conn
6 changes: 3 additions & 3 deletions airflow/providers/amazon/aws/log/cloudwatch_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import sys
from datetime import datetime

import watchtower

try:
if sys.version_info >= (3, 8):
from functools import cached_property
except ImportError:
else:
from cached_property import cached_property

from airflow.configuration import conf
Expand Down
5 changes: 3 additions & 2 deletions airflow/providers/amazon/aws/log/s3_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
# specific language governing permissions and limitations
# under the License.
import os
import sys

try:
if sys.version_info >= (3, 8):
from functools import cached_property
except ImportError:
else:
from cached_property import cached_property

from airflow.configuration import conf
Expand Down
5 changes: 3 additions & 2 deletions airflow/providers/amazon/aws/operators/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
# specific language governing permissions and limitations
# under the License.
#
import sys
from typing import Any, Dict, Optional
from uuid import uuid4

try:
if sys.version_info >= (3, 8):
from functools import cached_property
except ImportError:
else:
from cached_property import cached_property

from airflow.models import BaseOperator
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/amazon/aws/operators/emr_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import sys
from typing import Any, Optional
from uuid import uuid4

from airflow.exceptions import AirflowException

try:
if sys.version_info >= (3, 8):
from functools import cached_property
except ImportError:
else:
from cached_property import cached_property

from airflow.models import BaseOperator
Expand Down
5 changes: 3 additions & 2 deletions airflow/providers/amazon/aws/operators/glue_crawler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import sys

try:
if sys.version_info >= (3, 8):
from functools import cached_property
except ImportError:
else:
from cached_property import cached_property

from airflow.models import BaseOperator
Expand Down
6 changes: 4 additions & 2 deletions airflow/providers/amazon/aws/operators/sagemaker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
# under the License.

import json
import sys
from typing import Iterable

try:
if sys.version_info >= (3, 8):
from functools import cached_property
except ImportError:
else:
from cached_property import cached_property


from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook

Expand Down
8 changes: 5 additions & 3 deletions airflow/providers/amazon/aws/secrets/secrets_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,15 @@

import ast
import json
import sys
from typing import Optional
from urllib.parse import urlencode

import boto3

try:
if sys.version_info >= (3, 8):
from functools import cached_property
except ImportError:
else:
from cached_property import cached_property

from airflow.secrets import BaseSecretsBackend
Expand Down Expand Up @@ -194,7 +195,8 @@ def get_conn_uri(self, conn_id: str):
else:
try:
secret_string = self._get_secret(self.connections_prefix, conn_id)
secret = ast.literal_eval(secret_string) # json.loads gives error
# json.loads gives error
secret = ast.literal_eval(secret_string) if secret_string else None
except ValueError: # 'malformed node or string: ' error, for empty conns
connection = None
secret = None
Expand Down
5 changes: 3 additions & 2 deletions airflow/providers/amazon/aws/secrets/systems_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
# specific language governing permissions and limitations
# under the License.
"""Objects relating to sourcing connections from AWS SSM Parameter Store"""
import sys
from typing import Optional

import boto3

try:
if sys.version_info >= (3, 8):
from functools import cached_property
except ImportError:
else:
from cached_property import cached_property

from airflow.secrets import BaseSecretsBackend
Expand Down
5 changes: 3 additions & 2 deletions airflow/providers/amazon/aws/sensors/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import sys
from typing import Any, Optional

try:
if sys.version_info >= (3, 8):
from functools import cached_property
except ImportError:
else:
from cached_property import cached_property

from airflow.exceptions import AirflowException
Expand Down
5 changes: 3 additions & 2 deletions airflow/providers/amazon/aws/sensors/cloud_formation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
# specific language governing permissions and limitations
# under the License.
"""This module contains sensors for AWS CloudFormation."""
import sys
from typing import Optional

try:
if sys.version_info >= (3, 8):
from functools import cached_property
except ImportError:
else:
from cached_property import cached_property

from airflow.providers.amazon.aws.hooks.cloud_formation import AWSCloudFormationHook
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/amazon/aws/sensors/emr_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import sys
from typing import Any, Optional

try:
if sys.version_info >= (3, 8):
from functools import cached_property
except ImportError:
else:
from cached_property import cached_property

from airflow.exceptions import AirflowException
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/asana/hooks/asana.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _merge_create_task_parameters(self, task_name: str, task_params: Optional[di
:param task_params: Other task parameters which should override defaults from the connection
:return: A dict of merged parameters to use in the new task
"""
merged_params = {"name": task_name}
merged_params: Dict[str, Any] = {"name": task_name}
if self.project:
merged_params["projects"] = [self.project]
# Only use default workspace if user did not provide a project id
Expand Down
9 changes: 6 additions & 3 deletions airflow/providers/cncf/kubernetes/hooks/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,21 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import sys
import tempfile
from typing import Any, Dict, Generator, Optional, Tuple, Union

try:
if sys.version_info >= (3, 8):
from functools import cached_property
except ImportError:
else:
from cached_property import cached_property

from kubernetes import client, config, watch

try:
import airflow.utils.yaml as yaml
except ImportError:
import yaml
import yaml # type: ignore[no-redef]

from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
Expand Down Expand Up @@ -252,6 +254,7 @@ def get_namespace(self) -> Optional[str]:
extras = connection.extra_dejson
namespace = extras.get("extra__kubernetes__namespace", "default")
return namespace
return None

def get_pod_log_stream(
self,
Expand Down

0 comments on commit 2fb5e1d

Please sign in to comment.