Skip to content

Commit

Permalink
Generalize MLEngineStartTrainingJobOperator to custom images (#13318)
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanUkhov committed Jan 2, 2021
1 parent 6e1a6ff commit f6518dd
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 37 deletions.
92 changes: 56 additions & 36 deletions airflow/providers/google/cloud/operators/mlengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,26 +1080,30 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
:param job_id: A unique templated id for the submitted Google MLEngine
training job. (templated)
:type job_id: str
:param package_uris: A list of package locations for MLEngine training job,
which should include the main training program + any additional
dependencies. (templated)
:type package_uris: List[str]
:param training_python_module: The Python module name to run within MLEngine
training job after installing 'package_uris' packages. (templated)
:type training_python_module: str
:param training_args: A list of templated command line arguments to pass to
the MLEngine training program. (templated)
:type training_args: List[str]
:param region: The Google Compute Engine region to run the MLEngine training
job in (templated).
:type region: str
:param package_uris: A list of Python package locations for the training
job, which should include the main training program and any additional
dependencies. This is mutually exclusive with a custom image specified
via master_config. (templated)
:type package_uris: List[str]
:param training_python_module: The name of the Python module to run within
the training job after installing the packages. This is mutually
exclusive with a custom image specified via master_config. (templated)
:type training_python_module: str
:param training_args: A list of command-line arguments to pass to the
training program. (templated)
:type training_args: List[str]
:param scale_tier: Resource tier for MLEngine training job. (templated)
:type scale_tier: str
:param master_type: Cloud ML Engine machine name.
Must be set when scale_tier is CUSTOM. (templated)
:param master_type: The type of virtual machine to use for the master
worker. It must be set whenever scale_tier is CUSTOM. (templated)
:type master_type: str
:param master_config: Cloud ML Engine master config.
master_type must be set if master_config is provided. (templated)
:param master_config: The configuration for the master worker. If this is
provided, master_type must be set as well. If a custom image is
specified, this is mutually exclusive with package_uris and
training_python_module. (templated)
:type master_type: dict
:param runtime_version: The Google Cloud ML runtime version to use for
training. (templated)
Expand Down Expand Up @@ -1147,10 +1151,10 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
template_fields = [
'_project_id',
'_job_id',
'_region',
'_package_uris',
'_training_python_module',
'_training_args',
'_region',
'_scale_tier',
'_master_type',
'_master_config',
Expand All @@ -1168,10 +1172,10 @@ def __init__(
self, # pylint: disable=too-many-arguments
*,
job_id: str,
package_uris: List[str],
training_python_module: str,
training_args: List[str],
region: str,
package_uris: List[str] = None,
training_python_module: str = None,
training_args: List[str] = None,
scale_tier: Optional[str] = None,
master_type: Optional[str] = None,
master_config: Optional[Dict] = None,
Expand All @@ -1190,10 +1194,10 @@ def __init__(
super().__init__(**kwargs)
self._project_id = project_id
self._job_id = job_id
self._region = region
self._package_uris = package_uris
self._training_python_module = training_python_module
self._training_args = training_args
self._region = region
self._scale_tier = scale_tier
self._master_type = master_type
self._master_config = master_config
Expand All @@ -1207,37 +1211,56 @@ def __init__(
self._labels = labels
self._impersonation_chain = impersonation_chain

custom = self._scale_tier is not None and self._scale_tier.upper() == 'CUSTOM'
custom_image = (
custom
and self._master_config is not None
and self._master_config.get('imageUri', None) is not None
)

if not self._project_id:
raise AirflowException('Google Cloud project id is required.')
if not self._job_id:
raise AirflowException('An unique job id is required for Google MLEngine training job.')
if not package_uris:
raise AirflowException('At least one python package is required for MLEngine Training job.')
if not training_python_module:
raise AirflowException(
'Python module name to run after installing required packages is required.'
)
if not self._region:
raise AirflowException('Google Compute Engine region is required.')
if self._scale_tier is not None and self._scale_tier.upper() == "CUSTOM" and not self._master_type:
if custom and not self._master_type:
raise AirflowException('master_type must be set when scale_tier is CUSTOM')
if self._master_config and not self._master_type:
raise AirflowException('master_type must be set when master_config is provided')
if not (package_uris and training_python_module) and not custom_image:
raise AirflowException(
'Either a Python package with a Python module or a custom Docker image should be provided.'
)
if (package_uris or training_python_module) and custom_image:
raise AirflowException(
'Either a Python package with a Python module or '
'a custom Docker image should be provided but not both.'
)

def execute(self, context):
job_id = _normalize_mlengine_job_id(self._job_id)
training_request = {
'jobId': job_id,
'trainingInput': {
'scaleTier': self._scale_tier,
'packageUris': self._package_uris,
'pythonModule': self._training_python_module,
'region': self._region,
'args': self._training_args,
},
}
if self._labels:
training_request['labels'] = self._labels
if self._package_uris:
training_request['trainingInput']['packageUris'] = self._package_uris

if self._training_python_module:
training_request['trainingInput']['pythonModule'] = self._training_python_module

if self._training_args:
training_request['trainingInput']['args'] = self._training_args

if self._master_type:
training_request['trainingInput']['masterType'] = self._master_type

if self._master_config:
training_request['trainingInput']['masterConfig'] = self._master_config

if self._runtime_version:
training_request['trainingInput']['runtimeVersion'] = self._runtime_version
Expand All @@ -1251,11 +1274,8 @@ def execute(self, context):
if self._service_account:
training_request['trainingInput']['serviceAccount'] = self._service_account

if self._scale_tier is not None and self._scale_tier.upper() == "CUSTOM":
training_request['trainingInput']['masterType'] = self._master_type

if self._master_config:
training_request['trainingInput']['masterConfig'] = self._master_config
if self._labels:
training_request['labels'] = self._labels

if self._mode == 'DRY_RUN':
self.log.info('In dry_run mode.')
Expand Down
48 changes: 47 additions & 1 deletion tests/providers/google/cloud/operators/test_mlengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def test_failed_job_error(self, mock_hook):
self.assertEqual('A failure message', str(context.exception))


class TestMLEngineTrainingOperator(unittest.TestCase):
class TestMLEngineStartTrainingJobOperator(unittest.TestCase):
TRAINING_DEFAULT_ARGS = {
'project_id': 'test-project',
'job_id': 'test_training',
Expand Down Expand Up @@ -407,6 +407,52 @@ def test_success_create_training_job_with_master_config(self, mock_hook):
project_id='test-project', job=training_input, use_existing_job_fn=ANY
)

@patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook')
def test_success_create_training_job_with_master_image(self, hook):
arguments = {
'project_id': 'test-project',
'job_id': 'test_training',
'region': 'europe-west1',
'scale_tier': 'CUSTOM',
'master_type': 'n1-standard-8',
'master_config': {
'imageUri': 'eu.gcr.io/test-project/test-image:test-version',
},
'task_id': 'test-training',
'start_date': days_ago(1),
}
request = {
'jobId': 'test_training',
'trainingInput': {
'region': 'europe-west1',
'scaleTier': 'CUSTOM',
'masterType': 'n1-standard-8',
'masterConfig': {
'imageUri': 'eu.gcr.io/test-project/test-image:test-version',
},
},
}

response = request.copy()
response['state'] = 'SUCCEEDED'
hook_instance = hook.return_value
hook_instance.create_job.return_value = response

operator = MLEngineStartTrainingJobOperator(**arguments)
operator.execute(MagicMock())

hook.assert_called_once_with(
gcp_conn_id='google_cloud_default',
delegate_to=None,
impersonation_chain=None,
)
self.assertEqual(len(hook_instance.mock_calls), 1)
hook_instance.create_job.assert_called_once_with(
project_id='test-project',
job=request,
use_existing_job_fn=ANY,
)

@patch('airflow.providers.google.cloud.operators.mlengine.MLEngineHook')
def test_success_create_training_job_with_optional_args(self, mock_hook):
training_input = copy.deepcopy(self.TRAINING_INPUT)
Expand Down

0 comments on commit f6518dd

Please sign in to comment.