Skip to content

Commit

Permalink
add description method in BigQueryCursor class (#25366)
Browse files Browse the repository at this point in the history
]
  • Loading branch information
sophiely committed Aug 4, 2022
1 parent e84d753 commit 7d2c2ee
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 17 deletions.
62 changes: 48 additions & 14 deletions airflow/providers/google/cloud/hooks/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2663,11 +2663,16 @@ def __init__(
self.job_id = None # type: Optional[str]
self.buffer = [] # type: list
self.all_pages_loaded = False # type: bool
self._description = [] # type: List

@property
def description(self) -> None:
"""The schema description method is not currently implemented"""
raise NotImplementedError
def description(self) -> List:
"""Return the cursor description"""
return self._description

@description.setter
def description(self, value):
self._description = value

def close(self) -> None:
"""By default, do nothing"""
Expand All @@ -2688,6 +2693,10 @@ def execute(self, operation: str, parameters: Optional[dict] = None) -> None:
self.flush_results()
self.job_id = self.hook.run_query(sql)

query_results = self._get_query_result()
description = _format_schema_for_description(query_results["schema"])
self.description = description

def executemany(self, operation: str, seq_of_parameters: list) -> None:
"""
Execute a BigQuery query multiple times with different parameters.
Expand Down Expand Up @@ -2723,17 +2732,7 @@ def next(self) -> Union[List, None]:
if self.all_pages_loaded:
return None

query_results = (
self.service.jobs()
.getQueryResults(
projectId=self.project_id,
jobId=self.job_id,
location=self.location,
pageToken=self.page_token,
)
.execute(num_retries=self.num_retries)
)

query_results = self._get_query_result()
if 'rows' in query_results and query_results['rows']:
self.page_token = query_results.get('pageToken')
fields = query_results['schema']['fields']
Expand Down Expand Up @@ -2805,6 +2804,21 @@ def setinputsizes(self, sizes: Any) -> None:
def setoutputsize(self, size: Any, column: Any = None) -> None:
"""Does nothing by default"""

def _get_query_result(self) -> Dict:
"""Get job query results like data, schema, job type..."""
query_results = (
self.service.jobs()
.getQueryResults(
projectId=self.project_id,
jobId=self.job_id,
location=self.location,
pageToken=self.page_token,
)
.execute(num_retries=self.num_retries)
)

return query_results


def _bind_parameters(operation: str, parameters: dict) -> str:
"""Helper method that binds parameters to a SQL query"""
Expand Down Expand Up @@ -2973,3 +2987,23 @@ def _validate_src_fmt_configs(
raise ValueError(f"{k} is not a valid src_fmt_configs for type {source_format}.")

return src_fmt_configs


def _format_schema_for_description(schema: Dict) -> List:
"""
Reformat the schema to match cursor description standard which is a tuple
of 7 elemenbts (name, type, display_size, internal_size, precision, scale, null_ok)
"""
description = []
for field in schema["fields"]:
field_description = (
field["name"],
field["type"],
None,
None,
None,
None,
field["mode"] == "NULLABLE",
)
description.append(field_description)
return description
29 changes: 26 additions & 3 deletions tests/providers/google/cloud/hooks/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
BigQueryHook,
_api_resource_configs_duplication_check,
_cleanse_time_partitioning,
_format_schema_for_description,
_validate_src_fmt_configs,
_validate_value,
split_tablename,
Expand Down Expand Up @@ -1239,11 +1240,33 @@ def test_execute_many(self, mock_insert, _):
]
)

def test_format_schema_for_description(self):
test_query_result = {
"schema": {
"fields": [
{"name": "field_1", "type": "STRING", "mode": "NULLABLE"},
]
},
}
description = _format_schema_for_description(test_query_result["schema"])
assert description == [('field_1', 'STRING', None, None, None, None, True)]

@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
def test_description(self, mock_get_service):
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job")
def test_description(self, mock_insert, mock_get_service):
mock_get_query_results = mock_get_service.return_value.jobs.return_value.getQueryResults
mock_execute = mock_get_query_results.return_value.execute
mock_execute.return_value = {
"schema": {
"fields": [
{"name": "ts", "type": "TIMESTAMP", "mode": "NULLABLE"},
]
},
}

bq_cursor = self.hook.get_cursor()
with pytest.raises(NotImplementedError):
bq_cursor.description
bq_cursor.execute("SELECT CURRENT_TIMESTAMP() as ts")
assert bq_cursor.description == [("ts", "TIMESTAMP", None, None, None, None, True)]

@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
def test_close(self, mock_get_service):
Expand Down

0 comments on commit 7d2c2ee

Please sign in to comment.