Skip to content

Re-create PR for Enable Serde for Pydantic BaseModel and Subclasses #52360

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from

Conversation

sjyangkevin
Copy link
Contributor

@sjyangkevin sjyangkevin commented Jun 27, 2025

The chage in #51059 broke canary tests in https://github.com/apache/airflow/actions/runs/15910813471/job/44877861492. The PR is reverted through #52312.

Create this PR to fix the issue.

The issue in #51059

First, the refactored deserialize take the actual class instead of classname. Therefore, the following test case is modified from "zi = deserialize("backports.zoneinfo.ZoneInfo", 1, "Asia/Taipei")" to zi = deserialize(ZoneInfo, 1, "Asia/Taipei"). I think the "backports.zoneinfo.ZoneInfo" is a backport of the standard library module zoneinfo in Python 3.9, to suppor lower version of Python.

In the airflow-core/src/airflow/serialization/serializers/timezone.py, the line below is looking for the backport ZoneInfo, instead of the standard library zoneinfo. Therefore, when a ZoneInfo class is passed into the deserialize, the qualname(cls) is resolved to "zoneinfo.ZoneInfo" instead of "backports.zoneinfo.ZoneInfo". Therefore, the data will be deserialized through parse_timezone(data), resulting in a different object. Therefore the test case failed.

if qualname(cls) == "backports.zoneinfo.ZoneInfo" and isinstance(data, str):
def test_timezone_deserialize_zoneinfo(self):
    from airflow.serialization.serializers.timezone import deserialize

    zi = deserialize(ZoneInfo, 1, "Asia/Taipei")
    assert isinstance(zi, ZoneInfo)
    assert zi.key == "Asia/Taipei"

Solution

To resolve this issue, I updated the if condition to if cls is ZoneInfo and isinstance(data, str):. I think the minimum version of Python we support is 3.9. So, "backports.zoneinfo.ZoneInfo" should probably be removed.

After making the changes, I ran the following tests and all checks passed.

breeze --python 3.9 testing core-tests --test-type Serialization
breeze --python 3.9 testing core-tests --test-type Serialization --downgrade-pendulum

breeze --python 3.10 testing core-tests --test-type Serialization
breeze --python 3.10 testing core-tests --test-type Serialization --downgrade-pendulum

breeze --python 3.11 testing core-tests --test-type Serialization
breeze --python 3.11 testing core-tests --test-type Serialization --downgrade-pendulum

but I got some container issues when running for 3.12. Will look into the full test results and action accordingly.


^ Add meaningful description above
Read the Pull Request Guidelines for more information.
In case of fundamental code changes, an Airflow Improvement Proposal (AIP) is needed.
In case of a new dependency, check compliance with the ASF 3rd Party License Policy.
In case of backwards incompatible changes please leave a note in a newsfragment file, named {pr_number}.significant.rst or {issue_number}.significant.rst, in airflow-core/newsfragments.

@potiuk potiuk added the full tests needed We need to run full set of tests for this PR to merge label Jun 27, 2025
@potiuk
Copy link
Member

potiuk commented Jun 27, 2025

I applied "full tests needed" and resolved conflict with Python 3.9 removal to trigger the build

@@ -92,7 +92,7 @@ def deserialize(cls: type, version: int, data: dict | str) -> datetime.date | da
if cls is DateTime and isinstance(data, dict):
return DateTime.fromtimestamp(float(data[TIMESTAMP]), tz=tz)

if cls is datetime.timedelta and isinstance(data, (str, float)):
if classname == qualname(datetime.timedelta) and isinstance(data, str | float):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to make an update to this. I think if cls is datetime.timedelta and isinstance(data, (str, float)): should overwrite this line. the use of classname is before refactoring, and the breeze test script raised an error for str | float. Will push a commit soon.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isintance(data, str | float) seems like an invalid syntax (correct me if I am wrong), the second argument of isinstance accept a class or tuple. The change is introduced #52072.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I think I am wrong, this is actually valid > 3.9, I ran on 3.9 and get the error.

@sjyangkevin sjyangkevin force-pushed the issues/50867/cohere-serde branch 2 times, most recently from 5164de8 to 8d4d944 Compare June 27, 2025 21:54
@potiuk potiuk force-pushed the issues/50867/cohere-serde branch from 8d4d944 to c79af1f Compare June 28, 2025 08:54
@potiuk
Copy link
Member

potiuk commented Jun 28, 2025

REbased again - we had some microsoft kiota new version released that broke main

if not is_pydantic_model(o):
return "", "", 0, False

model = cast("BaseModel", o) # for mypy
Copy link
Contributor Author

@sjyangkevin sjyangkevin Jun 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to see if I it is a good practice to replace this line, and another one in the deserialize with # type: ignore. This line is merely to convince mypy that the model is actually a pydantic model that has the model_dump method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this cast and use # type: ignore

@sjyangkevin sjyangkevin force-pushed the issues/50867/cohere-serde branch from c79af1f to 85f95a7 Compare June 30, 2025 04:49
@sjyangkevin
Copy link
Contributor Author

Hi @bolkedebruin , @potiuk , @amoghrajesh , I would appreciate if I could get your help again to review this PR. I also wanted to share some findings while doing testing.

Issue with serializing numpy.bool

It looks like the serialize method in the numpy.py module has issue with serializing numpy.bool objects. I think it's because the qualified name (i.e., numpy.bool) is not matched to the ones set in the list. I attempted to add an entry numpy.bool into the serializers list, and it could solve the issue. However, it failed the test airflow-core/tests/unit/serialization/test_serde.py::TestSerDe::test_serializers_importable_and_str and details can be found here: https://github.com/apache/airflow/actions/runs/15964146039/job/45021686182.

serializers = [
    "numpy.int8",
    "numpy.int16",
    "numpy.int32",
    "numpy.int64",
    "numpy.uint8",
    "numpy.uint16",
    "numpy.uint32",
    "numpy.uint64",
    "numpy.bool_",
    "numpy.float64",
    "numpy.float16",
    "numpy.complex128",
    "numpy.complex64",
]

Screenshot from 2025-06-30 01-31-05

Screenshot from 2025-06-30 00-29-47

Issue with an unit test case

I also found a failed test case when I ran pytest in the breeze environment on the airflow-core/tests/unit/serialization. It looks like the DAG is missing, and not sure if we need some clean up for this, and I didn't see this failure when running breeze testing core-tests --test-type Serialization.

Screenshot from 2025-06-30 01-27-00

Test DAG code

I also updated the DAG code to test for as many objects as possible in serialize/deserialize, except for iceberg and deltalake.

from airflow.decorators import dag, task
from pendulum import datetime

@dag(
    start_date=datetime(2025, 5, 23),
    schedule=None,
    catchup=False,
    tags=["serialization", "pydantic", "airflow"],
)
def pydantic_serde():

    # 1. Pandas DataFrame
    @task
    def get_pandas():
        import pandas as pd
        import numpy as np
        df = pd.DataFrame(np.random.randn(3, 2), columns=list("AB"))
        return df

    @task
    def print_pandas(df):
        print("Pandas DataFrame:", df)

    # 2. Decimal
    @task
    def get_bignum():
        from decimal import Decimal
        return Decimal(1) / Decimal(7)

    @task
    def print_bignum(n):
        print("Decimal:", n, type(n))

    # 3. Built-in collections
    @task
    def get_all_builtins():
        return {
            "list": [1, 2, 3],
            "set": {4, 5},
            "tuple": (6, 7),
            "frozenset": frozenset([8, 9])
        }

    @task
    def print_all_builtins(obj):
        print("Built-in Types:")
        for k, v in obj.items():
            print(f"{k}: {v} ({type(v)})")

    # 4. NumPy scalar types - integers
    @task
    def get_numpy_ints():
        import numpy as np
        return {
            "int8": np.int8(8),
            "int16": np.int16(16),
            "int32": np.int32(32),
            "int64": np.int64(64),
            "uint8": np.uint8(8),
            "uint16": np.uint16(16),
            "uint32": np.uint32(32),
            "uint64": np.uint64(64),
        }

    @task
    def print_numpy_ints(obj):
        print("NumPy Integers:")
        for k, v in obj.items():
            print(f"{k}: {v} ({type(v)})")

    # 5. NumPy scalar types - misc
    @task
    def get_numpy_misc():
        import numpy as np
        return {
            "bool_": np.bool_(0),
            "float16": np.float16(0.125),
            "float64": np.float64(3.14159),
            "complex64": np.complex64(1 + 2j),
            "complex128": np.complex128(3 + 4j),
        }

    @task
    def print_numpy_misc(obj):
        print("NumPy Misc Types:")
        for k, v in obj.items():
            print(f"{k}: {v} ({type(v)})")

    # 6. Python datetime types
    @task
    def get_python_datetime_types():
        import datetime
        return {
            "date": datetime.date(2025, 6, 29),
            "datetime": datetime.datetime(2025, 6, 29, 12, 34, 56),
            "timedelta": datetime.timedelta(days=1, seconds=3600)
        }

    @task
    def print_python_datetime_types(obj):
        print("Python datetime types:")
        for k, v in obj.items():
            print(f"{k}: {v} ({type(v)})")

    # 7. Pendulum datetime
    @task
    def get_pendulum_datetime_type():
        import pendulum
        dt = pendulum.datetime(2025, 6, 29, 12, 34, 56, tz="Europe/Paris")
        return dt

    @task
    def print_pendulum_datetime_type(dt):
        print("Pendulum DateTime:", dt, type(dt))

    # 8. Timezone-aware datetime (ZoneInfo)
    @task
    def get_timezone_aware():
        from zoneinfo import ZoneInfo
        from datetime import datetime as dt
        return dt(2025, 6, 29, 12, 0, tzinfo=ZoneInfo("America/New_York"))

    @task
    def print_timezone_aware(tz_dt):
        print("Timezone-aware datetime:", tz_dt, type(tz_dt))

    # 9. Pendulum timezone object
    @task
    def get_pendulum_tz():
        import pendulum
        return pendulum.timezone("Asia/Tokyo")

    @task
    def print_pendulum_tz(tz):
        print("Pendulum timezone:", tz, type(tz))

    # 10. ZoneInfo timezone object
    @task
    def get_zoneinfo_tz():
        from zoneinfo import ZoneInfo
        return ZoneInfo("America/Toronto")
    
    @task
    def print_zoneinfo_tz(tz):
        print("ZoneInfo timezone:", tz, type(tz))

    # 11. Cohere embeddings (Pydantic model)
    @task
    def get_embeddings():
        import pydantic
        from airflow.providers.cohere.hooks.cohere import CohereHook

        hook = CohereHook()
        embeddings = hook.create_embeddings(["gruyere"])

        print("Cohere embeddings type:", type(embeddings))
        print("Is Pydantic model?", isinstance(embeddings, pydantic.BaseModel))
        return embeddings

    @task
    def print_embeddings(obj):
        print("Cohere Embeddings (Pydantic Model):", obj)

    # DAG chaining
    print_pandas(get_pandas())
    print_bignum(get_bignum())
    print_all_builtins(get_all_builtins())
    print_numpy_ints(get_numpy_ints())
    print_numpy_misc(get_numpy_misc())
    print_python_datetime_types(get_python_datetime_types())
    print_pendulum_datetime_type(get_pendulum_datetime_type())
    print_timezone_aware(get_timezone_aware())
    print_pendulum_tz(get_pendulum_tz())
    print_zoneinfo_tz(get_zoneinfo_tz())
    print_embeddings(get_embeddings())

pydantic_serde()

DAG Test Results

I modified the Cohere provider and let it return a Pydantic class in the breeze environment. 1.) When the Pydantic model is in the whitelist. All the tests defined in the DAG passed. 2.) When the Pydantic model is removed from the whitelist, the print_embeddings task failed due to ImportError.

To whitelist the Pydantic model. Add the following to files/airflow-breeze-config/environment_variables.env

AIRFLOW__CORE__ALLOWED_DESERIALIZATION_CLASSES=cohere.types.embed_by_type_response_embeddings.EmbedByTypeResponseEmbeddings

Screenshot from 2025-06-30 00-12-07
The numpy check passed because I added numpy.bool into the serializers list, but this change is not added to this PR because it failed the checks.

Screenshot from 2025-06-30 00-14-53

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
area:serialization full tests needed We need to run full set of tests for this PR to merge
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants