-
Notifications
You must be signed in to change notification settings - Fork 183
Support for Amazon SageMaker AI endpoints as Model Provider #176
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This is an updated version of the PR #30 . Please review and merge if appropriate 😄 |
I wish this gets merged soon. It opens up access to AWS Marketplace models to deployed as Amazon Sagemaker endpoints and accessible via Strands |
Do we have an expected date for this PR? It is needed for customer workshops in the coming weeks. |
Do we have an expected date for this PR? This will help with your SageMaker AI GTM motions. |
Hi all, thanks for the interest and implementation for the model provider! The team will review the pull request this week and start the feedback process (leave comments, questions if any). |
@dgallitelli Please correct me if i m wrong here. This handles only JSON type content. Will it be able to handle a multimedia content type? |
That's correct! It currently does not support multi-modal models. |
boto_session: Boto Session to use when calling the SageMaker Runtime. | ||
boto_client_config: Configuration to use when creating the SageMaker-Runtime Boto Client. | ||
""" | ||
if model_config.get("stream", "") == "": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this stream configuration check, could we use python's setdefault
method:
model_config.setdefault("stream", True)
payload.pop("tools") | ||
payload.pop("tool_choice", None) | ||
|
||
# TODO: this should be a @override of format_request_message |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious if this was meant in a follow-up commit? I see the TODO comment here
elif message.get("role", "") == "tool": | ||
logger.debug("message content:<%s> | streaming message content", message["content"]) | ||
logger.debug("message content type:<%s> | streaming message content type", type(message["content"])) | ||
if type(message["content"]) == str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Running with hatch fmt --linter
exposed a linting error here:
E721 Use `is` and `is not` for type comparisons, or `isinstance()` for isinstance checks
|
209 | logger.debug("message content:<%s> | streaming message content", message["content"])
210 | logger.debug("message content type:<%s> | streaming message content type", type(message["content"]))
211 | if type(message["content"]) == str:
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ E721
212 | message["content"] = json.loads(message["content"])["content"]
213 | message["content"] = message["content"][0]["text"]
This is the only linting error I found, however, just to be sure would you be able to run the commands below to ensure that all the checks pass:
hatch fmt --formatter
hatch fmt --linter
hatch test --all
"EndpointName": self.config["endpoint_name"], | ||
"Body": json.dumps(payload), | ||
"ContentType": "application/json", | ||
"Accept": "application/json", | ||
} | ||
|
||
# Add InferenceComponentName if provided | ||
if self.config.get("inference_component_name"): | ||
request["InferenceComponentName"] = self.config["inference_component_name"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we add some other additional parameters that are part of the request based on the boto documentation:
TargetModel
, InferenceId
if self.config.get("stream", True): | ||
response = self.client.invoke_endpoint_with_response_stream(**request) | ||
|
||
# Message start | ||
yield {"chunk_type": "message_start"} | ||
|
||
yield {"chunk_type": "content_start", "data_type": "text"} | ||
|
||
# Parse the content | ||
finish_reason = "" | ||
partial_content = "" | ||
tool_calls: dict[int, list[Any]] = {} | ||
for event in response["Body"]: | ||
chunk = event["PayloadPart"]["Bytes"].decode("utf-8") | ||
partial_content += chunk # Some messages are randomly split and not JSON decodable- not sure why | ||
try: | ||
content = json.loads(partial_content) | ||
partial_content = "" | ||
choice = content["choices"][0] | ||
|
||
# Start yielding message chunks | ||
if choice["delta"].get("content", None): | ||
yield {"chunk_type": "content_delta", "data_type": "text", "data": choice["delta"]["content"]} | ||
for tool_call in choice["delta"].get("tool_calls", []): | ||
tool_calls.setdefault(tool_call["index"], []).append(tool_call) | ||
if choice["finish_reason"] is not None: | ||
finish_reason = choice["finish_reason"] | ||
break | ||
|
||
except json.JSONDecodeError: | ||
# Continue accumulating content until we have valid JSON | ||
continue | ||
|
||
yield {"chunk_type": "content_stop", "data_type": "text"} | ||
|
||
# Handle tool calling | ||
for tool_deltas in tool_calls.values(): | ||
yield {"chunk_type": "content_start", "data_type": "tool", "data": ToolCall(**tool_deltas[0])} | ||
for tool_delta in tool_deltas: | ||
yield {"chunk_type": "content_delta", "data_type": "tool", "data": ToolCall(**tool_delta)} | ||
yield {"chunk_type": "content_stop", "data_type": "tool"} | ||
|
||
# Message close | ||
yield {"chunk_type": "message_stop", "data": finish_reason} | ||
# Handle usage metadata - TODO: not supported in current Response Schema! | ||
# Ref: https://docs.djl.ai/master/docs/serving/serving/docs/lmi/user_guides/chat_input_output_schema.html#response-schema | ||
# yield {"chunk_type": "metadata", "data": UsageMetadata(**choice["usage"])} | ||
|
||
else: | ||
# Not all SageMaker AI models support streaming! | ||
response = self.client.invoke_endpoint(**request) | ||
final_response_json = json.loads(response["Body"].read().decode("utf-8")) | ||
|
||
# Obtain the key elements from the response | ||
message = final_response_json["choices"][0]["message"] | ||
message_stop_reason = final_response_json["choices"][0]["finish_reason"] | ||
|
||
# Message start | ||
yield {"chunk_type": "message_start"} | ||
|
||
# Handle text | ||
yield {"chunk_type": "content_start", "data_type": "text"} | ||
yield {"chunk_type": "content_delta", "data_type": "text", "data": message["content"] or ""} | ||
yield {"chunk_type": "content_stop", "data_type": "text"} | ||
|
||
# Handle the tool calling, if any | ||
if message_stop_reason == "tool_calls": | ||
for tool_call in message["tool_calls"] or []: | ||
yield {"chunk_type": "content_start", "data_type": "tool", "data": ToolCall(**tool_call)} | ||
yield {"chunk_type": "content_delta", "data_type": "tool", "data": ToolCall(**tool_call)} | ||
yield {"chunk_type": "content_stop", "data_type": "tool", "data": ToolCall(**tool_call)} | ||
|
||
# Message close | ||
yield {"chunk_type": "message_stop", "data": message_stop_reason} | ||
# Handle usage metadata | ||
yield {"chunk_type": "metadata", "data": UsageMetadata(**final_response_json["usage"])} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thinking it may be beneficial to wrap this in a try/except block to handle some SageMaker errors such as ValidationError
, or ModelNotReadyException
to gracefully let the user know a useful message.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Running the coverage report on this test file I see that it's covered 72%:
src/strands/models/sagemaker.py 131 28 36 5 72% 135, 178-229, 255->276, 264->266, 267, 280-283
With some lines missing coverage not as crucial, i think it'll be valuable to add a test for lines 178-229
since it tests the method format_request()
Hi all.. Is there an ETA on this PR merge? |
Description
Support for Amazon SageMaker AI endpoints as Model Provider
Related Issues
Issue #16
Documentation PR
[Link to related associated PR in the agent-docs repo]
Type of Change
New feature
Testing
Yes
Checklist
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.