diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS deleted file mode 100644 index 7a4f8317..00000000 --- a/.github/CODEOWNERS +++ /dev/null @@ -1,5 +0,0 @@ -# These owners will be the default owners for everything in -# the repo. Unless a later match takes precedence, -# @strands-agents/contributors will be requested for -# review when someone opens a pull request. -* @strands-agents/maintainers \ No newline at end of file diff --git a/.github/workflows/pr-and-push.yml b/.github/workflows/pr-and-push.yml index 38e88691..2b2d026f 100644 --- a/.github/workflows/pr-and-push.yml +++ b/.github/workflows/pr-and-push.yml @@ -13,5 +13,7 @@ concurrency: jobs: call-test-lint: uses: ./.github/workflows/test-lint.yml + permissions: + contents: read with: - ref: ${{ github.event.pull_request.head.sha }} \ No newline at end of file + ref: ${{ github.event.pull_request.head.sha }} diff --git a/.github/workflows/pypi-publish-on-release.yml b/.github/workflows/pypi-publish-on-release.yml index 4047f596..8967c552 100644 --- a/.github/workflows/pypi-publish-on-release.yml +++ b/.github/workflows/pypi-publish-on-release.yml @@ -8,11 +8,15 @@ on: jobs: call-test-lint: uses: ./.github/workflows/test-lint.yml + permissions: + contents: read with: ref: ${{ github.event.release.target_commitish }} build: name: Build distribution 📦 + permissions: + contents: read needs: - call-test-lint runs-on: ubuntu-latest @@ -75,4 +79,4 @@ jobs: name: python-package-distributions path: dist/ - name: Publish distribution 📦 to PyPI - uses: pypa/gh-action-pypi-publish@release/v1 \ No newline at end of file + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 0f912b54..0651d452 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -44,6 +44,16 @@ logger = logging.getLogger(__name__) +# Sentinel class and object to distinguish between explicit None and default parameter value +class _DefaultCallbackHandlerSentinel: + """Sentinel class to distinguish between explicit None and default parameter value.""" + + pass + + +_DEFAULT_CALLBACK_HANDLER = _DefaultCallbackHandlerSentinel() + + class Agent: """Core Agent interface. @@ -70,7 +80,7 @@ def __init__(self, agent: "Agent") -> None: # agent tools and thus break their execution. self._agent = agent - def __getattr__(self, name: str) -> Callable: + def __getattr__(self, name: str) -> Callable[..., Any]: """Call tool as a function. This method enables the method-style interface (e.g., `agent.tool.tool_name(param="value")`). @@ -165,7 +175,7 @@ def caller(**kwargs: Any) -> Any: self._agent._record_tool_execution(tool_use, tool_result, user_message_override, messages) # Apply window management - self._agent.conversation_manager.apply_management(self._agent.messages) + self._agent.conversation_manager.apply_management(self._agent) return tool_result @@ -177,7 +187,9 @@ def __init__( messages: Optional[Messages] = None, tools: Optional[List[Union[str, Dict[str, str], Any]]] = None, system_prompt: Optional[str] = None, - callback_handler: Optional[Callable] = PrintingCallbackHandler(), + callback_handler: Optional[ + Union[Callable[..., Any], _DefaultCallbackHandlerSentinel] + ] = _DEFAULT_CALLBACK_HANDLER, conversation_manager: Optional[ConversationManager] = None, max_parallel_tools: int = os.cpu_count() or 1, record_direct_tool_call: bool = True, @@ -204,7 +216,8 @@ def __init__( system_prompt: System prompt to guide model behavior. If None, the model will behave according to its default settings. callback_handler: Callback for processing events as they happen during agent execution. - Defaults to strands.handlers.PrintingCallbackHandler if None. + If not provided (using the default), a new PrintingCallbackHandler instance is created. + If explicitly set to None, null_callback_handler is used. conversation_manager: Manager for conversation history and context window. Defaults to strands.agent.conversation_manager.SlidingWindowConversationManager if None. max_parallel_tools: Maximum number of tools to run in parallel when the model returns multiple tool calls. @@ -222,7 +235,17 @@ def __init__( self.messages = messages if messages is not None else [] self.system_prompt = system_prompt - self.callback_handler = callback_handler or null_callback_handler + + # If not provided, create a new PrintingCallbackHandler instance + # If explicitly set to None, use null_callback_handler + # Otherwise use the passed callback_handler + self.callback_handler: Union[Callable[..., Any], PrintingCallbackHandler] + if isinstance(callback_handler, _DefaultCallbackHandlerSentinel): + self.callback_handler = PrintingCallbackHandler() + elif callback_handler is None: + self.callback_handler = null_callback_handler + else: + self.callback_handler = callback_handler self.conversation_manager = conversation_manager if conversation_manager else SlidingWindowConversationManager() @@ -415,7 +438,7 @@ def target_callback() -> None: thread.join() def _run_loop( - self, prompt: str, kwargs: Any, supplementary_callback_handler: Optional[Callable] = None + self, prompt: str, kwargs: Dict[str, Any], supplementary_callback_handler: Optional[Callable[..., Any]] = None ) -> AgentResult: """Execute the agent's event loop with the given prompt and parameters.""" try: @@ -439,9 +462,9 @@ def _run_loop( return self._execute_event_loop_cycle(invocation_callback_handler, kwargs) finally: - self.conversation_manager.apply_management(self.messages) + self.conversation_manager.apply_management(self) - def _execute_event_loop_cycle(self, callback_handler: Callable, kwargs: dict[str, Any]) -> AgentResult: + def _execute_event_loop_cycle(self, callback_handler: Callable[..., Any], kwargs: Dict[str, Any]) -> AgentResult: """Execute the event loop cycle with retry logic for context window limits. This internal method handles the execution of the event loop cycle and implements @@ -483,7 +506,7 @@ def _execute_event_loop_cycle(self, callback_handler: Callable, kwargs: dict[str except ContextWindowOverflowException as e: # Try reducing the context size and retrying - self.conversation_manager.reduce_context(messages, e=e) + self.conversation_manager.reduce_context(self, e=e) return self._execute_event_loop_cycle(callback_handler_override, kwargs) def _record_tool_execution( diff --git a/src/strands/agent/conversation_manager/conversation_manager.py b/src/strands/agent/conversation_manager/conversation_manager.py index d18ae69a..dbccf941 100644 --- a/src/strands/agent/conversation_manager/conversation_manager.py +++ b/src/strands/agent/conversation_manager/conversation_manager.py @@ -1,9 +1,10 @@ """Abstract interface for conversation history management.""" from abc import ABC, abstractmethod -from typing import Optional +from typing import TYPE_CHECKING, Optional -from ...types.content import Messages +if TYPE_CHECKING: + from ...agent.agent import Agent class ConversationManager(ABC): @@ -19,22 +20,22 @@ class ConversationManager(ABC): @abstractmethod # pragma: no cover - def apply_management(self, messages: Messages) -> None: - """Applies management strategy to the provided list of messages. + def apply_management(self, agent: "Agent") -> None: + """Applies management strategy to the provided agent. Processes the conversation history to maintain appropriate size by modifying the messages list in-place. Implementations should handle message pruning, summarization, or other size management techniques to keep the conversation context within desired bounds. Args: - messages: The conversation history to manage. + agent: The agent whose conversation history will be manage. This list is modified in-place. """ pass @abstractmethod # pragma: no cover - def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> None: + def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None: """Called when the model's context window is exceeded. This method should implement the specific strategy for reducing the window size when a context overflow occurs. @@ -48,7 +49,7 @@ def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> N - Maintaining critical conversation markers Args: - messages: The conversation history to reduce. + agent: The agent whose conversation history will be reduced. This list is modified in-place. e: The exception that triggered the context reduction, if any. """ diff --git a/src/strands/agent/conversation_manager/null_conversation_manager.py b/src/strands/agent/conversation_manager/null_conversation_manager.py index 2066c08b..4af4eb78 100644 --- a/src/strands/agent/conversation_manager/null_conversation_manager.py +++ b/src/strands/agent/conversation_manager/null_conversation_manager.py @@ -1,8 +1,10 @@ """Null implementation of conversation management.""" -from typing import Optional +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from ...agent.agent import Agent -from ...types.content import Messages from ...types.exceptions import ContextWindowOverflowException from .conversation_manager import ConversationManager @@ -17,19 +19,19 @@ class NullConversationManager(ConversationManager): - Situations where the full conversation history should be preserved """ - def apply_management(self, messages: Messages) -> None: + def apply_management(self, _agent: "Agent") -> None: """Does nothing to the conversation history. Args: - messages: The conversation history that will remain unmodified. + agent: The agent whose conversation history will remain unmodified. """ pass - def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> None: + def reduce_context(self, _agent: "Agent", e: Optional[Exception] = None) -> None: """Does not reduce context and raises an exception. Args: - messages: The conversation history that will remain unmodified. + agent: The agent whose conversation history will remain unmodified. e: The exception that triggered the context reduction, if any. Raises: diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index f367b272..3381247c 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -1,7 +1,10 @@ """Sliding window conversation history management.""" import logging -from typing import Optional +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from ...agent.agent import Agent from ...types.content import Message, Messages from ...types.exceptions import ContextWindowOverflowException @@ -45,13 +48,13 @@ def __init__(self, window_size: int = 40): """Initialize the sliding window conversation manager. Args: - window_size: Maximum number of messages to keep in history. + window_size: Maximum number of messages to keep in the agent's history. Defaults to 40 messages. """ self.window_size = window_size - def apply_management(self, messages: Messages) -> None: - """Apply the sliding window to the messages array to maintain a manageable history size. + def apply_management(self, agent: "Agent") -> None: + """Apply the sliding window to the agent's messages array to maintain a manageable history size. This method is called after every event loop cycle, as the messages array may have been modified with tool results and assistant responses. It first removes any dangling messages that might create an invalid @@ -62,9 +65,10 @@ def apply_management(self, messages: Messages) -> None: blocks to maintain conversation coherence. Args: - messages: The messages to manage. + agent: The agent whose messages will be managed. This list is modified in-place. """ + messages = agent.messages self._remove_dangling_messages(messages) if len(messages) <= self.window_size: @@ -72,7 +76,7 @@ def apply_management(self, messages: Messages) -> None: "window_size=<%s>, message_count=<%s> | skipping context reduction", len(messages), self.window_size ) return - self.reduce_context(messages) + self.reduce_context(agent) def _remove_dangling_messages(self, messages: Messages) -> None: """Remove dangling messages that would create an invalid conversation state. @@ -105,7 +109,7 @@ def _remove_dangling_messages(self, messages: Messages) -> None: if not any("toolResult" in content for content in messages[-1]["content"]): messages.pop() - def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> None: + def reduce_context(self, agent: "Agent", e: Optional[Exception] = None) -> None: """Trim the oldest messages to reduce the conversation context size. The method handles special cases where trimming the messages leads to: @@ -113,7 +117,7 @@ def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> N - toolUse with no corresponding toolResult Args: - messages: The messages to reduce. + agent: The agent whose messages will be reduce. This list is modified in-place. e: The exception that triggered the context reduction, if any. @@ -122,6 +126,7 @@ def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> N Such as when the conversation is already minimal or when tool result messages cannot be properly converted. """ + messages = agent.messages # If the number of messages is less than the window_size, then we default to 2, otherwise, trim to window size trim_index = 2 if len(messages) <= self.window_size else len(messages) - self.window_size diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 99e49f81..57394e2c 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -95,6 +95,9 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An Returns: Anthropic formatted content block. + + Raises: + TypeError: If the content block type cannot be converted to an Anthropic-compatible format. """ if "document" in content: mime_type = mimetypes.types_map.get(f".{content['document']['format']}", "application/octet-stream") @@ -143,7 +146,11 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An if "toolResult" in content: return { "content": [ - self._format_request_message_content(cast(ContentBlock, tool_result_content)) + self._format_request_message_content( + {"text": json.dumps(tool_result_content["json"])} + if "json" in tool_result_content + else cast(ContentBlock, tool_result_content) + ) for tool_result_content in content["toolResult"]["content"] ], "is_error": content["toolResult"]["status"] == "error", @@ -151,7 +158,7 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An "type": "tool_result", } - return {"text": json.dumps(content), "type": "text"} + raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") def _format_request_messages(self, messages: Messages) -> list[dict[str, Any]]: """Format an Anthropic messages array. @@ -192,6 +199,10 @@ def format_request( Returns: An Anthropic streaming request. + + Raises: + TypeError: If a message contains a content block type that cannot be converted to an Anthropic-compatible + format. """ return { "max_tokens": self.config["max_tokens"], diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 23d2c2ae..62f16d31 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -67,8 +67,8 @@ def get_config(self) -> LiteLLMConfig: return cast(LiteLLMModel.LiteLLMConfig, self.config) @override - @staticmethod - def format_request_message_content(content: ContentBlock) -> dict[str, Any]: + @classmethod + def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: """Format a LiteLLM content block. Args: @@ -76,6 +76,9 @@ def format_request_message_content(content: ContentBlock) -> dict[str, Any]: Returns: LiteLLM formatted content block. + + Raises: + TypeError: If the content block type cannot be converted to a LiteLLM-compatible format. """ if "reasoningContent" in content: return { @@ -93,4 +96,4 @@ def format_request_message_content(content: ContentBlock) -> dict[str, Any]: }, } - return OpenAIModel.format_request_message_content(content) + return super().format_request_message_content(content) diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index 307614db..583db2f2 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -1,3 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates """Llama API model provider. - Docs: https://llama.developer.meta.com/ @@ -7,7 +8,7 @@ import json import logging import mimetypes -from typing import Any, Iterable, Optional +from typing import Any, Iterable, Optional, cast import llama_api_client from llama_api_client import LlamaAPIClient @@ -92,6 +93,9 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An Returns: LllamaAPI formatted content block. + + Raises: + TypeError: If the content block type cannot be converted to a LlamaAPI-compatible format. """ if "image" in content: mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream") @@ -107,7 +111,7 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An if "text" in content: return {"text": content["text"], "type": "text"} - return {"text": json.dumps(content), "type": "text"} + raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") def _format_request_message_tool_call(self, tool_use: ToolUse) -> dict[str, Any]: """Format a Llama API tool call. @@ -135,18 +139,30 @@ def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any Returns: Llama API formatted tool message. """ + contents = cast( + list[ContentBlock], + [ + {"text": json.dumps(content["json"])} if "json" in content else content + for content in tool_result["content"] + ], + ) + return { "role": "tool", "tool_call_id": tool_result["toolUseId"], - "content": json.dumps( - { - "content": tool_result["content"], - "status": tool_result["status"], - } - ), + "content": [self._format_request_message_content(content) for content in contents], } def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + """Format a LlamaAPI compatible messages array. + + Args: + messages: List of message objects to be processed by the model. + system_prompt: System prompt to provide context to the model. + + Returns: + An LlamaAPI compatible messages array. + """ formatted_messages: list[dict[str, Any]] formatted_messages = [{"role": "system", "content": system_prompt}] if system_prompt else [] @@ -196,6 +212,10 @@ def format_request( Returns: An Llama API chat streaming request. + + Raises: + TypeError: If a message contains a content block type that cannot be converted to a LlamaAPI-compatible + format. """ request = { "messages": self._format_request_messages(messages, system_prompt), diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index f632da34..7ed12216 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -5,13 +5,12 @@ import json import logging -from typing import Any, Iterable, Optional, Union +from typing import Any, Iterable, Optional, cast from ollama import Client as OllamaClient from typing_extensions import TypedDict, Unpack, override -from ..types.content import ContentBlock, Message, Messages -from ..types.media import DocumentContent, ImageContent +from ..types.content import ContentBlock, Messages from ..types.models import Model from ..types.streaming import StopReason, StreamEvent from ..types.tools import ToolSpec @@ -92,31 +91,31 @@ def get_config(self) -> OllamaConfig: """ return self.config - @override - def format_request( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None - ) -> dict[str, Any]: - """Format an Ollama chat streaming request. + def _format_request_message_contents(self, role: str, content: ContentBlock) -> list[dict[str, Any]]: + """Format Ollama compatible message contents. + + Ollama doesn't support an array of contents, so we must flatten everything into separate message blocks. Args: - messages: List of message objects to be processed by the model. - tool_specs: List of tool specifications to make available to the model. - system_prompt: System prompt to provide context to the model. + role: E.g., user. + content: Content block to format. Returns: - An Ollama chat streaming request. - """ + Ollama formatted message contents. - def format_message(message: Message, content: ContentBlock) -> dict[str, Any]: - if "text" in content: - return {"role": message["role"], "content": content["text"]} + Raises: + TypeError: If the content block type cannot be converted to an Ollama-compatible format. + """ + if "text" in content: + return [{"role": role, "content": content["text"]}] - if "image" in content: - return {"role": message["role"], "images": [content["image"]["source"]["bytes"]]} + if "image" in content: + return [{"role": role, "images": [content["image"]["source"]["bytes"]]}] - if "toolUse" in content: - return { - "role": "assistant", + if "toolUse" in content: + return [ + { + "role": role, "tool_calls": [ { "function": { @@ -126,45 +125,63 @@ def format_message(message: Message, content: ContentBlock) -> dict[str, Any]: } ], } + ] + + if "toolResult" in content: + return [ + formatted_tool_result_content + for tool_result_content in content["toolResult"]["content"] + for formatted_tool_result_content in self._format_request_message_contents( + "tool", + ( + {"text": json.dumps(tool_result_content["json"])} + if "json" in tool_result_content + else cast(ContentBlock, tool_result_content) + ), + ) + ] - if "toolResult" in content: - result_content: Union[str, ImageContent, DocumentContent, Any] = None - result_images = [] - for tool_result_content in content["toolResult"]["content"]: - if "text" in tool_result_content: - result_content = tool_result_content["text"] - elif "json" in tool_result_content: - result_content = tool_result_content["json"] - elif "image" in tool_result_content: - result_content = "see images" - result_images.append(tool_result_content["image"]["source"]["bytes"]) - else: - result_content = content["toolResult"]["content"] + raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") - return { - "role": "tool", - "content": json.dumps( - { - "name": content["toolResult"]["toolUseId"], - "result": result_content, - "status": content["toolResult"]["status"], - } - ), - **({"images": result_images} if result_images else {}), - } + def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + """Format an Ollama compatible messages array. - return {"role": message["role"], "content": json.dumps(content)} + Args: + messages: List of message objects to be processed by the model. + system_prompt: System prompt to provide context to the model. - def format_messages() -> list[dict[str, Any]]: - return [format_message(message, content) for message in messages for content in message["content"]] + Returns: + An Ollama compatible messages array. + """ + system_message = [{"role": "system", "content": system_prompt}] if system_prompt else [] - formatted_messages = format_messages() + return system_message + [ + formatted_message + for message in messages + for content in message["content"] + for formatted_message in self._format_request_message_contents(message["role"], content) + ] + @override + def format_request( + self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + ) -> dict[str, Any]: + """Format an Ollama chat streaming request. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. + + Returns: + An Ollama chat streaming request. + + Raises: + TypeError: If a message contains a content block type that cannot be converted to an Ollama-compatible + format. + """ return { - "messages": [ - *([{"role": "system", "content": system_prompt}] if system_prompt else []), - *formatted_messages, - ], + "messages": self._format_request_messages(messages, system_prompt), "model": self.config["model_id"], "options": { **(self.config.get("options") or {}), @@ -213,52 +230,54 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: RuntimeError: If chunk_type is not recognized. This error should never be encountered as we control chunk_type in the stream method. """ - if event["chunk_type"] == "message_start": - return {"messageStart": {"role": "assistant"}} - - if event["chunk_type"] == "content_start": - if event["data_type"] == "text": - return {"contentBlockStart": {"start": {}}} - - tool_name = event["data"].function.name - return {"contentBlockStart": {"start": {"toolUse": {"name": tool_name, "toolUseId": tool_name}}}} - - if event["chunk_type"] == "content_delta": - if event["data_type"] == "text": - return {"contentBlockDelta": {"delta": {"text": event["data"]}}} - - tool_arguments = event["data"].function.arguments - return {"contentBlockDelta": {"delta": {"toolUse": {"input": json.dumps(tool_arguments)}}}} - - if event["chunk_type"] == "content_stop": - return {"contentBlockStop": {}} - - if event["chunk_type"] == "message_stop": - reason: StopReason - if event["data"] == "tool_use": - reason = "tool_use" - elif event["data"] == "length": - reason = "max_tokens" - else: - reason = "end_turn" - - return {"messageStop": {"stopReason": reason}} - - if event["chunk_type"] == "metadata": - return { - "metadata": { - "usage": { - "inputTokens": event["data"].eval_count, - "outputTokens": event["data"].prompt_eval_count, - "totalTokens": event["data"].eval_count + event["data"].prompt_eval_count, - }, - "metrics": { - "latencyMs": event["data"].total_duration / 1e6, + match event["chunk_type"]: + case "message_start": + return {"messageStart": {"role": "assistant"}} + + case "content_start": + if event["data_type"] == "text": + return {"contentBlockStart": {"start": {}}} + + tool_name = event["data"].function.name + return {"contentBlockStart": {"start": {"toolUse": {"name": tool_name, "toolUseId": tool_name}}}} + + case "content_delta": + if event["data_type"] == "text": + return {"contentBlockDelta": {"delta": {"text": event["data"]}}} + + tool_arguments = event["data"].function.arguments + return {"contentBlockDelta": {"delta": {"toolUse": {"input": json.dumps(tool_arguments)}}}} + + case "content_stop": + return {"contentBlockStop": {}} + + case "message_stop": + reason: StopReason + if event["data"] == "tool_use": + reason = "tool_use" + elif event["data"] == "length": + reason = "max_tokens" + else: + reason = "end_turn" + + return {"messageStop": {"stopReason": reason}} + + case "metadata": + return { + "metadata": { + "usage": { + "inputTokens": event["data"].eval_count, + "outputTokens": event["data"].prompt_eval_count, + "totalTokens": event["data"].eval_count + event["data"].prompt_eval_count, + }, + "metrics": { + "latencyMs": event["data"].total_duration / 1e6, + }, }, - }, - } + } - raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") + case _: + raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") @override def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index 764cb851..6cbef664 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -94,6 +94,9 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: tool_calls: dict[int, list[Any]] = {} for event in response: + # Defensive: skip events with empty or missing choices + if not getattr(event, "choices", None): + continue choice = event.choices[0] if choice.delta.content: diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 34eb7bed..9f731996 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -13,7 +13,9 @@ from opentelemetry import trace from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter -from opentelemetry.sdk.resources import Resource + +# See https://github.com/open-telemetry/opentelemetry-python/issues/4615 for the type ignore +from opentelemetry.sdk.resources import Resource # type: ignore[attr-defined] from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter, SimpleSpanProcessor from opentelemetry.trace import StatusCode diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index c6890945..a2298813 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -41,6 +41,12 @@ "image/webp": "webp", } +CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE = ( + "the client session is not running. Ensure the agent is used within " + "the MCP client context manager. For more information see: " + "https://strandsagents.com/latest/user-guide/concepts/tools/mcp-tools/#mcpclientinitializationerror" +) + class MCPClient: """Represents a connection to a Model Context Protocol (MCP) server. @@ -145,7 +151,7 @@ def list_tools_sync(self) -> List[MCPAgentTool]: """ self._log_debug_with_thread("listing MCP tools synchronously") if not self._is_session_active(): - raise MCPClientInitializationError("the client session is not running") + raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) async def _list_tools_async() -> ListToolsResult: return await self._background_thread_session.list_tools() @@ -180,7 +186,7 @@ def call_tool_sync( """ self._log_debug_with_thread("calling MCP tool '%s' synchronously with tool_use_id=%s", name, tool_use_id) if not self._is_session_active(): - raise MCPClientInitializationError("the client session is not running") + raise MCPClientInitializationError(CLIENT_SESSION_NOT_RUNNING_ERROR_MESSAGE) async def _call_tool_async() -> MCPCallToolResult: return await self._background_thread_session.call_tool(name, arguments, read_timeout_seconds) diff --git a/src/strands/tools/tools.py b/src/strands/tools/tools.py index b595c3d6..a449c74e 100644 --- a/src/strands/tools/tools.py +++ b/src/strands/tools/tools.py @@ -63,50 +63,54 @@ def validate_tool_use_name(tool: ToolUse) -> None: raise InvalidToolUseNameException(message) +def _normalize_property(prop_name: str, prop_def: Any) -> Dict[str, Any]: + """Normalize a single property definition. + + Args: + prop_name: The name of the property. + prop_def: The property definition to normalize. + + Returns: + The normalized property definition. + """ + if not isinstance(prop_def, dict): + return {"type": "string", "description": f"Property {prop_name}"} + + if prop_def.get("type") == "object" and "properties" in prop_def: + return normalize_schema(prop_def) # Recursive call + + # Copy existing property, ensuring defaults + normalized_prop = prop_def.copy() + normalized_prop.setdefault("type", "string") + normalized_prop.setdefault("description", f"Property {prop_name}") + return normalized_prop + + def normalize_schema(schema: Dict[str, Any]) -> Dict[str, Any]: """Normalize a JSON schema to match expectations. + This function recursively processes nested objects to preserve the complete schema structure. + Uses a copy-then-normalize approach to preserve all original schema properties. + Args: schema: The schema to normalize. Returns: The normalized schema. """ - normalized = {"type": schema.get("type", "object"), "properties": {}} - - # Handle properties - if "properties" in schema: - for prop_name, prop_def in schema["properties"].items(): - if isinstance(prop_def, dict): - normalized_prop = { - "type": prop_def.get("type", "string"), - "description": prop_def.get("description", f"Property {prop_name}"), - } - - # Handle enum values correctly - if "enum" in prop_def: - normalized_prop["enum"] = prop_def["enum"] - - # Handle numeric constraints - if prop_def.get("type") in ["number", "integer"]: - if "minimum" in prop_def: - normalized_prop["minimum"] = prop_def["minimum"] - if "maximum" in prop_def: - normalized_prop["maximum"] = prop_def["maximum"] - - normalized["properties"][prop_name] = normalized_prop - else: - # Handle non-dict property definitions (like simple strings) - normalized["properties"][prop_name] = { - "type": "string", - "description": f"Property {prop_name}", - } - - # Required fields - if "required" in schema: - normalized["required"] = schema["required"] - else: - normalized["required"] = [] + # Start with a complete copy to preserve all existing properties + normalized = schema.copy() + + # Ensure essential structure exists + normalized.setdefault("type", "object") + normalized.setdefault("properties", {}) + normalized.setdefault("required", []) + + # Process properties recursively + if "properties" in normalized: + properties = normalized["properties"] + for prop_name, prop_def in properties.items(): + normalized["properties"][prop_name] = _normalize_property(prop_name, prop_def) return normalized diff --git a/src/strands/types/content.py b/src/strands/types/content.py index c64b6773..790e9094 100644 --- a/src/strands/types/content.py +++ b/src/strands/types/content.py @@ -60,10 +60,21 @@ class ReasoningContentBlock(TypedDict, total=False): redactedContent: bytes +class CachePoint(TypedDict): + """A cache point configuration for optimizing conversation history. + + Attributes: + type: The type of cache point, typically "default". + """ + + type: str + + class ContentBlock(TypedDict, total=False): """A block of content for a message that you pass to, or receive from, a model. Attributes: + cachePoint: A cache point configuration to optimize conversation history. document: A document to include in the message. guardContent: Contains the content to assess with the guardrail. image: Image to include in the message. @@ -74,6 +85,7 @@ class ContentBlock(TypedDict, total=False): video: Video to include in the message. """ + cachePoint: CachePoint document: DocumentContent guardContent: GuardContent image: ImageContent diff --git a/src/strands/types/models/openai.py b/src/strands/types/models/openai.py index 307c0be6..96f758d5 100644 --- a/src/strands/types/models/openai.py +++ b/src/strands/types/models/openai.py @@ -11,7 +11,7 @@ import json import logging import mimetypes -from typing import Any, Optional +from typing import Any, Optional, cast from typing_extensions import override @@ -31,8 +31,8 @@ class OpenAIModel(Model, abc.ABC): config: dict[str, Any] - @staticmethod - def format_request_message_content(content: ContentBlock) -> dict[str, Any]: + @classmethod + def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: """Format an OpenAI compatible content block. Args: @@ -40,6 +40,9 @@ def format_request_message_content(content: ContentBlock) -> dict[str, Any]: Returns: OpenAI compatible content block. + + Raises: + TypeError: If the content block type cannot be converted to an OpenAI-compatible format. """ if "document" in content: mime_type = mimetypes.types_map.get(f".{content['document']['format']}", "application/octet-stream") @@ -67,10 +70,10 @@ def format_request_message_content(content: ContentBlock) -> dict[str, Any]: if "text" in content: return {"text": content["text"], "type": "text"} - return {"text": json.dumps(content), "type": "text"} + raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") - @staticmethod - def format_request_message_tool_call(tool_use: ToolUse) -> dict[str, Any]: + @classmethod + def format_request_message_tool_call(cls, tool_use: ToolUse) -> dict[str, Any]: """Format an OpenAI compatible tool call. Args: @@ -88,8 +91,8 @@ def format_request_message_tool_call(tool_use: ToolUse) -> dict[str, Any]: "type": "function", } - @staticmethod - def format_request_tool_message(tool_result: ToolResult) -> dict[str, Any]: + @classmethod + def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]: """Format an OpenAI compatible tool message. Args: @@ -98,15 +101,18 @@ def format_request_tool_message(tool_result: ToolResult) -> dict[str, Any]: Returns: OpenAI compatible tool message. """ + contents = cast( + list[ContentBlock], + [ + {"text": json.dumps(content["json"])} if "json" in content else content + for content in tool_result["content"] + ], + ) + return { "role": "tool", "tool_call_id": tool_result["toolUseId"], - "content": json.dumps( - { - "content": tool_result["content"], - "status": tool_result["status"], - } - ), + "content": [cls.format_request_message_content(content) for content in contents], } @classmethod @@ -163,6 +169,10 @@ def format_request( Returns: An OpenAI compatible chat streaming request. + + Raises: + TypeError: If a message contains a content block type that cannot be converted to an OpenAI-compatible + format. """ return { "messages": self.format_request_messages(messages, system_prompt), diff --git a/tests-integ/test_model_llamaapi.py b/tests-integ/test_model_llamaapi.py index 5cddc1a0..dad6919e 100644 --- a/tests-integ/test_model_llamaapi.py +++ b/tests-integ/test_model_llamaapi.py @@ -1,3 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates import os import pytest diff --git a/tests/conftest.py b/tests/conftest.py index 4f0b5b21..cd18b698 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,6 +24,8 @@ def moto_env(monkeypatch): monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "test") monkeypatch.setenv("AWS_SECURITY_TOKEN", "test") monkeypatch.setenv("AWS_DEFAULT_REGION", "us-west-2") + monkeypatch.delenv("OTEL_EXPORTER_OTLP_ENDPOINT", raising=False) + monkeypatch.delenv("OTEL_EXPORTER_OTLP_HEADERS", raising=False) @pytest.fixture diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index ea06fb4e..0ea20b64 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -318,7 +318,7 @@ def test_agent__call__( ) callback_handler.assert_called() - conversation_manager_spy.apply_management.assert_called_with(agent.messages) + conversation_manager_spy.apply_management.assert_called_with(agent) def test_agent__call__passes_kwargs(mock_model, system_prompt, callback_handler, agent, tool, mock_event_loop_cycle): @@ -583,7 +583,7 @@ def test_agent_tool(mock_randint, agent): } assert tru_result == exp_result - conversation_manager_spy.apply_management.assert_called_with(agent.messages) + conversation_manager_spy.apply_management.assert_called_with(agent) def test_agent_tool_user_message_override(agent): @@ -686,6 +686,37 @@ def test_agent_with_callback_handler_none_uses_null_handler(): assert agent.callback_handler == null_callback_handler +def test_agent_callback_handler_not_provided_creates_new_instances(): + """Test that when callback_handler is not provided, new PrintingCallbackHandler instances are created.""" + # Create two agents without providing callback_handler + agent1 = Agent() + agent2 = Agent() + + # Both should have PrintingCallbackHandler instances + assert isinstance(agent1.callback_handler, PrintingCallbackHandler) + assert isinstance(agent2.callback_handler, PrintingCallbackHandler) + + # But they should be different object instances + assert agent1.callback_handler is not agent2.callback_handler + + +def test_agent_callback_handler_explicit_none_uses_null_handler(): + """Test that when callback_handler is explicitly set to None, null_callback_handler is used.""" + agent = Agent(callback_handler=None) + + # Should use null_callback_handler + assert agent.callback_handler is null_callback_handler + + +def test_agent_callback_handler_custom_handler_used(): + """Test that when a custom callback_handler is provided, it is used.""" + custom_handler = unittest.mock.Mock() + agent = Agent(callback_handler=custom_handler) + + # Should use the provided custom handler + assert agent.callback_handler is custom_handler + + @pytest.mark.asyncio async def test_stream_async_returns_all_events(mock_event_loop_cycle): agent = Agent() diff --git a/tests/strands/agent/test_conversation_manager.py b/tests/strands/agent/test_conversation_manager.py index b6132f1d..bbec3cd1 100644 --- a/tests/strands/agent/test_conversation_manager.py +++ b/tests/strands/agent/test_conversation_manager.py @@ -1,6 +1,7 @@ import pytest import strands +from strands.agent.agent import Agent from strands.types.exceptions import ContextWindowOverflowException @@ -160,7 +161,8 @@ def conversation_manager(request): indirect=["conversation_manager"], ) def test_apply_management(conversation_manager, messages, expected_messages): - conversation_manager.apply_management(messages) + test_agent = Agent(messages=messages) + conversation_manager.apply_management(test_agent) assert messages == expected_messages @@ -172,9 +174,10 @@ def test_sliding_window_conversation_manager_with_untrimmable_history_raises_con {"role": "user", "content": [{"toolResult": {"toolUseId": "789", "content": [], "status": "success"}}]}, ] original_messages = messages.copy() + test_agent = Agent(messages=messages) with pytest.raises(ContextWindowOverflowException): - manager.apply_management(messages) + manager.apply_management(test_agent) assert messages == original_messages @@ -187,8 +190,9 @@ def test_null_conversation_manager_reduce_context_raises_context_window_overflow {"role": "assistant", "content": [{"text": "Hi there"}]}, ] original_messages = messages.copy() + test_agent = Agent(messages=messages) - manager.apply_management(messages) + manager.apply_management(test_agent) with pytest.raises(ContextWindowOverflowException): manager.reduce_context(messages) @@ -204,8 +208,9 @@ def test_null_conversation_manager_reduce_context_with_exception_raises_same_exc {"role": "assistant", "content": [{"text": "Hi there"}]}, ] original_messages = messages.copy() + test_agent = Agent(messages=messages) - manager.apply_management(messages) + manager.apply_management(test_agent) with pytest.raises(RuntimeError): manager.reduce_context(messages, RuntimeError("test")) diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index 2ee344cc..9421650e 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -1,4 +1,3 @@ -import json import unittest.mock import anthropic @@ -290,6 +289,7 @@ def test_format_request_with_tool_results(model, model_id, max_tokens): "status": "success", "content": [ {"text": "see image"}, + {"json": ["see image"]}, { "image": { "format": "jpg", @@ -316,6 +316,10 @@ def test_format_request_with_tool_results(model, model_id, max_tokens): "text": "see image", "type": "text", }, + { + "text": '["see image"]', + "type": "text", + }, { "source": { "data": "YmFzZTY0ZW5jb2RlZGltYWdl", @@ -339,33 +343,16 @@ def test_format_request_with_tool_results(model, model_id, max_tokens): assert tru_request == exp_request -def test_format_request_with_other(model, model_id, max_tokens): +def test_format_request_with_unsupported_type(model): messages = [ { "role": "user", - "content": [{"other": {"a": 1}}], + "content": [{"unsupported": {}}], }, ] - tru_request = model.format_request(messages) - exp_request = { - "max_tokens": max_tokens, - "messages": [ - { - "role": "user", - "content": [ - { - "text": json.dumps({"other": {"a": 1}}), - "type": "text", - }, - ], - }, - ], - "model": model_id, - "tools": [], - } - - assert tru_request == exp_request + with pytest.raises(TypeError, match="content_type= | unsupported type"): + model.format_request(messages) def test_format_request_with_cache_point(model, model_id, max_tokens): diff --git a/tests/strands/models/test_llamaapi.py b/tests/strands/models/test_llamaapi.py index 5d7cd0f4..309dac2e 100644 --- a/tests/strands/models/test_llamaapi.py +++ b/tests/strands/models/test_llamaapi.py @@ -1,4 +1,4 @@ -import json +# Copyright (c) Meta Platforms, Inc. and affiliates import unittest.mock import pytest @@ -144,7 +144,7 @@ def test_format_request_with_tool_result(model, model_id): "toolResult": { "toolUseId": "c1", "status": "success", - "content": [{"value": 4}], + "content": [{"text": "4"}, {"json": ["4"]}], } } ], @@ -155,12 +155,7 @@ def test_format_request_with_tool_result(model, model_id): exp_request = { "messages": [ { - "content": json.dumps( - { - "content": [{"value": 4}], - "status": "success", - } - ), + "content": [{"text": "4", "type": "text"}, {"text": '["4"]', "type": "text"}], "role": "tool", "tool_call_id": "c1", }, @@ -233,6 +228,18 @@ def test_format_request_with_empty_content(model, model_id): assert tru_request == exp_request +def test_format_request_with_unsupported_type(model): + messages = [ + { + "role": "user", + "content": [{"unsupported": {}}], + }, + ] + + with pytest.raises(TypeError, match="content_type= | unsupported type"): + model.format_request(messages) + + def test_format_chunk_message_start(model): event = {"chunk_type": "message_start"} diff --git a/tests/strands/models/test_ollama.py b/tests/strands/models/test_ollama.py index d3bd8f7f..fe590dff 100644 --- a/tests/strands/models/test_ollama.py +++ b/tests/strands/models/test_ollama.py @@ -148,81 +148,24 @@ def test_format_request_with_tool_use(model, model_id): def test_format_request_with_tool_result(model, model_id): messages: Messages = [ { - "role": "tool", - "content": [{"toolResult": {"toolUseId": "calculator", "status": "success", "content": [{"text": "4"}]}}], - } - ] - - tru_request = model.format_request(messages) - exp_request = { - "messages": [ - { - "role": "tool", - "content": json.dumps( - { - "name": "calculator", - "result": "4", - "status": "success", - } - ), - } - ], - "model": model_id, - "options": {}, - "stream": True, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_tool_result_json(model, model_id): - messages: Messages = [ - { - "role": "tool", - "content": [ - {"toolResult": {"toolUseId": "calculator", "status": "success", "content": [{"json": {"result": 4}}]}} - ], - } - ] - - tru_request = model.format_request(messages) - exp_request = { - "messages": [ - { - "role": "tool", - "content": json.dumps( - { - "name": "calculator", - "result": {"result": 4}, - "status": "success", - } - ), - } - ], - "model": model_id, - "options": {}, - "stream": True, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_tool_result_image(model, model_id): - messages: Messages = [ - { - "role": "tool", + "role": "user", "content": [ { "toolResult": { - "toolUseId": "image_generator", + "toolUseId": "calculator", "status": "success", - "content": [{"image": {"source": {"bytes": "base64encodedimage"}}}], - } - } + "content": [ + {"text": "4"}, + {"image": {"source": {"bytes": b"image"}}}, + {"json": ["4"]}, + ], + }, + }, + { + "text": "see results", + }, ], - } + }, ] tru_request = model.format_request(messages) @@ -230,46 +173,20 @@ def test_format_request_with_tool_result_image(model, model_id): "messages": [ { "role": "tool", - "content": json.dumps( - { - "name": "image_generator", - "result": "see images", - "status": "success", - } - ), - "images": ["base64encodedimage"], - } - ], - "model": model_id, - "options": {}, - "stream": True, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_tool_result_other(model, model_id): - messages = [ - { - "role": "tool", - "content": [{"toolResult": {"toolUseId": "other", "status": "success", "content": {"other": {"a": 1}}}}], - } - ] - - tru_request = model.format_request(messages) - exp_request = { - "messages": [ + "content": "4", + }, { "role": "tool", - "content": json.dumps( - { - "name": "other", - "result": {"other": {"a": 1}}, - "status": "success", - } - ), - } + "images": [b"image"], + }, + { + "role": "tool", + "content": '["4"]', + }, + { + "role": "user", + "content": "see results", + }, ], "model": model_id, "options": {}, @@ -280,29 +197,16 @@ def test_format_request_with_tool_result_other(model, model_id): assert tru_request == exp_request -def test_format_request_with_other(model, model_id): +def test_format_request_with_unsupported_type(model): messages = [ { "role": "user", - "content": [{"other": {"a": 1}}], - } + "content": [{"unsupported": {}}], + }, ] - tru_request = model.format_request(messages) - exp_request = { - "messages": [ - { - "role": "user", - "content": json.dumps({"other": {"a": 1}}), - } - ], - "model": model_id, - "options": {}, - "stream": True, - "tools": [], - } - - assert tru_request == exp_request + with pytest.raises(TypeError, match="content_type= | unsupported type"): + model.format_request(messages) def test_format_request_with_tool_specs(model, messages, model_id): diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index 89aa591f..4c1f8528 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -132,3 +132,44 @@ def test_stream_empty(openai_client, model): assert tru_events == exp_events openai_client.chat.completions.create.assert_called_once_with(**request) + + +def test_stream_with_empty_choices(openai_client, model): + mock_delta = unittest.mock.Mock(content="content", tool_calls=None) + mock_usage = unittest.mock.Mock(prompt_tokens=10, completion_tokens=20, total_tokens=30) + + # Event with no choices attribute + mock_event_1 = unittest.mock.Mock(spec=[]) + + # Event with empty choices list + mock_event_2 = unittest.mock.Mock(choices=[]) + + # Valid event with content + mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) + + # Event with finish reason + mock_event_4 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) + + # Final event with usage info + mock_event_5 = unittest.mock.Mock(usage=mock_usage) + + openai_client.chat.completions.create.return_value = iter( + [mock_event_1, mock_event_2, mock_event_3, mock_event_4, mock_event_5] + ) + + request = {"model": "m1", "messages": [{"role": "user", "content": ["test"]}]} + response = model.stream(request) + + tru_events = list(response) + exp_events = [ + {"chunk_type": "message_start"}, + {"chunk_type": "content_start", "data_type": "text"}, + {"chunk_type": "content_delta", "data_type": "text", "data": "content"}, + {"chunk_type": "content_delta", "data_type": "text", "data": "content"}, + {"chunk_type": "content_stop", "data_type": "text"}, + {"chunk_type": "message_stop", "data": "stop"}, + {"chunk_type": "metadata", "data": mock_usage}, + ] + + assert tru_events == exp_events + openai_client.chat.completions.create.assert_called_once_with(**request) diff --git a/tests/strands/tools/test_executor.py b/tests/strands/tools/test_executor.py index ced2bd7f..a6ea45c3 100644 --- a/tests/strands/tools/test_executor.py +++ b/tests/strands/tools/test_executor.py @@ -1,6 +1,7 @@ import concurrent import functools import unittest.mock +import uuid import pytest @@ -9,6 +10,11 @@ from strands.types.content import Message +@pytest.fixture(autouse=True) +def moto_autouse(moto_env): + _ = moto_env + + @pytest.fixture def tool_handler(request): def handler(tool_use): @@ -52,10 +58,10 @@ def invalid_tool_use_ids(request): return request.param if hasattr(request, "param") else [] -@unittest.mock.patch.object(strands.telemetry.metrics, "uuid4", return_value="trace1") @pytest.fixture def cycle_trace(): - return strands.telemetry.metrics.Trace(name="test trace", raw_name="raw_name") + with unittest.mock.patch.object(uuid, "uuid4", return_value="trace1"): + return strands.telemetry.metrics.Trace(name="test trace", raw_name="raw_name") @pytest.fixture diff --git a/tests/strands/tools/test_tools.py b/tests/strands/tools/test_tools.py index f24cc22d..1b65156b 100644 --- a/tests/strands/tools/test_tools.py +++ b/tests/strands/tools/test_tools.py @@ -50,11 +50,10 @@ def test_validate_tool_use(): def test_normalize_schema_basic(): schema = {"type": "object"} normalized = normalize_schema(schema) - assert normalized["type"] == "object" - assert "properties" in normalized - assert normalized["properties"] == {} - assert "required" in normalized - assert normalized["required"] == [] + + expected = {"type": "object", "properties": {}, "required": []} + + assert normalized == expected def test_normalize_schema_with_properties(): @@ -66,14 +65,17 @@ def test_normalize_schema_with_properties(): }, } normalized = normalize_schema(schema) - assert normalized["type"] == "object" - assert "properties" in normalized - assert "name" in normalized["properties"] - assert normalized["properties"]["name"]["type"] == "string" - assert normalized["properties"]["name"]["description"] == "User name" - assert "age" in normalized["properties"] - assert normalized["properties"]["age"]["type"] == "integer" - assert normalized["properties"]["age"]["description"] == "User age" + + expected = { + "type": "object", + "properties": { + "name": {"type": "string", "description": "User name"}, + "age": {"type": "integer", "description": "User age"}, + }, + "required": [], + } + + assert normalized == expected def test_normalize_schema_with_property_removed(): @@ -82,27 +84,40 @@ def test_normalize_schema_with_property_removed(): "properties": {"name": "invalid"}, } normalized = normalize_schema(schema) - assert "name" in normalized["properties"] - assert normalized["properties"]["name"]["type"] == "string" - assert normalized["properties"]["name"]["description"] == "Property name" + + expected = { + "type": "object", + "properties": {"name": {"type": "string", "description": "Property name"}}, + "required": [], + } + + assert normalized == expected def test_normalize_schema_with_property_defaults(): schema = {"properties": {"name": {}}} normalized = normalize_schema(schema) - assert "name" in normalized["properties"] - assert normalized["properties"]["name"]["type"] == "string" - assert normalized["properties"]["name"]["description"] == "Property name" + + expected = { + "type": "object", + "properties": {"name": {"type": "string", "description": "Property name"}}, + "required": [], + } + + assert normalized == expected def test_normalize_schema_with_property_enum(): schema = {"properties": {"color": {"type": "string", "description": "color", "enum": ["red", "green", "blue"]}}} normalized = normalize_schema(schema) - assert "color" in normalized["properties"] - assert normalized["properties"]["color"]["type"] == "string" - assert normalized["properties"]["color"]["description"] == "color" - assert "enum" in normalized["properties"]["color"] - assert normalized["properties"]["color"]["enum"] == ["red", "green", "blue"] + + expected = { + "type": "object", + "properties": {"color": {"type": "string", "description": "color", "enum": ["red", "green", "blue"]}}, + "required": [], + } + + assert normalized == expected def test_normalize_schema_with_property_numeric_constraints(): @@ -113,21 +128,170 @@ def test_normalize_schema_with_property_numeric_constraints(): } } normalized = normalize_schema(schema) - assert "age" in normalized["properties"] - assert normalized["properties"]["age"]["type"] == "integer" - assert normalized["properties"]["age"]["minimum"] == 0 - assert normalized["properties"]["age"]["maximum"] == 120 - assert "score" in normalized["properties"] - assert normalized["properties"]["score"]["type"] == "number" - assert normalized["properties"]["score"]["minimum"] == 0.0 - assert normalized["properties"]["score"]["maximum"] == 100.0 + + expected = { + "type": "object", + "properties": { + "age": {"type": "integer", "description": "age", "minimum": 0, "maximum": 120}, + "score": {"type": "number", "description": "score", "minimum": 0.0, "maximum": 100.0}, + }, + "required": [], + } + + assert normalized == expected def test_normalize_schema_with_required(): schema = {"type": "object", "required": ["name", "email"]} normalized = normalize_schema(schema) - assert "required" in normalized - assert normalized["required"] == ["name", "email"] + + expected = {"type": "object", "properties": {}, "required": ["name", "email"]} + + assert normalized == expected + + +def test_normalize_schema_with_nested_object(): + """Test normalization of schemas with nested objects.""" + schema = { + "type": "object", + "properties": { + "user": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "User name"}, + "age": {"type": "integer", "description": "User age"}, + }, + "required": ["name"], + } + }, + "required": ["user"], + } + + normalized = normalize_schema(schema) + + expected = { + "type": "object", + "properties": { + "user": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "User name"}, + "age": {"type": "integer", "description": "User age"}, + }, + "required": ["name"], + } + }, + "required": ["user"], + } + + assert normalized == expected + + +def test_normalize_schema_with_deeply_nested_objects(): + """Test normalization of deeply nested object structures.""" + schema = { + "type": "object", + "properties": { + "level1": { + "type": "object", + "properties": { + "level2": { + "type": "object", + "properties": { + "level3": {"type": "object", "properties": {"value": {"type": "string", "const": "fixed"}}} + }, + } + }, + } + }, + } + + normalized = normalize_schema(schema) + + expected = { + "type": "object", + "properties": { + "level1": { + "type": "object", + "properties": { + "level2": { + "type": "object", + "properties": { + "level3": { + "type": "object", + "properties": { + "value": {"type": "string", "description": "Property value", "const": "fixed"} + }, + "required": [], + } + }, + "required": [], + } + }, + "required": [], + } + }, + "required": [], + } + + assert normalized == expected + + +def test_normalize_schema_with_const_constraint(): + """Test that const constraints are preserved.""" + schema = { + "type": "object", + "properties": { + "status": {"type": "string", "const": "ACTIVE"}, + "config": {"type": "object", "properties": {"mode": {"type": "string", "const": "PRODUCTION"}}}, + }, + } + + normalized = normalize_schema(schema) + + expected = { + "type": "object", + "properties": { + "status": {"type": "string", "description": "Property status", "const": "ACTIVE"}, + "config": { + "type": "object", + "properties": {"mode": {"type": "string", "description": "Property mode", "const": "PRODUCTION"}}, + "required": [], + }, + }, + "required": [], + } + + assert normalized == expected + + +def test_normalize_schema_with_additional_properties(): + """Test that additionalProperties constraint is preserved.""" + schema = { + "type": "object", + "additionalProperties": False, + "properties": { + "data": {"type": "object", "properties": {"id": {"type": "string"}}, "additionalProperties": False} + }, + } + + normalized = normalize_schema(schema) + + expected = { + "type": "object", + "additionalProperties": False, + "properties": { + "data": { + "type": "object", + "additionalProperties": False, + "properties": {"id": {"type": "string", "description": "Property id"}}, + "required": [], + } + }, + "required": [], + } + + assert normalized == expected def test_normalize_tool_spec_with_json_schema(): @@ -137,14 +301,20 @@ def test_normalize_tool_spec_with_json_schema(): "inputSchema": {"json": {"type": "object", "properties": {"query": {}}, "required": ["query"]}}, } normalized = normalize_tool_spec(tool_spec) - assert normalized["name"] == "test_tool" - assert normalized["description"] == "A test tool" - assert "inputSchema" in normalized - assert "json" in normalized["inputSchema"] - assert normalized["inputSchema"]["json"]["type"] == "object" - assert "query" in normalized["inputSchema"]["json"]["properties"] - assert normalized["inputSchema"]["json"]["properties"]["query"]["type"] == "string" - assert normalized["inputSchema"]["json"]["properties"]["query"]["description"] == "Property query" + + expected = { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": { + "type": "object", + "properties": {"query": {"type": "string", "description": "Property query"}}, + "required": ["query"], + } + }, + } + + assert normalized == expected def test_normalize_tool_spec_with_direct_schema(): @@ -154,22 +324,29 @@ def test_normalize_tool_spec_with_direct_schema(): "inputSchema": {"type": "object", "properties": {"query": {}}, "required": ["query"]}, } normalized = normalize_tool_spec(tool_spec) - assert normalized["name"] == "test_tool" - assert normalized["description"] == "A test tool" - assert "inputSchema" in normalized - assert "json" in normalized["inputSchema"] - assert normalized["inputSchema"]["json"]["type"] == "object" - assert "query" in normalized["inputSchema"]["json"]["properties"] - assert normalized["inputSchema"]["json"]["required"] == ["query"] + + expected = { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": { + "type": "object", + "properties": {"query": {"type": "string", "description": "Property query"}}, + "required": ["query"], + } + }, + } + + assert normalized == expected def test_normalize_tool_spec_without_input_schema(): tool_spec = {"name": "test_tool", "description": "A test tool"} normalized = normalize_tool_spec(tool_spec) - assert normalized["name"] == "test_tool" - assert normalized["description"] == "A test tool" - # Should not modify the spec if inputSchema is not present - assert "inputSchema" not in normalized + + expected = {"name": "test_tool", "description": "A test tool"} + + assert normalized == expected def test_normalize_tool_spec_empty_input_schema(): @@ -179,10 +356,10 @@ def test_normalize_tool_spec_empty_input_schema(): "inputSchema": "", } normalized = normalize_tool_spec(tool_spec) - assert normalized["name"] == "test_tool" - assert normalized["description"] == "A test tool" - # Should not modify the spec if inputSchema is not a dict - assert normalized["inputSchema"] == "" + + expected = {"name": "test_tool", "description": "A test tool", "inputSchema": ""} + + assert normalized == expected def test_validate_tool_use_with_valid_input(): diff --git a/tests/strands/types/models/test_openai.py b/tests/strands/types/models/test_openai.py index c6a05291..9db08bc9 100644 --- a/tests/strands/types/models/test_openai.py +++ b/tests/strands/types/models/test_openai.py @@ -1,4 +1,3 @@ -import json import unittest.mock import pytest @@ -101,14 +100,6 @@ def system_prompt(): {"text": "hello"}, {"type": "text", "text": "hello"}, ), - # Other - ( - {"other": {"a": 1}}, - { - "text": json.dumps({"other": {"a": 1}}), - "type": "text", - }, - ), ], ) def test_format_request_message_content(content, exp_result): @@ -116,6 +107,13 @@ def test_format_request_message_content(content, exp_result): assert tru_result == exp_result +def test_format_request_message_content_unsupported_type(): + content = {"unsupported": {}} + + with pytest.raises(TypeError, match="content_type= | unsupported type"): + SAOpenAIModel.format_request_message_content(content) + + def test_format_request_message_tool_call(): tool_use = { "input": {"expression": "2+2"}, @@ -137,19 +135,14 @@ def test_format_request_message_tool_call(): def test_format_request_tool_message(): tool_result = { - "content": [{"value": 4}], + "content": [{"text": "4"}, {"json": ["4"]}], "status": "success", "toolUseId": "c1", } tru_result = SAOpenAIModel.format_request_tool_message(tool_result) exp_result = { - "content": json.dumps( - { - "content": [{"value": 4}], - "status": "success", - } - ), + "content": [{"text": "4", "type": "text"}, {"text": '["4"]', "type": "text"}], "role": "tool", "tool_call_id": "c1", } @@ -180,7 +173,7 @@ def test_format_request_messages(system_prompt): "role": "assistant", }, { - "content": [{"toolResult": {"toolUseId": "c1", "status": "success", "content": [{"value": 4}]}}], + "content": [{"toolResult": {"toolUseId": "c1", "status": "success", "content": [{"text": "4"}]}}], "role": "user", }, ] @@ -210,12 +203,7 @@ def test_format_request_messages(system_prompt): ], }, { - "content": json.dumps( - { - "content": [{"value": 4}], - "status": "success", - } - ), + "content": [{"text": "4", "type": "text"}], "role": "tool", "tool_call_id": "c1", },