Skip to content

Commit 08d15d0

Browse files
authored
Add support for driver pool, instance flexibility policy, and min_num_instances for Dataproc (#34172)
1 parent 5983506 commit 08d15d0

File tree

4 files changed

+228
-1
lines changed

4 files changed

+228
-1
lines changed

airflow/providers/google/cloud/operators/dataproc.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import time
2626
import uuid
2727
import warnings
28+
from dataclasses import dataclass
2829
from datetime import datetime, timedelta
2930
from enum import Enum
3031
from typing import TYPE_CHECKING, Any, Sequence
@@ -77,6 +78,38 @@ class PreemptibilityType(Enum):
7778
NON_PREEMPTIBLE = "NON_PREEMPTIBLE"
7879

7980

81+
@dataclass
82+
class InstanceSelection:
83+
"""Defines machines types and a rank to which the machines types belong.
84+
85+
Representation for
86+
google.cloud.dataproc.v1#google.cloud.dataproc.v1.InstanceFlexibilityPolicy.InstanceSelection.
87+
88+
:param machine_types: Full machine-type names, e.g. "n1-standard-16".
89+
:param rank: Preference of this instance selection. Lower number means higher preference.
90+
Dataproc will first try to create a VM based on the machine-type with priority rank and fallback
91+
to next rank based on availability. Machine types and instance selections with the same priority have
92+
the same preference.
93+
"""
94+
95+
machine_types: list[str]
96+
rank: int = 0
97+
98+
99+
@dataclass
100+
class InstanceFlexibilityPolicy:
101+
"""
102+
Instance flexibility Policy allowing a mixture of VM shapes and provisioning models.
103+
104+
Representation for google.cloud.dataproc.v1#google.cloud.dataproc.v1.InstanceFlexibilityPolicy.
105+
106+
:param instance_selection_list: List of instance selection options that the group will use when
107+
creating new VMs.
108+
"""
109+
110+
instance_selection_list: list[InstanceSelection]
111+
112+
80113
class ClusterGenerator:
81114
"""Create a new Dataproc Cluster.
82115
@@ -85,6 +118,11 @@ class ClusterGenerator:
85118
to create the cluster. (templated)
86119
:param num_workers: The # of workers to spin up. If set to zero will
87120
spin up cluster in a single node mode
121+
:param min_num_workers: The minimum number of primary worker instances to create.
122+
If more than ``min_num_workers`` VMs are created out of ``num_workers``, the failed VMs will be
123+
deleted, cluster is resized to available VMs and set to RUNNING.
124+
If created VMs are less than ``min_num_workers``, the cluster is placed in ERROR state. The failed
125+
VMs are not deleted.
88126
:param storage_bucket: The storage bucket to use, setting to None lets dataproc
89127
generate a custom one for you
90128
:param init_actions_uris: List of GCS uri's containing
@@ -153,12 +191,18 @@ class ClusterGenerator:
153191
``projects/[PROJECT_STORING_KEYS]/locations/[LOCATION]/keyRings/[KEY_RING_NAME]/cryptoKeys/[KEY_NAME]`` # noqa
154192
:param enable_component_gateway: Provides access to the web interfaces of default and selected optional
155193
components on the cluster.
194+
:param driver_pool_size: The number of driver nodes in the node group.
195+
:param driver_pool_id: The ID for the driver pool. Must be unique within the cluster. Use this ID to
196+
identify the driver group in future operations, such as resizing the node group.
197+
:param secondary_worker_instance_flexibility_policy: Instance flexibility Policy allowing a mixture of VM
198+
shapes and provisioning models.
156199
"""
157200

158201
def __init__(
159202
self,
160203
project_id: str,
161204
num_workers: int | None = None,
205+
min_num_workers: int | None = None,
162206
zone: str | None = None,
163207
network_uri: str | None = None,
164208
subnetwork_uri: str | None = None,
@@ -191,11 +235,15 @@ def __init__(
191235
auto_delete_ttl: int | None = None,
192236
customer_managed_key: str | None = None,
193237
enable_component_gateway: bool | None = False,
238+
driver_pool_size: int = 0,
239+
driver_pool_id: str | None = None,
240+
secondary_worker_instance_flexibility_policy: InstanceFlexibilityPolicy | None = None,
194241
**kwargs,
195242
) -> None:
196243
self.project_id = project_id
197244
self.num_masters = num_masters
198245
self.num_workers = num_workers
246+
self.min_num_workers = min_num_workers
199247
self.num_preemptible_workers = num_preemptible_workers
200248
self.preemptibility = self._set_preemptibility_type(preemptibility)
201249
self.storage_bucket = storage_bucket
@@ -228,6 +276,9 @@ def __init__(
228276
self.customer_managed_key = customer_managed_key
229277
self.enable_component_gateway = enable_component_gateway
230278
self.single_node = num_workers == 0
279+
self.driver_pool_size = driver_pool_size
280+
self.driver_pool_id = driver_pool_id
281+
self.secondary_worker_instance_flexibility_policy = secondary_worker_instance_flexibility_policy
231282

232283
if self.custom_image and self.image_version:
233284
raise ValueError("The custom_image and image_version can't be both set")
@@ -241,6 +292,15 @@ def __init__(
241292
if self.single_node and self.num_preemptible_workers > 0:
242293
raise ValueError("Single node cannot have preemptible workers.")
243294

295+
if self.min_num_workers:
296+
if not self.num_workers:
297+
raise ValueError("Must specify num_workers when min_num_workers are provided.")
298+
if self.min_num_workers > self.num_workers:
299+
raise ValueError(
300+
"The value of min_num_workers must be less than or equal to num_workers. "
301+
f"Provided {self.min_num_workers}(min_num_workers) and {self.num_workers}(num_workers)."
302+
)
303+
244304
def _set_preemptibility_type(self, preemptibility: str):
245305
return PreemptibilityType(preemptibility.upper())
246306

@@ -307,6 +367,17 @@ def _build_lifecycle_config(self, cluster_data):
307367

308368
return cluster_data
309369

370+
def _build_driver_pool(self):
371+
driver_pool = {
372+
"node_group": {
373+
"roles": ["DRIVER"],
374+
"node_group_config": {"num_instances": self.driver_pool_size},
375+
},
376+
}
377+
if self.driver_pool_id:
378+
driver_pool["node_group_id"] = self.driver_pool_id
379+
return driver_pool
380+
310381
def _build_cluster_data(self):
311382
if self.zone:
312383
master_type_uri = (
@@ -344,6 +415,10 @@ def _build_cluster_data(self):
344415
"autoscaling_config": {},
345416
"endpoint_config": {},
346417
}
418+
419+
if self.min_num_workers:
420+
cluster_data["worker_config"]["min_num_instances"] = self.min_num_workers
421+
347422
if self.num_preemptible_workers > 0:
348423
cluster_data["secondary_worker_config"] = {
349424
"num_instances": self.num_preemptible_workers,
@@ -355,6 +430,13 @@ def _build_cluster_data(self):
355430
"is_preemptible": True,
356431
"preemptibility": self.preemptibility.value,
357432
}
433+
if self.secondary_worker_instance_flexibility_policy:
434+
cluster_data["secondary_worker_config"]["instance_flexibility_policy"] = {
435+
"instance_selection_list": [
436+
vars(s)
437+
for s in self.secondary_worker_instance_flexibility_policy.instance_selection_list
438+
]
439+
}
358440

359441
if self.storage_bucket:
360442
cluster_data["config_bucket"] = self.storage_bucket
@@ -382,6 +464,9 @@ def _build_cluster_data(self):
382464
if not self.single_node:
383465
cluster_data["worker_config"]["image_uri"] = custom_image_url
384466

467+
if self.driver_pool_size > 0:
468+
cluster_data["auxiliary_node_groups"] = [self._build_driver_pool()]
469+
385470
cluster_data = self._build_gce_cluster_config(cluster_data)
386471

387472
if self.single_node:

airflow/providers/google/provider.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ dependencies:
102102
- google-cloud-dataflow-client>=0.8.2
103103
- google-cloud-dataform>=0.5.0
104104
- google-cloud-dataplex>=1.4.2
105-
- google-cloud-dataproc>=5.4.0
105+
- google-cloud-dataproc>=5.5.0
106106
- google-cloud-dataproc-metastore>=1.12.0
107107
- google-cloud-dlp>=3.12.0
108108
- google-cloud-kms>=2.15.0

docs/spelling_wordlist.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -792,7 +792,9 @@ InspectContentResponse
792792
InspectTemplate
793793
instafail
794794
installable
795+
InstanceFlexibilityPolicy
795796
InstanceGroupConfig
797+
InstanceSelection
796798
instanceTemplates
797799
instantiation
798800
integrations

tests/providers/google/cloud/operators/test_dataproc.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060
DataprocSubmitSparkJobOperator,
6161
DataprocSubmitSparkSqlJobOperator,
6262
DataprocUpdateClusterOperator,
63+
InstanceFlexibilityPolicy,
64+
InstanceSelection,
6365
)
6466
from airflow.providers.google.cloud.triggers.dataproc import (
6567
DataprocBatchTrigger,
@@ -112,6 +114,7 @@
112114
"disk_config": {"boot_disk_type": "worker_disk_type", "boot_disk_size_gb": 256},
113115
"image_uri": "https://www.googleapis.com/compute/beta/projects/"
114116
"custom_image_project_id/global/images/custom_image",
117+
"min_num_instances": 1,
115118
},
116119
"secondary_worker_config": {
117120
"num_instances": 4,
@@ -132,6 +135,17 @@
132135
{"executable_file": "init_actions_uris", "execution_timeout": {"seconds": 600}}
133136
],
134137
"endpoint_config": {},
138+
"auxiliary_node_groups": [
139+
{
140+
"node_group": {
141+
"roles": ["DRIVER"],
142+
"node_group_config": {
143+
"num_instances": 2,
144+
},
145+
},
146+
"node_group_id": "cluster_driver_pool",
147+
}
148+
],
135149
}
136150
VIRTUAL_CLUSTER_CONFIG = {
137151
"kubernetes_cluster_config": {
@@ -197,6 +211,64 @@
197211
},
198212
}
199213

214+
CONFIG_WITH_FLEX_MIG = {
215+
"gce_cluster_config": {
216+
"zone_uri": "https://www.googleapis.com/compute/v1/projects/project_id/zones/zone",
217+
"metadata": {"metadata": "data"},
218+
"network_uri": "network_uri",
219+
"subnetwork_uri": "subnetwork_uri",
220+
"internal_ip_only": True,
221+
"tags": ["tags"],
222+
"service_account": "service_account",
223+
"service_account_scopes": ["service_account_scopes"],
224+
},
225+
"master_config": {
226+
"num_instances": 2,
227+
"machine_type_uri": "projects/project_id/zones/zone/machineTypes/master_machine_type",
228+
"disk_config": {"boot_disk_type": "master_disk_type", "boot_disk_size_gb": 128},
229+
"image_uri": "https://www.googleapis.com/compute/beta/projects/"
230+
"custom_image_project_id/global/images/custom_image",
231+
},
232+
"worker_config": {
233+
"num_instances": 2,
234+
"machine_type_uri": "projects/project_id/zones/zone/machineTypes/worker_machine_type",
235+
"disk_config": {"boot_disk_type": "worker_disk_type", "boot_disk_size_gb": 256},
236+
"image_uri": "https://www.googleapis.com/compute/beta/projects/"
237+
"custom_image_project_id/global/images/custom_image",
238+
},
239+
"secondary_worker_config": {
240+
"num_instances": 4,
241+
"machine_type_uri": "projects/project_id/zones/zone/machineTypes/worker_machine_type",
242+
"disk_config": {"boot_disk_type": "worker_disk_type", "boot_disk_size_gb": 256},
243+
"is_preemptible": True,
244+
"preemptibility": "SPOT",
245+
"instance_flexibility_policy": {
246+
"instance_selection_list": [
247+
{
248+
"machine_types": [
249+
"projects/project_id/zones/zone/machineTypes/machine1",
250+
"projects/project_id/zones/zone/machineTypes/machine2",
251+
],
252+
"rank": 0,
253+
},
254+
{"machine_types": ["projects/project_id/zones/zone/machineTypes/machine3"], "rank": 1},
255+
],
256+
},
257+
},
258+
"software_config": {"properties": {"properties": "data"}, "optional_components": ["optional_components"]},
259+
"lifecycle_config": {
260+
"idle_delete_ttl": {"seconds": 60},
261+
"auto_delete_time": "2019-09-12T00:00:00.000000Z",
262+
},
263+
"encryption_config": {"gce_pd_kms_key_name": "customer_managed_key"},
264+
"autoscaling_config": {"policy_uri": "autoscaling_policy"},
265+
"config_bucket": "storage_bucket",
266+
"initialization_actions": [
267+
{"executable_file": "init_actions_uris", "execution_timeout": {"seconds": 600}}
268+
],
269+
"endpoint_config": {},
270+
}
271+
200272
LABELS = {"labels": "data", "airflow-version": AIRFLOW_VERSION}
201273

202274
LABELS.update({"airflow-version": "v" + airflow_version.replace(".", "-").replace("+", "-")})
@@ -361,10 +433,26 @@ def test_nodes_number(self):
361433
)
362434
assert "num_workers == 0 means single" in str(ctx.value)
363435

436+
def test_min_num_workers_less_than_num_workers(self):
437+
with pytest.raises(ValueError) as ctx:
438+
ClusterGenerator(
439+
num_workers=3, min_num_workers=4, project_id=GCP_PROJECT, cluster_name=CLUSTER_NAME
440+
)
441+
assert (
442+
"The value of min_num_workers must be less than or equal to num_workers. "
443+
"Provided 4(min_num_workers) and 3(num_workers)." in str(ctx.value)
444+
)
445+
446+
def test_min_num_workers_without_num_workers(self):
447+
with pytest.raises(ValueError) as ctx:
448+
ClusterGenerator(min_num_workers=4, project_id=GCP_PROJECT, cluster_name=CLUSTER_NAME)
449+
assert "Must specify num_workers when min_num_workers are provided." in str(ctx.value)
450+
364451
def test_build(self):
365452
generator = ClusterGenerator(
366453
project_id="project_id",
367454
num_workers=2,
455+
min_num_workers=1,
368456
zone="zone",
369457
network_uri="network_uri",
370458
subnetwork_uri="subnetwork_uri",
@@ -395,6 +483,8 @@ def test_build(self):
395483
auto_delete_time=datetime(2019, 9, 12),
396484
auto_delete_ttl=250,
397485
customer_managed_key="customer_managed_key",
486+
driver_pool_id="cluster_driver_pool",
487+
driver_pool_size=2,
398488
)
399489
cluster = generator.make()
400490
assert CONFIG == cluster
@@ -438,6 +528,56 @@ def test_build_with_custom_image_family(self):
438528
cluster = generator.make()
439529
assert CONFIG_WITH_CUSTOM_IMAGE_FAMILY == cluster
440530

531+
def test_build_with_flex_migs(self):
532+
generator = ClusterGenerator(
533+
project_id="project_id",
534+
num_workers=2,
535+
zone="zone",
536+
network_uri="network_uri",
537+
subnetwork_uri="subnetwork_uri",
538+
internal_ip_only=True,
539+
tags=["tags"],
540+
storage_bucket="storage_bucket",
541+
init_actions_uris=["init_actions_uris"],
542+
init_action_timeout="10m",
543+
metadata={"metadata": "data"},
544+
custom_image="custom_image",
545+
custom_image_project_id="custom_image_project_id",
546+
autoscaling_policy="autoscaling_policy",
547+
properties={"properties": "data"},
548+
optional_components=["optional_components"],
549+
num_masters=2,
550+
master_machine_type="master_machine_type",
551+
master_disk_type="master_disk_type",
552+
master_disk_size=128,
553+
worker_machine_type="worker_machine_type",
554+
worker_disk_type="worker_disk_type",
555+
worker_disk_size=256,
556+
num_preemptible_workers=4,
557+
preemptibility="Spot",
558+
region="region",
559+
service_account="service_account",
560+
service_account_scopes=["service_account_scopes"],
561+
idle_delete_ttl=60,
562+
auto_delete_time=datetime(2019, 9, 12),
563+
auto_delete_ttl=250,
564+
customer_managed_key="customer_managed_key",
565+
secondary_worker_instance_flexibility_policy=InstanceFlexibilityPolicy(
566+
[
567+
InstanceSelection(
568+
[
569+
"projects/project_id/zones/zone/machineTypes/machine1",
570+
"projects/project_id/zones/zone/machineTypes/machine2",
571+
],
572+
0,
573+
),
574+
InstanceSelection(["projects/project_id/zones/zone/machineTypes/machine3"], 1),
575+
]
576+
),
577+
)
578+
cluster = generator.make()
579+
assert CONFIG_WITH_FLEX_MIG == cluster
580+
441581

442582
class TestDataprocClusterCreateOperator(DataprocClusterTestBase):
443583
def test_deprecation_warning(self):

0 commit comments

Comments
 (0)