Skip to content

Commit

Permalink
[Airflow-15245] - passing custom image family name to the DataProcClu…
Browse files Browse the repository at this point in the history
…sterCreateoperator (#15250)

* [airflow-15245] - custom_image_family added as a parameter to DataprocCreateClusterOperator

Signed-off-by: ashish <[email protected]>

* [airflow-15245] - test added to check both custom_image and custom_image_family must not be passed

Signed-off-by: ashish <[email protected]>

* [airflow-#15245] - typo fixed in documentation

Signed-off-by: ashish <[email protected]>

* [Airflow-15245] - comments updated, more info provided.

* [Airflow-15245] - sanity check added for image_version and custom_image_family.

* Update airflow/providers/google/cloud/operators/dataproc.py

Co-authored-by: Xinbin Huang <[email protected]>

* Apply suggestions from code review

Co-authored-by: Xinbin Huang <[email protected]>

* [Airflow-15245] - added a test case to verify the generated cluster config is as expected with custom_image_family and single_node.

* Remove print() from test case

Co-authored-by: Ashish Patel <[email protected]>
Co-authored-by: Xinbin Huang <[email protected]>
  • Loading branch information
3 people committed Apr 18, 2021
1 parent 99ec208 commit 6da36ba
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 0 deletions.
22 changes: 22 additions & 0 deletions airflow/providers/google/cloud/operators/dataproc.py
Expand Up @@ -75,6 +75,10 @@ class ClusterGenerator:
:param custom_image_project_id: project id for the custom Dataproc image, for more info see
https://cloud.google.com/dataproc/docs/guides/dataproc-images
:type custom_image_project_id: str
:param custom_image_family: family for the custom Dataproc image,
family name can be provide using --family flag while creating custom image, for more info see
https://cloud.google.com/dataproc/docs/guides/dataproc-images
:type custom_image_family: str
:param autoscaling_policy: The autoscaling policy used by the cluster. Only resource names
including projectid and location (region) are valid. Example:
``projects/[projectId]/locations/[dataproc_region]/autoscalingPolicies/[policy_id]``
Expand Down Expand Up @@ -163,6 +167,7 @@ def __init__(
metadata: Optional[Dict] = None,
custom_image: Optional[str] = None,
custom_image_project_id: Optional[str] = None,
custom_image_family: Optional[str] = None,
image_version: Optional[str] = None,
autoscaling_policy: Optional[str] = None,
properties: Optional[Dict] = None,
Expand Down Expand Up @@ -194,6 +199,7 @@ def __init__(
self.metadata = metadata
self.custom_image = custom_image
self.custom_image_project_id = custom_image_project_id
self.custom_image_family = custom_image_family
self.image_version = image_version
self.properties = properties or {}
self.optional_components = optional_components
Expand All @@ -220,6 +226,12 @@ def __init__(
if self.custom_image and self.image_version:
raise ValueError("The custom_image and image_version can't be both set")

if self.custom_image_family and self.image_version:
raise ValueError("The image_version and custom_image_family can't be both set")

if self.custom_image_family and self.custom_image:
raise ValueError("The custom_image and custom_image_family can't be both set")

if self.single_node and self.num_preemptible_workers > 0:
raise ValueError("Single node cannot have preemptible workers.")

Expand Down Expand Up @@ -346,6 +358,16 @@ def _build_cluster_data(self):
if not self.single_node:
cluster_data['worker_config']['image_uri'] = custom_image_url

elif self.custom_image_family:
project_id = self.custom_image_project_id or self.project_id
custom_image_url = (
'https://www.googleapis.com/compute/beta/projects/'
f'{project_id}/global/images/family/{self.custom_image_family}'
)
cluster_data['master_config']['image_uri'] = custom_image_url
if not self.single_node:
cluster_data['worker_config']['image_uri'] = custom_image_url

cluster_data = self._build_gce_cluster_config(cluster_data)

if self.single_node:
Expand Down
101 changes: 101 additions & 0 deletions tests/providers/google/cloud/operators/test_dataproc.py
Expand Up @@ -101,6 +101,50 @@
],
}

CONFIG_WITH_CUSTOM_IMAGE_FAMILY = {
"gce_cluster_config": {
"zone_uri": "https://www.googleapis.com/compute/v1/projects/project_id/zones/zone",
"metadata": {"metadata": "data"},
"network_uri": "network_uri",
"subnetwork_uri": "subnetwork_uri",
"internal_ip_only": True,
"tags": ["tags"],
"service_account": "service_account",
"service_account_scopes": ["service_account_scopes"],
},
"master_config": {
"num_instances": 2,
"machine_type_uri": "projects/project_id/zones/zone/machineTypes/master_machine_type",
"disk_config": {"boot_disk_type": "master_disk_type", "boot_disk_size_gb": 128},
"image_uri": "https://www.googleapis.com/compute/beta/projects/"
"custom_image_project_id/global/images/family/custom_image_family",
},
"worker_config": {
"num_instances": 2,
"machine_type_uri": "projects/project_id/zones/zone/machineTypes/worker_machine_type",
"disk_config": {"boot_disk_type": "worker_disk_type", "boot_disk_size_gb": 256},
"image_uri": "https://www.googleapis.com/compute/beta/projects/"
"custom_image_project_id/global/images/family/custom_image_family",
},
"secondary_worker_config": {
"num_instances": 4,
"machine_type_uri": "projects/project_id/zones/zone/machineTypes/worker_machine_type",
"disk_config": {"boot_disk_type": "worker_disk_type", "boot_disk_size_gb": 256},
"is_preemptible": True,
},
"software_config": {"properties": {"properties": "data"}, "optional_components": ["optional_components"]},
"lifecycle_config": {
"idle_delete_ttl": {'seconds': 60},
"auto_delete_time": "2019-09-12T00:00:00.000000Z",
},
"encryption_config": {"gce_pd_kms_key_name": "customer_managed_key"},
"autoscaling_config": {"policy_uri": "autoscaling_policy"},
"config_bucket": "storage_bucket",
"initialization_actions": [
{"executable_file": "init_actions_uris", "execution_timeout": {'seconds': 600}}
],
}

LABELS = {"labels": "data", "airflow-version": AIRFLOW_VERSION}

LABELS.update({'airflow-version': 'v' + airflow_version.replace('.', '-').replace('+', '-')})
Expand Down Expand Up @@ -144,6 +188,26 @@ def test_image_version(self):
)
assert "custom_image and image_version" in str(ctx.value)

def test_custom_image_family_error_with_image_version(self):
with pytest.raises(ValueError) as ctx:
ClusterGenerator(
image_version="image_version",
custom_image_family="custom_image_family",
project_id=GCP_PROJECT,
cluster_name=CLUSTER_NAME,
)
assert "image_version and custom_image_family" in str(ctx.value)

def test_custom_image_family_error_with_custom_image(self):
with pytest.raises(ValueError) as ctx:
ClusterGenerator(
custom_image="custom_image",
custom_image_family="custom_image_family",
project_id=GCP_PROJECT,
cluster_name=CLUSTER_NAME,
)
assert "custom_image and custom_image_family" in str(ctx.value)

def test_nodes_number(self):
with pytest.raises(AssertionError) as ctx:
ClusterGenerator(
Expand Down Expand Up @@ -188,6 +252,43 @@ def test_build(self):
cluster = generator.make()
assert CONFIG == cluster

def test_build_with_custom_image_family(self):
generator = ClusterGenerator(
project_id="project_id",
num_workers=2,
zone="zone",
network_uri="network_uri",
subnetwork_uri="subnetwork_uri",
internal_ip_only=True,
tags=["tags"],
storage_bucket="storage_bucket",
init_actions_uris=["init_actions_uris"],
init_action_timeout="10m",
metadata={"metadata": "data"},
custom_image_family="custom_image_family",
custom_image_project_id="custom_image_project_id",
autoscaling_policy="autoscaling_policy",
properties={"properties": "data"},
optional_components=["optional_components"],
num_masters=2,
master_machine_type="master_machine_type",
master_disk_type="master_disk_type",
master_disk_size=128,
worker_machine_type="worker_machine_type",
worker_disk_type="worker_disk_type",
worker_disk_size=256,
num_preemptible_workers=4,
region="region",
service_account="service_account",
service_account_scopes=["service_account_scopes"],
idle_delete_ttl=60,
auto_delete_time=datetime(2019, 9, 12),
auto_delete_ttl=250,
customer_managed_key="customer_managed_key",
)
cluster = generator.make()
assert CONFIG_WITH_CUSTOM_IMAGE_FAMILY == cluster


class TestDataprocClusterCreateOperator(unittest.TestCase):
def test_deprecation_warning(self):
Expand Down

0 comments on commit 6da36ba

Please sign in to comment.