Skip to content

Commit

Permalink
[AIRFLOW-6759] Added MLEngine operator/hook to cancel MLEngine jobs (#…
Browse files Browse the repository at this point in the history
…7400)

* [AIRFLOW-6759] Added MLEngine operator/hook to cancel MLEngine jobs

* Update airflow/providers/google/cloud/hooks/mlengine.py

Added types for `job_id`

Co-Authored-By: Tomek Urbaszek <[email protected]>

* Updates cancel_job doc

* Update airflow/providers/google/cloud/hooks/mlengine.py

cleaner formating

Co-Authored-By: Tomek Urbaszek <[email protected]>

* removed redundant error checking

Co-authored-by: Tomek Urbaszek <[email protected]>
  • Loading branch information
MarkYHZhang and turbaszek committed Feb 17, 2020
1 parent 946bdc2 commit 2c9345a
Show file tree
Hide file tree
Showing 4 changed files with 267 additions and 0 deletions.
45 changes: 45 additions & 0 deletions airflow/providers/google/cloud/hooks/mlengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,51 @@ def create_job(

return self._wait_for_job_done(project_id, job_id)

@CloudBaseHook.fallback_to_default_project_id
def cancel_job(
self,
job_id: str,
project_id: Optional[str] = None
) -> Dict:

"""
Cancels a MLEngine job.
:param project_id: The Google Cloud project id within which MLEngine
job will be cancelled. If set to None or missing, the default project_id from the GCP
connection is used.
:type project_id: str
:param job_id: A unique id for the want-to-be cancelled Google MLEngine training job.
:type job_id: str
:return: Empty dict if cancelled successfully
:rtype: dict
:raises: googleapiclient.errors.HttpError
"""

if not project_id:
raise ValueError("The project_id should be set")

hook = self.get_conn()

request = hook.projects().jobs().cancel( # pylint: disable=no-member
name=f'projects/{project_id}/jobs/{job_id}')

try:
return request.execute()
except HttpError as e:
if e.resp.status == 404:
self.log.error('Job with job_id %s does not exist. ', job_id)
raise
elif e.resp.status == 400:
self.log.info(
'Job with job_id %s is already complete, cancellation aborted.',
job_id)
return {}
else:
self.log.error('Failed to cancel MLEngine job: %s', e)
raise

def _get_job(self, project_id: str, job_id: str) -> Dict:
"""
Gets a MLEngine job based on the job id.
Expand Down
51 changes: 51 additions & 0 deletions airflow/providers/google/cloud/operators/mlengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1015,3 +1015,54 @@ def check_existing_job(existing_job):
if finished_training_job['state'] != 'SUCCEEDED':
self.log.error('MLEngine training job failed: %s', str(finished_training_job))
raise RuntimeError(finished_training_job['errorMessage'])


class MLEngineTrainingJobFailureOperator(BaseOperator):

"""
Operator for cleaning up failed MLEngine training job.
:param job_id: A unique templated id for the submitted Google MLEngine
training job. (templated)
:type job_id: str
:param project_id: The Google Cloud project name within which MLEngine training job should run.
If set to None or missing, the default project_id from the GCP connection is used. (templated)
:type project_id: str
:param gcp_conn_id: The connection ID to use when fetching connection info.
:type gcp_conn_id: str
:param delegate_to: The account to impersonate, if any.
For this to work, the service account making the request must have
domain-wide delegation enabled.
:type delegate_to: str
"""

template_fields = [
'_project_id',
'_job_id',
]

@apply_defaults
def __init__(self,
job_id: str,
project_id: Optional[str] = None,
gcp_conn_id: str = 'google_cloud_default',
delegate_to: Optional[str] = None,
*args,
**kwargs) -> None:
super().__init__(*args, **kwargs)
self._project_id = project_id
self._job_id = job_id
self._gcp_conn_id = gcp_conn_id
self._delegate_to = delegate_to

if not self._project_id:
raise AirflowException('Google Cloud project id is required.')

def execute(self, context):

hook = MLEngineHook(
gcp_conn_id=self._gcp_conn_id,
delegate_to=self._delegate_to
)

hook.cancel_job(project_id=self._project_id, job_id=_normalize_mlengine_job_id(self._job_id))
121 changes: 121 additions & 0 deletions tests/providers/google/cloud/hooks/test_mlengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,88 @@ def check_input(existing_job):

self.assertEqual(create_job_response, my_job)

@mock.patch("airflow.providers.google.cloud.hooks.mlengine.MLEngineHook.get_conn")
def test_cancel_mlengine_job(self, mock_get_conn):
project_id = "test-project"
job_id = 'test-job-id'
job_path = 'projects/{}/jobs/{}'.format(project_id, job_id)

job_cancelled = {}

(
mock_get_conn.return_value.
projects.return_value.
jobs.return_value.
cancel.return_value.
execute.return_value
) = job_cancelled

cancel_job_response = self.hook.cancel_job(job_id=job_id, project_id=project_id)

self.assertEqual(cancel_job_response, job_cancelled)
mock_get_conn.assert_has_calls([
mock.call().projects().jobs().cancel(name=job_path),
], any_order=True)

@mock.patch("airflow.providers.google.cloud.hooks.mlengine.MLEngineHook.get_conn")
def test_cancel_mlengine_job_nonexistent_job(self, mock_get_conn):
project_id = "test-project"
job_id = 'test-job-id'
job_cancelled = {}

error_job_does_not_exist = HttpError(resp=mock.MagicMock(status=404), content=b'Job does not exist')

(
mock_get_conn.return_value.
projects.return_value.
jobs.return_value.
cancel.return_value.
execute.side_effect
) = error_job_does_not_exist
(
mock_get_conn.return_value.
projects.return_value.
jobs.return_value.
cancel.return_value.
execute.return_value
) = job_cancelled

with self.assertRaises(HttpError):
self.hook.cancel_job(job_id=job_id, project_id=project_id)

@mock.patch("airflow.providers.google.cloud.hooks.mlengine.MLEngineHook.get_conn")
def test_cancel_mlengine_job_completed_job(self, mock_get_conn):
project_id = "test-project"
job_id = 'test-job-id'
job_path = 'projects/{}/jobs/{}'.format(project_id, job_id)
job_cancelled = {}

error_job_already_completed = HttpError(
resp=mock.MagicMock(status=400),
content=b'Job already completed')

(
mock_get_conn.return_value.
projects.return_value.
jobs.return_value.
cancel.return_value.
execute.side_effect
) = error_job_already_completed
(
mock_get_conn.return_value.
projects.return_value.
jobs.return_value.
cancel.return_value.
execute.return_value
) = job_cancelled

cancel_job_response = self.hook.cancel_job(job_id=job_id, project_id=project_id)

self.assertEqual(cancel_job_response, job_cancelled)
mock_get_conn.assert_has_calls([
mock.call().projects().jobs().cancel(name=job_path),
], any_order=True)


class TestMLEngineHookWithDefaultProjectId(unittest.TestCase):
def setUp(self) -> None:
Expand Down Expand Up @@ -987,6 +1069,33 @@ def test_create_mlengine_job(self, mock_get_conn, mock_sleep, mock_project_id):
mock.call().projects().jobs().get().execute()
], any_order=True)

@mock.patch(
'airflow.providers.google.cloud.hooks.base.CloudBaseHook.project_id',
new_callable=PropertyMock,
return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST
)
@mock.patch("airflow.providers.google.cloud.hooks.mlengine.MLEngineHook.get_conn")
def test_cancel_mlengine_job(self, mock_get_conn, mock_project_id):
job_id = 'test-job-id'
job_path = 'projects/{}/jobs/{}'.format(GCP_PROJECT_ID_HOOK_UNIT_TEST, job_id)

job_cancelled = {}

(
mock_get_conn.return_value.
projects.return_value.
jobs.return_value.
cancel.return_value.
execute.return_value
) = job_cancelled

cancel_job_response = self.hook.cancel_job(job_id=job_id)

self.assertEqual(cancel_job_response, job_cancelled)
mock_get_conn.assert_has_calls([
mock.call().projects().jobs().cancel(name=job_path),
], any_order=True)


class TestMLEngineHookWithoutProjectId(unittest.TestCase):
def setUp(self) -> None:
Expand Down Expand Up @@ -1110,3 +1219,15 @@ def test_create_mlengine_job(self, mock_get_conn, mock_sleep, mock_project_id):

with self.assertRaises(AirflowException):
self.hook.create_job(job=new_job)

@mock.patch(
'airflow.providers.google.cloud.hooks.base.CloudBaseHook.project_id',
new_callable=PropertyMock,
return_value=None
)
@mock.patch("airflow.providers.google.cloud.hooks.mlengine.MLEngineHook.get_conn")
def test_cancel_mlengine_job(self, mock_get_conn, mock_project_id):
job_id = 'test-job-id'

with self.assertRaises(AirflowException):
self.hook.cancel_job(job_id=job_id)
50 changes: 50 additions & 0 deletions tests/providers/google/cloud/operators/test_mlengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
MLEngineDeleteVersionOperator, MLEngineGetModelOperator, MLEngineListVersionsOperator,
MLEngineManageModelOperator, MLEngineManageVersionOperator, MLEngineSetDefaultVersionOperator,
MLEngineStartBatchPredictionJobOperator, MLEngineStartTrainingJobOperator,
MLEngineTrainingJobFailureOperator,
)

DEFAULT_DATE = datetime.datetime(2017, 6, 6)
Expand Down Expand Up @@ -404,6 +405,55 @@ def test_failed_job_error(self, mock_hook):
self.assertEqual('A failure message', str(context.exception))


class TestMLEngineTrainingJobFailureOperator(unittest.TestCase):

TRAINING_DEFAULT_ARGS = {
'project_id': 'test-project',
'job_id': 'test_training',
'task_id': 'test-training'
}

@patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook')
def test_success_cancel_training_job(self, mock_hook):
success_response = {}
hook_instance = mock_hook.return_value
hook_instance.cancel_job.return_value = success_response

cancel_training_op = MLEngineTrainingJobFailureOperator(
**self.TRAINING_DEFAULT_ARGS)
cancel_training_op.execute(None)

mock_hook.assert_called_once_with(
gcp_conn_id='google_cloud_default', delegate_to=None)
# Make sure only 'cancel_job' is invoked on hook instance
self.assertEqual(len(hook_instance.mock_calls), 1)
hook_instance.cancel_job.assert_called_once_with(
project_id=self.TRAINING_DEFAULT_ARGS['project_id'], job_id=self.TRAINING_DEFAULT_ARGS['job_id'])

@patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook')
def test_http_error(self, mock_hook):
http_error_code = 403
hook_instance = mock_hook.return_value
hook_instance.cancel_job.side_effect = HttpError(
resp=httplib2.Response({
'status': http_error_code
}),
content=b'Forbidden')

with self.assertRaises(HttpError) as context:
cancel_training_op = MLEngineTrainingJobFailureOperator(
**self.TRAINING_DEFAULT_ARGS)
cancel_training_op.execute(None)

mock_hook.assert_called_once_with(
gcp_conn_id='google_cloud_default', delegate_to=None)
# Make sure only 'create_job' is invoked on hook instance
self.assertEqual(len(hook_instance.mock_calls), 1)
hook_instance.cancel_job.assert_called_once_with(
project_id=self.TRAINING_DEFAULT_ARGS['project_id'], job_id=self.TRAINING_DEFAULT_ARGS['job_id'])
self.assertEqual(http_error_code, context.exception.resp.status)


class TestMLEngineModelOperator(unittest.TestCase):

@patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook')
Expand Down

0 comments on commit 2c9345a

Please sign in to comment.