Skip to content

Commit 99370fe

Browse files
authored
[AIRFLOW-7045] Update SQL query to delete RenderedTaskInstanceFields (#8051)
This is because "The composite IN construct is not supported by all backends" Based on discussion in #6788 (comment)
1 parent 6602160 commit 99370fe

File tree

2 files changed

+72
-41
lines changed

2 files changed

+72
-41
lines changed

airflow/models/renderedtifields.py

Lines changed: 49 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from typing import Optional
2020

2121
import sqlalchemy_jsonfield
22-
from sqlalchemy import Column, String, and_, tuple_
22+
from sqlalchemy import Column, String, and_, not_, tuple_
2323
from sqlalchemy.orm import Session
2424

2525
from airflow.configuration import conf
@@ -109,29 +109,52 @@ def delete_old_records(
109109
if num_to_keep <= 0:
110110
return
111111

112-
# Fetch Top X records given dag_id & task_id ordered by Execution Date
113-
subq1 = (
114-
session
115-
.query(cls.dag_id, cls.task_id, cls.execution_date)
116-
.filter(cls.dag_id == dag_id, cls.task_id == task_id)
117-
.order_by(cls.execution_date.desc())
112+
tis_to_keep_query = session \
113+
.query(cls.dag_id, cls.task_id, cls.execution_date) \
114+
.filter(cls.dag_id == dag_id, cls.task_id == task_id) \
115+
.order_by(cls.execution_date.desc()) \
118116
.limit(num_to_keep)
119-
.subquery('subq1')
120-
)
121-
122-
# Second Subquery
123-
# Workaround for MySQL Limitation (https://stackoverflow.com/a/19344141/5691525)
124-
# Limitation: This version of MySQL does not yet support
125-
# LIMIT & IN/ALL/ANY/SOME subquery
126-
subq2 = (
127-
session
128-
.query(subq1.c.dag_id, subq1.c.task_id, subq1.c.execution_date)
129-
.subquery('subq2')
130-
)
131-
132-
session.query(cls) \
133-
.filter(and_(
134-
cls.dag_id == dag_id,
135-
cls.task_id == task_id,
136-
tuple_(cls.dag_id, cls.task_id, cls.execution_date).notin_(subq2))) \
137-
.delete(synchronize_session=False)
117+
118+
if session.bind.dialect.name in ["postgresql", "sqlite"]:
119+
# Fetch Top X records given dag_id & task_id ordered by Execution Date
120+
subq1 = tis_to_keep_query.subquery('subq1')
121+
122+
session.query(cls) \
123+
.filter(and_(
124+
cls.dag_id == dag_id,
125+
cls.task_id == task_id,
126+
tuple_(cls.dag_id, cls.task_id, cls.execution_date).notin_(subq1))) \
127+
.delete(synchronize_session=False)
128+
elif session.bind.dialect.name in ["mysql"]:
129+
# Fetch Top X records given dag_id & task_id ordered by Execution Date
130+
subq1 = tis_to_keep_query.subquery('subq1')
131+
132+
# Second Subquery
133+
# Workaround for MySQL Limitation (https://stackoverflow.com/a/19344141/5691525)
134+
# Limitation: This version of MySQL does not yet support
135+
# LIMIT & IN/ALL/ANY/SOME subquery
136+
subq2 = (
137+
session
138+
.query(subq1.c.dag_id, subq1.c.task_id, subq1.c.execution_date)
139+
.subquery('subq2')
140+
)
141+
142+
session.query(cls) \
143+
.filter(and_(
144+
cls.dag_id == dag_id,
145+
cls.task_id == task_id,
146+
tuple_(cls.dag_id, cls.task_id, cls.execution_date).notin_(subq2))) \
147+
.delete(synchronize_session=False)
148+
else:
149+
# Fetch Top X records given dag_id & task_id ordered by Execution Date
150+
tis_to_keep = tis_to_keep_query.all()
151+
152+
filter_tis = [not_(and_(
153+
cls.dag_id == ti.dag_id,
154+
cls.task_id == ti.task_id,
155+
cls.execution_date == ti.execution_date
156+
)) for ti in tis_to_keep]
157+
158+
session.query(cls) \
159+
.filter(and_(*filter_tis)) \
160+
.delete(synchronize_session=False)

tests/models/test_renderedtifields.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from airflow.operators.bash import BashOperator
3232
from airflow.utils.session import create_session
3333
from airflow.utils.timezone import datetime
34+
from tests.test_utils.asserts import assert_queries_count
3435
from tests.test_utils.db import clear_rendered_ti_fields
3536

3637
TEST_DAG = DAG("example_rendered_ti_field", schedule_interval=None)
@@ -129,7 +130,15 @@ def test_get_templated_fields(self, templated_field, expected_rendered_field):
129130
ti2 = TI(task_2, EXECUTION_DATE)
130131
self.assertIsNone(RTIF.get_templated_fields(ti=ti2))
131132

132-
def test_delete_old_records(self):
133+
@parameterized.expand([
134+
(0, 1, 0, 1),
135+
(1, 1, 1, 1),
136+
(1, 0, 1, 0),
137+
(3, 1, 1, 1),
138+
(4, 2, 2, 1),
139+
(5, 2, 2, 1),
140+
])
141+
def test_delete_old_records(self, rtif_num, num_to_keep, remaining_rtifs, expected_query_count):
133142
"""
134143
Test that old records are deleted from rendered_task_instance_fields table
135144
for a given task_id and dag_id.
@@ -139,29 +148,28 @@ def test_delete_old_records(self):
139148
with dag:
140149
task = BashOperator(task_id="test", bash_command="echo {{ ds }}")
141150

142-
rtif_1 = RTIF(TI(task=task, execution_date=EXECUTION_DATE))
143-
rtif_2 = RTIF(TI(task=task, execution_date=EXECUTION_DATE + timedelta(days=1)))
144-
rtif_3 = RTIF(TI(task=task, execution_date=EXECUTION_DATE + timedelta(days=2)))
151+
rtif_list = [
152+
RTIF(TI(task=task, execution_date=EXECUTION_DATE + timedelta(days=num)))
153+
for num in range(rtif_num)
154+
]
145155

146-
session.add(rtif_1)
147-
session.add(rtif_2)
148-
session.add(rtif_3)
156+
session.add_all(rtif_list)
149157
session.commit()
150158

151159
result = session.query(RTIF)\
152160
.filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all()
153161

154-
self.assertIn(rtif_1, result)
155-
self.assertIn(rtif_2, result)
156-
self.assertIn(rtif_3, result)
157-
self.assertEqual(3, len(result))
162+
for rtif in rtif_list:
163+
self.assertIn(rtif, result)
164+
165+
self.assertEqual(rtif_num, len(result))
158166

159-
# Verify old records are deleted and only 1 record is kept
160-
RTIF.delete_old_records(task_id=task.task_id, dag_id=task.dag_id, num_to_keep=1)
167+
# Verify old records are deleted and only 'num_to_keep' records are kept
168+
with assert_queries_count(expected_query_count):
169+
RTIF.delete_old_records(task_id=task.task_id, dag_id=task.dag_id, num_to_keep=num_to_keep)
161170
result = session.query(RTIF) \
162171
.filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all()
163-
self.assertEqual(1, len(result))
164-
self.assertEqual(rtif_3.execution_date, result[0].execution_date)
172+
self.assertEqual(remaining_rtifs, len(result))
165173

166174
def test_write(self):
167175
"""

0 commit comments

Comments
 (0)