Skip to content

Commit

Permalink
Change TaskInstance and TaskReschedule PK from execution_date to run_…
Browse files Browse the repository at this point in the history
…id (#17719)

Since TaskReschedule had an existing FK to TaskInstance we had to move
change both of these at the same time.

This puts an explicit FK constraint between TaskInstance and DagRun,
meaning that we can remove a lot of "find TIs without DagRun" code in
the scheduler too, as that is no longer a possible situation.

Since there is now an explicit foreign key between TaskInstance and
DagRun, we can remove a lot of the "cleanup" code in the scheduler that
was dealing with this.

This change was made as part of AIP-39

Co-authored-by: Tzu-ping Chung <[email protected]>
  • Loading branch information
ashb and uranusjr committed Sep 7, 2021
1 parent 022b4e0 commit 944dcfb
Show file tree
Hide file tree
Showing 132 changed files with 4,284 additions and 4,513 deletions.
10 changes: 10 additions & 0 deletions UPDATING.md
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,16 @@ Now that the DAG parser syncs DAG permissions there is no longer a need for manu
In addition, the `/refresh` and `/refresh_all` webserver endpoints have also been removed.
### TaskInstances now *require* a DagRun
Under normal operation every TaskInstance row in the database would have DagRun row too, but it was possible to manually delete the DagRun and Airflow would still schedule the TaskInstances.
In Airflow 2.2 we have changed this and now there is a database-level foreign key constraint ensuring that every TaskInstance has a DagRun row.
Before updating to this 2.2 release you will have to manually resolve any inconsistencies (add back DagRun rows, or delete TaskInstances) if you have any "dangling" TaskInstance" rows.
As part of this change the `clean_tis_without_dagrun_interval` config option under `[scheduler]` section has been removed and has no effect.
## Airflow 2.1.3
No breaking changes.
Expand Down
5 changes: 4 additions & 1 deletion airflow/api/common/experimental/mark_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import Iterable

from sqlalchemy import or_
from sqlalchemy.orm import contains_eager

from airflow.models.baseoperator import BaseOperator
from airflow.models.dagrun import DagRun
Expand Down Expand Up @@ -148,12 +149,14 @@ def get_all_dag_task_query(dag, session, state, task_ids, confirmed_dates):
"""Get all tasks of the main dag that will be affected by a state change"""
qry_dag = (
session.query(TaskInstance)
.join(TaskInstance.dag_run)
.filter(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.execution_date.in_(confirmed_dates),
DagRun.execution_date.in_(confirmed_dates),
TaskInstance.task_id.in_(task_ids),
)
.filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state))
.options(contains_eager(TaskInstance.dag_run))
)
return qry_dag

Expand Down
18 changes: 10 additions & 8 deletions airflow/api_connexion/endpoints/log_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
from flask import Response, current_app, request
from itsdangerous.exc import BadSignature
from itsdangerous.url_safe import URLSafeSerializer
from sqlalchemy.orm import eagerload

from airflow.api_connexion import security
from airflow.api_connexion.exceptions import BadRequest, NotFound
from airflow.api_connexion.schemas.log_schema import LogResponseObject, logs_schema
from airflow.exceptions import TaskNotFound
from airflow.models import DagRun
from airflow.models import TaskInstance
from airflow.security import permissions
from airflow.utils.log.log_reader import TaskLogReader
from airflow.utils.session import provide_session
Expand Down Expand Up @@ -60,15 +61,16 @@ def get_log(session, dag_id, dag_run_id, task_id, task_try_number, full_content=
if not task_log_reader.supports_read:
raise BadRequest("Task log handler does not support read logs.")

query = session.query(DagRun).filter(DagRun.dag_id == dag_id)
dag_run = query.filter(DagRun.run_id == dag_run_id).first()
if not dag_run:
raise NotFound("DAG Run not found")

ti = dag_run.get_task_instance(task_id, session)
ti = (
session.query(TaskInstance)
.filter(TaskInstance.task_id == task_id, TaskInstance.run_id == dag_run_id)
.join(TaskInstance.dag_run)
.options(eagerload(TaskInstance.dag_run))
.one_or_none()
)
if ti is None:
metadata['end_of_log'] = True
raise BadRequest(detail="Task instance did not exist in the DB")
raise NotFound(title="TaskInstance not found")

dag = current_app.dag_bag.get_dag(dag_id)
if dag:
Expand Down
39 changes: 14 additions & 25 deletions airflow/api_connexion/endpoints/task_instance_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from flask import current_app, request
from marshmallow import ValidationError
from sqlalchemy import and_, func
from sqlalchemy.orm import eagerload

from airflow.api_connexion import security
from airflow.api_connexion.exceptions import BadRequest, NotFound
Expand Down Expand Up @@ -54,15 +55,14 @@ def get_task_instance(dag_id: str, dag_run_id: str, task_id: str, session=None):
"""Get task instance"""
query = (
session.query(TI)
.filter(TI.dag_id == dag_id)
.join(DR, and_(TI.dag_id == DR.dag_id, TI.execution_date == DR.execution_date))
.filter(DR.run_id == dag_run_id)
.filter(TI.task_id == task_id)
.filter(TI.dag_id == dag_id, DR.run_id == dag_run_id, TI.task_id == task_id)
.join(TI.dag_run)
.options(eagerload(TI.dag_run))
.outerjoin(
SlaMiss,
and_(
SlaMiss.dag_id == TI.dag_id,
SlaMiss.execution_date == TI.execution_date,
SlaMiss.execution_date == DR.execution_date,
SlaMiss.task_id == TI.task_id,
),
)
Expand Down Expand Up @@ -127,13 +127,12 @@ def get_task_instances(
session=None,
):
"""Get list of task instances."""
base_query = session.query(TI)
base_query = session.query(TI).join(TI.dag_run).options(eagerload(TI.dag_run))

if dag_id != "~":
base_query = base_query.filter(TI.dag_id == dag_id)
if dag_run_id != "~":
base_query = base_query.join(DR, and_(TI.dag_id == DR.dag_id, TI.execution_date == DR.execution_date))
base_query = base_query.filter(DR.run_id == dag_run_id)
base_query = base_query.filter(TI.run_id == dag_run_id)
base_query = _apply_range_filter(
base_query,
key=DR.execution_date,
Expand All @@ -156,7 +155,7 @@ def get_task_instances(
and_(
SlaMiss.dag_id == TI.dag_id,
SlaMiss.task_id == TI.task_id,
SlaMiss.execution_date == TI.execution_date,
SlaMiss.execution_date == DR.execution_date,
),
isouter=True,
)
Expand All @@ -183,12 +182,12 @@ def get_task_instances_batch(session=None):
data = task_instance_batch_form.load(body)
except ValidationError as err:
raise BadRequest(detail=str(err.messages))
base_query = session.query(TI)
base_query = session.query(TI).join(TI.dag_run).options(eagerload(TI.dag_run))

base_query = _apply_array_filter(base_query, key=TI.dag_id, values=data["dag_ids"])
base_query = _apply_range_filter(
base_query,
key=TI.execution_date,
key=DR.execution_date,
value_range=(data["execution_date_gte"], data["execution_date_lte"]),
)
base_query = _apply_range_filter(
Expand All @@ -214,7 +213,7 @@ def get_task_instances_batch(session=None):
and_(
SlaMiss.dag_id == TI.dag_id,
SlaMiss.task_id == TI.task_id,
SlaMiss.execution_date == TI.execution_date,
SlaMiss.execution_date == DR.execution_date,
),
isouter=True,
)
Expand Down Expand Up @@ -254,9 +253,7 @@ def post_clear_task_instances(dag_id: str, session=None):
clear_task_instances(
task_instances.all(), session, dag=dag, dag_run_state=State.RUNNING if reset_dag_runs else False
)
task_instances = task_instances.join(
DR, and_(DR.dag_id == TI.dag_id, DR.execution_date == TI.execution_date)
).add_column(DR.run_id)
task_instances = task_instances.join(TI.dag_run).options(eagerload(TI.dag_run))
return task_instance_reference_collection_schema.dump(
TaskInstanceReferenceCollection(task_instances=task_instances.all())
)
Expand Down Expand Up @@ -303,14 +300,6 @@ def post_set_task_instances_state(dag_id, session):
future=data["include_future"],
past=data["include_past"],
commit=not data["dry_run"],
session=session,
)
execution_dates = {ti.execution_date for ti in tis}
execution_date_to_run_id_map = dict(
session.query(DR.execution_date, DR.run_id).filter(
DR.dag_id == dag_id, DR.execution_date.in_(execution_dates)
)
)
tis_with_run_id = [(ti, execution_date_to_run_id_map.get(ti.execution_date)) for ti in tis]
return task_instance_reference_collection_schema.dump(
TaskInstanceReferenceCollection(task_instances=tis_with_run_id)
)
return task_instance_reference_collection_schema.dump(TaskInstanceReferenceCollection(task_instances=tis))
10 changes: 1 addition & 9 deletions airflow/api_connexion/schemas/task_instance_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,18 +134,10 @@ class TaskInstanceReferenceSchema(Schema):
"""Schema for the task instance reference schema"""

task_id = fields.Str()
dag_run_id = fields.Str()
run_id = fields.Str(data_key="dag_run_id")
dag_id = fields.Str()
execution_date = fields.DateTime()

def get_attribute(self, obj, attr, default):
"""Overwritten marshmallow function"""
task_instance_attr = ['task_id', 'execution_date', 'dag_id']
if attr in task_instance_attr:
obj = obj[0] # As object is a tuple of task_instance and dag_run_id
return get_value(obj, attr, default)
return obj[1]


class TaskInstanceReferenceCollection(NamedTuple):
"""List of objects with metadata about taskinstance and dag_run_id"""
Expand Down
4 changes: 3 additions & 1 deletion airflow/cli/commands/dag_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,11 @@ def dag_backfill(args, dag=None):

if args.dry_run:
print(f"Dry run of DAG {args.dag_id} on {args.start_date}")
dr = DagRun(dag.dag_id, execution_date=args.start_date)
for task in dag.tasks:
print(f"Task {task.task_id}")
ti = TaskInstance(task, args.start_date)
ti = TaskInstance(task, run_id=None)
ti.dag_run = dr
ti.dry_run()
else:
if args.reset_dagruns:
Expand Down
6 changes: 4 additions & 2 deletions airflow/cli/commands/kubernetes_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from airflow.kubernetes import pod_generator
from airflow.kubernetes.kube_client import get_kube_client
from airflow.kubernetes.pod_generator import PodGenerator
from airflow.models import TaskInstance
from airflow.models import DagRun, TaskInstance
from airflow.settings import pod_mutation_hook
from airflow.utils import cli as cli_utils, yaml
from airflow.utils.cli import get_dag
Expand All @@ -38,9 +38,11 @@ def generate_pod_yaml(args):
execution_date = args.execution_date
dag = get_dag(subdir=args.subdir, dag_id=args.dag_id)
yaml_output_path = args.output_path
dr = DagRun(dag.dag_id, execution_date=execution_date)
kube_config = KubeConfig()
for task in dag.tasks:
ti = TaskInstance(task, execution_date)
ti = TaskInstance(task, None)
ti.dag_run = dr
pod = PodGenerator.construct_pod(
dag_id=args.dag_id,
task_id=ti.task_id,
Expand Down

0 comments on commit 944dcfb

Please sign in to comment.