From b8939e75dc506808ef7b711886c2094821dd2372 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Fri, 30 May 2025 09:36:42 -0400 Subject: [PATCH 01/15] models - unsupported content types (#144) --- src/strands/models/anthropic.py | 10 +++++++-- src/strands/models/litellm.py | 3 +++ src/strands/models/llamaapi.py | 9 +++++++- src/strands/models/ollama.py | 6 +++++- src/strands/types/models/openai.py | 9 +++++++- tests/strands/models/test_anthropic.py | 26 ++++------------------- tests/strands/models/test_llamaapi.py | 12 +++++++++++ tests/strands/models/test_ollama.py | 23 +++++--------------- tests/strands/types/models/test_openai.py | 15 ++++++------- 9 files changed, 60 insertions(+), 53 deletions(-) diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 99e49f81..132093fd 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -4,7 +4,6 @@ """ import base64 -import json import logging import mimetypes from typing import Any, Iterable, Optional, TypedDict, cast @@ -95,6 +94,9 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An Returns: Anthropic formatted content block. + + Raises: + TypeError: If the content block type cannot be converted to an Anthropic-compatible format. """ if "document" in content: mime_type = mimetypes.types_map.get(f".{content['document']['format']}", "application/octet-stream") @@ -151,7 +153,7 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An "type": "tool_result", } - return {"text": json.dumps(content), "type": "text"} + raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") def _format_request_messages(self, messages: Messages) -> list[dict[str, Any]]: """Format an Anthropic messages array. @@ -192,6 +194,10 @@ def format_request( Returns: An Anthropic streaming request. + + Raises: + TypeError: If a message contains a content block type that cannot be converted to an Anthropic-compatible + format. """ return { "max_tokens": self.config["max_tokens"], diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 23d2c2ae..5ef9aac5 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -76,6 +76,9 @@ def format_request_message_content(content: ContentBlock) -> dict[str, Any]: Returns: LiteLLM formatted content block. + + Raises: + TypeError: If the content block type cannot be converted to a LiteLLM-compatible format. """ if "reasoningContent" in content: return { diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index 307614db..b0825ecb 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -92,6 +92,9 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An Returns: LllamaAPI formatted content block. + + Raises: + TypeError: If the content block type cannot be converted to a LlamaAPI-compatible format. """ if "image" in content: mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream") @@ -107,7 +110,7 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An if "text" in content: return {"text": content["text"], "type": "text"} - return {"text": json.dumps(content), "type": "text"} + raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") def _format_request_message_tool_call(self, tool_use: ToolUse) -> dict[str, Any]: """Format a Llama API tool call. @@ -196,6 +199,10 @@ def format_request( Returns: An Llama API chat streaming request. + + Raises: + TypeError: If a message contains a content block type that cannot be converted to a LlamaAPI-compatible + format. """ request = { "messages": self._format_request_messages(messages, system_prompt), diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index f632da34..ec1212fb 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -105,6 +105,10 @@ def format_request( Returns: An Ollama chat streaming request. + + Raises: + TypeError: If a message contains a content block type that cannot be converted to an Ollama-compatible + format. """ def format_message(message: Message, content: ContentBlock) -> dict[str, Any]: @@ -153,7 +157,7 @@ def format_message(message: Message, content: ContentBlock) -> dict[str, Any]: **({"images": result_images} if result_images else {}), } - return {"role": message["role"], "content": json.dumps(content)} + raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") def format_messages() -> list[dict[str, Any]]: return [format_message(message, content) for message in messages for content in message["content"]] diff --git a/src/strands/types/models/openai.py b/src/strands/types/models/openai.py index 307c0be6..2053c0f2 100644 --- a/src/strands/types/models/openai.py +++ b/src/strands/types/models/openai.py @@ -40,6 +40,9 @@ def format_request_message_content(content: ContentBlock) -> dict[str, Any]: Returns: OpenAI compatible content block. + + Raises: + TypeError: If the content block type cannot be converted to an OpenAI-compatible format. """ if "document" in content: mime_type = mimetypes.types_map.get(f".{content['document']['format']}", "application/octet-stream") @@ -67,7 +70,7 @@ def format_request_message_content(content: ContentBlock) -> dict[str, Any]: if "text" in content: return {"text": content["text"], "type": "text"} - return {"text": json.dumps(content), "type": "text"} + raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") @staticmethod def format_request_message_tool_call(tool_use: ToolUse) -> dict[str, Any]: @@ -163,6 +166,10 @@ def format_request( Returns: An OpenAI compatible chat streaming request. + + Raises: + TypeError: If a message contains a content block type that cannot be converted to an OpenAI-compatible + format. """ return { "messages": self.format_request_messages(messages, system_prompt), diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index 2ee344cc..30cf1ec9 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -1,4 +1,3 @@ -import json import unittest.mock import anthropic @@ -339,33 +338,16 @@ def test_format_request_with_tool_results(model, model_id, max_tokens): assert tru_request == exp_request -def test_format_request_with_other(model, model_id, max_tokens): +def test_format_request_with_unsupported_type(model): messages = [ { "role": "user", - "content": [{"other": {"a": 1}}], + "content": [{"unsupported": {}}], }, ] - tru_request = model.format_request(messages) - exp_request = { - "max_tokens": max_tokens, - "messages": [ - { - "role": "user", - "content": [ - { - "text": json.dumps({"other": {"a": 1}}), - "type": "text", - }, - ], - }, - ], - "model": model_id, - "tools": [], - } - - assert tru_request == exp_request + with pytest.raises(TypeError, match="content_type= | unsupported type"): + model.format_request(messages) def test_format_request_with_cache_point(model, model_id, max_tokens): diff --git a/tests/strands/models/test_llamaapi.py b/tests/strands/models/test_llamaapi.py index 5d7cd0f4..2e1a920a 100644 --- a/tests/strands/models/test_llamaapi.py +++ b/tests/strands/models/test_llamaapi.py @@ -233,6 +233,18 @@ def test_format_request_with_empty_content(model, model_id): assert tru_request == exp_request +def test_format_request_with_unsupported_type(model): + messages = [ + { + "role": "user", + "content": [{"unsupported": {}}], + }, + ] + + with pytest.raises(TypeError, match="content_type= | unsupported type"): + model.format_request(messages) + + def test_format_chunk_message_start(model): event = {"chunk_type": "message_start"} diff --git a/tests/strands/models/test_ollama.py b/tests/strands/models/test_ollama.py index d3bd8f7f..d87594af 100644 --- a/tests/strands/models/test_ollama.py +++ b/tests/strands/models/test_ollama.py @@ -280,29 +280,16 @@ def test_format_request_with_tool_result_other(model, model_id): assert tru_request == exp_request -def test_format_request_with_other(model, model_id): +def test_format_request_with_unsupported_type(model): messages = [ { "role": "user", - "content": [{"other": {"a": 1}}], - } + "content": [{"unsupported": {}}], + }, ] - tru_request = model.format_request(messages) - exp_request = { - "messages": [ - { - "role": "user", - "content": json.dumps({"other": {"a": 1}}), - } - ], - "model": model_id, - "options": {}, - "stream": True, - "tools": [], - } - - assert tru_request == exp_request + with pytest.raises(TypeError, match="content_type= | unsupported type"): + model.format_request(messages) def test_format_request_with_tool_specs(model, messages, model_id): diff --git a/tests/strands/types/models/test_openai.py b/tests/strands/types/models/test_openai.py index c6a05291..daa9a87e 100644 --- a/tests/strands/types/models/test_openai.py +++ b/tests/strands/types/models/test_openai.py @@ -101,14 +101,6 @@ def system_prompt(): {"text": "hello"}, {"type": "text", "text": "hello"}, ), - # Other - ( - {"other": {"a": 1}}, - { - "text": json.dumps({"other": {"a": 1}}), - "type": "text", - }, - ), ], ) def test_format_request_message_content(content, exp_result): @@ -116,6 +108,13 @@ def test_format_request_message_content(content, exp_result): assert tru_result == exp_result +def test_format_request_message_content_unsupported_type(): + content = {"unsupported": {}} + + with pytest.raises(TypeError, match="content_type= | unsupported type"): + SAOpenAIModel.format_request_message_content(content) + + def test_format_request_message_tool_call(): tool_use = { "input": {"expression": "2+2"}, From cba7db648eb31e1bfee7df3d5f8432c0bc9de8a8 Mon Sep 17 00:00:00 2001 From: moritalous Date: Sat, 31 May 2025 00:53:46 +0900 Subject: [PATCH 02/15] feat: Add CachePoint type definition to ContentBlock (#142) --- src/strands/types/content.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/strands/types/content.py b/src/strands/types/content.py index c64b6773..790e9094 100644 --- a/src/strands/types/content.py +++ b/src/strands/types/content.py @@ -60,10 +60,21 @@ class ReasoningContentBlock(TypedDict, total=False): redactedContent: bytes +class CachePoint(TypedDict): + """A cache point configuration for optimizing conversation history. + + Attributes: + type: The type of cache point, typically "default". + """ + + type: str + + class ContentBlock(TypedDict, total=False): """A block of content for a message that you pass to, or receive from, a model. Attributes: + cachePoint: A cache point configuration to optimize conversation history. document: A document to include in the message. guardContent: Contains the content to assess with the guardrail. image: Image to include in the message. @@ -74,6 +85,7 @@ class ContentBlock(TypedDict, total=False): video: Video to include in the message. """ + cachePoint: CachePoint document: DocumentContent guardContent: GuardContent image: ImageContent From af961b2770ad1a3288b9b830f09227cd1c8ae95f Mon Sep 17 00:00:00 2001 From: Young Han <110819238+seyeong-han@users.noreply.github.com> Date: Fri, 30 May 2025 15:43:24 -0700 Subject: [PATCH 03/15] docs: add meta copyright header (#153) --- src/strands/models/llamaapi.py | 1 + tests-integ/test_model_llamaapi.py | 1 + tests/strands/models/test_llamaapi.py | 1 + 3 files changed, 3 insertions(+) diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index b0825ecb..00f7742d 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -1,3 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates """Llama API model provider. - Docs: https://llama.developer.meta.com/ diff --git a/tests-integ/test_model_llamaapi.py b/tests-integ/test_model_llamaapi.py index 5cddc1a0..dad6919e 100644 --- a/tests-integ/test_model_llamaapi.py +++ b/tests-integ/test_model_llamaapi.py @@ -1,3 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates import os import pytest diff --git a/tests/strands/models/test_llamaapi.py b/tests/strands/models/test_llamaapi.py index 2e1a920a..9a69c4c1 100644 --- a/tests/strands/models/test_llamaapi.py +++ b/tests/strands/models/test_llamaapi.py @@ -1,3 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates import json import unittest.mock From af25f98e7bb53dc6b874de4f6d9f13e23b858d32 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Mon, 2 Jun 2025 11:30:21 -0400 Subject: [PATCH 04/15] refactor: Update conversation manager interface (#161) --- src/strands/agent/agent.py | 6 +++--- .../conversation_manager.py | 15 ++++++------- .../null_conversation_manager.py | 14 +++++++------ .../sliding_window_conversation_manager.py | 21 ++++++++++++------- tests/strands/agent/test_agent.py | 4 ++-- .../agent/test_conversation_manager.py | 13 ++++++++---- 6 files changed, 43 insertions(+), 30 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 0f912b54..bfa83fe2 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -165,7 +165,7 @@ def caller(**kwargs: Any) -> Any: self._agent._record_tool_execution(tool_use, tool_result, user_message_override, messages) # Apply window management - self._agent.conversation_manager.apply_management(self._agent.messages) + self._agent.conversation_manager.apply_management(self._agent) return tool_result @@ -439,7 +439,7 @@ def _run_loop( return self._execute_event_loop_cycle(invocation_callback_handler, kwargs) finally: - self.conversation_manager.apply_management(self.messages) + self.conversation_manager.apply_management(self) def _execute_event_loop_cycle(self, callback_handler: Callable, kwargs: dict[str, Any]) -> AgentResult: """Execute the event loop cycle with retry logic for context window limits. @@ -483,7 +483,7 @@ def _execute_event_loop_cycle(self, callback_handler: Callable, kwargs: dict[str except ContextWindowOverflowException as e: # Try reducing the context size and retrying - self.conversation_manager.reduce_context(messages, e=e) + self.conversation_manager.reduce_context(self, e=e) return self._execute_event_loop_cycle(callback_handler_override, kwargs) def _record_tool_execution( diff --git a/src/strands/agent/conversation_manager/conversation_manager.py b/src/strands/agent/conversation_manager/conversation_manager.py index d18ae69a..dbccf941 100644 --- a/src/strands/agent/conversation_manager/conversation_manager.py +++ b/src/strands/agent/conversation_manager/conversation_manager.py @@ -1,9 +1,10 @@ """Abstract interface for conversation history management.""" from abc import ABC, abstractmethod -from typing import Optional +from typing import TYPE_CHECKING, Optional -from ...types.content import Messages +if TYPE_CHECKING: + from ...agent.agent import Agent class ConversationManager(ABC): @@ -19,22 +20,22 @@ class ConversationManager(ABC): @abstractmethod # pragma: no cover - def apply_management(self, messages: Messages) -> None: - """Applies management strategy to the provided list of messages. + def apply_management(self, agent: "Agent") -> None: + """Applies management strategy to the provided agent. Processes the conversation history to maintain appropriate size by modifying the messages list in-place. Implementations should handle message pruning, summarization, or other size management techniques to keep the conversation context within desired bounds. Args: - messages: The conversation history to manage. + agent: The agent whose conversation history will be manage. This list is modified in-place. """ pass @abstractmethod # pragma: no cover - def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> None: + def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None: """Called when the model's context window is exceeded. This method should implement the specific strategy for reducing the window size when a context overflow occurs. @@ -48,7 +49,7 @@ def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> N - Maintaining critical conversation markers Args: - messages: The conversation history to reduce. + agent: The agent whose conversation history will be reduced. This list is modified in-place. e: The exception that triggered the context reduction, if any. """ diff --git a/src/strands/agent/conversation_manager/null_conversation_manager.py b/src/strands/agent/conversation_manager/null_conversation_manager.py index 2066c08b..4af4eb78 100644 --- a/src/strands/agent/conversation_manager/null_conversation_manager.py +++ b/src/strands/agent/conversation_manager/null_conversation_manager.py @@ -1,8 +1,10 @@ """Null implementation of conversation management.""" -from typing import Optional +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from ...agent.agent import Agent -from ...types.content import Messages from ...types.exceptions import ContextWindowOverflowException from .conversation_manager import ConversationManager @@ -17,19 +19,19 @@ class NullConversationManager(ConversationManager): - Situations where the full conversation history should be preserved """ - def apply_management(self, messages: Messages) -> None: + def apply_management(self, _agent: "Agent") -> None: """Does nothing to the conversation history. Args: - messages: The conversation history that will remain unmodified. + agent: The agent whose conversation history will remain unmodified. """ pass - def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> None: + def reduce_context(self, _agent: "Agent", e: Optional[Exception] = None) -> None: """Does not reduce context and raises an exception. Args: - messages: The conversation history that will remain unmodified. + agent: The agent whose conversation history will remain unmodified. e: The exception that triggered the context reduction, if any. Raises: diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index f367b272..3381247c 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -1,7 +1,10 @@ """Sliding window conversation history management.""" import logging -from typing import Optional +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from ...agent.agent import Agent from ...types.content import Message, Messages from ...types.exceptions import ContextWindowOverflowException @@ -45,13 +48,13 @@ def __init__(self, window_size: int = 40): """Initialize the sliding window conversation manager. Args: - window_size: Maximum number of messages to keep in history. + window_size: Maximum number of messages to keep in the agent's history. Defaults to 40 messages. """ self.window_size = window_size - def apply_management(self, messages: Messages) -> None: - """Apply the sliding window to the messages array to maintain a manageable history size. + def apply_management(self, agent: "Agent") -> None: + """Apply the sliding window to the agent's messages array to maintain a manageable history size. This method is called after every event loop cycle, as the messages array may have been modified with tool results and assistant responses. It first removes any dangling messages that might create an invalid @@ -62,9 +65,10 @@ def apply_management(self, messages: Messages) -> None: blocks to maintain conversation coherence. Args: - messages: The messages to manage. + agent: The agent whose messages will be managed. This list is modified in-place. """ + messages = agent.messages self._remove_dangling_messages(messages) if len(messages) <= self.window_size: @@ -72,7 +76,7 @@ def apply_management(self, messages: Messages) -> None: "window_size=<%s>, message_count=<%s> | skipping context reduction", len(messages), self.window_size ) return - self.reduce_context(messages) + self.reduce_context(agent) def _remove_dangling_messages(self, messages: Messages) -> None: """Remove dangling messages that would create an invalid conversation state. @@ -105,7 +109,7 @@ def _remove_dangling_messages(self, messages: Messages) -> None: if not any("toolResult" in content for content in messages[-1]["content"]): messages.pop() - def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> None: + def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None: """Trim the oldest messages to reduce the conversation context size. The method handles special cases where trimming the messages leads to: @@ -113,7 +117,7 @@ def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> N - toolUse with no corresponding toolResult Args: - messages: The messages to reduce. + agent: The agent whose messages will be reduce. This list is modified in-place. e: The exception that triggered the context reduction, if any. @@ -122,6 +126,7 @@ def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> N Such as when the conversation is already minimal or when tool result messages cannot be properly converted. """ + messages = agent.messages # If the number of messages is less than the window_size, then we default to 2, otherwise, trim to window size trim_index = 2 if len(messages) <= self.window_size else len(messages) - self.window_size diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index ea06fb4e..4a63fa31 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -318,7 +318,7 @@ def test_agent__call__( ) callback_handler.assert_called() - conversation_manager_spy.apply_management.assert_called_with(agent.messages) + conversation_manager_spy.apply_management.assert_called_with(agent) def test_agent__call__passes_kwargs(mock_model, system_prompt, callback_handler, agent, tool, mock_event_loop_cycle): @@ -583,7 +583,7 @@ def test_agent_tool(mock_randint, agent): } assert tru_result == exp_result - conversation_manager_spy.apply_management.assert_called_with(agent.messages) + conversation_manager_spy.apply_management.assert_called_with(agent) def test_agent_tool_user_message_override(agent): diff --git a/tests/strands/agent/test_conversation_manager.py b/tests/strands/agent/test_conversation_manager.py index b6132f1d..bbec3cd1 100644 --- a/tests/strands/agent/test_conversation_manager.py +++ b/tests/strands/agent/test_conversation_manager.py @@ -1,6 +1,7 @@ import pytest import strands +from strands.agent.agent import Agent from strands.types.exceptions import ContextWindowOverflowException @@ -160,7 +161,8 @@ def conversation_manager(request): indirect=["conversation_manager"], ) def test_apply_management(conversation_manager, messages, expected_messages): - conversation_manager.apply_management(messages) + test_agent = Agent(messages=messages) + conversation_manager.apply_management(test_agent) assert messages == expected_messages @@ -172,9 +174,10 @@ def test_sliding_window_conversation_manager_with_untrimmable_history_raises_con {"role": "user", "content": [{"toolResult": {"toolUseId": "789", "content": [], "status": "success"}}]}, ] original_messages = messages.copy() + test_agent = Agent(messages=messages) with pytest.raises(ContextWindowOverflowException): - manager.apply_management(messages) + manager.apply_management(test_agent) assert messages == original_messages @@ -187,8 +190,9 @@ def test_null_conversation_manager_reduce_context_raises_context_window_overflow {"role": "assistant", "content": [{"text": "Hi there"}]}, ] original_messages = messages.copy() + test_agent = Agent(messages=messages) - manager.apply_management(messages) + manager.apply_management(test_agent) with pytest.raises(ContextWindowOverflowException): manager.reduce_context(messages) @@ -204,8 +208,9 @@ def test_null_conversation_manager_reduce_context_with_exception_raises_same_exc {"role": "assistant", "content": [{"text": "Hi there"}]}, ] original_messages = messages.copy() + test_agent = Agent(messages=messages) - manager.apply_management(messages) + manager.apply_management(test_agent) with pytest.raises(RuntimeError): manager.reduce_context(messages, RuntimeError("test")) From 76cd7bafda7c82eab8064d186335038be100a1b6 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Mon, 2 Jun 2025 11:33:44 -0400 Subject: [PATCH 05/15] models - correct tool result content (#154) --- src/strands/models/anthropic.py | 7 +- src/strands/models/litellm.py | 6 +- src/strands/models/llamaapi.py | 26 ++- src/strands/models/ollama.py | 213 ++++++++++++---------- src/strands/types/models/openai.py | 29 +-- tests/strands/models/test_anthropic.py | 5 + tests/strands/models/test_llamaapi.py | 10 +- tests/strands/models/test_ollama.py | 133 +++----------- tests/strands/types/models/test_openai.py | 19 +- 9 files changed, 194 insertions(+), 254 deletions(-) diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 132093fd..57394e2c 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -4,6 +4,7 @@ """ import base64 +import json import logging import mimetypes from typing import Any, Iterable, Optional, TypedDict, cast @@ -145,7 +146,11 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An if "toolResult" in content: return { "content": [ - self._format_request_message_content(cast(ContentBlock, tool_result_content)) + self._format_request_message_content( + {"text": json.dumps(tool_result_content["json"])} + if "json" in tool_result_content + else cast(ContentBlock, tool_result_content) + ) for tool_result_content in content["toolResult"]["content"] ], "is_error": content["toolResult"]["status"] == "error", diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 5ef9aac5..62f16d31 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -67,8 +67,8 @@ def get_config(self) -> LiteLLMConfig: return cast(LiteLLMModel.LiteLLMConfig, self.config) @override - @staticmethod - def format_request_message_content(content: ContentBlock) -> dict[str, Any]: + @classmethod + def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: """Format a LiteLLM content block. Args: @@ -96,4 +96,4 @@ def format_request_message_content(content: ContentBlock) -> dict[str, Any]: }, } - return OpenAIModel.format_request_message_content(content) + return super().format_request_message_content(content) diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index 00f7742d..583db2f2 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -8,7 +8,7 @@ import json import logging import mimetypes -from typing import Any, Iterable, Optional +from typing import Any, Iterable, Optional, cast import llama_api_client from llama_api_client import LlamaAPIClient @@ -139,18 +139,30 @@ def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any Returns: Llama API formatted tool message. """ + contents = cast( + list[ContentBlock], + [ + {"text": json.dumps(content["json"])} if "json" in content else content + for content in tool_result["content"] + ], + ) + return { "role": "tool", "tool_call_id": tool_result["toolUseId"], - "content": json.dumps( - { - "content": tool_result["content"], - "status": tool_result["status"], - } - ), + "content": [self._format_request_message_content(content) for content in contents], } def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + """Format a LlamaAPI compatible messages array. + + Args: + messages: List of message objects to be processed by the model. + system_prompt: System prompt to provide context to the model. + + Returns: + An LlamaAPI compatible messages array. + """ formatted_messages: list[dict[str, Any]] formatted_messages = [{"role": "system", "content": system_prompt}] if system_prompt else [] diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index ec1212fb..7ed12216 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -5,13 +5,12 @@ import json import logging -from typing import Any, Iterable, Optional, Union +from typing import Any, Iterable, Optional, cast from ollama import Client as OllamaClient from typing_extensions import TypedDict, Unpack, override -from ..types.content import ContentBlock, Message, Messages -from ..types.media import DocumentContent, ImageContent +from ..types.content import ContentBlock, Messages from ..types.models import Model from ..types.streaming import StopReason, StreamEvent from ..types.tools import ToolSpec @@ -92,35 +91,31 @@ def get_config(self) -> OllamaConfig: """ return self.config - @override - def format_request( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None - ) -> dict[str, Any]: - """Format an Ollama chat streaming request. + def _format_request_message_contents(self, role: str, content: ContentBlock) -> list[dict[str, Any]]: + """Format Ollama compatible message contents. + + Ollama doesn't support an array of contents, so we must flatten everything into separate message blocks. Args: - messages: List of message objects to be processed by the model. - tool_specs: List of tool specifications to make available to the model. - system_prompt: System prompt to provide context to the model. + role: E.g., user. + content: Content block to format. Returns: - An Ollama chat streaming request. + Ollama formatted message contents. Raises: - TypeError: If a message contains a content block type that cannot be converted to an Ollama-compatible - format. + TypeError: If the content block type cannot be converted to an Ollama-compatible format. """ + if "text" in content: + return [{"role": role, "content": content["text"]}] - def format_message(message: Message, content: ContentBlock) -> dict[str, Any]: - if "text" in content: - return {"role": message["role"], "content": content["text"]} + if "image" in content: + return [{"role": role, "images": [content["image"]["source"]["bytes"]]}] - if "image" in content: - return {"role": message["role"], "images": [content["image"]["source"]["bytes"]]} - - if "toolUse" in content: - return { - "role": "assistant", + if "toolUse" in content: + return [ + { + "role": role, "tool_calls": [ { "function": { @@ -130,45 +125,63 @@ def format_message(message: Message, content: ContentBlock) -> dict[str, Any]: } ], } + ] + + if "toolResult" in content: + return [ + formatted_tool_result_content + for tool_result_content in content["toolResult"]["content"] + for formatted_tool_result_content in self._format_request_message_contents( + "tool", + ( + {"text": json.dumps(tool_result_content["json"])} + if "json" in tool_result_content + else cast(ContentBlock, tool_result_content) + ), + ) + ] - if "toolResult" in content: - result_content: Union[str, ImageContent, DocumentContent, Any] = None - result_images = [] - for tool_result_content in content["toolResult"]["content"]: - if "text" in tool_result_content: - result_content = tool_result_content["text"] - elif "json" in tool_result_content: - result_content = tool_result_content["json"] - elif "image" in tool_result_content: - result_content = "see images" - result_images.append(tool_result_content["image"]["source"]["bytes"]) - else: - result_content = content["toolResult"]["content"] + raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") - return { - "role": "tool", - "content": json.dumps( - { - "name": content["toolResult"]["toolUseId"], - "result": result_content, - "status": content["toolResult"]["status"], - } - ), - **({"images": result_images} if result_images else {}), - } + def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + """Format an Ollama compatible messages array. - raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") + Args: + messages: List of message objects to be processed by the model. + system_prompt: System prompt to provide context to the model. - def format_messages() -> list[dict[str, Any]]: - return [format_message(message, content) for message in messages for content in message["content"]] + Returns: + An Ollama compatible messages array. + """ + system_message = [{"role": "system", "content": system_prompt}] if system_prompt else [] - formatted_messages = format_messages() + return system_message + [ + formatted_message + for message in messages + for content in message["content"] + for formatted_message in self._format_request_message_contents(message["role"], content) + ] + @override + def format_request( + self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + ) -> dict[str, Any]: + """Format an Ollama chat streaming request. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. + + Returns: + An Ollama chat streaming request. + + Raises: + TypeError: If a message contains a content block type that cannot be converted to an Ollama-compatible + format. + """ return { - "messages": [ - *([{"role": "system", "content": system_prompt}] if system_prompt else []), - *formatted_messages, - ], + "messages": self._format_request_messages(messages, system_prompt), "model": self.config["model_id"], "options": { **(self.config.get("options") or {}), @@ -217,52 +230,54 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: RuntimeError: If chunk_type is not recognized. This error should never be encountered as we control chunk_type in the stream method. """ - if event["chunk_type"] == "message_start": - return {"messageStart": {"role": "assistant"}} - - if event["chunk_type"] == "content_start": - if event["data_type"] == "text": - return {"contentBlockStart": {"start": {}}} - - tool_name = event["data"].function.name - return {"contentBlockStart": {"start": {"toolUse": {"name": tool_name, "toolUseId": tool_name}}}} - - if event["chunk_type"] == "content_delta": - if event["data_type"] == "text": - return {"contentBlockDelta": {"delta": {"text": event["data"]}}} - - tool_arguments = event["data"].function.arguments - return {"contentBlockDelta": {"delta": {"toolUse": {"input": json.dumps(tool_arguments)}}}} - - if event["chunk_type"] == "content_stop": - return {"contentBlockStop": {}} - - if event["chunk_type"] == "message_stop": - reason: StopReason - if event["data"] == "tool_use": - reason = "tool_use" - elif event["data"] == "length": - reason = "max_tokens" - else: - reason = "end_turn" - - return {"messageStop": {"stopReason": reason}} - - if event["chunk_type"] == "metadata": - return { - "metadata": { - "usage": { - "inputTokens": event["data"].eval_count, - "outputTokens": event["data"].prompt_eval_count, - "totalTokens": event["data"].eval_count + event["data"].prompt_eval_count, - }, - "metrics": { - "latencyMs": event["data"].total_duration / 1e6, + match event["chunk_type"]: + case "message_start": + return {"messageStart": {"role": "assistant"}} + + case "content_start": + if event["data_type"] == "text": + return {"contentBlockStart": {"start": {}}} + + tool_name = event["data"].function.name + return {"contentBlockStart": {"start": {"toolUse": {"name": tool_name, "toolUseId": tool_name}}}} + + case "content_delta": + if event["data_type"] == "text": + return {"contentBlockDelta": {"delta": {"text": event["data"]}}} + + tool_arguments = event["data"].function.arguments + return {"contentBlockDelta": {"delta": {"toolUse": {"input": json.dumps(tool_arguments)}}}} + + case "content_stop": + return {"contentBlockStop": {}} + + case "message_stop": + reason: StopReason + if event["data"] == "tool_use": + reason = "tool_use" + elif event["data"] == "length": + reason = "max_tokens" + else: + reason = "end_turn" + + return {"messageStop": {"stopReason": reason}} + + case "metadata": + return { + "metadata": { + "usage": { + "inputTokens": event["data"].eval_count, + "outputTokens": event["data"].prompt_eval_count, + "totalTokens": event["data"].eval_count + event["data"].prompt_eval_count, + }, + "metrics": { + "latencyMs": event["data"].total_duration / 1e6, + }, }, - }, - } + } - raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") + case _: + raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") @override def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: diff --git a/src/strands/types/models/openai.py b/src/strands/types/models/openai.py index 2053c0f2..96f758d5 100644 --- a/src/strands/types/models/openai.py +++ b/src/strands/types/models/openai.py @@ -11,7 +11,7 @@ import json import logging import mimetypes -from typing import Any, Optional +from typing import Any, Optional, cast from typing_extensions import override @@ -31,8 +31,8 @@ class OpenAIModel(Model, abc.ABC): config: dict[str, Any] - @staticmethod - def format_request_message_content(content: ContentBlock) -> dict[str, Any]: + @classmethod + def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: """Format an OpenAI compatible content block. Args: @@ -72,8 +72,8 @@ def format_request_message_content(content: ContentBlock) -> dict[str, Any]: raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") - @staticmethod - def format_request_message_tool_call(tool_use: ToolUse) -> dict[str, Any]: + @classmethod + def format_request_message_tool_call(cls, tool_use: ToolUse) -> dict[str, Any]: """Format an OpenAI compatible tool call. Args: @@ -91,8 +91,8 @@ def format_request_message_tool_call(tool_use: ToolUse) -> dict[str, Any]: "type": "function", } - @staticmethod - def format_request_tool_message(tool_result: ToolResult) -> dict[str, Any]: + @classmethod + def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]: """Format an OpenAI compatible tool message. Args: @@ -101,15 +101,18 @@ def format_request_tool_message(tool_result: ToolResult) -> dict[str, Any]: Returns: OpenAI compatible tool message. """ + contents = cast( + list[ContentBlock], + [ + {"text": json.dumps(content["json"])} if "json" in content else content + for content in tool_result["content"] + ], + ) + return { "role": "tool", "tool_call_id": tool_result["toolUseId"], - "content": json.dumps( - { - "content": tool_result["content"], - "status": tool_result["status"], - } - ), + "content": [cls.format_request_message_content(content) for content in contents], } @classmethod diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index 30cf1ec9..9421650e 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -289,6 +289,7 @@ def test_format_request_with_tool_results(model, model_id, max_tokens): "status": "success", "content": [ {"text": "see image"}, + {"json": ["see image"]}, { "image": { "format": "jpg", @@ -315,6 +316,10 @@ def test_format_request_with_tool_results(model, model_id, max_tokens): "text": "see image", "type": "text", }, + { + "text": '["see image"]', + "type": "text", + }, { "source": { "data": "YmFzZTY0ZW5jb2RlZGltYWdl", diff --git a/tests/strands/models/test_llamaapi.py b/tests/strands/models/test_llamaapi.py index 9a69c4c1..309dac2e 100644 --- a/tests/strands/models/test_llamaapi.py +++ b/tests/strands/models/test_llamaapi.py @@ -1,5 +1,4 @@ # Copyright (c) Meta Platforms, Inc. and affiliates -import json import unittest.mock import pytest @@ -145,7 +144,7 @@ def test_format_request_with_tool_result(model, model_id): "toolResult": { "toolUseId": "c1", "status": "success", - "content": [{"value": 4}], + "content": [{"text": "4"}, {"json": ["4"]}], } } ], @@ -156,12 +155,7 @@ def test_format_request_with_tool_result(model, model_id): exp_request = { "messages": [ { - "content": json.dumps( - { - "content": [{"value": 4}], - "status": "success", - } - ), + "content": [{"text": "4", "type": "text"}, {"text": '["4"]', "type": "text"}], "role": "tool", "tool_call_id": "c1", }, diff --git a/tests/strands/models/test_ollama.py b/tests/strands/models/test_ollama.py index d87594af..fe590dff 100644 --- a/tests/strands/models/test_ollama.py +++ b/tests/strands/models/test_ollama.py @@ -148,81 +148,24 @@ def test_format_request_with_tool_use(model, model_id): def test_format_request_with_tool_result(model, model_id): messages: Messages = [ { - "role": "tool", - "content": [{"toolResult": {"toolUseId": "calculator", "status": "success", "content": [{"text": "4"}]}}], - } - ] - - tru_request = model.format_request(messages) - exp_request = { - "messages": [ - { - "role": "tool", - "content": json.dumps( - { - "name": "calculator", - "result": "4", - "status": "success", - } - ), - } - ], - "model": model_id, - "options": {}, - "stream": True, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_tool_result_json(model, model_id): - messages: Messages = [ - { - "role": "tool", - "content": [ - {"toolResult": {"toolUseId": "calculator", "status": "success", "content": [{"json": {"result": 4}}]}} - ], - } - ] - - tru_request = model.format_request(messages) - exp_request = { - "messages": [ - { - "role": "tool", - "content": json.dumps( - { - "name": "calculator", - "result": {"result": 4}, - "status": "success", - } - ), - } - ], - "model": model_id, - "options": {}, - "stream": True, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_tool_result_image(model, model_id): - messages: Messages = [ - { - "role": "tool", + "role": "user", "content": [ { "toolResult": { - "toolUseId": "image_generator", + "toolUseId": "calculator", "status": "success", - "content": [{"image": {"source": {"bytes": "base64encodedimage"}}}], - } - } + "content": [ + {"text": "4"}, + {"image": {"source": {"bytes": b"image"}}}, + {"json": ["4"]}, + ], + }, + }, + { + "text": "see results", + }, ], - } + }, ] tru_request = model.format_request(messages) @@ -230,46 +173,20 @@ def test_format_request_with_tool_result_image(model, model_id): "messages": [ { "role": "tool", - "content": json.dumps( - { - "name": "image_generator", - "result": "see images", - "status": "success", - } - ), - "images": ["base64encodedimage"], - } - ], - "model": model_id, - "options": {}, - "stream": True, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_tool_result_other(model, model_id): - messages = [ - { - "role": "tool", - "content": [{"toolResult": {"toolUseId": "other", "status": "success", "content": {"other": {"a": 1}}}}], - } - ] - - tru_request = model.format_request(messages) - exp_request = { - "messages": [ + "content": "4", + }, { "role": "tool", - "content": json.dumps( - { - "name": "other", - "result": {"other": {"a": 1}}, - "status": "success", - } - ), - } + "images": [b"image"], + }, + { + "role": "tool", + "content": '["4"]', + }, + { + "role": "user", + "content": "see results", + }, ], "model": model_id, "options": {}, diff --git a/tests/strands/types/models/test_openai.py b/tests/strands/types/models/test_openai.py index daa9a87e..9db08bc9 100644 --- a/tests/strands/types/models/test_openai.py +++ b/tests/strands/types/models/test_openai.py @@ -1,4 +1,3 @@ -import json import unittest.mock import pytest @@ -136,19 +135,14 @@ def test_format_request_message_tool_call(): def test_format_request_tool_message(): tool_result = { - "content": [{"value": 4}], + "content": [{"text": "4"}, {"json": ["4"]}], "status": "success", "toolUseId": "c1", } tru_result = SAOpenAIModel.format_request_tool_message(tool_result) exp_result = { - "content": json.dumps( - { - "content": [{"value": 4}], - "status": "success", - } - ), + "content": [{"text": "4", "type": "text"}, {"text": '["4"]', "type": "text"}], "role": "tool", "tool_call_id": "c1", } @@ -179,7 +173,7 @@ def test_format_request_messages(system_prompt): "role": "assistant", }, { - "content": [{"toolResult": {"toolUseId": "c1", "status": "success", "content": [{"value": 4}]}}], + "content": [{"toolResult": {"toolUseId": "c1", "status": "success", "content": [{"text": "4"}]}}], "role": "user", }, ] @@ -209,12 +203,7 @@ def test_format_request_messages(system_prompt): ], }, { - "content": json.dumps( - { - "content": [{"value": 4}], - "status": "success", - } - ), + "content": [{"text": "4", "type": "text"}], "role": "tool", "tool_call_id": "c1", }, From da55dc859025f68bd2efbcbe2d01e5c9ddf96ba2 Mon Sep 17 00:00:00 2001 From: Arron <139703460+awsarron@users.noreply.github.com> Date: Tue, 3 Jun 2025 11:04:29 -0400 Subject: [PATCH 06/15] test: set OTEL_ env vars correctly for tests (#169) --- tests/conftest.py | 2 ++ tests/strands/tools/test_executor.py | 10 ++++++++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 4f0b5b21..cd18b698 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,6 +24,8 @@ def moto_env(monkeypatch): monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test") monkeypatch.setenv("AWS_SECURITY_TOKEN", "test") monkeypatch.setenv("AWS_DEFAULT_REGION", "us-west-2") + monkeypatch.delenv("OTEL_EXPORTER_OTLP_ENDPOINT", raising=False) + monkeypatch.delenv("OTEL_EXPORTER_OTLP_HEADERS", raising=False) @pytest.fixture diff --git a/tests/strands/tools/test_executor.py b/tests/strands/tools/test_executor.py index ced2bd7f..a6ea45c3 100644 --- a/tests/strands/tools/test_executor.py +++ b/tests/strands/tools/test_executor.py @@ -1,6 +1,7 @@ import concurrent import functools import unittest.mock +import uuid import pytest @@ -9,6 +10,11 @@ from strands.types.content import Message +@pytest.fixture(autouse=True) +def moto_autouse(moto_env): + _ = moto_env + + @pytest.fixture def tool_handler(request): def handler(tool_use): @@ -52,10 +58,10 @@ def invalid_tool_use_ids(request): return request.param if hasattr(request, "param") else [] -@unittest.mock.patch.object(strands.telemetry.metrics, "uuid4", return_value="trace1") @pytest.fixture def cycle_trace(): - return strands.telemetry.metrics.Trace(name="test trace", raw_name="raw_name") + with unittest.mock.patch.object(uuid, "uuid4", return_value="trace1"): + return strands.telemetry.metrics.Trace(name="test trace", raw_name="raw_name") @pytest.fixture From 9ce8f3dadbffa7e8582d7e0a7a5f89b36a412d53 Mon Sep 17 00:00:00 2001 From: Arron <139703460+awsarron@users.noreply.github.com> Date: Tue, 3 Jun 2025 12:44:19 -0400 Subject: [PATCH 07/15] Fix agent default callback handler (#170) Co-authored-by: Sourabh Sarupria Co-authored-by: Sourabh Sarupria --- src/strands/agent/agent.py | 35 +++++++++++++++++++++++++------ tests/strands/agent/test_agent.py | 31 +++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 6 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index bfa83fe2..0651d452 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -44,6 +44,16 @@ logger = logging.getLogger(__name__) +# Sentinel class and object to distinguish between explicit None and default parameter value +class _DefaultCallbackHandlerSentinel: + """Sentinel class to distinguish between explicit None and default parameter value.""" + + pass + + +_DEFAULT_CALLBACK_HANDLER = _DefaultCallbackHandlerSentinel() + + class Agent: """Core Agent interface. @@ -70,7 +80,7 @@ def __init__(self, agent: "Agent") -> None: # agent tools and thus break their execution. self._agent = agent - def __getattr__(self, name: str) -> Callable: + def __getattr__(self, name: str) -> Callable[..., Any]: """Call tool as a function. This method enables the method-style interface (e.g., `agent.tool.tool_name(param="value")`). @@ -177,7 +187,9 @@ def __init__( messages: Optional[Messages] = None, tools: Optional[List[Union[str, Dict[str, str], Any]]] = None, system_prompt: Optional[str] = None, - callback_handler: Optional[Callable] = PrintingCallbackHandler(), + callback_handler: Optional[ + Union[Callable[..., Any], _DefaultCallbackHandlerSentinel] + ] = _DEFAULT_CALLBACK_HANDLER, conversation_manager: Optional[ConversationManager] = None, max_parallel_tools: int = os.cpu_count() or 1, record_direct_tool_call: bool = True, @@ -204,7 +216,8 @@ def __init__( system_prompt: System prompt to guide model behavior. If None, the model will behave according to its default settings. callback_handler: Callback for processing events as they happen during agent execution. - Defaults to strands.handlers.PrintingCallbackHandler if None. + If not provided (using the default), a new PrintingCallbackHandler instance is created. + If explicitly set to None, null_callback_handler is used. conversation_manager: Manager for conversation history and context window. Defaults to strands.agent.conversation_manager.SlidingWindowConversationManager if None. max_parallel_tools: Maximum number of tools to run in parallel when the model returns multiple tool calls. @@ -222,7 +235,17 @@ def __init__( self.messages = messages if messages is not None else [] self.system_prompt = system_prompt - self.callback_handler = callback_handler or null_callback_handler + + # If not provided, create a new PrintingCallbackHandler instance + # If explicitly set to None, use null_callback_handler + # Otherwise use the passed callback_handler + self.callback_handler: Union[Callable[..., Any], PrintingCallbackHandler] + if isinstance(callback_handler, _DefaultCallbackHandlerSentinel): + self.callback_handler = PrintingCallbackHandler() + elif callback_handler is None: + self.callback_handler = null_callback_handler + else: + self.callback_handler = callback_handler self.conversation_manager = conversation_manager if conversation_manager else SlidingWindowConversationManager() @@ -415,7 +438,7 @@ def target_callback() -> None: thread.join() def _run_loop( - self, prompt: str, kwargs: Any, supplementary_callback_handler: Optional[Callable] = None + self, prompt: str, kwargs: Dict[str, Any], supplementary_callback_handler: Optional[Callable[..., Any]] = None ) -> AgentResult: """Execute the agent's event loop with the given prompt and parameters.""" try: @@ -441,7 +464,7 @@ def _run_loop( finally: self.conversation_manager.apply_management(self) - def _execute_event_loop_cycle(self, callback_handler: Callable, kwargs: dict[str, Any]) -> AgentResult: + def _execute_event_loop_cycle(self, callback_handler: Callable[..., Any], kwargs: Dict[str, Any]) -> AgentResult: """Execute the event loop cycle with retry logic for context window limits. This internal method handles the execution of the event loop cycle and implements diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 4a63fa31..0ea20b64 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -686,6 +686,37 @@ def test_agent_with_callback_handler_none_uses_null_handler(): assert agent.callback_handler == null_callback_handler +def test_agent_callback_handler_not_provided_creates_new_instances(): + """Test that when callback_handler is not provided, new PrintingCallbackHandler instances are created.""" + # Create two agents without providing callback_handler + agent1 = Agent() + agent2 = Agent() + + # Both should have PrintingCallbackHandler instances + assert isinstance(agent1.callback_handler, PrintingCallbackHandler) + assert isinstance(agent2.callback_handler, PrintingCallbackHandler) + + # But they should be different object instances + assert agent1.callback_handler is not agent2.callback_handler + + +def test_agent_callback_handler_explicit_none_uses_null_handler(): + """Test that when callback_handler is explicitly set to None, null_callback_handler is used.""" + agent = Agent(callback_handler=None) + + # Should use null_callback_handler + assert agent.callback_handler is null_callback_handler + + +def test_agent_callback_handler_custom_handler_used(): + """Test that when a custom callback_handler is provided, it is used.""" + custom_handler = unittest.mock.Mock() + agent = Agent(callback_handler=custom_handler) + + # Should use the provided custom handler + assert agent.callback_handler is custom_handler + + @pytest.mark.asyncio async def test_stream_async_returns_all_events(mock_event_loop_cycle): agent = Agent() From ffc7c5e68a50e0b8a9afb4c997a0eec8f88e4bf4 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Tue, 3 Jun 2025 12:49:36 -0400 Subject: [PATCH 08/15] chore: Add permissions to workflows (#166) * Update pr-and-push.yml * Update pypi-publish-on-release.yml --- .github/workflows/pr-and-push.yml | 4 +++- .github/workflows/pypi-publish-on-release.yml | 6 +++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pr-and-push.yml b/.github/workflows/pr-and-push.yml index 38e88691..2b2d026f 100644 --- a/.github/workflows/pr-and-push.yml +++ b/.github/workflows/pr-and-push.yml @@ -13,5 +13,7 @@ concurrency: jobs: call-test-lint: uses: ./.github/workflows/test-lint.yml + permissions: + contents: read with: - ref: ${{ github.event.pull_request.head.sha }} \ No newline at end of file + ref: ${{ github.event.pull_request.head.sha }} diff --git a/.github/workflows/pypi-publish-on-release.yml b/.github/workflows/pypi-publish-on-release.yml index 4047f596..d7038e82 100644 --- a/.github/workflows/pypi-publish-on-release.yml +++ b/.github/workflows/pypi-publish-on-release.yml @@ -13,6 +13,8 @@ jobs: build: name: Build distribution 📦 + permissions: + contents: read needs: - call-test-lint runs-on: ubuntu-latest @@ -55,6 +57,8 @@ jobs: deploy: name: Upload release to PyPI + permissions: + contents: read needs: - build runs-on: ubuntu-latest @@ -75,4 +79,4 @@ jobs: name: python-package-distributions path: dist/ - name: Publish distribution 📦 to PyPI - uses: pypa/gh-action-pypi-publish@release/v1 \ No newline at end of file + uses: pypa/gh-action-pypi-publish@release/v1 From a8059193e61c357a34c22bf199c3e2c18ac9fc73 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Wed, 4 Jun 2025 10:00:04 -0400 Subject: [PATCH 09/15] Remove redundant permissions block (#172) --- .github/workflows/pypi-publish-on-release.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/pypi-publish-on-release.yml b/.github/workflows/pypi-publish-on-release.yml index d7038e82..0e26a1db 100644 --- a/.github/workflows/pypi-publish-on-release.yml +++ b/.github/workflows/pypi-publish-on-release.yml @@ -57,8 +57,6 @@ jobs: deploy: name: Upload release to PyPI - permissions: - contents: read needs: - build runs-on: ubuntu-latest From 8ffe24b44b90c015282845876c219b5b375baef7 Mon Sep 17 00:00:00 2001 From: Luke Harris Date: Thu, 5 Jun 2025 23:18:45 +1000 Subject: [PATCH 10/15] Fix: Preserve deeply nested schemas (#133) * Preserve deeply nested schemas * Refactor schema tests to assert entire structure --------- Authored-by: Luke Harris --- src/strands/tools/tools.py | 74 ++++---- tests/strands/tools/test_tools.py | 291 ++++++++++++++++++++++++------ 2 files changed, 273 insertions(+), 92 deletions(-) diff --git a/src/strands/tools/tools.py b/src/strands/tools/tools.py index b595c3d6..a449c74e 100644 --- a/src/strands/tools/tools.py +++ b/src/strands/tools/tools.py @@ -63,50 +63,54 @@ def validate_tool_use_name(tool: ToolUse) -> None: raise InvalidToolUseNameException(message) +def _normalize_property(prop_name: str, prop_def: Any) -> Dict[str, Any]: + """Normalize a single property definition. + + Args: + prop_name: The name of the property. + prop_def: The property definition to normalize. + + Returns: + The normalized property definition. + """ + if not isinstance(prop_def, dict): + return {"type": "string", "description": f"Property {prop_name}"} + + if prop_def.get("type") == "object" and "properties" in prop_def: + return normalize_schema(prop_def) # Recursive call + + # Copy existing property, ensuring defaults + normalized_prop = prop_def.copy() + normalized_prop.setdefault("type", "string") + normalized_prop.setdefault("description", f"Property {prop_name}") + return normalized_prop + + def normalize_schema(schema: Dict[str, Any]) -> Dict[str, Any]: """Normalize a JSON schema to match expectations. + This function recursively processes nested objects to preserve the complete schema structure. + Uses a copy-then-normalize approach to preserve all original schema properties. + Args: schema: The schema to normalize. Returns: The normalized schema. """ - normalized = {"type": schema.get("type", "object"), "properties": {}} - - # Handle properties - if "properties" in schema: - for prop_name, prop_def in schema["properties"].items(): - if isinstance(prop_def, dict): - normalized_prop = { - "type": prop_def.get("type", "string"), - "description": prop_def.get("description", f"Property {prop_name}"), - } - - # Handle enum values correctly - if "enum" in prop_def: - normalized_prop["enum"] = prop_def["enum"] - - # Handle numeric constraints - if prop_def.get("type") in ["number", "integer"]: - if "minimum" in prop_def: - normalized_prop["minimum"] = prop_def["minimum"] - if "maximum" in prop_def: - normalized_prop["maximum"] = prop_def["maximum"] - - normalized["properties"][prop_name] = normalized_prop - else: - # Handle non-dict property definitions (like simple strings) - normalized["properties"][prop_name] = { - "type": "string", - "description": f"Property {prop_name}", - } - - # Required fields - if "required" in schema: - normalized["required"] = schema["required"] - else: - normalized["required"] = [] + # Start with a complete copy to preserve all existing properties + normalized = schema.copy() + + # Ensure essential structure exists + normalized.setdefault("type", "object") + normalized.setdefault("properties", {}) + normalized.setdefault("required", []) + + # Process properties recursively + if "properties" in normalized: + properties = normalized["properties"] + for prop_name, prop_def in properties.items(): + normalized["properties"][prop_name] = _normalize_property(prop_name, prop_def) return normalized diff --git a/tests/strands/tools/test_tools.py b/tests/strands/tools/test_tools.py index f24cc22d..1b65156b 100644 --- a/tests/strands/tools/test_tools.py +++ b/tests/strands/tools/test_tools.py @@ -50,11 +50,10 @@ def test_validate_tool_use(): def test_normalize_schema_basic(): schema = {"type": "object"} normalized = normalize_schema(schema) - assert normalized["type"] == "object" - assert "properties" in normalized - assert normalized["properties"] == {} - assert "required" in normalized - assert normalized["required"] == [] + + expected = {"type": "object", "properties": {}, "required": []} + + assert normalized == expected def test_normalize_schema_with_properties(): @@ -66,14 +65,17 @@ def test_normalize_schema_with_properties(): }, } normalized = normalize_schema(schema) - assert normalized["type"] == "object" - assert "properties" in normalized - assert "name" in normalized["properties"] - assert normalized["properties"]["name"]["type"] == "string" - assert normalized["properties"]["name"]["description"] == "User name" - assert "age" in normalized["properties"] - assert normalized["properties"]["age"]["type"] == "integer" - assert normalized["properties"]["age"]["description"] == "User age" + + expected = { + "type": "object", + "properties": { + "name": {"type": "string", "description": "User name"}, + "age": {"type": "integer", "description": "User age"}, + }, + "required": [], + } + + assert normalized == expected def test_normalize_schema_with_property_removed(): @@ -82,27 +84,40 @@ def test_normalize_schema_with_property_removed(): "properties": {"name": "invalid"}, } normalized = normalize_schema(schema) - assert "name" in normalized["properties"] - assert normalized["properties"]["name"]["type"] == "string" - assert normalized["properties"]["name"]["description"] == "Property name" + + expected = { + "type": "object", + "properties": {"name": {"type": "string", "description": "Property name"}}, + "required": [], + } + + assert normalized == expected def test_normalize_schema_with_property_defaults(): schema = {"properties": {"name": {}}} normalized = normalize_schema(schema) - assert "name" in normalized["properties"] - assert normalized["properties"]["name"]["type"] == "string" - assert normalized["properties"]["name"]["description"] == "Property name" + + expected = { + "type": "object", + "properties": {"name": {"type": "string", "description": "Property name"}}, + "required": [], + } + + assert normalized == expected def test_normalize_schema_with_property_enum(): schema = {"properties": {"color": {"type": "string", "description": "color", "enum": ["red", "green", "blue"]}}} normalized = normalize_schema(schema) - assert "color" in normalized["properties"] - assert normalized["properties"]["color"]["type"] == "string" - assert normalized["properties"]["color"]["description"] == "color" - assert "enum" in normalized["properties"]["color"] - assert normalized["properties"]["color"]["enum"] == ["red", "green", "blue"] + + expected = { + "type": "object", + "properties": {"color": {"type": "string", "description": "color", "enum": ["red", "green", "blue"]}}, + "required": [], + } + + assert normalized == expected def test_normalize_schema_with_property_numeric_constraints(): @@ -113,21 +128,170 @@ def test_normalize_schema_with_property_numeric_constraints(): } } normalized = normalize_schema(schema) - assert "age" in normalized["properties"] - assert normalized["properties"]["age"]["type"] == "integer" - assert normalized["properties"]["age"]["minimum"] == 0 - assert normalized["properties"]["age"]["maximum"] == 120 - assert "score" in normalized["properties"] - assert normalized["properties"]["score"]["type"] == "number" - assert normalized["properties"]["score"]["minimum"] == 0.0 - assert normalized["properties"]["score"]["maximum"] == 100.0 + + expected = { + "type": "object", + "properties": { + "age": {"type": "integer", "description": "age", "minimum": 0, "maximum": 120}, + "score": {"type": "number", "description": "score", "minimum": 0.0, "maximum": 100.0}, + }, + "required": [], + } + + assert normalized == expected def test_normalize_schema_with_required(): schema = {"type": "object", "required": ["name", "email"]} normalized = normalize_schema(schema) - assert "required" in normalized - assert normalized["required"] == ["name", "email"] + + expected = {"type": "object", "properties": {}, "required": ["name", "email"]} + + assert normalized == expected + + +def test_normalize_schema_with_nested_object(): + """Test normalization of schemas with nested objects.""" + schema = { + "type": "object", + "properties": { + "user": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "User name"}, + "age": {"type": "integer", "description": "User age"}, + }, + "required": ["name"], + } + }, + "required": ["user"], + } + + normalized = normalize_schema(schema) + + expected = { + "type": "object", + "properties": { + "user": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "User name"}, + "age": {"type": "integer", "description": "User age"}, + }, + "required": ["name"], + } + }, + "required": ["user"], + } + + assert normalized == expected + + +def test_normalize_schema_with_deeply_nested_objects(): + """Test normalization of deeply nested object structures.""" + schema = { + "type": "object", + "properties": { + "level1": { + "type": "object", + "properties": { + "level2": { + "type": "object", + "properties": { + "level3": {"type": "object", "properties": {"value": {"type": "string", "const": "fixed"}}} + }, + } + }, + } + }, + } + + normalized = normalize_schema(schema) + + expected = { + "type": "object", + "properties": { + "level1": { + "type": "object", + "properties": { + "level2": { + "type": "object", + "properties": { + "level3": { + "type": "object", + "properties": { + "value": {"type": "string", "description": "Property value", "const": "fixed"} + }, + "required": [], + } + }, + "required": [], + } + }, + "required": [], + } + }, + "required": [], + } + + assert normalized == expected + + +def test_normalize_schema_with_const_constraint(): + """Test that const constraints are preserved.""" + schema = { + "type": "object", + "properties": { + "status": {"type": "string", "const": "ACTIVE"}, + "config": {"type": "object", "properties": {"mode": {"type": "string", "const": "PRODUCTION"}}}, + }, + } + + normalized = normalize_schema(schema) + + expected = { + "type": "object", + "properties": { + "status": {"type": "string", "description": "Property status", "const": "ACTIVE"}, + "config": { + "type": "object", + "properties": {"mode": {"type": "string", "description": "Property mode", "const": "PRODUCTION"}}, + "required": [], + }, + }, + "required": [], + } + + assert normalized == expected + + +def test_normalize_schema_with_additional_properties(): + """Test that additionalProperties constraint is preserved.""" + schema = { + "type": "object", + "additionalProperties": False, + "properties": { + "data": {"type": "object", "properties": {"id": {"type": "string"}}, "additionalProperties": False} + }, + } + + normalized = normalize_schema(schema) + + expected = { + "type": "object", + "additionalProperties": False, + "properties": { + "data": { + "type": "object", + "additionalProperties": False, + "properties": {"id": {"type": "string", "description": "Property id"}}, + "required": [], + } + }, + "required": [], + } + + assert normalized == expected def test_normalize_tool_spec_with_json_schema(): @@ -137,14 +301,20 @@ def test_normalize_tool_spec_with_json_schema(): "inputSchema": {"json": {"type": "object", "properties": {"query": {}}, "required": ["query"]}}, } normalized = normalize_tool_spec(tool_spec) - assert normalized["name"] == "test_tool" - assert normalized["description"] == "A test tool" - assert "inputSchema" in normalized - assert "json" in normalized["inputSchema"] - assert normalized["inputSchema"]["json"]["type"] == "object" - assert "query" in normalized["inputSchema"]["json"]["properties"] - assert normalized["inputSchema"]["json"]["properties"]["query"]["type"] == "string" - assert normalized["inputSchema"]["json"]["properties"]["query"]["description"] == "Property query" + + expected = { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": { + "type": "object", + "properties": {"query": {"type": "string", "description": "Property query"}}, + "required": ["query"], + } + }, + } + + assert normalized == expected def test_normalize_tool_spec_with_direct_schema(): @@ -154,22 +324,29 @@ def test_normalize_tool_spec_with_direct_schema(): "inputSchema": {"type": "object", "properties": {"query": {}}, "required": ["query"]}, } normalized = normalize_tool_spec(tool_spec) - assert normalized["name"] == "test_tool" - assert normalized["description"] == "A test tool" - assert "inputSchema" in normalized - assert "json" in normalized["inputSchema"] - assert normalized["inputSchema"]["json"]["type"] == "object" - assert "query" in normalized["inputSchema"]["json"]["properties"] - assert normalized["inputSchema"]["json"]["required"] == ["query"] + + expected = { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": { + "type": "object", + "properties": {"query": {"type": "string", "description": "Property query"}}, + "required": ["query"], + } + }, + } + + assert normalized == expected def test_normalize_tool_spec_without_input_schema(): tool_spec = {"name": "test_tool", "description": "A test tool"} normalized = normalize_tool_spec(tool_spec) - assert normalized["name"] == "test_tool" - assert normalized["description"] == "A test tool" - # Should not modify the spec if inputSchema is not present - assert "inputSchema" not in normalized + + expected = {"name": "test_tool", "description": "A test tool"} + + assert normalized == expected def test_normalize_tool_spec_empty_input_schema(): @@ -179,10 +356,10 @@ def test_normalize_tool_spec_empty_input_schema(): "inputSchema": "", } normalized = normalize_tool_spec(tool_spec) - assert normalized["name"] == "test_tool" - assert normalized["description"] == "A test tool" - # Should not modify the spec if inputSchema is not a dict - assert normalized["inputSchema"] == "" + + expected = {"name": "test_tool", "description": "A test tool", "inputSchema": ""} + + assert normalized == expected def test_validate_tool_use_with_valid_input(): From c5eb3ee713b14d4df24f879e5c0a3cbb551dc55f Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Fri, 6 Jun 2025 07:20:34 -0400 Subject: [PATCH 11/15] fix: ignore mypy error from latest OpenTelemetrySDK update (#180) See open-telemetry/opentelemetry-python#4615 but it looks like an OpenTelemetrySDK update caused type errors --------- Co-authored-by: Mackenzie Zastrow --- src/strands/telemetry/tracer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 34eb7bed..9f731996 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -13,7 +13,9 @@ from opentelemetry import trace from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter -from opentelemetry.sdk.resources import Resource + +# See https://github.com/open-telemetry/opentelemetry-python/issues/4615 for the type ignore +from opentelemetry.sdk.resources import Resource # type: ignore[attr-defined] from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter, SimpleSpanProcessor from opentelemetry.trace import StatusCode From 903a260d6d723e9078b8660ec3fef10780f78c2b Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Fri, 6 Jun 2025 11:06:55 -0400 Subject: [PATCH 12/15] Add permission block to call-tst-lint job (#186) --- .github/workflows/pypi-publish-on-release.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/pypi-publish-on-release.yml b/.github/workflows/pypi-publish-on-release.yml index 0e26a1db..8967c552 100644 --- a/.github/workflows/pypi-publish-on-release.yml +++ b/.github/workflows/pypi-publish-on-release.yml @@ -8,6 +8,8 @@ on: jobs: call-test-lint: uses: ./.github/workflows/test-lint.yml + permissions: + contents: read with: ref: ${{ github.event.release.target_commitish }} From a64e80dfa5f5f6053a3fe53767f59fc3c5c1af95 Mon Sep 17 00:00:00 2001 From: mrityunjay shukla Date: Sat, 7 Jun 2025 01:07:22 +0530 Subject: [PATCH 13/15] fix: Handle empty choices in OpenAI model provider (#185) Co-authored-by: Mrityunjay Shukla --- src/strands/models/openai.py | 3 +++ tests/strands/models/test_openai.py | 41 +++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index 764cb851..6cbef664 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -94,6 +94,9 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: tool_calls: dict[int, list[Any]] = {} for event in response: + # Defensive: skip events with empty or missing choices + if not getattr(event, "choices", None): + continue choice = event.choices[0] if choice.delta.content: diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index 89aa591f..4c1f8528 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -132,3 +132,44 @@ def test_stream_empty(openai_client, model): assert tru_events == exp_events openai_client.chat.completions.create.assert_called_once_with(**request) + + +def test_stream_with_empty_choices(openai_client, model): + mock_delta = unittest.mock.Mock(content="content", tool_calls=None) + mock_usage = unittest.mock.Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30) + + # Event with no choices attribute + mock_event_1 = unittest.mock.Mock(spec=[]) + + # Event with empty choices list + mock_event_2 = unittest.mock.Mock(choices=[]) + + # Valid event with content + mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) + + # Event with finish reason + mock_event_4 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) + + # Final event with usage info + mock_event_5 = unittest.mock.Mock(usage=mock_usage) + + openai_client.chat.completions.create.return_value = iter( + [mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5] + ) + + request = {"model": "m1", "messages": [{"role": "user", "content": ["test"]}]} + response = model.stream(request) + + tru_events = list(response) + exp_events = [ + {"chunk_type": "message_start"}, + {"chunk_type": "content_start", "data_type": "text"}, + {"chunk_type": "content_delta", "data_type": "text", "data": "content"}, + {"chunk_type": "content_delta", "data_type": "text", "data": "content"}, + {"chunk_type": "content_stop", "data_type": "text"}, + {"chunk_type": "message_stop", "data": "stop"}, + {"chunk_type": "metadata", "data": mock_usage}, + ] + + assert tru_events == exp_events + openai_client.chat.completions.create.assert_called_once_with(**request) From 2d0e441a06977b39e207d6fc2d0b2b5327117fa0 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Mon, 9 Jun 2025 09:56:12 -0400 Subject: [PATCH 14/15] Remove codeowners (#181) To avoid the code owners from being added to every PR Co-authored-by: Mackenzie Zastrow --- .github/CODEOWNERS | 5 ----- 1 file changed, 5 deletions(-) delete mode 100644 .github/CODEOWNERS diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS deleted file mode 100644 index 7a4f8317..00000000 --- a/.github/CODEOWNERS +++ /dev/null @@ -1,5 +0,0 @@ -# These owners will be the default owners for everything in -# the repo. Unless a later match takes precedence, -# @strands-agents/contributors will be requested for -# review when someone opens a pull request. -* @strands-agents/maintainers \ No newline at end of file From 900610589d88fc8089e4e46d83415cf8f2c69876 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Mon, 9 Jun 2025 17:51:12 +0300 Subject: [PATCH 15/15] chore: enhance MCP error message for use outside context manager (#175) --- src/strands/tools/mcp/mcp_client.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index c6890945..a2298813 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -41,6 +41,12 @@ "image/webp": "webp", } +CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE = ( + "the client session is not running. Ensure the agent is used within " + "the MCP client context manager. For more information see: " + "https://strandsagents.com/latest/user-guide/concepts/tools/mcp-tools/#mcpclientinitializationerror" +) + class MCPClient: """Represents a connection to a Model Context Protocol (MCP) server. @@ -145,7 +151,7 @@ def list_tools_sync(self) -> List[MCPAgentTool]: """ self._log_debug_with_thread("listing MCP tools synchronously") if not self._is_session_active(): - raise MCPClientInitializationError("the client session is not running") + raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) async def _list_tools_async() -> ListToolsResult: return await self._background_thread_session.list_tools() @@ -180,7 +186,7 @@ def call_tool_sync( """ self._log_debug_with_thread("calling MCP tool '%s' synchronously with tool_use_id=%s", name, tool_use_id) if not self._is_session_active(): - raise MCPClientInitializationError("the client session is not running") + raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) async def _call_tool_async() -> MCPCallToolResult: return await self._background_thread_session.call_tool(name, arguments, read_timeout_seconds)