Skip to content

Commit

Permalink
Unify DbApiHook.run() method with the methods which override it (#23971)
Browse files Browse the repository at this point in the history
  • Loading branch information
kazanzhy committed Jul 22, 2022
1 parent 31705ed commit df00436
Show file tree
Hide file tree
Showing 33 changed files with 307 additions and 264 deletions.
2 changes: 1 addition & 1 deletion airflow/operators/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ def __init__(
follow_task_ids_if_false: List[str],
conn_id: str = "default_conn_id",
database: Optional[str] = None,
parameters: Optional[Union[Mapping, Iterable]] = None,
parameters: Optional[Union[Iterable, Mapping]] = None,
**kwargs,
) -> None:
super().__init__(conn_id=conn_id, database=database, **kwargs)
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/operators/redshift_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.

from typing import TYPE_CHECKING, Iterable, Optional, Sequence, Union
from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
Expand Down Expand Up @@ -55,7 +55,7 @@ def __init__(
*,
sql: Union[str, Iterable[str]],
redshift_conn_id: str = 'redshift_default',
parameters: Optional[dict] = None,
parameters: Optional[Union[Iterable, Mapping]] = None,
autocommit: bool = True,
**kwargs,
) -> None:
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/amazon/aws/transfers/redshift_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(
unload_options: Optional[List] = None,
autocommit: bool = False,
include_header: bool = False,
parameters: Optional[Union[Mapping, Iterable]] = None,
parameters: Optional[Union[Iterable, Mapping]] = None,
table_as_file_name: bool = True, # Set to True by default for not breaking current workflows
**kwargs,
) -> None:
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/amazon/aws/transfers/s3_to_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.

import warnings
from typing import TYPE_CHECKING, List, Optional, Sequence, Union
from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Union

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
Expand Down Expand Up @@ -140,7 +140,7 @@ def execute(self, context: 'Context') -> None:

copy_statement = self._build_copy_query(copy_destination, credentials_block, copy_options)

sql: Union[list, str]
sql: Union[str, Iterable[str]]

if self.method == 'REPLACE':
sql = ["BEGIN;", f"DELETE FROM {destination};", copy_statement, "COMMIT"]
Expand Down
8 changes: 2 additions & 6 deletions airflow/providers/apache/drill/operators/drill.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
# under the License.
from typing import TYPE_CHECKING, Iterable, Mapping, Optional, Sequence, Union

import sqlparse

from airflow.models import BaseOperator
from airflow.providers.apache.drill.hooks.drill import DrillHook

Expand Down Expand Up @@ -52,7 +50,7 @@ def __init__(
*,
sql: str,
drill_conn_id: str = 'drill_default',
parameters: Optional[Union[Mapping, Iterable]] = None,
parameters: Optional[Union[Iterable, Mapping]] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -64,6 +62,4 @@ def __init__(
def execute(self, context: 'Context'):
self.log.info('Executing: %s on %s', self.sql, self.drill_conn_id)
self.hook = DrillHook(drill_conn_id=self.drill_conn_id)
sql = sqlparse.split(sqlparse.format(self.sql, strip_comments=True))
no_term_sql = [s[:-1] for s in sql if s[-1] == ';']
self.hook.run(no_term_sql, parameters=self.parameters)
self.hook.run(self.sql, parameters=self.parameters, split_statements=True)
1 change: 0 additions & 1 deletion airflow/providers/apache/drill/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ dependencies:
- apache-airflow>=2.2.0
- apache-airflow-providers-common-sql
- sqlalchemy-drill>=1.1.0
- sqlparse>=0.4.1

integrations:
- integration-name: Apache Drill
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/apache/pinot/hooks/pinot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import os
import subprocess
from typing import Any, Dict, Iterable, List, Optional, Union
from typing import Any, Iterable, List, Mapping, Optional, Union

from pinotdb import connect

Expand Down Expand Up @@ -275,7 +275,7 @@ def get_uri(self) -> str:
endpoint = conn.extra_dejson.get('endpoint', 'query/sql')
return f'{conn_type}://{host}/{endpoint}'

def get_records(self, sql: str, parameters: Optional[Union[Dict[str, Any], Iterable[Any]]] = None) -> Any:
def get_records(self, sql: str, parameters: Optional[Union[Iterable, Mapping]] = None) -> Any:
"""
Executes the sql and returns a set of records.
Expand All @@ -287,7 +287,7 @@ def get_records(self, sql: str, parameters: Optional[Union[Dict[str, Any], Itera
cur.execute(sql)
return cur.fetchall()

def get_first(self, sql: str, parameters: Optional[Union[Dict[str, Any], Iterable[Any]]] = None) -> Any:
def get_first(self, sql: str, parameters: Optional[Union[Iterable, Mapping]] = None) -> Any:
"""
Executes the sql and returns the first resulting row.
Expand Down
68 changes: 54 additions & 14 deletions airflow/providers/common/sql/hooks/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
import warnings
from contextlib import closing
from datetime import datetime
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Callable, Iterable, List, Mapping, Optional, Tuple, Union

import sqlparse
from sqlalchemy import create_engine
from typing_extensions import Protocol

Expand All @@ -27,6 +28,17 @@
from airflow.providers_manager import ProvidersManager
from airflow.utils.module_loading import import_string

if TYPE_CHECKING:
from sqlalchemy.engine import CursorResult


def fetch_all_handler(cursor: 'CursorResult') -> Optional[List[Tuple]]:
"""Handler for DbApiHook.run() to return results"""
if cursor.returns_rows:
return cursor.fetchall()
else:
return None


def _backported_get_hook(connection, *, hook_params=None):
"""Return hook based on conn_type
Expand Down Expand Up @@ -201,7 +213,31 @@ def get_first(self, sql, parameters=None):
cur.execute(sql)
return cur.fetchone()

def run(self, sql, autocommit=False, parameters=None, handler=None):
@staticmethod
def strip_sql_string(sql: str) -> str:
return sql.strip().rstrip(';')

@staticmethod
def split_sql_string(sql: str) -> List[str]:
"""
Splits string into multiple SQL expressions
:param sql: SQL string potentially consisting of multiple expressions
:return: list of individual expressions
"""
splits = sqlparse.split(sqlparse.format(sql, strip_comments=True))
statements = [s.rstrip(';') for s in splits if s.endswith(';')]
return statements

def run(
self,
sql: Union[str, Iterable[str]],
autocommit: bool = False,
parameters: Optional[Union[Iterable, Mapping]] = None,
handler: Optional[Callable] = None,
split_statements: bool = False,
return_last: bool = True,
) -> Optional[Union[Any, List[Any]]]:
"""
Runs a command or a list of commands. Pass a list of sql
statements to the sql parameter to get them to execute
Expand All @@ -213,14 +249,19 @@ def run(self, sql, autocommit=False, parameters=None, handler=None):
before executing the query.
:param parameters: The parameters to render the SQL query with.
:param handler: The result handler which is called with the result of each statement.
:return: query results if handler was provided.
:param split_statements: Whether to split a single SQL string into statements and run separately
:param return_last: Whether to return result for only last statement or for all after split
:return: return only result of the ALL SQL expressions if handler was provided.
"""
scalar = isinstance(sql, str)
if scalar:
sql = [sql]
scalar_return_last = isinstance(sql, str) and return_last
if isinstance(sql, str):
if split_statements:
sql = self.split_sql_string(sql)
else:
sql = [self.strip_sql_string(sql)]

if sql:
self.log.debug("Executing %d statements", len(sql))
self.log.debug("Executing following statements against DB: %s", list(sql))
else:
raise ValueError("List of SQL statements is empty")

Expand All @@ -232,22 +273,21 @@ def run(self, sql, autocommit=False, parameters=None, handler=None):
results = []
for sql_statement in sql:
self._run_command(cur, sql_statement, parameters)

if handler is not None:
result = handler(cur)
results.append(result)

# If autocommit was set to False for db that supports autocommit,
# or if db does not supports autocommit, we do a manual commit.
# If autocommit was set to False or db does not support autocommit, we do a manual commit.
if not self.get_autocommit(conn):
conn.commit()

if handler is None:
return None

if scalar:
return results[0]

return results
elif scalar_return_last:
return results[-1]
else:
return results

def _run_command(self, cur, sql_statement, parameters):
"""Runs a statement using an already open cursor."""
Expand Down
3 changes: 2 additions & 1 deletion airflow/providers/common/sql/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ description: |
versions:
- 1.0.0

dependencies: []
dependencies:
- sqlparse>=0.4.2

additional-extras:
- name: pandas
Expand Down
76 changes: 37 additions & 39 deletions airflow/providers/databricks/hooks/databricks_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
# specific language governing permissions and limitations
# under the License.

import re
from contextlib import closing
from copy import copy
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union

from databricks import sql # type: ignore[attr-defined]
from databricks.sql.client import Connection # type: ignore[attr-defined]
Expand Down Expand Up @@ -139,19 +138,15 @@ def get_conn(self) -> Connection:
)
return self._sql_conn

@staticmethod
def maybe_split_sql_string(sql: str) -> List[str]:
"""
Splits strings consisting of multiple SQL expressions into an
TODO: do we need something more sophisticated?
:param sql: SQL string potentially consisting of multiple expressions
:return: list of individual expressions
"""
splits = [s.strip() for s in re.split(";\\s*\r?\n", sql) if s.strip() != ""]
return splits

def run(self, sql: Union[str, List[str]], autocommit=True, parameters=None, handler=None):
def run(
self,
sql: Union[str, Iterable[str]],
autocommit: bool = False,
parameters: Optional[Union[Iterable, Mapping]] = None,
handler: Optional[Callable] = None,
split_statements: bool = True,
return_last: bool = True,
) -> Optional[Union[Tuple[str, Any], List[Tuple[str, Any]]]]:
"""
Runs a command or a list of commands. Pass a list of sql
statements to the sql parameter to get them to execute
Expand All @@ -163,41 +158,44 @@ def run(self, sql: Union[str, List[str]], autocommit=True, parameters=None, hand
before executing the query.
:param parameters: The parameters to render the SQL query with.
:param handler: The result handler which is called with the result of each statement.
:return: query results.
:param split_statements: Whether to split a single SQL string into statements and run separately
:param return_last: Whether to return result for only last statement or for all after split
:return: return only result of the LAST SQL expression if handler was provided.
"""
scalar_return_last = isinstance(sql, str) and return_last
if isinstance(sql, str):
sql = self.maybe_split_sql_string(sql)
if split_statements:
sql = self.split_sql_string(sql)
else:
sql = [self.strip_sql_string(sql)]

if sql:
self.log.debug("Executing %d statements", len(sql))
self.log.debug("Executing following statements against Databricks DB: %s", list(sql))
else:
raise ValueError("List of SQL statements is empty")

conn = None
results = []
for sql_statement in sql:
# when using AAD tokens, it could expire if previous query run longer than token lifetime
conn = self.get_conn()
with closing(conn.cursor()) as cur:
self.log.info("Executing statement: '%s', parameters: '%s'", sql_statement, parameters)
if parameters:
cur.execute(sql_statement, parameters)
else:
cur.execute(sql_statement)
schema = cur.description
results = []
if handler is not None:
cur = handler(cur)
for row in cur:
self.log.debug("Statement results: %s", row)
results.append(row)

self.log.info("Rows affected: %s", cur.rowcount)
if conn:
conn.close()
with closing(self.get_conn()) as conn:
self.set_autocommit(conn, autocommit)

with closing(conn.cursor()) as cur:
self._run_command(cur, sql_statement, parameters)

if handler is not None:
result = handler(cur)
schema = cur.description
results.append((schema, result))

self._sql_conn = None

# Return only result of the last SQL expression
return schema, results
if handler is None:
return None
elif scalar_return_last:
return results[-1]
else:
return results

def test_connection(self):
"""Test the Databricks SQL connection by running a simple query."""
Expand Down
12 changes: 7 additions & 5 deletions airflow/providers/databricks/operators/databricks_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@

import csv
import json
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union, cast

from databricks.sql.utils import ParamEscaper

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.common.sql.hooks.sql import fetch_all_handler
from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook

if TYPE_CHECKING:
Expand Down Expand Up @@ -71,11 +72,11 @@ class DatabricksSqlOperator(BaseOperator):
def __init__(
self,
*,
sql: Union[str, List[str]],
sql: Union[str, Iterable[str]],
databricks_conn_id: str = DatabricksSqlHook.default_conn_name,
http_path: Optional[str] = None,
sql_endpoint_name: Optional[str] = None,
parameters: Optional[Union[Mapping, Iterable]] = None,
parameters: Optional[Union[Iterable, Mapping]] = None,
session_configuration=None,
http_headers: Optional[List[Tuple[str, str]]] = None,
catalog: Optional[str] = None,
Expand Down Expand Up @@ -147,10 +148,11 @@ def _format_output(self, schema, results):
else:
raise AirflowException(f"Unsupported output format: '{self._output_format}'")

def execute(self, context: 'Context') -> Any:
def execute(self, context: 'Context'):
self.log.info('Executing: %s', self.sql)
hook = self._get_hook()
schema, results = hook.run(self.sql, parameters=self.parameters)
response = hook.run(self.sql, parameters=self.parameters, handler=fetch_all_handler)
schema, results = cast(List[Tuple[Any, Any]], response)[0]
# self.log.info('Schema: %s', schema)
# self.log.info('Results: %s', results)
self._format_output(schema, results)
Expand Down

0 comments on commit df00436

Please sign in to comment.