Skip to content

Commit

Permalink
Fix and improve GCP BigTable hook and system test (#13896)
Browse files Browse the repository at this point in the history
Improve environment variables in GCP BigTable system test.
It will help to parametrize system tests.
  • Loading branch information
Tobiasz Kędzierski committed Jan 27, 2021
1 parent 6616617 commit 810c15e
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 25 deletions.
32 changes: 16 additions & 16 deletions airflow/providers/google/cloud/example_dags/example_bigtable.py
Expand Up @@ -60,22 +60,22 @@
from airflow.utils.dates import days_ago

GCP_PROJECT_ID = getenv('GCP_PROJECT_ID', 'example-project')
CBT_INSTANCE_ID = getenv('CBT_INSTANCE_ID', 'some-instance-id')
CBT_INSTANCE_DISPLAY_NAME = getenv('CBT_INSTANCE_DISPLAY_NAME', 'Human-readable name')
CBT_INSTANCE_ID = getenv('GCP_BIG_TABLE_INSTANCE_ID', 'some-instance-id')
CBT_INSTANCE_DISPLAY_NAME = getenv('GCP_BIG_TABLE_INSTANCE_DISPLAY_NAME', 'Human-readable name')
CBT_INSTANCE_DISPLAY_NAME_UPDATED = getenv(
"CBT_INSTANCE_DISPLAY_NAME_UPDATED", "Human-readable name - updated"
"GCP_BIG_TABLE_INSTANCE_DISPLAY_NAME_UPDATED", f"{CBT_INSTANCE_DISPLAY_NAME} - updated"
)
CBT_INSTANCE_TYPE = getenv('CBT_INSTANCE_TYPE', '2')
CBT_INSTANCE_TYPE_PROD = getenv('CBT_INSTANCE_TYPE_PROD', '1')
CBT_INSTANCE_LABELS = getenv('CBT_INSTANCE_LABELS', '{}')
CBT_INSTANCE_LABELS_UPDATED = getenv('CBT_INSTANCE_LABELS', '{"env": "prod"}')
CBT_CLUSTER_ID = getenv('CBT_CLUSTER_ID', 'some-cluster-id')
CBT_CLUSTER_ZONE = getenv('CBT_CLUSTER_ZONE', 'europe-west1-b')
CBT_CLUSTER_NODES = getenv('CBT_CLUSTER_NODES', '3')
CBT_CLUSTER_NODES_UPDATED = getenv('CBT_CLUSTER_NODES_UPDATED', '5')
CBT_CLUSTER_STORAGE_TYPE = getenv('CBT_CLUSTER_STORAGE_TYPE', '2')
CBT_TABLE_ID = getenv('CBT_TABLE_ID', 'some-table-id')
CBT_POKE_INTERVAL = getenv('CBT_POKE_INTERVAL', '60')
CBT_INSTANCE_TYPE = getenv('GCP_BIG_TABLE_INSTANCE_TYPE', '2')
CBT_INSTANCE_TYPE_PROD = getenv('GCP_BIG_TABLE_INSTANCE_TYPE_PROD', '1')
CBT_INSTANCE_LABELS = getenv('GCP_BIG_TABLE_INSTANCE_LABELS', '{}')
CBT_INSTANCE_LABELS_UPDATED = getenv('GCP_BIG_TABLE_INSTANCE_LABELS_UPDATED', '{"env": "prod"}')
CBT_CLUSTER_ID = getenv('GCP_BIG_TABLE_CLUSTER_ID', 'some-cluster-id')
CBT_CLUSTER_ZONE = getenv('GCP_BIG_TABLE_CLUSTER_ZONE', 'europe-west1-b')
CBT_CLUSTER_NODES = getenv('GCP_BIG_TABLE_CLUSTER_NODES', '3')
CBT_CLUSTER_NODES_UPDATED = getenv('GCP_BIG_TABLE_CLUSTER_NODES_UPDATED', '5')
CBT_CLUSTER_STORAGE_TYPE = getenv('GCP_BIG_TABLE_CLUSTER_STORAGE_TYPE', '2')
CBT_TABLE_ID = getenv('GCP_BIG_TABLE_TABLE_ID', 'some-table-id')
CBT_POKE_INTERVAL = getenv('GCP_BIG_TABLE_POKE_INTERVAL', '60')


with models.DAG(
Expand All @@ -93,8 +93,8 @@
instance_display_name=CBT_INSTANCE_DISPLAY_NAME,
instance_type=int(CBT_INSTANCE_TYPE),
instance_labels=json.loads(CBT_INSTANCE_LABELS),
cluster_nodes=int(CBT_CLUSTER_NODES),
cluster_storage_type=CBT_CLUSTER_STORAGE_TYPE,
cluster_nodes=None,
cluster_storage_type=int(CBT_CLUSTER_STORAGE_TYPE),
task_id='create_instance_task',
)
create_instance_task2 = BigtableCreateInstanceOperator(
Expand Down
9 changes: 8 additions & 1 deletion airflow/providers/google/cloud/hooks/bigtable.py
Expand Up @@ -169,7 +169,14 @@ def create_instance(
instance_labels,
)

clusters = [instance.cluster(main_cluster_id, main_cluster_zone, cluster_nodes, cluster_storage_type)]
cluster_kwargs = dict(
cluster_id=main_cluster_id,
location_id=main_cluster_zone,
default_storage_type=cluster_storage_type,
)
if instance_type != enums.Instance.Type.DEVELOPMENT and cluster_nodes:
cluster_kwargs["serve_nodes"] = cluster_nodes
clusters = [instance.cluster(**cluster_kwargs)]
if replica_cluster_id and replica_cluster_zone:
warnings.warn(
"The replica_cluster_id and replica_cluster_zone parameter have been deprecated."
Expand Down
58 changes: 55 additions & 3 deletions tests/providers/google/cloud/hooks/test_bigtable.py
Expand Up @@ -309,7 +309,7 @@ def test_create_instance(self, get_client, instance_create, mock_project_id):
@mock.patch('google.cloud.bigtable.instance.Instance.cluster')
@mock.patch('google.cloud.bigtable.instance.Instance.create')
@mock.patch('airflow.providers.google.cloud.hooks.bigtable.BigtableHook._get_client')
def test_create_instance_with_one_replica_cluster(
def test_create_instance_with_one_replica_cluster_production(
self, get_client, instance_create, cluster, mock_project_id
):
operation = mock.Mock()
Expand All @@ -325,10 +325,57 @@ def test_create_instance_with_one_replica_cluster(
cluster_nodes=1,
cluster_storage_type=enums.StorageType.SSD,
project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
instance_type=enums.Instance.Type.PRODUCTION,
)
cluster.assert_has_calls(
[
unittest.mock.call(CBT_CLUSTER, CBT_ZONE, 1, enums.StorageType.SSD),
unittest.mock.call(
cluster_id=CBT_CLUSTER,
location_id=CBT_ZONE,
serve_nodes=1,
default_storage_type=enums.StorageType.SSD,
),
unittest.mock.call(
CBT_REPLICA_CLUSTER_ID, CBT_REPLICA_CLUSTER_ZONE, 1, enums.StorageType.SSD
),
],
any_order=True,
)
get_client.assert_called_once_with(project_id='example-project')
instance_create.assert_called_once_with(clusters=mock.ANY)
assert res.instance_id == 'instance'

@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
new_callable=PropertyMock,
return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST,
)
@mock.patch('google.cloud.bigtable.instance.Instance.cluster')
@mock.patch('google.cloud.bigtable.instance.Instance.create')
@mock.patch('airflow.providers.google.cloud.hooks.bigtable.BigtableHook._get_client')
def test_create_instance_with_one_replica_cluster_development(
self, get_client, instance_create, cluster, mock_project_id
):
operation = mock.Mock()
operation.result_return_value = Instance(instance_id=CBT_INSTANCE, client=get_client)
instance_create.return_value = operation

res = self.bigtable_hook_default_project_id.create_instance(
instance_id=CBT_INSTANCE,
main_cluster_id=CBT_CLUSTER,
main_cluster_zone=CBT_ZONE,
replica_cluster_id=CBT_REPLICA_CLUSTER_ID,
replica_cluster_zone=CBT_REPLICA_CLUSTER_ZONE,
cluster_nodes=1,
cluster_storage_type=enums.StorageType.SSD,
project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
instance_type=enums.Instance.Type.DEVELOPMENT,
)
cluster.assert_has_calls(
[
unittest.mock.call(
cluster_id=CBT_CLUSTER, location_id=CBT_ZONE, default_storage_type=enums.StorageType.SSD
),
unittest.mock.call(
CBT_REPLICA_CLUSTER_ID, CBT_REPLICA_CLUSTER_ZONE, 1, enums.StorageType.SSD
),
Expand Down Expand Up @@ -365,7 +412,12 @@ def test_create_instance_with_multiple_replica_clusters(
)
cluster.assert_has_calls(
[
unittest.mock.call(CBT_CLUSTER, CBT_ZONE, 1, enums.StorageType.SSD),
unittest.mock.call(
cluster_id=CBT_CLUSTER,
location_id=CBT_ZONE,
serve_nodes=1,
default_storage_type=enums.StorageType.SSD,
),
unittest.mock.call('replica-1', 'us-west1-a', 1, enums.StorageType.SSD),
unittest.mock.call('replica-2', 'us-central1-f', 1, enums.StorageType.SSD),
unittest.mock.call('replica-3', 'us-east1-d', 1, enums.StorageType.SSD),
Expand Down
Expand Up @@ -15,16 +15,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os

import pytest

from airflow.providers.google.cloud.example_dags.example_bigtable import CBT_INSTANCE_ID, GCP_PROJECT_ID
from tests.providers.google.cloud.utils.gcp_authenticator import GCP_BIGTABLE_KEY
from tests.test_utils.gcp_system_helpers import CLOUD_DAG_FOLDER, GoogleSystemTest, provide_gcp_context

GCP_PROJECT_ID = os.environ.get('GCP_PROJECT_ID', 'example-project')
CBT_INSTANCE = os.environ.get('CBT_INSTANCE_ID', 'testinstance')


@pytest.mark.backend("mysql", "postgres")
@pytest.mark.credential_file(GCP_BIGTABLE_KEY)
Expand All @@ -45,7 +42,7 @@ def tearDown(self):
'--verbosity=none',
'instances',
'delete',
CBT_INSTANCE,
CBT_INSTANCE_ID,
],
key=GCP_BIGTABLE_KEY,
)
Expand Down

0 comments on commit 810c15e

Please sign in to comment.