Skip to content

Commit

Permalink
Enable Black - Python Auto Formmatter (#9550)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaxil committed Nov 3, 2020
1 parent 1dc7099 commit 4e8f9cc
Show file tree
Hide file tree
Showing 1,070 changed files with 15,413 additions and 15,140 deletions.
5 changes: 2 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,7 @@ repos:
rev: 20.8b1
hooks:
- id: black
files: api_connexion/.*\.py|.*providers.*\.py|^chart/tests/.*\.py
exclude: .*kubernetes_pod\.py|.*google/common/hooks/base_google\.py$
exclude: .*kubernetes_pod\.py|.*google/common/hooks/base_google\.py$|^airflow/configuration.py$
args: [--config=./pyproject.toml]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.3.0
Expand Down Expand Up @@ -203,7 +202,7 @@ repos:
name: Run isort to sort imports
types: [python]
# To keep consistent with the global isort skip config defined in setup.cfg
exclude: ^build/.*$|^.tox/.*$|^venv/.*$|.*api_connexion/.*\.py|.*providers.*\.py
exclude: ^build/.*$|^.tox/.*$|^venv/.*$
- repo: https://github.com/pycqa/pydocstyle
rev: 5.1.1
hooks:
Expand Down
3 changes: 3 additions & 0 deletions airflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,18 @@ def __getattr__(name):
# PEP-562: Lazy loaded attributes on python modules
if name == "DAG":
from airflow.models.dag import DAG # pylint: disable=redefined-outer-name

return DAG
if name == "AirflowException":
from airflow.exceptions import AirflowException # pylint: disable=redefined-outer-name

return AirflowException
raise AttributeError(f"module {__name__} has no attribute {name}")


if not settings.LAZY_LOAD_PLUGINS:
from airflow import plugins_manager

plugins_manager.ensure_plugins_loaded()


Expand Down
5 changes: 1 addition & 4 deletions airflow/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,5 @@ def load_auth():
log.info("Loaded API auth backend: %s", auth_backend)
return auth_backend
except ImportError as err:
log.critical(
"Cannot import %s for API authentication due to: %s",
auth_backend, err
)
log.critical("Cannot import %s for API authentication due to: %s", auth_backend, err)
raise AirflowException(err)
5 changes: 2 additions & 3 deletions airflow/api/auth/backend/basic_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,12 @@ def auth_current_user() -> Optional[User]:

def requires_authentication(function: T):
"""Decorator for functions that require authentication"""

@wraps(function)
def decorated(*args, **kwargs):
if auth_current_user() is not None:
return function(*args, **kwargs)
else:
return Response(
"Unauthorized", 401, {"WWW-Authenticate": "Basic"}
)
return Response("Unauthorized", 401, {"WWW-Authenticate": "Basic"})

return cast(T, decorated)
1 change: 1 addition & 0 deletions airflow/api/auth/backend/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def init_app(_):

def requires_authentication(function: T):
"""Decorator for functions that require authentication"""

@wraps(function)
def decorated(*args, **kwargs):
return function(*args, **kwargs)
Expand Down
5 changes: 3 additions & 2 deletions airflow/api/auth/backend/kerberos_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def _gssapi_authenticate(token):

def requires_authentication(function: T):
"""Decorator for functions that require authentication with Kerberos"""

@wraps(function)
def decorated(*args, **kwargs):
header = request.headers.get("Authorization")
Expand All @@ -144,11 +145,11 @@ def decorated(*args, **kwargs):
response = function(*args, **kwargs)
response = make_response(response)
if ctx.kerberos_token is not None:
response.headers['WWW-Authenticate'] = ' '.join(['negotiate',
ctx.kerberos_token])
response.headers['WWW-Authenticate'] = ' '.join(['negotiate', ctx.kerberos_token])

return response
if return_code != kerberos.AUTH_GSS_CONTINUE:
return _forbidden()
return _unauthorized()

return cast(T, decorated)
2 changes: 1 addition & 1 deletion airflow/api/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,6 @@ def get_current_api_client() -> Client:
api_client = api_module.Client(
api_base_url=conf.get('cli', 'endpoint_url'),
auth=getattr(auth_backend, 'CLIENT_AUTH', None),
session=session
session=session,
)
return api_client
30 changes: 18 additions & 12 deletions airflow/api/client/json_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,15 @@ def _request(self, url, method='GET', json=None):
def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None):
endpoint = f'/api/experimental/dags/{dag_id}/dag_runs'
url = urljoin(self._api_base_url, endpoint)
data = self._request(url, method='POST',
json={
"run_id": run_id,
"conf": conf,
"execution_date": execution_date,
})
data = self._request(
url,
method='POST',
json={
"run_id": run_id,
"conf": conf,
"execution_date": execution_date,
},
)
return data['message']

def delete_dag(self, dag_id):
Expand All @@ -74,12 +77,15 @@ def get_pools(self):
def create_pool(self, name, slots, description):
endpoint = '/api/experimental/pools'
url = urljoin(self._api_base_url, endpoint)
pool = self._request(url, method='POST',
json={
'name': name,
'slots': slots,
'description': description,
})
pool = self._request(
url,
method='POST',
json={
'name': name,
'slots': slots,
'description': description,
},
)
return pool['pool'], pool['slots'], pool['description']

def delete_pool(self, name):
Expand Down
7 changes: 3 additions & 4 deletions airflow/api/client/local_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@ class Client(api_client.Client):
"""Local API client implementation."""

def trigger_dag(self, dag_id, run_id=None, conf=None, execution_date=None):
dag_run = trigger_dag.trigger_dag(dag_id=dag_id,
run_id=run_id,
conf=conf,
execution_date=execution_date)
dag_run = trigger_dag.trigger_dag(
dag_id=dag_id, run_id=run_id, conf=conf, execution_date=execution_date
)
return f"Created {dag_run}"

def delete_dag(self, dag_id):
Expand Down
8 changes: 2 additions & 6 deletions airflow/api/common/experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,7 @@ def check_and_get_dag(dag_id: str, task_id: Optional[str] = None) -> DagModel:
if dag_model is None:
raise DagNotFound(f"Dag id {dag_id} not found in DagModel")

dagbag = DagBag(
dag_folder=dag_model.fileloc,
read_dags_from_db=True
)
dagbag = DagBag(dag_folder=dag_model.fileloc, read_dags_from_db=True)
dag = dagbag.get_dag(dag_id)
if not dag:
error_message = f"Dag id {dag_id} not found"
Expand All @@ -47,7 +44,6 @@ def check_and_get_dagrun(dag: DagModel, execution_date: datetime) -> DagRun:
"""Get DagRun object and check that it exists"""
dagrun = dag.get_dagrun(execution_date=execution_date)
if not dagrun:
error_message = ('Dag Run for date {} not found in dag {}'
.format(execution_date, dag.dag_id))
error_message = f'Dag Run for date {execution_date} not found in dag {dag.dag_id}'
raise DagRunNotFound(error_message)
return dagrun
11 changes: 6 additions & 5 deletions airflow/api/common/experimental/delete_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,14 @@ def delete_dag(dag_id: str, keep_records_in_log: bool = True, session=None) -> i
if dag.is_subdag:
parent_dag_id, task_id = dag_id.rsplit(".", 1)
for model in TaskFail, models.TaskInstance:
count += session.query(model).filter(model.dag_id == parent_dag_id,
model.task_id == task_id).delete()
count += (
session.query(model).filter(model.dag_id == parent_dag_id, model.task_id == task_id).delete()
)

# Delete entries in Import Errors table for a deleted DAG
# This handles the case when the dag_id is changed in the file
session.query(models.ImportError).filter(
models.ImportError.filename == dag.fileloc
).delete(synchronize_session='fetch')
session.query(models.ImportError).filter(models.ImportError.filename == dag.fileloc).delete(
synchronize_session='fetch'
)

return count
22 changes: 11 additions & 11 deletions airflow/api/common/experimental/get_dag_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,16 @@ def get_dag_runs(dag_id: str, state: Optional[str] = None) -> List[Dict[str, Any
dag_runs = []
state = state.lower() if state else None
for run in DagRun.find(dag_id=dag_id, state=state):
dag_runs.append({
'id': run.id,
'run_id': run.run_id,
'state': run.state,
'dag_id': run.dag_id,
'execution_date': run.execution_date.isoformat(),
'start_date': ((run.start_date or '') and
run.start_date.isoformat()),
'dag_run_url': url_for('Airflow.graph', dag_id=run.dag_id,
execution_date=run.execution_date)
})
dag_runs.append(
{
'id': run.id,
'run_id': run.run_id,
'state': run.state,
'dag_id': run.dag_id,
'execution_date': run.execution_date.isoformat(),
'start_date': ((run.start_date or '') and run.start_date.isoformat()),
'dag_run_url': url_for('Airflow.graph', dag_id=run.dag_id, execution_date=run.execution_date),
}
)

return dag_runs
10 changes: 6 additions & 4 deletions airflow/api/common/experimental/get_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@ def get_lineage(dag_id: str, execution_date: datetime.datetime, session=None) ->
dag = check_and_get_dag(dag_id)
check_and_get_dagrun(dag, execution_date)

inlets: List[XCom] = XCom.get_many(dag_ids=dag_id, execution_date=execution_date,
key=PIPELINE_INLETS, session=session).all()
outlets: List[XCom] = XCom.get_many(dag_ids=dag_id, execution_date=execution_date,
key=PIPELINE_OUTLETS, session=session).all()
inlets: List[XCom] = XCom.get_many(
dag_ids=dag_id, execution_date=execution_date, key=PIPELINE_INLETS, session=session
).all()
outlets: List[XCom] = XCom.get_many(
dag_ids=dag_id, execution_date=execution_date, key=PIPELINE_OUTLETS, session=session
).all()

lineage: Dict[str, Dict[str, Any]] = {}
for meta in inlets:
Expand Down
3 changes: 1 addition & 2 deletions airflow/api/common/experimental/get_task_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ def get_task_instance(dag_id: str, task_id: str, execution_date: datetime) -> Ta
# Get task instance object and check that it exists
task_instance = dagrun.get_task_instance(task_id)
if not task_instance:
error_message = ('Task {} instance for date {} not found'
.format(task_id, execution_date))
error_message = f'Task {task_id} instance for date {execution_date} not found'
raise TaskInstanceNotFound(error_message)

return task_instance
73 changes: 36 additions & 37 deletions airflow/api/common/experimental/mark_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def set_state(
past: bool = False,
state: str = State.SUCCESS,
commit: bool = False,
session=None
session=None,
): # pylint: disable=too-many-arguments,too-many-locals
"""
Set the state of a task instance and if needed its relatives. Can set state
Expand Down Expand Up @@ -137,33 +137,24 @@ def set_state(
# Flake and pylint disagree about correct indents here
def all_subdag_tasks_query(sub_dag_run_ids, session, state, confirmed_dates): # noqa: E123
"""Get *all* tasks of the sub dags"""
qry_sub_dag = session.query(TaskInstance). \
filter(
TaskInstance.dag_id.in_(sub_dag_run_ids),
TaskInstance.execution_date.in_(confirmed_dates)
). \
filter(
or_(
TaskInstance.state.is_(None),
TaskInstance.state != state
)
qry_sub_dag = (
session.query(TaskInstance)
.filter(TaskInstance.dag_id.in_(sub_dag_run_ids), TaskInstance.execution_date.in_(confirmed_dates))
.filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state))
) # noqa: E123
return qry_sub_dag


def get_all_dag_task_query(dag, session, state, task_ids, confirmed_dates):
"""Get all tasks of the main dag that will be affected by a state change"""
qry_dag = session.query(TaskInstance). \
filter(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.execution_date.in_(confirmed_dates),
TaskInstance.task_id.in_(task_ids) # noqa: E123
). \
filter(
or_(
TaskInstance.state.is_(None),
TaskInstance.state != state
qry_dag = (
session.query(TaskInstance)
.filter(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.execution_date.in_(confirmed_dates),
TaskInstance.task_id.in_(task_ids), # noqa: E123
)
.filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state))
)
return qry_dag

Expand All @@ -186,10 +177,12 @@ def get_subdag_runs(dag, session, state, task_ids, commit, confirmed_dates):
# this works as a kind of integrity check
# it creates missing dag runs for subdag operators,
# maybe this should be moved to dagrun.verify_integrity
dag_runs = _create_dagruns(current_task.subdag,
execution_dates=confirmed_dates,
state=State.RUNNING,
run_type=DagRunType.BACKFILL_JOB)
dag_runs = _create_dagruns(
current_task.subdag,
execution_dates=confirmed_dates,
state=State.RUNNING,
run_type=DagRunType.BACKFILL_JOB,
)

verify_dagruns(dag_runs, commit, state, session, current_task)

Expand Down Expand Up @@ -279,10 +272,9 @@ def _set_dag_run_state(dag_id, execution_date, state, session=None):
:param state: target state
:param session: database session
"""
dag_run = session.query(DagRun).filter(
DagRun.dag_id == dag_id,
DagRun.execution_date == execution_date
).one()
dag_run = (
session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.execution_date == execution_date).one()
)
dag_run.state = state
if state == State.RUNNING:
dag_run.start_date = timezone.utcnow()
Expand Down Expand Up @@ -316,8 +308,9 @@ def set_dag_run_state_to_success(dag, execution_date, commit=False, session=None
# Mark all task instances of the dag run to success.
for task in dag.tasks:
task.dag = dag
return set_state(tasks=dag.tasks, execution_date=execution_date,
state=State.SUCCESS, commit=commit, session=session)
return set_state(
tasks=dag.tasks, execution_date=execution_date, state=State.SUCCESS, commit=commit, session=session
)


@provide_session
Expand All @@ -343,10 +336,15 @@ def set_dag_run_state_to_failed(dag, execution_date, commit=False, session=None)

# Mark only RUNNING task instances.
task_ids = [task.task_id for task in dag.tasks]
tis = session.query(TaskInstance).filter(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.execution_date == execution_date,
TaskInstance.task_id.in_(task_ids)).filter(TaskInstance.state == State.RUNNING)
tis = (
session.query(TaskInstance)
.filter(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.execution_date == execution_date,
TaskInstance.task_id.in_(task_ids),
)
.filter(TaskInstance.state == State.RUNNING)
)
task_ids_of_running_tis = [task_instance.task_id for task_instance in tis]

tasks = []
Expand All @@ -356,8 +354,9 @@ def set_dag_run_state_to_failed(dag, execution_date, commit=False, session=None)
task.dag = dag
tasks.append(task)

return set_state(tasks=tasks, execution_date=execution_date,
state=State.FAILED, commit=commit, session=session)
return set_state(
tasks=tasks, execution_date=execution_date, state=State.FAILED, commit=commit, session=session
)


@provide_session
Expand Down

0 comments on commit 4e8f9cc

Please sign in to comment.