Skip to content

Move BaseHook class to task SDK #51873

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 55 commits into
base: main
Choose a base branch
from

Conversation

amoghrajesh
Copy link
Contributor

@amoghrajesh amoghrajesh commented Jun 18, 2025

closes: #51672

Moving the BaseHook class into task SDK, exactly where it should live similar to other base classes in sdk like notifiers, operators etc.

Testing: Running by all the methods defined on BaseHook -- old vs new

Running with the new path

DAG:

from __future__ import annotations

from contextlib import suppress

from airflow.models.baseoperator import BaseOperator
from airflow import DAG
from airflow.sdk import BaseHook


class CustomOperator(BaseOperator):
    def execute(self, context):
        conn = BaseHook.get_connection("athena_default")
        print("got connection from basehook", conn)


        hook = BaseHook.get_hook("athena_default")
        print("got hook from basehook", hook)

        with suppress(NotImplementedError):
            BaseHook.get_conn(hook)
            print("Raising a NotImplementedError, trying to access get_conn")

        logger = BaseHook.logger()
        print("default logger is", logger)


with DAG("get_connection_basehook", schedule=None, catchup=False) as dag:
    CustomOperator(task_id="set_var")

image

Running with the older path to check backcompat

from __future__ import annotations

from contextlib import suppress

from airflow.models.baseoperator import BaseOperator
from airflow import DAG
from airflow.hooks.base import BaseHook


class CustomOperator(BaseOperator):
    def execute(self, context):
        conn = BaseHook.get_connection("athena_default")
        print("got connection from basehook", conn)


        hook = BaseHook.get_hook("athena_default")
        print("got hook from basehook", hook)

        with suppress(NotImplementedError):
            BaseHook.get_conn(hook)
            print("Raising a NotImplementedError, trying to access get_conn")

        logger = BaseHook.logger()
        print("default logger is", logger)


with DAG("get_connection_basehook", schedule=None, catchup=False) as dag:
    CustomOperator(task_id="set_var")

image


^ Add meaningful description above
Read the Pull Request Guidelines for more information.
In case of fundamental code changes, an Airflow Improvement Proposal (AIP) is needed.
In case of a new dependency, check compliance with the ASF 3rd Party License Policy.
In case of backwards incompatible changes please leave a note in a newsfragment file, named {pr_number}.significant.rst or {issue_number}.significant.rst, in airflow-core/newsfragments.

@amoghrajesh amoghrajesh requested a review from josh-fell as a code owner June 27, 2025 13:12
@amoghrajesh amoghrajesh force-pushed the move-basehook-to-task-sdk branch from 4106b95 to f166ee4 Compare June 30, 2025 07:18
Comment on lines +378 to +379
# TODO: Revisit this
status, message = conn.test_connection() # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this work or fail currently -- with this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I basically haven't made any change here, just to think a bit loud, this one can use Connection from task SDK right?

Right now its not

try:
from airflow.sdk import BaseHook
except ImportError:
from airflow.hooks.base import BaseHook # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be stricter:

Suggested change
from airflow.hooks.base import BaseHook # type: ignore
from airflow.hooks.base import BaseHook # type: ignore[no-redef]

try:
from airflow.sdk import BaseHook
except ImportError:
from airflow.hooks.base import BaseHook # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from airflow.hooks.base import BaseHook # type: ignore
from airflow.hooks.base import BaseHook # type: ignore[no-redef]

try:
from airflow.sdk import BaseHook
except ImportError:
from airflow.hooks.base import BaseHook # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from airflow.hooks.base import BaseHook # type: ignore
from airflow.hooks.base import BaseHook # type: ignore[no-redef]

@@ -93,7 +97,7 @@ def __init__(self, region: str | None = None, oss_conn_id="oss_default", *args,

def get_conn(self) -> Connection:
"""Return connection for the hook."""
return self.oss_conn
return self.oss_conn # type: ignore[return-value]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can/should be fixed -- if you change L37 to be Connection from SDK

@@ -127,7 +127,10 @@ def conn_config(self) -> AwsConnectionWrapper:
)

return AwsConnectionWrapper(
conn=connection, region_name=self._region_name, botocore_config=self._config, verify=self._verify
conn=connection, # type: ignore[arg-type]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

Comment on lines +66 to +69
try:
from airflow.sdk import BaseHook
except ImportError:
from airflow.hooks.base import BaseHook # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth moving to version_compat.py -- same as #52528

@@ -69,7 +69,7 @@ def _get_webhook_endpoint(self, conn_id: str) -> str:
token = conn.password
if token is None:
raise AirflowException("Webhook token field is missing and is required.")
url = conn.schema + "://" + conn.host
url = cast("str", conn.schema) + "://" + cast("str", conn.host)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add an error above, since you expect conn.schema and conn.host to be non-None -- similar to token field above

try:
from airflow.sdk import BaseHook
except ImportError:
from airflow.hooks.base import BaseHook # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Strictier typer ignore please for all

@amoghrajesh
Copy link
Contributor Author

@kaxil thanks for the review, pretty positive ill be getting a green run now. Will look at your comments tomorrow!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Move BaseHook to TaskSDK under airflow.sdk.bases
4 participants