Skip to content

Commit

Permalink
DbApiHook: Support kwargs in get_pandas_df (#9730)
Browse files Browse the repository at this point in the history
* DbApiHook: Support kwargs in get_pandas_df
* BigQueryHook: Support kwargs in get_pandas_df
* ExasolHook: Support kwargs in get_pandas_df
* PrestoHook: Support kwargs in get_pandas_df
* HiveServer2Hook: Support kwargs in get_pandas_df
  • Loading branch information
22quinn committed Aug 12, 2020
1 parent f618cdd commit 8f8db89
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 12 deletions.
6 changes: 4 additions & 2 deletions airflow/hooks/dbapi_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def get_sqlalchemy_engine(self, engine_kwargs=None):
engine_kwargs = {}
return create_engine(self.get_uri(), **engine_kwargs)

def get_pandas_df(self, sql, parameters=None):
def get_pandas_df(self, sql, parameters=None, **kwargs):
"""
Executes the sql and returns a pandas dataframe
Expand All @@ -115,11 +115,13 @@ def get_pandas_df(self, sql, parameters=None):
:type sql: str or list
:param parameters: The parameters to render the SQL query with.
:type parameters: dict or iterable
:param kwargs: (optional) passed into pandas.io.sql.read_sql method
:type kwargs: dict
"""
from pandas.io import sql as psql

with closing(self.get_conn()) as conn:
return psql.read_sql(sql, con=conn, params=parameters)
return psql.read_sql(sql, con=conn, params=parameters, **kwargs)

def get_records(self, sql, parameters=None):
"""
Expand Down
9 changes: 6 additions & 3 deletions airflow/providers/apache/hive/hooks/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,9 +1044,10 @@ def get_records(self, hql: Union[str, Text],
"""
return self.get_results(hql, schema=schema, hive_conf=hive_conf)['data']

def get_pandas_df(self, hql: Union[str, Text],
def get_pandas_df(self, hql: Union[str, Text], # type: ignore
schema: str = 'default',
hive_conf: Optional[Dict[Any, Any]] = None
hive_conf: Optional[Dict[Any, Any]] = None,
**kwargs
) -> pandas.DataFrame:
"""
Get a pandas dataframe from a Hive query
Expand All @@ -1057,6 +1058,8 @@ def get_pandas_df(self, hql: Union[str, Text],
:type schema: str
:param hive_conf: hive_conf to execute alone with the hql.
:type hive_conf: dict
:param kwargs: (optional) passed into pandas.DataFrame constructor
:type kwargs: dict
:return: result of hive execution
:rtype: DataFrame
Expand All @@ -1069,6 +1072,6 @@ def get_pandas_df(self, hql: Union[str, Text],
:return: pandas.DateFrame
"""
res = self.get_results(hql, schema=schema, hive_conf=hive_conf)
df = pandas.DataFrame(res['data'])
df = pandas.DataFrame(res['data'], **kwargs)
df.columns = [c[0] for c in res['header']]
return df
6 changes: 4 additions & 2 deletions airflow/providers/exasol/hooks/exasol.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def get_conn(self):
conn = pyexasol.connect(**conn_args)
return conn

def get_pandas_df(self, sql, parameters=None):
def get_pandas_df(self, sql, parameters=None, **kwargs):
"""
Executes the sql and returns a pandas dataframe
Expand All @@ -67,9 +67,11 @@ def get_pandas_df(self, sql, parameters=None):
:type sql: str or list
:param parameters: The parameters to render the SQL query with.
:type parameters: dict or iterable
:param kwargs: (optional) passed into pyexasol.ExaConnection.export_to_pandas method
:type kwargs: dict
"""
with closing(self.get_conn()) as conn:
conn.export_to_pandas(sql, query_params=parameters)
conn.export_to_pandas(sql, query_params=parameters, **kwargs)

def get_records(self, sql, parameters=None):
"""
Expand Down
8 changes: 6 additions & 2 deletions airflow/providers/google/cloud/hooks/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ def insert_rows(
raise NotImplementedError()

def get_pandas_df(
self, sql: str, parameters: Optional[Union[Iterable, Mapping]] = None, dialect: Optional[str] = None
self, sql: str, parameters: Optional[Union[Iterable, Mapping]] = None, dialect: Optional[str] = None,
**kwargs
) -> DataFrame:
"""
Returns a Pandas DataFrame for the results produced by a BigQuery
Expand All @@ -186,6 +187,8 @@ def get_pandas_df(
:param dialect: Dialect of BigQuery SQL – legacy SQL or standard SQL
defaults to use `self.use_legacy_sql` if not specified
:type dialect: str in {'legacy', 'standard'}
:param kwargs: (optional) passed into pandas_gbq.read_gbq method
:type kwargs: dict
"""
if dialect is None:
dialect = 'legacy' if self.use_legacy_sql else 'standard'
Expand All @@ -196,7 +199,8 @@ def get_pandas_df(
project_id=project_id,
dialect=dialect,
verbose=False,
credentials=credentials)
credentials=credentials,
**kwargs)

@GoogleBaseHook.fallback_to_default_project_id
def table_exists(self, dataset_id: str, table_id: str, project_id: str) -> bool:
Expand Down
6 changes: 3 additions & 3 deletions airflow/providers/presto/hooks/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def get_first(self, hql, parameters=None):
except DatabaseError as e:
raise PrestoException(e)

def get_pandas_df(self, hql, parameters=None):
def get_pandas_df(self, hql, parameters=None, **kwargs):
"""
Get a pandas dataframe from a sql query.
"""
Expand All @@ -102,10 +102,10 @@ def get_pandas_df(self, hql, parameters=None):
raise PrestoException(e)
column_descriptions = cursor.description
if data:
df = pandas.DataFrame(data)
df = pandas.DataFrame(data, **kwargs)
df.columns = [c[0] for c in column_descriptions]
else:
df = pandas.DataFrame()
df = pandas.DataFrame(**kwargs)
return df

def run(self, hql, parameters=None):
Expand Down

0 comments on commit 8f8db89

Please sign in to comment.