Skip to content

Commit

Permalink
Bump typing-extensions and mypy for ParamSpec (#25088)
Browse files Browse the repository at this point in the history
* Bump typing-extensions and mypy for ParamSpec

I want to use them in some @task signature improvements. Mypy added this
in 0.950, but let's just bump to latest since why not.

Changelog of typing-extensions is spotty before 4.0, but ParamSpec was
introduced some time before that (likely some time in 2021), and it
seems to be a reasonble minimum to bump to.

For more about ParamSpec, read PEP 612: https://peps.python.org/pep-0612/
  • Loading branch information
uranusjr committed Jul 18, 2022
1 parent d30885a commit e32e9c5
Show file tree
Hide file tree
Showing 12 changed files with 73 additions and 47 deletions.
4 changes: 2 additions & 2 deletions airflow/jobs/scheduler_job.py
Expand Up @@ -175,7 +175,7 @@ def register_signals(self) -> None:
signal.signal(signal.SIGTERM, self._exit_gracefully)
signal.signal(signal.SIGUSR2, self._debug_dump)

def _exit_gracefully(self, signum: int, frame: "FrameType") -> None:
def _exit_gracefully(self, signum: int, frame: Optional["FrameType"]) -> None:
"""Helper method to clean up processor_agent to avoid leaving orphan processes."""
if not _is_parent_process():
# Only the parent process should perform the cleanup.
Expand All @@ -186,7 +186,7 @@ def _exit_gracefully(self, signum: int, frame: "FrameType") -> None:
self.processor_agent.end()
sys.exit(os.EX_OK)

def _debug_dump(self, signum: int, frame: "FrameType") -> None:
def _debug_dump(self, signum: int, frame: Optional["FrameType"]) -> None:
if not _is_parent_process():
# Only the parent process should perform the debug dump.
return
Expand Down
5 changes: 4 additions & 1 deletion airflow/mypy/plugin/decorators.py
Expand Up @@ -68,7 +68,10 @@ def _change_decorator_function_type(
# Mark provided arguments as optional
decorator.arg_types = copy.copy(decorated.arg_types)
for argument in provided_arguments:
index = decorated.arg_names.index(argument)
try:
index = decorated.arg_names.index(argument)
except ValueError:
continue
decorated_type = decorated.arg_types[index]
decorator.arg_types[index] = UnionType.make_union([decorated_type, NoneType()])
decorated.arg_kinds[index] = ARG_NAMED_OPT
Expand Down
1 change: 1 addition & 0 deletions airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py
Expand Up @@ -114,6 +114,7 @@ def execute(self, context: 'Context') -> None:

scan_kwargs = copy(self.dynamodb_scan_kwargs) if self.dynamodb_scan_kwargs else {}
err = None
f: IO[Any]
with NamedTemporaryFile() as f:
try:
f = self._scan_dynamodb_and_upload_to_s3(f, scan_kwargs, table)
Expand Down
19 changes: 11 additions & 8 deletions airflow/providers/amazon/aws/transfers/sql_to_s3.py
Expand Up @@ -16,8 +16,8 @@
# specific language governing permissions and limitations
# under the License.

import enum
from collections import namedtuple
from enum import Enum
from tempfile import NamedTemporaryFile
from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union

Expand All @@ -35,10 +35,13 @@
from airflow.utils.context import Context


FILE_FORMAT = Enum(
"FILE_FORMAT",
"CSV, JSON, PARQUET",
)
class FILE_FORMAT(enum.Enum):
"""Possible file formats."""

CSV = enum.auto()
JSON = enum.auto()
PARQUET = enum.auto()


FileOptions = namedtuple('FileOptions', ['mode', 'suffix', 'function'])

Expand Down Expand Up @@ -118,9 +121,9 @@ def __init__(
if "path_or_buf" in self.pd_kwargs:
raise AirflowException('The argument path_or_buf is not allowed, please remove it')

self.file_format = getattr(FILE_FORMAT, file_format.upper(), None)

if self.file_format is None:
try:
self.file_format = FILE_FORMAT[file_format.upper()]
except KeyError:
raise AirflowException(f"The argument file_format doesn't support {file_format} value.")

@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/operators/cloud_sql.py
Expand Up @@ -37,7 +37,7 @@
SETTINGS = 'settings'
SETTINGS_VERSION = 'settingsVersion'

CLOUD_SQL_CREATE_VALIDATION = [
CLOUD_SQL_CREATE_VALIDATION: Sequence[dict] = [
dict(name="name", allow_empty=False),
dict(
name="settings",
Expand Down
62 changes: 38 additions & 24 deletions airflow/providers/microsoft/azure/hooks/cosmos.py
Expand Up @@ -23,6 +23,7 @@
login (=Endpoint uri), password (=secret key) and extra fields database_name and collection_name to specify
the default database and collection to use (see connection `azure_cosmos_default` for an example).
"""
import json
import uuid
from typing import Any, Dict, Optional

Expand Down Expand Up @@ -140,14 +141,22 @@ def does_collection_exist(self, collection_name: str, database_name: str) -> boo
existing_container = list(
self.get_conn()
.get_database_client(self.__get_database_name(database_name))
.query_containers("SELECT * FROM r WHERE r.id=@id", [{"name": "@id", "value": collection_name}])
.query_containers(
"SELECT * FROM r WHERE r.id=@id",
parameters=[json.dumps({"name": "@id", "value": collection_name})],
)
)
if len(existing_container) == 0:
return False

return True

def create_collection(self, collection_name: str, database_name: Optional[str] = None) -> None:
def create_collection(
self,
collection_name: str,
database_name: Optional[str] = None,
partition_key: Optional[str] = None,
) -> None:
"""Creates a new collection in the CosmosDB database."""
if collection_name is None:
raise AirflowBadRequest("Collection name cannot be None.")
Expand All @@ -157,13 +166,16 @@ def create_collection(self, collection_name: str, database_name: Optional[str] =
existing_container = list(
self.get_conn()
.get_database_client(self.__get_database_name(database_name))
.query_containers("SELECT * FROM r WHERE r.id=@id", [{"name": "@id", "value": collection_name}])
.query_containers(
"SELECT * FROM r WHERE r.id=@id",
parameters=[json.dumps({"name": "@id", "value": collection_name})],
)
)

# Only create if we did not find it already existing
if len(existing_container) == 0:
self.get_conn().get_database_client(self.__get_database_name(database_name)).create_container(
collection_name
collection_name, partition_key=partition_key
)

def does_database_exist(self, database_name: str) -> bool:
Expand All @@ -173,10 +185,8 @@ def does_database_exist(self, database_name: str) -> bool:

existing_database = list(
self.get_conn().query_databases(
{
"query": "SELECT * FROM r WHERE r.id=@id",
"parameters": [{"name": "@id", "value": database_name}],
}
"SELECT * FROM r WHERE r.id=@id",
parameters=[json.dumps({"name": "@id", "value": database_name})],
)
)
if len(existing_database) == 0:
Expand All @@ -193,10 +203,8 @@ def create_database(self, database_name: str) -> None:
# to create it twice
existing_database = list(
self.get_conn().query_databases(
{
"query": "SELECT * FROM r WHERE r.id=@id",
"parameters": [{"name": "@id", "value": database_name}],
}
"SELECT * FROM r WHERE r.id=@id",
parameters=[json.dumps({"name": "@id", "value": database_name})],
)
)

Expand Down Expand Up @@ -267,18 +275,28 @@ def insert_documents(
return created_documents

def delete_document(
self, document_id: str, database_name: Optional[str] = None, collection_name: Optional[str] = None
self,
document_id: str,
database_name: Optional[str] = None,
collection_name: Optional[str] = None,
partition_key: Optional[str] = None,
) -> None:
"""Delete an existing document out of a collection in the CosmosDB database."""
if document_id is None:
raise AirflowBadRequest("Cannot delete a document without an id")

self.get_conn().get_database_client(self.__get_database_name(database_name)).get_container_client(
self.__get_collection_name(collection_name)
).delete_item(document_id)
(
self.get_conn()
.get_database_client(self.__get_database_name(database_name))
.get_container_client(self.__get_collection_name(collection_name))
.delete_item(document_id, partition_key=partition_key)
)

def get_document(
self, document_id: str, database_name: Optional[str] = None, collection_name: Optional[str] = None
self,
document_id: str,
database_name: Optional[str] = None,
collection_name: Optional[str] = None,
partition_key: Optional[str] = None,
):
"""Get a document from an existing collection in the CosmosDB database."""
if document_id is None:
Expand All @@ -289,7 +307,7 @@ def get_document(
self.get_conn()
.get_database_client(self.__get_database_name(database_name))
.get_container_client(self.__get_collection_name(collection_name))
.read_item(document_id)
.read_item(document_id, partition_key=partition_key)
)
except CosmosHttpResponseError:
return None
Expand All @@ -305,17 +323,13 @@ def get_documents(
if sql_string is None:
raise AirflowBadRequest("SQL query string cannot be None")

# Query them in SQL
query = {'query': sql_string}

try:
result_iterable = (
self.get_conn()
.get_database_client(self.__get_database_name(database_name))
.get_container_client(self.__get_collection_name(collection_name))
.query_items(query, partition_key)
.query_items(sql_string, partition_key=partition_key)
)

return list(result_iterable)
except CosmosHttpResponseError:
return None
Expand Down
6 changes: 3 additions & 3 deletions airflow/utils/context.py
Expand Up @@ -23,12 +23,12 @@
import functools
import warnings
from typing import (
AbstractSet,
Any,
Container,
Dict,
ItemsView,
Iterator,
KeysView,
List,
Mapping,
MutableMapping,
Expand Down Expand Up @@ -175,7 +175,7 @@ class Context(MutableMapping[str, Any]):
}

def __init__(self, context: Optional[MutableMapping[str, Any]] = None, **kwargs: Any) -> None:
self._context = context or {}
self._context: MutableMapping[str, Any] = context or {}
if kwargs:
self._context.update(kwargs)
self._deprecation_replacements = self._DEPRECATION_REPLACEMENTS.copy()
Expand Down Expand Up @@ -231,7 +231,7 @@ def __ne__(self, other: Any) -> bool:
return NotImplemented
return self._context != other._context

def keys(self) -> AbstractSet[str]:
def keys(self) -> KeysView[str]:
return self._context.keys()

def items(self):
Expand Down
8 changes: 4 additions & 4 deletions dev/breeze/src/airflow_breeze/commands/testing_commands.py
Expand Up @@ -197,9 +197,9 @@ def run_with_progress(
) -> RunCommandResult:
title = f"Running tests: {test_type}, Python: {python}, Backend: {backend}:{version}"
try:
with tempfile.NamedTemporaryFile(mode='w+t', delete=False) as f:
with tempfile.NamedTemporaryFile(mode='w+t', delete=False) as tf:
get_console().print(f"[info]Starting test = {title}[/]")
thread = MonitoringThread(title=title, file_name=f.name)
thread = MonitoringThread(title=title, file_name=tf.name)
thread.start()
try:
result = run_command(
Expand All @@ -208,14 +208,14 @@ def run_with_progress(
dry_run=dry_run,
env=env_variables,
check=False,
stdout=f,
stdout=tf,
stderr=subprocess.STDOUT,
)
finally:
thread.stop()
thread.join()
with ci_group(f"Result of {title}", message_type=message_type_from_return_code(result.returncode)):
with open(f.name) as f:
with open(tf.name) as f:
shutil.copyfileobj(f, sys.stdout)
finally:
os.unlink(f.name)
Expand Down
1 change: 1 addition & 0 deletions scripts/in_container/run_migration_reference.py
Expand Up @@ -102,6 +102,7 @@ def revision_suffix(rev: "Script"):

def ensure_airflow_version(revisions: Iterable["Script"]):
for rev in revisions:
assert rev.module.__file__ is not None # For Mypy.
file = Path(rev.module.__file__)
content = file.read_text()
if not has_version(content):
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Expand Up @@ -147,7 +147,7 @@ install_requires =
tabulate>=0.7.5
tenacity>=6.2.0
termcolor>=1.1.0
typing-extensions>=3.7.4
typing-extensions>=4.0.0
unicodecsv>=0.14.1
werkzeug>=2.0

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -325,7 +325,7 @@ def write_version(filename: str = str(AIRFLOW_SOURCES_ROOT / "airflow" / "git_ve
# mypyd which does not support installing the types dynamically with --install-types
mypy_dependencies = [
# TODO: upgrade to newer versions of MyPy continuously as they are released
'mypy==0.910',
'mypy==0.950',
'types-boto',
'types-certifi',
'types-croniter',
Expand Down
8 changes: 6 additions & 2 deletions tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
Expand Up @@ -91,7 +91,9 @@ def test_create_container(self, mock_cosmos):
hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
hook.create_collection(self.test_collection_name, self.test_database_name)
expected_calls = [
mock.call().get_database_client('test_database_name').create_container('test_collection_name')
mock.call()
.get_database_client('test_database_name')
.create_container('test_collection_name', partition_key=None)
]
mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
mock_cosmos.assert_has_calls(expected_calls)
Expand All @@ -101,7 +103,9 @@ def test_create_container_default(self, mock_cosmos):
hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
hook.create_collection(self.test_collection_name)
expected_calls = [
mock.call().get_database_client('test_database_name').create_container('test_collection_name')
mock.call()
.get_database_client('test_database_name')
.create_container('test_collection_name', partition_key=None)
]
mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
mock_cosmos.assert_has_calls(expected_calls)
Expand Down

0 comments on commit e32e9c5

Please sign in to comment.