Skip to content

Commit

Permalink
Fix Vertex AI Custom Job training issue (#25367)
Browse files Browse the repository at this point in the history
  • Loading branch information
MaksYermak committed Jul 28, 2022
1 parent 4dc1778 commit a8e4519
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 28 deletions.
35 changes: 23 additions & 12 deletions airflow/providers/google/cloud/hooks/vertex_ai/custom_job.py
Expand Up @@ -246,6 +246,11 @@ def extract_model_id(obj: Dict) -> str:
"""Returns unique id of the Model."""
return obj["name"].rpartition("/")[-1]

@staticmethod
def extract_training_id(resource_name: str) -> str:
"""Returns unique id of the Training pipeline."""
return resource_name.rpartition("/")[-1]

def wait_for_operation(self, operation: Operation, timeout: Optional[float] = None):
"""Waits for long-lasting operation to complete."""
try:
Expand Down Expand Up @@ -299,7 +304,7 @@ def _run_job(
timestamp_split_column_name: Optional[str] = None,
tensorboard: Optional[str] = None,
sync=True,
) -> models.Model:
) -> Tuple[Optional[models.Model], str]:
"""Run Job for training pipeline"""
model = job.run(
dataset=dataset,
Expand Down Expand Up @@ -329,11 +334,17 @@ def _run_job(
tensorboard=tensorboard,
sync=sync,
)
training_id = self.extract_training_id(job.resource_name)
if model:
model.wait()
return model
else:
raise AirflowException("Training did not produce a Managed Model returning None.")
self.log.warning(
"Training did not produce a Managed Model returning None. Training Pipeline is not "
"configured to upload a Model. Create the Training Pipeline with "
"model_serving_container_image_uri and model_display_name passed in. "
"Ensure that your training script saves to model to os.environ['AIP_MODEL_DIR']."
)
return model, training_id

@GoogleBaseHook.fallback_to_default_project_id
def cancel_pipeline_job(
Expand Down Expand Up @@ -618,7 +629,7 @@ def create_custom_container_training_job(
timestamp_split_column_name: Optional[str] = None,
tensorboard: Optional[str] = None,
sync=True,
) -> models.Model:
) -> Tuple[Optional[models.Model], str]:
"""
Create Custom Container Training Job
Expand Down Expand Up @@ -890,7 +901,7 @@ def create_custom_container_training_job(
if not self._job:
raise AirflowException("CustomJob was not created")

model = self._run_job(
model, training_id = self._run_job(
job=self._job,
dataset=dataset,
annotation_schema_uri=annotation_schema_uri,
Expand Down Expand Up @@ -920,7 +931,7 @@ def create_custom_container_training_job(
sync=sync,
)

return model
return model, training_id

@GoogleBaseHook.fallback_to_default_project_id
def create_custom_python_package_training_job(
Expand Down Expand Up @@ -980,7 +991,7 @@ def create_custom_python_package_training_job(
timestamp_split_column_name: Optional[str] = None,
tensorboard: Optional[str] = None,
sync=True,
) -> models.Model:
) -> Tuple[Optional[models.Model], str]:
"""
Create Custom Python Package Training Job
Expand Down Expand Up @@ -1252,7 +1263,7 @@ def create_custom_python_package_training_job(
if not self._job:
raise AirflowException("CustomJob was not created")

model = self._run_job(
model, training_id = self._run_job(
job=self._job,
dataset=dataset,
annotation_schema_uri=annotation_schema_uri,
Expand Down Expand Up @@ -1282,7 +1293,7 @@ def create_custom_python_package_training_job(
sync=sync,
)

return model
return model, training_id

@GoogleBaseHook.fallback_to_default_project_id
def create_custom_training_job(
Expand Down Expand Up @@ -1342,7 +1353,7 @@ def create_custom_training_job(
timestamp_split_column_name: Optional[str] = None,
tensorboard: Optional[str] = None,
sync=True,
) -> models.Model:
) -> Tuple[Optional[models.Model], str]:
"""
Create Custom Training Job
Expand Down Expand Up @@ -1614,7 +1625,7 @@ def create_custom_training_job(
if not self._job:
raise AirflowException("CustomJob was not created")

model = self._run_job(
model, training_id = self._run_job(
job=self._job,
dataset=dataset,
annotation_schema_uri=annotation_schema_uri,
Expand Down Expand Up @@ -1644,7 +1655,7 @@ def create_custom_training_job(
sync=sync,
)

return model
return model, training_id

@GoogleBaseHook.fallback_to_default_project_id
def delete_pipeline_job(
Expand Down
48 changes: 32 additions & 16 deletions airflow/providers/google/cloud/operators/vertex_ai/custom_job.py
Expand Up @@ -29,7 +29,11 @@

from airflow.models import BaseOperator
from airflow.providers.google.cloud.hooks.vertex_ai.custom_job import CustomJobHook
from airflow.providers.google.cloud.links.vertex_ai import VertexAIModelLink, VertexAITrainingPipelinesLink
from airflow.providers.google.cloud.links.vertex_ai import (
VertexAIModelLink,
VertexAITrainingLink,
VertexAITrainingPipelinesLink,
)

if TYPE_CHECKING:
from airflow.utils.context import Context
Expand Down Expand Up @@ -411,7 +415,7 @@ class CreateCustomContainerTrainingJobOperator(CustomTrainingJobBaseOperator):
'command',
'impersonation_chain',
]
operator_extra_links = (VertexAIModelLink(),)
operator_extra_links = (VertexAIModelLink(), VertexAITrainingLink())

def __init__(
self,
Expand All @@ -428,7 +432,7 @@ def execute(self, context: "Context"):
delegate_to=self.delegate_to,
impersonation_chain=self.impersonation_chain,
)
model = self.hook.create_custom_container_training_job(
model, training_id = self.hook.create_custom_container_training_job(
project_id=self.project_id,
region=self.region,
display_name=self.display_name,
Expand Down Expand Up @@ -478,9 +482,13 @@ def execute(self, context: "Context"):
sync=True,
)

result = Model.to_dict(model)
model_id = self.hook.extract_model_id(result)
VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
if model:
result = Model.to_dict(model)
model_id = self.hook.extract_model_id(result)
VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
else:
result = model # type: ignore
VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id)
return result

def on_kill(self) -> None:
Expand Down Expand Up @@ -755,7 +763,7 @@ class CreateCustomPythonPackageTrainingJobOperator(CustomTrainingJobBaseOperator
'region',
'impersonation_chain',
]
operator_extra_links = (VertexAIModelLink(),)
operator_extra_links = (VertexAIModelLink(), VertexAITrainingLink())

def __init__(
self,
Expand All @@ -774,7 +782,7 @@ def execute(self, context: "Context"):
delegate_to=self.delegate_to,
impersonation_chain=self.impersonation_chain,
)
model = self.hook.create_custom_python_package_training_job(
model, training_id = self.hook.create_custom_python_package_training_job(
project_id=self.project_id,
region=self.region,
display_name=self.display_name,
Expand Down Expand Up @@ -825,9 +833,13 @@ def execute(self, context: "Context"):
sync=True,
)

result = Model.to_dict(model)
model_id = self.hook.extract_model_id(result)
VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
if model:
result = Model.to_dict(model)
model_id = self.hook.extract_model_id(result)
VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
else:
result = model # type: ignore
VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id)
return result

def on_kill(self) -> None:
Expand Down Expand Up @@ -1104,7 +1116,7 @@ class CreateCustomTrainingJobOperator(CustomTrainingJobBaseOperator):
'requirements',
'impersonation_chain',
]
operator_extra_links = (VertexAIModelLink(),)
operator_extra_links = (VertexAIModelLink(), VertexAITrainingLink())

def __init__(
self,
Expand All @@ -1123,7 +1135,7 @@ def execute(self, context: "Context"):
delegate_to=self.delegate_to,
impersonation_chain=self.impersonation_chain,
)
model = self.hook.create_custom_training_job(
model, training_id = self.hook.create_custom_training_job(
project_id=self.project_id,
region=self.region,
display_name=self.display_name,
Expand Down Expand Up @@ -1174,9 +1186,13 @@ def execute(self, context: "Context"):
sync=True,
)

result = Model.to_dict(model)
model_id = self.hook.extract_model_id(result)
VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
if model:
result = Model.to_dict(model)
model_id = self.hook.extract_model_id(result)
VertexAIModelLink.persist(context=context, task_instance=self, model_id=model_id)
else:
result = model # type: ignore
VertexAITrainingLink.persist(context=context, task_instance=self, training_id=training_id)
return result

def on_kill(self) -> None:
Expand Down
3 changes: 3 additions & 0 deletions tests/providers/google/cloud/operators/test_vertex_ai.py
Expand Up @@ -170,6 +170,7 @@
class TestVertexAICreateCustomContainerTrainingJobOperator:
@mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook"))
def test_execute(self, mock_hook):
mock_hook.return_value.create_custom_container_training_job.return_value = (None, 'training_id')
op = CreateCustomContainerTrainingJobOperator(
task_id=TASK_ID,
gcp_conn_id=GCP_CONN_ID,
Expand Down Expand Up @@ -250,6 +251,7 @@ def test_execute(self, mock_hook):
class TestVertexAICreateCustomPythonPackageTrainingJobOperator:
@mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook"))
def test_execute(self, mock_hook):
mock_hook.return_value.create_custom_python_package_training_job.return_value = (None, 'training_id')
op = CreateCustomPythonPackageTrainingJobOperator(
task_id=TASK_ID,
gcp_conn_id=GCP_CONN_ID,
Expand Down Expand Up @@ -332,6 +334,7 @@ def test_execute(self, mock_hook):
class TestVertexAICreateCustomTrainingJobOperator:
@mock.patch(VERTEX_AI_PATH.format("custom_job.CustomJobHook"))
def test_execute(self, mock_hook):
mock_hook.return_value.create_custom_training_job.return_value = (None, 'training_id')
op = CreateCustomTrainingJobOperator(
task_id=TASK_ID,
gcp_conn_id=GCP_CONN_ID,
Expand Down

0 comments on commit a8e4519

Please sign in to comment.