Skip to content

Commit

Permalink
Correctly restore upstream_task_ids when deserializing Operators (#8775)
Browse files Browse the repository at this point in the history
This test exposed a bug in one of the example dags, that wasn't caught
by #6549. That will be a fixed in a separate issue, but it caused the
round-trip tests to fail here

Fixes #8720
  • Loading branch information
ashb committed May 10, 2020
1 parent a715aa6 commit 280f1f0
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 3 deletions.
4 changes: 2 additions & 2 deletions airflow/providers/google/cloud/example_dags/example_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@
)

# [START howto_operator_gcs_delete_bucket]
delete_bucket_1 = GCSDeleteBucketOperator(task_id="delete_bucket", bucket_name=BUCKET_1)
delete_bucket_2 = GCSDeleteBucketOperator(task_id="delete_bucket", bucket_name=BUCKET_2)
delete_bucket_1 = GCSDeleteBucketOperator(task_id="delete_bucket_1", bucket_name=BUCKET_1)
delete_bucket_2 = GCSDeleteBucketOperator(task_id="delete_bucket_2", bucket_name=BUCKET_2)
# [END howto_operator_gcs_delete_bucket]

[create_bucket1, create_bucket2] >> list_buckets >> list_buckets_result
Expand Down
2 changes: 1 addition & 1 deletion airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ def deserialize_dag(cls, encoded_dag: Dict[str, Any]) -> 'SerializedDAG':
for task_id in serializable_task.downstream_task_ids:
# Bypass set_upstream etc here - it does more than we want
# noinspection PyProtectedMember
dag.task_dict[task_id]._upstream_task_ids.add(task_id) # pylint: disable=protected-access
dag.task_dict[task_id]._upstream_task_ids.add(serializable_task.task_id) # noqa: E501 # pylint: disable=protected-access

return dag

Expand Down
3 changes: 3 additions & 0 deletions tests/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,9 @@ def validate_deserialized_task(self, serialized_task, task,):
assert serialized_task.task_type == task.task_type
assert set(serialized_task.template_fields) == set(task.template_fields)

assert serialized_task.upstream_task_ids == task.upstream_task_ids
assert serialized_task.downstream_task_ids == task.downstream_task_ids

for field in fields_to_check:
assert getattr(serialized_task, field) == getattr(task, field), \
f'{task.dag.dag_id}.{task.task_id}.{field} does not match'
Expand Down

0 comments on commit 280f1f0

Please sign in to comment.