diff --git a/.gitignore b/.gitignore index a80f4bd1..a5cf11c4 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ __pycache__* .pytest_cache .ruff_cache *.bak +.vscode \ No newline at end of file diff --git a/README.md b/README.md index 262dde51..f4c483a2 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,14 @@ -# Strands Agents -
+
+ + Strands Agents + +
+ +

+ Strands Agents +

+

A model-driven approach to building AI agents in just a few lines of code.

@@ -8,16 +16,19 @@
GitHub commit activity GitHub open issues + GitHub open pull requests License PyPI version Python versions

- Docs - ◆ Samples + Documentation + ◆ Samples + ◆ Python SDKToolsAgent Builder + ◆ MCP Server

@@ -26,7 +37,7 @@ Strands Agents is a simple yet powerful SDK that takes a model-driven approach t ## Feature Overview - **Lightweight & Flexible**: Simple agent loop that just works and is fully customizable -- **Model Agnostic**: Support for Amazon Bedrock, Anthropic, Llama, Ollama, and custom providers +- **Model Agnostic**: Support for Amazon Bedrock, Anthropic, LiteLLM, Llama, Ollama, OpenAI, and custom providers - **Advanced Capabilities**: Multi-agent systems, autonomous agents, and streaming support - **Built-in MCP**: Native support for Model Context Protocol (MCP) servers, enabling access to thousands of pre-built tools @@ -138,6 +149,7 @@ Built-in providers: - [LiteLLM](https://strandsagents.com/latest/user-guide/concepts/model-providers/litellm/) - [LlamaAPI](https://strandsagents.com/latest/user-guide/concepts/model-providers/llamaapi/) - [Ollama](https://strandsagents.com/latest/user-guide/concepts/model-providers/ollama/) + - [OpenAI](https://strandsagents.com/latest/user-guide/concepts/model-providers/openai/) Custom providers can be implemented using [Custom Providers](https://strandsagents.com/latest/user-guide/concepts/model-providers/custom_model_provider/) @@ -165,9 +177,9 @@ For detailed guidance & examples, explore our documentation: - [API Reference](https://strandsagents.com/latest/api-reference/agent/) - [Production & Deployment Guide](https://strandsagents.com/latest/user-guide/deploy/operating-agents-in-production/) -## Contributing +## Contributing ❤️ -We welcome contributions! See our [Contributing Guide](https://github.com/strands-agents/sdk-python/blob/main/CONTRIBUTING.md) for details on: +We welcome contributions! See our [Contributing Guide](CONTRIBUTING.md) for details on: - Reporting bugs & features - Development setup - Contributing via Pull Requests @@ -178,6 +190,10 @@ We welcome contributions! See our [Contributing Guide](https://github.com/strand This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENSE) file for details. +## Security + +See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. + ## ⚠️ Preview Status Strands Agents is currently in public preview. During this period: diff --git a/pyproject.toml b/pyproject.toml index 52394cea..a3b36cab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "strands-agents" -version = "0.1.4" +version = "0.1.5" description = "A model-driven approach to building AI agents in just a few lines of code" readme = "README.md" requires-python = ">=3.10" diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 6a948980..0f912b54 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -328,27 +328,17 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult: - metrics: Performance metrics from the event loop - state: The final state of the event loop """ - model_id = self.model.config.get("model_id") if hasattr(self.model, "config") else None - - self.trace_span = self.tracer.start_agent_span( - prompt=prompt, - model_id=model_id, - tools=self.tool_names, - system_prompt=self.system_prompt, - custom_trace_attributes=self.trace_attributes, - ) + self._start_agent_trace_span(prompt) try: # Run the event loop and get the result result = self._run_loop(prompt, kwargs) - if self.trace_span: - self.tracer.end_agent_span(span=self.trace_span, response=result) + self._end_agent_trace_span(response=result) return result except Exception as e: - if self.trace_span: - self.tracer.end_agent_span(span=self.trace_span, error=e) + self._end_agent_trace_span(error=e) # Re-raise the exception to preserve original behavior raise @@ -383,6 +373,8 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]: yield event["data"] ``` """ + self._start_agent_trace_span(prompt) + _stop_event = uuid4() queue = asyncio.Queue[Any]() @@ -400,8 +392,10 @@ def target_callback() -> None: nonlocal kwargs try: - self._run_loop(prompt, kwargs, supplementary_callback_handler=queuing_callback_handler) - except BaseException as e: + result = self._run_loop(prompt, kwargs, supplementary_callback_handler=queuing_callback_handler) + self._end_agent_trace_span(response=result) + except Exception as e: + self._end_agent_trace_span(error=e) enqueue(e) finally: enqueue(_stop_event) @@ -414,7 +408,7 @@ def target_callback() -> None: item = await queue.get() if item == _stop_event: break - if isinstance(item, BaseException): + if isinstance(item, Exception): raise item yield item finally: @@ -457,27 +451,28 @@ def _execute_event_loop_cycle(self, callback_handler: Callable, kwargs: dict[str Returns: The result of the event loop cycle. """ - kwargs.pop("agent", None) - kwargs.pop("model", None) - kwargs.pop("system_prompt", None) - kwargs.pop("tool_execution_handler", None) - kwargs.pop("event_loop_metrics", None) - kwargs.pop("callback_handler", None) - kwargs.pop("tool_handler", None) - kwargs.pop("messages", None) - kwargs.pop("tool_config", None) + # Extract parameters with fallbacks to instance values + system_prompt = kwargs.pop("system_prompt", self.system_prompt) + model = kwargs.pop("model", self.model) + tool_execution_handler = kwargs.pop("tool_execution_handler", self.thread_pool_wrapper) + event_loop_metrics = kwargs.pop("event_loop_metrics", self.event_loop_metrics) + callback_handler_override = kwargs.pop("callback_handler", callback_handler) + tool_handler = kwargs.pop("tool_handler", self.tool_handler) + messages = kwargs.pop("messages", self.messages) + tool_config = kwargs.pop("tool_config", self.tool_config) + kwargs.pop("agent", None) # Remove agent to avoid conflicts try: # Execute the main event loop cycle stop_reason, message, metrics, state = event_loop_cycle( - model=self.model, - system_prompt=self.system_prompt, - messages=self.messages, # will be modified by event_loop_cycle - tool_config=self.tool_config, - callback_handler=callback_handler, - tool_handler=self.tool_handler, - tool_execution_handler=self.thread_pool_wrapper, - event_loop_metrics=self.event_loop_metrics, + model=model, + system_prompt=system_prompt, + messages=messages, # will be modified by event_loop_cycle + tool_config=tool_config, + callback_handler=callback_handler_override, + tool_handler=tool_handler, + tool_execution_handler=tool_execution_handler, + event_loop_metrics=event_loop_metrics, agent=self, event_loop_parent_span=self.trace_span, **kwargs, @@ -488,8 +483,8 @@ def _execute_event_loop_cycle(self, callback_handler: Callable, kwargs: dict[str except ContextWindowOverflowException as e: # Try reducing the context size and retrying - self.conversation_manager.reduce_context(self.messages, e=e) - return self._execute_event_loop_cycle(callback_handler, kwargs) + self.conversation_manager.reduce_context(messages, e=e) + return self._execute_event_loop_cycle(callback_handler_override, kwargs) def _record_tool_execution( self, @@ -545,3 +540,43 @@ def _record_tool_execution( messages.append(tool_use_msg) messages.append(tool_result_msg) messages.append(assistant_msg) + + def _start_agent_trace_span(self, prompt: str) -> None: + """Starts a trace span for the agent. + + Args: + prompt: The natural language prompt from the user. + """ + model_id = self.model.config.get("model_id") if hasattr(self.model, "config") else None + + self.trace_span = self.tracer.start_agent_span( + prompt=prompt, + model_id=model_id, + tools=self.tool_names, + system_prompt=self.system_prompt, + custom_trace_attributes=self.trace_attributes, + ) + + def _end_agent_trace_span( + self, + response: Optional[AgentResult] = None, + error: Optional[Exception] = None, + ) -> None: + """Ends a trace span for the agent. + + Args: + span: The span to end. + response: Response to record as a trace attribute. + error: Error to record as a trace attribute. + """ + if self.trace_span: + trace_attributes: Dict[str, Any] = { + "span": self.trace_span, + } + + if response: + trace_attributes["response"] = response + if error: + trace_attributes["error"] = error + + self.tracer.end_agent_span(**trace_attributes) diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index 4b11e81c..f367b272 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -1,12 +1,10 @@ """Sliding window conversation history management.""" -import json import logging -from typing import List, Optional, cast +from typing import Optional -from ...types.content import ContentBlock, Message, Messages +from ...types.content import Message, Messages from ...types.exceptions import ContextWindowOverflowException -from ...types.tools import ToolResult from .conversation_manager import ConversationManager logger = logging.getLogger(__name__) @@ -110,8 +108,9 @@ def _remove_dangling_messages(self, messages: Messages) -> None: def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> None: """Trim the oldest messages to reduce the conversation context size. - The method handles special cases where tool results need to be converted to regular content blocks to maintain - conversation coherence after trimming. + The method handles special cases where trimming the messages leads to: + - toolResult with no corresponding toolUse + - toolUse with no corresponding toolResult Args: messages: The messages to reduce. @@ -126,52 +125,24 @@ def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> N # If the number of messages is less than the window_size, then we default to 2, otherwise, trim to window size trim_index = 2 if len(messages) <= self.window_size else len(messages) - self.window_size - # Throw if we cannot trim any messages from the conversation - if trim_index >= len(messages): - raise ContextWindowOverflowException("Unable to trim conversation context!") from e - - # If the message at the cut index has ToolResultContent, then we map that to ContentBlock. This gets around the - # limitation of needing ToolUse and ToolResults to be paired. - if any("toolResult" in content for content in messages[trim_index]["content"]): - if len(messages[trim_index]["content"]) == 1: - messages[trim_index]["content"] = self._map_tool_result_content( - cast(ToolResult, messages[trim_index]["content"][0]["toolResult"]) + # Find the next valid trim_index + while trim_index < len(messages): + if ( + # Oldest message cannot be a toolResult because it needs a toolUse preceding it + any("toolResult" in content for content in messages[trim_index]["content"]) + or ( + # Oldest message can be a toolUse only if a toolResult immediately follows it. + any("toolUse" in content for content in messages[trim_index]["content"]) + and trim_index + 1 < len(messages) + and not any("toolResult" in content for content in messages[trim_index + 1]["content"]) ) - - # If there is more content than just one ToolResultContent, then we cannot cut at this index. + ): + trim_index += 1 else: - raise ContextWindowOverflowException("Unable to trim conversation context!") from e + break + else: + # If we didn't find a valid trim_index, then we throw + raise ContextWindowOverflowException("Unable to trim conversation context!") from e # Overwrite message history messages[:] = messages[trim_index:] - - def _map_tool_result_content(self, tool_result: ToolResult) -> List[ContentBlock]: - """Convert a ToolResult to a list of standard ContentBlocks. - - This method transforms tool result content into standard content blocks that can be preserved when trimming the - conversation history. - - Args: - tool_result: The ToolResult to convert. - - Returns: - A list of content blocks representing the tool result. - """ - contents = [] - text_content = "Tool Result Status: " + tool_result["status"] if tool_result["status"] else "" - - for tool_result_content in tool_result["content"]: - if "text" in tool_result_content: - text_content = "\nTool Result Text Content: " + tool_result_content["text"] + f"\n{text_content}" - elif "json" in tool_result_content: - text_content = ( - "\nTool Result JSON Content: " + json.dumps(tool_result_content["json"]) + f"\n{text_content}" - ) - elif "image" in tool_result_content: - contents.append(ContentBlock(image=tool_result_content["image"])) - elif "document" in tool_result_content: - contents.append(ContentBlock(document=tool_result_content["document"])) - else: - logger.warning("unsupported content type") - contents.append(ContentBlock(text=text_content)) - return contents diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index db5a1b97..23d7bd0f 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -28,6 +28,10 @@ logger = logging.getLogger(__name__) +MAX_ATTEMPTS = 6 +INITIAL_DELAY = 4 +MAX_DELAY = 240 # 4 minutes + def initialize_state(**kwargs: Any) -> Any: """Initialize the request state if not present. @@ -51,7 +55,7 @@ def event_loop_cycle( system_prompt: Optional[str], messages: Messages, tool_config: Optional[ToolConfig], - callback_handler: Any, + callback_handler: Callable[..., Any], tool_handler: Optional[ToolHandler], tool_execution_handler: Optional[ParallelToolExecutorInterface] = None, **kwargs: Any, @@ -130,13 +134,9 @@ def event_loop_cycle( stop_reason: StopReason usage: Any metrics: Metrics - max_attempts = 6 - initial_delay = 4 - max_delay = 240 # 4 minutes - current_delay = initial_delay # Retry loop for handling throttling exceptions - for attempt in range(max_attempts): + for attempt in range(MAX_ATTEMPTS): model_id = model.config.get("model_id") if hasattr(model, "config") else None model_invoke_span = tracer.start_model_invoke_span( parent_span=cycle_span, @@ -177,7 +177,7 @@ def event_loop_cycle( # Handle throttling errors with exponential backoff should_retry, current_delay = handle_throttling_error( - e, attempt, max_attempts, current_delay, max_delay, callback_handler, kwargs + e, attempt, MAX_ATTEMPTS, INITIAL_DELAY, MAX_DELAY, callback_handler, kwargs ) if should_retry: continue @@ -204,80 +204,35 @@ def event_loop_cycle( # If the model is requesting to use tools if stop_reason == "tool_use": - tool_uses: List[ToolUse] = [] - tool_results: List[ToolResult] = [] - invalid_tool_use_ids: List[str] = [] - - # Extract and validate tools - validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids) - - # Check if tools are available for execution - if tool_uses: - if tool_handler is None: - raise ValueError("toolUse present but tool handler not set") - if tool_config is None: - raise ValueError("toolUse present but tool config not set") - - # Create the tool handler process callable - tool_handler_process: Callable[[ToolUse], ToolResult] = partial( - tool_handler.process, - messages=messages, - model=model, - system_prompt=system_prompt, - tool_config=tool_config, - callback_handler=callback_handler, - **kwargs, + if not tool_handler: + raise EventLoopException( + Exception("Model requested tool use but no tool handler provided"), + kwargs["request_state"], ) - # Execute tools (parallel or sequential) - run_tools( - handler=tool_handler_process, - tool_uses=tool_uses, - event_loop_metrics=event_loop_metrics, - request_state=cast(Any, kwargs["request_state"]), - invalid_tool_use_ids=invalid_tool_use_ids, - tool_results=tool_results, - cycle_trace=cycle_trace, - parent_span=cycle_span, - parallel_tool_executor=tool_execution_handler, + if tool_config is None: + raise EventLoopException( + Exception("Model requested tool use but no tool config provided"), + kwargs["request_state"], ) - # Update state for the next cycle - kwargs = prepare_next_cycle(kwargs, event_loop_metrics) - - # Create the tool result message - tool_result_message: Message = { - "role": "user", - "content": [{"toolResult": result} for result in tool_results], - } - messages.append(tool_result_message) - callback_handler(message=tool_result_message) - - if cycle_span: - tracer.end_event_loop_cycle_span( - span=cycle_span, message=message, tool_result_message=tool_result_message - ) - - # Check if we should stop the event loop - if kwargs["request_state"].get("stop_event_loop"): - event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) - return ( - stop_reason, - message, - event_loop_metrics, - kwargs["request_state"], - ) - - # Recursive call to continue the conversation - return recurse_event_loop( - model=model, - system_prompt=system_prompt, - messages=messages, - tool_config=tool_config, - callback_handler=callback_handler, - tool_handler=tool_handler, - **kwargs, - ) + # Handle tool execution + return _handle_tool_execution( + stop_reason, + message, + model, + system_prompt, + messages, + tool_config, + tool_handler, + callback_handler, + tool_execution_handler, + event_loop_metrics, + cycle_trace, + cycle_span, + cycle_start_time, + kwargs, + ) # End the cycle and return results event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) @@ -377,3 +332,105 @@ def prepare_next_cycle(kwargs: Dict[str, Any], event_loop_metrics: EventLoopMetr kwargs["event_loop_parent_cycle_id"] = kwargs["event_loop_cycle_id"] return kwargs + + +def _handle_tool_execution( + stop_reason: StopReason, + message: Message, + model: Model, + system_prompt: Optional[str], + messages: Messages, + tool_config: ToolConfig, + tool_handler: ToolHandler, + callback_handler: Callable[..., Any], + tool_execution_handler: Optional[ParallelToolExecutorInterface], + event_loop_metrics: EventLoopMetrics, + cycle_trace: Trace, + cycle_span: Any, + cycle_start_time: float, + kwargs: Dict[str, Any], +) -> Tuple[StopReason, Message, EventLoopMetrics, Dict[str, Any]]: + tool_uses: List[ToolUse] = [] + tool_results: List[ToolResult] = [] + invalid_tool_use_ids: List[str] = [] + + """ + Handles the execution of tools requested by the model during an event loop cycle. + + Args: + stop_reason (StopReason): The reason the model stopped generating. + message (Message): The message from the model that may contain tool use requests. + model (Model): The model provider instance. + system_prompt (Optional[str]): The system prompt instructions for the model. + messages (Messages): The conversation history messages. + tool_config (ToolConfig): Configuration for available tools. + tool_handler (ToolHandler): Handler for tool execution. + callback_handler (Callable[..., Any]): Callback for processing events as they happen. + tool_execution_handler (Optional[ParallelToolExecutorInterface]): Optional handler for parallel tool execution. + event_loop_metrics (EventLoopMetrics): Metrics tracking object for the event loop. + cycle_trace (Trace): Trace object for the current event loop cycle. + cycle_span (Any): Span object for tracing the cycle (type may vary). + cycle_start_time (float): Start time of the current cycle. + kwargs (Dict[str, Any]): Additional keyword arguments, including request state. + + Returns: + Tuple[StopReason, Message, EventLoopMetrics, Dict[str, Any]]: + - The stop reason, + - The updated message, + - The updated event loop metrics, + - The updated request state. + """ + validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids) + + if not tool_uses: + return stop_reason, message, event_loop_metrics, kwargs["request_state"] + + tool_handler_process = partial( + tool_handler.process, + messages=messages, + model=model, + system_prompt=system_prompt, + tool_config=tool_config, + callback_handler=callback_handler, + **kwargs, + ) + + run_tools( + handler=tool_handler_process, + tool_uses=tool_uses, + event_loop_metrics=event_loop_metrics, + request_state=cast(Any, kwargs["request_state"]), + invalid_tool_use_ids=invalid_tool_use_ids, + tool_results=tool_results, + cycle_trace=cycle_trace, + parent_span=cycle_span, + parallel_tool_executor=tool_execution_handler, + ) + + kwargs = prepare_next_cycle(kwargs, event_loop_metrics) + + tool_result_message: Message = { + "role": "user", + "content": [{"toolResult": result} for result in tool_results], + } + + messages.append(tool_result_message) + callback_handler(message=tool_result_message) + + if cycle_span: + tracer = get_tracer() + tracer.end_event_loop_cycle_span(span=cycle_span, message=message, tool_result_message=tool_result_message) + + if kwargs["request_state"].get("stop_event_loop", False): + event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) + return stop_reason, message, event_loop_metrics, kwargs["request_state"] + + return recurse_event_loop( + model=model, + system_prompt=system_prompt, + messages=messages, + tool_config=tool_config, + callback_handler=callback_handler, + tool_handler=tool_handler, + **kwargs, + ) diff --git a/src/strands/handlers/callback_handler.py b/src/strands/handlers/callback_handler.py index d6d104d8..e46cb326 100644 --- a/src/strands/handlers/callback_handler.py +++ b/src/strands/handlers/callback_handler.py @@ -17,15 +17,19 @@ def __call__(self, **kwargs: Any) -> None: Args: **kwargs: Callback event data including: - - - data (str): Text content to stream. - - complete (bool): Whether this is the final chunk of a response. - - current_tool_use (dict): Information about the current tool being used. + - reasoningText (Optional[str]): Reasoning text to print if provided. + - data (str): Text content to stream. + - complete (bool): Whether this is the final chunk of a response. + - current_tool_use (dict): Information about the current tool being used. """ + reasoningText = kwargs.get("reasoningText", False) data = kwargs.get("data", "") complete = kwargs.get("complete", False) current_tool_use = kwargs.get("current_tool_use", {}) + if reasoningText: + print(reasoningText, end="") + if data: print(data, end="" if not complete else "\n") diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index 6e32c5bd..764cb851 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -18,6 +18,7 @@ class Client(Protocol): """Protocol defining the OpenAI-compatible interface for the underlying provider client.""" @property + # pragma: no cover def chat(self) -> Any: """Chat completions interface.""" ... diff --git a/src/strands/types/models/openai.py b/src/strands/types/models/openai.py index 8f5ffab3..c00a7774 100644 --- a/src/strands/types/models/openai.py +++ b/src/strands/types/models/openai.py @@ -206,7 +206,9 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: case "content_delta": if event["data_type"] == "tool": - return {"contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments}}}} + return { + "contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments or ""}}} + } return {"contentBlockDelta": {"delta": {"text": event["data"]}}} diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 5c7d11e4..ea06fb4e 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -9,7 +9,8 @@ import pytest import strands -from strands.agent.agent import Agent +from strands import Agent +from strands.agent import AgentResult from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.handlers.callback_handler import PrintingCallbackHandler, null_callback_handler @@ -337,17 +338,47 @@ def test_agent__call__passes_kwargs(mock_model, system_prompt, callback_handler, ], ] + override_system_prompt = "Override system prompt" + override_model = unittest.mock.Mock() + override_tool_execution_handler = unittest.mock.Mock() + override_event_loop_metrics = unittest.mock.Mock() + override_callback_handler = unittest.mock.Mock() + override_tool_handler = unittest.mock.Mock() + override_messages = [{"role": "user", "content": [{"text": "override msg"}]}] + override_tool_config = {"test": "config"} + def check_kwargs(some_value, **kwargs): assert some_value == "a_value" assert kwargs is not None + assert kwargs["system_prompt"] == override_system_prompt + assert kwargs["model"] == override_model + assert kwargs["tool_execution_handler"] == override_tool_execution_handler + assert kwargs["event_loop_metrics"] == override_event_loop_metrics + assert kwargs["callback_handler"] == override_callback_handler + assert kwargs["tool_handler"] == override_tool_handler + assert kwargs["messages"] == override_messages + assert kwargs["tool_config"] == override_tool_config + assert kwargs["agent"] == agent # Return expected values from event_loop_cycle return "stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {} mock_event_loop_cycle.side_effect = check_kwargs - agent("test message", some_value="a_value") - assert mock_event_loop_cycle.call_count == 1 + agent( + "test message", + some_value="a_value", + system_prompt=override_system_prompt, + model=override_model, + tool_execution_handler=override_tool_execution_handler, + event_loop_metrics=override_event_loop_metrics, + callback_handler=override_callback_handler, + tool_handler=override_tool_handler, + messages=override_messages, + tool_config=override_tool_config, + ) + + mock_event_loop_cycle.assert_called_once() def test_agent__call__retry_with_reduced_context(mock_model, agent, tool): @@ -428,7 +459,7 @@ def test_agent__call__always_sliding_window_conversation_manager_doesnt_infinite with pytest.raises(ContextWindowOverflowException): agent("Test!") - assert conversation_manager_spy.reduce_context.call_count == 251 + assert conversation_manager_spy.reduce_context.call_count > 0 assert conversation_manager_spy.apply_management.call_count == 1 @@ -657,8 +688,6 @@ def test_agent_with_callback_handler_none_uses_null_handler(): @pytest.mark.asyncio async def test_stream_async_returns_all_events(mock_event_loop_cycle): - mock_event_loop_cycle.side_effect = ValueError("Test exception") - agent = Agent() # Define the side effect to simulate callback handler being called multiple times @@ -922,6 +951,52 @@ def test_agent_call_creates_and_ends_span_on_success(mock_get_tracer, mock_model mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, response=result) +@pytest.mark.asyncio +@unittest.mock.patch("strands.agent.agent.get_tracer") +async def test_agent_stream_async_creates_and_ends_span_on_success(mock_get_tracer, mock_event_loop_cycle): + """Test that stream_async creates and ends a span when the call succeeds.""" + # Setup mock tracer and span + mock_tracer = unittest.mock.MagicMock() + mock_span = unittest.mock.MagicMock() + mock_tracer.start_agent_span.return_value = mock_span + mock_get_tracer.return_value = mock_tracer + + # Define the side effect to simulate callback handler being called multiple times + def call_callback_handler(*args, **kwargs): + # Extract the callback handler from kwargs + callback_handler = kwargs.get("callback_handler") + # Call the callback handler with different data values + callback_handler(data="First chunk") + callback_handler(data="Second chunk") + callback_handler(data="Final chunk", complete=True) + # Return expected values from event_loop_cycle + return "stop", {"role": "assistant", "content": [{"text": "Agent Response"}]}, {}, {} + + mock_event_loop_cycle.side_effect = call_callback_handler + + # Create agent and make a call + agent = Agent(model=mock_model) + iterator = agent.stream_async("test prompt") + async for _event in iterator: + pass # NoOp + + # Verify span was created + mock_tracer.start_agent_span.assert_called_once_with( + prompt="test prompt", + model_id=unittest.mock.ANY, + tools=agent.tool_names, + system_prompt=agent.system_prompt, + custom_trace_attributes=agent.trace_attributes, + ) + + expected_response = AgentResult( + stop_reason="stop", message={"role": "assistant", "content": [{"text": "Agent Response"}]}, metrics={}, state={} + ) + + # Verify span was ended with the result + mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, response=expected_response) + + @unittest.mock.patch("strands.agent.agent.get_tracer") def test_agent_call_creates_and_ends_span_on_exception(mock_get_tracer, mock_model): """Test that __call__ creates and ends a span when an exception occurs.""" @@ -955,6 +1030,42 @@ def test_agent_call_creates_and_ends_span_on_exception(mock_get_tracer, mock_mod mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception) +@pytest.mark.asyncio +@unittest.mock.patch("strands.agent.agent.get_tracer") +async def test_agent_stream_async_creates_and_ends_span_on_exception(mock_get_tracer, mock_model): + """Test that stream_async creates and ends a span when the call succeeds.""" + # Setup mock tracer and span + mock_tracer = unittest.mock.MagicMock() + mock_span = unittest.mock.MagicMock() + mock_tracer.start_agent_span.return_value = mock_span + mock_get_tracer.return_value = mock_tracer + + # Define the side effect to simulate callback handler raising an Exception + test_exception = ValueError("Test exception") + mock_model.mock_converse.side_effect = test_exception + + # Create agent and make a call + agent = Agent(model=mock_model) + + # Call the agent and catch the exception + with pytest.raises(ValueError): + iterator = agent.stream_async("test prompt") + async for _event in iterator: + pass # NoOp + + # Verify span was created + mock_tracer.start_agent_span.assert_called_once_with( + prompt="test prompt", + model_id=unittest.mock.ANY, + tools=agent.tool_names, + system_prompt=agent.system_prompt, + custom_trace_attributes=agent.trace_attributes, + ) + + # Verify span was ended with the exception + mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception) + + @unittest.mock.patch("strands.agent.agent.get_tracer") def test_event_loop_cycle_includes_parent_span(mock_get_tracer, mock_event_loop_cycle, mock_model): """Test that event_loop_cycle is called with the parent span.""" diff --git a/tests/strands/agent/test_conversation_manager.py b/tests/strands/agent/test_conversation_manager.py index 2f6ee77d..b6132f1d 100644 --- a/tests/strands/agent/test_conversation_manager.py +++ b/tests/strands/agent/test_conversation_manager.py @@ -111,41 +111,7 @@ def conversation_manager(request): {"role": "assistant", "content": [{"text": "Second response"}]}, ], ), - # 7 - Message count above max window size - Remove dangling tool uses and tool results - ( - {"window_size": 1}, - [ - {"role": "user", "content": [{"text": "First message"}]}, - {"role": "assistant", "content": [{"toolUse": {"toolUseId": "321", "name": "tool1", "input": {}}}]}, - { - "role": "user", - "content": [ - {"toolResult": {"toolUseId": "123", "content": [{"text": "Hello!"}], "status": "success"}} - ], - }, - ], - [ - { - "role": "user", - "content": [{"text": "\nTool Result Text Content: Hello!\nTool Result Status: success"}], - }, - ], - ), - # 8 - Message count above max window size - Remove multiple tool use/tool result pairs - ( - {"window_size": 1}, - [ - {"role": "user", "content": [{"toolResult": {"toolUseId": "123", "content": [], "status": "success"}}]}, - {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "tool1", "input": {}}}]}, - {"role": "user", "content": [{"toolResult": {"toolUseId": "456", "content": [], "status": "success"}}]}, - {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool1", "input": {}}}]}, - {"role": "user", "content": [{"toolResult": {"toolUseId": "789", "content": [], "status": "success"}}]}, - ], - [ - {"role": "user", "content": [{"text": "Tool Result Status: success"}]}, - ], - ), - # 9 - Message count above max window size - Preserve tool use/tool result pairs + # 7 - Message count above max window size - Preserve tool use/tool result pairs ( {"window_size": 2}, [ @@ -158,7 +124,7 @@ def conversation_manager(request): {"role": "user", "content": [{"toolResult": {"toolUseId": "456", "content": [], "status": "success"}}]}, ], ), - # 10 - Test sliding window behavior - preserve tool use/result pairs across cut boundary + # 8 - Test sliding window behavior - preserve tool use/result pairs across cut boundary ( {"window_size": 3}, [ @@ -173,7 +139,7 @@ def conversation_manager(request): {"role": "assistant", "content": [{"text": "Response after tool use"}]}, ], ), - # 11 - Test sliding window with multiple tool pairs that need preservation + # 9 - Test sliding window with multiple tool pairs that need preservation ( {"window_size": 4}, [ @@ -185,7 +151,6 @@ def conversation_manager(request): {"role": "assistant", "content": [{"text": "Final response"}]}, ], [ - {"role": "user", "content": [{"text": "Tool Result Status: success"}]}, {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool2", "input": {}}}]}, {"role": "user", "content": [{"toolResult": {"toolUseId": "456", "content": [], "status": "success"}}]}, {"role": "assistant", "content": [{"text": "Final response"}]}, @@ -200,6 +165,20 @@ def test_apply_management(conversation_manager, messages, expected_messages): assert messages == expected_messages +def test_sliding_window_conversation_manager_with_untrimmable_history_raises_context_window_overflow_exception(): + manager = strands.agent.conversation_manager.SlidingWindowConversationManager(1) + messages = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool1", "input": {}}}]}, + {"role": "user", "content": [{"toolResult": {"toolUseId": "789", "content": [], "status": "success"}}]}, + ] + original_messages = messages.copy() + + with pytest.raises(ContextWindowOverflowException): + manager.apply_management(messages) + + assert messages == original_messages + + def test_null_conversation_manager_reduce_context_raises_context_window_overflow_exception(): """Test that NullConversationManager doesn't modify messages.""" manager = strands.agent.conversation_manager.NullConversationManager() diff --git a/tests/strands/handlers/test_callback_handler.py b/tests/strands/handlers/test_callback_handler.py index 20e238cb..6fb2af07 100644 --- a/tests/strands/handlers/test_callback_handler.py +++ b/tests/strands/handlers/test_callback_handler.py @@ -30,6 +30,31 @@ def test_call_with_empty_args(handler, mock_print): mock_print.assert_not_called() +def test_call_handler_reasoningText(handler, mock_print): + """Test calling the handler with reasoningText.""" + handler(reasoningText="This is reasoning text") + # Should print reasoning text without newline + mock_print.assert_called_once_with("This is reasoning text", end="") + + +def test_call_without_reasoningText(handler, mock_print): + """Test calling the handler without reasoningText argument.""" + handler(data="Some output") + # Should only print data, not reasoningText + mock_print.assert_called_once_with("Some output", end="") + + +def test_call_with_reasoningText_and_data(handler, mock_print): + """Test calling the handler with both reasoningText and data.""" + handler(reasoningText="Reasoning", data="Output") + # Should print reasoningText and data, both without newline + calls = [ + unittest.mock.call("Reasoning", end=""), + unittest.mock.call("Output", end=""), + ] + mock_print.assert_has_calls(calls) + + def test_call_with_data_incomplete(handler, mock_print): """Test calling the handler with data but not complete.""" handler(data="Test output") diff --git a/tests/strands/types/models/test_openai.py b/tests/strands/types/models/test_openai.py index 97a0882a..2657c334 100644 --- a/tests/strands/types/models/test_openai.py +++ b/tests/strands/types/models/test_openai.py @@ -246,12 +246,12 @@ def test_format_request(model, messages, tool_specs, system_prompt): @pytest.mark.parametrize( ("event", "exp_chunk"), [ - # Case 1: Message start + # Message start ( {"chunk_type": "message_start"}, {"messageStart": {"role": "assistant"}}, ), - # Case 2: Content Start - Tool Use + # Content Start - Tool Use ( { "chunk_type": "content_start", @@ -260,12 +260,12 @@ def test_format_request(model, messages, tool_specs, system_prompt): }, {"contentBlockStart": {"start": {"toolUse": {"name": "calculator", "toolUseId": "c1"}}}}, ), - # Case 3: Content Start - Text + # Content Start - Text ( {"chunk_type": "content_start", "data_type": "text"}, {"contentBlockStart": {"start": {}}}, ), - # Case 4: Content Delta - Tool Use + # Content Delta - Tool Use ( { "chunk_type": "content_delta", @@ -274,32 +274,41 @@ def test_format_request(model, messages, tool_specs, system_prompt): }, {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "2+2"}'}}}}, ), - # Case 5: Content Delta - Text + # Content Delta - Tool Use - None + ( + { + "chunk_type": "content_delta", + "data_type": "tool", + "data": unittest.mock.Mock(function=unittest.mock.Mock(arguments=None)), + }, + {"contentBlockDelta": {"delta": {"toolUse": {"input": ""}}}}, + ), + # Content Delta - Text ( {"chunk_type": "content_delta", "data_type": "text", "data": "hello"}, {"contentBlockDelta": {"delta": {"text": "hello"}}}, ), - # Case 6: Content Stop + # Content Stop ( {"chunk_type": "content_stop"}, {"contentBlockStop": {}}, ), - # Case 7: Message Stop - Tool Use + # Message Stop - Tool Use ( {"chunk_type": "message_stop", "data": "tool_calls"}, {"messageStop": {"stopReason": "tool_use"}}, ), - # Case 8: Message Stop - Max Tokens + # Message Stop - Max Tokens ( {"chunk_type": "message_stop", "data": "length"}, {"messageStop": {"stopReason": "max_tokens"}}, ), - # Case 9: Message Stop - End Turn + # Message Stop - End Turn ( {"chunk_type": "message_stop", "data": "stop"}, {"messageStop": {"stopReason": "end_turn"}}, ), - # Case 10: Metadata + # Metadata ( { "chunk_type": "metadata",