Skip to content

Support for Amazon SageMaker AI endpoints as Model Provider #176

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: main
Choose a base branch
from

Conversation

dgallitelli
Copy link

@dgallitelli dgallitelli commented Jun 4, 2025

Description

Support for Amazon SageMaker AI endpoints as Model Provider

Related Issues

Issue #16

Documentation PR

[Link to related associated PR in the agent-docs repo]

Type of Change

New feature

Testing

Yes

Checklist

  • I have read the CONTRIBUTING document
  • I have added tests that prove my fix is effective or my feature works
  • I have updated the documentation accordingly
  • I have added an appropriate example to the documentation to outline the feature
  • My changes generate no new warnings
  • Any dependent changes have been merged and published

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

@dgallitelli dgallitelli requested a review from a team as a code owner June 4, 2025 13:38
@dgallitelli
Copy link
Author

This is an updated version of the PR #30 . Please review and merge if appropriate 😄

@swami87aws
Copy link

I wish this gets merged soon. It opens up access to AWS Marketplace models to deployed as Amazon Sagemaker endpoints and accessible via Strands

@brunopistone
Copy link

Do we have an expected date for this PR? It is needed for customer workshops in the coming weeks.

@rvvittal
Copy link

Do we have an expected date for this PR? This will help with your SageMaker AI GTM motions.

@dbschmigelski dbschmigelski requested review from mehtarac and pgrayy June 11, 2025 15:19
@mehtarac
Copy link
Member

Hi all, thanks for the interest and implementation for the model provider! The team will review the pull request this week and start the feedback process (leave comments, questions if any).

@swami87aws
Copy link

@dgallitelli Please correct me if i m wrong here. This handles only JSON type content. Will it be able to handle a multimedia content type?

@dgallitelli
Copy link
Author

@dgallitelli Please correct me if i m wrong here. This handles only JSON type content. Will it be able to handle a multimedia content type?

That's correct! It currently does not support multi-modal models.

boto_session: Boto Session to use when calling the SageMaker Runtime.
boto_client_config: Configuration to use when creating the SageMaker-Runtime Boto Client.
"""
if model_config.get("stream", "") == "":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this stream configuration check, could we use python's setdefault method:
model_config.setdefault("stream", True)

payload.pop("tools")
payload.pop("tool_choice", None)

# TODO: this should be a @override of format_request_message
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious if this was meant in a follow-up commit? I see the TODO comment here

elif message.get("role", "") == "tool":
logger.debug("message content:<%s> | streaming message content", message["content"])
logger.debug("message content type:<%s> | streaming message content type", type(message["content"]))
if type(message["content"]) == str:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Running with hatch fmt --linter exposed a linting error here:

E721 Use `is` and `is not` for type comparisons, or `isinstance()` for isinstance checks
    |
209 |                 logger.debug("message content:<%s> | streaming message content", message["content"])
210 |                 logger.debug("message content type:<%s> | streaming message content type", type(message["content"]))
211 |                 if type(message["content"]) == str:
    |                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ E721
212 |                     message["content"] = json.loads(message["content"])["content"]
213 |                 message["content"] = message["content"][0]["text"]

This is the only linting error I found, however, just to be sure would you be able to run the commands below to ensure that all the checks pass:

hatch fmt --formatter
hatch fmt --linter
hatch test --all

Comment on lines +216 to +224
"EndpointName": self.config["endpoint_name"],
"Body": json.dumps(payload),
"ContentType": "application/json",
"Accept": "application/json",
}

# Add InferenceComponentName if provided
if self.config.get("inference_component_name"):
request["InferenceComponentName"] = self.config["inference_component_name"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add some other additional parameters that are part of the request based on the boto documentation:
TargetModel, InferenceId

Comment on lines +239 to +314
if self.config.get("stream", True):
response = self.client.invoke_endpoint_with_response_stream(**request)

# Message start
yield {"chunk_type": "message_start"}

yield {"chunk_type": "content_start", "data_type": "text"}

# Parse the content
finish_reason = ""
partial_content = ""
tool_calls: dict[int, list[Any]] = {}
for event in response["Body"]:
chunk = event["PayloadPart"]["Bytes"].decode("utf-8")
partial_content += chunk # Some messages are randomly split and not JSON decodable- not sure why
try:
content = json.loads(partial_content)
partial_content = ""
choice = content["choices"][0]

# Start yielding message chunks
if choice["delta"].get("content", None):
yield {"chunk_type": "content_delta", "data_type": "text", "data": choice["delta"]["content"]}
for tool_call in choice["delta"].get("tool_calls", []):
tool_calls.setdefault(tool_call["index"], []).append(tool_call)
if choice["finish_reason"] is not None:
finish_reason = choice["finish_reason"]
break

except json.JSONDecodeError:
# Continue accumulating content until we have valid JSON
continue

yield {"chunk_type": "content_stop", "data_type": "text"}

# Handle tool calling
for tool_deltas in tool_calls.values():
yield {"chunk_type": "content_start", "data_type": "tool", "data": ToolCall(**tool_deltas[0])}
for tool_delta in tool_deltas:
yield {"chunk_type": "content_delta", "data_type": "tool", "data": ToolCall(**tool_delta)}
yield {"chunk_type": "content_stop", "data_type": "tool"}

# Message close
yield {"chunk_type": "message_stop", "data": finish_reason}
# Handle usage metadata - TODO: not supported in current Response Schema!
# Ref: https://docs.djl.ai/master/docs/serving/serving/docs/lmi/user_guides/chat_input_output_schema.html#response-schema
# yield {"chunk_type": "metadata", "data": UsageMetadata(**choice["usage"])}

else:
# Not all SageMaker AI models support streaming!
response = self.client.invoke_endpoint(**request)
final_response_json = json.loads(response["Body"].read().decode("utf-8"))

# Obtain the key elements from the response
message = final_response_json["choices"][0]["message"]
message_stop_reason = final_response_json["choices"][0]["finish_reason"]

# Message start
yield {"chunk_type": "message_start"}

# Handle text
yield {"chunk_type": "content_start", "data_type": "text"}
yield {"chunk_type": "content_delta", "data_type": "text", "data": message["content"] or ""}
yield {"chunk_type": "content_stop", "data_type": "text"}

# Handle the tool calling, if any
if message_stop_reason == "tool_calls":
for tool_call in message["tool_calls"] or []:
yield {"chunk_type": "content_start", "data_type": "tool", "data": ToolCall(**tool_call)}
yield {"chunk_type": "content_delta", "data_type": "tool", "data": ToolCall(**tool_call)}
yield {"chunk_type": "content_stop", "data_type": "tool", "data": ToolCall(**tool_call)}

# Message close
yield {"chunk_type": "message_stop", "data": message_stop_reason}
# Handle usage metadata
yield {"chunk_type": "metadata", "data": UsageMetadata(**final_response_json["usage"])}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking it may be beneficial to wrap this in a try/except block to handle some SageMaker errors such as ValidationError, or ModelNotReadyException to gracefully let the user know a useful message.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Running the coverage report on this test file I see that it's covered 72%:

src/strands/models/sagemaker.py                                                   131     28     36      5    72%   135, 178-229, 255->276, 264->266, 267, 280-283

With some lines missing coverage not as crucial, i think it'll be valuable to add a test for lines 178-229 since it tests the method format_request()

@mehtarac mehtarac self-assigned this Jun 27, 2025
@swami87aws
Copy link

Hi all.. Is there an ETA on this PR merge?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants