Skip to content

Commit

Permalink
[AIRFLOW-2911] Add job cancellation capability to Dataflow service (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
mik-laj committed Mar 9, 2020
1 parent faf0df4 commit e5130dc
Show file tree
Hide file tree
Showing 6 changed files with 398 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
start_java_job = DataflowCreateJavaJobOperator(
task_id="start-java-job",
jar=GCS_JAR,
job_name='{{task.task_id}}22222255sss{{ macros.uuid.uuid4() }}',
job_name='{{task.task_id}}',
options={
'output': GCS_OUTPUT,
},
Expand Down
225 changes: 169 additions & 56 deletions airflow/providers/google/cloud/hooks/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import uuid
from copy import deepcopy
from tempfile import TemporaryDirectory
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
from typing import Any, Callable, Dict, List, Optional, TypeVar

from googleapiclient.discovery import build

Expand Down Expand Up @@ -105,17 +105,34 @@ class DataflowJobType:


class _DataflowJobsController(LoggingMixin):
"""
Interface for communication with Google API.
It's not use Apache Beam, but only Google Dataflow API.
:param dataflow: Discovery resource
:param project_number: The Google Cloud Platform Project ID.
:param location: Job location.
:param poll_sleep: The status refresh rate for pending operations.
:param name: The Job ID prefix used when the multiple_jobs option is passed is set to True.
:param job_id: ID of a single job.
:param num_retries: Maximum number of retries in case of connection problems.
:param multiple_jobs: If set to true this task will be searched by name prefix (``name`` parameter),
not by specific job ID, then actions will be performed on all matching jobs.
"""
def __init__(
self,
dataflow: Any,
project_number: str,
name: str,
location: str,
poll_sleep: int = 10,
name: Optional[str] = None,
job_id: Optional[str] = None,
num_retries: int = 0,
multiple_jobs: bool = False
) -> None:

super().__init__()
self._dataflow = dataflow
self._project_number = project_number
self._job_name = name
Expand All @@ -135,7 +152,7 @@ def is_job_running(self) -> bool:
"""
self._refresh_jobs()
if not self._jobs:
raise ValueError("Could not read _jobs")
return False

for job in self._jobs:
if job['currentState'] not in DataflowJobStatus.END_STATES:
Expand Down Expand Up @@ -243,8 +260,10 @@ def wait_for_done(self) -> None:
"""
Helper method to wait for result of submitted job.
"""
self.log.info("Start waiting for done.")
self._refresh_jobs()
while self._jobs and not all(self._check_dataflow_job_state(job) for job in self._jobs):
self.log.info("Waiting for done. Sleep %s s", self._poll_sleep)
time.sleep(self._poll_sleep)
self._refresh_jobs()

Expand All @@ -262,10 +281,36 @@ def get_jobs(self) -> List[Dict]:

return self._jobs

def cancel(self) -> None:
"""
Cancels current job
"""
jobs = self._get_current_jobs()
batch = self._dataflow.new_batch_http_request()
job_ids = [job['id'] for job in jobs]
self.log.info("Canceling jobs: %s", ", ".join(job_ids))
for job_id in job_ids:
batch.add(
self._dataflow.projects().locations().jobs().update(
projectId=self._project_number,
location=self._job_location,
jobId=job_id,
body={
"requestedState": DataflowJobStatus.JOB_STATE_CANCELLED
}
)
)
batch.execute()


class _DataflowRunner(LoggingMixin):
def __init__(self, cmd: Union[List, str]) -> None:
def __init__(
self,
cmd: List[str],
on_new_job_id_callback: Optional[Callable[[str], None]] = None
) -> None:
self.log.info("Running command: %s", ' '.join(cmd))
self.on_new_job_id_callback = on_new_job_id_callback
self._proc = subprocess.Popen(
cmd,
shell=False,
Expand Down Expand Up @@ -302,6 +347,8 @@ def _extract_job(self, line: str) -> Optional[str]:
if matched_job:
job_id = matched_job.group(1)
self.log.info("Found Job ID: %s", job_id)
if self.on_new_job_id_callback:
self.on_new_job_id_callback(job_id)
return job_id
return None

Expand Down Expand Up @@ -376,10 +423,14 @@ def _start_dataflow(
label_formatter: Callable[[Dict], List[str]],
project_id: str,
multiple_jobs: bool = False,
on_new_job_id_callback: Optional[Callable[[str], None]] = None
) -> None:
variables = self._set_variables(variables)
cmd = command_prefix + self._build_cmd(variables, label_formatter, project_id)
runner = _DataflowRunner(cmd)
runner = _DataflowRunner(
cmd=cmd,
on_new_job_id_callback=on_new_job_id_callback
)
job_id = runner.wait_for_done()
job_controller = _DataflowJobsController(
dataflow=self.get_conn(),
Expand Down Expand Up @@ -409,7 +460,8 @@ def start_java_dataflow(
project_id: Optional[str] = None,
job_class: Optional[str] = None,
append_job_name: bool = True,
multiple_jobs: bool = False
multiple_jobs: bool = False,
on_new_job_id_callback: Optional[Callable[[str], None]] = None
) -> None:
"""
Starts Dataflow java job.
Expand All @@ -428,6 +480,8 @@ def start_java_dataflow(
:type append_job_name: bool
:param multiple_jobs: True if to check for multiple job in dataflow
:type multiple_jobs: bool
:param on_new_job_id_callback: Callback called when the job ID is known.
:type on_new_job_id_callback: callable
"""
if not project_id:
raise ValueError("The project_id should be set")
Expand All @@ -441,7 +495,15 @@ def label_formatter(labels_dict):

command_prefix = (["java", "-cp", jar, job_class] if job_class
else ["java", "-jar", jar])
self._start_dataflow(variables, name, command_prefix, label_formatter, project_id, multiple_jobs)
self._start_dataflow(
variables=variables,
name=name,
command_prefix=command_prefix,
label_formatter=label_formatter,
project_id=project_id,
multiple_jobs=multiple_jobs,
on_new_job_id_callback=on_new_job_id_callback
)

@_fallback_to_project_id_from_variables
@CloudBaseHook.fallback_to_default_project_id
Expand All @@ -452,8 +514,9 @@ def start_template_dataflow(
parameters: Dict,
dataflow_template: str,
project_id: Optional[str] = None,
append_job_name: bool = True
) -> None:
append_job_name: bool = True,
on_new_job_id_callback: Optional[Callable[[str], None]] = None
) -> Dict:
"""
Starts Dataflow template job.
Expand All @@ -469,18 +532,53 @@ def start_template_dataflow(
If set to None or missing, the default project_id from the GCP connection is used.
:param append_job_name: True if unique suffix has to be appended to job name.
:type append_job_name: bool
:param on_new_job_id_callback: Callback called when the job ID is known.
:type on_new_job_id_callback: callable
"""
if not project_id:
raise ValueError("The project_id should be set")

variables = self._set_variables(variables)
name = self._build_dataflow_job_name(job_name, append_job_name)
self._start_template_dataflow(
name, variables, parameters, dataflow_template, project_id)
# Builds RuntimeEnvironment from variables dictionary
# https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment
environment = {}
for key in ['numWorkers', 'maxWorkers', 'zone', 'serviceAccountEmail',
'tempLocation', 'bypassTempDirValidation', 'machineType',
'additionalExperiments', 'network', 'subnetwork', 'additionalUserLabels']:
if key in variables:
environment.update({key: variables[key]})
body = {"jobName": name,
"parameters": parameters,
"environment": environment}
service = self.get_conn()
request = service.projects().locations().templates().launch( # pylint: disable=no-member
projectId=project_id,
location=variables['region'],
gcsPath=dataflow_template,
body=body
)
response = request.execute(num_retries=self.num_retries)

job_id = response['job']['id']
if on_new_job_id_callback:
on_new_job_id_callback(job_id)

variables = self._set_variables(variables)
jobs_controller = _DataflowJobsController(
dataflow=self.get_conn(),
project_number=project_id,
name=name,
job_id=job_id,
location=variables['region'],
poll_sleep=self.poll_sleep,
num_retries=self.num_retries)
jobs_controller.wait_for_done()
return response["job"]

@_fallback_to_project_id_from_variables
@CloudBaseHook.fallback_to_default_project_id
def start_python_dataflow(
def start_python_dataflow( # pylint: disable=too-many-arguments
self,
job_name: str,
variables: Dict,
Expand All @@ -491,6 +589,7 @@ def start_python_dataflow(
py_system_site_packages: bool = False,
project_id: Optional[str] = None,
append_job_name: bool = True,
on_new_job_id_callback: Optional[Callable[[str], None]] = None
):
"""
Starts Dataflow job.
Expand Down Expand Up @@ -523,6 +622,8 @@ def start_python_dataflow(
:type append_job_name: bool
:param project_id: Optional, the GCP project ID in which to start a job.
If set to None or missing, the default project_id from the GCP connection is used.
:param on_new_job_id_callback: Callback called when the job ID is known.
:type on_new_job_id_callback: callable
"""
if not project_id:
raise ValueError("The project_id should be set")
Expand All @@ -542,12 +643,27 @@ def label_formatter(labels_dict):
system_site_packages=py_system_site_packages,
requirements=py_requirements,
)

self._start_dataflow(variables, name, [py_interpreter] + py_options + [dataflow],
label_formatter, project_id)
command_prefix = [py_interpreter] + py_options + [dataflow]

self._start_dataflow(
variables=variables,
name=name,
command_prefix=command_prefix,
label_formatter=label_formatter,
project_id=project_id,
on_new_job_id_callback=on_new_job_id_callback
)
else:
self._start_dataflow(variables, name, [py_interpreter] + py_options + [dataflow],
label_formatter, project_id)
command_prefix = [py_interpreter] + py_options + [dataflow]

self._start_dataflow(
variables=variables,
name=name,
command_prefix=command_prefix,
label_formatter=label_formatter,
project_id=project_id,
on_new_job_id_callback=on_new_job_id_callback
)

@staticmethod
def _build_dataflow_job_name(job_name: str, append_job_name: bool = True) -> str:
Expand Down Expand Up @@ -582,45 +698,6 @@ def _build_cmd(variables: Dict, label_formatter: Callable, project_id: str) -> L
command.append("--" + attr + "=" + value)
return command

def _start_template_dataflow(
self,
name: str,
variables: Dict[str, Any],
parameters: Dict,
dataflow_template: str,
project_id: str
) -> Dict:
# Builds RuntimeEnvironment from variables dictionary
# https://cloud.google.com/dataflow/docs/reference/rest/v1b3/RuntimeEnvironment
environment = {}
for key in ['numWorkers', 'maxWorkers', 'zone', 'serviceAccountEmail',
'tempLocation', 'bypassTempDirValidation', 'machineType',
'additionalExperiments', 'network', 'subnetwork', 'additionalUserLabels']:
if key in variables:
environment.update({key: variables[key]})
body = {"jobName": name,
"parameters": parameters,
"environment": environment}
service = self.get_conn()
request = service.projects().locations().templates().launch( # pylint: disable=no-member
projectId=project_id,
location=variables['region'],
gcsPath=dataflow_template,
body=body
)
response = request.execute(num_retries=self.num_retries)
variables = self._set_variables(variables)
jobs_controller = _DataflowJobsController(
dataflow=self.get_conn(),
project_number=project_id,
name=name,
job_id=response['job']['id'],
location=variables['region'],
poll_sleep=self.poll_sleep,
num_retries=self.num_retries)
jobs_controller.wait_for_done()
return response

@_fallback_to_project_id_from_variables
@CloudBaseHook.fallback_to_default_project_id
def is_job_dataflow_running(self, name: str, variables: Dict, project_id: Optional[str] = None) -> bool:
Expand Down Expand Up @@ -648,3 +725,39 @@ def is_job_dataflow_running(self, name: str, variables: Dict, project_id: Option
poll_sleep=self.poll_sleep
)
return jobs_controller.is_job_running()

@CloudBaseHook.fallback_to_default_project_id
def cancel_job(
self,
job_name: Optional[str] = None,
job_id: Optional[str] = None,
location: Optional[str] = None,
project_id: Optional[str] = None
) -> None:
"""
Cancels the job with the specified name prefix or Job ID.
Parameter ``name`` and ``job_id`` are mutually exclusive.
:param job_name: Name prefix specifying which jobs are to be canceled.
:type job_name: str
:param job_id: Job ID specifying which jobs are to be canceled.
:type job_id: str
:param location: Job location.
:type location: str
:param project_id: Optional, the GCP project ID in which to start a job.
If set to None or missing, the default project_id from the GCP connection is used.
:type project_id:
"""
if not project_id:
raise ValueError("The project_id should be set")

jobs_controller = _DataflowJobsController(
dataflow=self.get_conn(),
project_number=project_id,
name=job_name,
job_id=job_id,
location=location or DEFAULT_DATAFLOW_LOCATION,
poll_sleep=self.poll_sleep
)
jobs_controller.cancel()

0 comments on commit e5130dc

Please sign in to comment.