Skip to content

Commit

Permalink
Add more accurate typing for DbApiHook.run method (#31846)
Browse files Browse the repository at this point in the history
Co-authored-by: eladkal <[email protected]>
  • Loading branch information
dwreeves and eladkal committed Jul 18, 2023
1 parent 7ed791d commit 60c49ab
Show file tree
Hide file tree
Showing 16 changed files with 181 additions and 86 deletions.
2 changes: 1 addition & 1 deletion airflow/providers/apache/hive/hooks/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,7 +1014,7 @@ def to_csv(
self.log.info("Done. Loaded a total of %s rows.", i)

def get_records(
self, sql: str | list[str], parameters: Iterable | Mapping | None = None, **kwargs
self, sql: str | list[str], parameters: Iterable | Mapping[str, Any] | None = None, **kwargs
) -> Any:
"""
Get a set of records from a Hive query; optionally pass a 'schema' kwarg to specify target schema.
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/apache/pinot/hooks/pinot.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ def get_uri(self) -> str:
return f"{conn_type}://{host}/{endpoint}"

def get_records(
self, sql: str | list[str], parameters: Iterable | Mapping | None = None, **kwargs
self, sql: str | list[str], parameters: Iterable | Mapping[str, Any] | None = None, **kwargs
) -> Any:
"""
Executes the sql and returns a set of records.
Expand All @@ -301,7 +301,7 @@ def get_records(
cur.execute(sql)
return cur.fetchall()

def get_first(self, sql: str | list[str], parameters: Iterable | Mapping | None = None) -> Any:
def get_first(self, sql: str | list[str], parameters: Iterable | Mapping[str, Any] | None = None) -> Any:
"""
Executes the sql and returns the first resulting row.
Expand Down
56 changes: 48 additions & 8 deletions airflow/providers/common/sql/hooks/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,18 @@

from contextlib import closing
from datetime import datetime
from typing import TYPE_CHECKING, Any, Callable, Iterable, Mapping, Protocol, Sequence, cast
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
Mapping,
Protocol,
Sequence,
TypeVar,
cast,
overload,
)
from urllib.parse import urlparse

import sqlparse
Expand All @@ -34,6 +45,9 @@
from airflow.providers.openlineage.sqlparser import DatabaseInfo


T = TypeVar("T")


def return_single_query_results(sql: str | Iterable[str], return_last: bool, split_statements: bool):
"""
Determines when results of single query only should be returned.
Expand Down Expand Up @@ -184,7 +198,7 @@ def get_sqlalchemy_engine(self, engine_kwargs=None):
engine_kwargs = {}
return create_engine(self.get_uri(), **engine_kwargs)

def get_pandas_df(self, sql, parameters=None, **kwargs):
def get_pandas_df(self, sql, parameters: Iterable | Mapping[str, Any] | None = None, **kwargs):
"""
Executes the sql and returns a pandas dataframe.
Expand All @@ -204,7 +218,9 @@ def get_pandas_df(self, sql, parameters=None, **kwargs):
with closing(self.get_conn()) as conn:
return psql.read_sql(sql, con=conn, params=parameters, **kwargs)

def get_pandas_df_by_chunks(self, sql, parameters=None, *, chunksize, **kwargs):
def get_pandas_df_by_chunks(
self, sql, parameters: Iterable | Mapping[str, Any] | None = None, *, chunksize: int | None, **kwargs
):
"""
Executes the sql and returns a generator.
Expand All @@ -228,7 +244,7 @@ def get_pandas_df_by_chunks(self, sql, parameters=None, *, chunksize, **kwargs):
def get_records(
self,
sql: str | list[str],
parameters: Iterable | Mapping | None = None,
parameters: Iterable | Mapping[str, Any] | None = None,
) -> Any:
"""
Executes the sql and returns a set of records.
Expand All @@ -238,7 +254,7 @@ def get_records(
"""
return self.run(sql=sql, parameters=parameters, handler=fetch_all_handler)

def get_first(self, sql: str | list[str], parameters: Iterable | Mapping | None = None) -> Any:
def get_first(self, sql: str | list[str], parameters: Iterable | Mapping[str, Any] | None = None) -> Any:
"""
Executes the sql and returns the first resulting row.
Expand Down Expand Up @@ -268,15 +284,39 @@ def last_description(self) -> Sequence[Sequence] | None:
return None
return self.descriptions[-1]

@overload
def run(
self,
sql: str | Iterable[str],
autocommit: bool = ...,
parameters: Iterable | Mapping[str, Any] | None = ...,
handler: None = ...,
split_statements: bool = ...,
return_last: bool = ...,
) -> None:
...

@overload
def run(
self,
sql: str | Iterable[str],
autocommit: bool = ...,
parameters: Iterable | Mapping[str, Any] | None = ...,
handler: Callable[[Any], T] = ...,
split_statements: bool = ...,
return_last: bool = ...,
) -> T | list[T]:
...

def run(
self,
sql: str | Iterable[str],
autocommit: bool = False,
parameters: Iterable | Mapping | None = None,
handler: Callable | None = None,
parameters: Iterable | Mapping[str, Any] | None = None,
handler: Callable[[Any], T] | None = None,
split_statements: bool = False,
return_last: bool = True,
) -> Any | list[Any] | None:
) -> T | list[T] | None:
"""Run a command or a list of commands.
Pass a list of SQL statements to the sql parameter to get them to
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/common/sql/operators/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,7 @@ def __init__(
sql: str,
conn_id: str | None = None,
database: str | None = None,
parameters: Iterable | Mapping | None = None,
parameters: Iterable | Mapping[str, Any] | None = None,
**kwargs,
) -> None:
super().__init__(conn_id=conn_id, database=database, **kwargs)
Expand Down Expand Up @@ -1129,7 +1129,7 @@ def __init__(
follow_task_ids_if_false: list[str],
conn_id: str = "default_conn_id",
database: str | None = None,
parameters: Iterable | Mapping | None = None,
parameters: Iterable | Mapping[str, Any] | None = None,
**kwargs,
) -> None:
super().__init__(conn_id=conn_id, database=database, **kwargs)
Expand Down
35 changes: 31 additions & 4 deletions airflow/providers/databricks/hooks/databricks_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from contextlib import closing
from copy import copy
from typing import Any, Callable, Iterable, Mapping
from typing import Any, Callable, Iterable, Mapping, TypeVar, overload

from databricks import sql # type: ignore[attr-defined]
from databricks.sql.client import Connection # type: ignore[attr-defined]
Expand All @@ -30,6 +30,9 @@
LIST_SQL_ENDPOINTS_ENDPOINT = ("GET", "api/2.0/sql/endpoints")


T = TypeVar("T")


class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
"""Hook to interact with Databricks SQL.
Expand Down Expand Up @@ -138,15 +141,39 @@ def get_conn(self) -> Connection:
)
return self._sql_conn

@overload
def run(
self,
sql: str | Iterable[str],
autocommit: bool = ...,
parameters: Iterable | Mapping[str, Any] | None = ...,
handler: None = ...,
split_statements: bool = ...,
return_last: bool = ...,
) -> None:
...

@overload
def run(
self,
sql: str | Iterable[str],
autocommit: bool = ...,
parameters: Iterable | Mapping[str, Any] | None = ...,
handler: Callable[[Any], T] = ...,
split_statements: bool = ...,
return_last: bool = ...,
) -> T | list[T]:
...

def run(
self,
sql: str | Iterable[str],
autocommit: bool = False,
parameters: Iterable | Mapping | None = None,
handler: Callable | None = None,
parameters: Iterable | Mapping[str, Any] | None = None,
handler: Callable[[Any], T] | None = None,
split_statements: bool = True,
return_last: bool = True,
) -> Any | list[Any] | None:
) -> T | list[T] | None:
"""Runs a command or a list of commands.
Pass a list of SQL statements to the SQL parameter to get them to
Expand Down
42 changes: 35 additions & 7 deletions airflow/providers/exasol/hooks/exasol.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@
from __future__ import annotations

from contextlib import closing
from typing import Any, Callable, Iterable, Mapping, Sequence
from typing import Any, Callable, Iterable, Mapping, Sequence, TypeVar, overload

import pandas as pd
import pyexasol
from pyexasol import ExaConnection, ExaStatement

from airflow.providers.common.sql.hooks.sql import DbApiHook, return_single_query_results

T = TypeVar("T")


class ExasolHook(DbApiHook):
"""Interact with Exasol.
Expand Down Expand Up @@ -66,7 +68,9 @@ def get_conn(self) -> ExaConnection:
conn = pyexasol.connect(**conn_args)
return conn

def get_pandas_df(self, sql: str, parameters: dict | None = None, **kwargs) -> pd.DataFrame:
def get_pandas_df(
self, sql, parameters: Iterable | Mapping[str, Any] | None = None, **kwargs
) -> pd.DataFrame:
"""Execute the SQL and return a Pandas dataframe.
:param sql: The sql statement to be executed (str) or a list of
Expand All @@ -83,7 +87,7 @@ def get_pandas_df(self, sql: str, parameters: dict | None = None, **kwargs) -> p
def get_records(
self,
sql: str | list[str],
parameters: Iterable | Mapping | None = None,
parameters: Iterable | Mapping[str, Any] | None = None,
) -> list[dict | tuple[Any, ...]]:
"""Execute the SQL and return a set of records.
Expand All @@ -95,7 +99,7 @@ def get_records(
with closing(conn.execute(sql, parameters)) as cur:
return cur.fetchall()

def get_first(self, sql: str | list[str], parameters: Iterable | Mapping | None = None) -> Any:
def get_first(self, sql: str | list[str], parameters: Iterable | Mapping[str, Any] | None = None) -> Any:
"""Execute the SQL and return the first resulting row.
:param sql: the sql statement to be executed (str) or a list of
Expand Down Expand Up @@ -157,15 +161,39 @@ def get_description(statement: ExaStatement) -> Sequence[Sequence]:
)
return cols

@overload
def run(
self,
sql: str | Iterable[str],
autocommit: bool = ...,
parameters: Iterable | Mapping[str, Any] | None = ...,
handler: None = ...,
split_statements: bool = ...,
return_last: bool = ...,
) -> None:
...

@overload
def run(
self,
sql: str | Iterable[str],
autocommit: bool = ...,
parameters: Iterable | Mapping[str, Any] | None = ...,
handler: Callable[[Any], T] = ...,
split_statements: bool = ...,
return_last: bool = ...,
) -> T | list[T]:
...

def run(
self,
sql: str | Iterable[str],
autocommit: bool = False,
parameters: Iterable | Mapping | None = None,
handler: Callable | None = None,
parameters: Iterable | Mapping[str, Any] | None = None,
handler: Callable[[Any], T] | None = None,
split_statements: bool = False,
return_last: bool = True,
) -> Any | list[Any] | None:
) -> T | list[T] | None:
"""Run a command or a list of commands.
Pass a list of SQL statements to the SQL parameter to get them to
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/cloud/hooks/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def insert_rows(
def get_pandas_df(
self,
sql: str,
parameters: Iterable | Mapping | None = None,
parameters: Iterable | Mapping[str, Any] | None = None,
dialect: str | None = None,
**kwargs,
) -> DataFrame:
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/google/cloud/operators/cloud_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""This module contains Google Cloud SQL operators."""
from __future__ import annotations

from typing import TYPE_CHECKING, Iterable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Iterable, Mapping, Sequence

from googleapiclient.errors import HttpError

Expand Down Expand Up @@ -1189,7 +1189,7 @@ def __init__(
*,
sql: str | Iterable[str],
autocommit: bool = False,
parameters: Iterable | Mapping | None = None,
parameters: Iterable | Mapping[str, Any] | None = None,
gcp_conn_id: str = "google_cloud_default",
gcp_cloudsql_conn_id: str = "google_cloud_sql_default",
sql_proxy_binary_path: str | None = None,
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/google/suite/transfers/sql_to_sheets.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(
sql: str,
spreadsheet_id: str,
sql_conn_id: str,
parameters: Iterable | Mapping | None = None,
parameters: Iterable | Mapping[str, Any] | None = None,
database: str | None = None,
spreadsheet_range: str = "Sheet1",
gcp_conn_id: str = "google_cloud_default",
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/neo4j/operators/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Iterable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Iterable, Mapping, Sequence

from airflow.models import BaseOperator
from airflow.providers.neo4j.hooks.neo4j import Neo4jHook
Expand Down Expand Up @@ -46,7 +46,7 @@ def __init__(
*,
sql: str,
neo4j_conn_id: str = "neo4j_default",
parameters: Iterable | Mapping | None = None,
parameters: Iterable | Mapping[str, Any] | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand Down

0 comments on commit 60c49ab

Please sign in to comment.