Skip to content

Commit

Permalink
[AIRFLOW-6405] Add GCP BigQuery Table Upsert Operator (#7126)
Browse files Browse the repository at this point in the history
* [AIRFLOW-6405] Add GCP BigQuery Table Property Upsert Operator

* [AIRFLOW-6405] Remove unnecessary checks from BQ hook
  • Loading branch information
jithin-sukumar committed Feb 16, 2020
1 parent cc1bd64 commit 946bdc2
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 12 deletions.
12 changes: 11 additions & 1 deletion airflow/providers/google/cloud/example_dags/example_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
BigQueryCreateEmptyDatasetOperator, BigQueryCreateEmptyTableOperator, BigQueryCreateExternalTableOperator,
BigQueryDeleteDatasetOperator, BigQueryDeleteTableOperator, BigQueryExecuteQueryOperator,
BigQueryGetDataOperator, BigQueryGetDatasetOperator, BigQueryGetDatasetTablesOperator,
BigQueryPatchDatasetOperator, BigQueryUpdateDatasetOperator,
BigQueryPatchDatasetOperator, BigQueryUpdateDatasetOperator, BigQueryUpsertTableOperator,
)
from airflow.providers.google.cloud.operators.bigquery_to_bigquery import BigQueryToBigQueryOperator
from airflow.providers.google.cloud.operators.bigquery_to_gcs import BigQueryToGCSOperator
Expand Down Expand Up @@ -253,6 +253,15 @@
delete_contents=True
)

update_table = BigQueryUpsertTableOperator(
task_id="update_table", dataset_id=DATASET_NAME, table_resource={
"tableReference": {
"tableId": "test-table-id"
},
"expirationTime": 12345678
}
)

create_dataset >> execute_query_save >> delete_dataset
create_dataset >> get_empty_dataset_tables >> create_table >> get_dataset_tables >> delete_dataset
create_dataset >> get_dataset >> delete_dataset
Expand All @@ -264,3 +273,4 @@
execute_query_external_table >> bigquery_to_gcs >> delete_dataset
create_table >> create_view >> delete_view >> delete_table >> delete_dataset
create_dataset_with_location >> create_table_with_location >> delete_dataset_with_location
create_dataset >> create_table >> update_table >> delete_table >> delete_dataset
75 changes: 70 additions & 5 deletions airflow/providers/google/cloud/operators/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ class BigQueryCheckOperator(CheckOperator):
:type use_legacy_sql: bool
"""

template_fields = ('sql', 'gcp_conn_id', )
template_ext = ('.sql', )
template_fields = ('sql', 'gcp_conn_id',)
template_ext = ('.sql',)

@apply_defaults
def __init__(self,
Expand Down Expand Up @@ -121,8 +121,8 @@ class BigQueryValueCheckOperator(ValueCheckOperator):
:type bigquery_conn_id: str
"""

template_fields = ('sql', 'gcp_conn_id', 'pass_value', )
template_ext = ('.sql', )
template_fields = ('sql', 'gcp_conn_id', 'pass_value',)
template_ext = ('.sql',)

@apply_defaults
def __init__(self, sql: str,
Expand Down Expand Up @@ -179,7 +179,7 @@ class BigQueryIntervalCheckOperator(IntervalCheckOperator):
:type bigquery_conn_id: str
"""

template_fields = ('table', 'gcp_conn_id', )
template_fields = ('table', 'gcp_conn_id',)

@apply_defaults
def __init__(self,
Expand Down Expand Up @@ -1387,3 +1387,68 @@ def execute(self, context):
hook.run_table_delete(
deletion_dataset_table=self.deletion_dataset_table,
ignore_if_missing=self.ignore_if_missing)


class BigQueryUpsertTableOperator(BaseOperator):
"""
Upsert BigQuery table
:param dataset_id: A dotted
``(<project>.|<project>:)<dataset>`` that indicates which dataset
will be updated. (templated)
:type dataset_id: str
:param table_resource: a table resource. see
https://cloud.google.com/bigquery/docs/reference/v2/tables#resource
:type table_resource: dict
:param project_id: The name of the project where we want to update the dataset.
Don't need to provide, if projectId in dataset_reference.
:type project_id: str
:param gcp_conn_id: (Optional) The connection ID used to connect to Google Cloud Platform.
:type gcp_conn_id: str
:param bigquery_conn_id: (Deprecated) The connection ID used to connect to Google Cloud Platform.
This parameter has been deprecated. You should pass the gcp_conn_id parameter instead.
:type bigquery_conn_id: str
:param delegate_to: The account to impersonate, if any.
For this to work, the service account making the request must have domain-wide
delegation enabled.
:type delegate_to: str
:param location: The location used for the operation.
:type location: str
"""
template_fields = ('dataset_id', 'table_resource',)

@apply_defaults
def __init__(self,
dataset_id: str,
table_resource: dict,
project_id: Optional[str] = None,
gcp_conn_id: str = 'google_cloud_default',
bigquery_conn_id: Optional[str] = None,
delegate_to: Optional[str] = None,
location: Optional[str] = None,
*args,
**kwargs) -> None:
super().__init__(*args, **kwargs)

if bigquery_conn_id:
warnings.warn(
"The bigquery_conn_id parameter has been deprecated. You should pass "
"the gcp_conn_id parameter.", DeprecationWarning, stacklevel=3)
gcp_conn_id = bigquery_conn_id

self.dataset_id = dataset_id
self.table_resource = table_resource
self.project_id = project_id
self.gcp_conn_id = gcp_conn_id
self.delegate_to = delegate_to
self.location = location

def execute(self, context):
self.log.info('Upserting Dataset: %s with table_resource: %s', self.dataset_id, self.table_resource)
hook = BigQueryHook(bigquery_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to,
location=self.location)
hook.run_table_upsert(
dataset_id=self.dataset_id,
table_resource=self.table_resource,
project_id=self.project_id)
102 changes: 97 additions & 5 deletions tests/providers/google/cloud/hooks/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -1326,6 +1326,94 @@ def test_get_tables_list(self, mock_get_service, mock_get_creds_and_proj_id):
)
self.assertEqual(result, expected_result)

@mock.patch(
'airflow.providers.google.cloud.hooks.base.CloudBaseHook._get_credentials_and_project_id',
return_value=(CREDENTIALS, PROJECT_ID)
)
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
def test_table_upsert_on_insert(self, mock_get_service, mock_get_creds_and_proj_id):
table_resource = {
"tableReference": {
"tableId": "test-table-id"
},
"expirationTime": 123456
}
mock_service = mock_get_service.return_value
method = mock_service.tables.return_value.insert
bq_hook = hook.BigQueryHook()
bq_hook.run_table_upsert(
dataset_id=DATASET_ID,
table_resource=table_resource,
project_id=PROJECT_ID
)

method.assert_called_once_with(
body=table_resource,
datasetId=DATASET_ID,
projectId=PROJECT_ID
)

@mock.patch(
'airflow.providers.google.cloud.hooks.base.CloudBaseHook._get_credentials_and_project_id',
return_value=(CREDENTIALS, PROJECT_ID),
)
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
def test_table_upsert_on_update(self, mock_get_service, mock_get_creds_and_proj_id):
table_resource = {
"tableReference": {
"tableId": "table1"
},
"expirationTime": 123456
}
table1 = "table1"
table2 = "table2"
expected_tables_list = {'tables': [
{
"creationTime": "12345678",
"kind": "bigquery#table",
"type": "TABLE",
"id": "{project}:{dataset}.{table}".format(
project=PROJECT_ID,
dataset=DATASET_ID,
table=table1),
"tableReference": {
"projectId": PROJECT_ID,
"tableId": table1,
"datasetId": DATASET_ID
}
},
{
"creationTime": "12345678",
"kind": "bigquery#table",
"type": "TABLE",
"id": "{project}:{dataset}.{table}".format(
project=PROJECT_ID,
dataset=DATASET_ID,
table=table2),
"tableReference": {
"projectId": PROJECT_ID,
"tableId": table2,
"datasetId": DATASET_ID
}
}
]}
mock_service = mock_get_service.return_value
mock_service.tables.return_value.list.return_value.execute.return_value = expected_tables_list
method = mock_service.tables.return_value.update
bq_hook = hook.BigQueryHook()
bq_hook.run_table_upsert(
dataset_id=DATASET_ID,
table_resource=table_resource,
project_id=PROJECT_ID
)

method.assert_called_once_with(
body=table_resource,
datasetId=DATASET_ID,
projectId=PROJECT_ID,
tableId=table1
)


class TestBigQueryCursor(unittest.TestCase):

Expand Down Expand Up @@ -1611,11 +1699,11 @@ class TestLabelsInRunJob(unittest.TestCase):
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.run_with_configuration")
def test_run_query_with_arg(self, mocked_rwc, mock_get_service, mock_get_creds_and_proj_id):

def run_with_config(config):
self.assertEqual(
config['labels'], {'label1': 'test1', 'label2': 'test2'}
)

mocked_rwc.side_effect = run_with_config

bq_hook = hook.BigQueryHook()
Expand Down Expand Up @@ -1892,6 +1980,7 @@ class TestTimePartitioningInRunJob(unittest.TestCase):
def test_run_load_default(self, mocked_rwc, mock_get_service, mock_get_creds_and_proj_id):
def run_with_config(config):
self.assertIsNone(config['load'].get('timePartitioning'))

mocked_rwc.side_effect = run_with_config

bq_hook = hook.BigQueryHook()
Expand Down Expand Up @@ -1933,6 +2022,7 @@ def run_with_config(config):
'expirationMs': 1000
}
)

mocked_rwc.side_effect = run_with_config

bq_hook = hook.BigQueryHook()
Expand All @@ -1954,6 +2044,7 @@ def run_with_config(config):
def test_run_query_default(self, mocked_rwc, mock_get_service, mock_get_creds_and_proj_id):
def run_with_config(config):
self.assertIsNone(config['query'].get('timePartitioning'))

mocked_rwc.side_effect = run_with_config

bq_hook = hook.BigQueryHook()
Expand All @@ -1977,6 +2068,7 @@ def run_with_config(config):
'expirationMs': 1000
}
)

mocked_rwc.side_effect = run_with_config

bq_hook = hook.BigQueryHook()
Expand Down Expand Up @@ -2019,9 +2111,9 @@ class TestClusteringInRunJob(unittest.TestCase):
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.run_with_configuration")
def test_run_load_default(self, mocked_rwc, mock_get_service, mock_get_creds_and_proj_id):

def run_with_config(config):
self.assertIsNone(config['load'].get('clustering'))

mocked_rwc.side_effect = run_with_config

bq_hook = hook.BigQueryHook()
Expand All @@ -2040,14 +2132,14 @@ def run_with_config(config):
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.run_with_configuration")
def test_run_load_with_arg(self, mocked_rwc, mock_get_service, mock_get_creds_and_proj_id):

def run_with_config(config):
self.assertEqual(
config['load']['clustering'],
{
'fields': ['field1', 'field2']
}
)

mocked_rwc.side_effect = run_with_config

bq_hook = hook.BigQueryHook()
Expand All @@ -2068,9 +2160,9 @@ def run_with_config(config):
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.run_with_configuration")
def test_run_query_default(self, mocked_rwc, mock_get_service, mock_get_creds_and_proj_id):

def run_with_config(config):
self.assertIsNone(config['query'].get('clustering'))

mocked_rwc.side_effect = run_with_config

bq_hook = hook.BigQueryHook()
Expand All @@ -2085,14 +2177,14 @@ def run_with_config(config):
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_service")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.run_with_configuration")
def test_run_query_with_arg(self, mocked_rwc, mock_get_service, mock_get_creds_and_proj_id):

def run_with_config(config):
self.assertEqual(
config['query']['clustering'],
{
'fields': ['field1', 'field2']
}
)

mocked_rwc.side_effect = run_with_config

bq_hook = hook.BigQueryHook()
Expand Down
28 changes: 27 additions & 1 deletion tests/providers/google/cloud/operators/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
BigQueryDeleteDatasetOperator, BigQueryDeleteTableOperator, BigQueryExecuteQueryOperator,
BigQueryGetDataOperator, BigQueryGetDatasetOperator, BigQueryGetDatasetTablesOperator,
BigQueryIntervalCheckOperator, BigQueryPatchDatasetOperator, BigQueryUpdateDatasetOperator,
BigQueryValueCheckOperator,
BigQueryUpsertTableOperator, BigQueryValueCheckOperator,
)
from airflow.serialization.serialized_objects import SerializedDAG
from airflow.settings import Session
Expand All @@ -49,6 +49,12 @@
TEST_SOURCE_FORMAT = 'CSV'
DEFAULT_DATE = datetime(2015, 1, 1)
TEST_DAG_ID = 'test-bigquery-operators'
TEST_TABLE_RESOURCES = {
"tableReference": {
"tableId": TEST_TABLE_ID
},
"expirationTime": 1234567
}
VIEW_DEFINITION = {
"query": "SELECT * FROM `{}.{}`".format(TEST_DATASET, TEST_TABLE_ID),
"useLegacySql": False
Expand Down Expand Up @@ -757,3 +763,23 @@ def test_bigquery_conn_id_deprecation_warning(self, operator_class, kwargs):
**kwargs
)
self.assertEqual(bigquery_conn_id, operator.gcp_conn_id)


class TestBigQueryUpsertTableOperator(unittest.TestCase):
@mock.patch('airflow.providers.google.cloud.operators.bigquery.BigQueryHook')
def test_execute(self, mock_hook):
operator = BigQueryUpsertTableOperator(
task_id=TASK_ID,
dataset_id=TEST_DATASET,
table_resource=TEST_TABLE_RESOURCES,
project_id=TEST_GCP_PROJECT_ID,
)

operator.execute(None)
mock_hook.return_value \
.run_table_upsert \
.assert_called_once_with(
dataset_id=TEST_DATASET,
project_id=TEST_GCP_PROJECT_ID,
table_resource=TEST_TABLE_RESOURCES
)

0 comments on commit 946bdc2

Please sign in to comment.