Skip to content

Commit

Permalink
Add labels param to Google MLEngine Operators (#10222)
Browse files Browse the repository at this point in the history
  • Loading branch information
coopergillan committed Aug 8, 2020
1 parent 8a655cf commit c295338
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
package_uris=[TRAINER_URI],
training_python_module=TRAINER_PY_MODULE,
training_args=[],
labels={"job_type": "training"},
)
# [END howto_operator_gcp_mlengine_training]

Expand Down Expand Up @@ -169,6 +170,7 @@
data_format="TEXT",
input_paths=[PREDICTION_INPUT],
output_path=PREDICTION_OUTPUT,
labels={"job_type": "prediction"},
)
# [END howto_operator_gcp_mlengine_get_prediction]

Expand Down
14 changes: 13 additions & 1 deletion airflow/providers/google/cloud/operators/mlengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import logging
import re
import warnings
from typing import List, Optional
from typing import Dict, List, Optional

from airflow.exceptions import AirflowException
from airflow.models import BaseOperator, BaseOperatorLink
Expand Down Expand Up @@ -151,6 +151,8 @@ class MLEngineStartBatchPredictionJobOperator(BaseOperator):
For this to work, the service account making the request must
have domain-wide delegation enabled.
:type delegate_to: str
:param labels: a dictionary containing labels for the job; passed to BigQuery
:type labels: Dict[str, str]
:raises: ``ValueError``: if a unique model/version origin cannot be
determined.
"""
Expand Down Expand Up @@ -183,6 +185,7 @@ def __init__(self, # pylint: disable=too-many-arguments
project_id: Optional[str] = None,
gcp_conn_id: str = 'google_cloud_default',
delegate_to: Optional[str] = None,
labels: Optional[Dict[str, str]] = None,
**kwargs) -> None:
super().__init__(**kwargs)

Expand All @@ -200,6 +203,7 @@ def __init__(self, # pylint: disable=too-many-arguments
self._signature_name = signature_name
self._gcp_conn_id = gcp_conn_id
self._delegate_to = delegate_to
self._labels = labels

if not self._project_id:
raise AirflowException('Google Cloud project id is required.')
Expand Down Expand Up @@ -234,6 +238,8 @@ def execute(self, context):
'region': self._region
}
}
if self._labels:
prediction_request['labels'] = self._labels

if self._uri:
prediction_request['predictionInput']['uri'] = self._uri
Expand Down Expand Up @@ -953,6 +959,8 @@ class MLEngineStartTrainingJobOperator(BaseOperator):
will be printed out. In 'CLOUD' mode, a real MLEngine training job
creation request will be issued.
:type mode: str
:param labels: a dictionary containing labels for the job; passed to BigQuery
:type labels: Dict[str, str]
"""

template_fields = [
Expand Down Expand Up @@ -990,6 +998,7 @@ def __init__(self, # pylint: disable=too-many-arguments
gcp_conn_id: str = 'google_cloud_default',
delegate_to: Optional[str] = None,
mode: str = 'PRODUCTION',
labels: Optional[Dict[str, str]] = None,
**kwargs) -> None:
super().__init__(**kwargs)
self._project_id = project_id
Expand All @@ -1006,6 +1015,7 @@ def __init__(self, # pylint: disable=too-many-arguments
self._gcp_conn_id = gcp_conn_id
self._delegate_to = delegate_to
self._mode = mode
self._labels = labels

if not self._project_id:
raise AirflowException('Google Cloud project id is required.')
Expand Down Expand Up @@ -1039,6 +1049,8 @@ def execute(self, context):
'args': self._training_args,
}
}
if self._labels:
training_request['labels'] = self._labels

if self._runtime_version:
training_request['trainingInput']['runtimeVersion'] = self._runtime_version
Expand Down
6 changes: 6 additions & 0 deletions tests/providers/google/cloud/operators/test_mlengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class TestMLEngineBatchPredictionOperator(unittest.TestCase):
}
SUCCESS_MESSAGE_MISSING_INPUT = {
'jobId': 'test_prediction',
'labels': {'some': 'labels'},
'predictionOutput': {
'outputPath': 'gs://fake-output-path',
'predictionCount': 5000,
Expand All @@ -74,6 +75,7 @@ class TestMLEngineBatchPredictionOperator(unittest.TestCase):
BATCH_PREDICTION_DEFAULT_ARGS = {
'project_id': 'test-project',
'job_id': 'test_prediction',
'labels': {'some': 'labels'},
'region': 'us-east1',
'data_format': 'TEXT',
'input_paths': ['gs://legal-bucket-dash-Capital/legal-input-path/*'],
Expand Down Expand Up @@ -116,6 +118,7 @@ def test_success_with_model(self, mock_hook):
input_paths=input_with_model['inputPaths'],
output_path=input_with_model['outputPath'],
model_name=input_with_model['modelName'].split('/')[-1],
labels={'some': 'labels'},
dag=self.dag,
task_id='test-prediction')
prediction_output = prediction_task.execute(None)
Expand All @@ -125,6 +128,7 @@ def test_success_with_model(self, mock_hook):
project_id='test-project',
job={
'jobId': 'test_prediction',
'labels': {'some': 'labels'},
'predictionInput': input_with_model
},
use_existing_job_fn=ANY
Expand Down Expand Up @@ -308,11 +312,13 @@ class TestMLEngineTrainingOperator(unittest.TestCase):
'training_args': '--some_arg=\'aaa\'',
'region': 'us-east1',
'scale_tier': 'STANDARD_1',
'labels': {'some': 'labels'},
'task_id': 'test-training',
'start_date': days_ago(1)
}
TRAINING_INPUT = {
'jobId': 'test_training',
'labels': {'some': 'labels'},
'trainingInput': {
'scaleTier': 'STANDARD_1',
'packageUris': ['gs://some-bucket/package1'],
Expand Down

0 comments on commit c295338

Please sign in to comment.