From f431f152fc2c106932a7abc10e5a263a10ab3a7a Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Fri, 16 May 2025 14:42:39 -0400 Subject: [PATCH 01/49] fix: Update the PyPI package description (#15) To match the GitHub repository title --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6f9b78d6..9e084da5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "hatchling.build" [project] name = "strands-agents" version = "0.1.0" -description = "A production-ready framework for building autonomous AI agents" +description = "A model-driven approach to building AI agents in just a few lines of code" readme = "README.md" requires-python = ">=3.10" license = {text = "Apache-2.0"} From 6c4b6d3adab105bfe574f3be32e522852f7f00d4 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Fri, 16 May 2025 13:37:03 -0700 Subject: [PATCH 02/49] README: add link to llama user guide (#18) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 337acd83..4460643a 100644 --- a/README.md +++ b/README.md @@ -129,6 +129,7 @@ Built-in providers: - [Amazon Bedrock](https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/) - [Anthropic](https://strandsagents.com/latest/user-guide/concepts/model-providers/anthropic/) - [LiteLLM](https://strandsagents.com/latest/user-guide/concepts/model-providers/litellm/) + - [LlamaAPI](https://strandsagents.com/latest/user-guide/concepts/model-providers/llamaapi/) - [Ollama](https://strandsagents.com/latest/user-guide/concepts/model-providers/ollama/) Custom providers can be implemented using [Custom Providers](https://strandsagents.com/latest/user-guide/concepts/model-providers/custom_model_provider/) From 6decbd859789150123380ccaa274ccdd6115c3c7 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Fri, 16 May 2025 16:46:47 -0400 Subject: [PATCH 03/49] fix: Update readme to include badges (#17) Add common badges for python projects Co-authored-by: Mackenzie Zastrow --- README.md | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 4460643a..772ebe35 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,24 @@ # Strands Agents -[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](LICENSE) -[![Python](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/downloads/) -

A model-driven approach to building AI agents in just a few lines of code.

- Docs - ◆ Samples - ◆ Tools - ◆ Agent Builder +
+ GitHub commit activity + GitHub open issues + License + PyPI version + Python versions +
+ +

+ Docs + ◆ Samples + ◆ Tools + ◆ Agent Builder +

Strands Agents is a simple yet powerful SDK that takes a model-driven approach to building and running AI agents. From simple conversational assistants to complex autonomous workflows, from local development to production deployment, Strands Agents scales with your needs. From 5a54c28e24487c4a8557cd5073ab91cc02ed6651 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Fri, 16 May 2025 17:00:34 -0400 Subject: [PATCH 04/49] actions: fix docs dispatch (#19) --- .github/workflows/dispatch-docs.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/dispatch-docs.yml b/.github/workflows/dispatch-docs.yml index fda63413..b3802abe 100644 --- a/.github/workflows/dispatch-docs.yml +++ b/.github/workflows/dispatch-docs.yml @@ -12,5 +12,3 @@ jobs: steps: - name: Dispatch run: gh api repos/${{ github.repository_owner }}/docs/dispatches -F event_type=sdk-push - env: - GITHUB_TOKEN: ${{ secrets.PAT_TOKEN }} From d1b512d7ad69e2ebe8cc8d3ddad6953a6cfd680d Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Fri, 16 May 2025 17:59:01 -0400 Subject: [PATCH 05/49] actions: remove dispatch docs (#22) --- .github/workflows/dispatch-docs.yml | 14 -------------- 1 file changed, 14 deletions(-) delete mode 100644 .github/workflows/dispatch-docs.yml diff --git a/.github/workflows/dispatch-docs.yml b/.github/workflows/dispatch-docs.yml deleted file mode 100644 index b3802abe..00000000 --- a/.github/workflows/dispatch-docs.yml +++ /dev/null @@ -1,14 +0,0 @@ -name: Dispatch Docs -on: - push: - branches: - - main - -jobs: - trigger: - runs-on: ubuntu-latest - permissions: - contents: read - steps: - - name: Dispatch - run: gh api repos/${{ github.repository_owner }}/docs/dispatches -F event_type=sdk-push From e76945048a0a1f0136fb9b47f6f29a16bea6daef Mon Sep 17 00:00:00 2001 From: Clare Liguori Date: Fri, 16 May 2025 15:18:53 -0700 Subject: [PATCH 06/49] fix: set user-agent for Bedrock API calls (#23) --- src/strands/models/bedrock.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index d996aaae..b1b18b3e 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -110,9 +110,10 @@ def __init__( session = boto_session or boto3.Session( region_name=region_name or "us-west-2", ) + client_config = boto_client_config or BotocoreConfig(user_agent_extra="strands-agents") self.client = session.client( service_name="bedrock-runtime", - config=boto_client_config, + config=client_config, ) @override From e830c0512f8610c80e0030175becc921ff6796df Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Fri, 16 May 2025 21:05:14 -0400 Subject: [PATCH 07/49] v0.1.1 release (#26) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9e084da5..e3e3f372 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "strands-agents" -version = "0.1.0" +version = "0.1.1" description = "A model-driven approach to building AI agents in just a few lines of code" readme = "README.md" requires-python = ">=3.10" From a03b74cabf4e026eb6cc26f2066efcc2f7356cd5 Mon Sep 17 00:00:00 2001 From: Ryan Coleman Date: Fri, 16 May 2025 20:20:11 -0700 Subject: [PATCH 08/49] Update README.md mention of tools repo (#29) Typo in the examples tools header referencing the wrong repo --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 772ebe35..7a16324d 100644 --- a/README.md +++ b/README.md @@ -152,7 +152,7 @@ agent = Agent(tools=[calculator]) agent("What is the square root of 1764") ``` -It's also available on GitHub via [strands-agents-tools](https://github.com/strands-agents/strands-agents-tools). +It's also available on GitHub via [strands-agents/tools](https://github.com/strands-agents/tools). ## Documentation From 6088173ad7021c149e7ada3d1d4f8e3f3a7f6a72 Mon Sep 17 00:00:00 2001 From: Ryan Coleman Date: Sat, 17 May 2025 05:25:11 -0700 Subject: [PATCH 09/49] Update README to mention Meta Llama API as a supported model provider (#21) Co-authored-by: Ryan Coleman --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7a16324d..08d6bff0 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ Strands Agents is a simple yet powerful SDK that takes a model-driven approach t ## Feature Overview - **Lightweight & Flexible**: Simple agent loop that just works and is fully customizable -- **Model Agnostic**: Support for Amazon Bedrock, Anthropic, Ollama, and custom providers +- **Model Agnostic**: Support for Amazon Bedrock, Anthropic, Llama, Ollama, and custom providers - **Advanced Capabilities**: Multi-agent systems, autonomous agents, and streaming support - **Built-in MCP**: Native support for Model Context Protocol (MCP) servers, enabling access to thousands of pre-built tools From 2056888df9671aafaa733be4df52a8a02880584b Mon Sep 17 00:00:00 2001 From: Arron <139703460+awsarron@users.noreply.github.com> Date: Sat, 17 May 2025 18:58:28 -0400 Subject: [PATCH 10/49] fix: tracing of non-serializable values, e.g. bytes (#34) --- src/strands/telemetry/tracer.py | 58 +++++++--- tests/strands/telemetry/test_tracer.py | 142 ++++++++++++++++++++++++- 2 files changed, 184 insertions(+), 16 deletions(-) diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index ad30a445..809dbd46 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -7,7 +7,7 @@ import json import logging import os -from datetime import datetime, timezone +from datetime import date, datetime, timezone from importlib.metadata import version from typing import Any, Dict, Mapping, Optional @@ -30,21 +30,49 @@ class JSONEncoder(json.JSONEncoder): """Custom JSON encoder that handles non-serializable types.""" - def default(self, obj: Any) -> Any: - """Handle non-serializable types. + def encode(self, obj: Any) -> str: + """Recursively encode objects, preserving structure and only replacing unserializable values. Args: - obj: The object to serialize + obj: The object to encode Returns: - A JSON serializable version of the object + JSON string representation of the object """ - value = "" - try: - value = super().default(obj) - except TypeError: - value = "" - return value + # Process the object to handle non-serializable values + processed_obj = self._process_value(obj) + # Use the parent class to encode the processed object + return super().encode(processed_obj) + + def _process_value(self, value: Any) -> Any: + """Process any value, handling containers recursively. + + Args: + value: The value to process + + Returns: + Processed value with unserializable parts replaced + """ + # Handle datetime objects directly + if isinstance(value, (datetime, date)): + return value.isoformat() + + # Handle dictionaries + elif isinstance(value, dict): + return {k: self._process_value(v) for k, v in value.items()} + + # Handle lists + elif isinstance(value, list): + return [self._process_value(item) for item in value] + + # Handle all other values + else: + try: + # Test if the value is JSON serializable + json.dumps(value) + return value + except (TypeError, OverflowError, ValueError): + return "" class Tracer: @@ -332,6 +360,7 @@ def start_tool_call_span( The created span, or None if tracing is not enabled. """ attributes: Dict[str, AttributeValue] = { + "gen_ai.prompt": json.dumps(tool, cls=JSONEncoder), "tool.name": tool["name"], "tool.id": tool["toolUseId"], "tool.parameters": json.dumps(tool["input"], cls=JSONEncoder), @@ -358,10 +387,11 @@ def end_tool_call_span( status = tool_result.get("status") status_str = str(status) if status is not None else "" + tool_result_content_json = json.dumps(tool_result.get("content"), cls=JSONEncoder) attributes.update( { - "tool.result": json.dumps(tool_result.get("content"), cls=JSONEncoder), - "gen_ai.completion": json.dumps(tool_result.get("content"), cls=JSONEncoder), + "tool.result": tool_result_content_json, + "gen_ai.completion": tool_result_content_json, "tool.status": status_str, } ) @@ -492,7 +522,7 @@ def end_agent_span( if response: attributes.update( { - "gen_ai.completion": json.dumps(response, cls=JSONEncoder), + "gen_ai.completion": str(response), } ) diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 55018c5e..fbbfa89b 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -1,11 +1,12 @@ import json import os +from datetime import date, datetime, timezone from unittest import mock import pytest from opentelemetry.trace import StatusCode # type: ignore -from strands.telemetry.tracer import Tracer, get_tracer +from strands.telemetry.tracer import JSONEncoder, Tracer, get_tracer from strands.types.streaming import Usage @@ -268,6 +269,9 @@ def test_start_tool_call_span(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "Tool: test-tool" + mock_span.set_attribute.assert_any_call( + "gen_ai.prompt", json.dumps({"name": "test-tool", "toolUseId": "123", "input": {"param": "value"}}) + ) mock_span.set_attribute.assert_any_call("tool.name", "test-tool") mock_span.set_attribute.assert_any_call("tool.id", "123") mock_span.set_attribute.assert_any_call("tool.parameters", json.dumps({"param": "value"})) @@ -369,7 +373,7 @@ def test_end_agent_span(mock_span): tracer.end_agent_span(mock_span, mock_response) - mock_span.set_attribute.assert_any_call("gen_ai.completion", '""') + mock_span.set_attribute.assert_any_call("gen_ai.completion", "Agent response") mock_span.set_attribute.assert_any_call("gen_ai.usage.prompt_tokens", 50) mock_span.set_attribute.assert_any_call("gen_ai.usage.completion_tokens", 100) mock_span.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 150) @@ -497,3 +501,137 @@ def test_start_model_invoke_span_with_parent(mock_tracer): # Verify span was returned assert span is mock_span + + +@pytest.mark.parametrize( + "input_data, expected_result", + [ + ("test string", '"test string"'), + (1234, "1234"), + (13.37, "13.37"), + (False, "false"), + (None, "null"), + ], +) +def test_json_encoder_serializable(input_data, expected_result): + """Test encoding of serializable values.""" + encoder = JSONEncoder() + + result = encoder.encode(input_data) + assert result == expected_result + + +def test_json_encoder_datetime(): + """Test encoding datetime and date objects.""" + encoder = JSONEncoder() + + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + result = encoder.encode(dt) + assert result == f'"{dt.isoformat()}"' + + d = date(2025, 1, 1) + result = encoder.encode(d) + assert result == f'"{d.isoformat()}"' + + +def test_json_encoder_list(): + """Test encoding a list with mixed content.""" + encoder = JSONEncoder() + + non_serializable = lambda x: x # noqa: E731 + + data = ["value", 42, 13.37, non_serializable, None, {"key": True}, ["value here"]] + + result = json.loads(encoder.encode(data)) + assert result == ["value", 42, 13.37, "", None, {"key": True}, ["value here"]] + + +def test_json_encoder_dict(): + """Test encoding a dict with mixed content.""" + encoder = JSONEncoder() + + class UnserializableClass: + def __str__(self): + return "Unserializable Object" + + non_serializable = lambda x: x # noqa: E731 + + now = datetime.now(timezone.utc) + + data = { + "metadata": { + "timestamp": now, + "version": "1.0", + "debug_info": {"object": non_serializable, "callable": lambda x: x + 1}, # noqa: E731 + }, + "content": [ + {"type": "text", "value": "Hello world"}, + {"type": "binary", "value": non_serializable}, + {"type": "mixed", "values": [1, "text", non_serializable, {"nested": non_serializable}]}, + ], + "statistics": { + "processed": 100, + "failed": 5, + "details": [{"id": 1, "status": "ok"}, {"id": 2, "status": "error", "error_obj": non_serializable}], + }, + "list": [ + non_serializable, + 1234, + 13.37, + True, + None, + "string here", + ], + } + + expected = { + "metadata": { + "timestamp": now.isoformat(), + "version": "1.0", + "debug_info": {"object": "", "callable": ""}, + }, + "content": [ + {"type": "text", "value": "Hello world"}, + {"type": "binary", "value": ""}, + {"type": "mixed", "values": [1, "text", "", {"nested": ""}]}, + ], + "statistics": { + "processed": 100, + "failed": 5, + "details": [{"id": 1, "status": "ok"}, {"id": 2, "status": "error", "error_obj": ""}], + }, + "list": [ + "", + 1234, + 13.37, + True, + None, + "string here", + ], + } + + result = json.loads(encoder.encode(data)) + + assert result == expected + + +def test_json_encoder_value_error(): + """Test encoding values that cause ValueError.""" + encoder = JSONEncoder() + + # A very large integer that exceeds JSON limits and throws ValueError + huge_number = 2**100000 + + # Test in a dictionary + dict_data = {"normal": 42, "huge": huge_number} + result = json.loads(encoder.encode(dict_data)) + assert result == {"normal": 42, "huge": ""} + + # Test in a list + list_data = [42, huge_number] + result = json.loads(encoder.encode(list_data)) + assert result == [42, ""] + + # Test just the value + result = json.loads(encoder.encode(huge_number)) + assert result == "" From 4d29560c64a3f95c1eb7bb030ce5101a1756a112 Mon Sep 17 00:00:00 2001 From: Arron <139703460+awsarron@users.noreply.github.com> Date: Sun, 18 May 2025 17:55:59 -0400 Subject: [PATCH 11/49] fix(bedrock): use the AWS_REGION environment variable for the Bedrock model provider region if set and boto_session is not passed (#39) --- src/strands/models/bedrock.py | 6 ++++-- tests/strands/models/test_bedrock.py | 11 +++++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index b1b18b3e..4c02156d 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -4,6 +4,7 @@ """ import logging +import os from typing import Any, Iterable, Literal, Optional, cast import boto3 @@ -96,7 +97,8 @@ def __init__( Args: boto_session: Boto Session to use when calling the Bedrock Model. boto_client_config: Configuration to use when creating the Bedrock-Runtime Boto Client. - region_name: AWS region to use for the Bedrock service. Defaults to "us-west-2". + region_name: AWS region to use for the Bedrock service. + Defaults to the AWS_REGION environment variable if set, or "us-west-2" if not set. **model_config: Configuration options for the Bedrock model. """ if region_name and boto_session: @@ -108,7 +110,7 @@ def __init__( logger.debug("config=<%s> | initializing", self.config) session = boto_session or boto3.Session( - region_name=region_name or "us-west-2", + region_name=region_name or os.getenv("AWS_REGION") or "us-west-2", ) client_config = boto_client_config or BotocoreConfig(user_agent_extra="strands-agents") self.client = session.client( diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 566671ce..0844c8cd 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -1,3 +1,4 @@ +import os import unittest.mock import boto3 @@ -99,6 +100,16 @@ def test__init__with_custom_region(bedrock_client): mock_session_cls.assert_called_once_with(region_name=custom_region) +def test__init__with_environment_variable_region(bedrock_client): + """Test that BedrockModel uses the provided region.""" + _ = bedrock_client + os.environ["AWS_REGION"] = "eu-west-1" + + with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls: + _ = BedrockModel() + mock_session_cls.assert_called_once_with(region_name="eu-west-1") + + def test__init__with_region_and_session_raises_value_error(): """Test that BedrockModel raises ValueError when both region and session are provided.""" with pytest.raises(ValueError): From 912e1104fd52e3da99767a644b90b22b7eb606de Mon Sep 17 00:00:00 2001 From: Arron <139703460+awsarron@users.noreply.github.com> Date: Sun, 18 May 2025 18:27:44 -0400 Subject: [PATCH 12/49] v0.1.2 (#41) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e3e3f372..6582bddd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "strands-agents" -version = "0.1.1" +version = "0.1.2" description = "A model-driven approach to building AI agents in just a few lines of code" readme = "README.md" requires-python = ">=3.10" From 95c83134293550361f6ad979400fa0f0a6dfdde6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=2E/c=C2=B2?= Date: Tue, 20 May 2025 07:57:05 -0400 Subject: [PATCH 13/49] fix: update direct tool call references to use agent.tool.tool_name format (#56) --- src/strands/agent/agent.py | 4 ++-- tests/strands/agent/test_agent.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 89653036..6a948980 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -6,7 +6,7 @@ The Agent interface supports two complementary interaction patterns: 1. Natural language for conversation: `agent("Analyze this data")` -2. Method-style for direct tool access: `agent.tool_name(param1="value")` +2. Method-style for direct tool access: `agent.tool.tool_name(param1="value")` """ import asyncio @@ -515,7 +515,7 @@ def _record_tool_execution( """ # Create user message describing the tool call user_msg_content = [ - {"text": (f"agent.{tool['name']} direct tool call\nInput parameters: {json.dumps(tool['input'])}\n")} + {"text": (f"agent.tool.{tool['name']} direct tool call.\nInput parameters: {json.dumps(tool['input'])}\n")} ] # Add override message if provided diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 5d2ffb23..5c7d11e4 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -566,7 +566,7 @@ def test_agent_tool_user_message_override(agent): }, { "text": ( - "agent.tool_decorated direct tool call\n" + "agent.tool.tool_decorated direct tool call.\n" "Input parameters: " '{"random_string": "abcdEfghI123", "user_message_override": "test override"}\n' ), From 6c4a165d5e6151b0beba32294a613f38fc02f317 Mon Sep 17 00:00:00 2001 From: Bryan Samis <33967838+samisb@users.noreply.github.com> Date: Tue, 20 May 2025 10:51:57 -0400 Subject: [PATCH 14/49] Update README.md - corrected spelling of "model" (#59) Corrected spelling of model (s/modal/model) --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 08d6bff0..262dde51 100644 --- a/README.md +++ b/README.md @@ -117,11 +117,11 @@ agent = Agent(model=bedrock_model) agent("Tell me about Agentic AI") # Ollama -ollama_modal = OllamaModel( +ollama_model = OllamaModel( host="http://localhost:11434", model_id="llama3" ) -agent = Agent(model=ollama_modal) +agent = Agent(model=ollama_model) agent("Tell me about Agentic AI") # Llama API From 2d213c191e7060871929f7b51b0a20c77eba9519 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Tue, 20 May 2025 17:15:10 -0400 Subject: [PATCH 15/49] style guide (#49) * style guide * lint: logging --- CONTRIBUTING.md | 2 ++ STYLE_GUIDE.md | 59 +++++++++++++++++++++++++++++++++ pyproject.toml | 2 ++ src/strands/telemetry/tracer.py | 10 +++--- 4 files changed, 68 insertions(+), 5 deletions(-) create mode 100644 STYLE_GUIDE.md diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index e1ddbb89..18087852 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -94,6 +94,8 @@ hatch fmt --linter If you're using an IDE like VS Code or PyCharm, consider configuring it to use these tools automatically. +For additional details on styling, please see our dedicated [Style Guide](./STYLE_GUIDE.md). + ## Contributing via Pull Requests Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: diff --git a/STYLE_GUIDE.md b/STYLE_GUIDE.md new file mode 100644 index 00000000..a50c571b --- /dev/null +++ b/STYLE_GUIDE.md @@ -0,0 +1,59 @@ +# Style Guide + +## Overview + +The Strands Agents style guide aims to establish consistent formatting, naming conventions, and structure across all code in the repository. We strive to make our code clean, readable, and maintainable. + +Where possible, we will codify these style guidelines into our linting rules and pre-commit hooks to automate enforcement and reduce the manual review burden. + +## Log Formatting + +The format for Strands Agents logs is as follows: + +```python +logger.debug("field1=<%s>, field2=<%s>, ... | human readable message", field1, field2, ...) +``` + +### Guidelines + +1. **Context**: + - Add context as `=` pairs at the beginning of the log + - Many log services (CloudWatch, Splunk, etc.) look for these patterns to extract fields for searching + - Use `,`'s to separate pairs + - Enclose values in `<>` for readability + - This is particularly helpful in displaying empty values (`field=` vs `field=<>`) + - Use `%s` for string interpolation as recommended by Python logging + - This is an optimization to skip string interpolation when the log level is not enabled + +1. **Messages**: + - Add human readable messages at the end of the log + - Use lowercase for consistency + - Avoid punctuation (periods, exclamation points, etc.) to reduce clutter + - Keep messages concise and focused on a single statement + - If multiple statements are needed, separate them with the pipe character (`|`) + - Example: `"processing request | starting validation"` + +### Examples + +#### Good + +```python +logger.debug("user_id=<%s>, action=<%s> | user performed action", user_id, action) +logger.info("request_id=<%s>, duration_ms=<%d> | request completed", request_id, duration) +logger.warning("attempt=<%d>, max_attempts=<%d> | retry limit approaching", attempt, max_attempts) +``` + +#### Poor + +```python +# Avoid: No structured fields, direct variable interpolation in message +logger.debug(f"User {user_id} performed action {action}") + +# Avoid: Inconsistent formatting, punctuation +logger.info("Request completed in %d ms.", duration) + +# Avoid: No separation between fields and message +logger.warning("Retry limit approaching! attempt=%d max_attempts=%d", attempt, max_attempts) +``` + +By following these log formatting guidelines, we ensure that logs are both human-readable and machine-parseable, making debugging and monitoring more efficient. diff --git a/pyproject.toml b/pyproject.toml index 6582bddd..bd633390 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -185,7 +185,9 @@ select = [ "D", # pydocstyle "E", # pycodestyle "F", # pyflakes + "G", # logging format "I", # isort + "LOG", # logging ] [tool.ruff.lint.per-file-ignores] diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 809dbd46..b3709a1f 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -125,7 +125,7 @@ def __init__( headers_dict[key.strip()] = value.strip() otlp_headers = headers_dict except Exception as e: - logger.warning(f"error=<{e}> | failed to parse OTEL_EXPORTER_OTLP_HEADERS") + logger.warning("error=<%s> | failed to parse OTEL_EXPORTER_OTLP_HEADERS", e) self.service_name = service_name self.otlp_endpoint = otlp_endpoint @@ -184,9 +184,9 @@ def _initialize_tracer(self) -> None: batch_processor = BatchSpanProcessor(otlp_exporter) self.tracer_provider.add_span_processor(batch_processor) - logger.info(f"endpoint=<{endpoint}> | OTLP exporter configured with endpoint") + logger.info("endpoint=<%s> | OTLP exporter configured with endpoint", endpoint) except Exception as e: - logger.error(f"error=<{e}> | Failed to configure OTLP exporter", exc_info=True) + logger.exception("error=<%s> | Failed to configure OTLP exporter", e) # Set as global tracer provider trace.set_tracer_provider(self.tracer_provider) @@ -267,7 +267,7 @@ def _end_span( else: span.set_status(StatusCode.OK) except Exception as e: - logger.warning(f"error=<{e}> | error while ending span", exc_info=True) + logger.warning("error=<%s> | error while ending span", e, exc_info=True) finally: span.end() # Force flush to ensure spans are exported @@ -275,7 +275,7 @@ def _end_span( try: self.tracer_provider.force_flush() except Exception as e: - logger.warning(f"error=<{e}> | failed to force flush tracer provider") + logger.warning("error=<%s> | failed to force flush tracer provider", e) def end_span_with_error(self, span: trace.Span, error_message: str, exception: Optional[Exception] = None) -> None: """End a span with error status. From 5572438528d534687529a7cda8f1109cda4e61d1 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Tue, 20 May 2025 22:32:09 -0400 Subject: [PATCH 16/49] Update version to 0.1.3 (#63) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index bd633390..43130ac9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "strands-agents" -version = "0.1.2" +version = "0.1.3" description = "A model-driven approach to building AI agents in just a few lines of code" readme = "README.md" requires-python = ">=3.10" From 43aad3cc31a65923f09f79714196177e3ccea273 Mon Sep 17 00:00:00 2001 From: Jonathan Segev Date: Tue, 20 May 2025 22:32:41 -0400 Subject: [PATCH 17/49] fix: Updated GitHub Action to relay only on GitHub native approvals to run workflows (#67) --- .github/workflows/test-lint-pr.yml | 72 +++++++----------------------- 1 file changed, 15 insertions(+), 57 deletions(-) diff --git a/.github/workflows/test-lint-pr.yml b/.github/workflows/test-lint-pr.yml index 15fbebcb..b196a473 100644 --- a/.github/workflows/test-lint-pr.yml +++ b/.github/workflows/test-lint-pr.yml @@ -6,88 +6,49 @@ on: types: [opened, synchronize, reopened, ready_for_review, review_requested, review_request_removed] push: branches: [ main ] # Also run on direct pushes to main +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true jobs: - check-approval: - name: Check if PR has contributor approval - runs-on: ubuntu-latest - permissions: - pull-requests: read - # Skip this check for direct pushes to main - if: github.event_name == 'pull_request' - outputs: - approved: ${{ steps.check-approval.outputs.approved }} - steps: - - name: Check if PR has been approved by a contributor - id: check-approval - uses: actions/github-script@v7 - with: - script: | - const APPROVED_ASSOCIATION = ['COLLABORATOR', 'CONTRIBUTOR', 'MEMBER', 'OWNER'] - const PR_AUTHOR_ASSOCIATION = context.payload.pull_request.author_association; - const { data: reviews } = await github.rest.pulls.listReviews({ - owner: context.repo.owner, - repo: context.repo.repo, - pull_number: context.issue.number, - }); - - const isApprovedContributor = APPROVED_ASSOCIATION.includes(PR_AUTHOR_ASSOCIATION); - - // Check if any contributor has approved - const isApproved = reviews.some(review => - review.state === 'APPROVED' && APPROVED_ASSOCIATION.includes(review.author_association) - ) || isApprovedContributor; - - core.setOutput('approved', isApproved); - - if (!isApproved) { - core.notice('This PR does not have approval from a Contributor. Workflow will not run test jobs.'); - return false; - } - - return true; - unit-test: name: Unit Tests - Python ${{ matrix.python-version }} - ${{ matrix.os-name }} - needs: check-approval permissions: contents: read - # Only run if PR is approved or this is a direct push to main - if: github.event_name == 'push' || needs.check-approval.outputs.approved == 'true' strategy: matrix: include: # Linux - os: ubuntu-latest - os-name: linux + os-name: 'linux' python-version: "3.10" - os: ubuntu-latest - os-name: linux + os-name: 'linux' python-version: "3.11" - os: ubuntu-latest - os-name: linux + os-name: 'linux' python-version: "3.12" - os: ubuntu-latest - os-name: linux + os-name: 'linux' python-version: "3.13" # Windows - os: windows-latest - os-name: windows + os-name: 'windows' python-version: "3.10" - os: windows-latest - os-name: windows + os-name: 'windows' python-version: "3.11" - os: windows-latest - os-name: windows + os-name: 'windows' python-version: "3.12" - os: windows-latest - os-name: windows + os-name: 'windows' python-version: "3.13" - # MacOS - latest only; not enough runners for MacOS + # MacOS - latest only; not enough runners for macOS - os: macos-latest - os-name: macos - python-version: "3.13" - fail-fast: false + os-name: 'macOS' + python-version: "3.13" + fail-fast: true runs-on: ${{ matrix.os }} env: LOG_LEVEL: DEBUG @@ -108,14 +69,11 @@ jobs: id: tests run: hatch test tests --cover continue-on-error: false - lint: name: Lint runs-on: ubuntu-latest - needs: check-approval permissions: contents: read - if: github.event_name == 'push' || needs.check-approval.outputs.approved == 'true' steps: - name: Checkout code uses: actions/checkout@v4 From d0a57a9c46fcf026cbdd677345e94c2a879751ff Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Wed, 21 May 2025 12:08:35 -0400 Subject: [PATCH 18/49] models - litellm - capture usage (#73) --- src/strands/models/litellm.py | 8 +++++--- tests/strands/models/test_litellm.py | 7 +++++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index a7563133..3d198e1a 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -334,6 +334,8 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: yield {"chunk_type": "message_stop", "data": choice.finish_reason} - event = next(response) - if hasattr(event, "usage"): - yield {"chunk_type": "metadata", "data": event.usage} + # Skip remaining events as we don't have use for anything except the final usage payload + for event in response: + _ = event + + yield {"chunk_type": "metadata", "data": event.usage} diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 5d4d9b40..6789d441 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -545,12 +545,14 @@ def test_stream(litellm_client, model): def test_stream_empty(litellm_client, model): mock_delta = unittest.mock.Mock(content=None, tool_calls=None) + mock_usage = unittest.mock.Mock(prompt_tokens=0, completion_tokens=0, total_tokens=0) mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop")]) - mock_event_3 = unittest.mock.Mock(spec=[]) + mock_event_3 = unittest.mock.Mock() + mock_event_4 = unittest.mock.Mock(usage=mock_usage) - litellm_client.chat.completions.create.return_value = iter([mock_event_1, mock_event_2, mock_event_3]) + litellm_client.chat.completions.create.return_value = iter([mock_event_1, mock_event_2, mock_event_3, mock_event_4]) request = {"model": "m1", "messages": [{"role": "user", "content": []}]} response = model.stream(request) @@ -561,6 +563,7 @@ def test_stream_empty(litellm_client, model): {"chunk_type": "content_start", "data_type": "text"}, {"chunk_type": "content_stop", "data_type": "text"}, {"chunk_type": "message_stop", "data": "stop"}, + {"chunk_type": "metadata", "data": mock_usage}, ] assert tru_events == exp_events From 221d00455c295345b77bfc0d1e9220da6ee63290 Mon Sep 17 00:00:00 2001 From: Didier Durand Date: Wed, 21 May 2025 18:32:09 +0200 Subject: [PATCH 19/49] fixing various typos in markdowns and scripts (#74) * fixing various typos. * reverting 1 change after review --- CONTRIBUTING.md | 2 +- src/strands/event_loop/error_handler.py | 2 +- src/strands/models/bedrock.py | 2 +- src/strands/types/guardrails.py | 4 ++-- src/strands/types/media.py | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 18087852..b50957ef 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -117,7 +117,7 @@ To send us a pull request, please: ## Finding contributions to work on -Looking at the existing issues is a great way to find something to contribute on. +Looking at the existing issues is a great way to find something to contribute to. You can check: - Our known bugs list in [Bug Reports](../../issues?q=is%3Aissue%20state%3Aopen%20label%3Abug) for issues that need fixing diff --git a/src/strands/event_loop/error_handler.py b/src/strands/event_loop/error_handler.py index 5a78bb3f..a5c85668 100644 --- a/src/strands/event_loop/error_handler.py +++ b/src/strands/event_loop/error_handler.py @@ -74,7 +74,7 @@ def handle_input_too_long_error( """Handle 'Input is too long' errors by truncating tool results. When a context window overflow exception occurs (input too long for the model), this function attempts to recover - by finding and truncating the most recent tool results in the conversation history. If trunction is successful, the + by finding and truncating the most recent tool results in the conversation history. If truncation is successful, the function will make a call to the event loop. Args: diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 4c02156d..23583d4d 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -132,7 +132,7 @@ def get_config(self) -> BedrockConfig: """Get the current Bedrock Model configuration. Returns: - The Bedrok model configuration. + The Bedrock model configuration. """ return self.config diff --git a/src/strands/types/guardrails.py b/src/strands/types/guardrails.py index 6055b9ab..c15ba1be 100644 --- a/src/strands/types/guardrails.py +++ b/src/strands/types/guardrails.py @@ -16,7 +16,7 @@ class GuardrailConfig(TypedDict, total=False): Attributes: guardrailIdentifier: Unique identifier for the guardrail. guardrailVersion: Version of the guardrail to apply. - streamProcessingMode: Procesing mode. + streamProcessingMode: Processing mode. trace: The trace behavior for the guardrail. """ @@ -219,7 +219,7 @@ class GuardrailAssessment(TypedDict): contentPolicy: The content policy. contextualGroundingPolicy: The contextual grounding policy used for the guardrail assessment. sensitiveInformationPolicy: The sensitive information policy. - topicPolic: The topic policy. + topicPolicy: The topic policy. wordPolicy: The word policy. """ diff --git a/src/strands/types/media.py b/src/strands/types/media.py index 058a09ea..29b89e5c 100644 --- a/src/strands/types/media.py +++ b/src/strands/types/media.py @@ -68,7 +68,7 @@ class ImageContent(TypedDict): class VideoSource(TypedDict): - """Contains the content of a vidoe. + """Contains the content of a video. Attributes: bytes: The binary content of the video. From b21ea4770e322647906cef0657b2704223f56bff Mon Sep 17 00:00:00 2001 From: Jack Yuan <94985218+JackYPCOnline@users.noreply.github.com> Date: Thu, 22 May 2025 11:03:50 -0400 Subject: [PATCH 20/49] fix(docs): add missing quotation marks in pip install commands (#80) --- CONTRIBUTING.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index b50957ef..63464a5d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -33,7 +33,7 @@ This project uses [hatchling](https://hatch.pypa.io/latest/build/#hatchling) as 1. Install development dependencies: ```bash - pip install -e ".[dev]" && pip install -e ".[litellm] + pip install -e ".[dev]" && pip install -e ".[litellm]" ``` 2. Set up pre-commit hooks: From 5f4b68adb3dbb39837fc99363d8ca16dd8b61b73 Mon Sep 17 00:00:00 2001 From: Clare Liguori Date: Thu, 22 May 2025 08:04:17 -0700 Subject: [PATCH 21/49] fix: Merge strands-agents user agent into existing botocore config (#76) --- src/strands/models/bedrock.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 23583d4d..05d89923 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -112,7 +112,21 @@ def __init__( session = boto_session or boto3.Session( region_name=region_name or os.getenv("AWS_REGION") or "us-west-2", ) - client_config = boto_client_config or BotocoreConfig(user_agent_extra="strands-agents") + + # Add strands-agents to the request user agent + if boto_client_config: + existing_user_agent = getattr(boto_client_config, "user_agent_extra", None) + + # Append 'strands-agents' to existing user_agent_extra or set it if not present + if existing_user_agent: + new_user_agent = f"{existing_user_agent} strands-agents" + else: + new_user_agent = "strands-agents" + + client_config = boto_client_config.merge(BotocoreConfig(user_agent_extra=new_user_agent)) + else: + client_config = BotocoreConfig(user_agent_extra="strands-agents") + self.client = session.client( service_name="bedrock-runtime", config=client_config, From 6dda2d8fb2a97f6d78c5cb3dce60f9549d2ebb4b Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Thu, 22 May 2025 13:18:31 -0400 Subject: [PATCH 22/49] feature: models - openai (#65) --- pyproject.toml | 13 +- src/strands/models/litellm.py | 265 +-------- src/strands/models/openai.py | 120 ++++ src/strands/types/models/__init__.py | 6 + .../types/{models.py => models/model.py} | 6 +- src/strands/types/models/openai.py | 240 ++++++++ tests-integ/test_model_openai.py | 46 ++ tests/strands/models/test_litellm.py | 553 ++---------------- tests/{ => strands/models}/test_llamaapi.py | 0 tests/strands/models/test_openai.py | 134 +++++ tests/strands/types/models/__init__.py | 0 tests/strands/types/models/test_model.py | 81 +++ tests/strands/types/models/test_openai.py | 332 +++++++++++ 13 files changed, 1024 insertions(+), 772 deletions(-) create mode 100644 src/strands/models/openai.py create mode 100644 src/strands/types/models/__init__.py rename src/strands/types/{models.py => models/model.py} (97%) create mode 100644 src/strands/types/models/openai.py create mode 100644 tests-integ/test_model_openai.py rename tests/{ => strands/models}/test_llamaapi.py (100%) create mode 100644 tests/strands/models/test_openai.py create mode 100644 tests/strands/types/models/__init__.py create mode 100644 tests/strands/types/models/test_model.py create mode 100644 tests/strands/types/models/test_openai.py diff --git a/pyproject.toml b/pyproject.toml index 43130ac9..6800e6ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,7 @@ dev = [ "commitizen>=4.4.0,<5.0.0", "hatch>=1.0.0,<2.0.0", "moto>=5.1.0,<6.0.0", - "mypy>=0.981,<1.0.0", + "mypy>=1.15.0,<2.0.0", "pre-commit>=3.2.0,<4.2.0", "pytest>=8.0.0,<9.0.0", "pytest-asyncio>=0.26.0,<0.27.0", @@ -69,15 +69,18 @@ docs = [ litellm = [ "litellm>=1.69.0,<2.0.0", ] +llamaapi = [ + "llama-api-client>=0.1.0,<1.0.0", +] ollama = [ "ollama>=0.4.8,<1.0.0", ] -llamaapi = [ - "llama-api-client>=0.1.0,<1.0.0", +openai = [ + "openai>=1.68.0,<2.0.0", ] [tool.hatch.envs.hatch-static-analysis] -features = ["anthropic", "litellm", "llamaapi", "ollama"] +features = ["anthropic", "litellm", "llamaapi", "ollama", "openai"] dependencies = [ "mypy>=1.15.0,<2.0.0", "ruff>=0.11.6,<0.12.0", @@ -100,7 +103,7 @@ lint-fix = [ ] [tool.hatch.envs.hatch-test] -features = ["anthropic", "litellm", "llamaapi", "ollama"] +features = ["anthropic", "litellm", "llamaapi", "ollama", "openai"] extra-dependencies = [ "moto>=5.1.0,<6.0.0", "pytest>=8.0.0,<9.0.0", diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 3d198e1a..23d2c2ae 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -3,23 +3,19 @@ - Docs: https://docs.litellm.ai/ """ -import json import logging -import mimetypes -from typing import Any, Iterable, Optional, TypedDict +from typing import Any, Optional, TypedDict, cast import litellm from typing_extensions import Unpack, override -from ..types.content import ContentBlock, Messages -from ..types.models import Model -from ..types.streaming import StreamEvent -from ..types.tools import ToolResult, ToolSpec, ToolUse +from ..types.content import ContentBlock +from .openai import OpenAIModel logger = logging.getLogger(__name__) -class LiteLLMModel(Model): +class LiteLLMModel(OpenAIModel): """LiteLLM model provider implementation.""" class LiteLLMConfig(TypedDict, total=False): @@ -45,7 +41,7 @@ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: https://github.com/BerriAI/litellm/blob/main/litellm/main.py. **model_config: Configuration options for the LiteLLM model. """ - self.config = LiteLLMModel.LiteLLMConfig(**model_config) + self.config = dict(model_config) logger.debug("config=<%s> | initializing", self.config) @@ -68,9 +64,11 @@ def get_config(self) -> LiteLLMConfig: Returns: The LiteLLM model configuration. """ - return self.config + return cast(LiteLLMModel.LiteLLMConfig, self.config) - def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any]: + @override + @staticmethod + def format_request_message_content(content: ContentBlock) -> dict[str, Any]: """Format a LiteLLM content block. Args: @@ -79,18 +77,6 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An Returns: LiteLLM formatted content block. """ - if "image" in content: - mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream") - image_data = content["image"]["source"]["bytes"].decode("utf-8") - return { - "image_url": { - "detail": "auto", - "format": mime_type, - "url": f"data:{mime_type};base64,{image_data}", - }, - "type": "image_url", - } - if "reasoningContent" in content: return { "signature": content["reasoningContent"]["reasoningText"]["signature"], @@ -98,9 +84,6 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An "type": "thinking", } - if "text" in content: - return {"text": content["text"], "type": "text"} - if "video" in content: return { "type": "video_url", @@ -110,232 +93,4 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An }, } - return {"text": json.dumps(content), "type": "text"} - - def _format_request_message_tool_call(self, tool_use: ToolUse) -> dict[str, Any]: - """Format a LiteLLM tool call. - - Args: - tool_use: Tool use requested by the model. - - Returns: - LiteLLM formatted tool call. - """ - return { - "function": { - "arguments": json.dumps(tool_use["input"]), - "name": tool_use["name"], - }, - "id": tool_use["toolUseId"], - "type": "function", - } - - def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any]: - """Format a LiteLLM tool message. - - Args: - tool_result: Tool result collected from a tool execution. - - Returns: - LiteLLM formatted tool message. - """ - return { - "role": "tool", - "tool_call_id": tool_result["toolUseId"], - "content": json.dumps( - { - "content": tool_result["content"], - "status": tool_result["status"], - } - ), - } - - def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: - """Format a LiteLLM messages array. - - Args: - messages: List of message objects to be processed by the model. - system_prompt: System prompt to provide context to the model. - - Returns: - A LiteLLM messages array. - """ - formatted_messages: list[dict[str, Any]] - formatted_messages = [{"role": "system", "content": system_prompt}] if system_prompt else [] - - for message in messages: - contents = message["content"] - - formatted_contents = [ - self._format_request_message_content(content) - for content in contents - if not any(block_type in content for block_type in ["toolResult", "toolUse"]) - ] - formatted_tool_calls = [ - self._format_request_message_tool_call(content["toolUse"]) - for content in contents - if "toolUse" in content - ] - formatted_tool_messages = [ - self._format_request_tool_message(content["toolResult"]) - for content in contents - if "toolResult" in content - ] - - formatted_message = { - "role": message["role"], - "content": formatted_contents, - **({"tool_calls": formatted_tool_calls} if formatted_tool_calls else {}), - } - formatted_messages.append(formatted_message) - formatted_messages.extend(formatted_tool_messages) - - return [message for message in formatted_messages if message["content"] or "tool_calls" in message] - - @override - def format_request( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None - ) -> dict[str, Any]: - """Format a LiteLLM chat streaming request. - - Args: - messages: List of message objects to be processed by the model. - tool_specs: List of tool specifications to make available to the model. - system_prompt: System prompt to provide context to the model. - - Returns: - A LiteLLM chat streaming request. - """ - return { - "messages": self._format_request_messages(messages, system_prompt), - "model": self.config["model_id"], - "stream": True, - "stream_options": {"include_usage": True}, - "tools": [ - { - "type": "function", - "function": { - "name": tool_spec["name"], - "description": tool_spec["description"], - "parameters": tool_spec["inputSchema"]["json"], - }, - } - for tool_spec in tool_specs or [] - ], - **(self.config.get("params") or {}), - } - - @override - def format_chunk(self, event: dict[str, Any]) -> StreamEvent: - """Format the LiteLLM response events into standardized message chunks. - - Args: - event: A response event from the LiteLLM model. - - Returns: - The formatted chunk. - - Raises: - RuntimeError: If chunk_type is not recognized. - This error should never be encountered as we control chunk_type in the stream method. - """ - match event["chunk_type"]: - case "message_start": - return {"messageStart": {"role": "assistant"}} - - case "content_start": - if event["data_type"] == "tool": - return { - "contentBlockStart": { - "start": { - "toolUse": { - "name": event["data"].function.name, - "toolUseId": event["data"].id, - } - } - } - } - - return {"contentBlockStart": {"start": {}}} - - case "content_delta": - if event["data_type"] == "tool": - return {"contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments}}}} - - return {"contentBlockDelta": {"delta": {"text": event["data"]}}} - - case "content_stop": - return {"contentBlockStop": {}} - - case "message_stop": - match event["data"]: - case "tool_calls": - return {"messageStop": {"stopReason": "tool_use"}} - case "length": - return {"messageStop": {"stopReason": "max_tokens"}} - case _: - return {"messageStop": {"stopReason": "end_turn"}} - - case "metadata": - return { - "metadata": { - "usage": { - "inputTokens": event["data"].prompt_tokens, - "outputTokens": event["data"].completion_tokens, - "totalTokens": event["data"].total_tokens, - }, - "metrics": { - "latencyMs": 0, # TODO - }, - }, - } - - case _: - raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") - - @override - def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: - """Send the request to the LiteLLM model and get the streaming response. - - Args: - request: The formatted request to send to the LiteLLM model. - - Returns: - An iterable of response events from the LiteLLM model. - """ - response = self.client.chat.completions.create(**request) - - yield {"chunk_type": "message_start"} - yield {"chunk_type": "content_start", "data_type": "text"} - - tool_calls: dict[int, list[Any]] = {} - - for event in response: - choice = event.choices[0] - if choice.finish_reason: - break - - if choice.delta.content: - yield {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content} - - for tool_call in choice.delta.tool_calls or []: - tool_calls.setdefault(tool_call.index, []).append(tool_call) - - yield {"chunk_type": "content_stop", "data_type": "text"} - - for tool_deltas in tool_calls.values(): - tool_start, tool_deltas = tool_deltas[0], tool_deltas[1:] - yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_start} - - for tool_delta in tool_deltas: - yield {"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta} - - yield {"chunk_type": "content_stop", "data_type": "tool"} - - yield {"chunk_type": "message_stop", "data": choice.finish_reason} - - # Skip remaining events as we don't have use for anything except the final usage payload - for event in response: - _ = event - - yield {"chunk_type": "metadata", "data": event.usage} + return OpenAIModel.format_request_message_content(content) diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py new file mode 100644 index 00000000..f76ae018 --- /dev/null +++ b/src/strands/models/openai.py @@ -0,0 +1,120 @@ +"""OpenAI model provider. + +- Docs: https://platform.openai.com/docs/overview +""" + +import logging +from typing import Any, Iterable, Optional, Protocol, TypedDict, cast + +import openai +from typing_extensions import Unpack, override + +from ..types.models import OpenAIModel as SAOpenAIModel + +logger = logging.getLogger(__name__) + + +class Client(Protocol): + """Protocol defining the OpenAI-compatible interface for the underlying provider client.""" + + chat: Any + + +class OpenAIModel(SAOpenAIModel): + """OpenAI model provider implementation.""" + + client: Client + + class OpenAIConfig(TypedDict, total=False): + """Configuration options for OpenAI models. + + Attributes: + model_id: Model ID (e.g., "gpt-4o"). + For a complete list of supported models, see https://platform.openai.com/docs/models. + params: Model parameters (e.g., max_tokens). + For a complete list of supported parameters, see + https://platform.openai.com/docs/api-reference/chat/create. + """ + + model_id: str + params: Optional[dict[str, Any]] + + def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[OpenAIConfig]) -> None: + """Initialize provider instance. + + Args: + client_args: Arguments for the OpenAI client. + For a complete list of supported arguments, see https://pypi.org/project/openai/. + **model_config: Configuration options for the OpenAI model. + """ + self.config = dict(model_config) + + logger.debug("config=<%s> | initializing", self.config) + + client_args = client_args or {} + self.client = openai.OpenAI(**client_args) + + @override + def update_config(self, **model_config: Unpack[OpenAIConfig]) -> None: # type: ignore[override] + """Update the OpenAI model configuration with the provided arguments. + + Args: + **model_config: Configuration overrides. + """ + self.config.update(model_config) + + @override + def get_config(self) -> OpenAIConfig: + """Get the OpenAI model configuration. + + Returns: + The OpenAI model configuration. + """ + return cast(OpenAIModel.OpenAIConfig, self.config) + + @override + def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: + """Send the request to the OpenAI model and get the streaming response. + + Args: + request: The formatted request to send to the OpenAI model. + + Returns: + An iterable of response events from the OpenAI model. + """ + response = self.client.chat.completions.create(**request) + + yield {"chunk_type": "message_start"} + yield {"chunk_type": "content_start", "data_type": "text"} + + tool_calls: dict[int, list[Any]] = {} + + for event in response: + choice = event.choices[0] + + if choice.delta.content: + yield {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content} + + for tool_call in choice.delta.tool_calls or []: + tool_calls.setdefault(tool_call.index, []).append(tool_call) + + if choice.finish_reason: + break + + yield {"chunk_type": "content_stop", "data_type": "text"} + + for tool_deltas in tool_calls.values(): + yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]} + + for tool_delta in tool_deltas: + yield {"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta} + + yield {"chunk_type": "content_stop", "data_type": "tool"} + + yield {"chunk_type": "message_stop", "data": choice.finish_reason} + + # Skip remaining events as we don't have use for anything except the final usage payload + for event in response: + _ = event + + yield {"chunk_type": "metadata", "data": event.usage} diff --git a/src/strands/types/models/__init__.py b/src/strands/types/models/__init__.py new file mode 100644 index 00000000..5ce0a498 --- /dev/null +++ b/src/strands/types/models/__init__.py @@ -0,0 +1,6 @@ +"""Model-related type definitions for the SDK.""" + +from .model import Model +from .openai import OpenAIModel + +__all__ = ["Model", "OpenAIModel"] diff --git a/src/strands/types/models.py b/src/strands/types/models/model.py similarity index 97% rename from src/strands/types/models.py rename to src/strands/types/models/model.py index e3d96e29..23e74602 100644 --- a/src/strands/types/models.py +++ b/src/strands/types/models/model.py @@ -4,9 +4,9 @@ import logging from typing import Any, Iterable, Optional -from .content import Messages -from .streaming import StreamEvent -from .tools import ToolSpec +from ..content import Messages +from ..streaming import StreamEvent +from ..tools import ToolSpec logger = logging.getLogger(__name__) diff --git a/src/strands/types/models/openai.py b/src/strands/types/models/openai.py new file mode 100644 index 00000000..8f5ffab3 --- /dev/null +++ b/src/strands/types/models/openai.py @@ -0,0 +1,240 @@ +"""Base OpenAI model provider. + +This module provides the base OpenAI model provider class which implements shared logic for formatting requests and +responses to and from the OpenAI specification. + +- Docs: https://pypi.org/project/openai +""" + +import abc +import json +import logging +import mimetypes +from typing import Any, Optional + +from typing_extensions import override + +from ..content import ContentBlock, Messages +from ..streaming import StreamEvent +from ..tools import ToolResult, ToolSpec, ToolUse +from .model import Model + +logger = logging.getLogger(__name__) + + +class OpenAIModel(Model, abc.ABC): + """Base OpenAI model provider implementation. + + Implements shared logic for formatting requests and responses to and from the OpenAI specification. + """ + + config: dict[str, Any] + + @staticmethod + def format_request_message_content(content: ContentBlock) -> dict[str, Any]: + """Format an OpenAI compatible content block. + + Args: + content: Message content. + + Returns: + OpenAI compatible content block. + """ + if "image" in content: + mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream") + image_data = content["image"]["source"]["bytes"].decode("utf-8") + return { + "image_url": { + "detail": "auto", + "format": mime_type, + "url": f"data:{mime_type};base64,{image_data}", + }, + "type": "image_url", + } + + if "text" in content: + return {"text": content["text"], "type": "text"} + + return {"text": json.dumps(content), "type": "text"} + + @staticmethod + def format_request_message_tool_call(tool_use: ToolUse) -> dict[str, Any]: + """Format an OpenAI compatible tool call. + + Args: + tool_use: Tool use requested by the model. + + Returns: + OpenAI compatible tool call. + """ + return { + "function": { + "arguments": json.dumps(tool_use["input"]), + "name": tool_use["name"], + }, + "id": tool_use["toolUseId"], + "type": "function", + } + + @staticmethod + def format_request_tool_message(tool_result: ToolResult) -> dict[str, Any]: + """Format an OpenAI compatible tool message. + + Args: + tool_result: Tool result collected from a tool execution. + + Returns: + OpenAI compatible tool message. + """ + return { + "role": "tool", + "tool_call_id": tool_result["toolUseId"], + "content": json.dumps( + { + "content": tool_result["content"], + "status": tool_result["status"], + } + ), + } + + @classmethod + def format_request_messages(cls, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + """Format an OpenAI compatible messages array. + + Args: + messages: List of message objects to be processed by the model. + system_prompt: System prompt to provide context to the model. + + Returns: + An OpenAI compatible messages array. + """ + formatted_messages: list[dict[str, Any]] + formatted_messages = [{"role": "system", "content": system_prompt}] if system_prompt else [] + + for message in messages: + contents = message["content"] + + formatted_contents = [ + cls.format_request_message_content(content) + for content in contents + if not any(block_type in content for block_type in ["toolResult", "toolUse"]) + ] + formatted_tool_calls = [ + cls.format_request_message_tool_call(content["toolUse"]) for content in contents if "toolUse" in content + ] + formatted_tool_messages = [ + cls.format_request_tool_message(content["toolResult"]) + for content in contents + if "toolResult" in content + ] + + formatted_message = { + "role": message["role"], + "content": formatted_contents, + **({"tool_calls": formatted_tool_calls} if formatted_tool_calls else {}), + } + formatted_messages.append(formatted_message) + formatted_messages.extend(formatted_tool_messages) + + return [message for message in formatted_messages if message["content"] or "tool_calls" in message] + + @override + def format_request( + self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + ) -> dict[str, Any]: + """Format an OpenAI compatible chat streaming request. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. + + Returns: + An OpenAI compatible chat streaming request. + """ + return { + "messages": self.format_request_messages(messages, system_prompt), + "model": self.config["model_id"], + "stream": True, + "stream_options": {"include_usage": True}, + "tools": [ + { + "type": "function", + "function": { + "name": tool_spec["name"], + "description": tool_spec["description"], + "parameters": tool_spec["inputSchema"]["json"], + }, + } + for tool_spec in tool_specs or [] + ], + **(self.config.get("params") or {}), + } + + @override + def format_chunk(self, event: dict[str, Any]) -> StreamEvent: + """Format an OpenAI response event into a standardized message chunk. + + Args: + event: A response event from the OpenAI compatible model. + + Returns: + The formatted chunk. + + Raises: + RuntimeError: If chunk_type is not recognized. + This error should never be encountered as chunk_type is controlled in the stream method. + """ + match event["chunk_type"]: + case "message_start": + return {"messageStart": {"role": "assistant"}} + + case "content_start": + if event["data_type"] == "tool": + return { + "contentBlockStart": { + "start": { + "toolUse": { + "name": event["data"].function.name, + "toolUseId": event["data"].id, + } + } + } + } + + return {"contentBlockStart": {"start": {}}} + + case "content_delta": + if event["data_type"] == "tool": + return {"contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments}}}} + + return {"contentBlockDelta": {"delta": {"text": event["data"]}}} + + case "content_stop": + return {"contentBlockStop": {}} + + case "message_stop": + match event["data"]: + case "tool_calls": + return {"messageStop": {"stopReason": "tool_use"}} + case "length": + return {"messageStop": {"stopReason": "max_tokens"}} + case _: + return {"messageStop": {"stopReason": "end_turn"}} + + case "metadata": + return { + "metadata": { + "usage": { + "inputTokens": event["data"].prompt_tokens, + "outputTokens": event["data"].completion_tokens, + "totalTokens": event["data"].total_tokens, + }, + "metrics": { + "latencyMs": 0, # TODO + }, + }, + } + + case _: + raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") diff --git a/tests-integ/test_model_openai.py b/tests-integ/test_model_openai.py new file mode 100644 index 00000000..c9046ad5 --- /dev/null +++ b/tests-integ/test_model_openai.py @@ -0,0 +1,46 @@ +import os + +import pytest + +import strands +from strands import Agent +from strands.models.openai import OpenAIModel + + +@pytest.fixture +def model(): + return OpenAIModel( + model_id="gpt-4o", + client_args={ + "api_key": os.getenv("OPENAI_API_KEY"), + }, + ) + + +@pytest.fixture +def tools(): + @strands.tool + def tool_time() -> str: + return "12:00" + + @strands.tool + def tool_weather() -> str: + return "sunny" + + return [tool_time, tool_weather] + + +@pytest.fixture +def agent(model, tools): + return Agent(model=model, tools=tools) + + +@pytest.mark.skipif( + "OPENAI_API_KEY" not in os.environ, + reason="OPENAI_API_KEY environment variable missing", +) +def test_agent(agent): + result = agent("What is the time and weather in New York?") + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"]) diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 6789d441..528d1498 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -1,4 +1,3 @@ -import json import unittest.mock import pytest @@ -8,9 +7,14 @@ @pytest.fixture -def litellm_client(): +def litellm_client_cls(): with unittest.mock.patch.object(strands.models.litellm.litellm, "LiteLLM") as mock_client_cls: - yield mock_client_cls.return_value + yield mock_client_cls + + +@pytest.fixture +def litellm_client(litellm_client_cls): + return litellm_client_cls.return_value @pytest.fixture @@ -35,15 +39,15 @@ def system_prompt(): return "s1" -def test__init__model_configs(litellm_client, model_id): - _ = litellm_client +def test__init__(litellm_client_cls, model_id): + model = LiteLLMModel({"api_key": "k1"}, model_id=model_id, params={"max_tokens": 1}) - model = LiteLLMModel(model_id=model_id, params={"max_tokens": 1}) + tru_config = model.get_config() + exp_config = {"model_id": "m1", "params": {"max_tokens": 1}} - tru_max_tokens = model.get_config().get("params") - exp_max_tokens = {"max_tokens": 1} + assert tru_config == exp_config - assert tru_max_tokens == exp_max_tokens + litellm_client_cls.assert_called_once_with(api_key="k1") def test_update_config(model, model_id): @@ -55,516 +59,47 @@ def test_update_config(model, model_id): assert tru_model_id == exp_model_id -def test_format_request_default(model, messages, model_id): - tru_request = model.format_request(messages) - exp_request = { - "messages": [{"role": "user", "content": [{"type": "text", "text": "test"}]}], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_params(model, messages, model_id): - model.update_config(params={"max_tokens": 1}) - - tru_request = model.format_request(messages) - exp_request = { - "messages": [{"role": "user", "content": [{"type": "text", "text": "test"}]}], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [], - "max_tokens": 1, - } - - assert tru_request == exp_request - - -def test_format_request_with_system_prompt(model, messages, model_id, system_prompt): - tru_request = model.format_request(messages, system_prompt=system_prompt) - exp_request = { - "messages": [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": [{"type": "text", "text": "test"}]}, - ], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_image(model, model_id): - messages = [ - { - "role": "user", - "content": [ - { - "image": { - "format": "jpg", - "source": {"bytes": b"base64encodedimage"}, - }, - }, - ], - }, - ] - - tru_request = model.format_request(messages) - exp_request = { - "messages": [ - { - "role": "user", - "content": [ - { - "image_url": { - "detail": "auto", - "format": "image/jpeg", - "url": "", - }, - "type": "image_url", - }, - ], - }, - ], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_reasoning(model, model_id): - messages = [ - { - "role": "user", - "content": [ - { - "reasoningContent": { - "reasoningText": { - "signature": "reasoning_signature", - "text": "reasoning_text", - }, - }, - }, - ], - }, - ] - - tru_request = model.format_request(messages) - exp_request = { - "messages": [ +@pytest.mark.parametrize( + "content, exp_result", + [ + # Case 1: Thinking + ( { - "role": "user", - "content": [ - { + "reasoningContent": { + "reasoningText": { "signature": "reasoning_signature", - "thinking": "reasoning_text", - "type": "thinking", - }, - ], - }, - ], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_video(model, model_id): - messages = [ - { - "role": "user", - "content": [ - { - "video": { - "source": {"bytes": "base64encodedvideo"}, + "text": "reasoning_text", }, }, - ], - }, - ] - - tru_request = model.format_request(messages) - exp_request = { - "messages": [ - { - "role": "user", - "content": [ - { - "type": "video_url", - "video_url": { - "detail": "auto", - "url": "base64encodedvideo", - }, - }, - ], }, - ], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_other(model, model_id): - messages = [ - { - "role": "user", - "content": [{"other": {"a": 1}}], - }, - ] - - tru_request = model.format_request(messages) - exp_request = { - "messages": [ { - "role": "user", - "content": [ - { - "text": json.dumps({"other": {"a": 1}}), - "type": "text", - }, - ], + "signature": "reasoning_signature", + "thinking": "reasoning_text", + "type": "thinking", }, - ], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_tool_result(model, model_id): - messages = [ - { - "role": "user", - "content": [{"toolResult": {"toolUseId": "c1", "status": "success", "content": [{"value": 4}]}}], - } - ] - - tru_request = model.format_request(messages) - exp_request = { - "messages": [ + ), + # Case 2: Video + ( { - "content": json.dumps( - { - "content": [{"value": 4}], - "status": "success", - } - ), - "role": "tool", - "tool_call_id": "c1", - }, - ], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_tool_use(model, model_id): - messages = [ - { - "role": "assistant", - "content": [ - { - "toolUse": { - "toolUseId": "c1", - "name": "calculator", - "input": {"expression": "2+2"}, - }, + "video": { + "source": {"bytes": "base64encodedvideo"}, }, - ], - }, - ] - - tru_request = model.format_request(messages) - exp_request = { - "messages": [ - { - "content": [], - "role": "assistant", - "tool_calls": [ - { - "function": { - "name": "calculator", - "arguments": '{"expression": "2+2"}', - }, - "id": "c1", - "type": "function", - } - ], - } - ], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_tool_specs(model, messages, model_id): - tool_specs = [ - { - "name": "calculator", - "description": "Calculate mathematical expressions", - "inputSchema": { - "json": {"type": "object", "properties": {"expression": {"type": "string"}}, "required": ["expression"]} }, - } - ] - - tru_request = model.format_request(messages, tool_specs) - exp_request = { - "messages": [{"role": "user", "content": [{"type": "text", "text": "test"}]}], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [ { - "type": "function", - "function": { - "name": "calculator", - "description": "Calculate mathematical expressions", - "parameters": { - "type": "object", - "properties": {"expression": {"type": "string"}}, - "required": ["expression"], - }, + "type": "video_url", + "video_url": { + "detail": "auto", + "url": "base64encodedvideo", }, - } - ], - } - - assert tru_request == exp_request - - -def test_format_chunk_message_start(model): - event = {"chunk_type": "message_start"} - - tru_chunk = model.format_chunk(event) - exp_chunk = {"messageStart": {"role": "assistant"}} - - assert tru_chunk == exp_chunk - - -def test_format_chunk_content_start_text(model): - event = {"chunk_type": "content_start", "data_type": "text"} - - tru_chunk = model.format_chunk(event) - exp_chunk = {"contentBlockStart": {"start": {}}} - - assert tru_chunk == exp_chunk - - -def test_format_chunk_content_start_tool(model): - mock_tool_use = unittest.mock.Mock() - mock_tool_use.function.name = "calculator" - mock_tool_use.id = "c1" - - event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_use} - - tru_chunk = model.format_chunk(event) - exp_chunk = {"contentBlockStart": {"start": {"toolUse": {"name": "calculator", "toolUseId": "c1"}}}} - - assert tru_chunk == exp_chunk - - -def test_format_chunk_content_delta_text(model): - event = {"chunk_type": "content_delta", "data_type": "text", "data": "Hello"} - - tru_chunk = model.format_chunk(event) - exp_chunk = {"contentBlockDelta": {"delta": {"text": "Hello"}}} - - assert tru_chunk == exp_chunk - - -def test_format_chunk_content_delta_tool(model): - event = { - "chunk_type": "content_delta", - "data_type": "tool", - "data": unittest.mock.Mock(function=unittest.mock.Mock(arguments='{"expression": "2+2"}')), - } - - tru_chunk = model.format_chunk(event) - exp_chunk = {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "2+2"}'}}}} - - assert tru_chunk == exp_chunk - - -def test_format_chunk_content_stop(model): - event = {"chunk_type": "content_stop"} - - tru_chunk = model.format_chunk(event) - exp_chunk = {"contentBlockStop": {}} - - assert tru_chunk == exp_chunk - - -def test_format_chunk_message_stop_end_turn(model): - event = {"chunk_type": "message_stop", "data": "stop"} - - tru_chunk = model.format_chunk(event) - exp_chunk = {"messageStop": {"stopReason": "end_turn"}} - - assert tru_chunk == exp_chunk - - -def test_format_chunk_message_stop_tool_use(model): - event = {"chunk_type": "message_stop", "data": "tool_calls"} - - tru_chunk = model.format_chunk(event) - exp_chunk = {"messageStop": {"stopReason": "tool_use"}} - - assert tru_chunk == exp_chunk - - -def test_format_chunk_message_stop_max_tokens(model): - event = {"chunk_type": "message_stop", "data": "length"} - - tru_chunk = model.format_chunk(event) - exp_chunk = {"messageStop": {"stopReason": "max_tokens"}} - - assert tru_chunk == exp_chunk - - -def test_format_chunk_metadata(model): - event = { - "chunk_type": "metadata", - "data": unittest.mock.Mock(prompt_tokens=100, completion_tokens=50, total_tokens=150), - } - - tru_chunk = model.format_chunk(event) - exp_chunk = { - "metadata": { - "usage": { - "inputTokens": 100, - "outputTokens": 50, - "totalTokens": 150, - }, - "metrics": { - "latencyMs": 0, }, - }, - } - - assert tru_chunk == exp_chunk - - -def test_format_chunk_other(model): - event = {"chunk_type": "other"} - - with pytest.raises(RuntimeError, match="chunk_type= | unknown type"): - model.format_chunk(event) - - -def test_stream(litellm_client, model): - mock_tool_call_1_part_1 = unittest.mock.Mock(index=0) - mock_tool_call_2_part_1 = unittest.mock.Mock(index=1) - mock_delta_1 = unittest.mock.Mock( - content="I'll calculate", tool_calls=[mock_tool_call_1_part_1, mock_tool_call_2_part_1] - ) - - mock_tool_call_1_part_2 = unittest.mock.Mock(index=0) - mock_tool_call_2_part_2 = unittest.mock.Mock(index=1) - mock_delta_2 = unittest.mock.Mock( - content="that for you", tool_calls=[mock_tool_call_1_part_2, mock_tool_call_2_part_2] - ) - - mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_1)]) - mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_2)]) - mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls")]) - mock_event_4 = unittest.mock.Mock() - - litellm_client.chat.completions.create.return_value = iter([mock_event_1, mock_event_2, mock_event_3, mock_event_4]) - - request = {"model": "m1", "messages": [{"role": "user", "content": [{"type": "text", "text": "calculate 2+2"}]}]} - response = model.stream(request) - - tru_events = list(response) - exp_events = [ - {"chunk_type": "message_start"}, - {"chunk_type": "content_start", "data_type": "text"}, - {"chunk_type": "content_delta", "data_type": "text", "data": "I'll calculate"}, - {"chunk_type": "content_delta", "data_type": "text", "data": "that for you"}, - {"chunk_type": "content_stop", "data_type": "text"}, - {"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call_1_part_1}, - {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_1_part_2}, - {"chunk_type": "content_stop", "data_type": "tool"}, - {"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call_2_part_1}, - {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_2_part_2}, - {"chunk_type": "content_stop", "data_type": "tool"}, - {"chunk_type": "message_stop", "data": "tool_calls"}, - {"chunk_type": "metadata", "data": mock_event_4.usage}, - ] - - assert tru_events == exp_events - litellm_client.chat.completions.create.assert_called_once_with(**request) - - -def test_stream_empty(litellm_client, model): - mock_delta = unittest.mock.Mock(content=None, tool_calls=None) - mock_usage = unittest.mock.Mock(prompt_tokens=0, completion_tokens=0, total_tokens=0) - - mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) - mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop")]) - mock_event_3 = unittest.mock.Mock() - mock_event_4 = unittest.mock.Mock(usage=mock_usage) - - litellm_client.chat.completions.create.return_value = iter([mock_event_1, mock_event_2, mock_event_3, mock_event_4]) - - request = {"model": "m1", "messages": [{"role": "user", "content": []}]} - response = model.stream(request) - - tru_events = list(response) - exp_events = [ - {"chunk_type": "message_start"}, - {"chunk_type": "content_start", "data_type": "text"}, - {"chunk_type": "content_stop", "data_type": "text"}, - {"chunk_type": "message_stop", "data": "stop"}, - {"chunk_type": "metadata", "data": mock_usage}, - ] - - assert tru_events == exp_events - litellm_client.chat.completions.create.assert_called_once_with(**request) + ), + # Case 3: Text + ( + {"text": "hello"}, + {"type": "text", "text": "hello"}, + ), + ], +) +def test_format_request_message_content(content, exp_result): + tru_result = LiteLLMModel.format_request_message_content(content) + assert tru_result == exp_result diff --git a/tests/test_llamaapi.py b/tests/strands/models/test_llamaapi.py similarity index 100% rename from tests/test_llamaapi.py rename to tests/strands/models/test_llamaapi.py diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py new file mode 100644 index 00000000..89aa591f --- /dev/null +++ b/tests/strands/models/test_openai.py @@ -0,0 +1,134 @@ +import unittest.mock + +import pytest + +import strands +from strands.models.openai import OpenAIModel + + +@pytest.fixture +def openai_client_cls(): + with unittest.mock.patch.object(strands.models.openai.openai, "OpenAI") as mock_client_cls: + yield mock_client_cls + + +@pytest.fixture +def openai_client(openai_client_cls): + return openai_client_cls.return_value + + +@pytest.fixture +def model_id(): + return "m1" + + +@pytest.fixture +def model(openai_client, model_id): + _ = openai_client + + return OpenAIModel(model_id=model_id) + + +@pytest.fixture +def messages(): + return [{"role": "user", "content": [{"text": "test"}]}] + + +@pytest.fixture +def system_prompt(): + return "s1" + + +def test__init__(openai_client_cls, model_id): + model = OpenAIModel({"api_key": "k1"}, model_id=model_id, params={"max_tokens": 1}) + + tru_config = model.get_config() + exp_config = {"model_id": "m1", "params": {"max_tokens": 1}} + + assert tru_config == exp_config + + openai_client_cls.assert_called_once_with(api_key="k1") + + +def test_update_config(model, model_id): + model.update_config(model_id=model_id) + + tru_model_id = model.get_config().get("model_id") + exp_model_id = model_id + + assert tru_model_id == exp_model_id + + +def test_stream(openai_client, model): + mock_tool_call_1_part_1 = unittest.mock.Mock(index=0) + mock_tool_call_2_part_1 = unittest.mock.Mock(index=1) + mock_delta_1 = unittest.mock.Mock( + content="I'll calculate", tool_calls=[mock_tool_call_1_part_1, mock_tool_call_2_part_1] + ) + + mock_tool_call_1_part_2 = unittest.mock.Mock(index=0) + mock_tool_call_2_part_2 = unittest.mock.Mock(index=1) + mock_delta_2 = unittest.mock.Mock( + content="that for you", tool_calls=[mock_tool_call_1_part_2, mock_tool_call_2_part_2] + ) + + mock_delta_3 = unittest.mock.Mock(content="", tool_calls=None) + + mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_1)]) + mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_2)]) + mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta_3)]) + mock_event_4 = unittest.mock.Mock() + + openai_client.chat.completions.create.return_value = iter([mock_event_1, mock_event_2, mock_event_3, mock_event_4]) + + request = {"model": "m1", "messages": [{"role": "user", "content": [{"type": "text", "text": "calculate 2+2"}]}]} + response = model.stream(request) + + tru_events = list(response) + exp_events = [ + {"chunk_type": "message_start"}, + {"chunk_type": "content_start", "data_type": "text"}, + {"chunk_type": "content_delta", "data_type": "text", "data": "I'll calculate"}, + {"chunk_type": "content_delta", "data_type": "text", "data": "that for you"}, + {"chunk_type": "content_stop", "data_type": "text"}, + {"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call_1_part_1}, + {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_1_part_1}, + {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_1_part_2}, + {"chunk_type": "content_stop", "data_type": "tool"}, + {"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call_2_part_1}, + {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_2_part_1}, + {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_2_part_2}, + {"chunk_type": "content_stop", "data_type": "tool"}, + {"chunk_type": "message_stop", "data": "tool_calls"}, + {"chunk_type": "metadata", "data": mock_event_4.usage}, + ] + + assert tru_events == exp_events + openai_client.chat.completions.create.assert_called_once_with(**request) + + +def test_stream_empty(openai_client, model): + mock_delta = unittest.mock.Mock(content=None, tool_calls=None) + mock_usage = unittest.mock.Mock(prompt_tokens=0, completion_tokens=0, total_tokens=0) + + mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) + mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) + mock_event_3 = unittest.mock.Mock() + mock_event_4 = unittest.mock.Mock(usage=mock_usage) + + openai_client.chat.completions.create.return_value = iter([mock_event_1, mock_event_2, mock_event_3, mock_event_4]) + + request = {"model": "m1", "messages": [{"role": "user", "content": []}]} + response = model.stream(request) + + tru_events = list(response) + exp_events = [ + {"chunk_type": "message_start"}, + {"chunk_type": "content_start", "data_type": "text"}, + {"chunk_type": "content_stop", "data_type": "text"}, + {"chunk_type": "message_stop", "data": "stop"}, + {"chunk_type": "metadata", "data": mock_usage}, + ] + + assert tru_events == exp_events + openai_client.chat.completions.create.assert_called_once_with(**request) diff --git a/tests/strands/types/models/__init__.py b/tests/strands/types/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/strands/types/models/test_model.py b/tests/strands/types/models/test_model.py new file mode 100644 index 00000000..f2797fe5 --- /dev/null +++ b/tests/strands/types/models/test_model.py @@ -0,0 +1,81 @@ +import pytest + +from strands.types.models import Model as SAModel + + +class TestModel(SAModel): + def update_config(self, **model_config): + return model_config + + def get_config(self): + return + + def format_request(self, messages, tool_specs, system_prompt): + return { + "messages": messages, + "tool_specs": tool_specs, + "system_prompt": system_prompt, + } + + def format_chunk(self, event): + return {"event": event} + + def stream(self, request): + yield {"request": request} + + +@pytest.fixture +def model(): + return TestModel() + + +@pytest.fixture +def messages(): + return [ + { + "role": "user", + "content": [{"text": "hello"}], + }, + ] + + +@pytest.fixture +def tool_specs(): + return [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "input": {"type": "string"}, + }, + "required": ["input"], + }, + }, + }, + ] + + +@pytest.fixture +def system_prompt(): + return "s1" + + +def test_converse(model, messages, tool_specs, system_prompt): + response = model.converse(messages, tool_specs, system_prompt) + + tru_events = list(response) + exp_events = [ + { + "event": { + "request": { + "messages": messages, + "tool_specs": tool_specs, + "system_prompt": system_prompt, + }, + }, + }, + ] + assert tru_events == exp_events diff --git a/tests/strands/types/models/test_openai.py b/tests/strands/types/models/test_openai.py new file mode 100644 index 00000000..97a0882a --- /dev/null +++ b/tests/strands/types/models/test_openai.py @@ -0,0 +1,332 @@ +import json +import unittest.mock + +import pytest + +from strands.types.models import OpenAIModel as SAOpenAIModel + + +class TestOpenAIModel(SAOpenAIModel): + def __init__(self): + self.config = {"model_id": "m1", "params": {"max_tokens": 1}} + + def update_config(self, **model_config): + return model_config + + def get_config(self): + return + + def stream(self, request): + yield {"request": request} + + +@pytest.fixture +def model(): + return TestOpenAIModel() + + +@pytest.fixture +def messages(): + return [ + { + "role": "user", + "content": [{"text": "hello"}], + }, + ] + + +@pytest.fixture +def tool_specs(): + return [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "input": {"type": "string"}, + }, + "required": ["input"], + }, + }, + }, + ] + + +@pytest.fixture +def system_prompt(): + return "s1" + + +@pytest.mark.parametrize( + "content, exp_result", + [ + # Case 1: Image + ( + { + "image": { + "format": "jpg", + "source": {"bytes": b"image"}, + }, + }, + { + "image_url": { + "detail": "auto", + "format": "image/jpeg", + "url": "", + }, + "type": "image_url", + }, + ), + # Case 2: Text + ( + {"text": "hello"}, + {"type": "text", "text": "hello"}, + ), + # Case 3: Other + ( + {"other": {"a": 1}}, + { + "text": json.dumps({"other": {"a": 1}}), + "type": "text", + }, + ), + ], +) +def test_format_request_message_content(content, exp_result): + tru_result = SAOpenAIModel.format_request_message_content(content) + assert tru_result == exp_result + + +def test_format_request_message_tool_call(): + tool_use = { + "input": {"expression": "2+2"}, + "name": "calculator", + "toolUseId": "c1", + } + + tru_result = SAOpenAIModel.format_request_message_tool_call(tool_use) + exp_result = { + "function": { + "arguments": '{"expression": "2+2"}', + "name": "calculator", + }, + "id": "c1", + "type": "function", + } + assert tru_result == exp_result + + +def test_format_request_tool_message(): + tool_result = { + "content": [{"value": 4}], + "status": "success", + "toolUseId": "c1", + } + + tru_result = SAOpenAIModel.format_request_tool_message(tool_result) + exp_result = { + "content": json.dumps( + { + "content": [{"value": 4}], + "status": "success", + } + ), + "role": "tool", + "tool_call_id": "c1", + } + assert tru_result == exp_result + + +def test_format_request_messages(system_prompt): + messages = [ + { + "content": [], + "role": "user", + }, + { + "content": [{"text": "hello"}], + "role": "user", + }, + { + "content": [ + {"text": "call tool"}, + { + "toolUse": { + "input": {"expression": "2+2"}, + "name": "calculator", + "toolUseId": "c1", + }, + }, + ], + "role": "assistant", + }, + { + "content": [{"toolResult": {"toolUseId": "c1", "status": "success", "content": [{"value": 4}]}}], + "role": "user", + }, + ] + + tru_result = SAOpenAIModel.format_request_messages(messages, system_prompt) + exp_result = [ + { + "content": system_prompt, + "role": "system", + }, + { + "content": [{"text": "hello", "type": "text"}], + "role": "user", + }, + { + "content": [{"text": "call tool", "type": "text"}], + "role": "assistant", + "tool_calls": [ + { + "function": { + "name": "calculator", + "arguments": '{"expression": "2+2"}', + }, + "id": "c1", + "type": "function", + } + ], + }, + { + "content": json.dumps( + { + "content": [{"value": 4}], + "status": "success", + } + ), + "role": "tool", + "tool_call_id": "c1", + }, + ] + assert tru_result == exp_result + + +def test_format_request(model, messages, tool_specs, system_prompt): + tru_request = model.format_request(messages, tool_specs, system_prompt) + exp_request = { + "messages": [ + { + "content": system_prompt, + "role": "system", + }, + { + "content": [{"text": "hello", "type": "text"}], + "role": "user", + }, + ], + "model": "m1", + "stream": True, + "stream_options": {"include_usage": True}, + "tools": [ + { + "function": { + "description": "A test tool", + "name": "test_tool", + "parameters": { + "properties": { + "input": {"type": "string"}, + }, + "required": ["input"], + "type": "object", + }, + }, + "type": "function", + }, + ], + "max_tokens": 1, + } + assert tru_request == exp_request + + +@pytest.mark.parametrize( + ("event", "exp_chunk"), + [ + # Case 1: Message start + ( + {"chunk_type": "message_start"}, + {"messageStart": {"role": "assistant"}}, + ), + # Case 2: Content Start - Tool Use + ( + { + "chunk_type": "content_start", + "data_type": "tool", + "data": unittest.mock.Mock(**{"function.name": "calculator", "id": "c1"}), + }, + {"contentBlockStart": {"start": {"toolUse": {"name": "calculator", "toolUseId": "c1"}}}}, + ), + # Case 3: Content Start - Text + ( + {"chunk_type": "content_start", "data_type": "text"}, + {"contentBlockStart": {"start": {}}}, + ), + # Case 4: Content Delta - Tool Use + ( + { + "chunk_type": "content_delta", + "data_type": "tool", + "data": unittest.mock.Mock(function=unittest.mock.Mock(arguments='{"expression": "2+2"}')), + }, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "2+2"}'}}}}, + ), + # Case 5: Content Delta - Text + ( + {"chunk_type": "content_delta", "data_type": "text", "data": "hello"}, + {"contentBlockDelta": {"delta": {"text": "hello"}}}, + ), + # Case 6: Content Stop + ( + {"chunk_type": "content_stop"}, + {"contentBlockStop": {}}, + ), + # Case 7: Message Stop - Tool Use + ( + {"chunk_type": "message_stop", "data": "tool_calls"}, + {"messageStop": {"stopReason": "tool_use"}}, + ), + # Case 8: Message Stop - Max Tokens + ( + {"chunk_type": "message_stop", "data": "length"}, + {"messageStop": {"stopReason": "max_tokens"}}, + ), + # Case 9: Message Stop - End Turn + ( + {"chunk_type": "message_stop", "data": "stop"}, + {"messageStop": {"stopReason": "end_turn"}}, + ), + # Case 10: Metadata + ( + { + "chunk_type": "metadata", + "data": unittest.mock.Mock(prompt_tokens=100, completion_tokens=50, total_tokens=150), + }, + { + "metadata": { + "usage": { + "inputTokens": 100, + "outputTokens": 50, + "totalTokens": 150, + }, + "metrics": { + "latencyMs": 0, + }, + }, + }, + ), + ], +) +def test_format_chunk(event, exp_chunk, model): + tru_chunk = model.format_chunk(event) + assert tru_chunk == exp_chunk + + +def test_format_chunk_unknown_type(model): + event = {"chunk_type": "unknown"} + + with pytest.raises(RuntimeError, match="chunk_type= | unknown type"): + model.format_chunk(event) From cb4da3ce7ac8881386c14208e8ca3980d8ea5e13 Mon Sep 17 00:00:00 2001 From: Didier Durand Date: Thu, 22 May 2025 22:10:17 +0200 Subject: [PATCH 23/49] fixing typos in .py and .md (#78) --- .github/workflows/test-lint-pr.yml | 2 +- CONTRIBUTING.md | 2 +- STYLE_GUIDE.md | 2 +- src/strands/tools/tools.py | 2 +- tests-integ/test_stream_agent.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/test-lint-pr.yml b/.github/workflows/test-lint-pr.yml index b196a473..5ba62427 100644 --- a/.github/workflows/test-lint-pr.yml +++ b/.github/workflows/test-lint-pr.yml @@ -56,7 +56,7 @@ jobs: - name: Checkout code uses: actions/checkout@v4 with: - ref: ${{ github.event.pull_request.head.sha }} # Explicitly define which commit to checkout + ref: ${{ github.event.pull_request.head.sha }} # Explicitly define which commit to check out persist-credentials: false # Don't persist credentials for subsequent actions - name: Set up Python uses: actions/setup-python@v5 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 63464a5d..c54c5ec3 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -40,7 +40,7 @@ This project uses [hatchling](https://hatch.pypa.io/latest/build/#hatchling) as ```bash pre-commit install -t pre-commit -t commit-msg ``` - This will automatically run formatters and convention commit checks on your code before each commit. + This will automatically run formatters and conventional commit checks on your code before each commit. 3. Run code formatters manually: ```bash diff --git a/STYLE_GUIDE.md b/STYLE_GUIDE.md index a50c571b..51dc0a73 100644 --- a/STYLE_GUIDE.md +++ b/STYLE_GUIDE.md @@ -26,7 +26,7 @@ logger.debug("field1=<%s>, field2=<%s>, ... | human readable message", field1, f - This is an optimization to skip string interpolation when the log level is not enabled 1. **Messages**: - - Add human readable messages at the end of the log + - Add human-readable messages at the end of the log - Use lowercase for consistency - Avoid punctuation (periods, exclamation points, etc.) to reduce clutter - Keep messages concise and focused on a single statement diff --git a/src/strands/tools/tools.py b/src/strands/tools/tools.py index 40565a24..7d43125b 100644 --- a/src/strands/tools/tools.py +++ b/src/strands/tools/tools.py @@ -40,7 +40,7 @@ def validate_tool_use_name(tool: ToolUse) -> None: Raises: InvalidToolUseNameException: If the tool name is invalid. """ - # We need to fix some typing here, because we dont actually expect a ToolUse, but dict[str, Any] + # We need to fix some typing here, because we don't actually expect a ToolUse, but dict[str, Any] if "name" not in tool: message = "tool name missing" # type: ignore[unreachable] logger.warning(message) diff --git a/tests-integ/test_stream_agent.py b/tests-integ/test_stream_agent.py index 4c97db6b..01f20339 100644 --- a/tests-integ/test_stream_agent.py +++ b/tests-integ/test_stream_agent.py @@ -1,5 +1,5 @@ """ -Test script for Strands's custom callback handler functionality. +Test script for Strands' custom callback handler functionality. Demonstrates different patterns of callback handling and processing. """ From 77f5fa73f62913b19014d7246dd2ef63c34f2551 Mon Sep 17 00:00:00 2001 From: wzxxing <169175349+wzxxing@users.noreply.github.com> Date: Thu, 22 May 2025 23:22:01 +0200 Subject: [PATCH 24/49] docs: update contributing guide to manage python env with hatch shell (#46) --- CONTRIBUTING.md | 8 +++++++- pyproject.toml | 5 +++++ src/strands/handlers/tool_handler.py | 1 + 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c54c5ec3..fa724cdd 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -31,11 +31,17 @@ This project uses [hatchling](https://hatch.pypa.io/latest/build/#hatchling) as ### Setting Up Your Development Environment -1. Install development dependencies: +1. Entering virtual environment using `hatch` (recommended), then launch your IDE in the new shell. + ```bash + hatch shell dev + ``` + + Alternatively, install development dependencies in a manually created virtual environment: ```bash pip install -e ".[dev]" && pip install -e ".[litellm]" ``` + 2. Set up pre-commit hooks: ```bash pre-commit install -t pre-commit -t commit-msg diff --git a/pyproject.toml b/pyproject.toml index 6800e6ef..7be9c164 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -117,6 +117,11 @@ extra-args = [ "-vv", ] +[tool.hatch.envs.dev] +dev-mode = true +features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama"] + + [[tool.hatch.envs.hatch-test.matrix]] python = ["3.13", "3.12", "3.11", "3.10"] diff --git a/src/strands/handlers/tool_handler.py b/src/strands/handlers/tool_handler.py index 0803eca5..bc4ec1ce 100644 --- a/src/strands/handlers/tool_handler.py +++ b/src/strands/handlers/tool_handler.py @@ -46,6 +46,7 @@ def preprocess( def process( self, tool: Any, + *, model: Model, system_prompt: Optional[str], messages: List[Any], From 1831bdb13df0c267c3a04d692d93688c71b975eb Mon Sep 17 00:00:00 2001 From: moritalous Date: Fri, 23 May 2025 11:59:11 +0900 Subject: [PATCH 25/49] Add ensure_ascii=False to json.dumps() calls in telemetry tracer (#37) --- src/strands/telemetry/tracer.py | 30 +++++++++++----- tests/strands/telemetry/test_tracer.py | 48 +++++++++++++++++++++++++- 2 files changed, 68 insertions(+), 10 deletions(-) diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index b3709a1f..3ec663ce 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -315,7 +315,7 @@ def start_model_invoke_span( "gen_ai.system": "strands-agents", "agent.name": agent_name, "gen_ai.agent.name": agent_name, - "gen_ai.prompt": json.dumps(messages, cls=JSONEncoder), + "gen_ai.prompt": serialize(messages), } if model_id: @@ -338,7 +338,7 @@ def end_model_invoke_span( error: Optional exception if the model call failed. """ attributes: Dict[str, AttributeValue] = { - "gen_ai.completion": json.dumps(message["content"], cls=JSONEncoder), + "gen_ai.completion": serialize(message["content"]), "gen_ai.usage.prompt_tokens": usage["inputTokens"], "gen_ai.usage.completion_tokens": usage["outputTokens"], "gen_ai.usage.total_tokens": usage["totalTokens"], @@ -360,10 +360,10 @@ def start_tool_call_span( The created span, or None if tracing is not enabled. """ attributes: Dict[str, AttributeValue] = { - "gen_ai.prompt": json.dumps(tool, cls=JSONEncoder), + "gen_ai.prompt": serialize(tool), "tool.name": tool["name"], "tool.id": tool["toolUseId"], - "tool.parameters": json.dumps(tool["input"], cls=JSONEncoder), + "tool.parameters": serialize(tool["input"]), } # Add additional kwargs as attributes @@ -387,7 +387,7 @@ def end_tool_call_span( status = tool_result.get("status") status_str = str(status) if status is not None else "" - tool_result_content_json = json.dumps(tool_result.get("content"), cls=JSONEncoder) + tool_result_content_json = serialize(tool_result.get("content")) attributes.update( { "tool.result": tool_result_content_json, @@ -420,7 +420,7 @@ def start_event_loop_cycle_span( parent_span = parent_span if parent_span else event_loop_kwargs.get("event_loop_parent_span") attributes: Dict[str, AttributeValue] = { - "gen_ai.prompt": json.dumps(messages, cls=JSONEncoder), + "gen_ai.prompt": serialize(messages), "event_loop.cycle_id": event_loop_cycle_id, } @@ -449,11 +449,11 @@ def end_event_loop_cycle_span( error: Optional exception if the cycle failed. """ attributes: Dict[str, AttributeValue] = { - "gen_ai.completion": json.dumps(message["content"], cls=JSONEncoder), + "gen_ai.completion": serialize(message["content"]), } if tool_result_message: - attributes["tool.result"] = json.dumps(tool_result_message["content"], cls=JSONEncoder) + attributes["tool.result"] = serialize(tool_result_message["content"]) self._end_span(span, attributes, error) @@ -490,7 +490,7 @@ def start_agent_span( attributes["gen_ai.request.model"] = model_id if tools: - tools_json = json.dumps(tools, cls=JSONEncoder) + tools_json = serialize(tools) attributes["agent.tools"] = tools_json attributes["gen_ai.agent.tools"] = tools_json @@ -571,3 +571,15 @@ def get_tracer( ) return _tracer_instance + + +def serialize(obj: Any) -> str: + """Serialize an object to JSON with consistent settings. + + Args: + obj: The object to serialize + + Returns: + JSON string representation of the object + """ + return json.dumps(obj, ensure_ascii=False, cls=JSONEncoder) diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index fbbfa89b..128b4f94 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -6,7 +6,7 @@ import pytest from opentelemetry.trace import StatusCode # type: ignore -from strands.telemetry.tracer import JSONEncoder, Tracer, get_tracer +from strands.telemetry.tracer import JSONEncoder, Tracer, get_tracer, serialize from strands.types.streaming import Usage @@ -635,3 +635,49 @@ def test_json_encoder_value_error(): # Test just the value result = json.loads(encoder.encode(huge_number)) assert result == "" + + +def test_serialize_non_ascii_characters(): + """Test that non-ASCII characters are preserved in JSON serialization.""" + + # Test with Japanese text + japanese_text = "こんにちは世界" + result = serialize({"text": japanese_text}) + assert japanese_text in result + assert "\\u" not in result + + # Test with emoji + emoji_text = "Hello 🌍" + result = serialize({"text": emoji_text}) + assert emoji_text in result + assert "\\u" not in result + + # Test with Chinese characters + chinese_text = "你好,世界" + result = serialize({"text": chinese_text}) + assert chinese_text in result + assert "\\u" not in result + + # Test with mixed content + mixed_text = {"ja": "こんにちは", "emoji": "😊", "zh": "你好", "en": "hello"} + result = serialize(mixed_text) + assert "こんにちは" in result + assert "😊" in result + assert "你好" in result + assert "\\u" not in result + + +def test_serialize_vs_json_dumps(): + """Test that serialize behaves differently from default json.dumps for non-ASCII characters.""" + + # Test with Japanese text + japanese_text = "こんにちは世界" + + # Default json.dumps should escape non-ASCII characters + default_result = json.dumps({"text": japanese_text}) + assert "\\u" in default_result + + # Our serialize function should preserve non-ASCII characters + custom_result = serialize({"text": japanese_text}) + assert japanese_text in custom_result + assert "\\u" not in custom_result From 0bb2b64640e457a18374206eefa3b070406f6373 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Fri, 23 May 2025 12:07:13 -0400 Subject: [PATCH 26/49] lint - openai client protocol (#87) --- src/strands/models/openai.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index f76ae018..6e32c5bd 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -17,7 +17,10 @@ class Client(Protocol): """Protocol defining the OpenAI-compatible interface for the underlying provider client.""" - chat: Any + @property + def chat(self) -> Any: + """Chat completions interface.""" + ... class OpenAIModel(SAOpenAIModel): From f5ac46da024d00ca03006ee568b7e9059b3b1ef8 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Fri, 23 May 2025 14:47:06 -0400 Subject: [PATCH 27/49] fix: Lower OpenTelemetry minimum version (#89) This fixes https://github.com/strands-agents/sdk-python/issues/88 --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7be9c164..8bb55bd6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,9 +33,9 @@ dependencies = [ "pydantic>=2.0.0,<3.0.0", "typing-extensions>=4.13.2,<5.0.0", "watchdog>=6.0.0,<7.0.0", - "opentelemetry-api>=1.33.0,<2.0.0", - "opentelemetry-sdk>=1.33.0,<2.0.0", - "opentelemetry-exporter-otlp-proto-http>=1.33.0,<2.0.0", + "opentelemetry-api>=1.30.0,<2.0.0", + "opentelemetry-sdk>=1.30.0,<2.0.0", + "opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0", ] [project.urls] From 58dc3ab8c0c37e6915648cfa60816cbc2874202b Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Fri, 23 May 2025 16:03:49 -0400 Subject: [PATCH 28/49] version - 0.1.4 (#93) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8bb55bd6..52394cea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "strands-agents" -version = "0.1.3" +version = "0.1.4" description = "A model-driven approach to building AI agents in just a few lines of code" readme = "README.md" requires-python = ">=3.10" From 61c3bf794f81c46aef4cd40ded6a68b7b48acb20 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Fri, 23 May 2025 19:31:02 -0400 Subject: [PATCH 29/49] models - openai - argument none (#97) --- src/strands/models/openai.py | 1 + src/strands/types/models/openai.py | 4 +++- tests/strands/types/models/test_openai.py | 29 +++++++++++++++-------- 3 files changed, 23 insertions(+), 11 deletions(-) diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index 6e32c5bd..764cb851 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -18,6 +18,7 @@ class Client(Protocol): """Protocol defining the OpenAI-compatible interface for the underlying provider client.""" @property + # pragma: no cover def chat(self) -> Any: """Chat completions interface.""" ... diff --git a/src/strands/types/models/openai.py b/src/strands/types/models/openai.py index 8f5ffab3..c00a7774 100644 --- a/src/strands/types/models/openai.py +++ b/src/strands/types/models/openai.py @@ -206,7 +206,9 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: case "content_delta": if event["data_type"] == "tool": - return {"contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments}}}} + return { + "contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments or ""}}} + } return {"contentBlockDelta": {"delta": {"text": event["data"]}}} diff --git a/tests/strands/types/models/test_openai.py b/tests/strands/types/models/test_openai.py index 97a0882a..2657c334 100644 --- a/tests/strands/types/models/test_openai.py +++ b/tests/strands/types/models/test_openai.py @@ -246,12 +246,12 @@ def test_format_request(model, messages, tool_specs, system_prompt): @pytest.mark.parametrize( ("event", "exp_chunk"), [ - # Case 1: Message start + # Message start ( {"chunk_type": "message_start"}, {"messageStart": {"role": "assistant"}}, ), - # Case 2: Content Start - Tool Use + # Content Start - Tool Use ( { "chunk_type": "content_start", @@ -260,12 +260,12 @@ def test_format_request(model, messages, tool_specs, system_prompt): }, {"contentBlockStart": {"start": {"toolUse": {"name": "calculator", "toolUseId": "c1"}}}}, ), - # Case 3: Content Start - Text + # Content Start - Text ( {"chunk_type": "content_start", "data_type": "text"}, {"contentBlockStart": {"start": {}}}, ), - # Case 4: Content Delta - Tool Use + # Content Delta - Tool Use ( { "chunk_type": "content_delta", @@ -274,32 +274,41 @@ def test_format_request(model, messages, tool_specs, system_prompt): }, {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "2+2"}'}}}}, ), - # Case 5: Content Delta - Text + # Content Delta - Tool Use - None + ( + { + "chunk_type": "content_delta", + "data_type": "tool", + "data": unittest.mock.Mock(function=unittest.mock.Mock(arguments=None)), + }, + {"contentBlockDelta": {"delta": {"toolUse": {"input": ""}}}}, + ), + # Content Delta - Text ( {"chunk_type": "content_delta", "data_type": "text", "data": "hello"}, {"contentBlockDelta": {"delta": {"text": "hello"}}}, ), - # Case 6: Content Stop + # Content Stop ( {"chunk_type": "content_stop"}, {"contentBlockStop": {}}, ), - # Case 7: Message Stop - Tool Use + # Message Stop - Tool Use ( {"chunk_type": "message_stop", "data": "tool_calls"}, {"messageStop": {"stopReason": "tool_use"}}, ), - # Case 8: Message Stop - Max Tokens + # Message Stop - Max Tokens ( {"chunk_type": "message_stop", "data": "length"}, {"messageStop": {"stopReason": "max_tokens"}}, ), - # Case 9: Message Stop - End Turn + # Message Stop - End Turn ( {"chunk_type": "message_stop", "data": "stop"}, {"messageStop": {"stopReason": "end_turn"}}, ), - # Case 10: Metadata + # Metadata ( { "chunk_type": "metadata", From a248d543544e8590ea42bb08cde0139792ba8f41 Mon Sep 17 00:00:00 2001 From: Arron <139703460+awsarron@users.noreply.github.com> Date: Sat, 24 May 2025 16:40:21 -0400 Subject: [PATCH 30/49] docs(readme): add open PRs badge + link to samples repo + change 'Docs' to 'Documentation' (#100) --- README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 262dde51..802af5bf 100644 --- a/README.md +++ b/README.md @@ -8,14 +8,15 @@
GitHub commit activity GitHub open issues + GitHub open pull requests License PyPI version Python versions

- Docs - ◆ Samples + Documentation + ◆ SamplesToolsAgent Builder

From 173078d6290e46126c948b1aa84edb40e7b07668 Mon Sep 17 00:00:00 2001 From: Arron <139703460+awsarron@users.noreply.github.com> Date: Sat, 24 May 2025 16:52:37 -0400 Subject: [PATCH 31/49] docs(readme): add logo (#101) --- README.md | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 802af5bf..2a4c4a7f 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,14 @@ -# Strands Agents -
+
+ + Strands Agents + +
+ +

+ Strands Agents +

+

A model-driven approach to building AI agents in just a few lines of code.

From aa95e4107c6ed3392f5d486388b77324b1e0d122 Mon Sep 17 00:00:00 2001 From: Arron <139703460+awsarron@users.noreply.github.com> Date: Sat, 24 May 2025 18:05:15 -0400 Subject: [PATCH 32/49] docs(readme): add logo, title, badges, links to other repos, standardize headings (#102) --- README.md | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 2a4c4a7f..9a74250d 100644 --- a/README.md +++ b/README.md @@ -25,8 +25,10 @@

DocumentationSamples + ◆ Python SDKToolsAgent Builder + ◆ MCP Server

@@ -35,7 +37,7 @@ Strands Agents is a simple yet powerful SDK that takes a model-driven approach t ## Feature Overview - **Lightweight & Flexible**: Simple agent loop that just works and is fully customizable -- **Model Agnostic**: Support for Amazon Bedrock, Anthropic, Llama, Ollama, and custom providers +- **Model Agnostic**: Support for Amazon Bedrock, Anthropic, LiteLLM, Llama, Ollama, OpenAI, and custom providers - **Advanced Capabilities**: Multi-agent systems, autonomous agents, and streaming support - **Built-in MCP**: Native support for Model Context Protocol (MCP) servers, enabling access to thousands of pre-built tools @@ -147,6 +149,7 @@ Built-in providers: - [LiteLLM](https://strandsagents.com/latest/user-guide/concepts/model-providers/litellm/) - [LlamaAPI](https://strandsagents.com/latest/user-guide/concepts/model-providers/llamaapi/) - [Ollama](https://strandsagents.com/latest/user-guide/concepts/model-providers/ollama/) + - [OpenAI](https://strandsagents.com/latest/user-guide/concepts/model-providers/openai/) Custom providers can be implemented using [Custom Providers](https://strandsagents.com/latest/user-guide/concepts/model-providers/custom_model_provider/) @@ -174,9 +177,9 @@ For detailed guidance & examples, explore our documentation: - [API Reference](https://strandsagents.com/latest/api-reference/agent/) - [Production & Deployment Guide](https://strandsagents.com/latest/user-guide/deploy/operating-agents-in-production/) -## Contributing +## Contributing ❤️ -We welcome contributions! See our [Contributing Guide](https://github.com/strands-agents/sdk-python/blob/main/CONTRIBUTING.md) for details on: +We welcome contributions! See our [Contributing Guide](CONTRIBUTING.md) for details on: - Reporting bugs & features - Development setup - Contributing via Pull Requests @@ -187,6 +190,10 @@ We welcome contributions! See our [Contributing Guide](https://github.com/strand This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENSE) file for details. +## Security + +See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. + ## ⚠️ Preview Status Strands Agents is currently in public preview. During this period: From 9c46f42774c18fb0298ddba03c788db6e824fdea Mon Sep 17 00:00:00 2001 From: Arron <139703460+awsarron@users.noreply.github.com> Date: Sat, 24 May 2025 22:10:07 -0400 Subject: [PATCH 33/49] style(readme): use dark logo for clearer visibility when system is using light color scheme (#104) --- README.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 9a74250d..5fa90351 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,11 @@
From 63cef21ec16e41733d405ea4578d5f8845914b0f Mon Sep 17 00:00:00 2001 From: Arron <139703460+awsarron@users.noreply.github.com> Date: Sat, 24 May 2025 22:54:16 -0400 Subject: [PATCH 34/49] fix(readme): use logo that changes color automatically depending on user's color preference scheme (#105) --- README.md | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/README.md b/README.md index 5fa90351..f4c483a2 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,7 @@
From 7042af189540b9b7b95b671752ad5dcfd72634bc Mon Sep 17 00:00:00 2001 From: "Gokhan (Joe) Gultekin" Date: Sun, 25 May 2025 18:21:48 +0200 Subject: [PATCH 35/49] feat(handlers): add reasoning text to callback handler and related tests (#109) * feat(handlers): add reasoning text to callback handler and related tests * feat(handlers): removed redundant comment in .gitignore file * feat(handlers): Updated reasoningText type as (Optional[str] --- .gitignore | 1 + src/strands/handlers/callback_handler.py | 12 ++++++--- .../strands/handlers/test_callback_handler.py | 25 +++++++++++++++++++ 3 files changed, 34 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index a80f4bd1..a5cf11c4 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ __pycache__* .pytest_cache .ruff_cache *.bak +.vscode \ No newline at end of file diff --git a/src/strands/handlers/callback_handler.py b/src/strands/handlers/callback_handler.py index d6d104d8..e46cb326 100644 --- a/src/strands/handlers/callback_handler.py +++ b/src/strands/handlers/callback_handler.py @@ -17,15 +17,19 @@ def __call__(self, **kwargs: Any) -> None: Args: **kwargs: Callback event data including: - - - data (str): Text content to stream. - - complete (bool): Whether this is the final chunk of a response. - - current_tool_use (dict): Information about the current tool being used. + - reasoningText (Optional[str]): Reasoning text to print if provided. + - data (str): Text content to stream. + - complete (bool): Whether this is the final chunk of a response. + - current_tool_use (dict): Information about the current tool being used. """ + reasoningText = kwargs.get("reasoningText", False) data = kwargs.get("data", "") complete = kwargs.get("complete", False) current_tool_use = kwargs.get("current_tool_use", {}) + if reasoningText: + print(reasoningText, end="") + if data: print(data, end="" if not complete else "\n") diff --git a/tests/strands/handlers/test_callback_handler.py b/tests/strands/handlers/test_callback_handler.py index 20e238cb..6fb2af07 100644 --- a/tests/strands/handlers/test_callback_handler.py +++ b/tests/strands/handlers/test_callback_handler.py @@ -30,6 +30,31 @@ def test_call_with_empty_args(handler, mock_print): mock_print.assert_not_called() +def test_call_handler_reasoningText(handler, mock_print): + """Test calling the handler with reasoningText.""" + handler(reasoningText="This is reasoning text") + # Should print reasoning text without newline + mock_print.assert_called_once_with("This is reasoning text", end="") + + +def test_call_without_reasoningText(handler, mock_print): + """Test calling the handler without reasoningText argument.""" + handler(data="Some output") + # Should only print data, not reasoningText + mock_print.assert_called_once_with("Some output", end="") + + +def test_call_with_reasoningText_and_data(handler, mock_print): + """Test calling the handler with both reasoningText and data.""" + handler(reasoningText="Reasoning", data="Output") + # Should print reasoningText and data, both without newline + calls = [ + unittest.mock.call("Reasoning", end=""), + unittest.mock.call("Output", end=""), + ] + mock_print.assert_has_calls(calls) + + def test_call_with_data_incomplete(handler, mock_print): """Test calling the handler with data but not complete.""" handler(data="Test output") From f2d2cb60947a6ac30fd94507d793004762bbc54a Mon Sep 17 00:00:00 2001 From: Shubham Raut Date: Mon, 26 May 2025 01:49:30 +0530 Subject: [PATCH 36/49] feat: Add dynamic system prompt override functionality (#108) --- src/strands/agent/agent.py | 39 ++++++++++++++++--------------- tests/strands/agent/test_agent.py | 34 +++++++++++++++++++++++++-- 2 files changed, 52 insertions(+), 21 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 6a948980..bed9e52e 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -457,27 +457,28 @@ def _execute_event_loop_cycle(self, callback_handler: Callable, kwargs: dict[str Returns: The result of the event loop cycle. """ - kwargs.pop("agent", None) - kwargs.pop("model", None) - kwargs.pop("system_prompt", None) - kwargs.pop("tool_execution_handler", None) - kwargs.pop("event_loop_metrics", None) - kwargs.pop("callback_handler", None) - kwargs.pop("tool_handler", None) - kwargs.pop("messages", None) - kwargs.pop("tool_config", None) + # Extract parameters with fallbacks to instance values + system_prompt = kwargs.pop("system_prompt", self.system_prompt) + model = kwargs.pop("model", self.model) + tool_execution_handler = kwargs.pop("tool_execution_handler", self.thread_pool_wrapper) + event_loop_metrics = kwargs.pop("event_loop_metrics", self.event_loop_metrics) + callback_handler_override = kwargs.pop("callback_handler", callback_handler) + tool_handler = kwargs.pop("tool_handler", self.tool_handler) + messages = kwargs.pop("messages", self.messages) + tool_config = kwargs.pop("tool_config", self.tool_config) + kwargs.pop("agent", None) # Remove agent to avoid conflicts try: # Execute the main event loop cycle stop_reason, message, metrics, state = event_loop_cycle( - model=self.model, - system_prompt=self.system_prompt, - messages=self.messages, # will be modified by event_loop_cycle - tool_config=self.tool_config, - callback_handler=callback_handler, - tool_handler=self.tool_handler, - tool_execution_handler=self.thread_pool_wrapper, - event_loop_metrics=self.event_loop_metrics, + model=model, + system_prompt=system_prompt, + messages=messages, # will be modified by event_loop_cycle + tool_config=tool_config, + callback_handler=callback_handler_override, + tool_handler=tool_handler, + tool_execution_handler=tool_execution_handler, + event_loop_metrics=event_loop_metrics, agent=self, event_loop_parent_span=self.trace_span, **kwargs, @@ -488,8 +489,8 @@ def _execute_event_loop_cycle(self, callback_handler: Callable, kwargs: dict[str except ContextWindowOverflowException as e: # Try reducing the context size and retrying - self.conversation_manager.reduce_context(self.messages, e=e) - return self._execute_event_loop_cycle(callback_handler, kwargs) + self.conversation_manager.reduce_context(messages, e=e) + return self._execute_event_loop_cycle(callback_handler_override, kwargs) def _record_tool_execution( self, diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 5c7d11e4..ff70089b 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -337,17 +337,47 @@ def test_agent__call__passes_kwargs(mock_model, system_prompt, callback_handler, ], ] + override_system_prompt = "Override system prompt" + override_model = unittest.mock.Mock() + override_tool_execution_handler = unittest.mock.Mock() + override_event_loop_metrics = unittest.mock.Mock() + override_callback_handler = unittest.mock.Mock() + override_tool_handler = unittest.mock.Mock() + override_messages = [{"role": "user", "content": [{"text": "override msg"}]}] + override_tool_config = {"test": "config"} + def check_kwargs(some_value, **kwargs): assert some_value == "a_value" assert kwargs is not None + assert kwargs["system_prompt"] == override_system_prompt + assert kwargs["model"] == override_model + assert kwargs["tool_execution_handler"] == override_tool_execution_handler + assert kwargs["event_loop_metrics"] == override_event_loop_metrics + assert kwargs["callback_handler"] == override_callback_handler + assert kwargs["tool_handler"] == override_tool_handler + assert kwargs["messages"] == override_messages + assert kwargs["tool_config"] == override_tool_config + assert kwargs["agent"] == agent # Return expected values from event_loop_cycle return "stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {} mock_event_loop_cycle.side_effect = check_kwargs - agent("test message", some_value="a_value") - assert mock_event_loop_cycle.call_count == 1 + agent( + "test message", + some_value="a_value", + system_prompt=override_system_prompt, + model=override_model, + tool_execution_handler=override_tool_execution_handler, + event_loop_metrics=override_event_loop_metrics, + callback_handler=override_callback_handler, + tool_handler=override_tool_handler, + messages=override_messages, + tool_config=override_tool_config, + ) + + mock_event_loop_cycle.assert_called_once() def test_agent__call__retry_with_reduced_context(mock_model, agent, tool): From a331e63271caff8d274c5d441695039e2df063b4 Mon Sep 17 00:00:00 2001 From: fede-dash Date: Sun, 25 May 2025 21:16:03 -0400 Subject: [PATCH 37/49] Modularizing Event Loop (#106) --- src/strands/event_loop/event_loop.py | 211 +++++++++++++++++---------- 1 file changed, 134 insertions(+), 77 deletions(-) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index db5a1b97..23d7bd0f 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -28,6 +28,10 @@ logger = logging.getLogger(__name__) +MAX_ATTEMPTS = 6 +INITIAL_DELAY = 4 +MAX_DELAY = 240 # 4 minutes + def initialize_state(**kwargs: Any) -> Any: """Initialize the request state if not present. @@ -51,7 +55,7 @@ def event_loop_cycle( system_prompt: Optional[str], messages: Messages, tool_config: Optional[ToolConfig], - callback_handler: Any, + callback_handler: Callable[..., Any], tool_handler: Optional[ToolHandler], tool_execution_handler: Optional[ParallelToolExecutorInterface] = None, **kwargs: Any, @@ -130,13 +134,9 @@ def event_loop_cycle( stop_reason: StopReason usage: Any metrics: Metrics - max_attempts = 6 - initial_delay = 4 - max_delay = 240 # 4 minutes - current_delay = initial_delay # Retry loop for handling throttling exceptions - for attempt in range(max_attempts): + for attempt in range(MAX_ATTEMPTS): model_id = model.config.get("model_id") if hasattr(model, "config") else None model_invoke_span = tracer.start_model_invoke_span( parent_span=cycle_span, @@ -177,7 +177,7 @@ def event_loop_cycle( # Handle throttling errors with exponential backoff should_retry, current_delay = handle_throttling_error( - e, attempt, max_attempts, current_delay, max_delay, callback_handler, kwargs + e, attempt, MAX_ATTEMPTS, INITIAL_DELAY, MAX_DELAY, callback_handler, kwargs ) if should_retry: continue @@ -204,80 +204,35 @@ def event_loop_cycle( # If the model is requesting to use tools if stop_reason == "tool_use": - tool_uses: List[ToolUse] = [] - tool_results: List[ToolResult] = [] - invalid_tool_use_ids: List[str] = [] - - # Extract and validate tools - validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids) - - # Check if tools are available for execution - if tool_uses: - if tool_handler is None: - raise ValueError("toolUse present but tool handler not set") - if tool_config is None: - raise ValueError("toolUse present but tool config not set") - - # Create the tool handler process callable - tool_handler_process: Callable[[ToolUse], ToolResult] = partial( - tool_handler.process, - messages=messages, - model=model, - system_prompt=system_prompt, - tool_config=tool_config, - callback_handler=callback_handler, - **kwargs, + if not tool_handler: + raise EventLoopException( + Exception("Model requested tool use but no tool handler provided"), + kwargs["request_state"], ) - # Execute tools (parallel or sequential) - run_tools( - handler=tool_handler_process, - tool_uses=tool_uses, - event_loop_metrics=event_loop_metrics, - request_state=cast(Any, kwargs["request_state"]), - invalid_tool_use_ids=invalid_tool_use_ids, - tool_results=tool_results, - cycle_trace=cycle_trace, - parent_span=cycle_span, - parallel_tool_executor=tool_execution_handler, + if tool_config is None: + raise EventLoopException( + Exception("Model requested tool use but no tool config provided"), + kwargs["request_state"], ) - # Update state for the next cycle - kwargs = prepare_next_cycle(kwargs, event_loop_metrics) - - # Create the tool result message - tool_result_message: Message = { - "role": "user", - "content": [{"toolResult": result} for result in tool_results], - } - messages.append(tool_result_message) - callback_handler(message=tool_result_message) - - if cycle_span: - tracer.end_event_loop_cycle_span( - span=cycle_span, message=message, tool_result_message=tool_result_message - ) - - # Check if we should stop the event loop - if kwargs["request_state"].get("stop_event_loop"): - event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) - return ( - stop_reason, - message, - event_loop_metrics, - kwargs["request_state"], - ) - - # Recursive call to continue the conversation - return recurse_event_loop( - model=model, - system_prompt=system_prompt, - messages=messages, - tool_config=tool_config, - callback_handler=callback_handler, - tool_handler=tool_handler, - **kwargs, - ) + # Handle tool execution + return _handle_tool_execution( + stop_reason, + message, + model, + system_prompt, + messages, + tool_config, + tool_handler, + callback_handler, + tool_execution_handler, + event_loop_metrics, + cycle_trace, + cycle_span, + cycle_start_time, + kwargs, + ) # End the cycle and return results event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) @@ -377,3 +332,105 @@ def prepare_next_cycle(kwargs: Dict[str, Any], event_loop_metrics: EventLoopMetr kwargs["event_loop_parent_cycle_id"] = kwargs["event_loop_cycle_id"] return kwargs + + +def _handle_tool_execution( + stop_reason: StopReason, + message: Message, + model: Model, + system_prompt: Optional[str], + messages: Messages, + tool_config: ToolConfig, + tool_handler: ToolHandler, + callback_handler: Callable[..., Any], + tool_execution_handler: Optional[ParallelToolExecutorInterface], + event_loop_metrics: EventLoopMetrics, + cycle_trace: Trace, + cycle_span: Any, + cycle_start_time: float, + kwargs: Dict[str, Any], +) -> Tuple[StopReason, Message, EventLoopMetrics, Dict[str, Any]]: + tool_uses: List[ToolUse] = [] + tool_results: List[ToolResult] = [] + invalid_tool_use_ids: List[str] = [] + + """ + Handles the execution of tools requested by the model during an event loop cycle. + + Args: + stop_reason (StopReason): The reason the model stopped generating. + message (Message): The message from the model that may contain tool use requests. + model (Model): The model provider instance. + system_prompt (Optional[str]): The system prompt instructions for the model. + messages (Messages): The conversation history messages. + tool_config (ToolConfig): Configuration for available tools. + tool_handler (ToolHandler): Handler for tool execution. + callback_handler (Callable[..., Any]): Callback for processing events as they happen. + tool_execution_handler (Optional[ParallelToolExecutorInterface]): Optional handler for parallel tool execution. + event_loop_metrics (EventLoopMetrics): Metrics tracking object for the event loop. + cycle_trace (Trace): Trace object for the current event loop cycle. + cycle_span (Any): Span object for tracing the cycle (type may vary). + cycle_start_time (float): Start time of the current cycle. + kwargs (Dict[str, Any]): Additional keyword arguments, including request state. + + Returns: + Tuple[StopReason, Message, EventLoopMetrics, Dict[str, Any]]: + - The stop reason, + - The updated message, + - The updated event loop metrics, + - The updated request state. + """ + validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids) + + if not tool_uses: + return stop_reason, message, event_loop_metrics, kwargs["request_state"] + + tool_handler_process = partial( + tool_handler.process, + messages=messages, + model=model, + system_prompt=system_prompt, + tool_config=tool_config, + callback_handler=callback_handler, + **kwargs, + ) + + run_tools( + handler=tool_handler_process, + tool_uses=tool_uses, + event_loop_metrics=event_loop_metrics, + request_state=cast(Any, kwargs["request_state"]), + invalid_tool_use_ids=invalid_tool_use_ids, + tool_results=tool_results, + cycle_trace=cycle_trace, + parent_span=cycle_span, + parallel_tool_executor=tool_execution_handler, + ) + + kwargs = prepare_next_cycle(kwargs, event_loop_metrics) + + tool_result_message: Message = { + "role": "user", + "content": [{"toolResult": result} for result in tool_results], + } + + messages.append(tool_result_message) + callback_handler(message=tool_result_message) + + if cycle_span: + tracer = get_tracer() + tracer.end_event_loop_cycle_span(span=cycle_span, message=message, tool_result_message=tool_result_message) + + if kwargs["request_state"].get("stop_event_loop", False): + event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) + return stop_reason, message, event_loop_metrics, kwargs["request_state"] + + return recurse_event_loop( + model=model, + system_prompt=system_prompt, + messages=messages, + tool_config=tool_config, + callback_handler=callback_handler, + tool_handler=tool_handler, + **kwargs, + ) From bd60f90e53d11757b30195c5024575ab05ccfe04 Mon Sep 17 00:00:00 2001 From: Arron <139703460+awsarron@users.noreply.github.com> Date: Mon, 26 May 2025 13:06:20 -0400 Subject: [PATCH 38/49] fix(telemetry): fix agent span start and end when using Agent.stream_async() (#119) --- src/strands/agent/agent.py | 66 +++++++++++++++++------ tests/strands/agent/test_agent.py | 87 +++++++++++++++++++++++++++++-- 2 files changed, 134 insertions(+), 19 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index bed9e52e..0f912b54 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -328,27 +328,17 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult: - metrics: Performance metrics from the event loop - state: The final state of the event loop """ - model_id = self.model.config.get("model_id") if hasattr(self.model, "config") else None - - self.trace_span = self.tracer.start_agent_span( - prompt=prompt, - model_id=model_id, - tools=self.tool_names, - system_prompt=self.system_prompt, - custom_trace_attributes=self.trace_attributes, - ) + self._start_agent_trace_span(prompt) try: # Run the event loop and get the result result = self._run_loop(prompt, kwargs) - if self.trace_span: - self.tracer.end_agent_span(span=self.trace_span, response=result) + self._end_agent_trace_span(response=result) return result except Exception as e: - if self.trace_span: - self.tracer.end_agent_span(span=self.trace_span, error=e) + self._end_agent_trace_span(error=e) # Re-raise the exception to preserve original behavior raise @@ -383,6 +373,8 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]: yield event["data"] ``` """ + self._start_agent_trace_span(prompt) + _stop_event = uuid4() queue = asyncio.Queue[Any]() @@ -400,8 +392,10 @@ def target_callback() -> None: nonlocal kwargs try: - self._run_loop(prompt, kwargs, supplementary_callback_handler=queuing_callback_handler) - except BaseException as e: + result = self._run_loop(prompt, kwargs, supplementary_callback_handler=queuing_callback_handler) + self._end_agent_trace_span(response=result) + except Exception as e: + self._end_agent_trace_span(error=e) enqueue(e) finally: enqueue(_stop_event) @@ -414,7 +408,7 @@ def target_callback() -> None: item = await queue.get() if item == _stop_event: break - if isinstance(item, BaseException): + if isinstance(item, Exception): raise item yield item finally: @@ -546,3 +540,43 @@ def _record_tool_execution( messages.append(tool_use_msg) messages.append(tool_result_msg) messages.append(assistant_msg) + + def _start_agent_trace_span(self, prompt: str) -> None: + """Starts a trace span for the agent. + + Args: + prompt: The natural language prompt from the user. + """ + model_id = self.model.config.get("model_id") if hasattr(self.model, "config") else None + + self.trace_span = self.tracer.start_agent_span( + prompt=prompt, + model_id=model_id, + tools=self.tool_names, + system_prompt=self.system_prompt, + custom_trace_attributes=self.trace_attributes, + ) + + def _end_agent_trace_span( + self, + response: Optional[AgentResult] = None, + error: Optional[Exception] = None, + ) -> None: + """Ends a trace span for the agent. + + Args: + span: The span to end. + response: Response to record as a trace attribute. + error: Error to record as a trace attribute. + """ + if self.trace_span: + trace_attributes: Dict[str, Any] = { + "span": self.trace_span, + } + + if response: + trace_attributes["response"] = response + if error: + trace_attributes["error"] = error + + self.tracer.end_agent_span(**trace_attributes) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index ff70089b..828ae8f9 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -9,7 +9,8 @@ import pytest import strands -from strands.agent.agent import Agent +from strands import Agent +from strands.agent import AgentResult from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.handlers.callback_handler import PrintingCallbackHandler, null_callback_handler @@ -687,8 +688,6 @@ def test_agent_with_callback_handler_none_uses_null_handler(): @pytest.mark.asyncio async def test_stream_async_returns_all_events(mock_event_loop_cycle): - mock_event_loop_cycle.side_effect = ValueError("Test exception") - agent = Agent() # Define the side effect to simulate callback handler being called multiple times @@ -952,6 +951,52 @@ def test_agent_call_creates_and_ends_span_on_success(mock_get_tracer, mock_model mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, response=result) +@pytest.mark.asyncio +@unittest.mock.patch("strands.agent.agent.get_tracer") +async def test_agent_stream_async_creates_and_ends_span_on_success(mock_get_tracer, mock_event_loop_cycle): + """Test that stream_async creates and ends a span when the call succeeds.""" + # Setup mock tracer and span + mock_tracer = unittest.mock.MagicMock() + mock_span = unittest.mock.MagicMock() + mock_tracer.start_agent_span.return_value = mock_span + mock_get_tracer.return_value = mock_tracer + + # Define the side effect to simulate callback handler being called multiple times + def call_callback_handler(*args, **kwargs): + # Extract the callback handler from kwargs + callback_handler = kwargs.get("callback_handler") + # Call the callback handler with different data values + callback_handler(data="First chunk") + callback_handler(data="Second chunk") + callback_handler(data="Final chunk", complete=True) + # Return expected values from event_loop_cycle + return "stop", {"role": "assistant", "content": [{"text": "Agent Response"}]}, {}, {} + + mock_event_loop_cycle.side_effect = call_callback_handler + + # Create agent and make a call + agent = Agent(model=mock_model) + iterator = agent.stream_async("test prompt") + async for _event in iterator: + pass # NoOp + + # Verify span was created + mock_tracer.start_agent_span.assert_called_once_with( + prompt="test prompt", + model_id=unittest.mock.ANY, + tools=agent.tool_names, + system_prompt=agent.system_prompt, + custom_trace_attributes=agent.trace_attributes, + ) + + expected_response = AgentResult( + stop_reason="stop", message={"role": "assistant", "content": [{"text": "Agent Response"}]}, metrics={}, state={} + ) + + # Verify span was ended with the result + mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, response=expected_response) + + @unittest.mock.patch("strands.agent.agent.get_tracer") def test_agent_call_creates_and_ends_span_on_exception(mock_get_tracer, mock_model): """Test that __call__ creates and ends a span when an exception occurs.""" @@ -985,6 +1030,42 @@ def test_agent_call_creates_and_ends_span_on_exception(mock_get_tracer, mock_mod mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception) +@pytest.mark.asyncio +@unittest.mock.patch("strands.agent.agent.get_tracer") +async def test_agent_stream_async_creates_and_ends_span_on_exception(mock_get_tracer, mock_model): + """Test that stream_async creates and ends a span when the call succeeds.""" + # Setup mock tracer and span + mock_tracer = unittest.mock.MagicMock() + mock_span = unittest.mock.MagicMock() + mock_tracer.start_agent_span.return_value = mock_span + mock_get_tracer.return_value = mock_tracer + + # Define the side effect to simulate callback handler raising an Exception + test_exception = ValueError("Test exception") + mock_model.mock_converse.side_effect = test_exception + + # Create agent and make a call + agent = Agent(model=mock_model) + + # Call the agent and catch the exception + with pytest.raises(ValueError): + iterator = agent.stream_async("test prompt") + async for _event in iterator: + pass # NoOp + + # Verify span was created + mock_tracer.start_agent_span.assert_called_once_with( + prompt="test prompt", + model_id=unittest.mock.ANY, + tools=agent.tool_names, + system_prompt=agent.system_prompt, + custom_trace_attributes=agent.trace_attributes, + ) + + # Verify span was ended with the exception + mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception) + + @unittest.mock.patch("strands.agent.agent.get_tracer") def test_event_loop_cycle_includes_parent_span(mock_get_tracer, mock_event_loop_cycle, mock_model): """Test that event_loop_cycle is called with the parent span.""" From f6c8d9d928cc3b7829ce713187a0f2c6b77d47ae Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Mon, 26 May 2025 14:35:03 -0400 Subject: [PATCH 39/49] feat: Update SlidingWindowConversationManager (#120) --- .../sliding_window_conversation_manager.py | 71 ++++++------------- tests/strands/agent/test_agent.py | 2 +- .../agent/test_conversation_manager.py | 55 +++++--------- 3 files changed, 39 insertions(+), 89 deletions(-) diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index 4b11e81c..f367b272 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -1,12 +1,10 @@ """Sliding window conversation history management.""" -import json import logging -from typing import List, Optional, cast +from typing import Optional -from ...types.content import ContentBlock, Message, Messages +from ...types.content import Message, Messages from ...types.exceptions import ContextWindowOverflowException -from ...types.tools import ToolResult from .conversation_manager import ConversationManager logger = logging.getLogger(__name__) @@ -110,8 +108,9 @@ def _remove_dangling_messages(self, messages: Messages) -> None: def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> None: """Trim the oldest messages to reduce the conversation context size. - The method handles special cases where tool results need to be converted to regular content blocks to maintain - conversation coherence after trimming. + The method handles special cases where trimming the messages leads to: + - toolResult with no corresponding toolUse + - toolUse with no corresponding toolResult Args: messages: The messages to reduce. @@ -126,52 +125,24 @@ def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> N # If the number of messages is less than the window_size, then we default to 2, otherwise, trim to window size trim_index = 2 if len(messages) <= self.window_size else len(messages) - self.window_size - # Throw if we cannot trim any messages from the conversation - if trim_index >= len(messages): - raise ContextWindowOverflowException("Unable to trim conversation context!") from e - - # If the message at the cut index has ToolResultContent, then we map that to ContentBlock. This gets around the - # limitation of needing ToolUse and ToolResults to be paired. - if any("toolResult" in content for content in messages[trim_index]["content"]): - if len(messages[trim_index]["content"]) == 1: - messages[trim_index]["content"] = self._map_tool_result_content( - cast(ToolResult, messages[trim_index]["content"][0]["toolResult"]) + # Find the next valid trim_index + while trim_index < len(messages): + if ( + # Oldest message cannot be a toolResult because it needs a toolUse preceding it + any("toolResult" in content for content in messages[trim_index]["content"]) + or ( + # Oldest message can be a toolUse only if a toolResult immediately follows it. + any("toolUse" in content for content in messages[trim_index]["content"]) + and trim_index + 1 < len(messages) + and not any("toolResult" in content for content in messages[trim_index + 1]["content"]) ) - - # If there is more content than just one ToolResultContent, then we cannot cut at this index. + ): + trim_index += 1 else: - raise ContextWindowOverflowException("Unable to trim conversation context!") from e + break + else: + # If we didn't find a valid trim_index, then we throw + raise ContextWindowOverflowException("Unable to trim conversation context!") from e # Overwrite message history messages[:] = messages[trim_index:] - - def _map_tool_result_content(self, tool_result: ToolResult) -> List[ContentBlock]: - """Convert a ToolResult to a list of standard ContentBlocks. - - This method transforms tool result content into standard content blocks that can be preserved when trimming the - conversation history. - - Args: - tool_result: The ToolResult to convert. - - Returns: - A list of content blocks representing the tool result. - """ - contents = [] - text_content = "Tool Result Status: " + tool_result["status"] if tool_result["status"] else "" - - for tool_result_content in tool_result["content"]: - if "text" in tool_result_content: - text_content = "\nTool Result Text Content: " + tool_result_content["text"] + f"\n{text_content}" - elif "json" in tool_result_content: - text_content = ( - "\nTool Result JSON Content: " + json.dumps(tool_result_content["json"]) + f"\n{text_content}" - ) - elif "image" in tool_result_content: - contents.append(ContentBlock(image=tool_result_content["image"])) - elif "document" in tool_result_content: - contents.append(ContentBlock(document=tool_result_content["document"])) - else: - logger.warning("unsupported content type") - contents.append(ContentBlock(text=text_content)) - return contents diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 828ae8f9..ea06fb4e 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -459,7 +459,7 @@ def test_agent__call__always_sliding_window_conversation_manager_doesnt_infinite with pytest.raises(ContextWindowOverflowException): agent("Test!") - assert conversation_manager_spy.reduce_context.call_count == 251 + assert conversation_manager_spy.reduce_context.call_count > 0 assert conversation_manager_spy.apply_management.call_count == 1 diff --git a/tests/strands/agent/test_conversation_manager.py b/tests/strands/agent/test_conversation_manager.py index 2f6ee77d..b6132f1d 100644 --- a/tests/strands/agent/test_conversation_manager.py +++ b/tests/strands/agent/test_conversation_manager.py @@ -111,41 +111,7 @@ def conversation_manager(request): {"role": "assistant", "content": [{"text": "Second response"}]}, ], ), - # 7 - Message count above max window size - Remove dangling tool uses and tool results - ( - {"window_size": 1}, - [ - {"role": "user", "content": [{"text": "First message"}]}, - {"role": "assistant", "content": [{"toolUse": {"toolUseId": "321", "name": "tool1", "input": {}}}]}, - { - "role": "user", - "content": [ - {"toolResult": {"toolUseId": "123", "content": [{"text": "Hello!"}], "status": "success"}} - ], - }, - ], - [ - { - "role": "user", - "content": [{"text": "\nTool Result Text Content: Hello!\nTool Result Status: success"}], - }, - ], - ), - # 8 - Message count above max window size - Remove multiple tool use/tool result pairs - ( - {"window_size": 1}, - [ - {"role": "user", "content": [{"toolResult": {"toolUseId": "123", "content": [], "status": "success"}}]}, - {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "tool1", "input": {}}}]}, - {"role": "user", "content": [{"toolResult": {"toolUseId": "456", "content": [], "status": "success"}}]}, - {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool1", "input": {}}}]}, - {"role": "user", "content": [{"toolResult": {"toolUseId": "789", "content": [], "status": "success"}}]}, - ], - [ - {"role": "user", "content": [{"text": "Tool Result Status: success"}]}, - ], - ), - # 9 - Message count above max window size - Preserve tool use/tool result pairs + # 7 - Message count above max window size - Preserve tool use/tool result pairs ( {"window_size": 2}, [ @@ -158,7 +124,7 @@ def conversation_manager(request): {"role": "user", "content": [{"toolResult": {"toolUseId": "456", "content": [], "status": "success"}}]}, ], ), - # 10 - Test sliding window behavior - preserve tool use/result pairs across cut boundary + # 8 - Test sliding window behavior - preserve tool use/result pairs across cut boundary ( {"window_size": 3}, [ @@ -173,7 +139,7 @@ def conversation_manager(request): {"role": "assistant", "content": [{"text": "Response after tool use"}]}, ], ), - # 11 - Test sliding window with multiple tool pairs that need preservation + # 9 - Test sliding window with multiple tool pairs that need preservation ( {"window_size": 4}, [ @@ -185,7 +151,6 @@ def conversation_manager(request): {"role": "assistant", "content": [{"text": "Final response"}]}, ], [ - {"role": "user", "content": [{"text": "Tool Result Status: success"}]}, {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool2", "input": {}}}]}, {"role": "user", "content": [{"toolResult": {"toolUseId": "456", "content": [], "status": "success"}}]}, {"role": "assistant", "content": [{"text": "Final response"}]}, @@ -200,6 +165,20 @@ def test_apply_management(conversation_manager, messages, expected_messages): assert messages == expected_messages +def test_sliding_window_conversation_manager_with_untrimmable_history_raises_context_window_overflow_exception(): + manager = strands.agent.conversation_manager.SlidingWindowConversationManager(1) + messages = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool1", "input": {}}}]}, + {"role": "user", "content": [{"toolResult": {"toolUseId": "789", "content": [], "status": "success"}}]}, + ] + original_messages = messages.copy() + + with pytest.raises(ContextWindowOverflowException): + manager.apply_management(messages) + + assert messages == original_messages + + def test_null_conversation_manager_reduce_context_raises_context_window_overflow_exception(): """Test that NullConversationManager doesn't modify messages.""" manager = strands.agent.conversation_manager.NullConversationManager() From 8a29ae520446e68f690dc38938ef47ab86911e89 Mon Sep 17 00:00:00 2001 From: Arron <139703460+awsarron@users.noreply.github.com> Date: Mon, 26 May 2025 15:37:50 -0400 Subject: [PATCH 40/49] v0.1.5 (#121) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 52394cea..a3b36cab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "strands-agents" -version = "0.1.4" +version = "0.1.5" description = "A model-driven approach to building AI agents in just a few lines of code" readme = "README.md" requires-python = ">=3.10" From b04c4bebc71dbdeb22cd5be2303144ae05f4d840 Mon Sep 17 00:00:00 2001 From: Arron <139703460+awsarron@users.noreply.github.com> Date: Tue, 27 May 2025 13:13:02 -0400 Subject: [PATCH 41/49] style(callback_handler): fix docstring for PrintingCallbackHandler.__call__ (#126) --- src/strands/handlers/callback_handler.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/strands/handlers/callback_handler.py b/src/strands/handlers/callback_handler.py index e46cb326..4b794b4f 100644 --- a/src/strands/handlers/callback_handler.py +++ b/src/strands/handlers/callback_handler.py @@ -17,10 +17,10 @@ def __call__(self, **kwargs: Any) -> None: Args: **kwargs: Callback event data including: - - reasoningText (Optional[str]): Reasoning text to print if provided. - - data (str): Text content to stream. - - complete (bool): Whether this is the final chunk of a response. - - current_tool_use (dict): Information about the current tool being used. + - reasoningText (Optional[str]): Reasoning text to print if provided. + - data (str): Text content to stream. + - complete (bool): Whether this is the final chunk of a response. + - current_tool_use (dict): Information about the current tool being used. """ reasoningText = kwargs.get("reasoningText", False) data = kwargs.get("data", "") From e907ac6e97041d3620cae98a631c0fb115b28bcc Mon Sep 17 00:00:00 2001 From: Clare Liguori Date: Tue, 27 May 2025 12:19:46 -0700 Subject: [PATCH 42/49] chore(tests): Add unit tests for user agent changes (#125) --- tests/strands/models/test_bedrock.py | 105 +++++++++++++++++++++++++-- 1 file changed, 97 insertions(+), 8 deletions(-) diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 0844c8cd..2f5fc4ad 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -3,6 +3,7 @@ import boto3 import pytest +from botocore.config import Config as BotocoreConfig from botocore.exceptions import ClientError, EventStreamError import strands @@ -116,6 +117,54 @@ def test__init__with_region_and_session_raises_value_error(): _ = BedrockModel(region_name="us-east-1", boto_session=boto3.Session(region_name="us-east-1")) +def test__init__default_user_agent(bedrock_client): + """Set user agent when no boto_client_config is provided.""" + with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls: + mock_session = mock_session_cls.return_value + _ = BedrockModel() + + # Verify the client was created with the correct config + mock_session.client.assert_called_once() + args, kwargs = mock_session.client.call_args + assert kwargs["service_name"] == "bedrock-runtime" + assert isinstance(kwargs["config"], BotocoreConfig) + assert kwargs["config"].user_agent_extra == "strands-agents" + + +def test__init__with_custom_boto_client_config_no_user_agent(bedrock_client): + """Set user agent when boto_client_config is provided without user_agent_extra.""" + custom_config = BotocoreConfig(read_timeout=900) + + with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls: + mock_session = mock_session_cls.return_value + _ = BedrockModel(boto_client_config=custom_config) + + # Verify the client was created with the correct config + mock_session.client.assert_called_once() + args, kwargs = mock_session.client.call_args + assert kwargs["service_name"] == "bedrock-runtime" + assert isinstance(kwargs["config"], BotocoreConfig) + assert kwargs["config"].user_agent_extra == "strands-agents" + assert kwargs["config"].read_timeout == 900 + + +def test__init__with_custom_boto_client_config_with_user_agent(bedrock_client): + """Append to existing user agent when boto_client_config is provided with user_agent_extra.""" + custom_config = BotocoreConfig(user_agent_extra="existing-agent", read_timeout=900) + + with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls: + mock_session = mock_session_cls.return_value + _ = BedrockModel(boto_client_config=custom_config) + + # Verify the client was created with the correct config + mock_session.client.assert_called_once() + args, kwargs = mock_session.client.call_args + assert kwargs["service_name"] == "bedrock-runtime" + assert isinstance(kwargs["config"], BotocoreConfig) + assert kwargs["config"].user_agent_extra == "existing-agent strands-agents" + assert kwargs["config"].read_timeout == 900 + + def test__init__model_config(bedrock_client): _ = bedrock_client @@ -381,7 +430,15 @@ def test_converse_input_guardrails(bedrock_client, model, messages, tool_spec, m "guardrail": { "inputAssessment": { "3e59qlue4hag": { - "wordPolicy": {"customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}]} + "wordPolicy": { + "customWords": [ + { + "match": "CACTUS", + "action": "BLOCKED", + "detected": True, + } + ] + } } } } @@ -406,7 +463,10 @@ def test_converse_input_guardrails(bedrock_client, model, messages, tool_spec, m chunks = model.converse(messages, [tool_spec]) tru_chunks = list(chunks) - exp_chunks = [{"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, metadata_event] + exp_chunks = [ + {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, + metadata_event, + ] assert tru_chunks == exp_chunks bedrock_client.converse_stream.assert_called_once_with(**request) @@ -424,7 +484,13 @@ def test_converse_output_guardrails(bedrock_client, model, messages, tool_spec, "3e59qlue4hag": [ { "wordPolicy": { - "customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}] + "customWords": [ + { + "match": "CACTUS", + "action": "BLOCKED", + "detected": True, + } + ] }, } ] @@ -451,7 +517,10 @@ def test_converse_output_guardrails(bedrock_client, model, messages, tool_spec, chunks = model.converse(messages, [tool_spec]) tru_chunks = list(chunks) - exp_chunks = [{"redactContent": {"redactAssistantContentMessage": "[Assistant output redacted.]"}}, metadata_event] + exp_chunks = [ + {"redactContent": {"redactAssistantContentMessage": "[Assistant output redacted.]"}}, + metadata_event, + ] assert tru_chunks == exp_chunks bedrock_client.converse_stream.assert_called_once_with(**request) @@ -471,7 +540,13 @@ def test_converse_output_guardrails_redacts_input_and_output( "3e59qlue4hag": [ { "wordPolicy": { - "customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}] + "customWords": [ + { + "match": "CACTUS", + "action": "BLOCKED", + "detected": True, + } + ] }, } ] @@ -521,7 +596,13 @@ def test_converse_output_no_blocked_guardrails_doesnt_redact( "3e59qlue4hag": [ { "wordPolicy": { - "customWords": [{"match": "CACTUS", "action": "NONE", "detected": True}] + "customWords": [ + { + "match": "CACTUS", + "action": "NONE", + "detected": True, + } + ] }, } ] @@ -567,7 +648,13 @@ def test_converse_output_no_guardrail_redact( "3e59qlue4hag": [ { "wordPolicy": { - "customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}] + "customWords": [ + { + "match": "CACTUS", + "action": "BLOCKED", + "detected": True, + } + ] }, } ] @@ -591,7 +678,9 @@ def test_converse_output_no_guardrail_redact( } model.update_config( - additional_request_fields=additional_request_fields, guardrail_redact_output=False, guardrail_redact_input=False + additional_request_fields=additional_request_fields, + guardrail_redact_output=False, + guardrail_redact_input=False, ) chunks = model.converse(messages, [tool_spec]) From c3895d486dec0de019937365e35e2c6423b36d20 Mon Sep 17 00:00:00 2001 From: fede-dash Date: Tue, 27 May 2025 19:12:04 -0400 Subject: [PATCH 43/49] Increasing Coverage Message Processor (#115) --- .../event_loop/test_message_processor.py | 128 ++++++++++++++++++ 1 file changed, 128 insertions(+) create mode 100644 tests/strands/event_loop/test_message_processor.py diff --git a/tests/strands/event_loop/test_message_processor.py b/tests/strands/event_loop/test_message_processor.py new file mode 100644 index 00000000..395c71a1 --- /dev/null +++ b/tests/strands/event_loop/test_message_processor.py @@ -0,0 +1,128 @@ +import copy + +import pytest + +from strands.event_loop import message_processor + + +@pytest.mark.parametrize( + "messages,expected,expected_messages", + [ + # Orphaned toolUse with empty input, no toolResult + ( + [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "input": {}, "name": "foo"}}]}, + {"role": "user", "content": [{"toolResult": {"toolUseId": "2"}}]}, + ], + True, + [ + {"role": "assistant", "content": [{"text": "[Attempted to use foo, but operation was canceled]"}]}, + {"role": "user", "content": [{"toolResult": {"toolUseId": "2"}}]}, + ], + ), + # toolUse with input, has matching toolResult + ( + [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "input": {"a": 1}, "name": "foo"}}]}, + {"role": "user", "content": [{"toolResult": {"toolUseId": "1"}}]}, + ], + False, + [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "input": {"a": 1}, "name": "foo"}}]}, + {"role": "user", "content": [{"toolResult": {"toolUseId": "1"}}]}, + ], + ), + # No messages + ( + [], + False, + [], + ), + ], +) +def test_clean_orphaned_empty_tool_uses(messages, expected, expected_messages): + test_messages = copy.deepcopy(messages) + result = message_processor.clean_orphaned_empty_tool_uses(test_messages) + assert result == expected + assert test_messages == expected_messages + + +@pytest.mark.parametrize( + "messages,expected_idx", + [ + ( + [ + {"role": "user", "content": [{"text": "hi"}]}, + {"role": "user", "content": [{"toolResult": {"toolUseId": "1"}}]}, + {"role": "assistant", "content": [{"text": "ok"}]}, + ], + 1, + ), + ( + [ + {"role": "user", "content": [{"text": "hi"}]}, + {"role": "assistant", "content": [{"text": "ok"}]}, + ], + None, + ), + ( + [], + None, + ), + ], +) +def test_find_last_message_with_tool_results(messages, expected_idx): + idx = message_processor.find_last_message_with_tool_results(messages) + assert idx == expected_idx + + +@pytest.mark.parametrize( + "messages,msg_idx,expected_changed,expected_content", + [ + ( + [ + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "1", "status": "ok", "content": [{"text": "big"}]}}], + } + ], + 0, + True, + [ + { + "toolResult": { + "toolUseId": "1", + "status": "error", + "content": [{"text": "The tool result was too large!"}], + } + } + ], + ), + ( + [{"role": "user", "content": [{"text": "no tool result"}]}], + 0, + False, + [{"text": "no tool result"}], + ), + ( + [], + 0, + False, + [], + ), + ( + [{"role": "user", "content": [{"toolResult": {"toolUseId": "1"}}]}], + 2, + False, + [{"toolResult": {"toolUseId": "1"}}], + ), + ], +) +def test_truncate_tool_results(messages, msg_idx, expected_changed, expected_content): + test_messages = copy.deepcopy(messages) + changed = message_processor.truncate_tool_results(test_messages, msg_idx) + assert changed == expected_changed + if 0 <= msg_idx < len(test_messages): + assert test_messages[msg_idx]["content"] == expected_content + else: + assert test_messages == messages From 3100ea0dc8b694ff8db673000fc914e26ee32abd Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Tue, 27 May 2025 19:30:23 -0400 Subject: [PATCH 44/49] feat: Add non-streaming support to BedrockModel (#75) * feat: Add non-streaming support to BedrockModel * fix: Add more test coverage * fix: Update with pr comments --- README.md | 1 + src/strands/models/bedrock.py | 227 ++++++++++++++---- src/strands/types/streaming.py | 2 +- tests-integ/test_model_bedrock.py | 120 ++++++++++ tests/strands/models/test_bedrock.py | 331 ++++++++++++++++++++++++++- 5 files changed, 627 insertions(+), 54 deletions(-) create mode 100644 tests-integ/test_model_bedrock.py diff --git a/README.md b/README.md index f4c483a2..ed98d001 100644 --- a/README.md +++ b/README.md @@ -123,6 +123,7 @@ from strands.models.llamaapi import LlamaAPIModel bedrock_model = BedrockModel( model_id="us.amazon.nova-pro-v1:0", temperature=0.3, + streaming=True, # Enable/disable streaming ) agent = Agent(model=bedrock_model) agent("Tell me about Agentic AI") diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 05d89923..9bbcca7d 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -3,13 +3,14 @@ - Docs: https://aws.amazon.com/bedrock/ """ +import json import logging import os -from typing import Any, Iterable, Literal, Optional, cast +from typing import Any, Iterable, List, Literal, Optional, cast import boto3 from botocore.config import Config as BotocoreConfig -from botocore.exceptions import ClientError, EventStreamError +from botocore.exceptions import ClientError from typing_extensions import TypedDict, Unpack, override from ..types.content import Messages @@ -61,6 +62,7 @@ class BedrockConfig(TypedDict, total=False): max_tokens: Maximum number of tokens to generate in the response model_id: The Bedrock model ID (e.g., "us.anthropic.claude-3-7-sonnet-20250219-v1:0") stop_sequences: List of sequences that will stop generation when encountered + streaming: Flag to enable/disable streaming. Defaults to True. temperature: Controls randomness in generation (higher = more random) top_p: Controls diversity via nucleus sampling (alternative to temperature) """ @@ -81,6 +83,7 @@ class BedrockConfig(TypedDict, total=False): max_tokens: Optional[int] model_id: str stop_sequences: Optional[list[str]] + streaming: Optional[bool] temperature: Optional[float] top_p: Optional[float] @@ -246,11 +249,68 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: """ return cast(StreamEvent, event) + def _has_blocked_guardrail(self, guardrail_data: dict[str, Any]) -> bool: + """Check if guardrail data contains any blocked policies. + + Args: + guardrail_data: Guardrail data from trace information. + + Returns: + True if any blocked guardrail is detected, False otherwise. + """ + input_assessment = guardrail_data.get("inputAssessment", {}) + output_assessments = guardrail_data.get("outputAssessments", {}) + + # Check input assessments + if any(self._find_detected_and_blocked_policy(assessment) for assessment in input_assessment.values()): + return True + + # Check output assessments + if any(self._find_detected_and_blocked_policy(assessment) for assessment in output_assessments.values()): + return True + + return False + + def _generate_redaction_events(self) -> list[StreamEvent]: + """Generate redaction events based on configuration. + + Returns: + List of redaction events to yield. + """ + events: List[StreamEvent] = [] + + if self.config.get("guardrail_redact_input", True): + logger.debug("Redacting user input due to guardrail.") + events.append( + { + "redactContent": { + "redactUserContentMessage": self.config.get( + "guardrail_redact_input_message", "[User input redacted.]" + ) + } + } + ) + + if self.config.get("guardrail_redact_output", False): + logger.debug("Redacting assistant output due to guardrail.") + events.append( + { + "redactContent": { + "redactAssistantContentMessage": self.config.get( + "guardrail_redact_output_message", "[Assistant output redacted.]" + ) + } + } + ) + + return events + @override - def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: - """Send the request to the Bedrock model and get the streaming response. + def stream(self, request: dict[str, Any]) -> Iterable[StreamEvent]: + """Send the request to the Bedrock model and get the response. - This method calls the Bedrock converse_stream API and returns the stream of response events. + This method calls either the Bedrock converse_stream API or the converse API + based on the streaming parameter in the configuration. Args: request: The formatted request to send to the Bedrock model @@ -260,63 +320,132 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: Raises: ContextWindowOverflowException: If the input exceeds the model's context window. - EventStreamError: For all other Bedrock API errors. + ModelThrottledException: If the model service is throttling requests. """ + streaming = self.config.get("streaming", True) + try: - response = self.client.converse_stream(**request) - for chunk in response["stream"]: - if self.config.get("guardrail_redact_input", True) or self.config.get("guardrail_redact_output", False): + if streaming: + # Streaming implementation + response = self.client.converse_stream(**request) + for chunk in response["stream"]: if ( "metadata" in chunk and "trace" in chunk["metadata"] and "guardrail" in chunk["metadata"]["trace"] ): - inputAssessment = chunk["metadata"]["trace"]["guardrail"].get("inputAssessment", {}) - outputAssessments = chunk["metadata"]["trace"]["guardrail"].get("outputAssessments", {}) - - # Check if an input or output guardrail was triggered - if any( - self._find_detected_and_blocked_policy(assessment) - for assessment in inputAssessment.values() - ) or any( - self._find_detected_and_blocked_policy(assessment) - for assessment in outputAssessments.values() - ): - if self.config.get("guardrail_redact_input", True): - logger.debug("Found blocked input guardrail. Redacting input.") - yield { - "redactContent": { - "redactUserContentMessage": self.config.get( - "guardrail_redact_input_message", "[User input redacted.]" - ) - } - } - if self.config.get("guardrail_redact_output", False): - logger.debug("Found blocked output guardrail. Redacting output.") - yield { - "redactContent": { - "redactAssistantContentMessage": self.config.get( - "guardrail_redact_output_message", "[Assistant output redacted.]" - ) - } - } + guardrail_data = chunk["metadata"]["trace"]["guardrail"] + if self._has_blocked_guardrail(guardrail_data): + yield from self._generate_redaction_events() + yield chunk + else: + # Non-streaming implementation + response = self.client.converse(**request) + + # Convert and yield from the response + yield from self._convert_non_streaming_to_streaming(response) - yield chunk - except EventStreamError as e: - # Handle throttling that occurs mid-stream? - if "ThrottlingException" in str(e) and "ConverseStream" in str(e): - raise ModelThrottledException(str(e)) from e + # Check for guardrail triggers after yielding any events (same as streaming path) + if ( + "trace" in response + and "guardrail" in response["trace"] + and self._has_blocked_guardrail(response["trace"]["guardrail"]) + ): + yield from self._generate_redaction_events() - if any(overflow_message in str(e) for overflow_message in BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES): + except ClientError as e: + error_message = str(e) + + # Handle throttling error + if e.response["Error"]["Code"] == "ThrottlingException": + raise ModelThrottledException(error_message) from e + + # Handle context window overflow + if any(overflow_message in error_message for overflow_message in BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES): logger.warning("bedrock threw context window overflow error") raise ContextWindowOverflowException(e) from e + + # Otherwise raise the error raise e - except ClientError as e: - # Handle throttling that occurs at the beginning of the call - if e.response["Error"]["Code"] == "ThrottlingException": - raise ModelThrottledException(str(e)) from e - raise + def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Iterable[StreamEvent]: + """Convert a non-streaming response to the streaming format. + + Args: + response: The non-streaming response from the Bedrock model. + + Returns: + An iterable of response events in the streaming format. + """ + # Yield messageStart event + yield {"messageStart": {"role": response["output"]["message"]["role"]}} + + # Process content blocks + for content in response["output"]["message"]["content"]: + # Yield contentBlockStart event if needed + if "toolUse" in content: + yield { + "contentBlockStart": { + "start": { + "toolUse": { + "toolUseId": content["toolUse"]["toolUseId"], + "name": content["toolUse"]["name"], + } + }, + } + } + + # For tool use, we need to yield the input as a delta + input_value = json.dumps(content["toolUse"]["input"]) + + yield {"contentBlockDelta": {"delta": {"toolUse": {"input": input_value}}}} + elif "text" in content: + # Then yield the text as a delta + yield { + "contentBlockDelta": { + "delta": {"text": content["text"]}, + } + } + elif "reasoningContent" in content: + # Then yield the reasoning content as a delta + yield { + "contentBlockDelta": { + "delta": {"reasoningContent": {"text": content["reasoningContent"]["reasoningText"]["text"]}} + } + } + + if "signature" in content["reasoningContent"]["reasoningText"]: + yield { + "contentBlockDelta": { + "delta": { + "reasoningContent": { + "signature": content["reasoningContent"]["reasoningText"]["signature"] + } + } + } + } + + # Yield contentBlockStop event + yield {"contentBlockStop": {}} + + # Yield messageStop event + yield { + "messageStop": { + "stopReason": response["stopReason"], + "additionalModelResponseFields": response.get("additionalModelResponseFields"), + } + } + + # Yield metadata event + if "usage" in response or "metrics" in response or "trace" in response: + metadata: StreamEvent = {"metadata": {}} + if "usage" in response: + metadata["metadata"]["usage"] = response["usage"] + if "metrics" in response: + metadata["metadata"]["metrics"] = response["metrics"] + if "trace" in response: + metadata["metadata"]["trace"] = response["trace"] + yield metadata def _find_detected_and_blocked_policy(self, input: Any) -> bool: """Recursively checks if the assessment contains a detected and blocked guardrail. diff --git a/src/strands/types/streaming.py b/src/strands/types/streaming.py index db600f15..9c99b210 100644 --- a/src/strands/types/streaming.py +++ b/src/strands/types/streaming.py @@ -157,7 +157,7 @@ class ModelStreamErrorEvent(ExceptionEvent): originalStatusCode: int -class RedactContentEvent(TypedDict): +class RedactContentEvent(TypedDict, total=False): """Event for redacting content. Attributes: diff --git a/tests-integ/test_model_bedrock.py b/tests-integ/test_model_bedrock.py new file mode 100644 index 00000000..a6a29aa9 --- /dev/null +++ b/tests-integ/test_model_bedrock.py @@ -0,0 +1,120 @@ +import pytest + +import strands +from strands import Agent +from strands.models import BedrockModel + + +@pytest.fixture +def system_prompt(): + return "You are an AI assistant that uses & instead of ." + + +@pytest.fixture +def streaming_model(): + return BedrockModel( + model_id="us.anthropic.claude-3-7-sonnet-20250219-v1:0", + streaming=True, + ) + + +@pytest.fixture +def non_streaming_model(): + return BedrockModel( + model_id="us.meta.llama3-2-90b-instruct-v1:0", + streaming=False, + ) + + +@pytest.fixture +def streaming_agent(streaming_model, system_prompt): + return Agent(model=streaming_model, system_prompt=system_prompt, load_tools_from_directory=False) + + +@pytest.fixture +def non_streaming_agent(non_streaming_model, system_prompt): + return Agent(model=non_streaming_model, system_prompt=system_prompt, load_tools_from_directory=False) + + +def test_streaming_agent(streaming_agent): + """Test agent with streaming model.""" + result = streaming_agent("Hello!") + + assert len(str(result)) > 0 + + +def test_non_streaming_agent(non_streaming_agent): + """Test agent with non-streaming model.""" + result = non_streaming_agent("Hello!") + + assert len(str(result)) > 0 + + +def test_streaming_model_events(streaming_model): + """Test streaming model events.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + + # Call converse and collect events + events = list(streaming_model.converse(messages)) + + # Verify basic structure of events + assert any("messageStart" in event for event in events) + assert any("contentBlockDelta" in event for event in events) + assert any("messageStop" in event for event in events) + + +def test_non_streaming_model_events(non_streaming_model): + """Test non-streaming model events.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + + # Call converse and collect events + events = list(non_streaming_model.converse(messages)) + + # Verify basic structure of events + assert any("messageStart" in event for event in events) + assert any("contentBlockDelta" in event for event in events) + assert any("messageStop" in event for event in events) + + +def test_tool_use_streaming(streaming_model): + """Test tool use with streaming model.""" + + tool_was_called = False + + @strands.tool + def calculator(expression: str) -> float: + """Calculate the result of a mathematical expression.""" + + nonlocal tool_was_called + tool_was_called = True + return eval(expression) + + agent = Agent(model=streaming_model, tools=[calculator], load_tools_from_directory=False) + result = agent("What is 123 + 456?") + + # Print the full message content for debugging + print("\nFull message content:") + import json + + print(json.dumps(result.message["content"], indent=2)) + + assert tool_was_called + + +def test_tool_use_non_streaming(non_streaming_model): + """Test tool use with non-streaming model.""" + + tool_was_called = False + + @strands.tool + def calculator(expression: str) -> float: + """Calculate the result of a mathematical expression.""" + + nonlocal tool_was_called + tool_was_called = True + return eval(expression) + + agent = Agent(model=non_streaming_model, tools=[calculator], load_tools_from_directory=False) + agent("What is 123 + 456?") + + assert tool_was_called diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 2f5fc4ad..b326eee7 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -354,9 +354,9 @@ def test_stream(bedrock_client, model): def test_stream_throttling_exception_from_event_stream_error(bedrock_client, model): - error_message = "ThrottlingException - Rate exceeded" + error_message = "Rate exceeded" bedrock_client.converse_stream.side_effect = EventStreamError( - {"Error": {"Message": error_message}}, "ConverseStream" + {"Error": {"Message": error_message, "Code": "ThrottlingException"}}, "ConverseStream" ) request = {"a": 1} @@ -421,7 +421,9 @@ def test_converse(bedrock_client, model, messages, tool_spec, model_id, addition bedrock_client.converse_stream.assert_called_once_with(**request) -def test_converse_input_guardrails(bedrock_client, model, messages, tool_spec, model_id, additional_request_fields): +def test_converse_stream_input_guardrails( + bedrock_client, model, messages, tool_spec, model_id, additional_request_fields +): metadata_event = { "metadata": { "usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, @@ -472,7 +474,9 @@ def test_converse_input_guardrails(bedrock_client, model, messages, tool_spec, m bedrock_client.converse_stream.assert_called_once_with(**request) -def test_converse_output_guardrails(bedrock_client, model, messages, tool_spec, model_id, additional_request_fields): +def test_converse_stream_output_guardrails( + bedrock_client, model, messages, tool_spec, model_id, additional_request_fields +): model.update_config(guardrail_redact_input=False, guardrail_redact_output=True) metadata_event = { "metadata": { @@ -689,3 +693,322 @@ def test_converse_output_no_guardrail_redact( assert tru_chunks == exp_chunks bedrock_client.converse_stream.assert_called_once_with(**request) + + +def test_stream_with_streaming_false(bedrock_client): + """Test stream method with streaming=False.""" + bedrock_client.converse.return_value = { + "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, + "stopReason": "end_turn", + } + expected_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "test"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn", "additionalModelResponseFields": None}}, + ] + + # Create model and call stream + model = BedrockModel(model_id="test-model", streaming=False) + request = {"modelId": "test-model"} + events = list(model.stream(request)) + + assert expected_events == events + + bedrock_client.converse.assert_called_once() + bedrock_client.converse_stream.assert_not_called() + + +def test_stream_with_streaming_false_and_tool_use(bedrock_client): + """Test stream method with streaming=False.""" + bedrock_client.converse.return_value = { + "output": { + "message": { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "123", "name": "dummyTool", "input": {"hello": "world!"}}}], + } + }, + "stopReason": "tool_use", + } + + expected_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "dummyTool"}}}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"hello": "world!"}'}}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use", "additionalModelResponseFields": None}}, + ] + + # Create model and call stream + model = BedrockModel(model_id="test-model", streaming=False) + request = {"modelId": "test-model"} + events = list(model.stream(request)) + + assert expected_events == events + + bedrock_client.converse.assert_called_once() + bedrock_client.converse_stream.assert_not_called() + + +def test_stream_with_streaming_false_and_reasoning(bedrock_client): + """Test stream method with streaming=False.""" + bedrock_client.converse.return_value = { + "output": { + "message": { + "role": "assistant", + "content": [ + { + "reasoningContent": { + "reasoningText": {"text": "Thinking really hard....", "signature": "123"}, + } + } + ], + } + }, + "stopReason": "tool_use", + } + + expected_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "Thinking really hard...."}}}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "123"}}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use", "additionalModelResponseFields": None}}, + ] + + # Create model and call stream + model = BedrockModel(model_id="test-model", streaming=False) + request = {"modelId": "test-model"} + events = list(model.stream(request)) + + assert expected_events == events + + # Verify converse was called + bedrock_client.converse.assert_called_once() + bedrock_client.converse_stream.assert_not_called() + + +def test_converse_and_reasoning_no_signature(bedrock_client): + """Test stream method with streaming=False.""" + bedrock_client.converse.return_value = { + "output": { + "message": { + "role": "assistant", + "content": [ + { + "reasoningContent": { + "reasoningText": {"text": "Thinking really hard...."}, + } + } + ], + } + }, + "stopReason": "tool_use", + } + + expected_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "Thinking really hard...."}}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use", "additionalModelResponseFields": None}}, + ] + + # Create model and call stream + model = BedrockModel(model_id="test-model", streaming=False) + request = {"modelId": "test-model"} + events = list(model.stream(request)) + + assert expected_events == events + + bedrock_client.converse.assert_called_once() + bedrock_client.converse_stream.assert_not_called() + + +def test_stream_with_streaming_false_with_metrics_and_usage(bedrock_client): + """Test stream method with streaming=False.""" + bedrock_client.converse.return_value = { + "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, + "usage": {"inputTokens": 1234, "outputTokens": 1234, "totalTokens": 2468}, + "metrics": {"latencyMs": 1234}, + "stopReason": "tool_use", + } + + expected_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "test"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use", "additionalModelResponseFields": None}}, + { + "metadata": { + "usage": {"inputTokens": 1234, "outputTokens": 1234, "totalTokens": 2468}, + "metrics": {"latencyMs": 1234}, + } + }, + ] + + # Create model and call stream + model = BedrockModel(model_id="test-model", streaming=False) + request = {"modelId": "test-model"} + events = list(model.stream(request)) + + assert expected_events == events + + # Verify converse was called + bedrock_client.converse.assert_called_once() + bedrock_client.converse_stream.assert_not_called() + + +def test_converse_input_guardrails(bedrock_client): + """Test stream method with streaming=False.""" + bedrock_client.converse.return_value = { + "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, + "trace": { + "guardrail": { + "inputAssessment": { + "3e59qlue4hag": { + "wordPolicy": {"customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}]} + } + } + } + }, + "stopReason": "end_turn", + } + + expected_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "test"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn", "additionalModelResponseFields": None}}, + { + "metadata": { + "trace": { + "guardrail": { + "inputAssessment": { + "3e59qlue4hag": { + "wordPolicy": { + "customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}] + } + } + } + } + } + } + }, + {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, + ] + + # Create model and call stream + model = BedrockModel(model_id="test-model", streaming=False) + request = {"modelId": "test-model"} + events = list(model.stream(request)) + + assert expected_events == events + + bedrock_client.converse.assert_called_once() + bedrock_client.converse_stream.assert_not_called() + + +def test_converse_output_guardrails(bedrock_client): + """Test stream method with streaming=False.""" + bedrock_client.converse.return_value = { + "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, + "trace": { + "guardrail": { + "outputAssessments": { + "3e59qlue4hag": [ + { + "wordPolicy": {"customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}]}, + } + ] + }, + } + }, + "stopReason": "end_turn", + } + + expected_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "test"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn", "additionalModelResponseFields": None}}, + { + "metadata": { + "trace": { + "guardrail": { + "outputAssessments": { + "3e59qlue4hag": [ + { + "wordPolicy": { + "customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}] + } + } + ] + } + } + } + } + }, + {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, + ] + + model = BedrockModel(model_id="test-model", streaming=False) + request = {"modelId": "test-model"} + events = list(model.stream(request)) + + assert expected_events == events + + bedrock_client.converse.assert_called_once() + bedrock_client.converse_stream.assert_not_called() + + +def test_converse_output_guardrails_redacts_output(bedrock_client): + """Test stream method with streaming=False.""" + bedrock_client.converse.return_value = { + "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, + "trace": { + "guardrail": { + "outputAssessments": { + "3e59qlue4hag": [ + { + "wordPolicy": {"customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}]}, + } + ] + }, + } + }, + "stopReason": "end_turn", + } + + expected_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "test"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn", "additionalModelResponseFields": None}}, + { + "metadata": { + "trace": { + "guardrail": { + "outputAssessments": { + "3e59qlue4hag": [ + { + "wordPolicy": { + "customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}] + } + } + ] + } + } + } + } + }, + {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, + ] + + model = BedrockModel(model_id="test-model", streaming=False) + request = {"modelId": "test-model"} + events = list(model.stream(request)) + + assert expected_events == events + + bedrock_client.converse.assert_called_once() + bedrock_client.converse_stream.assert_not_called() From a2bb8814c32178f5c7269a186eae3474097dfbba Mon Sep 17 00:00:00 2001 From: AI Ape Wisdom Date: Thu, 29 May 2025 08:37:04 +0800 Subject: [PATCH 45/49] fix: Added hyphen to allowed characters in tool name validation (#55) * Added hyphen to allowed characters in tool name validation --- src/strands/tools/tools.py | 2 +- tests/strands/tools/test_tools.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/strands/tools/tools.py b/src/strands/tools/tools.py index 7d43125b..b595c3d6 100644 --- a/src/strands/tools/tools.py +++ b/src/strands/tools/tools.py @@ -47,7 +47,7 @@ def validate_tool_use_name(tool: ToolUse) -> None: raise InvalidToolUseNameException(message) tool_name = tool["name"] - tool_name_pattern = r"^[a-zA-Z][a-zA-Z0-9_]*$" + tool_name_pattern = r"^[a-zA-Z][a-zA-Z0-9_\-]*$" tool_name_max_length = 64 valid_name_pattern = bool(re.match(tool_name_pattern, tool_name)) tool_name_len = len(tool_name) diff --git a/tests/strands/tools/test_tools.py b/tests/strands/tools/test_tools.py index 8a6b406f..f24cc22d 100644 --- a/tests/strands/tools/test_tools.py +++ b/tests/strands/tools/test_tools.py @@ -14,9 +14,13 @@ def test_validate_tool_use_name_valid(): - tool = {"name": "valid_tool_name", "toolUseId": "123"} + tool1 = {"name": "valid_tool_name", "toolUseId": "123"} + # Should not raise an exception + validate_tool_use_name(tool1) + + tool2 = {"name": "valid-name", "toolUseId": "123"} # Should not raise an exception - validate_tool_use_name(tool) + validate_tool_use_name(tool2) def test_validate_tool_use_name_missing(): From 58bfecb433f7f0d8b16184b93c79354e1da0fd62 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Wed, 28 May 2025 23:11:42 -0400 Subject: [PATCH 46/49] models - content - documents (#138) --- src/strands/types/models/openai.py | 12 ++++++++++++ tests/strands/types/models/test_openai.py | 23 ++++++++++++++++++++--- 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/src/strands/types/models/openai.py b/src/strands/types/models/openai.py index c00a7774..307c0be6 100644 --- a/src/strands/types/models/openai.py +++ b/src/strands/types/models/openai.py @@ -7,6 +7,7 @@ """ import abc +import base64 import json import logging import mimetypes @@ -40,6 +41,17 @@ def format_request_message_content(content: ContentBlock) -> dict[str, Any]: Returns: OpenAI compatible content block. """ + if "document" in content: + mime_type = mimetypes.types_map.get(f".{content['document']['format']}", "application/octet-stream") + file_data = base64.b64encode(content["document"]["source"]["bytes"]).decode("utf-8") + return { + "file": { + "file_data": f"data:{mime_type};base64,{file_data}", + "filename": content["document"]["name"], + }, + "type": "file", + } + if "image" in content: mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream") image_data = content["image"]["source"]["bytes"].decode("utf-8") diff --git a/tests/strands/types/models/test_openai.py b/tests/strands/types/models/test_openai.py index 2657c334..c6a05291 100644 --- a/tests/strands/types/models/test_openai.py +++ b/tests/strands/types/models/test_openai.py @@ -62,7 +62,24 @@ def system_prompt(): @pytest.mark.parametrize( "content, exp_result", [ - # Case 1: Image + # Document + ( + { + "document": { + "format": "pdf", + "name": "test doc", + "source": {"bytes": b"document"}, + }, + }, + { + "file": { + "file_data": "data:application/pdf;base64,ZG9jdW1lbnQ=", + "filename": "test doc", + }, + "type": "file", + }, + ), + # Image ( { "image": { @@ -79,12 +96,12 @@ def system_prompt(): "type": "image_url", }, ), - # Case 2: Text + # Text ( {"text": "hello"}, {"type": "text", "text": "hello"}, ), - # Case 3: Other + # Other ( {"other": {"a": 1}}, { From 5d785a1f1cbbfcc3ba7b4c6b681dff6061f5395d Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Thu, 29 May 2025 10:05:51 -0400 Subject: [PATCH 47/49] models - anthropic - document - plain text (#141) --- src/strands/models/anthropic.py | 11 +++-- tests/strands/models/test_anthropic.py | 59 +++++++++++++++++--------- 2 files changed, 45 insertions(+), 25 deletions(-) diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 704114eb..99e49f81 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -97,13 +97,16 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An Anthropic formatted content block. """ if "document" in content: + mime_type = mimetypes.types_map.get(f".{content['document']['format']}", "application/octet-stream") return { "source": { - "data": base64.b64encode(content["document"]["source"]["bytes"]).decode("utf-8"), - "media_type": mimetypes.types_map.get( - f".{content['document']['format']}", "application/octet-stream" + "data": ( + content["document"]["source"]["bytes"].decode("utf-8") + if mime_type == "text/plain" + else base64.b64encode(content["document"]["source"]["bytes"]).decode("utf-8") ), - "type": "base64", + "media_type": mime_type, + "type": "text" if mime_type == "text/plain" else "base64", }, "title": content["document"]["name"], "type": "document", diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index 48a1da37..2ee344cc 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -102,19 +102,46 @@ def test_format_request_with_system_prompt(model, messages, model_id, max_tokens assert tru_request == exp_request -def test_format_request_with_document(model, model_id, max_tokens): +@pytest.mark.parametrize( + ("content", "formatted_content"), + [ + # PDF + ( + { + "document": {"format": "pdf", "name": "test doc", "source": {"bytes": b"pdf"}}, + }, + { + "source": { + "data": "cGRm", + "media_type": "application/pdf", + "type": "base64", + }, + "title": "test doc", + "type": "document", + }, + ), + # Plain text + ( + { + "document": {"format": "txt", "name": "test doc", "source": {"bytes": b"txt"}}, + }, + { + "source": { + "data": "txt", + "media_type": "text/plain", + "type": "text", + }, + "title": "test doc", + "type": "document", + }, + ), + ], +) +def test_format_request_with_document(content, formatted_content, model, model_id, max_tokens): messages = [ { "role": "user", - "content": [ - { - "document": { - "format": "pdf", - "name": "test-doc", - "source": {"bytes": b"base64encodeddoc"}, - }, - }, - ], + "content": [content], }, ] @@ -124,17 +151,7 @@ def test_format_request_with_document(model, model_id, max_tokens): "messages": [ { "role": "user", - "content": [ - { - "source": { - "data": "YmFzZTY0ZW5jb2RlZGRvYw==", - "media_type": "application/pdf", - "type": "base64", - }, - "title": "test-doc", - "type": "document", - }, - ], + "content": [formatted_content], }, ], "model": model_id, From 73dae72545c98606afb7cb987496d83cbca04e35 Mon Sep 17 00:00:00 2001 From: Jack Yuan <94985218+JackYPCOnline@users.noreply.github.com> Date: Thu, 29 May 2025 11:31:21 -0400 Subject: [PATCH 48/49] fix(telemetry): correct environment variable precedence for OTEL config (#86) * fix(telemetry): correct environment variable precedence for OTEL configuration * fix(telemetry): update get_tracer function to adapt the changes * fix(telemetry): fix tracer initialization --------- Co-authored-by: Jack Yuan --- src/strands/telemetry/tracer.py | 25 +++---- tests/strands/telemetry/test_tracer.py | 91 ++++++++++++++++++++++++++ 2 files changed, 104 insertions(+), 12 deletions(-) diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 3ec663ce..34eb7bed 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -93,7 +93,7 @@ def __init__( service_name: str = "strands-agents", otlp_endpoint: Optional[str] = None, otlp_headers: Optional[Dict[str, str]] = None, - enable_console_export: bool = False, + enable_console_export: Optional[bool] = None, ): """Initialize the tracer. @@ -105,13 +105,17 @@ def __init__( """ # Check environment variables first env_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT") - env_console_export = os.environ.get("STRANDS_OTEL_ENABLE_CONSOLE_EXPORT", "").lower() in ("true", "1", "yes") + env_console_export_str = os.environ.get("STRANDS_OTEL_ENABLE_CONSOLE_EXPORT") - # Environment variables take precedence over constructor parameters - if env_endpoint: - otlp_endpoint = env_endpoint - if env_console_export: - enable_console_export = True + # Constructor parameters take precedence over environment variables + self.otlp_endpoint = otlp_endpoint or env_endpoint + + if enable_console_export is not None: + self.enable_console_export = enable_console_export + elif env_console_export_str: + self.enable_console_export = env_console_export_str.lower() in ("true", "1", "yes") + else: + self.enable_console_export = False # Parse headers from environment if available env_headers = os.environ.get("OTEL_EXPORTER_OTLP_HEADERS") @@ -128,14 +132,11 @@ def __init__( logger.warning("error=<%s> | failed to parse OTEL_EXPORTER_OTLP_HEADERS", e) self.service_name = service_name - self.otlp_endpoint = otlp_endpoint self.otlp_headers = otlp_headers or {} - self.enable_console_export = enable_console_export - self.tracer_provider: Optional[TracerProvider] = None self.tracer: Optional[trace.Tracer] = None - if otlp_endpoint or enable_console_export: + if self.otlp_endpoint or self.enable_console_export: self._initialize_tracer() def _initialize_tracer(self) -> None: @@ -547,7 +548,7 @@ def get_tracer( service_name: str = "strands-agents", otlp_endpoint: Optional[str] = None, otlp_headers: Optional[Dict[str, str]] = None, - enable_console_export: bool = False, + enable_console_export: Optional[bool] = None, ) -> Tracer: """Get or create the global tracer. diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 128b4f94..32a4ac0a 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -60,6 +60,50 @@ def mock_resource(): yield mock_resource +@pytest.fixture +def clean_env(): + """Fixture to provide a clean environment for each test.""" + with mock.patch.dict(os.environ, {}, clear=True): + yield + + +@pytest.fixture +def env_with_otlp(): + """Fixture with OTLP environment variables.""" + with mock.patch.dict( + os.environ, + { + "OTEL_EXPORTER_OTLP_ENDPOINT": "http://env-endpoint", + }, + ): + yield + + +@pytest.fixture +def env_with_console(): + """Fixture with console export environment variables.""" + with mock.patch.dict( + os.environ, + { + "STRANDS_OTEL_ENABLE_CONSOLE_EXPORT": "true", + }, + ): + yield + + +@pytest.fixture +def env_with_both(): + """Fixture with both OTLP and console export environment variables.""" + with mock.patch.dict( + os.environ, + { + "OTEL_EXPORTER_OTLP_ENDPOINT": "http://env-endpoint", + "STRANDS_OTEL_ENABLE_CONSOLE_EXPORT": "true", + }, + ): + yield + + def test_init_default(): """Test initializing the Tracer with default parameters.""" tracer = Tracer() @@ -681,3 +725,50 @@ def test_serialize_vs_json_dumps(): custom_result = serialize({"text": japanese_text}) assert japanese_text in custom_result assert "\\u" not in custom_result + + +def test_init_with_no_env_or_param(clean_env): + """Test initializing with neither environment variable nor constructor parameter.""" + tracer = Tracer() + assert tracer.otlp_endpoint is None + assert tracer.enable_console_export is False + + tracer = Tracer(otlp_endpoint="http://param-endpoint") + assert tracer.otlp_endpoint == "http://param-endpoint" + + tracer = Tracer(enable_console_export=True) + assert tracer.enable_console_export is True + + +def test_constructor_params_with_otlp_env(env_with_otlp): + """Test constructor parameters precedence over OTLP environment variable.""" + # Constructor parameter should take precedence + tracer = Tracer(otlp_endpoint="http://constructor-endpoint") + assert tracer.otlp_endpoint == "http://constructor-endpoint" + + # Without constructor parameter, should use env var + tracer = Tracer() + assert tracer.otlp_endpoint == "http://env-endpoint" + + +def test_constructor_params_with_console_env(env_with_console): + """Test constructor parameters precedence over console environment variable.""" + # Constructor parameter should take precedence + tracer = Tracer(enable_console_export=False) + assert tracer.enable_console_export is False + + # Without explicit constructor parameter, should use env var + tracer = Tracer() + assert tracer.enable_console_export is True + + +def test_fallback_to_env_vars(env_with_both): + """Test fallback to environment variables when no constructor parameters.""" + tracer = Tracer() + assert tracer.otlp_endpoint == "http://env-endpoint" + assert tracer.enable_console_export is True + + # Constructor parameters should still take precedence + tracer = Tracer(otlp_endpoint="http://constructor-endpoint", enable_console_export=False) + assert tracer.otlp_endpoint == "http://constructor-endpoint" + assert tracer.enable_console_export is False From 947f6b6715a491acf662b149ffe2eb8465b38868 Mon Sep 17 00:00:00 2001 From: Nick Clegg Date: Fri, 30 May 2025 08:49:53 -0400 Subject: [PATCH 49/49] Automate deployment to PYPI (#145) * feat: Deploy to pypi on release * feat: get version from tag --- .github/workflows/pr-and-push.yml | 17 ++++ .github/workflows/pypi-publish-on-release.yml | 78 +++++++++++++++++++ .../{test-lint-pr.yml => test-lint.yml} | 17 ++-- .gitignore | 3 +- pyproject.toml | 8 +- 5 files changed, 110 insertions(+), 13 deletions(-) create mode 100644 .github/workflows/pr-and-push.yml create mode 100644 .github/workflows/pypi-publish-on-release.yml rename .github/workflows/{test-lint-pr.yml => test-lint.yml} (80%) diff --git a/.github/workflows/pr-and-push.yml b/.github/workflows/pr-and-push.yml new file mode 100644 index 00000000..38e88691 --- /dev/null +++ b/.github/workflows/pr-and-push.yml @@ -0,0 +1,17 @@ +name: Pull Request and Push Action + +on: + pull_request: # Safer than pull_request_target for untrusted code + branches: [ main ] + types: [opened, synchronize, reopened, ready_for_review, review_requested, review_request_removed] + push: + branches: [ main ] # Also run on direct pushes to main +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + call-test-lint: + uses: ./.github/workflows/test-lint.yml + with: + ref: ${{ github.event.pull_request.head.sha }} \ No newline at end of file diff --git a/.github/workflows/pypi-publish-on-release.yml b/.github/workflows/pypi-publish-on-release.yml new file mode 100644 index 00000000..4047f596 --- /dev/null +++ b/.github/workflows/pypi-publish-on-release.yml @@ -0,0 +1,78 @@ +name: Publish Python Package + +on: + release: + types: + - published + +jobs: + call-test-lint: + uses: ./.github/workflows/test-lint.yml + with: + ref: ${{ github.event.release.target_commitish }} + + build: + name: Build distribution 📦 + needs: + - call-test-lint + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + with: + persist-credentials: false + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install hatch twine + + - name: Validate version + run: | + version=$(hatch version) + if [[ $version =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then + echo "Valid version format" + exit 0 + else + echo "Invalid version format" + exit 1 + fi + + - name: Build + run: | + hatch build + + - name: Store the distribution packages + uses: actions/upload-artifact@v4 + with: + name: python-package-distributions + path: dist/ + + deploy: + name: Upload release to PyPI + needs: + - build + runs-on: ubuntu-latest + + # environment is used by PyPI Trusted Publisher and is strongly encouraged + # https://docs.pypi.org/trusted-publishers/adding-a-publisher/ + environment: + name: pypi + url: https://pypi.org/p/strands-agents + permissions: + # IMPORTANT: this permission is mandatory for Trusted Publishing + id-token: write + + steps: + - name: Download all the dists + uses: actions/download-artifact@v4 + with: + name: python-package-distributions + path: dist/ + - name: Publish distribution 📦 to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 \ No newline at end of file diff --git a/.github/workflows/test-lint-pr.yml b/.github/workflows/test-lint.yml similarity index 80% rename from .github/workflows/test-lint-pr.yml rename to .github/workflows/test-lint.yml index 5ba62427..35e0f584 100644 --- a/.github/workflows/test-lint-pr.yml +++ b/.github/workflows/test-lint.yml @@ -1,14 +1,11 @@ name: Test and Lint on: - pull_request: # Safer than pull_request_target for untrusted code - branches: [ main ] - types: [opened, synchronize, reopened, ready_for_review, review_requested, review_request_removed] - push: - branches: [ main ] # Also run on direct pushes to main -concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} - cancel-in-progress: true + workflow_call: + inputs: + ref: + required: true + type: string jobs: unit-test: @@ -56,7 +53,7 @@ jobs: - name: Checkout code uses: actions/checkout@v4 with: - ref: ${{ github.event.pull_request.head.sha }} # Explicitly define which commit to check out + ref: ${{ inputs.ref }} # Explicitly define which commit to check out persist-credentials: false # Don't persist credentials for subsequent actions - name: Set up Python uses: actions/setup-python@v5 @@ -78,7 +75,7 @@ jobs: - name: Checkout code uses: actions/checkout@v4 with: - ref: ${{ github.event.pull_request.head.sha }} + ref: ${{ inputs.ref }} persist-credentials: false - name: Set up Python diff --git a/.gitignore b/.gitignore index a5cf11c4..5cdc43db 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ __pycache__* .pytest_cache .ruff_cache *.bak -.vscode \ No newline at end of file +.vscode +dist \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index a3b36cab..bd309732 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,10 +1,10 @@ [build-system] -requires = ["hatchling"] +requires = ["hatchling", "hatch-vcs"] build-backend = "hatchling.build" [project] name = "strands-agents" -version = "0.1.5" +dynamic = ["version"] description = "A model-driven approach to building AI agents in just a few lines of code" readme = "README.md" requires-python = ">=3.10" @@ -79,6 +79,10 @@ openai = [ "openai>=1.68.0,<2.0.0", ] +[tool.hatch.version] +# Tells Hatch to use your version control system (git) to determine the version. +source = "vcs" + [tool.hatch.envs.hatch-static-analysis] features = ["anthropic", "litellm", "llamaapi", "ollama", "openai"] dependencies = [