diff --git a/.gitignore b/.gitignore
index a80f4bd1..a5cf11c4 100644
--- a/.gitignore
+++ b/.gitignore
@@ -7,3 +7,4 @@ __pycache__*
.pytest_cache
.ruff_cache
*.bak
+.vscode
\ No newline at end of file
diff --git a/README.md b/README.md
index 262dde51..f4c483a2 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,14 @@
-# Strands Agents
-
@@ -26,7 +37,7 @@ Strands Agents is a simple yet powerful SDK that takes a model-driven approach t
## Feature Overview
- **Lightweight & Flexible**: Simple agent loop that just works and is fully customizable
-- **Model Agnostic**: Support for Amazon Bedrock, Anthropic, Llama, Ollama, and custom providers
+- **Model Agnostic**: Support for Amazon Bedrock, Anthropic, LiteLLM, Llama, Ollama, OpenAI, and custom providers
- **Advanced Capabilities**: Multi-agent systems, autonomous agents, and streaming support
- **Built-in MCP**: Native support for Model Context Protocol (MCP) servers, enabling access to thousands of pre-built tools
@@ -138,6 +149,7 @@ Built-in providers:
- [LiteLLM](https://strandsagents.com/latest/user-guide/concepts/model-providers/litellm/)
- [LlamaAPI](https://strandsagents.com/latest/user-guide/concepts/model-providers/llamaapi/)
- [Ollama](https://strandsagents.com/latest/user-guide/concepts/model-providers/ollama/)
+ - [OpenAI](https://strandsagents.com/latest/user-guide/concepts/model-providers/openai/)
Custom providers can be implemented using [Custom Providers](https://strandsagents.com/latest/user-guide/concepts/model-providers/custom_model_provider/)
@@ -165,9 +177,9 @@ For detailed guidance & examples, explore our documentation:
- [API Reference](https://strandsagents.com/latest/api-reference/agent/)
- [Production & Deployment Guide](https://strandsagents.com/latest/user-guide/deploy/operating-agents-in-production/)
-## Contributing
+## Contributing ❤️
-We welcome contributions! See our [Contributing Guide](https://github.com/strands-agents/sdk-python/blob/main/CONTRIBUTING.md) for details on:
+We welcome contributions! See our [Contributing Guide](CONTRIBUTING.md) for details on:
- Reporting bugs & features
- Development setup
- Contributing via Pull Requests
@@ -178,6 +190,10 @@ We welcome contributions! See our [Contributing Guide](https://github.com/strand
This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENSE) file for details.
+## Security
+
+See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information.
+
## ⚠️ Preview Status
Strands Agents is currently in public preview. During this period:
diff --git a/pyproject.toml b/pyproject.toml
index 52394cea..a3b36cab 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
[project]
name = "strands-agents"
-version = "0.1.4"
+version = "0.1.5"
description = "A model-driven approach to building AI agents in just a few lines of code"
readme = "README.md"
requires-python = ">=3.10"
diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py
index 6a948980..0f912b54 100644
--- a/src/strands/agent/agent.py
+++ b/src/strands/agent/agent.py
@@ -328,27 +328,17 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult:
- metrics: Performance metrics from the event loop
- state: The final state of the event loop
"""
- model_id = self.model.config.get("model_id") if hasattr(self.model, "config") else None
-
- self.trace_span = self.tracer.start_agent_span(
- prompt=prompt,
- model_id=model_id,
- tools=self.tool_names,
- system_prompt=self.system_prompt,
- custom_trace_attributes=self.trace_attributes,
- )
+ self._start_agent_trace_span(prompt)
try:
# Run the event loop and get the result
result = self._run_loop(prompt, kwargs)
- if self.trace_span:
- self.tracer.end_agent_span(span=self.trace_span, response=result)
+ self._end_agent_trace_span(response=result)
return result
except Exception as e:
- if self.trace_span:
- self.tracer.end_agent_span(span=self.trace_span, error=e)
+ self._end_agent_trace_span(error=e)
# Re-raise the exception to preserve original behavior
raise
@@ -383,6 +373,8 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
yield event["data"]
```
"""
+ self._start_agent_trace_span(prompt)
+
_stop_event = uuid4()
queue = asyncio.Queue[Any]()
@@ -400,8 +392,10 @@ def target_callback() -> None:
nonlocal kwargs
try:
- self._run_loop(prompt, kwargs, supplementary_callback_handler=queuing_callback_handler)
- except BaseException as e:
+ result = self._run_loop(prompt, kwargs, supplementary_callback_handler=queuing_callback_handler)
+ self._end_agent_trace_span(response=result)
+ except Exception as e:
+ self._end_agent_trace_span(error=e)
enqueue(e)
finally:
enqueue(_stop_event)
@@ -414,7 +408,7 @@ def target_callback() -> None:
item = await queue.get()
if item == _stop_event:
break
- if isinstance(item, BaseException):
+ if isinstance(item, Exception):
raise item
yield item
finally:
@@ -457,27 +451,28 @@ def _execute_event_loop_cycle(self, callback_handler: Callable, kwargs: dict[str
Returns:
The result of the event loop cycle.
"""
- kwargs.pop("agent", None)
- kwargs.pop("model", None)
- kwargs.pop("system_prompt", None)
- kwargs.pop("tool_execution_handler", None)
- kwargs.pop("event_loop_metrics", None)
- kwargs.pop("callback_handler", None)
- kwargs.pop("tool_handler", None)
- kwargs.pop("messages", None)
- kwargs.pop("tool_config", None)
+ # Extract parameters with fallbacks to instance values
+ system_prompt = kwargs.pop("system_prompt", self.system_prompt)
+ model = kwargs.pop("model", self.model)
+ tool_execution_handler = kwargs.pop("tool_execution_handler", self.thread_pool_wrapper)
+ event_loop_metrics = kwargs.pop("event_loop_metrics", self.event_loop_metrics)
+ callback_handler_override = kwargs.pop("callback_handler", callback_handler)
+ tool_handler = kwargs.pop("tool_handler", self.tool_handler)
+ messages = kwargs.pop("messages", self.messages)
+ tool_config = kwargs.pop("tool_config", self.tool_config)
+ kwargs.pop("agent", None) # Remove agent to avoid conflicts
try:
# Execute the main event loop cycle
stop_reason, message, metrics, state = event_loop_cycle(
- model=self.model,
- system_prompt=self.system_prompt,
- messages=self.messages, # will be modified by event_loop_cycle
- tool_config=self.tool_config,
- callback_handler=callback_handler,
- tool_handler=self.tool_handler,
- tool_execution_handler=self.thread_pool_wrapper,
- event_loop_metrics=self.event_loop_metrics,
+ model=model,
+ system_prompt=system_prompt,
+ messages=messages, # will be modified by event_loop_cycle
+ tool_config=tool_config,
+ callback_handler=callback_handler_override,
+ tool_handler=tool_handler,
+ tool_execution_handler=tool_execution_handler,
+ event_loop_metrics=event_loop_metrics,
agent=self,
event_loop_parent_span=self.trace_span,
**kwargs,
@@ -488,8 +483,8 @@ def _execute_event_loop_cycle(self, callback_handler: Callable, kwargs: dict[str
except ContextWindowOverflowException as e:
# Try reducing the context size and retrying
- self.conversation_manager.reduce_context(self.messages, e=e)
- return self._execute_event_loop_cycle(callback_handler, kwargs)
+ self.conversation_manager.reduce_context(messages, e=e)
+ return self._execute_event_loop_cycle(callback_handler_override, kwargs)
def _record_tool_execution(
self,
@@ -545,3 +540,43 @@ def _record_tool_execution(
messages.append(tool_use_msg)
messages.append(tool_result_msg)
messages.append(assistant_msg)
+
+ def _start_agent_trace_span(self, prompt: str) -> None:
+ """Starts a trace span for the agent.
+
+ Args:
+ prompt: The natural language prompt from the user.
+ """
+ model_id = self.model.config.get("model_id") if hasattr(self.model, "config") else None
+
+ self.trace_span = self.tracer.start_agent_span(
+ prompt=prompt,
+ model_id=model_id,
+ tools=self.tool_names,
+ system_prompt=self.system_prompt,
+ custom_trace_attributes=self.trace_attributes,
+ )
+
+ def _end_agent_trace_span(
+ self,
+ response: Optional[AgentResult] = None,
+ error: Optional[Exception] = None,
+ ) -> None:
+ """Ends a trace span for the agent.
+
+ Args:
+ span: The span to end.
+ response: Response to record as a trace attribute.
+ error: Error to record as a trace attribute.
+ """
+ if self.trace_span:
+ trace_attributes: Dict[str, Any] = {
+ "span": self.trace_span,
+ }
+
+ if response:
+ trace_attributes["response"] = response
+ if error:
+ trace_attributes["error"] = error
+
+ self.tracer.end_agent_span(**trace_attributes)
diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py
index 4b11e81c..f367b272 100644
--- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py
+++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py
@@ -1,12 +1,10 @@
"""Sliding window conversation history management."""
-import json
import logging
-from typing import List, Optional, cast
+from typing import Optional
-from ...types.content import ContentBlock, Message, Messages
+from ...types.content import Message, Messages
from ...types.exceptions import ContextWindowOverflowException
-from ...types.tools import ToolResult
from .conversation_manager import ConversationManager
logger = logging.getLogger(__name__)
@@ -110,8 +108,9 @@ def _remove_dangling_messages(self, messages: Messages) -> None:
def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> None:
"""Trim the oldest messages to reduce the conversation context size.
- The method handles special cases where tool results need to be converted to regular content blocks to maintain
- conversation coherence after trimming.
+ The method handles special cases where trimming the messages leads to:
+ - toolResult with no corresponding toolUse
+ - toolUse with no corresponding toolResult
Args:
messages: The messages to reduce.
@@ -126,52 +125,24 @@ def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> N
# If the number of messages is less than the window_size, then we default to 2, otherwise, trim to window size
trim_index = 2 if len(messages) <= self.window_size else len(messages) - self.window_size
- # Throw if we cannot trim any messages from the conversation
- if trim_index >= len(messages):
- raise ContextWindowOverflowException("Unable to trim conversation context!") from e
-
- # If the message at the cut index has ToolResultContent, then we map that to ContentBlock. This gets around the
- # limitation of needing ToolUse and ToolResults to be paired.
- if any("toolResult" in content for content in messages[trim_index]["content"]):
- if len(messages[trim_index]["content"]) == 1:
- messages[trim_index]["content"] = self._map_tool_result_content(
- cast(ToolResult, messages[trim_index]["content"][0]["toolResult"])
+ # Find the next valid trim_index
+ while trim_index < len(messages):
+ if (
+ # Oldest message cannot be a toolResult because it needs a toolUse preceding it
+ any("toolResult" in content for content in messages[trim_index]["content"])
+ or (
+ # Oldest message can be a toolUse only if a toolResult immediately follows it.
+ any("toolUse" in content for content in messages[trim_index]["content"])
+ and trim_index + 1 < len(messages)
+ and not any("toolResult" in content for content in messages[trim_index + 1]["content"])
)
-
- # If there is more content than just one ToolResultContent, then we cannot cut at this index.
+ ):
+ trim_index += 1
else:
- raise ContextWindowOverflowException("Unable to trim conversation context!") from e
+ break
+ else:
+ # If we didn't find a valid trim_index, then we throw
+ raise ContextWindowOverflowException("Unable to trim conversation context!") from e
# Overwrite message history
messages[:] = messages[trim_index:]
-
- def _map_tool_result_content(self, tool_result: ToolResult) -> List[ContentBlock]:
- """Convert a ToolResult to a list of standard ContentBlocks.
-
- This method transforms tool result content into standard content blocks that can be preserved when trimming the
- conversation history.
-
- Args:
- tool_result: The ToolResult to convert.
-
- Returns:
- A list of content blocks representing the tool result.
- """
- contents = []
- text_content = "Tool Result Status: " + tool_result["status"] if tool_result["status"] else ""
-
- for tool_result_content in tool_result["content"]:
- if "text" in tool_result_content:
- text_content = "\nTool Result Text Content: " + tool_result_content["text"] + f"\n{text_content}"
- elif "json" in tool_result_content:
- text_content = (
- "\nTool Result JSON Content: " + json.dumps(tool_result_content["json"]) + f"\n{text_content}"
- )
- elif "image" in tool_result_content:
- contents.append(ContentBlock(image=tool_result_content["image"]))
- elif "document" in tool_result_content:
- contents.append(ContentBlock(document=tool_result_content["document"]))
- else:
- logger.warning("unsupported content type")
- contents.append(ContentBlock(text=text_content))
- return contents
diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py
index db5a1b97..23d7bd0f 100644
--- a/src/strands/event_loop/event_loop.py
+++ b/src/strands/event_loop/event_loop.py
@@ -28,6 +28,10 @@
logger = logging.getLogger(__name__)
+MAX_ATTEMPTS = 6
+INITIAL_DELAY = 4
+MAX_DELAY = 240 # 4 minutes
+
def initialize_state(**kwargs: Any) -> Any:
"""Initialize the request state if not present.
@@ -51,7 +55,7 @@ def event_loop_cycle(
system_prompt: Optional[str],
messages: Messages,
tool_config: Optional[ToolConfig],
- callback_handler: Any,
+ callback_handler: Callable[..., Any],
tool_handler: Optional[ToolHandler],
tool_execution_handler: Optional[ParallelToolExecutorInterface] = None,
**kwargs: Any,
@@ -130,13 +134,9 @@ def event_loop_cycle(
stop_reason: StopReason
usage: Any
metrics: Metrics
- max_attempts = 6
- initial_delay = 4
- max_delay = 240 # 4 minutes
- current_delay = initial_delay
# Retry loop for handling throttling exceptions
- for attempt in range(max_attempts):
+ for attempt in range(MAX_ATTEMPTS):
model_id = model.config.get("model_id") if hasattr(model, "config") else None
model_invoke_span = tracer.start_model_invoke_span(
parent_span=cycle_span,
@@ -177,7 +177,7 @@ def event_loop_cycle(
# Handle throttling errors with exponential backoff
should_retry, current_delay = handle_throttling_error(
- e, attempt, max_attempts, current_delay, max_delay, callback_handler, kwargs
+ e, attempt, MAX_ATTEMPTS, INITIAL_DELAY, MAX_DELAY, callback_handler, kwargs
)
if should_retry:
continue
@@ -204,80 +204,35 @@ def event_loop_cycle(
# If the model is requesting to use tools
if stop_reason == "tool_use":
- tool_uses: List[ToolUse] = []
- tool_results: List[ToolResult] = []
- invalid_tool_use_ids: List[str] = []
-
- # Extract and validate tools
- validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids)
-
- # Check if tools are available for execution
- if tool_uses:
- if tool_handler is None:
- raise ValueError("toolUse present but tool handler not set")
- if tool_config is None:
- raise ValueError("toolUse present but tool config not set")
-
- # Create the tool handler process callable
- tool_handler_process: Callable[[ToolUse], ToolResult] = partial(
- tool_handler.process,
- messages=messages,
- model=model,
- system_prompt=system_prompt,
- tool_config=tool_config,
- callback_handler=callback_handler,
- **kwargs,
+ if not tool_handler:
+ raise EventLoopException(
+ Exception("Model requested tool use but no tool handler provided"),
+ kwargs["request_state"],
)
- # Execute tools (parallel or sequential)
- run_tools(
- handler=tool_handler_process,
- tool_uses=tool_uses,
- event_loop_metrics=event_loop_metrics,
- request_state=cast(Any, kwargs["request_state"]),
- invalid_tool_use_ids=invalid_tool_use_ids,
- tool_results=tool_results,
- cycle_trace=cycle_trace,
- parent_span=cycle_span,
- parallel_tool_executor=tool_execution_handler,
+ if tool_config is None:
+ raise EventLoopException(
+ Exception("Model requested tool use but no tool config provided"),
+ kwargs["request_state"],
)
- # Update state for the next cycle
- kwargs = prepare_next_cycle(kwargs, event_loop_metrics)
-
- # Create the tool result message
- tool_result_message: Message = {
- "role": "user",
- "content": [{"toolResult": result} for result in tool_results],
- }
- messages.append(tool_result_message)
- callback_handler(message=tool_result_message)
-
- if cycle_span:
- tracer.end_event_loop_cycle_span(
- span=cycle_span, message=message, tool_result_message=tool_result_message
- )
-
- # Check if we should stop the event loop
- if kwargs["request_state"].get("stop_event_loop"):
- event_loop_metrics.end_cycle(cycle_start_time, cycle_trace)
- return (
- stop_reason,
- message,
- event_loop_metrics,
- kwargs["request_state"],
- )
-
- # Recursive call to continue the conversation
- return recurse_event_loop(
- model=model,
- system_prompt=system_prompt,
- messages=messages,
- tool_config=tool_config,
- callback_handler=callback_handler,
- tool_handler=tool_handler,
- **kwargs,
- )
+ # Handle tool execution
+ return _handle_tool_execution(
+ stop_reason,
+ message,
+ model,
+ system_prompt,
+ messages,
+ tool_config,
+ tool_handler,
+ callback_handler,
+ tool_execution_handler,
+ event_loop_metrics,
+ cycle_trace,
+ cycle_span,
+ cycle_start_time,
+ kwargs,
+ )
# End the cycle and return results
event_loop_metrics.end_cycle(cycle_start_time, cycle_trace)
@@ -377,3 +332,105 @@ def prepare_next_cycle(kwargs: Dict[str, Any], event_loop_metrics: EventLoopMetr
kwargs["event_loop_parent_cycle_id"] = kwargs["event_loop_cycle_id"]
return kwargs
+
+
+def _handle_tool_execution(
+ stop_reason: StopReason,
+ message: Message,
+ model: Model,
+ system_prompt: Optional[str],
+ messages: Messages,
+ tool_config: ToolConfig,
+ tool_handler: ToolHandler,
+ callback_handler: Callable[..., Any],
+ tool_execution_handler: Optional[ParallelToolExecutorInterface],
+ event_loop_metrics: EventLoopMetrics,
+ cycle_trace: Trace,
+ cycle_span: Any,
+ cycle_start_time: float,
+ kwargs: Dict[str, Any],
+) -> Tuple[StopReason, Message, EventLoopMetrics, Dict[str, Any]]:
+ tool_uses: List[ToolUse] = []
+ tool_results: List[ToolResult] = []
+ invalid_tool_use_ids: List[str] = []
+
+ """
+ Handles the execution of tools requested by the model during an event loop cycle.
+
+ Args:
+ stop_reason (StopReason): The reason the model stopped generating.
+ message (Message): The message from the model that may contain tool use requests.
+ model (Model): The model provider instance.
+ system_prompt (Optional[str]): The system prompt instructions for the model.
+ messages (Messages): The conversation history messages.
+ tool_config (ToolConfig): Configuration for available tools.
+ tool_handler (ToolHandler): Handler for tool execution.
+ callback_handler (Callable[..., Any]): Callback for processing events as they happen.
+ tool_execution_handler (Optional[ParallelToolExecutorInterface]): Optional handler for parallel tool execution.
+ event_loop_metrics (EventLoopMetrics): Metrics tracking object for the event loop.
+ cycle_trace (Trace): Trace object for the current event loop cycle.
+ cycle_span (Any): Span object for tracing the cycle (type may vary).
+ cycle_start_time (float): Start time of the current cycle.
+ kwargs (Dict[str, Any]): Additional keyword arguments, including request state.
+
+ Returns:
+ Tuple[StopReason, Message, EventLoopMetrics, Dict[str, Any]]:
+ - The stop reason,
+ - The updated message,
+ - The updated event loop metrics,
+ - The updated request state.
+ """
+ validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids)
+
+ if not tool_uses:
+ return stop_reason, message, event_loop_metrics, kwargs["request_state"]
+
+ tool_handler_process = partial(
+ tool_handler.process,
+ messages=messages,
+ model=model,
+ system_prompt=system_prompt,
+ tool_config=tool_config,
+ callback_handler=callback_handler,
+ **kwargs,
+ )
+
+ run_tools(
+ handler=tool_handler_process,
+ tool_uses=tool_uses,
+ event_loop_metrics=event_loop_metrics,
+ request_state=cast(Any, kwargs["request_state"]),
+ invalid_tool_use_ids=invalid_tool_use_ids,
+ tool_results=tool_results,
+ cycle_trace=cycle_trace,
+ parent_span=cycle_span,
+ parallel_tool_executor=tool_execution_handler,
+ )
+
+ kwargs = prepare_next_cycle(kwargs, event_loop_metrics)
+
+ tool_result_message: Message = {
+ "role": "user",
+ "content": [{"toolResult": result} for result in tool_results],
+ }
+
+ messages.append(tool_result_message)
+ callback_handler(message=tool_result_message)
+
+ if cycle_span:
+ tracer = get_tracer()
+ tracer.end_event_loop_cycle_span(span=cycle_span, message=message, tool_result_message=tool_result_message)
+
+ if kwargs["request_state"].get("stop_event_loop", False):
+ event_loop_metrics.end_cycle(cycle_start_time, cycle_trace)
+ return stop_reason, message, event_loop_metrics, kwargs["request_state"]
+
+ return recurse_event_loop(
+ model=model,
+ system_prompt=system_prompt,
+ messages=messages,
+ tool_config=tool_config,
+ callback_handler=callback_handler,
+ tool_handler=tool_handler,
+ **kwargs,
+ )
diff --git a/src/strands/handlers/callback_handler.py b/src/strands/handlers/callback_handler.py
index d6d104d8..e46cb326 100644
--- a/src/strands/handlers/callback_handler.py
+++ b/src/strands/handlers/callback_handler.py
@@ -17,15 +17,19 @@ def __call__(self, **kwargs: Any) -> None:
Args:
**kwargs: Callback event data including:
-
- - data (str): Text content to stream.
- - complete (bool): Whether this is the final chunk of a response.
- - current_tool_use (dict): Information about the current tool being used.
+ - reasoningText (Optional[str]): Reasoning text to print if provided.
+ - data (str): Text content to stream.
+ - complete (bool): Whether this is the final chunk of a response.
+ - current_tool_use (dict): Information about the current tool being used.
"""
+ reasoningText = kwargs.get("reasoningText", False)
data = kwargs.get("data", "")
complete = kwargs.get("complete", False)
current_tool_use = kwargs.get("current_tool_use", {})
+ if reasoningText:
+ print(reasoningText, end="")
+
if data:
print(data, end="" if not complete else "\n")
diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py
index 6e32c5bd..764cb851 100644
--- a/src/strands/models/openai.py
+++ b/src/strands/models/openai.py
@@ -18,6 +18,7 @@ class Client(Protocol):
"""Protocol defining the OpenAI-compatible interface for the underlying provider client."""
@property
+ # pragma: no cover
def chat(self) -> Any:
"""Chat completions interface."""
...
diff --git a/src/strands/types/models/openai.py b/src/strands/types/models/openai.py
index 8f5ffab3..c00a7774 100644
--- a/src/strands/types/models/openai.py
+++ b/src/strands/types/models/openai.py
@@ -206,7 +206,9 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
case "content_delta":
if event["data_type"] == "tool":
- return {"contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments}}}}
+ return {
+ "contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments or ""}}}
+ }
return {"contentBlockDelta": {"delta": {"text": event["data"]}}}
diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py
index 5c7d11e4..ea06fb4e 100644
--- a/tests/strands/agent/test_agent.py
+++ b/tests/strands/agent/test_agent.py
@@ -9,7 +9,8 @@
import pytest
import strands
-from strands.agent.agent import Agent
+from strands import Agent
+from strands.agent import AgentResult
from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager
from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager
from strands.handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
@@ -337,17 +338,47 @@ def test_agent__call__passes_kwargs(mock_model, system_prompt, callback_handler,
],
]
+ override_system_prompt = "Override system prompt"
+ override_model = unittest.mock.Mock()
+ override_tool_execution_handler = unittest.mock.Mock()
+ override_event_loop_metrics = unittest.mock.Mock()
+ override_callback_handler = unittest.mock.Mock()
+ override_tool_handler = unittest.mock.Mock()
+ override_messages = [{"role": "user", "content": [{"text": "override msg"}]}]
+ override_tool_config = {"test": "config"}
+
def check_kwargs(some_value, **kwargs):
assert some_value == "a_value"
assert kwargs is not None
+ assert kwargs["system_prompt"] == override_system_prompt
+ assert kwargs["model"] == override_model
+ assert kwargs["tool_execution_handler"] == override_tool_execution_handler
+ assert kwargs["event_loop_metrics"] == override_event_loop_metrics
+ assert kwargs["callback_handler"] == override_callback_handler
+ assert kwargs["tool_handler"] == override_tool_handler
+ assert kwargs["messages"] == override_messages
+ assert kwargs["tool_config"] == override_tool_config
+ assert kwargs["agent"] == agent
# Return expected values from event_loop_cycle
return "stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {}
mock_event_loop_cycle.side_effect = check_kwargs
- agent("test message", some_value="a_value")
- assert mock_event_loop_cycle.call_count == 1
+ agent(
+ "test message",
+ some_value="a_value",
+ system_prompt=override_system_prompt,
+ model=override_model,
+ tool_execution_handler=override_tool_execution_handler,
+ event_loop_metrics=override_event_loop_metrics,
+ callback_handler=override_callback_handler,
+ tool_handler=override_tool_handler,
+ messages=override_messages,
+ tool_config=override_tool_config,
+ )
+
+ mock_event_loop_cycle.assert_called_once()
def test_agent__call__retry_with_reduced_context(mock_model, agent, tool):
@@ -428,7 +459,7 @@ def test_agent__call__always_sliding_window_conversation_manager_doesnt_infinite
with pytest.raises(ContextWindowOverflowException):
agent("Test!")
- assert conversation_manager_spy.reduce_context.call_count == 251
+ assert conversation_manager_spy.reduce_context.call_count > 0
assert conversation_manager_spy.apply_management.call_count == 1
@@ -657,8 +688,6 @@ def test_agent_with_callback_handler_none_uses_null_handler():
@pytest.mark.asyncio
async def test_stream_async_returns_all_events(mock_event_loop_cycle):
- mock_event_loop_cycle.side_effect = ValueError("Test exception")
-
agent = Agent()
# Define the side effect to simulate callback handler being called multiple times
@@ -922,6 +951,52 @@ def test_agent_call_creates_and_ends_span_on_success(mock_get_tracer, mock_model
mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, response=result)
+@pytest.mark.asyncio
+@unittest.mock.patch("strands.agent.agent.get_tracer")
+async def test_agent_stream_async_creates_and_ends_span_on_success(mock_get_tracer, mock_event_loop_cycle):
+ """Test that stream_async creates and ends a span when the call succeeds."""
+ # Setup mock tracer and span
+ mock_tracer = unittest.mock.MagicMock()
+ mock_span = unittest.mock.MagicMock()
+ mock_tracer.start_agent_span.return_value = mock_span
+ mock_get_tracer.return_value = mock_tracer
+
+ # Define the side effect to simulate callback handler being called multiple times
+ def call_callback_handler(*args, **kwargs):
+ # Extract the callback handler from kwargs
+ callback_handler = kwargs.get("callback_handler")
+ # Call the callback handler with different data values
+ callback_handler(data="First chunk")
+ callback_handler(data="Second chunk")
+ callback_handler(data="Final chunk", complete=True)
+ # Return expected values from event_loop_cycle
+ return "stop", {"role": "assistant", "content": [{"text": "Agent Response"}]}, {}, {}
+
+ mock_event_loop_cycle.side_effect = call_callback_handler
+
+ # Create agent and make a call
+ agent = Agent(model=mock_model)
+ iterator = agent.stream_async("test prompt")
+ async for _event in iterator:
+ pass # NoOp
+
+ # Verify span was created
+ mock_tracer.start_agent_span.assert_called_once_with(
+ prompt="test prompt",
+ model_id=unittest.mock.ANY,
+ tools=agent.tool_names,
+ system_prompt=agent.system_prompt,
+ custom_trace_attributes=agent.trace_attributes,
+ )
+
+ expected_response = AgentResult(
+ stop_reason="stop", message={"role": "assistant", "content": [{"text": "Agent Response"}]}, metrics={}, state={}
+ )
+
+ # Verify span was ended with the result
+ mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, response=expected_response)
+
+
@unittest.mock.patch("strands.agent.agent.get_tracer")
def test_agent_call_creates_and_ends_span_on_exception(mock_get_tracer, mock_model):
"""Test that __call__ creates and ends a span when an exception occurs."""
@@ -955,6 +1030,42 @@ def test_agent_call_creates_and_ends_span_on_exception(mock_get_tracer, mock_mod
mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception)
+@pytest.mark.asyncio
+@unittest.mock.patch("strands.agent.agent.get_tracer")
+async def test_agent_stream_async_creates_and_ends_span_on_exception(mock_get_tracer, mock_model):
+ """Test that stream_async creates and ends a span when the call succeeds."""
+ # Setup mock tracer and span
+ mock_tracer = unittest.mock.MagicMock()
+ mock_span = unittest.mock.MagicMock()
+ mock_tracer.start_agent_span.return_value = mock_span
+ mock_get_tracer.return_value = mock_tracer
+
+ # Define the side effect to simulate callback handler raising an Exception
+ test_exception = ValueError("Test exception")
+ mock_model.mock_converse.side_effect = test_exception
+
+ # Create agent and make a call
+ agent = Agent(model=mock_model)
+
+ # Call the agent and catch the exception
+ with pytest.raises(ValueError):
+ iterator = agent.stream_async("test prompt")
+ async for _event in iterator:
+ pass # NoOp
+
+ # Verify span was created
+ mock_tracer.start_agent_span.assert_called_once_with(
+ prompt="test prompt",
+ model_id=unittest.mock.ANY,
+ tools=agent.tool_names,
+ system_prompt=agent.system_prompt,
+ custom_trace_attributes=agent.trace_attributes,
+ )
+
+ # Verify span was ended with the exception
+ mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception)
+
+
@unittest.mock.patch("strands.agent.agent.get_tracer")
def test_event_loop_cycle_includes_parent_span(mock_get_tracer, mock_event_loop_cycle, mock_model):
"""Test that event_loop_cycle is called with the parent span."""
diff --git a/tests/strands/agent/test_conversation_manager.py b/tests/strands/agent/test_conversation_manager.py
index 2f6ee77d..b6132f1d 100644
--- a/tests/strands/agent/test_conversation_manager.py
+++ b/tests/strands/agent/test_conversation_manager.py
@@ -111,41 +111,7 @@ def conversation_manager(request):
{"role": "assistant", "content": [{"text": "Second response"}]},
],
),
- # 7 - Message count above max window size - Remove dangling tool uses and tool results
- (
- {"window_size": 1},
- [
- {"role": "user", "content": [{"text": "First message"}]},
- {"role": "assistant", "content": [{"toolUse": {"toolUseId": "321", "name": "tool1", "input": {}}}]},
- {
- "role": "user",
- "content": [
- {"toolResult": {"toolUseId": "123", "content": [{"text": "Hello!"}], "status": "success"}}
- ],
- },
- ],
- [
- {
- "role": "user",
- "content": [{"text": "\nTool Result Text Content: Hello!\nTool Result Status: success"}],
- },
- ],
- ),
- # 8 - Message count above max window size - Remove multiple tool use/tool result pairs
- (
- {"window_size": 1},
- [
- {"role": "user", "content": [{"toolResult": {"toolUseId": "123", "content": [], "status": "success"}}]},
- {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "tool1", "input": {}}}]},
- {"role": "user", "content": [{"toolResult": {"toolUseId": "456", "content": [], "status": "success"}}]},
- {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool1", "input": {}}}]},
- {"role": "user", "content": [{"toolResult": {"toolUseId": "789", "content": [], "status": "success"}}]},
- ],
- [
- {"role": "user", "content": [{"text": "Tool Result Status: success"}]},
- ],
- ),
- # 9 - Message count above max window size - Preserve tool use/tool result pairs
+ # 7 - Message count above max window size - Preserve tool use/tool result pairs
(
{"window_size": 2},
[
@@ -158,7 +124,7 @@ def conversation_manager(request):
{"role": "user", "content": [{"toolResult": {"toolUseId": "456", "content": [], "status": "success"}}]},
],
),
- # 10 - Test sliding window behavior - preserve tool use/result pairs across cut boundary
+ # 8 - Test sliding window behavior - preserve tool use/result pairs across cut boundary
(
{"window_size": 3},
[
@@ -173,7 +139,7 @@ def conversation_manager(request):
{"role": "assistant", "content": [{"text": "Response after tool use"}]},
],
),
- # 11 - Test sliding window with multiple tool pairs that need preservation
+ # 9 - Test sliding window with multiple tool pairs that need preservation
(
{"window_size": 4},
[
@@ -185,7 +151,6 @@ def conversation_manager(request):
{"role": "assistant", "content": [{"text": "Final response"}]},
],
[
- {"role": "user", "content": [{"text": "Tool Result Status: success"}]},
{"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool2", "input": {}}}]},
{"role": "user", "content": [{"toolResult": {"toolUseId": "456", "content": [], "status": "success"}}]},
{"role": "assistant", "content": [{"text": "Final response"}]},
@@ -200,6 +165,20 @@ def test_apply_management(conversation_manager, messages, expected_messages):
assert messages == expected_messages
+def test_sliding_window_conversation_manager_with_untrimmable_history_raises_context_window_overflow_exception():
+ manager = strands.agent.conversation_manager.SlidingWindowConversationManager(1)
+ messages = [
+ {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool1", "input": {}}}]},
+ {"role": "user", "content": [{"toolResult": {"toolUseId": "789", "content": [], "status": "success"}}]},
+ ]
+ original_messages = messages.copy()
+
+ with pytest.raises(ContextWindowOverflowException):
+ manager.apply_management(messages)
+
+ assert messages == original_messages
+
+
def test_null_conversation_manager_reduce_context_raises_context_window_overflow_exception():
"""Test that NullConversationManager doesn't modify messages."""
manager = strands.agent.conversation_manager.NullConversationManager()
diff --git a/tests/strands/handlers/test_callback_handler.py b/tests/strands/handlers/test_callback_handler.py
index 20e238cb..6fb2af07 100644
--- a/tests/strands/handlers/test_callback_handler.py
+++ b/tests/strands/handlers/test_callback_handler.py
@@ -30,6 +30,31 @@ def test_call_with_empty_args(handler, mock_print):
mock_print.assert_not_called()
+def test_call_handler_reasoningText(handler, mock_print):
+ """Test calling the handler with reasoningText."""
+ handler(reasoningText="This is reasoning text")
+ # Should print reasoning text without newline
+ mock_print.assert_called_once_with("This is reasoning text", end="")
+
+
+def test_call_without_reasoningText(handler, mock_print):
+ """Test calling the handler without reasoningText argument."""
+ handler(data="Some output")
+ # Should only print data, not reasoningText
+ mock_print.assert_called_once_with("Some output", end="")
+
+
+def test_call_with_reasoningText_and_data(handler, mock_print):
+ """Test calling the handler with both reasoningText and data."""
+ handler(reasoningText="Reasoning", data="Output")
+ # Should print reasoningText and data, both without newline
+ calls = [
+ unittest.mock.call("Reasoning", end=""),
+ unittest.mock.call("Output", end=""),
+ ]
+ mock_print.assert_has_calls(calls)
+
+
def test_call_with_data_incomplete(handler, mock_print):
"""Test calling the handler with data but not complete."""
handler(data="Test output")
diff --git a/tests/strands/types/models/test_openai.py b/tests/strands/types/models/test_openai.py
index 97a0882a..2657c334 100644
--- a/tests/strands/types/models/test_openai.py
+++ b/tests/strands/types/models/test_openai.py
@@ -246,12 +246,12 @@ def test_format_request(model, messages, tool_specs, system_prompt):
@pytest.mark.parametrize(
("event", "exp_chunk"),
[
- # Case 1: Message start
+ # Message start
(
{"chunk_type": "message_start"},
{"messageStart": {"role": "assistant"}},
),
- # Case 2: Content Start - Tool Use
+ # Content Start - Tool Use
(
{
"chunk_type": "content_start",
@@ -260,12 +260,12 @@ def test_format_request(model, messages, tool_specs, system_prompt):
},
{"contentBlockStart": {"start": {"toolUse": {"name": "calculator", "toolUseId": "c1"}}}},
),
- # Case 3: Content Start - Text
+ # Content Start - Text
(
{"chunk_type": "content_start", "data_type": "text"},
{"contentBlockStart": {"start": {}}},
),
- # Case 4: Content Delta - Tool Use
+ # Content Delta - Tool Use
(
{
"chunk_type": "content_delta",
@@ -274,32 +274,41 @@ def test_format_request(model, messages, tool_specs, system_prompt):
},
{"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "2+2"}'}}}},
),
- # Case 5: Content Delta - Text
+ # Content Delta - Tool Use - None
+ (
+ {
+ "chunk_type": "content_delta",
+ "data_type": "tool",
+ "data": unittest.mock.Mock(function=unittest.mock.Mock(arguments=None)),
+ },
+ {"contentBlockDelta": {"delta": {"toolUse": {"input": ""}}}},
+ ),
+ # Content Delta - Text
(
{"chunk_type": "content_delta", "data_type": "text", "data": "hello"},
{"contentBlockDelta": {"delta": {"text": "hello"}}},
),
- # Case 6: Content Stop
+ # Content Stop
(
{"chunk_type": "content_stop"},
{"contentBlockStop": {}},
),
- # Case 7: Message Stop - Tool Use
+ # Message Stop - Tool Use
(
{"chunk_type": "message_stop", "data": "tool_calls"},
{"messageStop": {"stopReason": "tool_use"}},
),
- # Case 8: Message Stop - Max Tokens
+ # Message Stop - Max Tokens
(
{"chunk_type": "message_stop", "data": "length"},
{"messageStop": {"stopReason": "max_tokens"}},
),
- # Case 9: Message Stop - End Turn
+ # Message Stop - End Turn
(
{"chunk_type": "message_stop", "data": "stop"},
{"messageStop": {"stopReason": "end_turn"}},
),
- # Case 10: Metadata
+ # Metadata
(
{
"chunk_type": "metadata",