Skip to content

Commit

Permalink
Add endpoint_id arg to `google.cloud.operators.vertex_ai.CreateEndp…
Browse files Browse the repository at this point in the history
…ointOperator`
  • Loading branch information
mai-nakagawa authored and potiuk committed Apr 25, 2022
1 parent b45240a commit 48abf57
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def create_endpoint(
project_id: str,
region: str,
endpoint: Union[Endpoint, Dict],
endpoint_id: Optional[str] = None,
retry: Union[Retry, _MethodDefault] = DEFAULT,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
Expand All @@ -91,6 +92,7 @@ def create_endpoint(
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
:param region: Required. The ID of the Google Cloud region that the service belongs to.
:param endpoint: Required. The Endpoint to create.
:param endpoint_id: The ID of Endpoint. If not provided, Vertex AI will generate a value for this ID.
:param retry: Designation of what errors, if any, should be retried.
:param timeout: The timeout for this request.
:param metadata: Strings which should be sent along with the request as metadata.
Expand All @@ -102,6 +104,7 @@ def create_endpoint(
request={
'parent': parent,
'endpoint': endpoint,
'endpoint_id': endpoint_id,
},
retry=retry,
timeout=timeout,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __init__(
region: str,
project_id: str,
endpoint: Union[Endpoint, Dict],
endpoint_id: Optional[str] = None,
retry: Union[Retry, _MethodDefault] = DEFAULT,
timeout: Optional[float] = None,
metadata: Sequence[Tuple[str, str]] = (),
Expand All @@ -93,6 +94,7 @@ def __init__(
self.region = region
self.project_id = project_id
self.endpoint = endpoint
self.endpoint_id = endpoint_id
self.retry = retry
self.timeout = timeout
self.metadata = metadata
Expand All @@ -112,6 +114,7 @@ def execute(self, context: 'Context'):
project_id=self.project_id,
region=self.region,
endpoint=self.endpoint,
endpoint_id=self.endpoint_id,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
TEST_REGION: str = "test-region"
TEST_PROJECT_ID: str = "test-project-id"
TEST_ENDPOINT: dict = {}
TEST_ENDPOINT_ID: str = "test_endpoint_id"
TEST_ENDPOINT_NAME: str = "test_endpoint_name"
TEST_DEPLOYED_MODEL: dict = {}
TEST_DEPLOYED_MODEL_ID: str = "test-deployed-model-id"
Expand All @@ -54,12 +55,14 @@ def test_create_endpoint(self, mock_client) -> None:
project_id=TEST_PROJECT_ID,
region=TEST_REGION,
endpoint=TEST_ENDPOINT,
endpoint_id=TEST_ENDPOINT_ID,
)
mock_client.assert_called_once_with(TEST_REGION)
mock_client.return_value.create_endpoint.assert_called_once_with(
request=dict(
parent=mock_client.return_value.common_location_path.return_value,
endpoint=TEST_ENDPOINT,
endpoint_id=TEST_ENDPOINT_ID,
),
metadata=(),
retry=DEFAULT,
Expand Down Expand Up @@ -223,12 +226,14 @@ def test_create_endpoint(self, mock_client) -> None:
project_id=TEST_PROJECT_ID,
region=TEST_REGION,
endpoint=TEST_ENDPOINT,
endpoint_id=TEST_ENDPOINT_ID,
)
mock_client.assert_called_once_with(TEST_REGION)
mock_client.return_value.create_endpoint.assert_called_once_with(
request=dict(
parent=mock_client.return_value.common_location_path.return_value,
endpoint=TEST_ENDPOINT,
endpoint_id=TEST_ENDPOINT_ID,
),
metadata=(),
retry=DEFAULT,
Expand Down
2 changes: 2 additions & 0 deletions tests/providers/google/cloud/operators/test_vertex_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -1137,6 +1137,7 @@ def test_execute(self, mock_hook, to_dict_mock):
region=GCP_LOCATION,
project_id=GCP_PROJECT,
endpoint=TEST_ENDPOINT,
endpoint_id=TEST_ENDPOINT_ID,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
Expand All @@ -1149,6 +1150,7 @@ def test_execute(self, mock_hook, to_dict_mock):
region=GCP_LOCATION,
project_id=GCP_PROJECT,
endpoint=TEST_ENDPOINT,
endpoint_id=TEST_ENDPOINT_ID,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
Expand Down

0 comments on commit 48abf57

Please sign in to comment.