Skip to content

Commit

Permalink
Refactor dedent nested loops (#34409)
Browse files Browse the repository at this point in the history
  • Loading branch information
eumiro committed Sep 27, 2023
1 parent b316be1 commit 07fe1d2
Show file tree
Hide file tree
Showing 10 changed files with 66 additions and 80 deletions.
48 changes: 24 additions & 24 deletions airflow/models/taskinstance.py
Expand Up @@ -20,6 +20,7 @@
import collections.abc
import contextlib
import hashlib
import itertools
import logging
import math
import operator
Expand Down Expand Up @@ -3073,32 +3074,31 @@ def filter_for_tis(tis: Iterable[TaskInstance | TaskInstanceKey]) -> BooleanClau

# this assumes that most dags have dag_id as the largest grouping, followed by run_id. even
# if its not, this is still a significant optimization over querying for every single tuple key
for cur_dag_id in dag_ids:
for cur_run_id in run_ids:
# we compare the group size between task_id and map_index and use the smaller group
dag_task_id_groups = task_id_groups[(cur_dag_id, cur_run_id)]
dag_map_index_groups = map_index_groups[(cur_dag_id, cur_run_id)]

if len(dag_task_id_groups) <= len(dag_map_index_groups):
for cur_task_id, cur_map_indices in dag_task_id_groups.items():
filter_condition.append(
and_(
TaskInstance.dag_id == cur_dag_id,
TaskInstance.run_id == cur_run_id,
TaskInstance.task_id == cur_task_id,
TaskInstance.map_index.in_(cur_map_indices),
)
for cur_dag_id, cur_run_id in itertools.product(dag_ids, run_ids):
# we compare the group size between task_id and map_index and use the smaller group
dag_task_id_groups = task_id_groups[(cur_dag_id, cur_run_id)]
dag_map_index_groups = map_index_groups[(cur_dag_id, cur_run_id)]

if len(dag_task_id_groups) <= len(dag_map_index_groups):
for cur_task_id, cur_map_indices in dag_task_id_groups.items():
filter_condition.append(
and_(
TaskInstance.dag_id == cur_dag_id,
TaskInstance.run_id == cur_run_id,
TaskInstance.task_id == cur_task_id,
TaskInstance.map_index.in_(cur_map_indices),
)
else:
for cur_map_index, cur_task_ids in dag_map_index_groups.items():
filter_condition.append(
and_(
TaskInstance.dag_id == cur_dag_id,
TaskInstance.run_id == cur_run_id,
TaskInstance.task_id.in_(cur_task_ids),
TaskInstance.map_index == cur_map_index,
)
)
else:
for cur_map_index, cur_task_ids in dag_map_index_groups.items():
filter_condition.append(
and_(
TaskInstance.dag_id == cur_dag_id,
TaskInstance.run_id == cur_run_id,
TaskInstance.task_id.in_(cur_task_ids),
TaskInstance.map_index == cur_map_index,
)
)

return or_(*filter_condition)

Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/google/cloud/transfers/sql_to_gcs.py
Expand Up @@ -222,8 +222,8 @@ def convert_types(self, schema, col_type_dict, row) -> list:
def _write_rows_to_parquet(parquet_writer: pq.ParquetWriter, rows):
rows_pydic: dict[str, list[Any]] = {col: [] for col in parquet_writer.schema.names}
for row in rows:
for ind, col in enumerate(parquet_writer.schema.names):
rows_pydic[col].append(row[ind])
for cell, col in zip(row, parquet_writer.schema.names):
rows_pydic[col].append(cell)
tbl = pa.Table.from_pydict(rows_pydic, parquet_writer.schema)
parquet_writer.write_table(tbl)

Expand Down
6 changes: 3 additions & 3 deletions airflow/www/security_manager.py
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import itertools
import warnings
from typing import TYPE_CHECKING, Any, Collection, Container, Iterable, Sequence

Expand Down Expand Up @@ -731,9 +732,8 @@ def _revoke_stale_permissions(resource: Resource):
def create_perm_vm_for_all_dag(self) -> None:
"""Create perm-vm if not exist and insert into FAB security model for all-dags."""
# create perm for global logical dag
for resource_name in self.DAG_RESOURCES:
for action_name in self.DAG_ACTIONS:
self._merge_perm(action_name, resource_name)
for resource_name, action_name in itertools.product(self.DAG_RESOURCES, self.DAG_ACTIONS):
self._merge_perm(action_name, resource_name)

def check_authorization(
self,
Expand Down
6 changes: 2 additions & 4 deletions airflow/www/views.py
Expand Up @@ -1223,10 +1223,8 @@ def task_stats(self, session: Session = NEW_SESSION):
)
data = get_task_stats_from_query(qry)
payload: dict[str, list[dict[str, Any]]] = collections.defaultdict(list)
for dag_id in filter_dag_ids:
for state in State.task_states:
count = data.get(dag_id, {}).get(state, 0)
payload[dag_id].append({"state": state, "count": count})
for dag_id, state in itertools.product(filter_dag_ids, State.task_states):
payload[dag_id].append({"state": state, "count": data.get(dag_id, {}).get(state, 0)})
return flask.json.jsonify(payload)

@expose("/last_dagruns", methods=["POST"])
Expand Down
7 changes: 3 additions & 4 deletions dev/breeze/src/airflow_breeze/utils/exclude_from_matrix.py
Expand Up @@ -16,6 +16,8 @@
# under the License.
from __future__ import annotations

import itertools


def representative_combos(list_1: list[str], list_2: list[str]) -> list[tuple[str, str]]:
"""
Expand All @@ -40,8 +42,5 @@ def excluded_combos(list_1: list[str], list_2: list[str]) -> list[tuple[str, str
:param list_2: second list
:return: list of exclusions = list 1 x list 2 - representative_combos
"""
all_combos: list[tuple[str, str]] = []
for item_1 in list_1:
for item_2 in list_2:
all_combos.append((item_1, item_2))
all_combos: list[tuple[str, str]] = list(itertools.product(list_1, list_2))
return [item for item in all_combos if item not in set(representative_combos(list_1, list_2))]
6 changes: 2 additions & 4 deletions dev/breeze/src/airflow_breeze/utils/selective_checks.py
Expand Up @@ -470,10 +470,8 @@ def kubernetes_combos_list_as_string(self) -> str:

def _match_files_with_regexps(self, matched_files, regexps):
for file in self._files:
for regexp in regexps:
if re.match(regexp, file):
matched_files.append(file)
break
if any(re.match(regexp, file) for regexp in regexps):
matched_files.append(file)

@lru_cache(maxsize=None)
def _matching_files(self, match_group: T, match_dict: dict[T, list[str]]) -> list[str]:
Expand Down
26 changes: 13 additions & 13 deletions dev/retag_docker_images.py
Expand Up @@ -27,6 +27,7 @@
# * when starting new release branch (for example `v2-1-test`)
# * when renaming a branch
#
import itertools
import subprocess

import rich_click as click
Expand All @@ -52,19 +53,18 @@ def pull_push_all_images(
target_branch: str,
target_repo: str,
):
for python in PYTHON_VERSIONS:
for image in images:
source_image = image.format(
prefix=source_prefix, branch=source_branch, repo=source_repo, python=python
)
target_image = image.format(
prefix=target_prefix, branch=target_branch, repo=target_repo, python=python
)
print(f"Copying image: {source_image} -> {target_image}")
subprocess.run(
["regctl", "image", "copy", "--force-recursive", "--digest-tags", source_image, target_image],
check=True,
)
for python, image in itertools.product(PYTHON_VERSIONS, images):
source_image = image.format(
prefix=source_prefix, branch=source_branch, repo=source_repo, python=python
)
target_image = image.format(
prefix=target_prefix, branch=target_branch, repo=target_repo, python=python
)
print(f"Copying image: {source_image} -> {target_image}")
subprocess.run(
["regctl", "image", "copy", "--force-recursive", "--digest-tags", source_image, target_image],
check=True,
)


@click.group(invoke_without_command=True)
Expand Down
7 changes: 3 additions & 4 deletions docs/exts/airflow_intersphinx.py
Expand Up @@ -151,10 +151,9 @@ def domain_and_object_type_to_role(domain: str, object_type: str) -> str:
def inspect_main(inv_data, name) -> None:
try:
for key in sorted(inv_data or {}):
for entry, _ in sorted(inv_data[key].items()):
domain, object_type = key.split(":")
role_name = domain_and_object_type_to_role(domain, object_type)

domain, object_type = key.split(":")
role_name = domain_and_object_type_to_role(domain, object_type)
for entry in sorted(inv_data[key].keys()):
print(f":{role_name}:`{name}:{entry}`")
except ValueError as exc:
print(exc.args[0] % exc.args[1:])
Expand Down
28 changes: 10 additions & 18 deletions tests/models/test_dag.py
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import datetime
import itertools
import logging
import os
import pickle
Expand Down Expand Up @@ -342,12 +343,9 @@ def test_dag_task_priority_weight_total(self):
[EmptyOperator(task_id=f"stage{i}.{j}", priority_weight=weight) for j in range(width)]
for i in range(depth)
]
for i, stage in enumerate(pipeline):
if i == 0:
continue
for current_task in stage:
for prev_task in pipeline[i - 1]:
current_task.set_upstream(prev_task)
for upstream, downstream in zip(pipeline, pipeline[1:]):
for up_task, down_task in itertools.product(upstream, downstream):
down_task.set_upstream(up_task)

for task in dag.task_dict.values():
match = pattern.match(task.task_id)
Expand Down Expand Up @@ -376,12 +374,9 @@ def test_dag_task_priority_weight_total_using_upstream(self):
]
for i in range(depth)
]
for i, stage in enumerate(pipeline):
if i == 0:
continue
for current_task in stage:
for prev_task in pipeline[i - 1]:
current_task.set_upstream(prev_task)
for upstream, downstream in zip(pipeline, pipeline[1:]):
for up_task, down_task in itertools.product(upstream, downstream):
down_task.set_upstream(up_task)

for task in dag.task_dict.values():
match = pattern.match(task.task_id)
Expand Down Expand Up @@ -409,12 +404,9 @@ def test_dag_task_priority_weight_total_using_absolute(self):
]
for i in range(depth)
]
for i, stage in enumerate(pipeline):
if i == 0:
continue
for current_task in stage:
for prev_task in pipeline[i - 1]:
current_task.set_upstream(prev_task)
for upstream, downstream in zip(pipeline, pipeline[1:]):
for up_task, down_task in itertools.product(upstream, downstream):
down_task.set_upstream(up_task)

for task in dag.task_dict.values():
# the sum of each stages after this task + itself
Expand Down
8 changes: 4 additions & 4 deletions tests/sensors/test_external_task_sensor.py
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

import itertools
import logging
import os
import tempfile
Expand Down Expand Up @@ -1159,10 +1160,9 @@ def test_external_task_marker_clear_activate(dag_bag_parent_child, session):
run_tasks(dag_bag, execution_date=day_2)

# Assert that dagruns of all the affected dags are set to SUCCESS before tasks are cleared.
for dag in dag_bag.dags.values():
for execution_date in [day_1, day_2]:
dagrun = dag.get_dagrun(execution_date=execution_date, session=session)
dagrun.set_state(State.SUCCESS)
for dag, execution_date in itertools.product(dag_bag.dags.values(), [day_1, day_2]):
dagrun = dag.get_dagrun(execution_date=execution_date, session=session)
dagrun.set_state(State.SUCCESS)
session.flush()

dag_0 = dag_bag.get_dag("parent_dag_0")
Expand Down

0 comments on commit 07fe1d2

Please sign in to comment.