diff --git a/.github/workflows/dispatch-docs.yml b/.github/workflows/dispatch-docs.yml
deleted file mode 100644
index fda63413..00000000
--- a/.github/workflows/dispatch-docs.yml
+++ /dev/null
@@ -1,16 +0,0 @@
-name: Dispatch Docs
-on:
- push:
- branches:
- - main
-
-jobs:
- trigger:
- runs-on: ubuntu-latest
- permissions:
- contents: read
- steps:
- - name: Dispatch
- run: gh api repos/${{ github.repository_owner }}/docs/dispatches -F event_type=sdk-push
- env:
- GITHUB_TOKEN: ${{ secrets.PAT_TOKEN }}
diff --git a/.github/workflows/pr-and-push.yml b/.github/workflows/pr-and-push.yml
new file mode 100644
index 00000000..38e88691
--- /dev/null
+++ b/.github/workflows/pr-and-push.yml
@@ -0,0 +1,17 @@
+name: Pull Request and Push Action
+
+on:
+ pull_request: # Safer than pull_request_target for untrusted code
+ branches: [ main ]
+ types: [opened, synchronize, reopened, ready_for_review, review_requested, review_request_removed]
+ push:
+ branches: [ main ] # Also run on direct pushes to main
+concurrency:
+ group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
+ cancel-in-progress: true
+
+jobs:
+ call-test-lint:
+ uses: ./.github/workflows/test-lint.yml
+ with:
+ ref: ${{ github.event.pull_request.head.sha }}
\ No newline at end of file
diff --git a/.github/workflows/pypi-publish-on-release.yml b/.github/workflows/pypi-publish-on-release.yml
new file mode 100644
index 00000000..4047f596
--- /dev/null
+++ b/.github/workflows/pypi-publish-on-release.yml
@@ -0,0 +1,78 @@
+name: Publish Python Package
+
+on:
+ release:
+ types:
+ - published
+
+jobs:
+ call-test-lint:
+ uses: ./.github/workflows/test-lint.yml
+ with:
+ ref: ${{ github.event.release.target_commitish }}
+
+ build:
+ name: Build distribution π¦
+ needs:
+ - call-test-lint
+ runs-on: ubuntu-latest
+
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ persist-credentials: false
+
+ - name: Set up Python
+ uses: actions/setup-python@v5
+ with:
+ python-version: '3.10'
+
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install hatch twine
+
+ - name: Validate version
+ run: |
+ version=$(hatch version)
+ if [[ $version =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
+ echo "Valid version format"
+ exit 0
+ else
+ echo "Invalid version format"
+ exit 1
+ fi
+
+ - name: Build
+ run: |
+ hatch build
+
+ - name: Store the distribution packages
+ uses: actions/upload-artifact@v4
+ with:
+ name: python-package-distributions
+ path: dist/
+
+ deploy:
+ name: Upload release to PyPI
+ needs:
+ - build
+ runs-on: ubuntu-latest
+
+ # environment is used by PyPI Trusted Publisher and is strongly encouraged
+ # https://docs.pypi.org/trusted-publishers/adding-a-publisher/
+ environment:
+ name: pypi
+ url: https://pypi.org/p/strands-agents
+ permissions:
+ # IMPORTANT: this permission is mandatory for Trusted Publishing
+ id-token: write
+
+ steps:
+ - name: Download all the dists
+ uses: actions/download-artifact@v4
+ with:
+ name: python-package-distributions
+ path: dist/
+ - name: Publish distribution π¦ to PyPI
+ uses: pypa/gh-action-pypi-publish@release/v1
\ No newline at end of file
diff --git a/.github/workflows/test-lint-pr.yml b/.github/workflows/test-lint-pr.yml
deleted file mode 100644
index 15fbebcb..00000000
--- a/.github/workflows/test-lint-pr.yml
+++ /dev/null
@@ -1,139 +0,0 @@
-name: Test and Lint
-
-on:
- pull_request: # Safer than pull_request_target for untrusted code
- branches: [ main ]
- types: [opened, synchronize, reopened, ready_for_review, review_requested, review_request_removed]
- push:
- branches: [ main ] # Also run on direct pushes to main
-
-jobs:
- check-approval:
- name: Check if PR has contributor approval
- runs-on: ubuntu-latest
- permissions:
- pull-requests: read
- # Skip this check for direct pushes to main
- if: github.event_name == 'pull_request'
- outputs:
- approved: ${{ steps.check-approval.outputs.approved }}
- steps:
- - name: Check if PR has been approved by a contributor
- id: check-approval
- uses: actions/github-script@v7
- with:
- script: |
- const APPROVED_ASSOCIATION = ['COLLABORATOR', 'CONTRIBUTOR', 'MEMBER', 'OWNER']
- const PR_AUTHOR_ASSOCIATION = context.payload.pull_request.author_association;
- const { data: reviews } = await github.rest.pulls.listReviews({
- owner: context.repo.owner,
- repo: context.repo.repo,
- pull_number: context.issue.number,
- });
-
- const isApprovedContributor = APPROVED_ASSOCIATION.includes(PR_AUTHOR_ASSOCIATION);
-
- // Check if any contributor has approved
- const isApproved = reviews.some(review =>
- review.state === 'APPROVED' && APPROVED_ASSOCIATION.includes(review.author_association)
- ) || isApprovedContributor;
-
- core.setOutput('approved', isApproved);
-
- if (!isApproved) {
- core.notice('This PR does not have approval from a Contributor. Workflow will not run test jobs.');
- return false;
- }
-
- return true;
-
- unit-test:
- name: Unit Tests - Python ${{ matrix.python-version }} - ${{ matrix.os-name }}
- needs: check-approval
- permissions:
- contents: read
- # Only run if PR is approved or this is a direct push to main
- if: github.event_name == 'push' || needs.check-approval.outputs.approved == 'true'
- strategy:
- matrix:
- include:
- # Linux
- - os: ubuntu-latest
- os-name: linux
- python-version: "3.10"
- - os: ubuntu-latest
- os-name: linux
- python-version: "3.11"
- - os: ubuntu-latest
- os-name: linux
- python-version: "3.12"
- - os: ubuntu-latest
- os-name: linux
- python-version: "3.13"
- # Windows
- - os: windows-latest
- os-name: windows
- python-version: "3.10"
- - os: windows-latest
- os-name: windows
- python-version: "3.11"
- - os: windows-latest
- os-name: windows
- python-version: "3.12"
- - os: windows-latest
- os-name: windows
- python-version: "3.13"
- # MacOS - latest only; not enough runners for MacOS
- - os: macos-latest
- os-name: macos
- python-version: "3.13"
- fail-fast: false
- runs-on: ${{ matrix.os }}
- env:
- LOG_LEVEL: DEBUG
- steps:
- - name: Checkout code
- uses: actions/checkout@v4
- with:
- ref: ${{ github.event.pull_request.head.sha }} # Explicitly define which commit to checkout
- persist-credentials: false # Don't persist credentials for subsequent actions
- - name: Set up Python
- uses: actions/setup-python@v5
- with:
- python-version: ${{ matrix.python-version }}
- - name: Install dependencies
- run: |
- pip install --no-cache-dir hatch
- - name: Run Unit tests
- id: tests
- run: hatch test tests --cover
- continue-on-error: false
-
- lint:
- name: Lint
- runs-on: ubuntu-latest
- needs: check-approval
- permissions:
- contents: read
- if: github.event_name == 'push' || needs.check-approval.outputs.approved == 'true'
- steps:
- - name: Checkout code
- uses: actions/checkout@v4
- with:
- ref: ${{ github.event.pull_request.head.sha }}
- persist-credentials: false
-
- - name: Set up Python
- uses: actions/setup-python@v5
- with:
- python-version: '3.10'
- cache: 'pip'
-
- - name: Install dependencies
- run: |
- pip install --no-cache-dir hatch
-
- - name: Run lint
- id: lint
- run: hatch run test-lint
- continue-on-error: false
diff --git a/.github/workflows/test-lint.yml b/.github/workflows/test-lint.yml
new file mode 100644
index 00000000..35e0f584
--- /dev/null
+++ b/.github/workflows/test-lint.yml
@@ -0,0 +1,94 @@
+name: Test and Lint
+
+on:
+ workflow_call:
+ inputs:
+ ref:
+ required: true
+ type: string
+
+jobs:
+ unit-test:
+ name: Unit Tests - Python ${{ matrix.python-version }} - ${{ matrix.os-name }}
+ permissions:
+ contents: read
+ strategy:
+ matrix:
+ include:
+ # Linux
+ - os: ubuntu-latest
+ os-name: 'linux'
+ python-version: "3.10"
+ - os: ubuntu-latest
+ os-name: 'linux'
+ python-version: "3.11"
+ - os: ubuntu-latest
+ os-name: 'linux'
+ python-version: "3.12"
+ - os: ubuntu-latest
+ os-name: 'linux'
+ python-version: "3.13"
+ # Windows
+ - os: windows-latest
+ os-name: 'windows'
+ python-version: "3.10"
+ - os: windows-latest
+ os-name: 'windows'
+ python-version: "3.11"
+ - os: windows-latest
+ os-name: 'windows'
+ python-version: "3.12"
+ - os: windows-latest
+ os-name: 'windows'
+ python-version: "3.13"
+ # MacOS - latest only; not enough runners for macOS
+ - os: macos-latest
+ os-name: 'macOS'
+ python-version: "3.13"
+ fail-fast: true
+ runs-on: ${{ matrix.os }}
+ env:
+ LOG_LEVEL: DEBUG
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v4
+ with:
+ ref: ${{ inputs.ref }} # Explicitly define which commit to check out
+ persist-credentials: false # Don't persist credentials for subsequent actions
+ - name: Set up Python
+ uses: actions/setup-python@v5
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install dependencies
+ run: |
+ pip install --no-cache-dir hatch
+ - name: Run Unit tests
+ id: tests
+ run: hatch test tests --cover
+ continue-on-error: false
+ lint:
+ name: Lint
+ runs-on: ubuntu-latest
+ permissions:
+ contents: read
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v4
+ with:
+ ref: ${{ inputs.ref }}
+ persist-credentials: false
+
+ - name: Set up Python
+ uses: actions/setup-python@v5
+ with:
+ python-version: '3.10'
+ cache: 'pip'
+
+ - name: Install dependencies
+ run: |
+ pip install --no-cache-dir hatch
+
+ - name: Run lint
+ id: lint
+ run: hatch run test-lint
+ continue-on-error: false
diff --git a/.gitignore b/.gitignore
index a80f4bd1..5cdc43db 100644
--- a/.gitignore
+++ b/.gitignore
@@ -7,3 +7,5 @@ __pycache__*
.pytest_cache
.ruff_cache
*.bak
+.vscode
+dist
\ No newline at end of file
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index e1ddbb89..fa724cdd 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -31,16 +31,22 @@ This project uses [hatchling](https://hatch.pypa.io/latest/build/#hatchling) as
### Setting Up Your Development Environment
-1. Install development dependencies:
+1. Entering virtual environment using `hatch` (recommended), then launch your IDE in the new shell.
```bash
- pip install -e ".[dev]" && pip install -e ".[litellm]
+ hatch shell dev
```
+ Alternatively, install development dependencies in a manually created virtual environment:
+ ```bash
+ pip install -e ".[dev]" && pip install -e ".[litellm]"
+ ```
+
+
2. Set up pre-commit hooks:
```bash
pre-commit install -t pre-commit -t commit-msg
```
- This will automatically run formatters and convention commit checks on your code before each commit.
+ This will automatically run formatters and conventional commit checks on your code before each commit.
3. Run code formatters manually:
```bash
@@ -94,6 +100,8 @@ hatch fmt --linter
If you're using an IDE like VS Code or PyCharm, consider configuring it to use these tools automatically.
+For additional details on styling, please see our dedicated [Style Guide](./STYLE_GUIDE.md).
+
## Contributing via Pull Requests
Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:
@@ -115,7 +123,7 @@ To send us a pull request, please:
## Finding contributions to work on
-Looking at the existing issues is a great way to find something to contribute on.
+Looking at the existing issues is a great way to find something to contribute to.
You can check:
- Our known bugs list in [Bug Reports](../../issues?q=is%3Aissue%20state%3Aopen%20label%3Abug) for issues that need fixing
diff --git a/README.md b/README.md
index 337acd83..ed98d001 100644
--- a/README.md
+++ b/README.md
@@ -1,17 +1,35 @@
-# Strands Agents
+
+
-[](LICENSE)
-[](https://www.python.org/downloads/)
+
+ Strands Agents
+
-
Strands Agents is a simple yet powerful SDK that takes a model-driven approach to building and running AI agents. From simple conversational assistants to complex autonomous workflows, from local development to production deployment, Strands Agents scales with your needs.
@@ -19,7 +37,7 @@ Strands Agents is a simple yet powerful SDK that takes a model-driven approach t
## Feature Overview
- **Lightweight & Flexible**: Simple agent loop that just works and is fully customizable
-- **Model Agnostic**: Support for Amazon Bedrock, Anthropic, Ollama, and custom providers
+- **Model Agnostic**: Support for Amazon Bedrock, Anthropic, LiteLLM, Llama, Ollama, OpenAI, and custom providers
- **Advanced Capabilities**: Multi-agent systems, autonomous agents, and streaming support
- **Built-in MCP**: Native support for Model Context Protocol (MCP) servers, enabling access to thousands of pre-built tools
@@ -105,16 +123,17 @@ from strands.models.llamaapi import LlamaAPIModel
bedrock_model = BedrockModel(
model_id="us.amazon.nova-pro-v1:0",
temperature=0.3,
+ streaming=True, # Enable/disable streaming
)
agent = Agent(model=bedrock_model)
agent("Tell me about Agentic AI")
# Ollama
-ollama_modal = OllamaModel(
+ollama_model = OllamaModel(
host="http://localhost:11434",
model_id="llama3"
)
-agent = Agent(model=ollama_modal)
+agent = Agent(model=ollama_model)
agent("Tell me about Agentic AI")
# Llama API
@@ -129,7 +148,9 @@ Built-in providers:
- [Amazon Bedrock](https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/)
- [Anthropic](https://strandsagents.com/latest/user-guide/concepts/model-providers/anthropic/)
- [LiteLLM](https://strandsagents.com/latest/user-guide/concepts/model-providers/litellm/)
+ - [LlamaAPI](https://strandsagents.com/latest/user-guide/concepts/model-providers/llamaapi/)
- [Ollama](https://strandsagents.com/latest/user-guide/concepts/model-providers/ollama/)
+ - [OpenAI](https://strandsagents.com/latest/user-guide/concepts/model-providers/openai/)
Custom providers can be implemented using [Custom Providers](https://strandsagents.com/latest/user-guide/concepts/model-providers/custom_model_provider/)
@@ -144,7 +165,7 @@ agent = Agent(tools=[calculator])
agent("What is the square root of 1764")
```
-It's also available on GitHub via [strands-agents-tools](https://github.com/strands-agents/strands-agents-tools).
+It's also available on GitHub via [strands-agents/tools](https://github.com/strands-agents/tools).
## Documentation
@@ -157,9 +178,9 @@ For detailed guidance & examples, explore our documentation:
- [API Reference](https://strandsagents.com/latest/api-reference/agent/)
- [Production & Deployment Guide](https://strandsagents.com/latest/user-guide/deploy/operating-agents-in-production/)
-## Contributing
+## Contributing β€οΈ
-We welcome contributions! See our [Contributing Guide](https://github.com/strands-agents/sdk-python/blob/main/CONTRIBUTING.md) for details on:
+We welcome contributions! See our [Contributing Guide](CONTRIBUTING.md) for details on:
- Reporting bugs & features
- Development setup
- Contributing via Pull Requests
@@ -170,6 +191,10 @@ We welcome contributions! See our [Contributing Guide](https://github.com/strand
This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENSE) file for details.
+## Security
+
+See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information.
+
## β οΈ Preview Status
Strands Agents is currently in public preview. During this period:
diff --git a/STYLE_GUIDE.md b/STYLE_GUIDE.md
new file mode 100644
index 00000000..51dc0a73
--- /dev/null
+++ b/STYLE_GUIDE.md
@@ -0,0 +1,59 @@
+# Style Guide
+
+## Overview
+
+The Strands Agents style guide aims to establish consistent formatting, naming conventions, and structure across all code in the repository. We strive to make our code clean, readable, and maintainable.
+
+Where possible, we will codify these style guidelines into our linting rules and pre-commit hooks to automate enforcement and reduce the manual review burden.
+
+## Log Formatting
+
+The format for Strands Agents logs is as follows:
+
+```python
+logger.debug("field1=<%s>, field2=<%s>, ... | human readable message", field1, field2, ...)
+```
+
+### Guidelines
+
+1. **Context**:
+ - Add context as `
=` pairs at the beginning of the log
+ - Many log services (CloudWatch, Splunk, etc.) look for these patterns to extract fields for searching
+ - Use `,`'s to separate pairs
+ - Enclose values in `<>` for readability
+ - This is particularly helpful in displaying empty values (`field=` vs `field=<>`)
+ - Use `%s` for string interpolation as recommended by Python logging
+ - This is an optimization to skip string interpolation when the log level is not enabled
+
+1. **Messages**:
+ - Add human-readable messages at the end of the log
+ - Use lowercase for consistency
+ - Avoid punctuation (periods, exclamation points, etc.) to reduce clutter
+ - Keep messages concise and focused on a single statement
+ - If multiple statements are needed, separate them with the pipe character (`|`)
+ - Example: `"processing request | starting validation"`
+
+### Examples
+
+#### Good
+
+```python
+logger.debug("user_id=<%s>, action=<%s> | user performed action", user_id, action)
+logger.info("request_id=<%s>, duration_ms=<%d> | request completed", request_id, duration)
+logger.warning("attempt=<%d>, max_attempts=<%d> | retry limit approaching", attempt, max_attempts)
+```
+
+#### Poor
+
+```python
+# Avoid: No structured fields, direct variable interpolation in message
+logger.debug(f"User {user_id} performed action {action}")
+
+# Avoid: Inconsistent formatting, punctuation
+logger.info("Request completed in %d ms.", duration)
+
+# Avoid: No separation between fields and message
+logger.warning("Retry limit approaching! attempt=%d max_attempts=%d", attempt, max_attempts)
+```
+
+By following these log formatting guidelines, we ensure that logs are both human-readable and machine-parseable, making debugging and monitoring more efficient.
diff --git a/pyproject.toml b/pyproject.toml
index 6f9b78d6..bd309732 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,11 +1,11 @@
[build-system]
-requires = ["hatchling"]
+requires = ["hatchling", "hatch-vcs"]
build-backend = "hatchling.build"
[project]
name = "strands-agents"
-version = "0.1.0"
-description = "A production-ready framework for building autonomous AI agents"
+dynamic = ["version"]
+description = "A model-driven approach to building AI agents in just a few lines of code"
readme = "README.md"
requires-python = ">=3.10"
license = {text = "Apache-2.0"}
@@ -33,9 +33,9 @@ dependencies = [
"pydantic>=2.0.0,<3.0.0",
"typing-extensions>=4.13.2,<5.0.0",
"watchdog>=6.0.0,<7.0.0",
- "opentelemetry-api>=1.33.0,<2.0.0",
- "opentelemetry-sdk>=1.33.0,<2.0.0",
- "opentelemetry-exporter-otlp-proto-http>=1.33.0,<2.0.0",
+ "opentelemetry-api>=1.30.0,<2.0.0",
+ "opentelemetry-sdk>=1.30.0,<2.0.0",
+ "opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0",
]
[project.urls]
@@ -54,7 +54,7 @@ dev = [
"commitizen>=4.4.0,<5.0.0",
"hatch>=1.0.0,<2.0.0",
"moto>=5.1.0,<6.0.0",
- "mypy>=0.981,<1.0.0",
+ "mypy>=1.15.0,<2.0.0",
"pre-commit>=3.2.0,<4.2.0",
"pytest>=8.0.0,<9.0.0",
"pytest-asyncio>=0.26.0,<0.27.0",
@@ -69,15 +69,22 @@ docs = [
litellm = [
"litellm>=1.69.0,<2.0.0",
]
+llamaapi = [
+ "llama-api-client>=0.1.0,<1.0.0",
+]
ollama = [
"ollama>=0.4.8,<1.0.0",
]
-llamaapi = [
- "llama-api-client>=0.1.0,<1.0.0",
+openai = [
+ "openai>=1.68.0,<2.0.0",
]
+[tool.hatch.version]
+# Tells Hatch to use your version control system (git) to determine the version.
+source = "vcs"
+
[tool.hatch.envs.hatch-static-analysis]
-features = ["anthropic", "litellm", "llamaapi", "ollama"]
+features = ["anthropic", "litellm", "llamaapi", "ollama", "openai"]
dependencies = [
"mypy>=1.15.0,<2.0.0",
"ruff>=0.11.6,<0.12.0",
@@ -100,7 +107,7 @@ lint-fix = [
]
[tool.hatch.envs.hatch-test]
-features = ["anthropic", "litellm", "llamaapi", "ollama"]
+features = ["anthropic", "litellm", "llamaapi", "ollama", "openai"]
extra-dependencies = [
"moto>=5.1.0,<6.0.0",
"pytest>=8.0.0,<9.0.0",
@@ -114,6 +121,11 @@ extra-args = [
"-vv",
]
+[tool.hatch.envs.dev]
+dev-mode = true
+features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama"]
+
+
[[tool.hatch.envs.hatch-test.matrix]]
python = ["3.13", "3.12", "3.11", "3.10"]
@@ -185,7 +197,9 @@ select = [
"D", # pydocstyle
"E", # pycodestyle
"F", # pyflakes
+ "G", # logging format
"I", # isort
+ "LOG", # logging
]
[tool.ruff.lint.per-file-ignores]
diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py
index 89653036..0f912b54 100644
--- a/src/strands/agent/agent.py
+++ b/src/strands/agent/agent.py
@@ -6,7 +6,7 @@
The Agent interface supports two complementary interaction patterns:
1. Natural language for conversation: `agent("Analyze this data")`
-2. Method-style for direct tool access: `agent.tool_name(param1="value")`
+2. Method-style for direct tool access: `agent.tool.tool_name(param1="value")`
"""
import asyncio
@@ -328,27 +328,17 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult:
- metrics: Performance metrics from the event loop
- state: The final state of the event loop
"""
- model_id = self.model.config.get("model_id") if hasattr(self.model, "config") else None
-
- self.trace_span = self.tracer.start_agent_span(
- prompt=prompt,
- model_id=model_id,
- tools=self.tool_names,
- system_prompt=self.system_prompt,
- custom_trace_attributes=self.trace_attributes,
- )
+ self._start_agent_trace_span(prompt)
try:
# Run the event loop and get the result
result = self._run_loop(prompt, kwargs)
- if self.trace_span:
- self.tracer.end_agent_span(span=self.trace_span, response=result)
+ self._end_agent_trace_span(response=result)
return result
except Exception as e:
- if self.trace_span:
- self.tracer.end_agent_span(span=self.trace_span, error=e)
+ self._end_agent_trace_span(error=e)
# Re-raise the exception to preserve original behavior
raise
@@ -383,6 +373,8 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]:
yield event["data"]
```
"""
+ self._start_agent_trace_span(prompt)
+
_stop_event = uuid4()
queue = asyncio.Queue[Any]()
@@ -400,8 +392,10 @@ def target_callback() -> None:
nonlocal kwargs
try:
- self._run_loop(prompt, kwargs, supplementary_callback_handler=queuing_callback_handler)
- except BaseException as e:
+ result = self._run_loop(prompt, kwargs, supplementary_callback_handler=queuing_callback_handler)
+ self._end_agent_trace_span(response=result)
+ except Exception as e:
+ self._end_agent_trace_span(error=e)
enqueue(e)
finally:
enqueue(_stop_event)
@@ -414,7 +408,7 @@ def target_callback() -> None:
item = await queue.get()
if item == _stop_event:
break
- if isinstance(item, BaseException):
+ if isinstance(item, Exception):
raise item
yield item
finally:
@@ -457,27 +451,28 @@ def _execute_event_loop_cycle(self, callback_handler: Callable, kwargs: dict[str
Returns:
The result of the event loop cycle.
"""
- kwargs.pop("agent", None)
- kwargs.pop("model", None)
- kwargs.pop("system_prompt", None)
- kwargs.pop("tool_execution_handler", None)
- kwargs.pop("event_loop_metrics", None)
- kwargs.pop("callback_handler", None)
- kwargs.pop("tool_handler", None)
- kwargs.pop("messages", None)
- kwargs.pop("tool_config", None)
+ # Extract parameters with fallbacks to instance values
+ system_prompt = kwargs.pop("system_prompt", self.system_prompt)
+ model = kwargs.pop("model", self.model)
+ tool_execution_handler = kwargs.pop("tool_execution_handler", self.thread_pool_wrapper)
+ event_loop_metrics = kwargs.pop("event_loop_metrics", self.event_loop_metrics)
+ callback_handler_override = kwargs.pop("callback_handler", callback_handler)
+ tool_handler = kwargs.pop("tool_handler", self.tool_handler)
+ messages = kwargs.pop("messages", self.messages)
+ tool_config = kwargs.pop("tool_config", self.tool_config)
+ kwargs.pop("agent", None) # Remove agent to avoid conflicts
try:
# Execute the main event loop cycle
stop_reason, message, metrics, state = event_loop_cycle(
- model=self.model,
- system_prompt=self.system_prompt,
- messages=self.messages, # will be modified by event_loop_cycle
- tool_config=self.tool_config,
- callback_handler=callback_handler,
- tool_handler=self.tool_handler,
- tool_execution_handler=self.thread_pool_wrapper,
- event_loop_metrics=self.event_loop_metrics,
+ model=model,
+ system_prompt=system_prompt,
+ messages=messages, # will be modified by event_loop_cycle
+ tool_config=tool_config,
+ callback_handler=callback_handler_override,
+ tool_handler=tool_handler,
+ tool_execution_handler=tool_execution_handler,
+ event_loop_metrics=event_loop_metrics,
agent=self,
event_loop_parent_span=self.trace_span,
**kwargs,
@@ -488,8 +483,8 @@ def _execute_event_loop_cycle(self, callback_handler: Callable, kwargs: dict[str
except ContextWindowOverflowException as e:
# Try reducing the context size and retrying
- self.conversation_manager.reduce_context(self.messages, e=e)
- return self._execute_event_loop_cycle(callback_handler, kwargs)
+ self.conversation_manager.reduce_context(messages, e=e)
+ return self._execute_event_loop_cycle(callback_handler_override, kwargs)
def _record_tool_execution(
self,
@@ -515,7 +510,7 @@ def _record_tool_execution(
"""
# Create user message describing the tool call
user_msg_content = [
- {"text": (f"agent.{tool['name']} direct tool call\nInput parameters: {json.dumps(tool['input'])}\n")}
+ {"text": (f"agent.tool.{tool['name']} direct tool call.\nInput parameters: {json.dumps(tool['input'])}\n")}
]
# Add override message if provided
@@ -545,3 +540,43 @@ def _record_tool_execution(
messages.append(tool_use_msg)
messages.append(tool_result_msg)
messages.append(assistant_msg)
+
+ def _start_agent_trace_span(self, prompt: str) -> None:
+ """Starts a trace span for the agent.
+
+ Args:
+ prompt: The natural language prompt from the user.
+ """
+ model_id = self.model.config.get("model_id") if hasattr(self.model, "config") else None
+
+ self.trace_span = self.tracer.start_agent_span(
+ prompt=prompt,
+ model_id=model_id,
+ tools=self.tool_names,
+ system_prompt=self.system_prompt,
+ custom_trace_attributes=self.trace_attributes,
+ )
+
+ def _end_agent_trace_span(
+ self,
+ response: Optional[AgentResult] = None,
+ error: Optional[Exception] = None,
+ ) -> None:
+ """Ends a trace span for the agent.
+
+ Args:
+ span: The span to end.
+ response: Response to record as a trace attribute.
+ error: Error to record as a trace attribute.
+ """
+ if self.trace_span:
+ trace_attributes: Dict[str, Any] = {
+ "span": self.trace_span,
+ }
+
+ if response:
+ trace_attributes["response"] = response
+ if error:
+ trace_attributes["error"] = error
+
+ self.tracer.end_agent_span(**trace_attributes)
diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py
index 4b11e81c..f367b272 100644
--- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py
+++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py
@@ -1,12 +1,10 @@
"""Sliding window conversation history management."""
-import json
import logging
-from typing import List, Optional, cast
+from typing import Optional
-from ...types.content import ContentBlock, Message, Messages
+from ...types.content import Message, Messages
from ...types.exceptions import ContextWindowOverflowException
-from ...types.tools import ToolResult
from .conversation_manager import ConversationManager
logger = logging.getLogger(__name__)
@@ -110,8 +108,9 @@ def _remove_dangling_messages(self, messages: Messages) -> None:
def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> None:
"""Trim the oldest messages to reduce the conversation context size.
- The method handles special cases where tool results need to be converted to regular content blocks to maintain
- conversation coherence after trimming.
+ The method handles special cases where trimming the messages leads to:
+ - toolResult with no corresponding toolUse
+ - toolUse with no corresponding toolResult
Args:
messages: The messages to reduce.
@@ -126,52 +125,24 @@ def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> N
# If the number of messages is less than the window_size, then we default to 2, otherwise, trim to window size
trim_index = 2 if len(messages) <= self.window_size else len(messages) - self.window_size
- # Throw if we cannot trim any messages from the conversation
- if trim_index >= len(messages):
- raise ContextWindowOverflowException("Unable to trim conversation context!") from e
-
- # If the message at the cut index has ToolResultContent, then we map that to ContentBlock. This gets around the
- # limitation of needing ToolUse and ToolResults to be paired.
- if any("toolResult" in content for content in messages[trim_index]["content"]):
- if len(messages[trim_index]["content"]) == 1:
- messages[trim_index]["content"] = self._map_tool_result_content(
- cast(ToolResult, messages[trim_index]["content"][0]["toolResult"])
+ # Find the next valid trim_index
+ while trim_index < len(messages):
+ if (
+ # Oldest message cannot be a toolResult because it needs a toolUse preceding it
+ any("toolResult" in content for content in messages[trim_index]["content"])
+ or (
+ # Oldest message can be a toolUse only if a toolResult immediately follows it.
+ any("toolUse" in content for content in messages[trim_index]["content"])
+ and trim_index + 1 < len(messages)
+ and not any("toolResult" in content for content in messages[trim_index + 1]["content"])
)
-
- # If there is more content than just one ToolResultContent, then we cannot cut at this index.
+ ):
+ trim_index += 1
else:
- raise ContextWindowOverflowException("Unable to trim conversation context!") from e
+ break
+ else:
+ # If we didn't find a valid trim_index, then we throw
+ raise ContextWindowOverflowException("Unable to trim conversation context!") from e
# Overwrite message history
messages[:] = messages[trim_index:]
-
- def _map_tool_result_content(self, tool_result: ToolResult) -> List[ContentBlock]:
- """Convert a ToolResult to a list of standard ContentBlocks.
-
- This method transforms tool result content into standard content blocks that can be preserved when trimming the
- conversation history.
-
- Args:
- tool_result: The ToolResult to convert.
-
- Returns:
- A list of content blocks representing the tool result.
- """
- contents = []
- text_content = "Tool Result Status: " + tool_result["status"] if tool_result["status"] else ""
-
- for tool_result_content in tool_result["content"]:
- if "text" in tool_result_content:
- text_content = "\nTool Result Text Content: " + tool_result_content["text"] + f"\n{text_content}"
- elif "json" in tool_result_content:
- text_content = (
- "\nTool Result JSON Content: " + json.dumps(tool_result_content["json"]) + f"\n{text_content}"
- )
- elif "image" in tool_result_content:
- contents.append(ContentBlock(image=tool_result_content["image"]))
- elif "document" in tool_result_content:
- contents.append(ContentBlock(document=tool_result_content["document"]))
- else:
- logger.warning("unsupported content type")
- contents.append(ContentBlock(text=text_content))
- return contents
diff --git a/src/strands/event_loop/error_handler.py b/src/strands/event_loop/error_handler.py
index 5a78bb3f..a5c85668 100644
--- a/src/strands/event_loop/error_handler.py
+++ b/src/strands/event_loop/error_handler.py
@@ -74,7 +74,7 @@ def handle_input_too_long_error(
"""Handle 'Input is too long' errors by truncating tool results.
When a context window overflow exception occurs (input too long for the model), this function attempts to recover
- by finding and truncating the most recent tool results in the conversation history. If trunction is successful, the
+ by finding and truncating the most recent tool results in the conversation history. If truncation is successful, the
function will make a call to the event loop.
Args:
diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py
index db5a1b97..23d7bd0f 100644
--- a/src/strands/event_loop/event_loop.py
+++ b/src/strands/event_loop/event_loop.py
@@ -28,6 +28,10 @@
logger = logging.getLogger(__name__)
+MAX_ATTEMPTS = 6
+INITIAL_DELAY = 4
+MAX_DELAY = 240 # 4 minutes
+
def initialize_state(**kwargs: Any) -> Any:
"""Initialize the request state if not present.
@@ -51,7 +55,7 @@ def event_loop_cycle(
system_prompt: Optional[str],
messages: Messages,
tool_config: Optional[ToolConfig],
- callback_handler: Any,
+ callback_handler: Callable[..., Any],
tool_handler: Optional[ToolHandler],
tool_execution_handler: Optional[ParallelToolExecutorInterface] = None,
**kwargs: Any,
@@ -130,13 +134,9 @@ def event_loop_cycle(
stop_reason: StopReason
usage: Any
metrics: Metrics
- max_attempts = 6
- initial_delay = 4
- max_delay = 240 # 4 minutes
- current_delay = initial_delay
# Retry loop for handling throttling exceptions
- for attempt in range(max_attempts):
+ for attempt in range(MAX_ATTEMPTS):
model_id = model.config.get("model_id") if hasattr(model, "config") else None
model_invoke_span = tracer.start_model_invoke_span(
parent_span=cycle_span,
@@ -177,7 +177,7 @@ def event_loop_cycle(
# Handle throttling errors with exponential backoff
should_retry, current_delay = handle_throttling_error(
- e, attempt, max_attempts, current_delay, max_delay, callback_handler, kwargs
+ e, attempt, MAX_ATTEMPTS, INITIAL_DELAY, MAX_DELAY, callback_handler, kwargs
)
if should_retry:
continue
@@ -204,80 +204,35 @@ def event_loop_cycle(
# If the model is requesting to use tools
if stop_reason == "tool_use":
- tool_uses: List[ToolUse] = []
- tool_results: List[ToolResult] = []
- invalid_tool_use_ids: List[str] = []
-
- # Extract and validate tools
- validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids)
-
- # Check if tools are available for execution
- if tool_uses:
- if tool_handler is None:
- raise ValueError("toolUse present but tool handler not set")
- if tool_config is None:
- raise ValueError("toolUse present but tool config not set")
-
- # Create the tool handler process callable
- tool_handler_process: Callable[[ToolUse], ToolResult] = partial(
- tool_handler.process,
- messages=messages,
- model=model,
- system_prompt=system_prompt,
- tool_config=tool_config,
- callback_handler=callback_handler,
- **kwargs,
+ if not tool_handler:
+ raise EventLoopException(
+ Exception("Model requested tool use but no tool handler provided"),
+ kwargs["request_state"],
)
- # Execute tools (parallel or sequential)
- run_tools(
- handler=tool_handler_process,
- tool_uses=tool_uses,
- event_loop_metrics=event_loop_metrics,
- request_state=cast(Any, kwargs["request_state"]),
- invalid_tool_use_ids=invalid_tool_use_ids,
- tool_results=tool_results,
- cycle_trace=cycle_trace,
- parent_span=cycle_span,
- parallel_tool_executor=tool_execution_handler,
+ if tool_config is None:
+ raise EventLoopException(
+ Exception("Model requested tool use but no tool config provided"),
+ kwargs["request_state"],
)
- # Update state for the next cycle
- kwargs = prepare_next_cycle(kwargs, event_loop_metrics)
-
- # Create the tool result message
- tool_result_message: Message = {
- "role": "user",
- "content": [{"toolResult": result} for result in tool_results],
- }
- messages.append(tool_result_message)
- callback_handler(message=tool_result_message)
-
- if cycle_span:
- tracer.end_event_loop_cycle_span(
- span=cycle_span, message=message, tool_result_message=tool_result_message
- )
-
- # Check if we should stop the event loop
- if kwargs["request_state"].get("stop_event_loop"):
- event_loop_metrics.end_cycle(cycle_start_time, cycle_trace)
- return (
- stop_reason,
- message,
- event_loop_metrics,
- kwargs["request_state"],
- )
-
- # Recursive call to continue the conversation
- return recurse_event_loop(
- model=model,
- system_prompt=system_prompt,
- messages=messages,
- tool_config=tool_config,
- callback_handler=callback_handler,
- tool_handler=tool_handler,
- **kwargs,
- )
+ # Handle tool execution
+ return _handle_tool_execution(
+ stop_reason,
+ message,
+ model,
+ system_prompt,
+ messages,
+ tool_config,
+ tool_handler,
+ callback_handler,
+ tool_execution_handler,
+ event_loop_metrics,
+ cycle_trace,
+ cycle_span,
+ cycle_start_time,
+ kwargs,
+ )
# End the cycle and return results
event_loop_metrics.end_cycle(cycle_start_time, cycle_trace)
@@ -377,3 +332,105 @@ def prepare_next_cycle(kwargs: Dict[str, Any], event_loop_metrics: EventLoopMetr
kwargs["event_loop_parent_cycle_id"] = kwargs["event_loop_cycle_id"]
return kwargs
+
+
+def _handle_tool_execution(
+ stop_reason: StopReason,
+ message: Message,
+ model: Model,
+ system_prompt: Optional[str],
+ messages: Messages,
+ tool_config: ToolConfig,
+ tool_handler: ToolHandler,
+ callback_handler: Callable[..., Any],
+ tool_execution_handler: Optional[ParallelToolExecutorInterface],
+ event_loop_metrics: EventLoopMetrics,
+ cycle_trace: Trace,
+ cycle_span: Any,
+ cycle_start_time: float,
+ kwargs: Dict[str, Any],
+) -> Tuple[StopReason, Message, EventLoopMetrics, Dict[str, Any]]:
+ tool_uses: List[ToolUse] = []
+ tool_results: List[ToolResult] = []
+ invalid_tool_use_ids: List[str] = []
+
+ """
+ Handles the execution of tools requested by the model during an event loop cycle.
+
+ Args:
+ stop_reason (StopReason): The reason the model stopped generating.
+ message (Message): The message from the model that may contain tool use requests.
+ model (Model): The model provider instance.
+ system_prompt (Optional[str]): The system prompt instructions for the model.
+ messages (Messages): The conversation history messages.
+ tool_config (ToolConfig): Configuration for available tools.
+ tool_handler (ToolHandler): Handler for tool execution.
+ callback_handler (Callable[..., Any]): Callback for processing events as they happen.
+ tool_execution_handler (Optional[ParallelToolExecutorInterface]): Optional handler for parallel tool execution.
+ event_loop_metrics (EventLoopMetrics): Metrics tracking object for the event loop.
+ cycle_trace (Trace): Trace object for the current event loop cycle.
+ cycle_span (Any): Span object for tracing the cycle (type may vary).
+ cycle_start_time (float): Start time of the current cycle.
+ kwargs (Dict[str, Any]): Additional keyword arguments, including request state.
+
+ Returns:
+ Tuple[StopReason, Message, EventLoopMetrics, Dict[str, Any]]:
+ - The stop reason,
+ - The updated message,
+ - The updated event loop metrics,
+ - The updated request state.
+ """
+ validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids)
+
+ if not tool_uses:
+ return stop_reason, message, event_loop_metrics, kwargs["request_state"]
+
+ tool_handler_process = partial(
+ tool_handler.process,
+ messages=messages,
+ model=model,
+ system_prompt=system_prompt,
+ tool_config=tool_config,
+ callback_handler=callback_handler,
+ **kwargs,
+ )
+
+ run_tools(
+ handler=tool_handler_process,
+ tool_uses=tool_uses,
+ event_loop_metrics=event_loop_metrics,
+ request_state=cast(Any, kwargs["request_state"]),
+ invalid_tool_use_ids=invalid_tool_use_ids,
+ tool_results=tool_results,
+ cycle_trace=cycle_trace,
+ parent_span=cycle_span,
+ parallel_tool_executor=tool_execution_handler,
+ )
+
+ kwargs = prepare_next_cycle(kwargs, event_loop_metrics)
+
+ tool_result_message: Message = {
+ "role": "user",
+ "content": [{"toolResult": result} for result in tool_results],
+ }
+
+ messages.append(tool_result_message)
+ callback_handler(message=tool_result_message)
+
+ if cycle_span:
+ tracer = get_tracer()
+ tracer.end_event_loop_cycle_span(span=cycle_span, message=message, tool_result_message=tool_result_message)
+
+ if kwargs["request_state"].get("stop_event_loop", False):
+ event_loop_metrics.end_cycle(cycle_start_time, cycle_trace)
+ return stop_reason, message, event_loop_metrics, kwargs["request_state"]
+
+ return recurse_event_loop(
+ model=model,
+ system_prompt=system_prompt,
+ messages=messages,
+ tool_config=tool_config,
+ callback_handler=callback_handler,
+ tool_handler=tool_handler,
+ **kwargs,
+ )
diff --git a/src/strands/handlers/callback_handler.py b/src/strands/handlers/callback_handler.py
index d6d104d8..4b794b4f 100644
--- a/src/strands/handlers/callback_handler.py
+++ b/src/strands/handlers/callback_handler.py
@@ -17,15 +17,19 @@ def __call__(self, **kwargs: Any) -> None:
Args:
**kwargs: Callback event data including:
-
+ - reasoningText (Optional[str]): Reasoning text to print if provided.
- data (str): Text content to stream.
- complete (bool): Whether this is the final chunk of a response.
- current_tool_use (dict): Information about the current tool being used.
"""
+ reasoningText = kwargs.get("reasoningText", False)
data = kwargs.get("data", "")
complete = kwargs.get("complete", False)
current_tool_use = kwargs.get("current_tool_use", {})
+ if reasoningText:
+ print(reasoningText, end="")
+
if data:
print(data, end="" if not complete else "\n")
diff --git a/src/strands/handlers/tool_handler.py b/src/strands/handlers/tool_handler.py
index 0803eca5..bc4ec1ce 100644
--- a/src/strands/handlers/tool_handler.py
+++ b/src/strands/handlers/tool_handler.py
@@ -46,6 +46,7 @@ def preprocess(
def process(
self,
tool: Any,
+ *,
model: Model,
system_prompt: Optional[str],
messages: List[Any],
diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py
index 704114eb..99e49f81 100644
--- a/src/strands/models/anthropic.py
+++ b/src/strands/models/anthropic.py
@@ -97,13 +97,16 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
Anthropic formatted content block.
"""
if "document" in content:
+ mime_type = mimetypes.types_map.get(f".{content['document']['format']}", "application/octet-stream")
return {
"source": {
- "data": base64.b64encode(content["document"]["source"]["bytes"]).decode("utf-8"),
- "media_type": mimetypes.types_map.get(
- f".{content['document']['format']}", "application/octet-stream"
+ "data": (
+ content["document"]["source"]["bytes"].decode("utf-8")
+ if mime_type == "text/plain"
+ else base64.b64encode(content["document"]["source"]["bytes"]).decode("utf-8")
),
- "type": "base64",
+ "media_type": mime_type,
+ "type": "text" if mime_type == "text/plain" else "base64",
},
"title": content["document"]["name"],
"type": "document",
diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py
index d996aaae..9bbcca7d 100644
--- a/src/strands/models/bedrock.py
+++ b/src/strands/models/bedrock.py
@@ -3,12 +3,14 @@
- Docs: https://aws.amazon.com/bedrock/
"""
+import json
import logging
-from typing import Any, Iterable, Literal, Optional, cast
+import os
+from typing import Any, Iterable, List, Literal, Optional, cast
import boto3
from botocore.config import Config as BotocoreConfig
-from botocore.exceptions import ClientError, EventStreamError
+from botocore.exceptions import ClientError
from typing_extensions import TypedDict, Unpack, override
from ..types.content import Messages
@@ -60,6 +62,7 @@ class BedrockConfig(TypedDict, total=False):
max_tokens: Maximum number of tokens to generate in the response
model_id: The Bedrock model ID (e.g., "us.anthropic.claude-3-7-sonnet-20250219-v1:0")
stop_sequences: List of sequences that will stop generation when encountered
+ streaming: Flag to enable/disable streaming. Defaults to True.
temperature: Controls randomness in generation (higher = more random)
top_p: Controls diversity via nucleus sampling (alternative to temperature)
"""
@@ -80,6 +83,7 @@ class BedrockConfig(TypedDict, total=False):
max_tokens: Optional[int]
model_id: str
stop_sequences: Optional[list[str]]
+ streaming: Optional[bool]
temperature: Optional[float]
top_p: Optional[float]
@@ -96,7 +100,8 @@ def __init__(
Args:
boto_session: Boto Session to use when calling the Bedrock Model.
boto_client_config: Configuration to use when creating the Bedrock-Runtime Boto Client.
- region_name: AWS region to use for the Bedrock service. Defaults to "us-west-2".
+ region_name: AWS region to use for the Bedrock service.
+ Defaults to the AWS_REGION environment variable if set, or "us-west-2" if not set.
**model_config: Configuration options for the Bedrock model.
"""
if region_name and boto_session:
@@ -108,11 +113,26 @@ def __init__(
logger.debug("config=<%s> | initializing", self.config)
session = boto_session or boto3.Session(
- region_name=region_name or "us-west-2",
+ region_name=region_name or os.getenv("AWS_REGION") or "us-west-2",
)
+
+ # Add strands-agents to the request user agent
+ if boto_client_config:
+ existing_user_agent = getattr(boto_client_config, "user_agent_extra", None)
+
+ # Append 'strands-agents' to existing user_agent_extra or set it if not present
+ if existing_user_agent:
+ new_user_agent = f"{existing_user_agent} strands-agents"
+ else:
+ new_user_agent = "strands-agents"
+
+ client_config = boto_client_config.merge(BotocoreConfig(user_agent_extra=new_user_agent))
+ else:
+ client_config = BotocoreConfig(user_agent_extra="strands-agents")
+
self.client = session.client(
service_name="bedrock-runtime",
- config=boto_client_config,
+ config=client_config,
)
@override
@@ -129,7 +149,7 @@ def get_config(self) -> BedrockConfig:
"""Get the current Bedrock Model configuration.
Returns:
- The Bedrok model configuration.
+ The Bedrock model configuration.
"""
return self.config
@@ -229,11 +249,68 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
"""
return cast(StreamEvent, event)
+ def _has_blocked_guardrail(self, guardrail_data: dict[str, Any]) -> bool:
+ """Check if guardrail data contains any blocked policies.
+
+ Args:
+ guardrail_data: Guardrail data from trace information.
+
+ Returns:
+ True if any blocked guardrail is detected, False otherwise.
+ """
+ input_assessment = guardrail_data.get("inputAssessment", {})
+ output_assessments = guardrail_data.get("outputAssessments", {})
+
+ # Check input assessments
+ if any(self._find_detected_and_blocked_policy(assessment) for assessment in input_assessment.values()):
+ return True
+
+ # Check output assessments
+ if any(self._find_detected_and_blocked_policy(assessment) for assessment in output_assessments.values()):
+ return True
+
+ return False
+
+ def _generate_redaction_events(self) -> list[StreamEvent]:
+ """Generate redaction events based on configuration.
+
+ Returns:
+ List of redaction events to yield.
+ """
+ events: List[StreamEvent] = []
+
+ if self.config.get("guardrail_redact_input", True):
+ logger.debug("Redacting user input due to guardrail.")
+ events.append(
+ {
+ "redactContent": {
+ "redactUserContentMessage": self.config.get(
+ "guardrail_redact_input_message", "[User input redacted.]"
+ )
+ }
+ }
+ )
+
+ if self.config.get("guardrail_redact_output", False):
+ logger.debug("Redacting assistant output due to guardrail.")
+ events.append(
+ {
+ "redactContent": {
+ "redactAssistantContentMessage": self.config.get(
+ "guardrail_redact_output_message", "[Assistant output redacted.]"
+ )
+ }
+ }
+ )
+
+ return events
+
@override
- def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
- """Send the request to the Bedrock model and get the streaming response.
+ def stream(self, request: dict[str, Any]) -> Iterable[StreamEvent]:
+ """Send the request to the Bedrock model and get the response.
- This method calls the Bedrock converse_stream API and returns the stream of response events.
+ This method calls either the Bedrock converse_stream API or the converse API
+ based on the streaming parameter in the configuration.
Args:
request: The formatted request to send to the Bedrock model
@@ -243,63 +320,132 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
Raises:
ContextWindowOverflowException: If the input exceeds the model's context window.
- EventStreamError: For all other Bedrock API errors.
+ ModelThrottledException: If the model service is throttling requests.
"""
+ streaming = self.config.get("streaming", True)
+
try:
- response = self.client.converse_stream(**request)
- for chunk in response["stream"]:
- if self.config.get("guardrail_redact_input", True) or self.config.get("guardrail_redact_output", False):
+ if streaming:
+ # Streaming implementation
+ response = self.client.converse_stream(**request)
+ for chunk in response["stream"]:
if (
"metadata" in chunk
and "trace" in chunk["metadata"]
and "guardrail" in chunk["metadata"]["trace"]
):
- inputAssessment = chunk["metadata"]["trace"]["guardrail"].get("inputAssessment", {})
- outputAssessments = chunk["metadata"]["trace"]["guardrail"].get("outputAssessments", {})
-
- # Check if an input or output guardrail was triggered
- if any(
- self._find_detected_and_blocked_policy(assessment)
- for assessment in inputAssessment.values()
- ) or any(
- self._find_detected_and_blocked_policy(assessment)
- for assessment in outputAssessments.values()
- ):
- if self.config.get("guardrail_redact_input", True):
- logger.debug("Found blocked input guardrail. Redacting input.")
- yield {
- "redactContent": {
- "redactUserContentMessage": self.config.get(
- "guardrail_redact_input_message", "[User input redacted.]"
- )
- }
- }
- if self.config.get("guardrail_redact_output", False):
- logger.debug("Found blocked output guardrail. Redacting output.")
- yield {
- "redactContent": {
- "redactAssistantContentMessage": self.config.get(
- "guardrail_redact_output_message", "[Assistant output redacted.]"
- )
- }
- }
+ guardrail_data = chunk["metadata"]["trace"]["guardrail"]
+ if self._has_blocked_guardrail(guardrail_data):
+ yield from self._generate_redaction_events()
+ yield chunk
+ else:
+ # Non-streaming implementation
+ response = self.client.converse(**request)
+
+ # Convert and yield from the response
+ yield from self._convert_non_streaming_to_streaming(response)
+
+ # Check for guardrail triggers after yielding any events (same as streaming path)
+ if (
+ "trace" in response
+ and "guardrail" in response["trace"]
+ and self._has_blocked_guardrail(response["trace"]["guardrail"])
+ ):
+ yield from self._generate_redaction_events()
- yield chunk
- except EventStreamError as e:
- # Handle throttling that occurs mid-stream?
- if "ThrottlingException" in str(e) and "ConverseStream" in str(e):
- raise ModelThrottledException(str(e)) from e
+ except ClientError as e:
+ error_message = str(e)
- if any(overflow_message in str(e) for overflow_message in BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES):
+ # Handle throttling error
+ if e.response["Error"]["Code"] == "ThrottlingException":
+ raise ModelThrottledException(error_message) from e
+
+ # Handle context window overflow
+ if any(overflow_message in error_message for overflow_message in BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES):
logger.warning("bedrock threw context window overflow error")
raise ContextWindowOverflowException(e) from e
+
+ # Otherwise raise the error
raise e
- except ClientError as e:
- # Handle throttling that occurs at the beginning of the call
- if e.response["Error"]["Code"] == "ThrottlingException":
- raise ModelThrottledException(str(e)) from e
- raise
+ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Iterable[StreamEvent]:
+ """Convert a non-streaming response to the streaming format.
+
+ Args:
+ response: The non-streaming response from the Bedrock model.
+
+ Returns:
+ An iterable of response events in the streaming format.
+ """
+ # Yield messageStart event
+ yield {"messageStart": {"role": response["output"]["message"]["role"]}}
+
+ # Process content blocks
+ for content in response["output"]["message"]["content"]:
+ # Yield contentBlockStart event if needed
+ if "toolUse" in content:
+ yield {
+ "contentBlockStart": {
+ "start": {
+ "toolUse": {
+ "toolUseId": content["toolUse"]["toolUseId"],
+ "name": content["toolUse"]["name"],
+ }
+ },
+ }
+ }
+
+ # For tool use, we need to yield the input as a delta
+ input_value = json.dumps(content["toolUse"]["input"])
+
+ yield {"contentBlockDelta": {"delta": {"toolUse": {"input": input_value}}}}
+ elif "text" in content:
+ # Then yield the text as a delta
+ yield {
+ "contentBlockDelta": {
+ "delta": {"text": content["text"]},
+ }
+ }
+ elif "reasoningContent" in content:
+ # Then yield the reasoning content as a delta
+ yield {
+ "contentBlockDelta": {
+ "delta": {"reasoningContent": {"text": content["reasoningContent"]["reasoningText"]["text"]}}
+ }
+ }
+
+ if "signature" in content["reasoningContent"]["reasoningText"]:
+ yield {
+ "contentBlockDelta": {
+ "delta": {
+ "reasoningContent": {
+ "signature": content["reasoningContent"]["reasoningText"]["signature"]
+ }
+ }
+ }
+ }
+
+ # Yield contentBlockStop event
+ yield {"contentBlockStop": {}}
+
+ # Yield messageStop event
+ yield {
+ "messageStop": {
+ "stopReason": response["stopReason"],
+ "additionalModelResponseFields": response.get("additionalModelResponseFields"),
+ }
+ }
+
+ # Yield metadata event
+ if "usage" in response or "metrics" in response or "trace" in response:
+ metadata: StreamEvent = {"metadata": {}}
+ if "usage" in response:
+ metadata["metadata"]["usage"] = response["usage"]
+ if "metrics" in response:
+ metadata["metadata"]["metrics"] = response["metrics"]
+ if "trace" in response:
+ metadata["metadata"]["trace"] = response["trace"]
+ yield metadata
def _find_detected_and_blocked_policy(self, input: Any) -> bool:
"""Recursively checks if the assessment contains a detected and blocked guardrail.
diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py
index a7563133..23d2c2ae 100644
--- a/src/strands/models/litellm.py
+++ b/src/strands/models/litellm.py
@@ -3,23 +3,19 @@
- Docs: https://docs.litellm.ai/
"""
-import json
import logging
-import mimetypes
-from typing import Any, Iterable, Optional, TypedDict
+from typing import Any, Optional, TypedDict, cast
import litellm
from typing_extensions import Unpack, override
-from ..types.content import ContentBlock, Messages
-from ..types.models import Model
-from ..types.streaming import StreamEvent
-from ..types.tools import ToolResult, ToolSpec, ToolUse
+from ..types.content import ContentBlock
+from .openai import OpenAIModel
logger = logging.getLogger(__name__)
-class LiteLLMModel(Model):
+class LiteLLMModel(OpenAIModel):
"""LiteLLM model provider implementation."""
class LiteLLMConfig(TypedDict, total=False):
@@ -45,7 +41,7 @@ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config:
https://github.com/BerriAI/litellm/blob/main/litellm/main.py.
**model_config: Configuration options for the LiteLLM model.
"""
- self.config = LiteLLMModel.LiteLLMConfig(**model_config)
+ self.config = dict(model_config)
logger.debug("config=<%s> | initializing", self.config)
@@ -68,9 +64,11 @@ def get_config(self) -> LiteLLMConfig:
Returns:
The LiteLLM model configuration.
"""
- return self.config
+ return cast(LiteLLMModel.LiteLLMConfig, self.config)
- def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any]:
+ @override
+ @staticmethod
+ def format_request_message_content(content: ContentBlock) -> dict[str, Any]:
"""Format a LiteLLM content block.
Args:
@@ -79,18 +77,6 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
Returns:
LiteLLM formatted content block.
"""
- if "image" in content:
- mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream")
- image_data = content["image"]["source"]["bytes"].decode("utf-8")
- return {
- "image_url": {
- "detail": "auto",
- "format": mime_type,
- "url": f"data:{mime_type};base64,{image_data}",
- },
- "type": "image_url",
- }
-
if "reasoningContent" in content:
return {
"signature": content["reasoningContent"]["reasoningText"]["signature"],
@@ -98,9 +84,6 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
"type": "thinking",
}
- if "text" in content:
- return {"text": content["text"], "type": "text"}
-
if "video" in content:
return {
"type": "video_url",
@@ -110,230 +93,4 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
},
}
- return {"text": json.dumps(content), "type": "text"}
-
- def _format_request_message_tool_call(self, tool_use: ToolUse) -> dict[str, Any]:
- """Format a LiteLLM tool call.
-
- Args:
- tool_use: Tool use requested by the model.
-
- Returns:
- LiteLLM formatted tool call.
- """
- return {
- "function": {
- "arguments": json.dumps(tool_use["input"]),
- "name": tool_use["name"],
- },
- "id": tool_use["toolUseId"],
- "type": "function",
- }
-
- def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any]:
- """Format a LiteLLM tool message.
-
- Args:
- tool_result: Tool result collected from a tool execution.
-
- Returns:
- LiteLLM formatted tool message.
- """
- return {
- "role": "tool",
- "tool_call_id": tool_result["toolUseId"],
- "content": json.dumps(
- {
- "content": tool_result["content"],
- "status": tool_result["status"],
- }
- ),
- }
-
- def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]:
- """Format a LiteLLM messages array.
-
- Args:
- messages: List of message objects to be processed by the model.
- system_prompt: System prompt to provide context to the model.
-
- Returns:
- A LiteLLM messages array.
- """
- formatted_messages: list[dict[str, Any]]
- formatted_messages = [{"role": "system", "content": system_prompt}] if system_prompt else []
-
- for message in messages:
- contents = message["content"]
-
- formatted_contents = [
- self._format_request_message_content(content)
- for content in contents
- if not any(block_type in content for block_type in ["toolResult", "toolUse"])
- ]
- formatted_tool_calls = [
- self._format_request_message_tool_call(content["toolUse"])
- for content in contents
- if "toolUse" in content
- ]
- formatted_tool_messages = [
- self._format_request_tool_message(content["toolResult"])
- for content in contents
- if "toolResult" in content
- ]
-
- formatted_message = {
- "role": message["role"],
- "content": formatted_contents,
- **({"tool_calls": formatted_tool_calls} if formatted_tool_calls else {}),
- }
- formatted_messages.append(formatted_message)
- formatted_messages.extend(formatted_tool_messages)
-
- return [message for message in formatted_messages if message["content"] or "tool_calls" in message]
-
- @override
- def format_request(
- self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
- ) -> dict[str, Any]:
- """Format a LiteLLM chat streaming request.
-
- Args:
- messages: List of message objects to be processed by the model.
- tool_specs: List of tool specifications to make available to the model.
- system_prompt: System prompt to provide context to the model.
-
- Returns:
- A LiteLLM chat streaming request.
- """
- return {
- "messages": self._format_request_messages(messages, system_prompt),
- "model": self.config["model_id"],
- "stream": True,
- "stream_options": {"include_usage": True},
- "tools": [
- {
- "type": "function",
- "function": {
- "name": tool_spec["name"],
- "description": tool_spec["description"],
- "parameters": tool_spec["inputSchema"]["json"],
- },
- }
- for tool_spec in tool_specs or []
- ],
- **(self.config.get("params") or {}),
- }
-
- @override
- def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
- """Format the LiteLLM response events into standardized message chunks.
-
- Args:
- event: A response event from the LiteLLM model.
-
- Returns:
- The formatted chunk.
-
- Raises:
- RuntimeError: If chunk_type is not recognized.
- This error should never be encountered as we control chunk_type in the stream method.
- """
- match event["chunk_type"]:
- case "message_start":
- return {"messageStart": {"role": "assistant"}}
-
- case "content_start":
- if event["data_type"] == "tool":
- return {
- "contentBlockStart": {
- "start": {
- "toolUse": {
- "name": event["data"].function.name,
- "toolUseId": event["data"].id,
- }
- }
- }
- }
-
- return {"contentBlockStart": {"start": {}}}
-
- case "content_delta":
- if event["data_type"] == "tool":
- return {"contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments}}}}
-
- return {"contentBlockDelta": {"delta": {"text": event["data"]}}}
-
- case "content_stop":
- return {"contentBlockStop": {}}
-
- case "message_stop":
- match event["data"]:
- case "tool_calls":
- return {"messageStop": {"stopReason": "tool_use"}}
- case "length":
- return {"messageStop": {"stopReason": "max_tokens"}}
- case _:
- return {"messageStop": {"stopReason": "end_turn"}}
-
- case "metadata":
- return {
- "metadata": {
- "usage": {
- "inputTokens": event["data"].prompt_tokens,
- "outputTokens": event["data"].completion_tokens,
- "totalTokens": event["data"].total_tokens,
- },
- "metrics": {
- "latencyMs": 0, # TODO
- },
- },
- }
-
- case _:
- raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type")
-
- @override
- def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
- """Send the request to the LiteLLM model and get the streaming response.
-
- Args:
- request: The formatted request to send to the LiteLLM model.
-
- Returns:
- An iterable of response events from the LiteLLM model.
- """
- response = self.client.chat.completions.create(**request)
-
- yield {"chunk_type": "message_start"}
- yield {"chunk_type": "content_start", "data_type": "text"}
-
- tool_calls: dict[int, list[Any]] = {}
-
- for event in response:
- choice = event.choices[0]
- if choice.finish_reason:
- break
-
- if choice.delta.content:
- yield {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content}
-
- for tool_call in choice.delta.tool_calls or []:
- tool_calls.setdefault(tool_call.index, []).append(tool_call)
-
- yield {"chunk_type": "content_stop", "data_type": "text"}
-
- for tool_deltas in tool_calls.values():
- tool_start, tool_deltas = tool_deltas[0], tool_deltas[1:]
- yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_start}
-
- for tool_delta in tool_deltas:
- yield {"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}
-
- yield {"chunk_type": "content_stop", "data_type": "tool"}
-
- yield {"chunk_type": "message_stop", "data": choice.finish_reason}
-
- event = next(response)
- if hasattr(event, "usage"):
- yield {"chunk_type": "metadata", "data": event.usage}
+ return OpenAIModel.format_request_message_content(content)
diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py
new file mode 100644
index 00000000..764cb851
--- /dev/null
+++ b/src/strands/models/openai.py
@@ -0,0 +1,124 @@
+"""OpenAI model provider.
+
+- Docs: https://platform.openai.com/docs/overview
+"""
+
+import logging
+from typing import Any, Iterable, Optional, Protocol, TypedDict, cast
+
+import openai
+from typing_extensions import Unpack, override
+
+from ..types.models import OpenAIModel as SAOpenAIModel
+
+logger = logging.getLogger(__name__)
+
+
+class Client(Protocol):
+ """Protocol defining the OpenAI-compatible interface for the underlying provider client."""
+
+ @property
+ # pragma: no cover
+ def chat(self) -> Any:
+ """Chat completions interface."""
+ ...
+
+
+class OpenAIModel(SAOpenAIModel):
+ """OpenAI model provider implementation."""
+
+ client: Client
+
+ class OpenAIConfig(TypedDict, total=False):
+ """Configuration options for OpenAI models.
+
+ Attributes:
+ model_id: Model ID (e.g., "gpt-4o").
+ For a complete list of supported models, see https://platform.openai.com/docs/models.
+ params: Model parameters (e.g., max_tokens).
+ For a complete list of supported parameters, see
+ https://platform.openai.com/docs/api-reference/chat/create.
+ """
+
+ model_id: str
+ params: Optional[dict[str, Any]]
+
+ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[OpenAIConfig]) -> None:
+ """Initialize provider instance.
+
+ Args:
+ client_args: Arguments for the OpenAI client.
+ For a complete list of supported arguments, see https://pypi.org/project/openai/.
+ **model_config: Configuration options for the OpenAI model.
+ """
+ self.config = dict(model_config)
+
+ logger.debug("config=<%s> | initializing", self.config)
+
+ client_args = client_args or {}
+ self.client = openai.OpenAI(**client_args)
+
+ @override
+ def update_config(self, **model_config: Unpack[OpenAIConfig]) -> None: # type: ignore[override]
+ """Update the OpenAI model configuration with the provided arguments.
+
+ Args:
+ **model_config: Configuration overrides.
+ """
+ self.config.update(model_config)
+
+ @override
+ def get_config(self) -> OpenAIConfig:
+ """Get the OpenAI model configuration.
+
+ Returns:
+ The OpenAI model configuration.
+ """
+ return cast(OpenAIModel.OpenAIConfig, self.config)
+
+ @override
+ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]:
+ """Send the request to the OpenAI model and get the streaming response.
+
+ Args:
+ request: The formatted request to send to the OpenAI model.
+
+ Returns:
+ An iterable of response events from the OpenAI model.
+ """
+ response = self.client.chat.completions.create(**request)
+
+ yield {"chunk_type": "message_start"}
+ yield {"chunk_type": "content_start", "data_type": "text"}
+
+ tool_calls: dict[int, list[Any]] = {}
+
+ for event in response:
+ choice = event.choices[0]
+
+ if choice.delta.content:
+ yield {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content}
+
+ for tool_call in choice.delta.tool_calls or []:
+ tool_calls.setdefault(tool_call.index, []).append(tool_call)
+
+ if choice.finish_reason:
+ break
+
+ yield {"chunk_type": "content_stop", "data_type": "text"}
+
+ for tool_deltas in tool_calls.values():
+ yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]}
+
+ for tool_delta in tool_deltas:
+ yield {"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta}
+
+ yield {"chunk_type": "content_stop", "data_type": "tool"}
+
+ yield {"chunk_type": "message_stop", "data": choice.finish_reason}
+
+ # Skip remaining events as we don't have use for anything except the final usage payload
+ for event in response:
+ _ = event
+
+ yield {"chunk_type": "metadata", "data": event.usage}
diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py
index ad30a445..34eb7bed 100644
--- a/src/strands/telemetry/tracer.py
+++ b/src/strands/telemetry/tracer.py
@@ -7,7 +7,7 @@
import json
import logging
import os
-from datetime import datetime, timezone
+from datetime import date, datetime, timezone
from importlib.metadata import version
from typing import Any, Dict, Mapping, Optional
@@ -30,21 +30,49 @@
class JSONEncoder(json.JSONEncoder):
"""Custom JSON encoder that handles non-serializable types."""
- def default(self, obj: Any) -> Any:
- """Handle non-serializable types.
+ def encode(self, obj: Any) -> str:
+ """Recursively encode objects, preserving structure and only replacing unserializable values.
Args:
- obj: The object to serialize
+ obj: The object to encode
Returns:
- A JSON serializable version of the object
+ JSON string representation of the object
"""
- value = ""
- try:
- value = super().default(obj)
- except TypeError:
- value = ""
- return value
+ # Process the object to handle non-serializable values
+ processed_obj = self._process_value(obj)
+ # Use the parent class to encode the processed object
+ return super().encode(processed_obj)
+
+ def _process_value(self, value: Any) -> Any:
+ """Process any value, handling containers recursively.
+
+ Args:
+ value: The value to process
+
+ Returns:
+ Processed value with unserializable parts replaced
+ """
+ # Handle datetime objects directly
+ if isinstance(value, (datetime, date)):
+ return value.isoformat()
+
+ # Handle dictionaries
+ elif isinstance(value, dict):
+ return {k: self._process_value(v) for k, v in value.items()}
+
+ # Handle lists
+ elif isinstance(value, list):
+ return [self._process_value(item) for item in value]
+
+ # Handle all other values
+ else:
+ try:
+ # Test if the value is JSON serializable
+ json.dumps(value)
+ return value
+ except (TypeError, OverflowError, ValueError):
+ return ""
class Tracer:
@@ -65,7 +93,7 @@ def __init__(
service_name: str = "strands-agents",
otlp_endpoint: Optional[str] = None,
otlp_headers: Optional[Dict[str, str]] = None,
- enable_console_export: bool = False,
+ enable_console_export: Optional[bool] = None,
):
"""Initialize the tracer.
@@ -77,13 +105,17 @@ def __init__(
"""
# Check environment variables first
env_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT")
- env_console_export = os.environ.get("STRANDS_OTEL_ENABLE_CONSOLE_EXPORT", "").lower() in ("true", "1", "yes")
+ env_console_export_str = os.environ.get("STRANDS_OTEL_ENABLE_CONSOLE_EXPORT")
+
+ # Constructor parameters take precedence over environment variables
+ self.otlp_endpoint = otlp_endpoint or env_endpoint
- # Environment variables take precedence over constructor parameters
- if env_endpoint:
- otlp_endpoint = env_endpoint
- if env_console_export:
- enable_console_export = True
+ if enable_console_export is not None:
+ self.enable_console_export = enable_console_export
+ elif env_console_export_str:
+ self.enable_console_export = env_console_export_str.lower() in ("true", "1", "yes")
+ else:
+ self.enable_console_export = False
# Parse headers from environment if available
env_headers = os.environ.get("OTEL_EXPORTER_OTLP_HEADERS")
@@ -97,17 +129,14 @@ def __init__(
headers_dict[key.strip()] = value.strip()
otlp_headers = headers_dict
except Exception as e:
- logger.warning(f"error=<{e}> | failed to parse OTEL_EXPORTER_OTLP_HEADERS")
+ logger.warning("error=<%s> | failed to parse OTEL_EXPORTER_OTLP_HEADERS", e)
self.service_name = service_name
- self.otlp_endpoint = otlp_endpoint
self.otlp_headers = otlp_headers or {}
- self.enable_console_export = enable_console_export
-
self.tracer_provider: Optional[TracerProvider] = None
self.tracer: Optional[trace.Tracer] = None
- if otlp_endpoint or enable_console_export:
+ if self.otlp_endpoint or self.enable_console_export:
self._initialize_tracer()
def _initialize_tracer(self) -> None:
@@ -156,9 +185,9 @@ def _initialize_tracer(self) -> None:
batch_processor = BatchSpanProcessor(otlp_exporter)
self.tracer_provider.add_span_processor(batch_processor)
- logger.info(f"endpoint=<{endpoint}> | OTLP exporter configured with endpoint")
+ logger.info("endpoint=<%s> | OTLP exporter configured with endpoint", endpoint)
except Exception as e:
- logger.error(f"error=<{e}> | Failed to configure OTLP exporter", exc_info=True)
+ logger.exception("error=<%s> | Failed to configure OTLP exporter", e)
# Set as global tracer provider
trace.set_tracer_provider(self.tracer_provider)
@@ -239,7 +268,7 @@ def _end_span(
else:
span.set_status(StatusCode.OK)
except Exception as e:
- logger.warning(f"error=<{e}> | error while ending span", exc_info=True)
+ logger.warning("error=<%s> | error while ending span", e, exc_info=True)
finally:
span.end()
# Force flush to ensure spans are exported
@@ -247,7 +276,7 @@ def _end_span(
try:
self.tracer_provider.force_flush()
except Exception as e:
- logger.warning(f"error=<{e}> | failed to force flush tracer provider")
+ logger.warning("error=<%s> | failed to force flush tracer provider", e)
def end_span_with_error(self, span: trace.Span, error_message: str, exception: Optional[Exception] = None) -> None:
"""End a span with error status.
@@ -287,7 +316,7 @@ def start_model_invoke_span(
"gen_ai.system": "strands-agents",
"agent.name": agent_name,
"gen_ai.agent.name": agent_name,
- "gen_ai.prompt": json.dumps(messages, cls=JSONEncoder),
+ "gen_ai.prompt": serialize(messages),
}
if model_id:
@@ -310,7 +339,7 @@ def end_model_invoke_span(
error: Optional exception if the model call failed.
"""
attributes: Dict[str, AttributeValue] = {
- "gen_ai.completion": json.dumps(message["content"], cls=JSONEncoder),
+ "gen_ai.completion": serialize(message["content"]),
"gen_ai.usage.prompt_tokens": usage["inputTokens"],
"gen_ai.usage.completion_tokens": usage["outputTokens"],
"gen_ai.usage.total_tokens": usage["totalTokens"],
@@ -332,9 +361,10 @@ def start_tool_call_span(
The created span, or None if tracing is not enabled.
"""
attributes: Dict[str, AttributeValue] = {
+ "gen_ai.prompt": serialize(tool),
"tool.name": tool["name"],
"tool.id": tool["toolUseId"],
- "tool.parameters": json.dumps(tool["input"], cls=JSONEncoder),
+ "tool.parameters": serialize(tool["input"]),
}
# Add additional kwargs as attributes
@@ -358,10 +388,11 @@ def end_tool_call_span(
status = tool_result.get("status")
status_str = str(status) if status is not None else ""
+ tool_result_content_json = serialize(tool_result.get("content"))
attributes.update(
{
- "tool.result": json.dumps(tool_result.get("content"), cls=JSONEncoder),
- "gen_ai.completion": json.dumps(tool_result.get("content"), cls=JSONEncoder),
+ "tool.result": tool_result_content_json,
+ "gen_ai.completion": tool_result_content_json,
"tool.status": status_str,
}
)
@@ -390,7 +421,7 @@ def start_event_loop_cycle_span(
parent_span = parent_span if parent_span else event_loop_kwargs.get("event_loop_parent_span")
attributes: Dict[str, AttributeValue] = {
- "gen_ai.prompt": json.dumps(messages, cls=JSONEncoder),
+ "gen_ai.prompt": serialize(messages),
"event_loop.cycle_id": event_loop_cycle_id,
}
@@ -419,11 +450,11 @@ def end_event_loop_cycle_span(
error: Optional exception if the cycle failed.
"""
attributes: Dict[str, AttributeValue] = {
- "gen_ai.completion": json.dumps(message["content"], cls=JSONEncoder),
+ "gen_ai.completion": serialize(message["content"]),
}
if tool_result_message:
- attributes["tool.result"] = json.dumps(tool_result_message["content"], cls=JSONEncoder)
+ attributes["tool.result"] = serialize(tool_result_message["content"])
self._end_span(span, attributes, error)
@@ -460,7 +491,7 @@ def start_agent_span(
attributes["gen_ai.request.model"] = model_id
if tools:
- tools_json = json.dumps(tools, cls=JSONEncoder)
+ tools_json = serialize(tools)
attributes["agent.tools"] = tools_json
attributes["gen_ai.agent.tools"] = tools_json
@@ -492,7 +523,7 @@ def end_agent_span(
if response:
attributes.update(
{
- "gen_ai.completion": json.dumps(response, cls=JSONEncoder),
+ "gen_ai.completion": str(response),
}
)
@@ -517,7 +548,7 @@ def get_tracer(
service_name: str = "strands-agents",
otlp_endpoint: Optional[str] = None,
otlp_headers: Optional[Dict[str, str]] = None,
- enable_console_export: bool = False,
+ enable_console_export: Optional[bool] = None,
) -> Tracer:
"""Get or create the global tracer.
@@ -541,3 +572,15 @@ def get_tracer(
)
return _tracer_instance
+
+
+def serialize(obj: Any) -> str:
+ """Serialize an object to JSON with consistent settings.
+
+ Args:
+ obj: The object to serialize
+
+ Returns:
+ JSON string representation of the object
+ """
+ return json.dumps(obj, ensure_ascii=False, cls=JSONEncoder)
diff --git a/src/strands/tools/tools.py b/src/strands/tools/tools.py
index 40565a24..b595c3d6 100644
--- a/src/strands/tools/tools.py
+++ b/src/strands/tools/tools.py
@@ -40,14 +40,14 @@ def validate_tool_use_name(tool: ToolUse) -> None:
Raises:
InvalidToolUseNameException: If the tool name is invalid.
"""
- # We need to fix some typing here, because we dont actually expect a ToolUse, but dict[str, Any]
+ # We need to fix some typing here, because we don't actually expect a ToolUse, but dict[str, Any]
if "name" not in tool:
message = "tool name missing" # type: ignore[unreachable]
logger.warning(message)
raise InvalidToolUseNameException(message)
tool_name = tool["name"]
- tool_name_pattern = r"^[a-zA-Z][a-zA-Z0-9_]*$"
+ tool_name_pattern = r"^[a-zA-Z][a-zA-Z0-9_\-]*$"
tool_name_max_length = 64
valid_name_pattern = bool(re.match(tool_name_pattern, tool_name))
tool_name_len = len(tool_name)
diff --git a/src/strands/types/guardrails.py b/src/strands/types/guardrails.py
index 6055b9ab..c15ba1be 100644
--- a/src/strands/types/guardrails.py
+++ b/src/strands/types/guardrails.py
@@ -16,7 +16,7 @@ class GuardrailConfig(TypedDict, total=False):
Attributes:
guardrailIdentifier: Unique identifier for the guardrail.
guardrailVersion: Version of the guardrail to apply.
- streamProcessingMode: Procesing mode.
+ streamProcessingMode: Processing mode.
trace: The trace behavior for the guardrail.
"""
@@ -219,7 +219,7 @@ class GuardrailAssessment(TypedDict):
contentPolicy: The content policy.
contextualGroundingPolicy: The contextual grounding policy used for the guardrail assessment.
sensitiveInformationPolicy: The sensitive information policy.
- topicPolic: The topic policy.
+ topicPolicy: The topic policy.
wordPolicy: The word policy.
"""
diff --git a/src/strands/types/media.py b/src/strands/types/media.py
index 058a09ea..29b89e5c 100644
--- a/src/strands/types/media.py
+++ b/src/strands/types/media.py
@@ -68,7 +68,7 @@ class ImageContent(TypedDict):
class VideoSource(TypedDict):
- """Contains the content of a vidoe.
+ """Contains the content of a video.
Attributes:
bytes: The binary content of the video.
diff --git a/src/strands/types/models/__init__.py b/src/strands/types/models/__init__.py
new file mode 100644
index 00000000..5ce0a498
--- /dev/null
+++ b/src/strands/types/models/__init__.py
@@ -0,0 +1,6 @@
+"""Model-related type definitions for the SDK."""
+
+from .model import Model
+from .openai import OpenAIModel
+
+__all__ = ["Model", "OpenAIModel"]
diff --git a/src/strands/types/models.py b/src/strands/types/models/model.py
similarity index 97%
rename from src/strands/types/models.py
rename to src/strands/types/models/model.py
index e3d96e29..23e74602 100644
--- a/src/strands/types/models.py
+++ b/src/strands/types/models/model.py
@@ -4,9 +4,9 @@
import logging
from typing import Any, Iterable, Optional
-from .content import Messages
-from .streaming import StreamEvent
-from .tools import ToolSpec
+from ..content import Messages
+from ..streaming import StreamEvent
+from ..tools import ToolSpec
logger = logging.getLogger(__name__)
diff --git a/src/strands/types/models/openai.py b/src/strands/types/models/openai.py
new file mode 100644
index 00000000..307c0be6
--- /dev/null
+++ b/src/strands/types/models/openai.py
@@ -0,0 +1,254 @@
+"""Base OpenAI model provider.
+
+This module provides the base OpenAI model provider class which implements shared logic for formatting requests and
+responses to and from the OpenAI specification.
+
+- Docs: https://pypi.org/project/openai
+"""
+
+import abc
+import base64
+import json
+import logging
+import mimetypes
+from typing import Any, Optional
+
+from typing_extensions import override
+
+from ..content import ContentBlock, Messages
+from ..streaming import StreamEvent
+from ..tools import ToolResult, ToolSpec, ToolUse
+from .model import Model
+
+logger = logging.getLogger(__name__)
+
+
+class OpenAIModel(Model, abc.ABC):
+ """Base OpenAI model provider implementation.
+
+ Implements shared logic for formatting requests and responses to and from the OpenAI specification.
+ """
+
+ config: dict[str, Any]
+
+ @staticmethod
+ def format_request_message_content(content: ContentBlock) -> dict[str, Any]:
+ """Format an OpenAI compatible content block.
+
+ Args:
+ content: Message content.
+
+ Returns:
+ OpenAI compatible content block.
+ """
+ if "document" in content:
+ mime_type = mimetypes.types_map.get(f".{content['document']['format']}", "application/octet-stream")
+ file_data = base64.b64encode(content["document"]["source"]["bytes"]).decode("utf-8")
+ return {
+ "file": {
+ "file_data": f"data:{mime_type};base64,{file_data}",
+ "filename": content["document"]["name"],
+ },
+ "type": "file",
+ }
+
+ if "image" in content:
+ mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream")
+ image_data = content["image"]["source"]["bytes"].decode("utf-8")
+ return {
+ "image_url": {
+ "detail": "auto",
+ "format": mime_type,
+ "url": f"data:{mime_type};base64,{image_data}",
+ },
+ "type": "image_url",
+ }
+
+ if "text" in content:
+ return {"text": content["text"], "type": "text"}
+
+ return {"text": json.dumps(content), "type": "text"}
+
+ @staticmethod
+ def format_request_message_tool_call(tool_use: ToolUse) -> dict[str, Any]:
+ """Format an OpenAI compatible tool call.
+
+ Args:
+ tool_use: Tool use requested by the model.
+
+ Returns:
+ OpenAI compatible tool call.
+ """
+ return {
+ "function": {
+ "arguments": json.dumps(tool_use["input"]),
+ "name": tool_use["name"],
+ },
+ "id": tool_use["toolUseId"],
+ "type": "function",
+ }
+
+ @staticmethod
+ def format_request_tool_message(tool_result: ToolResult) -> dict[str, Any]:
+ """Format an OpenAI compatible tool message.
+
+ Args:
+ tool_result: Tool result collected from a tool execution.
+
+ Returns:
+ OpenAI compatible tool message.
+ """
+ return {
+ "role": "tool",
+ "tool_call_id": tool_result["toolUseId"],
+ "content": json.dumps(
+ {
+ "content": tool_result["content"],
+ "status": tool_result["status"],
+ }
+ ),
+ }
+
+ @classmethod
+ def format_request_messages(cls, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]:
+ """Format an OpenAI compatible messages array.
+
+ Args:
+ messages: List of message objects to be processed by the model.
+ system_prompt: System prompt to provide context to the model.
+
+ Returns:
+ An OpenAI compatible messages array.
+ """
+ formatted_messages: list[dict[str, Any]]
+ formatted_messages = [{"role": "system", "content": system_prompt}] if system_prompt else []
+
+ for message in messages:
+ contents = message["content"]
+
+ formatted_contents = [
+ cls.format_request_message_content(content)
+ for content in contents
+ if not any(block_type in content for block_type in ["toolResult", "toolUse"])
+ ]
+ formatted_tool_calls = [
+ cls.format_request_message_tool_call(content["toolUse"]) for content in contents if "toolUse" in content
+ ]
+ formatted_tool_messages = [
+ cls.format_request_tool_message(content["toolResult"])
+ for content in contents
+ if "toolResult" in content
+ ]
+
+ formatted_message = {
+ "role": message["role"],
+ "content": formatted_contents,
+ **({"tool_calls": formatted_tool_calls} if formatted_tool_calls else {}),
+ }
+ formatted_messages.append(formatted_message)
+ formatted_messages.extend(formatted_tool_messages)
+
+ return [message for message in formatted_messages if message["content"] or "tool_calls" in message]
+
+ @override
+ def format_request(
+ self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None
+ ) -> dict[str, Any]:
+ """Format an OpenAI compatible chat streaming request.
+
+ Args:
+ messages: List of message objects to be processed by the model.
+ tool_specs: List of tool specifications to make available to the model.
+ system_prompt: System prompt to provide context to the model.
+
+ Returns:
+ An OpenAI compatible chat streaming request.
+ """
+ return {
+ "messages": self.format_request_messages(messages, system_prompt),
+ "model": self.config["model_id"],
+ "stream": True,
+ "stream_options": {"include_usage": True},
+ "tools": [
+ {
+ "type": "function",
+ "function": {
+ "name": tool_spec["name"],
+ "description": tool_spec["description"],
+ "parameters": tool_spec["inputSchema"]["json"],
+ },
+ }
+ for tool_spec in tool_specs or []
+ ],
+ **(self.config.get("params") or {}),
+ }
+
+ @override
+ def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
+ """Format an OpenAI response event into a standardized message chunk.
+
+ Args:
+ event: A response event from the OpenAI compatible model.
+
+ Returns:
+ The formatted chunk.
+
+ Raises:
+ RuntimeError: If chunk_type is not recognized.
+ This error should never be encountered as chunk_type is controlled in the stream method.
+ """
+ match event["chunk_type"]:
+ case "message_start":
+ return {"messageStart": {"role": "assistant"}}
+
+ case "content_start":
+ if event["data_type"] == "tool":
+ return {
+ "contentBlockStart": {
+ "start": {
+ "toolUse": {
+ "name": event["data"].function.name,
+ "toolUseId": event["data"].id,
+ }
+ }
+ }
+ }
+
+ return {"contentBlockStart": {"start": {}}}
+
+ case "content_delta":
+ if event["data_type"] == "tool":
+ return {
+ "contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments or ""}}}
+ }
+
+ return {"contentBlockDelta": {"delta": {"text": event["data"]}}}
+
+ case "content_stop":
+ return {"contentBlockStop": {}}
+
+ case "message_stop":
+ match event["data"]:
+ case "tool_calls":
+ return {"messageStop": {"stopReason": "tool_use"}}
+ case "length":
+ return {"messageStop": {"stopReason": "max_tokens"}}
+ case _:
+ return {"messageStop": {"stopReason": "end_turn"}}
+
+ case "metadata":
+ return {
+ "metadata": {
+ "usage": {
+ "inputTokens": event["data"].prompt_tokens,
+ "outputTokens": event["data"].completion_tokens,
+ "totalTokens": event["data"].total_tokens,
+ },
+ "metrics": {
+ "latencyMs": 0, # TODO
+ },
+ },
+ }
+
+ case _:
+ raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type")
diff --git a/src/strands/types/streaming.py b/src/strands/types/streaming.py
index db600f15..9c99b210 100644
--- a/src/strands/types/streaming.py
+++ b/src/strands/types/streaming.py
@@ -157,7 +157,7 @@ class ModelStreamErrorEvent(ExceptionEvent):
originalStatusCode: int
-class RedactContentEvent(TypedDict):
+class RedactContentEvent(TypedDict, total=False):
"""Event for redacting content.
Attributes:
diff --git a/tests-integ/test_model_bedrock.py b/tests-integ/test_model_bedrock.py
new file mode 100644
index 00000000..a6a29aa9
--- /dev/null
+++ b/tests-integ/test_model_bedrock.py
@@ -0,0 +1,120 @@
+import pytest
+
+import strands
+from strands import Agent
+from strands.models import BedrockModel
+
+
+@pytest.fixture
+def system_prompt():
+ return "You are an AI assistant that uses & instead of ."
+
+
+@pytest.fixture
+def streaming_model():
+ return BedrockModel(
+ model_id="us.anthropic.claude-3-7-sonnet-20250219-v1:0",
+ streaming=True,
+ )
+
+
+@pytest.fixture
+def non_streaming_model():
+ return BedrockModel(
+ model_id="us.meta.llama3-2-90b-instruct-v1:0",
+ streaming=False,
+ )
+
+
+@pytest.fixture
+def streaming_agent(streaming_model, system_prompt):
+ return Agent(model=streaming_model, system_prompt=system_prompt, load_tools_from_directory=False)
+
+
+@pytest.fixture
+def non_streaming_agent(non_streaming_model, system_prompt):
+ return Agent(model=non_streaming_model, system_prompt=system_prompt, load_tools_from_directory=False)
+
+
+def test_streaming_agent(streaming_agent):
+ """Test agent with streaming model."""
+ result = streaming_agent("Hello!")
+
+ assert len(str(result)) > 0
+
+
+def test_non_streaming_agent(non_streaming_agent):
+ """Test agent with non-streaming model."""
+ result = non_streaming_agent("Hello!")
+
+ assert len(str(result)) > 0
+
+
+def test_streaming_model_events(streaming_model):
+ """Test streaming model events."""
+ messages = [{"role": "user", "content": [{"text": "Hello"}]}]
+
+ # Call converse and collect events
+ events = list(streaming_model.converse(messages))
+
+ # Verify basic structure of events
+ assert any("messageStart" in event for event in events)
+ assert any("contentBlockDelta" in event for event in events)
+ assert any("messageStop" in event for event in events)
+
+
+def test_non_streaming_model_events(non_streaming_model):
+ """Test non-streaming model events."""
+ messages = [{"role": "user", "content": [{"text": "Hello"}]}]
+
+ # Call converse and collect events
+ events = list(non_streaming_model.converse(messages))
+
+ # Verify basic structure of events
+ assert any("messageStart" in event for event in events)
+ assert any("contentBlockDelta" in event for event in events)
+ assert any("messageStop" in event for event in events)
+
+
+def test_tool_use_streaming(streaming_model):
+ """Test tool use with streaming model."""
+
+ tool_was_called = False
+
+ @strands.tool
+ def calculator(expression: str) -> float:
+ """Calculate the result of a mathematical expression."""
+
+ nonlocal tool_was_called
+ tool_was_called = True
+ return eval(expression)
+
+ agent = Agent(model=streaming_model, tools=[calculator], load_tools_from_directory=False)
+ result = agent("What is 123 + 456?")
+
+ # Print the full message content for debugging
+ print("\nFull message content:")
+ import json
+
+ print(json.dumps(result.message["content"], indent=2))
+
+ assert tool_was_called
+
+
+def test_tool_use_non_streaming(non_streaming_model):
+ """Test tool use with non-streaming model."""
+
+ tool_was_called = False
+
+ @strands.tool
+ def calculator(expression: str) -> float:
+ """Calculate the result of a mathematical expression."""
+
+ nonlocal tool_was_called
+ tool_was_called = True
+ return eval(expression)
+
+ agent = Agent(model=non_streaming_model, tools=[calculator], load_tools_from_directory=False)
+ agent("What is 123 + 456?")
+
+ assert tool_was_called
diff --git a/tests-integ/test_model_openai.py b/tests-integ/test_model_openai.py
new file mode 100644
index 00000000..c9046ad5
--- /dev/null
+++ b/tests-integ/test_model_openai.py
@@ -0,0 +1,46 @@
+import os
+
+import pytest
+
+import strands
+from strands import Agent
+from strands.models.openai import OpenAIModel
+
+
+@pytest.fixture
+def model():
+ return OpenAIModel(
+ model_id="gpt-4o",
+ client_args={
+ "api_key": os.getenv("OPENAI_API_KEY"),
+ },
+ )
+
+
+@pytest.fixture
+def tools():
+ @strands.tool
+ def tool_time() -> str:
+ return "12:00"
+
+ @strands.tool
+ def tool_weather() -> str:
+ return "sunny"
+
+ return [tool_time, tool_weather]
+
+
+@pytest.fixture
+def agent(model, tools):
+ return Agent(model=model, tools=tools)
+
+
+@pytest.mark.skipif(
+ "OPENAI_API_KEY" not in os.environ,
+ reason="OPENAI_API_KEY environment variable missing",
+)
+def test_agent(agent):
+ result = agent("What is the time and weather in New York?")
+ text = result.message["content"][0]["text"].lower()
+
+ assert all(string in text for string in ["12:00", "sunny"])
diff --git a/tests-integ/test_stream_agent.py b/tests-integ/test_stream_agent.py
index 4c97db6b..01f20339 100644
--- a/tests-integ/test_stream_agent.py
+++ b/tests-integ/test_stream_agent.py
@@ -1,5 +1,5 @@
"""
-Test script for Strands's custom callback handler functionality.
+Test script for Strands' custom callback handler functionality.
Demonstrates different patterns of callback handling and processing.
"""
diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py
index 5d2ffb23..ea06fb4e 100644
--- a/tests/strands/agent/test_agent.py
+++ b/tests/strands/agent/test_agent.py
@@ -9,7 +9,8 @@
import pytest
import strands
-from strands.agent.agent import Agent
+from strands import Agent
+from strands.agent import AgentResult
from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager
from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager
from strands.handlers.callback_handler import PrintingCallbackHandler, null_callback_handler
@@ -337,17 +338,47 @@ def test_agent__call__passes_kwargs(mock_model, system_prompt, callback_handler,
],
]
+ override_system_prompt = "Override system prompt"
+ override_model = unittest.mock.Mock()
+ override_tool_execution_handler = unittest.mock.Mock()
+ override_event_loop_metrics = unittest.mock.Mock()
+ override_callback_handler = unittest.mock.Mock()
+ override_tool_handler = unittest.mock.Mock()
+ override_messages = [{"role": "user", "content": [{"text": "override msg"}]}]
+ override_tool_config = {"test": "config"}
+
def check_kwargs(some_value, **kwargs):
assert some_value == "a_value"
assert kwargs is not None
+ assert kwargs["system_prompt"] == override_system_prompt
+ assert kwargs["model"] == override_model
+ assert kwargs["tool_execution_handler"] == override_tool_execution_handler
+ assert kwargs["event_loop_metrics"] == override_event_loop_metrics
+ assert kwargs["callback_handler"] == override_callback_handler
+ assert kwargs["tool_handler"] == override_tool_handler
+ assert kwargs["messages"] == override_messages
+ assert kwargs["tool_config"] == override_tool_config
+ assert kwargs["agent"] == agent
# Return expected values from event_loop_cycle
return "stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {}
mock_event_loop_cycle.side_effect = check_kwargs
- agent("test message", some_value="a_value")
- assert mock_event_loop_cycle.call_count == 1
+ agent(
+ "test message",
+ some_value="a_value",
+ system_prompt=override_system_prompt,
+ model=override_model,
+ tool_execution_handler=override_tool_execution_handler,
+ event_loop_metrics=override_event_loop_metrics,
+ callback_handler=override_callback_handler,
+ tool_handler=override_tool_handler,
+ messages=override_messages,
+ tool_config=override_tool_config,
+ )
+
+ mock_event_loop_cycle.assert_called_once()
def test_agent__call__retry_with_reduced_context(mock_model, agent, tool):
@@ -428,7 +459,7 @@ def test_agent__call__always_sliding_window_conversation_manager_doesnt_infinite
with pytest.raises(ContextWindowOverflowException):
agent("Test!")
- assert conversation_manager_spy.reduce_context.call_count == 251
+ assert conversation_manager_spy.reduce_context.call_count > 0
assert conversation_manager_spy.apply_management.call_count == 1
@@ -566,7 +597,7 @@ def test_agent_tool_user_message_override(agent):
},
{
"text": (
- "agent.tool_decorated direct tool call\n"
+ "agent.tool.tool_decorated direct tool call.\n"
"Input parameters: "
'{"random_string": "abcdEfghI123", "user_message_override": "test override"}\n'
),
@@ -657,8 +688,6 @@ def test_agent_with_callback_handler_none_uses_null_handler():
@pytest.mark.asyncio
async def test_stream_async_returns_all_events(mock_event_loop_cycle):
- mock_event_loop_cycle.side_effect = ValueError("Test exception")
-
agent = Agent()
# Define the side effect to simulate callback handler being called multiple times
@@ -922,6 +951,52 @@ def test_agent_call_creates_and_ends_span_on_success(mock_get_tracer, mock_model
mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, response=result)
+@pytest.mark.asyncio
+@unittest.mock.patch("strands.agent.agent.get_tracer")
+async def test_agent_stream_async_creates_and_ends_span_on_success(mock_get_tracer, mock_event_loop_cycle):
+ """Test that stream_async creates and ends a span when the call succeeds."""
+ # Setup mock tracer and span
+ mock_tracer = unittest.mock.MagicMock()
+ mock_span = unittest.mock.MagicMock()
+ mock_tracer.start_agent_span.return_value = mock_span
+ mock_get_tracer.return_value = mock_tracer
+
+ # Define the side effect to simulate callback handler being called multiple times
+ def call_callback_handler(*args, **kwargs):
+ # Extract the callback handler from kwargs
+ callback_handler = kwargs.get("callback_handler")
+ # Call the callback handler with different data values
+ callback_handler(data="First chunk")
+ callback_handler(data="Second chunk")
+ callback_handler(data="Final chunk", complete=True)
+ # Return expected values from event_loop_cycle
+ return "stop", {"role": "assistant", "content": [{"text": "Agent Response"}]}, {}, {}
+
+ mock_event_loop_cycle.side_effect = call_callback_handler
+
+ # Create agent and make a call
+ agent = Agent(model=mock_model)
+ iterator = agent.stream_async("test prompt")
+ async for _event in iterator:
+ pass # NoOp
+
+ # Verify span was created
+ mock_tracer.start_agent_span.assert_called_once_with(
+ prompt="test prompt",
+ model_id=unittest.mock.ANY,
+ tools=agent.tool_names,
+ system_prompt=agent.system_prompt,
+ custom_trace_attributes=agent.trace_attributes,
+ )
+
+ expected_response = AgentResult(
+ stop_reason="stop", message={"role": "assistant", "content": [{"text": "Agent Response"}]}, metrics={}, state={}
+ )
+
+ # Verify span was ended with the result
+ mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, response=expected_response)
+
+
@unittest.mock.patch("strands.agent.agent.get_tracer")
def test_agent_call_creates_and_ends_span_on_exception(mock_get_tracer, mock_model):
"""Test that __call__ creates and ends a span when an exception occurs."""
@@ -955,6 +1030,42 @@ def test_agent_call_creates_and_ends_span_on_exception(mock_get_tracer, mock_mod
mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception)
+@pytest.mark.asyncio
+@unittest.mock.patch("strands.agent.agent.get_tracer")
+async def test_agent_stream_async_creates_and_ends_span_on_exception(mock_get_tracer, mock_model):
+ """Test that stream_async creates and ends a span when the call succeeds."""
+ # Setup mock tracer and span
+ mock_tracer = unittest.mock.MagicMock()
+ mock_span = unittest.mock.MagicMock()
+ mock_tracer.start_agent_span.return_value = mock_span
+ mock_get_tracer.return_value = mock_tracer
+
+ # Define the side effect to simulate callback handler raising an Exception
+ test_exception = ValueError("Test exception")
+ mock_model.mock_converse.side_effect = test_exception
+
+ # Create agent and make a call
+ agent = Agent(model=mock_model)
+
+ # Call the agent and catch the exception
+ with pytest.raises(ValueError):
+ iterator = agent.stream_async("test prompt")
+ async for _event in iterator:
+ pass # NoOp
+
+ # Verify span was created
+ mock_tracer.start_agent_span.assert_called_once_with(
+ prompt="test prompt",
+ model_id=unittest.mock.ANY,
+ tools=agent.tool_names,
+ system_prompt=agent.system_prompt,
+ custom_trace_attributes=agent.trace_attributes,
+ )
+
+ # Verify span was ended with the exception
+ mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception)
+
+
@unittest.mock.patch("strands.agent.agent.get_tracer")
def test_event_loop_cycle_includes_parent_span(mock_get_tracer, mock_event_loop_cycle, mock_model):
"""Test that event_loop_cycle is called with the parent span."""
diff --git a/tests/strands/agent/test_conversation_manager.py b/tests/strands/agent/test_conversation_manager.py
index 2f6ee77d..b6132f1d 100644
--- a/tests/strands/agent/test_conversation_manager.py
+++ b/tests/strands/agent/test_conversation_manager.py
@@ -111,41 +111,7 @@ def conversation_manager(request):
{"role": "assistant", "content": [{"text": "Second response"}]},
],
),
- # 7 - Message count above max window size - Remove dangling tool uses and tool results
- (
- {"window_size": 1},
- [
- {"role": "user", "content": [{"text": "First message"}]},
- {"role": "assistant", "content": [{"toolUse": {"toolUseId": "321", "name": "tool1", "input": {}}}]},
- {
- "role": "user",
- "content": [
- {"toolResult": {"toolUseId": "123", "content": [{"text": "Hello!"}], "status": "success"}}
- ],
- },
- ],
- [
- {
- "role": "user",
- "content": [{"text": "\nTool Result Text Content: Hello!\nTool Result Status: success"}],
- },
- ],
- ),
- # 8 - Message count above max window size - Remove multiple tool use/tool result pairs
- (
- {"window_size": 1},
- [
- {"role": "user", "content": [{"toolResult": {"toolUseId": "123", "content": [], "status": "success"}}]},
- {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "tool1", "input": {}}}]},
- {"role": "user", "content": [{"toolResult": {"toolUseId": "456", "content": [], "status": "success"}}]},
- {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool1", "input": {}}}]},
- {"role": "user", "content": [{"toolResult": {"toolUseId": "789", "content": [], "status": "success"}}]},
- ],
- [
- {"role": "user", "content": [{"text": "Tool Result Status: success"}]},
- ],
- ),
- # 9 - Message count above max window size - Preserve tool use/tool result pairs
+ # 7 - Message count above max window size - Preserve tool use/tool result pairs
(
{"window_size": 2},
[
@@ -158,7 +124,7 @@ def conversation_manager(request):
{"role": "user", "content": [{"toolResult": {"toolUseId": "456", "content": [], "status": "success"}}]},
],
),
- # 10 - Test sliding window behavior - preserve tool use/result pairs across cut boundary
+ # 8 - Test sliding window behavior - preserve tool use/result pairs across cut boundary
(
{"window_size": 3},
[
@@ -173,7 +139,7 @@ def conversation_manager(request):
{"role": "assistant", "content": [{"text": "Response after tool use"}]},
],
),
- # 11 - Test sliding window with multiple tool pairs that need preservation
+ # 9 - Test sliding window with multiple tool pairs that need preservation
(
{"window_size": 4},
[
@@ -185,7 +151,6 @@ def conversation_manager(request):
{"role": "assistant", "content": [{"text": "Final response"}]},
],
[
- {"role": "user", "content": [{"text": "Tool Result Status: success"}]},
{"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool2", "input": {}}}]},
{"role": "user", "content": [{"toolResult": {"toolUseId": "456", "content": [], "status": "success"}}]},
{"role": "assistant", "content": [{"text": "Final response"}]},
@@ -200,6 +165,20 @@ def test_apply_management(conversation_manager, messages, expected_messages):
assert messages == expected_messages
+def test_sliding_window_conversation_manager_with_untrimmable_history_raises_context_window_overflow_exception():
+ manager = strands.agent.conversation_manager.SlidingWindowConversationManager(1)
+ messages = [
+ {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool1", "input": {}}}]},
+ {"role": "user", "content": [{"toolResult": {"toolUseId": "789", "content": [], "status": "success"}}]},
+ ]
+ original_messages = messages.copy()
+
+ with pytest.raises(ContextWindowOverflowException):
+ manager.apply_management(messages)
+
+ assert messages == original_messages
+
+
def test_null_conversation_manager_reduce_context_raises_context_window_overflow_exception():
"""Test that NullConversationManager doesn't modify messages."""
manager = strands.agent.conversation_manager.NullConversationManager()
diff --git a/tests/strands/event_loop/test_message_processor.py b/tests/strands/event_loop/test_message_processor.py
new file mode 100644
index 00000000..395c71a1
--- /dev/null
+++ b/tests/strands/event_loop/test_message_processor.py
@@ -0,0 +1,128 @@
+import copy
+
+import pytest
+
+from strands.event_loop import message_processor
+
+
+@pytest.mark.parametrize(
+ "messages,expected,expected_messages",
+ [
+ # Orphaned toolUse with empty input, no toolResult
+ (
+ [
+ {"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "input": {}, "name": "foo"}}]},
+ {"role": "user", "content": [{"toolResult": {"toolUseId": "2"}}]},
+ ],
+ True,
+ [
+ {"role": "assistant", "content": [{"text": "[Attempted to use foo, but operation was canceled]"}]},
+ {"role": "user", "content": [{"toolResult": {"toolUseId": "2"}}]},
+ ],
+ ),
+ # toolUse with input, has matching toolResult
+ (
+ [
+ {"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "input": {"a": 1}, "name": "foo"}}]},
+ {"role": "user", "content": [{"toolResult": {"toolUseId": "1"}}]},
+ ],
+ False,
+ [
+ {"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "input": {"a": 1}, "name": "foo"}}]},
+ {"role": "user", "content": [{"toolResult": {"toolUseId": "1"}}]},
+ ],
+ ),
+ # No messages
+ (
+ [],
+ False,
+ [],
+ ),
+ ],
+)
+def test_clean_orphaned_empty_tool_uses(messages, expected, expected_messages):
+ test_messages = copy.deepcopy(messages)
+ result = message_processor.clean_orphaned_empty_tool_uses(test_messages)
+ assert result == expected
+ assert test_messages == expected_messages
+
+
+@pytest.mark.parametrize(
+ "messages,expected_idx",
+ [
+ (
+ [
+ {"role": "user", "content": [{"text": "hi"}]},
+ {"role": "user", "content": [{"toolResult": {"toolUseId": "1"}}]},
+ {"role": "assistant", "content": [{"text": "ok"}]},
+ ],
+ 1,
+ ),
+ (
+ [
+ {"role": "user", "content": [{"text": "hi"}]},
+ {"role": "assistant", "content": [{"text": "ok"}]},
+ ],
+ None,
+ ),
+ (
+ [],
+ None,
+ ),
+ ],
+)
+def test_find_last_message_with_tool_results(messages, expected_idx):
+ idx = message_processor.find_last_message_with_tool_results(messages)
+ assert idx == expected_idx
+
+
+@pytest.mark.parametrize(
+ "messages,msg_idx,expected_changed,expected_content",
+ [
+ (
+ [
+ {
+ "role": "user",
+ "content": [{"toolResult": {"toolUseId": "1", "status": "ok", "content": [{"text": "big"}]}}],
+ }
+ ],
+ 0,
+ True,
+ [
+ {
+ "toolResult": {
+ "toolUseId": "1",
+ "status": "error",
+ "content": [{"text": "The tool result was too large!"}],
+ }
+ }
+ ],
+ ),
+ (
+ [{"role": "user", "content": [{"text": "no tool result"}]}],
+ 0,
+ False,
+ [{"text": "no tool result"}],
+ ),
+ (
+ [],
+ 0,
+ False,
+ [],
+ ),
+ (
+ [{"role": "user", "content": [{"toolResult": {"toolUseId": "1"}}]}],
+ 2,
+ False,
+ [{"toolResult": {"toolUseId": "1"}}],
+ ),
+ ],
+)
+def test_truncate_tool_results(messages, msg_idx, expected_changed, expected_content):
+ test_messages = copy.deepcopy(messages)
+ changed = message_processor.truncate_tool_results(test_messages, msg_idx)
+ assert changed == expected_changed
+ if 0 <= msg_idx < len(test_messages):
+ assert test_messages[msg_idx]["content"] == expected_content
+ else:
+ assert test_messages == messages
diff --git a/tests/strands/handlers/test_callback_handler.py b/tests/strands/handlers/test_callback_handler.py
index 20e238cb..6fb2af07 100644
--- a/tests/strands/handlers/test_callback_handler.py
+++ b/tests/strands/handlers/test_callback_handler.py
@@ -30,6 +30,31 @@ def test_call_with_empty_args(handler, mock_print):
mock_print.assert_not_called()
+def test_call_handler_reasoningText(handler, mock_print):
+ """Test calling the handler with reasoningText."""
+ handler(reasoningText="This is reasoning text")
+ # Should print reasoning text without newline
+ mock_print.assert_called_once_with("This is reasoning text", end="")
+
+
+def test_call_without_reasoningText(handler, mock_print):
+ """Test calling the handler without reasoningText argument."""
+ handler(data="Some output")
+ # Should only print data, not reasoningText
+ mock_print.assert_called_once_with("Some output", end="")
+
+
+def test_call_with_reasoningText_and_data(handler, mock_print):
+ """Test calling the handler with both reasoningText and data."""
+ handler(reasoningText="Reasoning", data="Output")
+ # Should print reasoningText and data, both without newline
+ calls = [
+ unittest.mock.call("Reasoning", end=""),
+ unittest.mock.call("Output", end=""),
+ ]
+ mock_print.assert_has_calls(calls)
+
+
def test_call_with_data_incomplete(handler, mock_print):
"""Test calling the handler with data but not complete."""
handler(data="Test output")
diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py
index 48a1da37..2ee344cc 100644
--- a/tests/strands/models/test_anthropic.py
+++ b/tests/strands/models/test_anthropic.py
@@ -102,19 +102,46 @@ def test_format_request_with_system_prompt(model, messages, model_id, max_tokens
assert tru_request == exp_request
-def test_format_request_with_document(model, model_id, max_tokens):
+@pytest.mark.parametrize(
+ ("content", "formatted_content"),
+ [
+ # PDF
+ (
+ {
+ "document": {"format": "pdf", "name": "test doc", "source": {"bytes": b"pdf"}},
+ },
+ {
+ "source": {
+ "data": "cGRm",
+ "media_type": "application/pdf",
+ "type": "base64",
+ },
+ "title": "test doc",
+ "type": "document",
+ },
+ ),
+ # Plain text
+ (
+ {
+ "document": {"format": "txt", "name": "test doc", "source": {"bytes": b"txt"}},
+ },
+ {
+ "source": {
+ "data": "txt",
+ "media_type": "text/plain",
+ "type": "text",
+ },
+ "title": "test doc",
+ "type": "document",
+ },
+ ),
+ ],
+)
+def test_format_request_with_document(content, formatted_content, model, model_id, max_tokens):
messages = [
{
"role": "user",
- "content": [
- {
- "document": {
- "format": "pdf",
- "name": "test-doc",
- "source": {"bytes": b"base64encodeddoc"},
- },
- },
- ],
+ "content": [content],
},
]
@@ -124,17 +151,7 @@ def test_format_request_with_document(model, model_id, max_tokens):
"messages": [
{
"role": "user",
- "content": [
- {
- "source": {
- "data": "YmFzZTY0ZW5jb2RlZGRvYw==",
- "media_type": "application/pdf",
- "type": "base64",
- },
- "title": "test-doc",
- "type": "document",
- },
- ],
+ "content": [formatted_content],
},
],
"model": model_id,
diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py
index 566671ce..b326eee7 100644
--- a/tests/strands/models/test_bedrock.py
+++ b/tests/strands/models/test_bedrock.py
@@ -1,7 +1,9 @@
+import os
import unittest.mock
import boto3
import pytest
+from botocore.config import Config as BotocoreConfig
from botocore.exceptions import ClientError, EventStreamError
import strands
@@ -99,12 +101,70 @@ def test__init__with_custom_region(bedrock_client):
mock_session_cls.assert_called_once_with(region_name=custom_region)
+def test__init__with_environment_variable_region(bedrock_client):
+ """Test that BedrockModel uses the provided region."""
+ _ = bedrock_client
+ os.environ["AWS_REGION"] = "eu-west-1"
+
+ with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls:
+ _ = BedrockModel()
+ mock_session_cls.assert_called_once_with(region_name="eu-west-1")
+
+
def test__init__with_region_and_session_raises_value_error():
"""Test that BedrockModel raises ValueError when both region and session are provided."""
with pytest.raises(ValueError):
_ = BedrockModel(region_name="us-east-1", boto_session=boto3.Session(region_name="us-east-1"))
+def test__init__default_user_agent(bedrock_client):
+ """Set user agent when no boto_client_config is provided."""
+ with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls:
+ mock_session = mock_session_cls.return_value
+ _ = BedrockModel()
+
+ # Verify the client was created with the correct config
+ mock_session.client.assert_called_once()
+ args, kwargs = mock_session.client.call_args
+ assert kwargs["service_name"] == "bedrock-runtime"
+ assert isinstance(kwargs["config"], BotocoreConfig)
+ assert kwargs["config"].user_agent_extra == "strands-agents"
+
+
+def test__init__with_custom_boto_client_config_no_user_agent(bedrock_client):
+ """Set user agent when boto_client_config is provided without user_agent_extra."""
+ custom_config = BotocoreConfig(read_timeout=900)
+
+ with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls:
+ mock_session = mock_session_cls.return_value
+ _ = BedrockModel(boto_client_config=custom_config)
+
+ # Verify the client was created with the correct config
+ mock_session.client.assert_called_once()
+ args, kwargs = mock_session.client.call_args
+ assert kwargs["service_name"] == "bedrock-runtime"
+ assert isinstance(kwargs["config"], BotocoreConfig)
+ assert kwargs["config"].user_agent_extra == "strands-agents"
+ assert kwargs["config"].read_timeout == 900
+
+
+def test__init__with_custom_boto_client_config_with_user_agent(bedrock_client):
+ """Append to existing user agent when boto_client_config is provided with user_agent_extra."""
+ custom_config = BotocoreConfig(user_agent_extra="existing-agent", read_timeout=900)
+
+ with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls:
+ mock_session = mock_session_cls.return_value
+ _ = BedrockModel(boto_client_config=custom_config)
+
+ # Verify the client was created with the correct config
+ mock_session.client.assert_called_once()
+ args, kwargs = mock_session.client.call_args
+ assert kwargs["service_name"] == "bedrock-runtime"
+ assert isinstance(kwargs["config"], BotocoreConfig)
+ assert kwargs["config"].user_agent_extra == "existing-agent strands-agents"
+ assert kwargs["config"].read_timeout == 900
+
+
def test__init__model_config(bedrock_client):
_ = bedrock_client
@@ -294,9 +354,9 @@ def test_stream(bedrock_client, model):
def test_stream_throttling_exception_from_event_stream_error(bedrock_client, model):
- error_message = "ThrottlingException - Rate exceeded"
+ error_message = "Rate exceeded"
bedrock_client.converse_stream.side_effect = EventStreamError(
- {"Error": {"Message": error_message}}, "ConverseStream"
+ {"Error": {"Message": error_message, "Code": "ThrottlingException"}}, "ConverseStream"
)
request = {"a": 1}
@@ -361,7 +421,9 @@ def test_converse(bedrock_client, model, messages, tool_spec, model_id, addition
bedrock_client.converse_stream.assert_called_once_with(**request)
-def test_converse_input_guardrails(bedrock_client, model, messages, tool_spec, model_id, additional_request_fields):
+def test_converse_stream_input_guardrails(
+ bedrock_client, model, messages, tool_spec, model_id, additional_request_fields
+):
metadata_event = {
"metadata": {
"usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0},
@@ -370,7 +432,15 @@ def test_converse_input_guardrails(bedrock_client, model, messages, tool_spec, m
"guardrail": {
"inputAssessment": {
"3e59qlue4hag": {
- "wordPolicy": {"customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}]}
+ "wordPolicy": {
+ "customWords": [
+ {
+ "match": "CACTUS",
+ "action": "BLOCKED",
+ "detected": True,
+ }
+ ]
+ }
}
}
}
@@ -395,13 +465,18 @@ def test_converse_input_guardrails(bedrock_client, model, messages, tool_spec, m
chunks = model.converse(messages, [tool_spec])
tru_chunks = list(chunks)
- exp_chunks = [{"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, metadata_event]
+ exp_chunks = [
+ {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}},
+ metadata_event,
+ ]
assert tru_chunks == exp_chunks
bedrock_client.converse_stream.assert_called_once_with(**request)
-def test_converse_output_guardrails(bedrock_client, model, messages, tool_spec, model_id, additional_request_fields):
+def test_converse_stream_output_guardrails(
+ bedrock_client, model, messages, tool_spec, model_id, additional_request_fields
+):
model.update_config(guardrail_redact_input=False, guardrail_redact_output=True)
metadata_event = {
"metadata": {
@@ -413,7 +488,13 @@ def test_converse_output_guardrails(bedrock_client, model, messages, tool_spec,
"3e59qlue4hag": [
{
"wordPolicy": {
- "customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}]
+ "customWords": [
+ {
+ "match": "CACTUS",
+ "action": "BLOCKED",
+ "detected": True,
+ }
+ ]
},
}
]
@@ -440,7 +521,10 @@ def test_converse_output_guardrails(bedrock_client, model, messages, tool_spec,
chunks = model.converse(messages, [tool_spec])
tru_chunks = list(chunks)
- exp_chunks = [{"redactContent": {"redactAssistantContentMessage": "[Assistant output redacted.]"}}, metadata_event]
+ exp_chunks = [
+ {"redactContent": {"redactAssistantContentMessage": "[Assistant output redacted.]"}},
+ metadata_event,
+ ]
assert tru_chunks == exp_chunks
bedrock_client.converse_stream.assert_called_once_with(**request)
@@ -460,7 +544,13 @@ def test_converse_output_guardrails_redacts_input_and_output(
"3e59qlue4hag": [
{
"wordPolicy": {
- "customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}]
+ "customWords": [
+ {
+ "match": "CACTUS",
+ "action": "BLOCKED",
+ "detected": True,
+ }
+ ]
},
}
]
@@ -510,7 +600,13 @@ def test_converse_output_no_blocked_guardrails_doesnt_redact(
"3e59qlue4hag": [
{
"wordPolicy": {
- "customWords": [{"match": "CACTUS", "action": "NONE", "detected": True}]
+ "customWords": [
+ {
+ "match": "CACTUS",
+ "action": "NONE",
+ "detected": True,
+ }
+ ]
},
}
]
@@ -556,7 +652,13 @@ def test_converse_output_no_guardrail_redact(
"3e59qlue4hag": [
{
"wordPolicy": {
- "customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}]
+ "customWords": [
+ {
+ "match": "CACTUS",
+ "action": "BLOCKED",
+ "detected": True,
+ }
+ ]
},
}
]
@@ -580,7 +682,9 @@ def test_converse_output_no_guardrail_redact(
}
model.update_config(
- additional_request_fields=additional_request_fields, guardrail_redact_output=False, guardrail_redact_input=False
+ additional_request_fields=additional_request_fields,
+ guardrail_redact_output=False,
+ guardrail_redact_input=False,
)
chunks = model.converse(messages, [tool_spec])
@@ -589,3 +693,322 @@ def test_converse_output_no_guardrail_redact(
assert tru_chunks == exp_chunks
bedrock_client.converse_stream.assert_called_once_with(**request)
+
+
+def test_stream_with_streaming_false(bedrock_client):
+ """Test stream method with streaming=False."""
+ bedrock_client.converse.return_value = {
+ "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}},
+ "stopReason": "end_turn",
+ }
+ expected_events = [
+ {"messageStart": {"role": "assistant"}},
+ {"contentBlockDelta": {"delta": {"text": "test"}}},
+ {"contentBlockStop": {}},
+ {"messageStop": {"stopReason": "end_turn", "additionalModelResponseFields": None}},
+ ]
+
+ # Create model and call stream
+ model = BedrockModel(model_id="test-model", streaming=False)
+ request = {"modelId": "test-model"}
+ events = list(model.stream(request))
+
+ assert expected_events == events
+
+ bedrock_client.converse.assert_called_once()
+ bedrock_client.converse_stream.assert_not_called()
+
+
+def test_stream_with_streaming_false_and_tool_use(bedrock_client):
+ """Test stream method with streaming=False."""
+ bedrock_client.converse.return_value = {
+ "output": {
+ "message": {
+ "role": "assistant",
+ "content": [{"toolUse": {"toolUseId": "123", "name": "dummyTool", "input": {"hello": "world!"}}}],
+ }
+ },
+ "stopReason": "tool_use",
+ }
+
+ expected_events = [
+ {"messageStart": {"role": "assistant"}},
+ {"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "dummyTool"}}}},
+ {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"hello": "world!"}'}}}},
+ {"contentBlockStop": {}},
+ {"messageStop": {"stopReason": "tool_use", "additionalModelResponseFields": None}},
+ ]
+
+ # Create model and call stream
+ model = BedrockModel(model_id="test-model", streaming=False)
+ request = {"modelId": "test-model"}
+ events = list(model.stream(request))
+
+ assert expected_events == events
+
+ bedrock_client.converse.assert_called_once()
+ bedrock_client.converse_stream.assert_not_called()
+
+
+def test_stream_with_streaming_false_and_reasoning(bedrock_client):
+ """Test stream method with streaming=False."""
+ bedrock_client.converse.return_value = {
+ "output": {
+ "message": {
+ "role": "assistant",
+ "content": [
+ {
+ "reasoningContent": {
+ "reasoningText": {"text": "Thinking really hard....", "signature": "123"},
+ }
+ }
+ ],
+ }
+ },
+ "stopReason": "tool_use",
+ }
+
+ expected_events = [
+ {"messageStart": {"role": "assistant"}},
+ {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "Thinking really hard...."}}}},
+ {"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "123"}}}},
+ {"contentBlockStop": {}},
+ {"messageStop": {"stopReason": "tool_use", "additionalModelResponseFields": None}},
+ ]
+
+ # Create model and call stream
+ model = BedrockModel(model_id="test-model", streaming=False)
+ request = {"modelId": "test-model"}
+ events = list(model.stream(request))
+
+ assert expected_events == events
+
+ # Verify converse was called
+ bedrock_client.converse.assert_called_once()
+ bedrock_client.converse_stream.assert_not_called()
+
+
+def test_converse_and_reasoning_no_signature(bedrock_client):
+ """Test stream method with streaming=False."""
+ bedrock_client.converse.return_value = {
+ "output": {
+ "message": {
+ "role": "assistant",
+ "content": [
+ {
+ "reasoningContent": {
+ "reasoningText": {"text": "Thinking really hard...."},
+ }
+ }
+ ],
+ }
+ },
+ "stopReason": "tool_use",
+ }
+
+ expected_events = [
+ {"messageStart": {"role": "assistant"}},
+ {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "Thinking really hard...."}}}},
+ {"contentBlockStop": {}},
+ {"messageStop": {"stopReason": "tool_use", "additionalModelResponseFields": None}},
+ ]
+
+ # Create model and call stream
+ model = BedrockModel(model_id="test-model", streaming=False)
+ request = {"modelId": "test-model"}
+ events = list(model.stream(request))
+
+ assert expected_events == events
+
+ bedrock_client.converse.assert_called_once()
+ bedrock_client.converse_stream.assert_not_called()
+
+
+def test_stream_with_streaming_false_with_metrics_and_usage(bedrock_client):
+ """Test stream method with streaming=False."""
+ bedrock_client.converse.return_value = {
+ "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}},
+ "usage": {"inputTokens": 1234, "outputTokens": 1234, "totalTokens": 2468},
+ "metrics": {"latencyMs": 1234},
+ "stopReason": "tool_use",
+ }
+
+ expected_events = [
+ {"messageStart": {"role": "assistant"}},
+ {"contentBlockDelta": {"delta": {"text": "test"}}},
+ {"contentBlockStop": {}},
+ {"messageStop": {"stopReason": "tool_use", "additionalModelResponseFields": None}},
+ {
+ "metadata": {
+ "usage": {"inputTokens": 1234, "outputTokens": 1234, "totalTokens": 2468},
+ "metrics": {"latencyMs": 1234},
+ }
+ },
+ ]
+
+ # Create model and call stream
+ model = BedrockModel(model_id="test-model", streaming=False)
+ request = {"modelId": "test-model"}
+ events = list(model.stream(request))
+
+ assert expected_events == events
+
+ # Verify converse was called
+ bedrock_client.converse.assert_called_once()
+ bedrock_client.converse_stream.assert_not_called()
+
+
+def test_converse_input_guardrails(bedrock_client):
+ """Test stream method with streaming=False."""
+ bedrock_client.converse.return_value = {
+ "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}},
+ "trace": {
+ "guardrail": {
+ "inputAssessment": {
+ "3e59qlue4hag": {
+ "wordPolicy": {"customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}]}
+ }
+ }
+ }
+ },
+ "stopReason": "end_turn",
+ }
+
+ expected_events = [
+ {"messageStart": {"role": "assistant"}},
+ {"contentBlockDelta": {"delta": {"text": "test"}}},
+ {"contentBlockStop": {}},
+ {"messageStop": {"stopReason": "end_turn", "additionalModelResponseFields": None}},
+ {
+ "metadata": {
+ "trace": {
+ "guardrail": {
+ "inputAssessment": {
+ "3e59qlue4hag": {
+ "wordPolicy": {
+ "customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}]
+ }
+ }
+ }
+ }
+ }
+ }
+ },
+ {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}},
+ ]
+
+ # Create model and call stream
+ model = BedrockModel(model_id="test-model", streaming=False)
+ request = {"modelId": "test-model"}
+ events = list(model.stream(request))
+
+ assert expected_events == events
+
+ bedrock_client.converse.assert_called_once()
+ bedrock_client.converse_stream.assert_not_called()
+
+
+def test_converse_output_guardrails(bedrock_client):
+ """Test stream method with streaming=False."""
+ bedrock_client.converse.return_value = {
+ "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}},
+ "trace": {
+ "guardrail": {
+ "outputAssessments": {
+ "3e59qlue4hag": [
+ {
+ "wordPolicy": {"customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}]},
+ }
+ ]
+ },
+ }
+ },
+ "stopReason": "end_turn",
+ }
+
+ expected_events = [
+ {"messageStart": {"role": "assistant"}},
+ {"contentBlockDelta": {"delta": {"text": "test"}}},
+ {"contentBlockStop": {}},
+ {"messageStop": {"stopReason": "end_turn", "additionalModelResponseFields": None}},
+ {
+ "metadata": {
+ "trace": {
+ "guardrail": {
+ "outputAssessments": {
+ "3e59qlue4hag": [
+ {
+ "wordPolicy": {
+ "customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}]
+ }
+ }
+ ]
+ }
+ }
+ }
+ }
+ },
+ {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}},
+ ]
+
+ model = BedrockModel(model_id="test-model", streaming=False)
+ request = {"modelId": "test-model"}
+ events = list(model.stream(request))
+
+ assert expected_events == events
+
+ bedrock_client.converse.assert_called_once()
+ bedrock_client.converse_stream.assert_not_called()
+
+
+def test_converse_output_guardrails_redacts_output(bedrock_client):
+ """Test stream method with streaming=False."""
+ bedrock_client.converse.return_value = {
+ "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}},
+ "trace": {
+ "guardrail": {
+ "outputAssessments": {
+ "3e59qlue4hag": [
+ {
+ "wordPolicy": {"customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}]},
+ }
+ ]
+ },
+ }
+ },
+ "stopReason": "end_turn",
+ }
+
+ expected_events = [
+ {"messageStart": {"role": "assistant"}},
+ {"contentBlockDelta": {"delta": {"text": "test"}}},
+ {"contentBlockStop": {}},
+ {"messageStop": {"stopReason": "end_turn", "additionalModelResponseFields": None}},
+ {
+ "metadata": {
+ "trace": {
+ "guardrail": {
+ "outputAssessments": {
+ "3e59qlue4hag": [
+ {
+ "wordPolicy": {
+ "customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}]
+ }
+ }
+ ]
+ }
+ }
+ }
+ }
+ },
+ {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}},
+ ]
+
+ model = BedrockModel(model_id="test-model", streaming=False)
+ request = {"modelId": "test-model"}
+ events = list(model.stream(request))
+
+ assert expected_events == events
+
+ bedrock_client.converse.assert_called_once()
+ bedrock_client.converse_stream.assert_not_called()
diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py
index 5d4d9b40..528d1498 100644
--- a/tests/strands/models/test_litellm.py
+++ b/tests/strands/models/test_litellm.py
@@ -1,4 +1,3 @@
-import json
import unittest.mock
import pytest
@@ -8,9 +7,14 @@
@pytest.fixture
-def litellm_client():
+def litellm_client_cls():
with unittest.mock.patch.object(strands.models.litellm.litellm, "LiteLLM") as mock_client_cls:
- yield mock_client_cls.return_value
+ yield mock_client_cls
+
+
+@pytest.fixture
+def litellm_client(litellm_client_cls):
+ return litellm_client_cls.return_value
@pytest.fixture
@@ -35,15 +39,15 @@ def system_prompt():
return "s1"
-def test__init__model_configs(litellm_client, model_id):
- _ = litellm_client
+def test__init__(litellm_client_cls, model_id):
+ model = LiteLLMModel({"api_key": "k1"}, model_id=model_id, params={"max_tokens": 1})
- model = LiteLLMModel(model_id=model_id, params={"max_tokens": 1})
+ tru_config = model.get_config()
+ exp_config = {"model_id": "m1", "params": {"max_tokens": 1}}
- tru_max_tokens = model.get_config().get("params")
- exp_max_tokens = {"max_tokens": 1}
+ assert tru_config == exp_config
- assert tru_max_tokens == exp_max_tokens
+ litellm_client_cls.assert_called_once_with(api_key="k1")
def test_update_config(model, model_id):
@@ -55,513 +59,47 @@ def test_update_config(model, model_id):
assert tru_model_id == exp_model_id
-def test_format_request_default(model, messages, model_id):
- tru_request = model.format_request(messages)
- exp_request = {
- "messages": [{"role": "user", "content": [{"type": "text", "text": "test"}]}],
- "model": model_id,
- "stream": True,
- "stream_options": {
- "include_usage": True,
- },
- "tools": [],
- }
-
- assert tru_request == exp_request
-
-
-def test_format_request_with_params(model, messages, model_id):
- model.update_config(params={"max_tokens": 1})
-
- tru_request = model.format_request(messages)
- exp_request = {
- "messages": [{"role": "user", "content": [{"type": "text", "text": "test"}]}],
- "model": model_id,
- "stream": True,
- "stream_options": {
- "include_usage": True,
- },
- "tools": [],
- "max_tokens": 1,
- }
-
- assert tru_request == exp_request
-
-
-def test_format_request_with_system_prompt(model, messages, model_id, system_prompt):
- tru_request = model.format_request(messages, system_prompt=system_prompt)
- exp_request = {
- "messages": [
- {"role": "system", "content": system_prompt},
- {"role": "user", "content": [{"type": "text", "text": "test"}]},
- ],
- "model": model_id,
- "stream": True,
- "stream_options": {
- "include_usage": True,
- },
- "tools": [],
- }
-
- assert tru_request == exp_request
-
-
-def test_format_request_with_image(model, model_id):
- messages = [
- {
- "role": "user",
- "content": [
- {
- "image": {
- "format": "jpg",
- "source": {"bytes": b"base64encodedimage"},
- },
- },
- ],
- },
- ]
-
- tru_request = model.format_request(messages)
- exp_request = {
- "messages": [
- {
- "role": "user",
- "content": [
- {
- "image_url": {
- "detail": "auto",
- "format": "image/jpeg",
- "url": "data:image/jpeg;base64,base64encodedimage",
- },
- "type": "image_url",
- },
- ],
- },
- ],
- "model": model_id,
- "stream": True,
- "stream_options": {
- "include_usage": True,
- },
- "tools": [],
- }
-
- assert tru_request == exp_request
-
-
-def test_format_request_with_reasoning(model, model_id):
- messages = [
- {
- "role": "user",
- "content": [
- {
- "reasoningContent": {
- "reasoningText": {
- "signature": "reasoning_signature",
- "text": "reasoning_text",
- },
- },
- },
- ],
- },
- ]
-
- tru_request = model.format_request(messages)
- exp_request = {
- "messages": [
+@pytest.mark.parametrize(
+ "content, exp_result",
+ [
+ # Case 1: Thinking
+ (
{
- "role": "user",
- "content": [
- {
+ "reasoningContent": {
+ "reasoningText": {
"signature": "reasoning_signature",
- "thinking": "reasoning_text",
- "type": "thinking",
- },
- ],
- },
- ],
- "model": model_id,
- "stream": True,
- "stream_options": {
- "include_usage": True,
- },
- "tools": [],
- }
-
- assert tru_request == exp_request
-
-
-def test_format_request_with_video(model, model_id):
- messages = [
- {
- "role": "user",
- "content": [
- {
- "video": {
- "source": {"bytes": "base64encodedvideo"},
+ "text": "reasoning_text",
},
},
- ],
- },
- ]
-
- tru_request = model.format_request(messages)
- exp_request = {
- "messages": [
- {
- "role": "user",
- "content": [
- {
- "type": "video_url",
- "video_url": {
- "detail": "auto",
- "url": "base64encodedvideo",
- },
- },
- ],
},
- ],
- "model": model_id,
- "stream": True,
- "stream_options": {
- "include_usage": True,
- },
- "tools": [],
- }
-
- assert tru_request == exp_request
-
-
-def test_format_request_with_other(model, model_id):
- messages = [
- {
- "role": "user",
- "content": [{"other": {"a": 1}}],
- },
- ]
-
- tru_request = model.format_request(messages)
- exp_request = {
- "messages": [
{
- "role": "user",
- "content": [
- {
- "text": json.dumps({"other": {"a": 1}}),
- "type": "text",
- },
- ],
+ "signature": "reasoning_signature",
+ "thinking": "reasoning_text",
+ "type": "thinking",
},
- ],
- "model": model_id,
- "stream": True,
- "stream_options": {
- "include_usage": True,
- },
- "tools": [],
- }
-
- assert tru_request == exp_request
-
-
-def test_format_request_with_tool_result(model, model_id):
- messages = [
- {
- "role": "user",
- "content": [{"toolResult": {"toolUseId": "c1", "status": "success", "content": [{"value": 4}]}}],
- }
- ]
-
- tru_request = model.format_request(messages)
- exp_request = {
- "messages": [
+ ),
+ # Case 2: Video
+ (
{
- "content": json.dumps(
- {
- "content": [{"value": 4}],
- "status": "success",
- }
- ),
- "role": "tool",
- "tool_call_id": "c1",
- },
- ],
- "model": model_id,
- "stream": True,
- "stream_options": {
- "include_usage": True,
- },
- "tools": [],
- }
-
- assert tru_request == exp_request
-
-
-def test_format_request_with_tool_use(model, model_id):
- messages = [
- {
- "role": "assistant",
- "content": [
- {
- "toolUse": {
- "toolUseId": "c1",
- "name": "calculator",
- "input": {"expression": "2+2"},
- },
+ "video": {
+ "source": {"bytes": "base64encodedvideo"},
},
- ],
- },
- ]
-
- tru_request = model.format_request(messages)
- exp_request = {
- "messages": [
- {
- "content": [],
- "role": "assistant",
- "tool_calls": [
- {
- "function": {
- "name": "calculator",
- "arguments": '{"expression": "2+2"}',
- },
- "id": "c1",
- "type": "function",
- }
- ],
- }
- ],
- "model": model_id,
- "stream": True,
- "stream_options": {
- "include_usage": True,
- },
- "tools": [],
- }
-
- assert tru_request == exp_request
-
-
-def test_format_request_with_tool_specs(model, messages, model_id):
- tool_specs = [
- {
- "name": "calculator",
- "description": "Calculate mathematical expressions",
- "inputSchema": {
- "json": {"type": "object", "properties": {"expression": {"type": "string"}}, "required": ["expression"]}
},
- }
- ]
-
- tru_request = model.format_request(messages, tool_specs)
- exp_request = {
- "messages": [{"role": "user", "content": [{"type": "text", "text": "test"}]}],
- "model": model_id,
- "stream": True,
- "stream_options": {
- "include_usage": True,
- },
- "tools": [
{
- "type": "function",
- "function": {
- "name": "calculator",
- "description": "Calculate mathematical expressions",
- "parameters": {
- "type": "object",
- "properties": {"expression": {"type": "string"}},
- "required": ["expression"],
- },
+ "type": "video_url",
+ "video_url": {
+ "detail": "auto",
+ "url": "base64encodedvideo",
},
- }
- ],
- }
-
- assert tru_request == exp_request
-
-
-def test_format_chunk_message_start(model):
- event = {"chunk_type": "message_start"}
-
- tru_chunk = model.format_chunk(event)
- exp_chunk = {"messageStart": {"role": "assistant"}}
-
- assert tru_chunk == exp_chunk
-
-
-def test_format_chunk_content_start_text(model):
- event = {"chunk_type": "content_start", "data_type": "text"}
-
- tru_chunk = model.format_chunk(event)
- exp_chunk = {"contentBlockStart": {"start": {}}}
-
- assert tru_chunk == exp_chunk
-
-
-def test_format_chunk_content_start_tool(model):
- mock_tool_use = unittest.mock.Mock()
- mock_tool_use.function.name = "calculator"
- mock_tool_use.id = "c1"
-
- event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_use}
-
- tru_chunk = model.format_chunk(event)
- exp_chunk = {"contentBlockStart": {"start": {"toolUse": {"name": "calculator", "toolUseId": "c1"}}}}
-
- assert tru_chunk == exp_chunk
-
-
-def test_format_chunk_content_delta_text(model):
- event = {"chunk_type": "content_delta", "data_type": "text", "data": "Hello"}
-
- tru_chunk = model.format_chunk(event)
- exp_chunk = {"contentBlockDelta": {"delta": {"text": "Hello"}}}
-
- assert tru_chunk == exp_chunk
-
-
-def test_format_chunk_content_delta_tool(model):
- event = {
- "chunk_type": "content_delta",
- "data_type": "tool",
- "data": unittest.mock.Mock(function=unittest.mock.Mock(arguments='{"expression": "2+2"}')),
- }
-
- tru_chunk = model.format_chunk(event)
- exp_chunk = {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "2+2"}'}}}}
-
- assert tru_chunk == exp_chunk
-
-
-def test_format_chunk_content_stop(model):
- event = {"chunk_type": "content_stop"}
-
- tru_chunk = model.format_chunk(event)
- exp_chunk = {"contentBlockStop": {}}
-
- assert tru_chunk == exp_chunk
-
-
-def test_format_chunk_message_stop_end_turn(model):
- event = {"chunk_type": "message_stop", "data": "stop"}
-
- tru_chunk = model.format_chunk(event)
- exp_chunk = {"messageStop": {"stopReason": "end_turn"}}
-
- assert tru_chunk == exp_chunk
-
-
-def test_format_chunk_message_stop_tool_use(model):
- event = {"chunk_type": "message_stop", "data": "tool_calls"}
-
- tru_chunk = model.format_chunk(event)
- exp_chunk = {"messageStop": {"stopReason": "tool_use"}}
-
- assert tru_chunk == exp_chunk
-
-
-def test_format_chunk_message_stop_max_tokens(model):
- event = {"chunk_type": "message_stop", "data": "length"}
-
- tru_chunk = model.format_chunk(event)
- exp_chunk = {"messageStop": {"stopReason": "max_tokens"}}
-
- assert tru_chunk == exp_chunk
-
-
-def test_format_chunk_metadata(model):
- event = {
- "chunk_type": "metadata",
- "data": unittest.mock.Mock(prompt_tokens=100, completion_tokens=50, total_tokens=150),
- }
-
- tru_chunk = model.format_chunk(event)
- exp_chunk = {
- "metadata": {
- "usage": {
- "inputTokens": 100,
- "outputTokens": 50,
- "totalTokens": 150,
- },
- "metrics": {
- "latencyMs": 0,
},
- },
- }
-
- assert tru_chunk == exp_chunk
-
-
-def test_format_chunk_other(model):
- event = {"chunk_type": "other"}
-
- with pytest.raises(RuntimeError, match="chunk_type= | unknown type"):
- model.format_chunk(event)
-
-
-def test_stream(litellm_client, model):
- mock_tool_call_1_part_1 = unittest.mock.Mock(index=0)
- mock_tool_call_2_part_1 = unittest.mock.Mock(index=1)
- mock_delta_1 = unittest.mock.Mock(
- content="I'll calculate", tool_calls=[mock_tool_call_1_part_1, mock_tool_call_2_part_1]
- )
-
- mock_tool_call_1_part_2 = unittest.mock.Mock(index=0)
- mock_tool_call_2_part_2 = unittest.mock.Mock(index=1)
- mock_delta_2 = unittest.mock.Mock(
- content="that for you", tool_calls=[mock_tool_call_1_part_2, mock_tool_call_2_part_2]
- )
-
- mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_1)])
- mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_2)])
- mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls")])
- mock_event_4 = unittest.mock.Mock()
-
- litellm_client.chat.completions.create.return_value = iter([mock_event_1, mock_event_2, mock_event_3, mock_event_4])
-
- request = {"model": "m1", "messages": [{"role": "user", "content": [{"type": "text", "text": "calculate 2+2"}]}]}
- response = model.stream(request)
-
- tru_events = list(response)
- exp_events = [
- {"chunk_type": "message_start"},
- {"chunk_type": "content_start", "data_type": "text"},
- {"chunk_type": "content_delta", "data_type": "text", "data": "I'll calculate"},
- {"chunk_type": "content_delta", "data_type": "text", "data": "that for you"},
- {"chunk_type": "content_stop", "data_type": "text"},
- {"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call_1_part_1},
- {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_1_part_2},
- {"chunk_type": "content_stop", "data_type": "tool"},
- {"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call_2_part_1},
- {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_2_part_2},
- {"chunk_type": "content_stop", "data_type": "tool"},
- {"chunk_type": "message_stop", "data": "tool_calls"},
- {"chunk_type": "metadata", "data": mock_event_4.usage},
- ]
-
- assert tru_events == exp_events
- litellm_client.chat.completions.create.assert_called_once_with(**request)
-
-
-def test_stream_empty(litellm_client, model):
- mock_delta = unittest.mock.Mock(content=None, tool_calls=None)
-
- mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)])
- mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop")])
- mock_event_3 = unittest.mock.Mock(spec=[])
-
- litellm_client.chat.completions.create.return_value = iter([mock_event_1, mock_event_2, mock_event_3])
-
- request = {"model": "m1", "messages": [{"role": "user", "content": []}]}
- response = model.stream(request)
-
- tru_events = list(response)
- exp_events = [
- {"chunk_type": "message_start"},
- {"chunk_type": "content_start", "data_type": "text"},
- {"chunk_type": "content_stop", "data_type": "text"},
- {"chunk_type": "message_stop", "data": "stop"},
- ]
-
- assert tru_events == exp_events
- litellm_client.chat.completions.create.assert_called_once_with(**request)
+ ),
+ # Case 3: Text
+ (
+ {"text": "hello"},
+ {"type": "text", "text": "hello"},
+ ),
+ ],
+)
+def test_format_request_message_content(content, exp_result):
+ tru_result = LiteLLMModel.format_request_message_content(content)
+ assert tru_result == exp_result
diff --git a/tests/test_llamaapi.py b/tests/strands/models/test_llamaapi.py
similarity index 100%
rename from tests/test_llamaapi.py
rename to tests/strands/models/test_llamaapi.py
diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py
new file mode 100644
index 00000000..89aa591f
--- /dev/null
+++ b/tests/strands/models/test_openai.py
@@ -0,0 +1,134 @@
+import unittest.mock
+
+import pytest
+
+import strands
+from strands.models.openai import OpenAIModel
+
+
+@pytest.fixture
+def openai_client_cls():
+ with unittest.mock.patch.object(strands.models.openai.openai, "OpenAI") as mock_client_cls:
+ yield mock_client_cls
+
+
+@pytest.fixture
+def openai_client(openai_client_cls):
+ return openai_client_cls.return_value
+
+
+@pytest.fixture
+def model_id():
+ return "m1"
+
+
+@pytest.fixture
+def model(openai_client, model_id):
+ _ = openai_client
+
+ return OpenAIModel(model_id=model_id)
+
+
+@pytest.fixture
+def messages():
+ return [{"role": "user", "content": [{"text": "test"}]}]
+
+
+@pytest.fixture
+def system_prompt():
+ return "s1"
+
+
+def test__init__(openai_client_cls, model_id):
+ model = OpenAIModel({"api_key": "k1"}, model_id=model_id, params={"max_tokens": 1})
+
+ tru_config = model.get_config()
+ exp_config = {"model_id": "m1", "params": {"max_tokens": 1}}
+
+ assert tru_config == exp_config
+
+ openai_client_cls.assert_called_once_with(api_key="k1")
+
+
+def test_update_config(model, model_id):
+ model.update_config(model_id=model_id)
+
+ tru_model_id = model.get_config().get("model_id")
+ exp_model_id = model_id
+
+ assert tru_model_id == exp_model_id
+
+
+def test_stream(openai_client, model):
+ mock_tool_call_1_part_1 = unittest.mock.Mock(index=0)
+ mock_tool_call_2_part_1 = unittest.mock.Mock(index=1)
+ mock_delta_1 = unittest.mock.Mock(
+ content="I'll calculate", tool_calls=[mock_tool_call_1_part_1, mock_tool_call_2_part_1]
+ )
+
+ mock_tool_call_1_part_2 = unittest.mock.Mock(index=0)
+ mock_tool_call_2_part_2 = unittest.mock.Mock(index=1)
+ mock_delta_2 = unittest.mock.Mock(
+ content="that for you", tool_calls=[mock_tool_call_1_part_2, mock_tool_call_2_part_2]
+ )
+
+ mock_delta_3 = unittest.mock.Mock(content="", tool_calls=None)
+
+ mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_1)])
+ mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_2)])
+ mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta_3)])
+ mock_event_4 = unittest.mock.Mock()
+
+ openai_client.chat.completions.create.return_value = iter([mock_event_1, mock_event_2, mock_event_3, mock_event_4])
+
+ request = {"model": "m1", "messages": [{"role": "user", "content": [{"type": "text", "text": "calculate 2+2"}]}]}
+ response = model.stream(request)
+
+ tru_events = list(response)
+ exp_events = [
+ {"chunk_type": "message_start"},
+ {"chunk_type": "content_start", "data_type": "text"},
+ {"chunk_type": "content_delta", "data_type": "text", "data": "I'll calculate"},
+ {"chunk_type": "content_delta", "data_type": "text", "data": "that for you"},
+ {"chunk_type": "content_stop", "data_type": "text"},
+ {"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call_1_part_1},
+ {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_1_part_1},
+ {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_1_part_2},
+ {"chunk_type": "content_stop", "data_type": "tool"},
+ {"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call_2_part_1},
+ {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_2_part_1},
+ {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_2_part_2},
+ {"chunk_type": "content_stop", "data_type": "tool"},
+ {"chunk_type": "message_stop", "data": "tool_calls"},
+ {"chunk_type": "metadata", "data": mock_event_4.usage},
+ ]
+
+ assert tru_events == exp_events
+ openai_client.chat.completions.create.assert_called_once_with(**request)
+
+
+def test_stream_empty(openai_client, model):
+ mock_delta = unittest.mock.Mock(content=None, tool_calls=None)
+ mock_usage = unittest.mock.Mock(prompt_tokens=0, completion_tokens=0, total_tokens=0)
+
+ mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)])
+ mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)])
+ mock_event_3 = unittest.mock.Mock()
+ mock_event_4 = unittest.mock.Mock(usage=mock_usage)
+
+ openai_client.chat.completions.create.return_value = iter([mock_event_1, mock_event_2, mock_event_3, mock_event_4])
+
+ request = {"model": "m1", "messages": [{"role": "user", "content": []}]}
+ response = model.stream(request)
+
+ tru_events = list(response)
+ exp_events = [
+ {"chunk_type": "message_start"},
+ {"chunk_type": "content_start", "data_type": "text"},
+ {"chunk_type": "content_stop", "data_type": "text"},
+ {"chunk_type": "message_stop", "data": "stop"},
+ {"chunk_type": "metadata", "data": mock_usage},
+ ]
+
+ assert tru_events == exp_events
+ openai_client.chat.completions.create.assert_called_once_with(**request)
diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py
index 55018c5e..32a4ac0a 100644
--- a/tests/strands/telemetry/test_tracer.py
+++ b/tests/strands/telemetry/test_tracer.py
@@ -1,11 +1,12 @@
import json
import os
+from datetime import date, datetime, timezone
from unittest import mock
import pytest
from opentelemetry.trace import StatusCode # type: ignore
-from strands.telemetry.tracer import Tracer, get_tracer
+from strands.telemetry.tracer import JSONEncoder, Tracer, get_tracer, serialize
from strands.types.streaming import Usage
@@ -59,6 +60,50 @@ def mock_resource():
yield mock_resource
+@pytest.fixture
+def clean_env():
+ """Fixture to provide a clean environment for each test."""
+ with mock.patch.dict(os.environ, {}, clear=True):
+ yield
+
+
+@pytest.fixture
+def env_with_otlp():
+ """Fixture with OTLP environment variables."""
+ with mock.patch.dict(
+ os.environ,
+ {
+ "OTEL_EXPORTER_OTLP_ENDPOINT": "http://env-endpoint",
+ },
+ ):
+ yield
+
+
+@pytest.fixture
+def env_with_console():
+ """Fixture with console export environment variables."""
+ with mock.patch.dict(
+ os.environ,
+ {
+ "STRANDS_OTEL_ENABLE_CONSOLE_EXPORT": "true",
+ },
+ ):
+ yield
+
+
+@pytest.fixture
+def env_with_both():
+ """Fixture with both OTLP and console export environment variables."""
+ with mock.patch.dict(
+ os.environ,
+ {
+ "OTEL_EXPORTER_OTLP_ENDPOINT": "http://env-endpoint",
+ "STRANDS_OTEL_ENABLE_CONSOLE_EXPORT": "true",
+ },
+ ):
+ yield
+
+
def test_init_default():
"""Test initializing the Tracer with default parameters."""
tracer = Tracer()
@@ -268,6 +313,9 @@ def test_start_tool_call_span(mock_tracer):
mock_tracer.start_span.assert_called_once()
assert mock_tracer.start_span.call_args[1]["name"] == "Tool: test-tool"
+ mock_span.set_attribute.assert_any_call(
+ "gen_ai.prompt", json.dumps({"name": "test-tool", "toolUseId": "123", "input": {"param": "value"}})
+ )
mock_span.set_attribute.assert_any_call("tool.name", "test-tool")
mock_span.set_attribute.assert_any_call("tool.id", "123")
mock_span.set_attribute.assert_any_call("tool.parameters", json.dumps({"param": "value"}))
@@ -369,7 +417,7 @@ def test_end_agent_span(mock_span):
tracer.end_agent_span(mock_span, mock_response)
- mock_span.set_attribute.assert_any_call("gen_ai.completion", '""')
+ mock_span.set_attribute.assert_any_call("gen_ai.completion", "Agent response")
mock_span.set_attribute.assert_any_call("gen_ai.usage.prompt_tokens", 50)
mock_span.set_attribute.assert_any_call("gen_ai.usage.completion_tokens", 100)
mock_span.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 150)
@@ -497,3 +545,230 @@ def test_start_model_invoke_span_with_parent(mock_tracer):
# Verify span was returned
assert span is mock_span
+
+
+@pytest.mark.parametrize(
+ "input_data, expected_result",
+ [
+ ("test string", '"test string"'),
+ (1234, "1234"),
+ (13.37, "13.37"),
+ (False, "false"),
+ (None, "null"),
+ ],
+)
+def test_json_encoder_serializable(input_data, expected_result):
+ """Test encoding of serializable values."""
+ encoder = JSONEncoder()
+
+ result = encoder.encode(input_data)
+ assert result == expected_result
+
+
+def test_json_encoder_datetime():
+ """Test encoding datetime and date objects."""
+ encoder = JSONEncoder()
+
+ dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
+ result = encoder.encode(dt)
+ assert result == f'"{dt.isoformat()}"'
+
+ d = date(2025, 1, 1)
+ result = encoder.encode(d)
+ assert result == f'"{d.isoformat()}"'
+
+
+def test_json_encoder_list():
+ """Test encoding a list with mixed content."""
+ encoder = JSONEncoder()
+
+ non_serializable = lambda x: x # noqa: E731
+
+ data = ["value", 42, 13.37, non_serializable, None, {"key": True}, ["value here"]]
+
+ result = json.loads(encoder.encode(data))
+ assert result == ["value", 42, 13.37, "", None, {"key": True}, ["value here"]]
+
+
+def test_json_encoder_dict():
+ """Test encoding a dict with mixed content."""
+ encoder = JSONEncoder()
+
+ class UnserializableClass:
+ def __str__(self):
+ return "Unserializable Object"
+
+ non_serializable = lambda x: x # noqa: E731
+
+ now = datetime.now(timezone.utc)
+
+ data = {
+ "metadata": {
+ "timestamp": now,
+ "version": "1.0",
+ "debug_info": {"object": non_serializable, "callable": lambda x: x + 1}, # noqa: E731
+ },
+ "content": [
+ {"type": "text", "value": "Hello world"},
+ {"type": "binary", "value": non_serializable},
+ {"type": "mixed", "values": [1, "text", non_serializable, {"nested": non_serializable}]},
+ ],
+ "statistics": {
+ "processed": 100,
+ "failed": 5,
+ "details": [{"id": 1, "status": "ok"}, {"id": 2, "status": "error", "error_obj": non_serializable}],
+ },
+ "list": [
+ non_serializable,
+ 1234,
+ 13.37,
+ True,
+ None,
+ "string here",
+ ],
+ }
+
+ expected = {
+ "metadata": {
+ "timestamp": now.isoformat(),
+ "version": "1.0",
+ "debug_info": {"object": "", "callable": ""},
+ },
+ "content": [
+ {"type": "text", "value": "Hello world"},
+ {"type": "binary", "value": ""},
+ {"type": "mixed", "values": [1, "text", "", {"nested": ""}]},
+ ],
+ "statistics": {
+ "processed": 100,
+ "failed": 5,
+ "details": [{"id": 1, "status": "ok"}, {"id": 2, "status": "error", "error_obj": ""}],
+ },
+ "list": [
+ "",
+ 1234,
+ 13.37,
+ True,
+ None,
+ "string here",
+ ],
+ }
+
+ result = json.loads(encoder.encode(data))
+
+ assert result == expected
+
+
+def test_json_encoder_value_error():
+ """Test encoding values that cause ValueError."""
+ encoder = JSONEncoder()
+
+ # A very large integer that exceeds JSON limits and throws ValueError
+ huge_number = 2**100000
+
+ # Test in a dictionary
+ dict_data = {"normal": 42, "huge": huge_number}
+ result = json.loads(encoder.encode(dict_data))
+ assert result == {"normal": 42, "huge": ""}
+
+ # Test in a list
+ list_data = [42, huge_number]
+ result = json.loads(encoder.encode(list_data))
+ assert result == [42, ""]
+
+ # Test just the value
+ result = json.loads(encoder.encode(huge_number))
+ assert result == ""
+
+
+def test_serialize_non_ascii_characters():
+ """Test that non-ASCII characters are preserved in JSON serialization."""
+
+ # Test with Japanese text
+ japanese_text = "γγγ«γ‘γ―δΈη"
+ result = serialize({"text": japanese_text})
+ assert japanese_text in result
+ assert "\\u" not in result
+
+ # Test with emoji
+ emoji_text = "Hello π"
+ result = serialize({"text": emoji_text})
+ assert emoji_text in result
+ assert "\\u" not in result
+
+ # Test with Chinese characters
+ chinese_text = "δ½ ε₯½οΌδΈη"
+ result = serialize({"text": chinese_text})
+ assert chinese_text in result
+ assert "\\u" not in result
+
+ # Test with mixed content
+ mixed_text = {"ja": "γγγ«γ‘γ―", "emoji": "π", "zh": "δ½ ε₯½", "en": "hello"}
+ result = serialize(mixed_text)
+ assert "γγγ«γ‘γ―" in result
+ assert "π" in result
+ assert "δ½ ε₯½" in result
+ assert "\\u" not in result
+
+
+def test_serialize_vs_json_dumps():
+ """Test that serialize behaves differently from default json.dumps for non-ASCII characters."""
+
+ # Test with Japanese text
+ japanese_text = "γγγ«γ‘γ―δΈη"
+
+ # Default json.dumps should escape non-ASCII characters
+ default_result = json.dumps({"text": japanese_text})
+ assert "\\u" in default_result
+
+ # Our serialize function should preserve non-ASCII characters
+ custom_result = serialize({"text": japanese_text})
+ assert japanese_text in custom_result
+ assert "\\u" not in custom_result
+
+
+def test_init_with_no_env_or_param(clean_env):
+ """Test initializing with neither environment variable nor constructor parameter."""
+ tracer = Tracer()
+ assert tracer.otlp_endpoint is None
+ assert tracer.enable_console_export is False
+
+ tracer = Tracer(otlp_endpoint="http://param-endpoint")
+ assert tracer.otlp_endpoint == "http://param-endpoint"
+
+ tracer = Tracer(enable_console_export=True)
+ assert tracer.enable_console_export is True
+
+
+def test_constructor_params_with_otlp_env(env_with_otlp):
+ """Test constructor parameters precedence over OTLP environment variable."""
+ # Constructor parameter should take precedence
+ tracer = Tracer(otlp_endpoint="http://constructor-endpoint")
+ assert tracer.otlp_endpoint == "http://constructor-endpoint"
+
+ # Without constructor parameter, should use env var
+ tracer = Tracer()
+ assert tracer.otlp_endpoint == "http://env-endpoint"
+
+
+def test_constructor_params_with_console_env(env_with_console):
+ """Test constructor parameters precedence over console environment variable."""
+ # Constructor parameter should take precedence
+ tracer = Tracer(enable_console_export=False)
+ assert tracer.enable_console_export is False
+
+ # Without explicit constructor parameter, should use env var
+ tracer = Tracer()
+ assert tracer.enable_console_export is True
+
+
+def test_fallback_to_env_vars(env_with_both):
+ """Test fallback to environment variables when no constructor parameters."""
+ tracer = Tracer()
+ assert tracer.otlp_endpoint == "http://env-endpoint"
+ assert tracer.enable_console_export is True
+
+ # Constructor parameters should still take precedence
+ tracer = Tracer(otlp_endpoint="http://constructor-endpoint", enable_console_export=False)
+ assert tracer.otlp_endpoint == "http://constructor-endpoint"
+ assert tracer.enable_console_export is False
diff --git a/tests/strands/tools/test_tools.py b/tests/strands/tools/test_tools.py
index 8a6b406f..f24cc22d 100644
--- a/tests/strands/tools/test_tools.py
+++ b/tests/strands/tools/test_tools.py
@@ -14,9 +14,13 @@
def test_validate_tool_use_name_valid():
- tool = {"name": "valid_tool_name", "toolUseId": "123"}
+ tool1 = {"name": "valid_tool_name", "toolUseId": "123"}
+ # Should not raise an exception
+ validate_tool_use_name(tool1)
+
+ tool2 = {"name": "valid-name", "toolUseId": "123"}
# Should not raise an exception
- validate_tool_use_name(tool)
+ validate_tool_use_name(tool2)
def test_validate_tool_use_name_missing():
diff --git a/tests/strands/types/models/__init__.py b/tests/strands/types/models/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/strands/types/models/test_model.py b/tests/strands/types/models/test_model.py
new file mode 100644
index 00000000..f2797fe5
--- /dev/null
+++ b/tests/strands/types/models/test_model.py
@@ -0,0 +1,81 @@
+import pytest
+
+from strands.types.models import Model as SAModel
+
+
+class TestModel(SAModel):
+ def update_config(self, **model_config):
+ return model_config
+
+ def get_config(self):
+ return
+
+ def format_request(self, messages, tool_specs, system_prompt):
+ return {
+ "messages": messages,
+ "tool_specs": tool_specs,
+ "system_prompt": system_prompt,
+ }
+
+ def format_chunk(self, event):
+ return {"event": event}
+
+ def stream(self, request):
+ yield {"request": request}
+
+
+@pytest.fixture
+def model():
+ return TestModel()
+
+
+@pytest.fixture
+def messages():
+ return [
+ {
+ "role": "user",
+ "content": [{"text": "hello"}],
+ },
+ ]
+
+
+@pytest.fixture
+def tool_specs():
+ return [
+ {
+ "name": "test_tool",
+ "description": "A test tool",
+ "inputSchema": {
+ "json": {
+ "type": "object",
+ "properties": {
+ "input": {"type": "string"},
+ },
+ "required": ["input"],
+ },
+ },
+ },
+ ]
+
+
+@pytest.fixture
+def system_prompt():
+ return "s1"
+
+
+def test_converse(model, messages, tool_specs, system_prompt):
+ response = model.converse(messages, tool_specs, system_prompt)
+
+ tru_events = list(response)
+ exp_events = [
+ {
+ "event": {
+ "request": {
+ "messages": messages,
+ "tool_specs": tool_specs,
+ "system_prompt": system_prompt,
+ },
+ },
+ },
+ ]
+ assert tru_events == exp_events
diff --git a/tests/strands/types/models/test_openai.py b/tests/strands/types/models/test_openai.py
new file mode 100644
index 00000000..c6a05291
--- /dev/null
+++ b/tests/strands/types/models/test_openai.py
@@ -0,0 +1,358 @@
+import json
+import unittest.mock
+
+import pytest
+
+from strands.types.models import OpenAIModel as SAOpenAIModel
+
+
+class TestOpenAIModel(SAOpenAIModel):
+ def __init__(self):
+ self.config = {"model_id": "m1", "params": {"max_tokens": 1}}
+
+ def update_config(self, **model_config):
+ return model_config
+
+ def get_config(self):
+ return
+
+ def stream(self, request):
+ yield {"request": request}
+
+
+@pytest.fixture
+def model():
+ return TestOpenAIModel()
+
+
+@pytest.fixture
+def messages():
+ return [
+ {
+ "role": "user",
+ "content": [{"text": "hello"}],
+ },
+ ]
+
+
+@pytest.fixture
+def tool_specs():
+ return [
+ {
+ "name": "test_tool",
+ "description": "A test tool",
+ "inputSchema": {
+ "json": {
+ "type": "object",
+ "properties": {
+ "input": {"type": "string"},
+ },
+ "required": ["input"],
+ },
+ },
+ },
+ ]
+
+
+@pytest.fixture
+def system_prompt():
+ return "s1"
+
+
+@pytest.mark.parametrize(
+ "content, exp_result",
+ [
+ # Document
+ (
+ {
+ "document": {
+ "format": "pdf",
+ "name": "test doc",
+ "source": {"bytes": b"document"},
+ },
+ },
+ {
+ "file": {
+ "file_data": "data:application/pdf;base64,ZG9jdW1lbnQ=",
+ "filename": "test doc",
+ },
+ "type": "file",
+ },
+ ),
+ # Image
+ (
+ {
+ "image": {
+ "format": "jpg",
+ "source": {"bytes": b"image"},
+ },
+ },
+ {
+ "image_url": {
+ "detail": "auto",
+ "format": "image/jpeg",
+ "url": "data:image/jpeg;base64,image",
+ },
+ "type": "image_url",
+ },
+ ),
+ # Text
+ (
+ {"text": "hello"},
+ {"type": "text", "text": "hello"},
+ ),
+ # Other
+ (
+ {"other": {"a": 1}},
+ {
+ "text": json.dumps({"other": {"a": 1}}),
+ "type": "text",
+ },
+ ),
+ ],
+)
+def test_format_request_message_content(content, exp_result):
+ tru_result = SAOpenAIModel.format_request_message_content(content)
+ assert tru_result == exp_result
+
+
+def test_format_request_message_tool_call():
+ tool_use = {
+ "input": {"expression": "2+2"},
+ "name": "calculator",
+ "toolUseId": "c1",
+ }
+
+ tru_result = SAOpenAIModel.format_request_message_tool_call(tool_use)
+ exp_result = {
+ "function": {
+ "arguments": '{"expression": "2+2"}',
+ "name": "calculator",
+ },
+ "id": "c1",
+ "type": "function",
+ }
+ assert tru_result == exp_result
+
+
+def test_format_request_tool_message():
+ tool_result = {
+ "content": [{"value": 4}],
+ "status": "success",
+ "toolUseId": "c1",
+ }
+
+ tru_result = SAOpenAIModel.format_request_tool_message(tool_result)
+ exp_result = {
+ "content": json.dumps(
+ {
+ "content": [{"value": 4}],
+ "status": "success",
+ }
+ ),
+ "role": "tool",
+ "tool_call_id": "c1",
+ }
+ assert tru_result == exp_result
+
+
+def test_format_request_messages(system_prompt):
+ messages = [
+ {
+ "content": [],
+ "role": "user",
+ },
+ {
+ "content": [{"text": "hello"}],
+ "role": "user",
+ },
+ {
+ "content": [
+ {"text": "call tool"},
+ {
+ "toolUse": {
+ "input": {"expression": "2+2"},
+ "name": "calculator",
+ "toolUseId": "c1",
+ },
+ },
+ ],
+ "role": "assistant",
+ },
+ {
+ "content": [{"toolResult": {"toolUseId": "c1", "status": "success", "content": [{"value": 4}]}}],
+ "role": "user",
+ },
+ ]
+
+ tru_result = SAOpenAIModel.format_request_messages(messages, system_prompt)
+ exp_result = [
+ {
+ "content": system_prompt,
+ "role": "system",
+ },
+ {
+ "content": [{"text": "hello", "type": "text"}],
+ "role": "user",
+ },
+ {
+ "content": [{"text": "call tool", "type": "text"}],
+ "role": "assistant",
+ "tool_calls": [
+ {
+ "function": {
+ "name": "calculator",
+ "arguments": '{"expression": "2+2"}',
+ },
+ "id": "c1",
+ "type": "function",
+ }
+ ],
+ },
+ {
+ "content": json.dumps(
+ {
+ "content": [{"value": 4}],
+ "status": "success",
+ }
+ ),
+ "role": "tool",
+ "tool_call_id": "c1",
+ },
+ ]
+ assert tru_result == exp_result
+
+
+def test_format_request(model, messages, tool_specs, system_prompt):
+ tru_request = model.format_request(messages, tool_specs, system_prompt)
+ exp_request = {
+ "messages": [
+ {
+ "content": system_prompt,
+ "role": "system",
+ },
+ {
+ "content": [{"text": "hello", "type": "text"}],
+ "role": "user",
+ },
+ ],
+ "model": "m1",
+ "stream": True,
+ "stream_options": {"include_usage": True},
+ "tools": [
+ {
+ "function": {
+ "description": "A test tool",
+ "name": "test_tool",
+ "parameters": {
+ "properties": {
+ "input": {"type": "string"},
+ },
+ "required": ["input"],
+ "type": "object",
+ },
+ },
+ "type": "function",
+ },
+ ],
+ "max_tokens": 1,
+ }
+ assert tru_request == exp_request
+
+
+@pytest.mark.parametrize(
+ ("event", "exp_chunk"),
+ [
+ # Message start
+ (
+ {"chunk_type": "message_start"},
+ {"messageStart": {"role": "assistant"}},
+ ),
+ # Content Start - Tool Use
+ (
+ {
+ "chunk_type": "content_start",
+ "data_type": "tool",
+ "data": unittest.mock.Mock(**{"function.name": "calculator", "id": "c1"}),
+ },
+ {"contentBlockStart": {"start": {"toolUse": {"name": "calculator", "toolUseId": "c1"}}}},
+ ),
+ # Content Start - Text
+ (
+ {"chunk_type": "content_start", "data_type": "text"},
+ {"contentBlockStart": {"start": {}}},
+ ),
+ # Content Delta - Tool Use
+ (
+ {
+ "chunk_type": "content_delta",
+ "data_type": "tool",
+ "data": unittest.mock.Mock(function=unittest.mock.Mock(arguments='{"expression": "2+2"}')),
+ },
+ {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "2+2"}'}}}},
+ ),
+ # Content Delta - Tool Use - None
+ (
+ {
+ "chunk_type": "content_delta",
+ "data_type": "tool",
+ "data": unittest.mock.Mock(function=unittest.mock.Mock(arguments=None)),
+ },
+ {"contentBlockDelta": {"delta": {"toolUse": {"input": ""}}}},
+ ),
+ # Content Delta - Text
+ (
+ {"chunk_type": "content_delta", "data_type": "text", "data": "hello"},
+ {"contentBlockDelta": {"delta": {"text": "hello"}}},
+ ),
+ # Content Stop
+ (
+ {"chunk_type": "content_stop"},
+ {"contentBlockStop": {}},
+ ),
+ # Message Stop - Tool Use
+ (
+ {"chunk_type": "message_stop", "data": "tool_calls"},
+ {"messageStop": {"stopReason": "tool_use"}},
+ ),
+ # Message Stop - Max Tokens
+ (
+ {"chunk_type": "message_stop", "data": "length"},
+ {"messageStop": {"stopReason": "max_tokens"}},
+ ),
+ # Message Stop - End Turn
+ (
+ {"chunk_type": "message_stop", "data": "stop"},
+ {"messageStop": {"stopReason": "end_turn"}},
+ ),
+ # Metadata
+ (
+ {
+ "chunk_type": "metadata",
+ "data": unittest.mock.Mock(prompt_tokens=100, completion_tokens=50, total_tokens=150),
+ },
+ {
+ "metadata": {
+ "usage": {
+ "inputTokens": 100,
+ "outputTokens": 50,
+ "totalTokens": 150,
+ },
+ "metrics": {
+ "latencyMs": 0,
+ },
+ },
+ },
+ ),
+ ],
+)
+def test_format_chunk(event, exp_chunk, model):
+ tru_chunk = model.format_chunk(event)
+ assert tru_chunk == exp_chunk
+
+
+def test_format_chunk_unknown_type(model):
+ event = {"chunk_type": "unknown"}
+
+ with pytest.raises(RuntimeError, match="chunk_type= | unknown type"):
+ model.format_chunk(event)