Skip to content

Commit

Permalink
More strict rules in mypy (#9705) (#9906)
Browse files Browse the repository at this point in the history
Signed-off-by: Raymond Etornam <[email protected]>
  • Loading branch information
retornam committed Jul 22, 2020
1 parent 24a951e commit c2db0df
Show file tree
Hide file tree
Showing 28 changed files with 156 additions and 158 deletions.
2 changes: 1 addition & 1 deletion airflow/cli/cli_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def command(*args, **kwargs):
func = import_string(import_path)
return func(*args, **kwargs)

command.__name__ = name # type: ignore
command.__name__ = name

return command

Expand Down
12 changes: 6 additions & 6 deletions airflow/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def getsection(self, section: str) -> Optional[Dict[str, Union[str, int, float,
if (section not in self._sections and section not in self.airflow_defaults._sections): # type: ignore
return None

_section = copy.deepcopy(self.airflow_defaults._sections[section]) # type: ignore
_section = copy.deepcopy(self.airflow_defaults._sections[section])

if section in self._sections: # type: ignore
_section.update(copy.deepcopy(self._sections[section])) # type: ignore
Expand All @@ -481,7 +481,7 @@ def getsection(self, section: str) -> Optional[Dict[str, Union[str, int, float,
key = key.lower()
_section[key] = self._get_env_var_option(section, key)

for key, val in _section.items(): # type: ignore
for key, val in _section.items():
try:
val = int(val)
except ValueError:
Expand All @@ -499,13 +499,13 @@ def write(self, fp, space_around_delimiters=True):
# This is based on the configparser.RawConfigParser.write method code to add support for
# reading options from environment variables.
if space_around_delimiters:
d = " {} ".format(self._delimiters[0]) # type: ignore
d = " {} ".format(self._delimiters[0])
else:
d = self._delimiters[0] # type: ignore
d = self._delimiters[0]
if self._defaults:
self._write_section(fp, self.default_section, self._defaults.items(), d) # type: ignore
self._write_section(fp, self.default_section, self._defaults.items(), d)
for section in self._sections:
self._write_section(fp, section, self.getsection(section).items(), d) # type: ignore
self._write_section(fp, section, self.getsection(section).items(), d)

def as_dict(
self, display_source=False, display_sensitive=False, raw=False,
Expand Down
10 changes: 4 additions & 6 deletions airflow/models/baseoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@
import warnings
from abc import ABCMeta, abstractmethod
from datetime import datetime, timedelta
from typing import (
Any, Callable, ClassVar, Dict, FrozenSet, Iterable, List, Optional, Set, Tuple, Type, Union, cast,
)
from typing import Any, Callable, ClassVar, Dict, FrozenSet, Iterable, List, Optional, Set, Tuple, Type, Union

import attr
import jinja2
Expand Down Expand Up @@ -1168,7 +1166,7 @@ def _set_relatives(self,
task_list = [task_or_task_list] # type: ignore

task_list = [
t.operator if isinstance(t, XComArg) else t # type: ignore
t.operator if isinstance(t, XComArg) else t
for t in task_list
]

Expand Down Expand Up @@ -1381,8 +1379,8 @@ def chain(*tasks: Union[BaseOperator, List[BaseOperator]]):
raise TypeError(
'Chain not supported between instances of {up_type} and {down_type}'.format(
up_type=type(up_task), down_type=type(down_task)))
up_task_list = cast(List[BaseOperator], up_task)
down_task_list = cast(List[BaseOperator], down_task)
up_task_list = up_task
down_task_list = down_task
if len(up_task_list) != len(down_task_list):
raise AirflowException(
f'Chain not supported different length Iterable '
Expand Down
4 changes: 1 addition & 3 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import Any, List, Optional, Tuple, Union, cast
from typing import Any, List, Optional, Tuple, Union

from sqlalchemy import (
Boolean, Column, DateTime, Index, Integer, PickleType, String, UniqueConstraint, and_, func, or_,
Expand Down Expand Up @@ -266,8 +266,6 @@ def get_dag(self):
def get_previous_dagrun(self, state: Optional[str] = None, session: Session = None) -> Optional['DagRun']:
"""The previous DagRun, if there is one"""

session = cast(Session, session) # mypy

filters = [
DagRun.dag_id == self.dag_id,
DagRun.execution_date < self.execution_date,
Expand Down
6 changes: 3 additions & 3 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1771,9 +1771,9 @@ def filter_for_tis(
for tik in tis])
return or_(*filter_for_tis)
if all(isinstance(t, TaskInstance) for t in tis):
filter_for_tis = ([and_(TI.dag_id == ti.dag_id, # type: ignore
TI.task_id == ti.task_id, # type: ignore
TI.execution_date == ti.execution_date) # type: ignore
filter_for_tis = ([and_(TI.dag_id == ti.dag_id,
TI.task_id == ti.task_id,
TI.execution_date == ti.execution_date)
for ti in tis])
return or_(*filter_for_tis)

Expand Down
4 changes: 2 additions & 2 deletions airflow/plugins_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@

import pkg_resources

from airflow import settings # type: ignore
from airflow.utils.file import find_path_from_directory # type: ignore
from airflow import settings
from airflow.utils.file import find_path_from_directory

log = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/apache/hive/transfers/mssql_to_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def execute(self, context: Dict[str, str]):
cursor.execute(self.sql)
with NamedTemporaryFile("w") as tmp_file:
csv_writer = csv.writer(tmp_file, delimiter=self.delimiter, encoding='utf-8')
field_dict = OrderedDict() # type:ignore
field_dict = OrderedDict()
col_count = 0
for field in cursor.description:
col_count += 1
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/apache/hive/transfers/mysql_to_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def execute(self, context: Dict[str, str]):
quotechar=self.quotechar,
escapechar=self.escapechar,
encoding="utf-8")
field_dict = OrderedDict() # type:ignore
field_dict = OrderedDict()
for field in cursor.description:
field_dict[field[0]] = self.type_map(field[1])
csv_writer.writerows(cursor)
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/cncf/kubernetes/operators/kubernetes_pod.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,9 +370,9 @@ def create_new_pod_for_operator(self, labels, launcher) -> Tuple[State, k8s.V1Po
# noinspection PyTypeChecker
pod = append_to_pod(
pod,
self.pod_runtime_info_envs + # type: ignore
self.pod_runtime_info_envs +
self.ports + # type: ignore
self.resources + # type: ignore
self.resources +
self.secrets + # type: ignore
self.volumes + # type: ignore
self.volume_mounts # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/docker/operators/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def __get_tls_config(self):
ca_cert=self.tls_ca_cert,
client_cert=(self.tls_client_cert, self.tls_client_key),
verify=True,
ssl_version=self.tls_ssl_version, # type: ignore
ssl_version=self.tls_ssl_version,
assert_hostname=self.tls_hostname
)
self.docker_url = self.docker_url.replace('tcp://', 'https://')
Expand Down
18 changes: 9 additions & 9 deletions airflow/providers/google/cloud/hooks/cloud_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def create_instance(self, body: Dict, project_id: str) -> None:
body=body
).execute(num_retries=self.num_retries)
operation_name = response["name"]
self._wait_for_operation_to_complete(project_id=project_id, # type:ignore
self._wait_for_operation_to_complete(project_id=project_id,
operation_name=operation_name)

@GoogleBaseHook.fallback_to_default_project_id
Expand Down Expand Up @@ -170,7 +170,7 @@ def patch_instance(self, body: Dict, instance: str, project_id: str) -> None:
body=body
).execute(num_retries=self.num_retries)
operation_name = response["name"]
self._wait_for_operation_to_complete(project_id=project_id, # type:ignore
self._wait_for_operation_to_complete(project_id=project_id,
operation_name=operation_name)

@GoogleBaseHook.fallback_to_default_project_id
Expand All @@ -191,7 +191,7 @@ def delete_instance(self, instance: str, project_id: str) -> None:
instance=instance,
).execute(num_retries=self.num_retries)
operation_name = response["name"]
self._wait_for_operation_to_complete(project_id=project_id, # type:ignore
self._wait_for_operation_to_complete(project_id=project_id,
operation_name=operation_name)

@GoogleBaseHook.fallback_to_default_project_id
Expand Down Expand Up @@ -238,7 +238,7 @@ def create_database(self, instance: str, body: Dict, project_id: str) -> None:
body=body
).execute(num_retries=self.num_retries)
operation_name = response["name"]
self._wait_for_operation_to_complete(project_id=project_id, # type:ignore
self._wait_for_operation_to_complete(project_id=project_id,
operation_name=operation_name)

@GoogleBaseHook.fallback_to_default_project_id
Expand Down Expand Up @@ -275,7 +275,7 @@ def patch_database(
body=body
).execute(num_retries=self.num_retries)
operation_name = response["name"]
self._wait_for_operation_to_complete(project_id=project_id, # type:ignore
self._wait_for_operation_to_complete(project_id=project_id,
operation_name=operation_name)

@GoogleBaseHook.fallback_to_default_project_id
Expand All @@ -299,7 +299,7 @@ def delete_database(self, instance: str, database: str, project_id: str) -> None
database=database
).execute(num_retries=self.num_retries)
operation_name = response["name"]
self._wait_for_operation_to_complete(project_id=project_id, # type:ignore
self._wait_for_operation_to_complete(project_id=project_id,
operation_name=operation_name)

@GoogleBaseHook.fallback_to_default_project_id
Expand All @@ -326,7 +326,7 @@ def export_instance(self, instance: str, body: Dict, project_id: str) -> None:
body=body
).execute(num_retries=self.num_retries)
operation_name = response["name"]
self._wait_for_operation_to_complete(project_id=project_id, # type:ignore
self._wait_for_operation_to_complete(project_id=project_id,
operation_name=operation_name)

@GoogleBaseHook.fallback_to_default_project_id
Expand All @@ -353,7 +353,7 @@ def import_instance(self, instance: str, body: Dict, project_id: str) -> None:
body=body
).execute(num_retries=self.num_retries)
operation_name = response["name"]
self._wait_for_operation_to_complete(project_id=project_id, # type: ignore
self._wait_for_operation_to_complete(project_id=project_id,
operation_name=operation_name)
except HttpError as ex:
raise AirflowException(
Expand Down Expand Up @@ -984,7 +984,7 @@ def cleanup_database_hook(self) -> None:
raise ValueError("The db_hook should be set")
if not isinstance(self.db_hook, PostgresHook):
raise ValueError(f"The db_hook should be PostrgresHook and is {type(self.db_hook)}")
conn = getattr(self.db_hook, 'conn') # type: ignore
conn = getattr(self.db_hook, 'conn')
if conn and conn.notices:
for output in self.db_hook.conn.notices:
self.log.info(output)
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/hooks/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def translate(
"""
client = self.get_conn()

return client.translate( # type: ignore
return client.translate(
values=values,
target_language=target_language,
format_=format_,
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/operators/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __init__(
super().__init__(sql=sql, *args, **kwargs)
if bigquery_conn_id:
warnings.warn(_DEPRECATION_MSG, DeprecationWarning, stacklevel=3)
gcp_conn_id = bigquery_conn_id # type: ignore
gcp_conn_id = bigquery_conn_id

self.gcp_conn_id = gcp_conn_id
self.sql = sql
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def execute(self, context: Dict):
self.log.info("Inserting Campaign Manager report.")
response = hook.insert_report(
profile_id=self.profile_id, report=self.report
) # type: ignore
)
report_id = response.get("id")
self.xcom_push(context, key="report_id", value=report_id)
self.log.info("Report successfully inserted. Report id: %s", report_id)
Expand Down
4 changes: 2 additions & 2 deletions airflow/utils/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import zipfile
from typing import Dict, Generator, List, Optional, Pattern

from airflow.configuration import conf # type: ignore
from airflow.configuration import conf

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -159,7 +159,7 @@ def list_py_file_paths(directory: str,
elif os.path.isdir(directory):
find_dag_file_paths(directory, file_paths, safe_mode)
if include_examples:
from airflow import example_dags # type: ignore
from airflow import example_dags
example_dag_folder = example_dags.__path__[0] # type: ignore
file_paths.extend(list_py_file_paths(example_dag_folder, safe_mode, False))
return file_paths
Expand Down
8 changes: 4 additions & 4 deletions docs/exts/exampleinclude.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def register_source(app, env, modname):
:param modname: name of the module to load
:return: True if the code is registered successfully, False otherwise
"""
entry = env._viewcode_modules.get(modname, None) # type: ignore
entry = env._viewcode_modules.get(modname, None)
if entry is False:
print("[%s] Entry is false for " % modname)
return False
Expand All @@ -153,7 +153,7 @@ def register_source(app, env, modname):
except Exception as ex: # pylint: disable=broad-except
logger.info("Module \"%s\" could not be loaded. Full source will not be available. \"%s\"",
modname, ex)
env._viewcode_modules[modname] = False # type: ignore
env._viewcode_modules[modname] = False
return False

if not isinstance(analyzer.code, str):
Expand All @@ -169,7 +169,7 @@ def register_source(app, env, modname):

if entry is None or entry[0] != code:
entry = code, tags, {}, ""
env._viewcode_modules[modname] = entry # type: ignore
env._viewcode_modules[modname] = entry

return True
# pylint: enable=protected-access
Expand Down Expand Up @@ -222,7 +222,7 @@ def doctree_read(app, doctree):
"""
env = app.builder.env
if not hasattr(env, "_viewcode_modules"):
env._viewcode_modules = {} # type: ignore
env._viewcode_modules = {}

if app.builder.name == "singlehtml":
return
Expand Down
2 changes: 1 addition & 1 deletion scripts/perf/sql_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def main() -> None:
times.append(exec_time)
for qry in queries:
info = qry.to_dict()
info["test_no"] = i # type: ignore
info["test_no"] = i
rows.append(info)

rows_to_csv(rows, name="/files/sql_after_remote.csv")
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ packages = airflow
[mypy]
ignore_missing_imports = True
no_implicit_optional = True
warn_redundant_casts = True
warn_unused_ignores = True
plugins =
airflow.mypy.plugin.decorators
pretty = True
Expand Down
4 changes: 2 additions & 2 deletions tests/api_connexion/endpoints/test_dag_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ def test_should_response_200(self):

def test_should_response_200_serialized(self):
# Create empty app with empty dagbag to check if DAG is read from db
app_serialized = app.create_app(testing=True) # type:ignore
app_serialized = app.create_app(testing=True)
dag_bag = DagBag(os.devnull, include_examples=False, read_dags_from_db=True)
app_serialized.dag_bag = dag_bag # type:ignore
app_serialized.dag_bag = dag_bag
client = app_serialized.test_client()

SerializedDagModel.write_dag(self.dag)
Expand Down
4 changes: 2 additions & 2 deletions tests/api_connexion/endpoints/test_task_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ def test_should_response_200(self):
@conf_vars({("core", "store_serialized_dags"): "True"})
def test_should_response_200_serialized(self):
# Create empty app with empty dagbag to check if DAG is read from db
app_serialized = app.create_app(testing=True) # type:ignore
app_serialized = app.create_app(testing=True)
dag_bag = DagBag(os.devnull, include_examples=False, read_dags_from_db=True)
app_serialized.dag_bag = dag_bag # type:ignore
app_serialized.dag_bag = dag_bag
client = app_serialized.test_client()

SerializedDagModel.write_dag(self.dag)
Expand Down
2 changes: 1 addition & 1 deletion tests/cli/commands/test_task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def test_cli_run_mutually_exclusive(self):
AirflowException,
"Option --raw and --local are mutually exclusive."
):
task_command.task_run(self.parser.parse_args([ # type: ignore
task_command.task_run(self.parser.parse_args([
'tasks', 'run', 'example_bash_operator', 'runme_0', DEFAULT_DATE.isoformat(), '--raw',
'--local'
]))
Expand Down
4 changes: 2 additions & 2 deletions tests/plugins/test_plugin_ignore.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
import unittest
from unittest.mock import patch

from airflow import settings # type: ignore
from airflow.utils.file import find_path_from_directory # type: ignore
from airflow import settings
from airflow.utils.file import find_path_from_directory


class TestIgnorePluginFile(unittest.TestCase):
Expand Down

0 comments on commit c2db0df

Please sign in to comment.