Skip to content

Commit

Permalink
Change DAG.clear to take dag_run_state (#9824)
Browse files Browse the repository at this point in the history
* Change DAG.clear to take dag_run_state

* fix lint

* fix tests

* assign var

* extend original clause
  • Loading branch information
milton0825 committed Jul 15, 2020
1 parent 6d65c15 commit b01d95e
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 41 deletions.
4 changes: 3 additions & 1 deletion airflow/cli/commands/dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from airflow.utils.cli import get_dag, get_dag_by_file_location, process_subdir, sigint_handler
from airflow.utils.dot_renderer import render_dag
from airflow.utils.session import create_session, provide_session
from airflow.utils.state import State


def _tabulate_dag_runs(dag_runs: List[DagRun], tablefmt: str = "fancy_grid") -> str:
Expand Down Expand Up @@ -123,6 +124,7 @@ def dag_backfill(args, dag=None):
end_date=args.end_date,
confirm_prompt=not args.yes,
include_subdags=True,
dag_run_state=State.NONE,
)

dag.run(
Expand Down Expand Up @@ -381,7 +383,7 @@ def dag_list_dag_runs(args, dag=None):
def dag_test(args, session=None):
"""Execute one single DagRun for a given DAG and execution date, using the DebugExecutor."""
dag = get_dag(subdir=args.subdir, dag_id=args.dag_id)
dag.clear(start_date=args.execution_date, end_date=args.execution_date, reset_dag_runs=True)
dag.clear(start_date=args.execution_date, end_date=args.execution_date, dag_run_state=State.NONE)
try:
dag.run(executor=DebugExecutor(), start_date=args.execution_date, end_date=args.execution_date)
except BackfillUnfinished as e:
Expand Down
55 changes: 28 additions & 27 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import warnings
from collections import OrderedDict
from datetime import datetime, timedelta
from typing import Callable, Collection, Dict, FrozenSet, Iterable, List, Optional, Set, Type, Union
from typing import Callable, Collection, Dict, FrozenSet, Iterable, List, Optional, Set, Type, Union, cast

import jinja2
import pendulum
Expand Down Expand Up @@ -297,7 +297,7 @@ def __init__(
template_searchpath = [template_searchpath]
self.template_searchpath = template_searchpath
self.template_undefined = template_undefined
self.parent_dag = None # Gets set when DAGs are loaded
self.parent_dag: Optional[DAG] = None # Gets set when DAGs are loaded
self.last_loaded = timezone.utcnow()
self.safe_dag_id = dag_id.replace('.', '__dot__')
self.max_active_runs = max_active_runs
Expand Down Expand Up @@ -966,7 +966,7 @@ def clear(
confirm_prompt=False,
include_subdags=True,
include_parentdag=True,
reset_dag_runs=True,
dag_run_state: str = State.RUNNING,
dry_run=False,
session=None,
get_tis=False,
Expand All @@ -993,8 +993,7 @@ def clear(
:type include_subdags: bool
:param include_parentdag: Clear tasks in the parent dag of the subdag.
:type include_parentdag: bool
:param reset_dag_runs: Set state of dag to RUNNING
:type reset_dag_runs: bool
:param dag_run_state: state to set DagRun to
:param dry_run: Find the tasks to clear but don't clear them.
:type dry_run: bool
:param session: The sqlalchemy session to use
Expand Down Expand Up @@ -1025,8 +1024,7 @@ def clear(
tis = session.query(TI).filter(TI.dag_id == self.dag_id)
tis = tis.filter(TI.task_id.in_(self.task_ids))

if include_parentdag and self.is_subdag:

if include_parentdag and self.is_subdag and self.parent_dag is not None:
p_dag = self.parent_dag.sub_dag(
task_regex=r"^{}$".format(self.dag_id.split('.')[1]),
include_upstream=False,
Expand All @@ -1039,7 +1037,7 @@ def clear(
confirm_prompt=confirm_prompt,
include_subdags=include_subdags,
include_parentdag=False,
reset_dag_runs=reset_dag_runs,
dag_run_state=dag_run_state,
get_tis=True,
session=session,
recursion_depth=recursion_depth,
Expand All @@ -1065,12 +1063,13 @@ def clear(
instances = tis.all()
for ti in instances:
if ti.operator == ExternalTaskMarker.__name__:
ti.task = self.get_task(ti.task_id)
task: ExternalTaskMarker = cast(ExternalTaskMarker, self.get_task(ti.task_id))
ti.task = task

if recursion_depth == 0:
# Maximum recursion depth allowed is the recursion_depth of the first
# ExternalTaskMarker in the tasks to be cleared.
max_recursion_depth = ti.task.recursion_depth
max_recursion_depth = task.recursion_depth

if recursion_depth + 1 > max_recursion_depth:
# Prevent cycles or accidents.
Expand All @@ -1080,10 +1079,10 @@ def clear(
.format(max_recursion_depth,
ExternalTaskMarker.__name__, ti.task_id))
ti.render_templates()
external_tis = session.query(TI).filter(TI.dag_id == ti.task.external_dag_id,
TI.task_id == ti.task.external_task_id,
external_tis = session.query(TI).filter(TI.dag_id == task.external_dag_id,
TI.task_id == task.external_task_id,
TI.execution_date ==
pendulum.parse(ti.task.execution_date))
pendulum.parse(task.execution_date))

for tii in external_tis:
if not dag_bag:
Expand All @@ -1103,7 +1102,7 @@ def clear(
confirm_prompt=confirm_prompt,
include_subdags=include_subdags,
include_parentdag=False,
reset_dag_runs=reset_dag_runs,
dag_run_state=dag_run_state,
get_tis=True,
session=session,
recursion_depth=recursion_depth + 1,
Expand Down Expand Up @@ -1134,16 +1133,18 @@ def clear(
do_it = utils.helpers.ask_yesno(question)

if do_it:
clear_task_instances(tis,
session,
dag=self,
)
if reset_dag_runs:
self.set_dag_runs_state(session=session,
start_date=start_date,
end_date=end_date,
state=State.NONE,
)
clear_task_instances(
tis,
session,
dag=self,
activate_dag_runs=False, # We will set DagRun state later.
)
self.set_dag_runs_state(
session=session,
start_date=start_date,
end_date=end_date,
state=dag_run_state,
)
else:
count = 0
print("Bail. Nothing was cleared.")
Expand All @@ -1161,7 +1162,7 @@ def clear_dags(
confirm_prompt=False,
include_subdags=True,
include_parentdag=False,
reset_dag_runs=True,
dag_run_state=State.RUNNING,
dry_run=False,
):
all_tis = []
Expand All @@ -1174,7 +1175,7 @@ def clear_dags(
confirm_prompt=False,
include_subdags=include_subdags,
include_parentdag=include_parentdag,
reset_dag_runs=reset_dag_runs,
dag_run_state=dag_run_state,
dry_run=True)
all_tis.extend(tis)

Expand Down Expand Up @@ -1202,7 +1203,7 @@ def clear_dags(
only_running=only_running,
confirm_prompt=False,
include_subdags=include_subdags,
reset_dag_runs=reset_dag_runs,
dag_run_state=dag_run_state,
dry_run=False,
)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
CloudDataFusionStopPipelineOperator, CloudDataFusionUpdateInstanceOperator,
)
from airflow.utils import dates
from airflow.utils.state import State

# [START howto_data_fusion_env_variables]
LOCATION = "europe-north1"
Expand Down Expand Up @@ -227,5 +228,5 @@
delete_pipeline >> delete_instance

if __name__ == "__main__":
dag.clear(reset_dag_runs=True)
dag.clear(dag_run_state=State.NONE)
dag.run()
3 changes: 2 additions & 1 deletion airflow/providers/google/cloud/example_dags/example_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from airflow.providers.google.cloud.transfers.gcs_to_local import GCSToLocalFilesystemOperator
from airflow.providers.google.cloud.transfers.local_to_gcs import LocalFilesystemToGCSOperator
from airflow.utils.dates import days_ago
from airflow.utils.state import State

default_args = {"start_date": days_ago(1)}

Expand Down Expand Up @@ -155,5 +156,5 @@


if __name__ == '__main__':
dag.clear(reset_dag_runs=True)
dag.clear(dag_run_state=State.NONE)
dag.run()
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
GoogleCampaignManagerReportSensor,
)
from airflow.utils import dates
from airflow.utils.state import State

PROFILE_ID = os.environ.get("MARKETING_PROFILE_ID", "123456789")
FLOODLIGHT_ACTIVITY_ID = os.environ.get("FLOODLIGHT_ACTIVITY_ID", 12345)
Expand Down Expand Up @@ -157,5 +158,5 @@
insert_conversion >> update_conversion

if __name__ == "__main__":
dag.clear(reset_dag_runs=True)
dag.clear(dag_run_state=State.NONE)
dag.run()
7 changes: 5 additions & 2 deletions tests/cli/commands/test_dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,8 @@ def test_dag_test(self, mock_get_dag, mock_executor):
subdir=cli_args.subdir, dag_id='example_bash_operator'
),
mock.call().clear(
start_date=cli_args.execution_date, end_date=cli_args.execution_date, reset_dag_runs=True
start_date=cli_args.execution_date, end_date=cli_args.execution_date,
dag_run_state=State.NONE,
),
mock.call().run(
executor=mock_executor.return_value,
Expand Down Expand Up @@ -461,7 +462,9 @@ def test_dag_test_show_dag(self, mock_get_dag, mock_executor, mock_render_dag):
subdir=cli_args.subdir, dag_id='example_bash_operator'
),
mock.call().clear(
start_date=cli_args.execution_date, end_date=cli_args.execution_date, reset_dag_runs=True
start_date=cli_args.execution_date,
end_date=cli_args.execution_date,
dag_run_state=State.NONE,
),
mock.call().run(
executor=mock_executor.return_value,
Expand Down
60 changes: 52 additions & 8 deletions tests/models/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import unittest
from contextlib import redirect_stdout
from tempfile import NamedTemporaryFile
from typing import Optional
from unittest import mock
from unittest.mock import patch

Expand Down Expand Up @@ -55,6 +56,12 @@

class TestDag(unittest.TestCase):

def setUp(self) -> None:
clear_db_runs()

def tearDown(self) -> None:
clear_db_runs()

@staticmethod
def _clean_up(dag_id: str):
with create_session() as session:
Expand Down Expand Up @@ -1355,8 +1362,14 @@ def test_create_dagrun_run_type_is_obtained_from_run_id(self):
dr = dag.create_dagrun(run_id="custom_is_set_to_manual", state=State.NONE)
assert dr.run_type == DagRunType.MANUAL.value

def test_clear_reset_dagruns(self):
dag_id = 'test_clear_dag_reset_dagruns'
@parameterized.expand(
[
(State.NONE,),
(State.RUNNING,),
]
)
def test_clear_set_dagrun_state(self, dag_run_state):
dag_id = 'test_clear_set_dagrun_state'
self._clean_up(dag_id)
task_id = 't1'
dag = DAG(dag_id, start_date=DEFAULT_DATE, max_active_runs=1)
Expand All @@ -1365,7 +1378,7 @@ def test_clear_reset_dagruns(self):
session = settings.Session()
dagrun_1 = dag.create_dagrun(
run_type=DagRunType.BACKFILL_JOB,
state=State.RUNNING,
state=State.FAILED,
start_date=DEFAULT_DATE,
execution_date=DEFAULT_DATE,
)
Expand All @@ -1378,7 +1391,7 @@ def test_clear_reset_dagruns(self):
dag.clear(
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=1),
reset_dag_runs=True,
dag_run_state=dag_run_state,
include_subdags=False,
include_parentdag=False,
session=session,
Expand All @@ -1392,17 +1405,48 @@ def test_clear_reset_dagruns(self):

self.assertEqual(len(dagruns), 1)
dagrun = dagruns[0] # type: DagRun
self.assertEqual(dagrun.state, State.NONE)
self.assertEqual(dagrun.state, dag_run_state)

@parameterized.expand([
(state, State.NONE)
for state in State.task_states if state != State.RUNNING
] + [(State.RUNNING, State.SHUTDOWN)]) # type: ignore
def test_clear_dag(self, ti_state_begin, ti_state_end: Optional[str]):
dag_id = 'test_clear_dag'
self._clean_up(dag_id)
task_id = 't1'
dag = DAG(dag_id, start_date=DEFAULT_DATE, max_active_runs=1)
t_1 = DummyOperator(task_id=task_id, dag=dag)

session = settings.Session() # type: ignore
dagrun_1 = dag.create_dagrun(
run_type=DagRunType.BACKFILL_JOB,
state=State.RUNNING,
start_date=DEFAULT_DATE,
execution_date=DEFAULT_DATE,
)
session.merge(dagrun_1)

task_instance_1 = TI(t_1, execution_date=DEFAULT_DATE, state=ti_state_begin)
task_instance_1.job_id = 123
session.merge(task_instance_1)
session.commit()

dag.clear(
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=1),
session=session,
)

task_instances = session.query(
DagRun,
TI,
).filter(
DagRun.dag_id == dag_id,
TI.dag_id == dag_id,
).all()

self.assertEqual(len(task_instances), 1)
task_instance = task_instances[0] # type: TI
self.assertEqual(task_instance.state, State.NONE)
self.assertEqual(task_instance.state, ti_state_end)
self._clean_up(dag_id)


Expand Down

0 comments on commit b01d95e

Please sign in to comment.