Skip to content

Commit

Permalink
Refactor DataprocCreateCluster operator to use simpler interface (#10403
Browse files Browse the repository at this point in the history
)

DataprocCreateCluster requires now:
- cluster config
- cluster name
- project id

In this way users don't have to pass project_id two times 
(in cluster definition and as parameter). The cluster object 
is built in create_cluster hook method
  • Loading branch information
turbaszek committed Sep 7, 2020
1 parent 1959d6a commit c8ee455
Show file tree
Hide file tree
Showing 7 changed files with 276 additions and 310 deletions.
42 changes: 21 additions & 21 deletions airflow/providers/google/cloud/example_dags/example_dataproc.py
Expand Up @@ -35,7 +35,7 @@
PROJECT_ID = os.environ.get("GCP_PROJECT_ID", "an-id")
CLUSTER_NAME = os.environ.get("GCP_DATAPROC_CLUSTER_NAME", "example-project")
REGION = os.environ.get("GCP_LOCATION", "europe-west1")
ZONE = os.environ.get("GCP_REGION", "europe-west-1b")
ZONE = os.environ.get("GCP_REGION", "europe-west1-b")
BUCKET = os.environ.get("GCP_DATAPROC_BUCKET", "dataproc-system-tests")
OUTPUT_FOLDER = "wordcount"
OUTPUT_PATH = "gs://{}/{}/".format(BUCKET, OUTPUT_FOLDER)
Expand All @@ -47,20 +47,16 @@
# Cluster definition
# [START how_to_cloud_dataproc_create_cluster]

CLUSTER = {
"project_id": PROJECT_ID,
"cluster_name": CLUSTER_NAME,
"config": {
"master_config": {
"num_instances": 1,
"machine_type_uri": "n1-standard-4",
"disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb": 1024},
},
"worker_config": {
"num_instances": 2,
"machine_type_uri": "n1-standard-4",
"disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb": 1024},
},
CLUSTER_CONFIG = {
"master_config": {
"num_instances": 1,
"machine_type_uri": "n1-standard-4",
"disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb": 1024},
},
"worker_config": {
"num_instances": 2,
"machine_type_uri": "n1-standard-4",
"disk_config": {"boot_disk_type": "pd-standard", "boot_disk_size_gb": 1024},
},
}

Expand All @@ -69,10 +65,10 @@
# Update options
# [START how_to_cloud_dataproc_updatemask_cluster_operator]
CLUSTER_UPDATE = {
"config": {"worker_config": {"num_instances": 3}, "secondary_worker_config": {"num_instances": 3},}
"config": {"worker_config": {"num_instances": 3}, "secondary_worker_config": {"num_instances": 3}}
}
UPDATE_MASK = {
"paths": ["config.worker_config.num_instances", "config.secondary_worker_config.num_instances",]
"paths": ["config.worker_config.num_instances", "config.secondary_worker_config.num_instances"]
}
# [END how_to_cloud_dataproc_updatemask_cluster_operator]

Expand Down Expand Up @@ -141,10 +137,14 @@
}
# [END how_to_cloud_dataproc_hadoop_config]

with models.DAG("example_gcp_dataproc", start_date=days_ago(1), schedule_interval=None,) as dag:
with models.DAG("example_gcp_dataproc", start_date=days_ago(1), schedule_interval=None) as dag:
# [START how_to_cloud_dataproc_create_cluster_operator]
create_cluster = DataprocCreateClusterOperator(
task_id="create_cluster", project_id=PROJECT_ID, cluster=CLUSTER, region=REGION
task_id="create_cluster",
project_id=PROJECT_ID,
cluster_config=CLUSTER_CONFIG,
region=REGION,
cluster_name=CLUSTER_NAME,
)
# [END how_to_cloud_dataproc_create_cluster_operator]

Expand All @@ -164,7 +164,7 @@
task_id="pig_task", job=PIG_JOB, location=REGION, project_id=PROJECT_ID
)
spark_sql_task = DataprocSubmitJobOperator(
task_id="spark_sql_task", job=SPARK_SQL_JOB, location=REGION, project_id=PROJECT_ID,
task_id="spark_sql_task", job=SPARK_SQL_JOB, location=REGION, project_id=PROJECT_ID
)

spark_task = DataprocSubmitJobOperator(
Expand Down Expand Up @@ -205,7 +205,7 @@

# [START how_to_cloud_dataproc_delete_cluster_operator]
delete_cluster = DataprocDeleteClusterOperator(
task_id="delete_cluster", project_id=PROJECT_ID, cluster_name=CLUSTER_NAME, region=REGION,
task_id="delete_cluster", project_id=PROJECT_ID, cluster_name=CLUSTER_NAME, region=REGION
)
# [END how_to_cloud_dataproc_delete_cluster_operator]

Expand Down
29 changes: 24 additions & 5 deletions airflow/providers/google/cloud/hooks/dataproc.py
Expand Up @@ -63,7 +63,7 @@ def __init__(
self.job_type = job_type
self.job = {
"job": {
"reference": {"project_id": project_id, "job_id": name,},
"reference": {"project_id": project_id, "job_id": name},
"placement": {"cluster_name": cluster_name},
"labels": {'airflow-version': 'v' + airflow_version.replace('.', '-').replace('+', '-')},
job_type: {},
Expand Down Expand Up @@ -250,8 +250,10 @@ def get_job_client(self, location: Optional[str] = None) -> JobControllerClient:
def create_cluster(
self,
region: str,
cluster: Union[Dict, Cluster],
project_id: str,
cluster_name: str,
cluster_config: Union[Dict, Cluster],
labels: Optional[Dict[str, str]] = None,
request_id: Optional[str] = None,
retry: Optional[Retry] = None,
timeout: Optional[float] = None,
Expand All @@ -264,10 +266,14 @@ def create_cluster(
:type project_id: str
:param region: Required. The Cloud Dataproc region in which to handle the request.
:type region: str
:param cluster: Required. The cluster to create.
:param cluster_name: Name of the cluster to create
:type cluster_name: str
:param labels: Labels that will be assigned to created cluster
:type labels: Dict[str, str]
:param cluster_config: Required. The cluster config to create.
If a dict is provided, it must be of the same form as the protobuf message
:class:`~google.cloud.dataproc_v1.types.Cluster`
:type cluster: Union[Dict, google.cloud.dataproc_v1.types.Cluster]
:class:`~google.cloud.dataproc_v1.types.ClusterConfig`
:type cluster_config: Union[Dict, google.cloud.dataproc_v1.types.ClusterConfig]
:param request_id: Optional. A unique id used to identify the request. If the server receives two
``CreateClusterRequest`` requests with the same id, then the second request will be ignored and
the first ``google.longrunning.Operation`` created and stored in the backend is returned.
Expand All @@ -281,6 +287,19 @@ def create_cluster(
:param metadata: Additional metadata that is provided to the method.
:type metadata: Sequence[Tuple[str, str]]
"""
# Dataproc labels must conform to the following regex:
# [a-z]([-a-z0-9]*[a-z0-9])? (current airflow version string follows
# semantic versioning spec: x.y.z).
labels = labels or {}
labels.update({'airflow-version': 'v' + airflow_version.replace('.', '-').replace('+', '-')})

cluster = {
"project_id": project_id,
"cluster_name": cluster_name,
"config": cluster_config,
"labels": labels,
}

client = self.get_cluster_client(location=region)
result = client.create_cluster(
project_id=project_id,
Expand Down

0 comments on commit c8ee455

Please sign in to comment.