diff --git a/.github/workflows/dispatch-docs.yml b/.github/workflows/dispatch-docs.yml deleted file mode 100644 index fda63413..00000000 --- a/.github/workflows/dispatch-docs.yml +++ /dev/null @@ -1,16 +0,0 @@ -name: Dispatch Docs -on: - push: - branches: - - main - -jobs: - trigger: - runs-on: ubuntu-latest - permissions: - contents: read - steps: - - name: Dispatch - run: gh api repos/${{ github.repository_owner }}/docs/dispatches -F event_type=sdk-push - env: - GITHUB_TOKEN: ${{ secrets.PAT_TOKEN }} diff --git a/.github/workflows/pr-and-push.yml b/.github/workflows/pr-and-push.yml new file mode 100644 index 00000000..38e88691 --- /dev/null +++ b/.github/workflows/pr-and-push.yml @@ -0,0 +1,17 @@ +name: Pull Request and Push Action + +on: + pull_request: # Safer than pull_request_target for untrusted code + branches: [ main ] + types: [opened, synchronize, reopened, ready_for_review, review_requested, review_request_removed] + push: + branches: [ main ] # Also run on direct pushes to main +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + call-test-lint: + uses: ./.github/workflows/test-lint.yml + with: + ref: ${{ github.event.pull_request.head.sha }} \ No newline at end of file diff --git a/.github/workflows/pypi-publish-on-release.yml b/.github/workflows/pypi-publish-on-release.yml new file mode 100644 index 00000000..4047f596 --- /dev/null +++ b/.github/workflows/pypi-publish-on-release.yml @@ -0,0 +1,78 @@ +name: Publish Python Package + +on: + release: + types: + - published + +jobs: + call-test-lint: + uses: ./.github/workflows/test-lint.yml + with: + ref: ${{ github.event.release.target_commitish }} + + build: + name: Build distribution πŸ“¦ + needs: + - call-test-lint + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + with: + persist-credentials: false + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install hatch twine + + - name: Validate version + run: | + version=$(hatch version) + if [[ $version =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then + echo "Valid version format" + exit 0 + else + echo "Invalid version format" + exit 1 + fi + + - name: Build + run: | + hatch build + + - name: Store the distribution packages + uses: actions/upload-artifact@v4 + with: + name: python-package-distributions + path: dist/ + + deploy: + name: Upload release to PyPI + needs: + - build + runs-on: ubuntu-latest + + # environment is used by PyPI Trusted Publisher and is strongly encouraged + # https://docs.pypi.org/trusted-publishers/adding-a-publisher/ + environment: + name: pypi + url: https://pypi.org/p/strands-agents + permissions: + # IMPORTANT: this permission is mandatory for Trusted Publishing + id-token: write + + steps: + - name: Download all the dists + uses: actions/download-artifact@v4 + with: + name: python-package-distributions + path: dist/ + - name: Publish distribution πŸ“¦ to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 \ No newline at end of file diff --git a/.github/workflows/test-lint-pr.yml b/.github/workflows/test-lint-pr.yml deleted file mode 100644 index 15fbebcb..00000000 --- a/.github/workflows/test-lint-pr.yml +++ /dev/null @@ -1,139 +0,0 @@ -name: Test and Lint - -on: - pull_request: # Safer than pull_request_target for untrusted code - branches: [ main ] - types: [opened, synchronize, reopened, ready_for_review, review_requested, review_request_removed] - push: - branches: [ main ] # Also run on direct pushes to main - -jobs: - check-approval: - name: Check if PR has contributor approval - runs-on: ubuntu-latest - permissions: - pull-requests: read - # Skip this check for direct pushes to main - if: github.event_name == 'pull_request' - outputs: - approved: ${{ steps.check-approval.outputs.approved }} - steps: - - name: Check if PR has been approved by a contributor - id: check-approval - uses: actions/github-script@v7 - with: - script: | - const APPROVED_ASSOCIATION = ['COLLABORATOR', 'CONTRIBUTOR', 'MEMBER', 'OWNER'] - const PR_AUTHOR_ASSOCIATION = context.payload.pull_request.author_association; - const { data: reviews } = await github.rest.pulls.listReviews({ - owner: context.repo.owner, - repo: context.repo.repo, - pull_number: context.issue.number, - }); - - const isApprovedContributor = APPROVED_ASSOCIATION.includes(PR_AUTHOR_ASSOCIATION); - - // Check if any contributor has approved - const isApproved = reviews.some(review => - review.state === 'APPROVED' && APPROVED_ASSOCIATION.includes(review.author_association) - ) || isApprovedContributor; - - core.setOutput('approved', isApproved); - - if (!isApproved) { - core.notice('This PR does not have approval from a Contributor. Workflow will not run test jobs.'); - return false; - } - - return true; - - unit-test: - name: Unit Tests - Python ${{ matrix.python-version }} - ${{ matrix.os-name }} - needs: check-approval - permissions: - contents: read - # Only run if PR is approved or this is a direct push to main - if: github.event_name == 'push' || needs.check-approval.outputs.approved == 'true' - strategy: - matrix: - include: - # Linux - - os: ubuntu-latest - os-name: linux - python-version: "3.10" - - os: ubuntu-latest - os-name: linux - python-version: "3.11" - - os: ubuntu-latest - os-name: linux - python-version: "3.12" - - os: ubuntu-latest - os-name: linux - python-version: "3.13" - # Windows - - os: windows-latest - os-name: windows - python-version: "3.10" - - os: windows-latest - os-name: windows - python-version: "3.11" - - os: windows-latest - os-name: windows - python-version: "3.12" - - os: windows-latest - os-name: windows - python-version: "3.13" - # MacOS - latest only; not enough runners for MacOS - - os: macos-latest - os-name: macos - python-version: "3.13" - fail-fast: false - runs-on: ${{ matrix.os }} - env: - LOG_LEVEL: DEBUG - steps: - - name: Checkout code - uses: actions/checkout@v4 - with: - ref: ${{ github.event.pull_request.head.sha }} # Explicitly define which commit to checkout - persist-credentials: false # Don't persist credentials for subsequent actions - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - pip install --no-cache-dir hatch - - name: Run Unit tests - id: tests - run: hatch test tests --cover - continue-on-error: false - - lint: - name: Lint - runs-on: ubuntu-latest - needs: check-approval - permissions: - contents: read - if: github.event_name == 'push' || needs.check-approval.outputs.approved == 'true' - steps: - - name: Checkout code - uses: actions/checkout@v4 - with: - ref: ${{ github.event.pull_request.head.sha }} - persist-credentials: false - - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: '3.10' - cache: 'pip' - - - name: Install dependencies - run: | - pip install --no-cache-dir hatch - - - name: Run lint - id: lint - run: hatch run test-lint - continue-on-error: false diff --git a/.github/workflows/test-lint.yml b/.github/workflows/test-lint.yml new file mode 100644 index 00000000..35e0f584 --- /dev/null +++ b/.github/workflows/test-lint.yml @@ -0,0 +1,94 @@ +name: Test and Lint + +on: + workflow_call: + inputs: + ref: + required: true + type: string + +jobs: + unit-test: + name: Unit Tests - Python ${{ matrix.python-version }} - ${{ matrix.os-name }} + permissions: + contents: read + strategy: + matrix: + include: + # Linux + - os: ubuntu-latest + os-name: 'linux' + python-version: "3.10" + - os: ubuntu-latest + os-name: 'linux' + python-version: "3.11" + - os: ubuntu-latest + os-name: 'linux' + python-version: "3.12" + - os: ubuntu-latest + os-name: 'linux' + python-version: "3.13" + # Windows + - os: windows-latest + os-name: 'windows' + python-version: "3.10" + - os: windows-latest + os-name: 'windows' + python-version: "3.11" + - os: windows-latest + os-name: 'windows' + python-version: "3.12" + - os: windows-latest + os-name: 'windows' + python-version: "3.13" + # MacOS - latest only; not enough runners for macOS + - os: macos-latest + os-name: 'macOS' + python-version: "3.13" + fail-fast: true + runs-on: ${{ matrix.os }} + env: + LOG_LEVEL: DEBUG + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + ref: ${{ inputs.ref }} # Explicitly define which commit to check out + persist-credentials: false # Don't persist credentials for subsequent actions + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + pip install --no-cache-dir hatch + - name: Run Unit tests + id: tests + run: hatch test tests --cover + continue-on-error: false + lint: + name: Lint + runs-on: ubuntu-latest + permissions: + contents: read + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + ref: ${{ inputs.ref }} + persist-credentials: false + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + cache: 'pip' + + - name: Install dependencies + run: | + pip install --no-cache-dir hatch + + - name: Run lint + id: lint + run: hatch run test-lint + continue-on-error: false diff --git a/.gitignore b/.gitignore index a80f4bd1..5cdc43db 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,5 @@ __pycache__* .pytest_cache .ruff_cache *.bak +.vscode +dist \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index e1ddbb89..fa724cdd 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -31,16 +31,22 @@ This project uses [hatchling](https://hatch.pypa.io/latest/build/#hatchling) as ### Setting Up Your Development Environment -1. Install development dependencies: +1. Entering virtual environment using `hatch` (recommended), then launch your IDE in the new shell. ```bash - pip install -e ".[dev]" && pip install -e ".[litellm] + hatch shell dev ``` + Alternatively, install development dependencies in a manually created virtual environment: + ```bash + pip install -e ".[dev]" && pip install -e ".[litellm]" + ``` + + 2. Set up pre-commit hooks: ```bash pre-commit install -t pre-commit -t commit-msg ``` - This will automatically run formatters and convention commit checks on your code before each commit. + This will automatically run formatters and conventional commit checks on your code before each commit. 3. Run code formatters manually: ```bash @@ -94,6 +100,8 @@ hatch fmt --linter If you're using an IDE like VS Code or PyCharm, consider configuring it to use these tools automatically. +For additional details on styling, please see our dedicated [Style Guide](./STYLE_GUIDE.md). + ## Contributing via Pull Requests Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: @@ -115,7 +123,7 @@ To send us a pull request, please: ## Finding contributions to work on -Looking at the existing issues is a great way to find something to contribute on. +Looking at the existing issues is a great way to find something to contribute to. You can check: - Our known bugs list in [Bug Reports](../../issues?q=is%3Aissue%20state%3Aopen%20label%3Abug) for issues that need fixing diff --git a/README.md b/README.md index 337acd83..ed98d001 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,35 @@ -# Strands Agents +
+
+ + Strands Agents + +
-[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](LICENSE) -[![Python](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/downloads/) +

+ Strands Agents +

-

A model-driven approach to building AI agents in just a few lines of code.

- Docs - β—† Samples - β—† Tools - β—† Agent Builder +
+ GitHub commit activity + GitHub open issues + GitHub open pull requests + License + PyPI version + Python versions +
+ +

+ Documentation + β—† Samples + β—† Python SDK + β—† Tools + β—† Agent Builder + β—† MCP Server +

Strands Agents is a simple yet powerful SDK that takes a model-driven approach to building and running AI agents. From simple conversational assistants to complex autonomous workflows, from local development to production deployment, Strands Agents scales with your needs. @@ -19,7 +37,7 @@ Strands Agents is a simple yet powerful SDK that takes a model-driven approach t ## Feature Overview - **Lightweight & Flexible**: Simple agent loop that just works and is fully customizable -- **Model Agnostic**: Support for Amazon Bedrock, Anthropic, Ollama, and custom providers +- **Model Agnostic**: Support for Amazon Bedrock, Anthropic, LiteLLM, Llama, Ollama, OpenAI, and custom providers - **Advanced Capabilities**: Multi-agent systems, autonomous agents, and streaming support - **Built-in MCP**: Native support for Model Context Protocol (MCP) servers, enabling access to thousands of pre-built tools @@ -105,16 +123,17 @@ from strands.models.llamaapi import LlamaAPIModel bedrock_model = BedrockModel( model_id="us.amazon.nova-pro-v1:0", temperature=0.3, + streaming=True, # Enable/disable streaming ) agent = Agent(model=bedrock_model) agent("Tell me about Agentic AI") # Ollama -ollama_modal = OllamaModel( +ollama_model = OllamaModel( host="http://localhost:11434", model_id="llama3" ) -agent = Agent(model=ollama_modal) +agent = Agent(model=ollama_model) agent("Tell me about Agentic AI") # Llama API @@ -129,7 +148,9 @@ Built-in providers: - [Amazon Bedrock](https://strandsagents.com/latest/user-guide/concepts/model-providers/amazon-bedrock/) - [Anthropic](https://strandsagents.com/latest/user-guide/concepts/model-providers/anthropic/) - [LiteLLM](https://strandsagents.com/latest/user-guide/concepts/model-providers/litellm/) + - [LlamaAPI](https://strandsagents.com/latest/user-guide/concepts/model-providers/llamaapi/) - [Ollama](https://strandsagents.com/latest/user-guide/concepts/model-providers/ollama/) + - [OpenAI](https://strandsagents.com/latest/user-guide/concepts/model-providers/openai/) Custom providers can be implemented using [Custom Providers](https://strandsagents.com/latest/user-guide/concepts/model-providers/custom_model_provider/) @@ -144,7 +165,7 @@ agent = Agent(tools=[calculator]) agent("What is the square root of 1764") ``` -It's also available on GitHub via [strands-agents-tools](https://github.com/strands-agents/strands-agents-tools). +It's also available on GitHub via [strands-agents/tools](https://github.com/strands-agents/tools). ## Documentation @@ -157,9 +178,9 @@ For detailed guidance & examples, explore our documentation: - [API Reference](https://strandsagents.com/latest/api-reference/agent/) - [Production & Deployment Guide](https://strandsagents.com/latest/user-guide/deploy/operating-agents-in-production/) -## Contributing +## Contributing ❀️ -We welcome contributions! See our [Contributing Guide](https://github.com/strands-agents/sdk-python/blob/main/CONTRIBUTING.md) for details on: +We welcome contributions! See our [Contributing Guide](CONTRIBUTING.md) for details on: - Reporting bugs & features - Development setup - Contributing via Pull Requests @@ -170,6 +191,10 @@ We welcome contributions! See our [Contributing Guide](https://github.com/strand This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENSE) file for details. +## Security + +See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. + ## ⚠️ Preview Status Strands Agents is currently in public preview. During this period: diff --git a/STYLE_GUIDE.md b/STYLE_GUIDE.md new file mode 100644 index 00000000..51dc0a73 --- /dev/null +++ b/STYLE_GUIDE.md @@ -0,0 +1,59 @@ +# Style Guide + +## Overview + +The Strands Agents style guide aims to establish consistent formatting, naming conventions, and structure across all code in the repository. We strive to make our code clean, readable, and maintainable. + +Where possible, we will codify these style guidelines into our linting rules and pre-commit hooks to automate enforcement and reduce the manual review burden. + +## Log Formatting + +The format for Strands Agents logs is as follows: + +```python +logger.debug("field1=<%s>, field2=<%s>, ... | human readable message", field1, field2, ...) +``` + +### Guidelines + +1. **Context**: + - Add context as `=` pairs at the beginning of the log + - Many log services (CloudWatch, Splunk, etc.) look for these patterns to extract fields for searching + - Use `,`'s to separate pairs + - Enclose values in `<>` for readability + - This is particularly helpful in displaying empty values (`field=` vs `field=<>`) + - Use `%s` for string interpolation as recommended by Python logging + - This is an optimization to skip string interpolation when the log level is not enabled + +1. **Messages**: + - Add human-readable messages at the end of the log + - Use lowercase for consistency + - Avoid punctuation (periods, exclamation points, etc.) to reduce clutter + - Keep messages concise and focused on a single statement + - If multiple statements are needed, separate them with the pipe character (`|`) + - Example: `"processing request | starting validation"` + +### Examples + +#### Good + +```python +logger.debug("user_id=<%s>, action=<%s> | user performed action", user_id, action) +logger.info("request_id=<%s>, duration_ms=<%d> | request completed", request_id, duration) +logger.warning("attempt=<%d>, max_attempts=<%d> | retry limit approaching", attempt, max_attempts) +``` + +#### Poor + +```python +# Avoid: No structured fields, direct variable interpolation in message +logger.debug(f"User {user_id} performed action {action}") + +# Avoid: Inconsistent formatting, punctuation +logger.info("Request completed in %d ms.", duration) + +# Avoid: No separation between fields and message +logger.warning("Retry limit approaching! attempt=%d max_attempts=%d", attempt, max_attempts) +``` + +By following these log formatting guidelines, we ensure that logs are both human-readable and machine-parseable, making debugging and monitoring more efficient. diff --git a/pyproject.toml b/pyproject.toml index 6f9b78d6..bd309732 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,11 +1,11 @@ [build-system] -requires = ["hatchling"] +requires = ["hatchling", "hatch-vcs"] build-backend = "hatchling.build" [project] name = "strands-agents" -version = "0.1.0" -description = "A production-ready framework for building autonomous AI agents" +dynamic = ["version"] +description = "A model-driven approach to building AI agents in just a few lines of code" readme = "README.md" requires-python = ">=3.10" license = {text = "Apache-2.0"} @@ -33,9 +33,9 @@ dependencies = [ "pydantic>=2.0.0,<3.0.0", "typing-extensions>=4.13.2,<5.0.0", "watchdog>=6.0.0,<7.0.0", - "opentelemetry-api>=1.33.0,<2.0.0", - "opentelemetry-sdk>=1.33.0,<2.0.0", - "opentelemetry-exporter-otlp-proto-http>=1.33.0,<2.0.0", + "opentelemetry-api>=1.30.0,<2.0.0", + "opentelemetry-sdk>=1.30.0,<2.0.0", + "opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0", ] [project.urls] @@ -54,7 +54,7 @@ dev = [ "commitizen>=4.4.0,<5.0.0", "hatch>=1.0.0,<2.0.0", "moto>=5.1.0,<6.0.0", - "mypy>=0.981,<1.0.0", + "mypy>=1.15.0,<2.0.0", "pre-commit>=3.2.0,<4.2.0", "pytest>=8.0.0,<9.0.0", "pytest-asyncio>=0.26.0,<0.27.0", @@ -69,15 +69,22 @@ docs = [ litellm = [ "litellm>=1.69.0,<2.0.0", ] +llamaapi = [ + "llama-api-client>=0.1.0,<1.0.0", +] ollama = [ "ollama>=0.4.8,<1.0.0", ] -llamaapi = [ - "llama-api-client>=0.1.0,<1.0.0", +openai = [ + "openai>=1.68.0,<2.0.0", ] +[tool.hatch.version] +# Tells Hatch to use your version control system (git) to determine the version. +source = "vcs" + [tool.hatch.envs.hatch-static-analysis] -features = ["anthropic", "litellm", "llamaapi", "ollama"] +features = ["anthropic", "litellm", "llamaapi", "ollama", "openai"] dependencies = [ "mypy>=1.15.0,<2.0.0", "ruff>=0.11.6,<0.12.0", @@ -100,7 +107,7 @@ lint-fix = [ ] [tool.hatch.envs.hatch-test] -features = ["anthropic", "litellm", "llamaapi", "ollama"] +features = ["anthropic", "litellm", "llamaapi", "ollama", "openai"] extra-dependencies = [ "moto>=5.1.0,<6.0.0", "pytest>=8.0.0,<9.0.0", @@ -114,6 +121,11 @@ extra-args = [ "-vv", ] +[tool.hatch.envs.dev] +dev-mode = true +features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama"] + + [[tool.hatch.envs.hatch-test.matrix]] python = ["3.13", "3.12", "3.11", "3.10"] @@ -185,7 +197,9 @@ select = [ "D", # pydocstyle "E", # pycodestyle "F", # pyflakes + "G", # logging format "I", # isort + "LOG", # logging ] [tool.ruff.lint.per-file-ignores] diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 89653036..0f912b54 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -6,7 +6,7 @@ The Agent interface supports two complementary interaction patterns: 1. Natural language for conversation: `agent("Analyze this data")` -2. Method-style for direct tool access: `agent.tool_name(param1="value")` +2. Method-style for direct tool access: `agent.tool.tool_name(param1="value")` """ import asyncio @@ -328,27 +328,17 @@ def __call__(self, prompt: str, **kwargs: Any) -> AgentResult: - metrics: Performance metrics from the event loop - state: The final state of the event loop """ - model_id = self.model.config.get("model_id") if hasattr(self.model, "config") else None - - self.trace_span = self.tracer.start_agent_span( - prompt=prompt, - model_id=model_id, - tools=self.tool_names, - system_prompt=self.system_prompt, - custom_trace_attributes=self.trace_attributes, - ) + self._start_agent_trace_span(prompt) try: # Run the event loop and get the result result = self._run_loop(prompt, kwargs) - if self.trace_span: - self.tracer.end_agent_span(span=self.trace_span, response=result) + self._end_agent_trace_span(response=result) return result except Exception as e: - if self.trace_span: - self.tracer.end_agent_span(span=self.trace_span, error=e) + self._end_agent_trace_span(error=e) # Re-raise the exception to preserve original behavior raise @@ -383,6 +373,8 @@ async def stream_async(self, prompt: str, **kwargs: Any) -> AsyncIterator[Any]: yield event["data"] ``` """ + self._start_agent_trace_span(prompt) + _stop_event = uuid4() queue = asyncio.Queue[Any]() @@ -400,8 +392,10 @@ def target_callback() -> None: nonlocal kwargs try: - self._run_loop(prompt, kwargs, supplementary_callback_handler=queuing_callback_handler) - except BaseException as e: + result = self._run_loop(prompt, kwargs, supplementary_callback_handler=queuing_callback_handler) + self._end_agent_trace_span(response=result) + except Exception as e: + self._end_agent_trace_span(error=e) enqueue(e) finally: enqueue(_stop_event) @@ -414,7 +408,7 @@ def target_callback() -> None: item = await queue.get() if item == _stop_event: break - if isinstance(item, BaseException): + if isinstance(item, Exception): raise item yield item finally: @@ -457,27 +451,28 @@ def _execute_event_loop_cycle(self, callback_handler: Callable, kwargs: dict[str Returns: The result of the event loop cycle. """ - kwargs.pop("agent", None) - kwargs.pop("model", None) - kwargs.pop("system_prompt", None) - kwargs.pop("tool_execution_handler", None) - kwargs.pop("event_loop_metrics", None) - kwargs.pop("callback_handler", None) - kwargs.pop("tool_handler", None) - kwargs.pop("messages", None) - kwargs.pop("tool_config", None) + # Extract parameters with fallbacks to instance values + system_prompt = kwargs.pop("system_prompt", self.system_prompt) + model = kwargs.pop("model", self.model) + tool_execution_handler = kwargs.pop("tool_execution_handler", self.thread_pool_wrapper) + event_loop_metrics = kwargs.pop("event_loop_metrics", self.event_loop_metrics) + callback_handler_override = kwargs.pop("callback_handler", callback_handler) + tool_handler = kwargs.pop("tool_handler", self.tool_handler) + messages = kwargs.pop("messages", self.messages) + tool_config = kwargs.pop("tool_config", self.tool_config) + kwargs.pop("agent", None) # Remove agent to avoid conflicts try: # Execute the main event loop cycle stop_reason, message, metrics, state = event_loop_cycle( - model=self.model, - system_prompt=self.system_prompt, - messages=self.messages, # will be modified by event_loop_cycle - tool_config=self.tool_config, - callback_handler=callback_handler, - tool_handler=self.tool_handler, - tool_execution_handler=self.thread_pool_wrapper, - event_loop_metrics=self.event_loop_metrics, + model=model, + system_prompt=system_prompt, + messages=messages, # will be modified by event_loop_cycle + tool_config=tool_config, + callback_handler=callback_handler_override, + tool_handler=tool_handler, + tool_execution_handler=tool_execution_handler, + event_loop_metrics=event_loop_metrics, agent=self, event_loop_parent_span=self.trace_span, **kwargs, @@ -488,8 +483,8 @@ def _execute_event_loop_cycle(self, callback_handler: Callable, kwargs: dict[str except ContextWindowOverflowException as e: # Try reducing the context size and retrying - self.conversation_manager.reduce_context(self.messages, e=e) - return self._execute_event_loop_cycle(callback_handler, kwargs) + self.conversation_manager.reduce_context(messages, e=e) + return self._execute_event_loop_cycle(callback_handler_override, kwargs) def _record_tool_execution( self, @@ -515,7 +510,7 @@ def _record_tool_execution( """ # Create user message describing the tool call user_msg_content = [ - {"text": (f"agent.{tool['name']} direct tool call\nInput parameters: {json.dumps(tool['input'])}\n")} + {"text": (f"agent.tool.{tool['name']} direct tool call.\nInput parameters: {json.dumps(tool['input'])}\n")} ] # Add override message if provided @@ -545,3 +540,43 @@ def _record_tool_execution( messages.append(tool_use_msg) messages.append(tool_result_msg) messages.append(assistant_msg) + + def _start_agent_trace_span(self, prompt: str) -> None: + """Starts a trace span for the agent. + + Args: + prompt: The natural language prompt from the user. + """ + model_id = self.model.config.get("model_id") if hasattr(self.model, "config") else None + + self.trace_span = self.tracer.start_agent_span( + prompt=prompt, + model_id=model_id, + tools=self.tool_names, + system_prompt=self.system_prompt, + custom_trace_attributes=self.trace_attributes, + ) + + def _end_agent_trace_span( + self, + response: Optional[AgentResult] = None, + error: Optional[Exception] = None, + ) -> None: + """Ends a trace span for the agent. + + Args: + span: The span to end. + response: Response to record as a trace attribute. + error: Error to record as a trace attribute. + """ + if self.trace_span: + trace_attributes: Dict[str, Any] = { + "span": self.trace_span, + } + + if response: + trace_attributes["response"] = response + if error: + trace_attributes["error"] = error + + self.tracer.end_agent_span(**trace_attributes) diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index 4b11e81c..f367b272 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -1,12 +1,10 @@ """Sliding window conversation history management.""" -import json import logging -from typing import List, Optional, cast +from typing import Optional -from ...types.content import ContentBlock, Message, Messages +from ...types.content import Message, Messages from ...types.exceptions import ContextWindowOverflowException -from ...types.tools import ToolResult from .conversation_manager import ConversationManager logger = logging.getLogger(__name__) @@ -110,8 +108,9 @@ def _remove_dangling_messages(self, messages: Messages) -> None: def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> None: """Trim the oldest messages to reduce the conversation context size. - The method handles special cases where tool results need to be converted to regular content blocks to maintain - conversation coherence after trimming. + The method handles special cases where trimming the messages leads to: + - toolResult with no corresponding toolUse + - toolUse with no corresponding toolResult Args: messages: The messages to reduce. @@ -126,52 +125,24 @@ def reduce_context(self, messages: Messages, e: Optional[Exception] = None) -> N # If the number of messages is less than the window_size, then we default to 2, otherwise, trim to window size trim_index = 2 if len(messages) <= self.window_size else len(messages) - self.window_size - # Throw if we cannot trim any messages from the conversation - if trim_index >= len(messages): - raise ContextWindowOverflowException("Unable to trim conversation context!") from e - - # If the message at the cut index has ToolResultContent, then we map that to ContentBlock. This gets around the - # limitation of needing ToolUse and ToolResults to be paired. - if any("toolResult" in content for content in messages[trim_index]["content"]): - if len(messages[trim_index]["content"]) == 1: - messages[trim_index]["content"] = self._map_tool_result_content( - cast(ToolResult, messages[trim_index]["content"][0]["toolResult"]) + # Find the next valid trim_index + while trim_index < len(messages): + if ( + # Oldest message cannot be a toolResult because it needs a toolUse preceding it + any("toolResult" in content for content in messages[trim_index]["content"]) + or ( + # Oldest message can be a toolUse only if a toolResult immediately follows it. + any("toolUse" in content for content in messages[trim_index]["content"]) + and trim_index + 1 < len(messages) + and not any("toolResult" in content for content in messages[trim_index + 1]["content"]) ) - - # If there is more content than just one ToolResultContent, then we cannot cut at this index. + ): + trim_index += 1 else: - raise ContextWindowOverflowException("Unable to trim conversation context!") from e + break + else: + # If we didn't find a valid trim_index, then we throw + raise ContextWindowOverflowException("Unable to trim conversation context!") from e # Overwrite message history messages[:] = messages[trim_index:] - - def _map_tool_result_content(self, tool_result: ToolResult) -> List[ContentBlock]: - """Convert a ToolResult to a list of standard ContentBlocks. - - This method transforms tool result content into standard content blocks that can be preserved when trimming the - conversation history. - - Args: - tool_result: The ToolResult to convert. - - Returns: - A list of content blocks representing the tool result. - """ - contents = [] - text_content = "Tool Result Status: " + tool_result["status"] if tool_result["status"] else "" - - for tool_result_content in tool_result["content"]: - if "text" in tool_result_content: - text_content = "\nTool Result Text Content: " + tool_result_content["text"] + f"\n{text_content}" - elif "json" in tool_result_content: - text_content = ( - "\nTool Result JSON Content: " + json.dumps(tool_result_content["json"]) + f"\n{text_content}" - ) - elif "image" in tool_result_content: - contents.append(ContentBlock(image=tool_result_content["image"])) - elif "document" in tool_result_content: - contents.append(ContentBlock(document=tool_result_content["document"])) - else: - logger.warning("unsupported content type") - contents.append(ContentBlock(text=text_content)) - return contents diff --git a/src/strands/event_loop/error_handler.py b/src/strands/event_loop/error_handler.py index 5a78bb3f..a5c85668 100644 --- a/src/strands/event_loop/error_handler.py +++ b/src/strands/event_loop/error_handler.py @@ -74,7 +74,7 @@ def handle_input_too_long_error( """Handle 'Input is too long' errors by truncating tool results. When a context window overflow exception occurs (input too long for the model), this function attempts to recover - by finding and truncating the most recent tool results in the conversation history. If trunction is successful, the + by finding and truncating the most recent tool results in the conversation history. If truncation is successful, the function will make a call to the event loop. Args: diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index db5a1b97..23d7bd0f 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -28,6 +28,10 @@ logger = logging.getLogger(__name__) +MAX_ATTEMPTS = 6 +INITIAL_DELAY = 4 +MAX_DELAY = 240 # 4 minutes + def initialize_state(**kwargs: Any) -> Any: """Initialize the request state if not present. @@ -51,7 +55,7 @@ def event_loop_cycle( system_prompt: Optional[str], messages: Messages, tool_config: Optional[ToolConfig], - callback_handler: Any, + callback_handler: Callable[..., Any], tool_handler: Optional[ToolHandler], tool_execution_handler: Optional[ParallelToolExecutorInterface] = None, **kwargs: Any, @@ -130,13 +134,9 @@ def event_loop_cycle( stop_reason: StopReason usage: Any metrics: Metrics - max_attempts = 6 - initial_delay = 4 - max_delay = 240 # 4 minutes - current_delay = initial_delay # Retry loop for handling throttling exceptions - for attempt in range(max_attempts): + for attempt in range(MAX_ATTEMPTS): model_id = model.config.get("model_id") if hasattr(model, "config") else None model_invoke_span = tracer.start_model_invoke_span( parent_span=cycle_span, @@ -177,7 +177,7 @@ def event_loop_cycle( # Handle throttling errors with exponential backoff should_retry, current_delay = handle_throttling_error( - e, attempt, max_attempts, current_delay, max_delay, callback_handler, kwargs + e, attempt, MAX_ATTEMPTS, INITIAL_DELAY, MAX_DELAY, callback_handler, kwargs ) if should_retry: continue @@ -204,80 +204,35 @@ def event_loop_cycle( # If the model is requesting to use tools if stop_reason == "tool_use": - tool_uses: List[ToolUse] = [] - tool_results: List[ToolResult] = [] - invalid_tool_use_ids: List[str] = [] - - # Extract and validate tools - validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids) - - # Check if tools are available for execution - if tool_uses: - if tool_handler is None: - raise ValueError("toolUse present but tool handler not set") - if tool_config is None: - raise ValueError("toolUse present but tool config not set") - - # Create the tool handler process callable - tool_handler_process: Callable[[ToolUse], ToolResult] = partial( - tool_handler.process, - messages=messages, - model=model, - system_prompt=system_prompt, - tool_config=tool_config, - callback_handler=callback_handler, - **kwargs, + if not tool_handler: + raise EventLoopException( + Exception("Model requested tool use but no tool handler provided"), + kwargs["request_state"], ) - # Execute tools (parallel or sequential) - run_tools( - handler=tool_handler_process, - tool_uses=tool_uses, - event_loop_metrics=event_loop_metrics, - request_state=cast(Any, kwargs["request_state"]), - invalid_tool_use_ids=invalid_tool_use_ids, - tool_results=tool_results, - cycle_trace=cycle_trace, - parent_span=cycle_span, - parallel_tool_executor=tool_execution_handler, + if tool_config is None: + raise EventLoopException( + Exception("Model requested tool use but no tool config provided"), + kwargs["request_state"], ) - # Update state for the next cycle - kwargs = prepare_next_cycle(kwargs, event_loop_metrics) - - # Create the tool result message - tool_result_message: Message = { - "role": "user", - "content": [{"toolResult": result} for result in tool_results], - } - messages.append(tool_result_message) - callback_handler(message=tool_result_message) - - if cycle_span: - tracer.end_event_loop_cycle_span( - span=cycle_span, message=message, tool_result_message=tool_result_message - ) - - # Check if we should stop the event loop - if kwargs["request_state"].get("stop_event_loop"): - event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) - return ( - stop_reason, - message, - event_loop_metrics, - kwargs["request_state"], - ) - - # Recursive call to continue the conversation - return recurse_event_loop( - model=model, - system_prompt=system_prompt, - messages=messages, - tool_config=tool_config, - callback_handler=callback_handler, - tool_handler=tool_handler, - **kwargs, - ) + # Handle tool execution + return _handle_tool_execution( + stop_reason, + message, + model, + system_prompt, + messages, + tool_config, + tool_handler, + callback_handler, + tool_execution_handler, + event_loop_metrics, + cycle_trace, + cycle_span, + cycle_start_time, + kwargs, + ) # End the cycle and return results event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) @@ -377,3 +332,105 @@ def prepare_next_cycle(kwargs: Dict[str, Any], event_loop_metrics: EventLoopMetr kwargs["event_loop_parent_cycle_id"] = kwargs["event_loop_cycle_id"] return kwargs + + +def _handle_tool_execution( + stop_reason: StopReason, + message: Message, + model: Model, + system_prompt: Optional[str], + messages: Messages, + tool_config: ToolConfig, + tool_handler: ToolHandler, + callback_handler: Callable[..., Any], + tool_execution_handler: Optional[ParallelToolExecutorInterface], + event_loop_metrics: EventLoopMetrics, + cycle_trace: Trace, + cycle_span: Any, + cycle_start_time: float, + kwargs: Dict[str, Any], +) -> Tuple[StopReason, Message, EventLoopMetrics, Dict[str, Any]]: + tool_uses: List[ToolUse] = [] + tool_results: List[ToolResult] = [] + invalid_tool_use_ids: List[str] = [] + + """ + Handles the execution of tools requested by the model during an event loop cycle. + + Args: + stop_reason (StopReason): The reason the model stopped generating. + message (Message): The message from the model that may contain tool use requests. + model (Model): The model provider instance. + system_prompt (Optional[str]): The system prompt instructions for the model. + messages (Messages): The conversation history messages. + tool_config (ToolConfig): Configuration for available tools. + tool_handler (ToolHandler): Handler for tool execution. + callback_handler (Callable[..., Any]): Callback for processing events as they happen. + tool_execution_handler (Optional[ParallelToolExecutorInterface]): Optional handler for parallel tool execution. + event_loop_metrics (EventLoopMetrics): Metrics tracking object for the event loop. + cycle_trace (Trace): Trace object for the current event loop cycle. + cycle_span (Any): Span object for tracing the cycle (type may vary). + cycle_start_time (float): Start time of the current cycle. + kwargs (Dict[str, Any]): Additional keyword arguments, including request state. + + Returns: + Tuple[StopReason, Message, EventLoopMetrics, Dict[str, Any]]: + - The stop reason, + - The updated message, + - The updated event loop metrics, + - The updated request state. + """ + validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids) + + if not tool_uses: + return stop_reason, message, event_loop_metrics, kwargs["request_state"] + + tool_handler_process = partial( + tool_handler.process, + messages=messages, + model=model, + system_prompt=system_prompt, + tool_config=tool_config, + callback_handler=callback_handler, + **kwargs, + ) + + run_tools( + handler=tool_handler_process, + tool_uses=tool_uses, + event_loop_metrics=event_loop_metrics, + request_state=cast(Any, kwargs["request_state"]), + invalid_tool_use_ids=invalid_tool_use_ids, + tool_results=tool_results, + cycle_trace=cycle_trace, + parent_span=cycle_span, + parallel_tool_executor=tool_execution_handler, + ) + + kwargs = prepare_next_cycle(kwargs, event_loop_metrics) + + tool_result_message: Message = { + "role": "user", + "content": [{"toolResult": result} for result in tool_results], + } + + messages.append(tool_result_message) + callback_handler(message=tool_result_message) + + if cycle_span: + tracer = get_tracer() + tracer.end_event_loop_cycle_span(span=cycle_span, message=message, tool_result_message=tool_result_message) + + if kwargs["request_state"].get("stop_event_loop", False): + event_loop_metrics.end_cycle(cycle_start_time, cycle_trace) + return stop_reason, message, event_loop_metrics, kwargs["request_state"] + + return recurse_event_loop( + model=model, + system_prompt=system_prompt, + messages=messages, + tool_config=tool_config, + callback_handler=callback_handler, + tool_handler=tool_handler, + **kwargs, + ) diff --git a/src/strands/handlers/callback_handler.py b/src/strands/handlers/callback_handler.py index d6d104d8..4b794b4f 100644 --- a/src/strands/handlers/callback_handler.py +++ b/src/strands/handlers/callback_handler.py @@ -17,15 +17,19 @@ def __call__(self, **kwargs: Any) -> None: Args: **kwargs: Callback event data including: - + - reasoningText (Optional[str]): Reasoning text to print if provided. - data (str): Text content to stream. - complete (bool): Whether this is the final chunk of a response. - current_tool_use (dict): Information about the current tool being used. """ + reasoningText = kwargs.get("reasoningText", False) data = kwargs.get("data", "") complete = kwargs.get("complete", False) current_tool_use = kwargs.get("current_tool_use", {}) + if reasoningText: + print(reasoningText, end="") + if data: print(data, end="" if not complete else "\n") diff --git a/src/strands/handlers/tool_handler.py b/src/strands/handlers/tool_handler.py index 0803eca5..bc4ec1ce 100644 --- a/src/strands/handlers/tool_handler.py +++ b/src/strands/handlers/tool_handler.py @@ -46,6 +46,7 @@ def preprocess( def process( self, tool: Any, + *, model: Model, system_prompt: Optional[str], messages: List[Any], diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 704114eb..99e49f81 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -97,13 +97,16 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An Anthropic formatted content block. """ if "document" in content: + mime_type = mimetypes.types_map.get(f".{content['document']['format']}", "application/octet-stream") return { "source": { - "data": base64.b64encode(content["document"]["source"]["bytes"]).decode("utf-8"), - "media_type": mimetypes.types_map.get( - f".{content['document']['format']}", "application/octet-stream" + "data": ( + content["document"]["source"]["bytes"].decode("utf-8") + if mime_type == "text/plain" + else base64.b64encode(content["document"]["source"]["bytes"]).decode("utf-8") ), - "type": "base64", + "media_type": mime_type, + "type": "text" if mime_type == "text/plain" else "base64", }, "title": content["document"]["name"], "type": "document", diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index d996aaae..9bbcca7d 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -3,12 +3,14 @@ - Docs: https://aws.amazon.com/bedrock/ """ +import json import logging -from typing import Any, Iterable, Literal, Optional, cast +import os +from typing import Any, Iterable, List, Literal, Optional, cast import boto3 from botocore.config import Config as BotocoreConfig -from botocore.exceptions import ClientError, EventStreamError +from botocore.exceptions import ClientError from typing_extensions import TypedDict, Unpack, override from ..types.content import Messages @@ -60,6 +62,7 @@ class BedrockConfig(TypedDict, total=False): max_tokens: Maximum number of tokens to generate in the response model_id: The Bedrock model ID (e.g., "us.anthropic.claude-3-7-sonnet-20250219-v1:0") stop_sequences: List of sequences that will stop generation when encountered + streaming: Flag to enable/disable streaming. Defaults to True. temperature: Controls randomness in generation (higher = more random) top_p: Controls diversity via nucleus sampling (alternative to temperature) """ @@ -80,6 +83,7 @@ class BedrockConfig(TypedDict, total=False): max_tokens: Optional[int] model_id: str stop_sequences: Optional[list[str]] + streaming: Optional[bool] temperature: Optional[float] top_p: Optional[float] @@ -96,7 +100,8 @@ def __init__( Args: boto_session: Boto Session to use when calling the Bedrock Model. boto_client_config: Configuration to use when creating the Bedrock-Runtime Boto Client. - region_name: AWS region to use for the Bedrock service. Defaults to "us-west-2". + region_name: AWS region to use for the Bedrock service. + Defaults to the AWS_REGION environment variable if set, or "us-west-2" if not set. **model_config: Configuration options for the Bedrock model. """ if region_name and boto_session: @@ -108,11 +113,26 @@ def __init__( logger.debug("config=<%s> | initializing", self.config) session = boto_session or boto3.Session( - region_name=region_name or "us-west-2", + region_name=region_name or os.getenv("AWS_REGION") or "us-west-2", ) + + # Add strands-agents to the request user agent + if boto_client_config: + existing_user_agent = getattr(boto_client_config, "user_agent_extra", None) + + # Append 'strands-agents' to existing user_agent_extra or set it if not present + if existing_user_agent: + new_user_agent = f"{existing_user_agent} strands-agents" + else: + new_user_agent = "strands-agents" + + client_config = boto_client_config.merge(BotocoreConfig(user_agent_extra=new_user_agent)) + else: + client_config = BotocoreConfig(user_agent_extra="strands-agents") + self.client = session.client( service_name="bedrock-runtime", - config=boto_client_config, + config=client_config, ) @override @@ -129,7 +149,7 @@ def get_config(self) -> BedrockConfig: """Get the current Bedrock Model configuration. Returns: - The Bedrok model configuration. + The Bedrock model configuration. """ return self.config @@ -229,11 +249,68 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent: """ return cast(StreamEvent, event) + def _has_blocked_guardrail(self, guardrail_data: dict[str, Any]) -> bool: + """Check if guardrail data contains any blocked policies. + + Args: + guardrail_data: Guardrail data from trace information. + + Returns: + True if any blocked guardrail is detected, False otherwise. + """ + input_assessment = guardrail_data.get("inputAssessment", {}) + output_assessments = guardrail_data.get("outputAssessments", {}) + + # Check input assessments + if any(self._find_detected_and_blocked_policy(assessment) for assessment in input_assessment.values()): + return True + + # Check output assessments + if any(self._find_detected_and_blocked_policy(assessment) for assessment in output_assessments.values()): + return True + + return False + + def _generate_redaction_events(self) -> list[StreamEvent]: + """Generate redaction events based on configuration. + + Returns: + List of redaction events to yield. + """ + events: List[StreamEvent] = [] + + if self.config.get("guardrail_redact_input", True): + logger.debug("Redacting user input due to guardrail.") + events.append( + { + "redactContent": { + "redactUserContentMessage": self.config.get( + "guardrail_redact_input_message", "[User input redacted.]" + ) + } + } + ) + + if self.config.get("guardrail_redact_output", False): + logger.debug("Redacting assistant output due to guardrail.") + events.append( + { + "redactContent": { + "redactAssistantContentMessage": self.config.get( + "guardrail_redact_output_message", "[Assistant output redacted.]" + ) + } + } + ) + + return events + @override - def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: - """Send the request to the Bedrock model and get the streaming response. + def stream(self, request: dict[str, Any]) -> Iterable[StreamEvent]: + """Send the request to the Bedrock model and get the response. - This method calls the Bedrock converse_stream API and returns the stream of response events. + This method calls either the Bedrock converse_stream API or the converse API + based on the streaming parameter in the configuration. Args: request: The formatted request to send to the Bedrock model @@ -243,63 +320,132 @@ def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: Raises: ContextWindowOverflowException: If the input exceeds the model's context window. - EventStreamError: For all other Bedrock API errors. + ModelThrottledException: If the model service is throttling requests. """ + streaming = self.config.get("streaming", True) + try: - response = self.client.converse_stream(**request) - for chunk in response["stream"]: - if self.config.get("guardrail_redact_input", True) or self.config.get("guardrail_redact_output", False): + if streaming: + # Streaming implementation + response = self.client.converse_stream(**request) + for chunk in response["stream"]: if ( "metadata" in chunk and "trace" in chunk["metadata"] and "guardrail" in chunk["metadata"]["trace"] ): - inputAssessment = chunk["metadata"]["trace"]["guardrail"].get("inputAssessment", {}) - outputAssessments = chunk["metadata"]["trace"]["guardrail"].get("outputAssessments", {}) - - # Check if an input or output guardrail was triggered - if any( - self._find_detected_and_blocked_policy(assessment) - for assessment in inputAssessment.values() - ) or any( - self._find_detected_and_blocked_policy(assessment) - for assessment in outputAssessments.values() - ): - if self.config.get("guardrail_redact_input", True): - logger.debug("Found blocked input guardrail. Redacting input.") - yield { - "redactContent": { - "redactUserContentMessage": self.config.get( - "guardrail_redact_input_message", "[User input redacted.]" - ) - } - } - if self.config.get("guardrail_redact_output", False): - logger.debug("Found blocked output guardrail. Redacting output.") - yield { - "redactContent": { - "redactAssistantContentMessage": self.config.get( - "guardrail_redact_output_message", "[Assistant output redacted.]" - ) - } - } + guardrail_data = chunk["metadata"]["trace"]["guardrail"] + if self._has_blocked_guardrail(guardrail_data): + yield from self._generate_redaction_events() + yield chunk + else: + # Non-streaming implementation + response = self.client.converse(**request) + + # Convert and yield from the response + yield from self._convert_non_streaming_to_streaming(response) + + # Check for guardrail triggers after yielding any events (same as streaming path) + if ( + "trace" in response + and "guardrail" in response["trace"] + and self._has_blocked_guardrail(response["trace"]["guardrail"]) + ): + yield from self._generate_redaction_events() - yield chunk - except EventStreamError as e: - # Handle throttling that occurs mid-stream? - if "ThrottlingException" in str(e) and "ConverseStream" in str(e): - raise ModelThrottledException(str(e)) from e + except ClientError as e: + error_message = str(e) - if any(overflow_message in str(e) for overflow_message in BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES): + # Handle throttling error + if e.response["Error"]["Code"] == "ThrottlingException": + raise ModelThrottledException(error_message) from e + + # Handle context window overflow + if any(overflow_message in error_message for overflow_message in BEDROCK_CONTEXT_WINDOW_OVERFLOW_MESSAGES): logger.warning("bedrock threw context window overflow error") raise ContextWindowOverflowException(e) from e + + # Otherwise raise the error raise e - except ClientError as e: - # Handle throttling that occurs at the beginning of the call - if e.response["Error"]["Code"] == "ThrottlingException": - raise ModelThrottledException(str(e)) from e - raise + def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Iterable[StreamEvent]: + """Convert a non-streaming response to the streaming format. + + Args: + response: The non-streaming response from the Bedrock model. + + Returns: + An iterable of response events in the streaming format. + """ + # Yield messageStart event + yield {"messageStart": {"role": response["output"]["message"]["role"]}} + + # Process content blocks + for content in response["output"]["message"]["content"]: + # Yield contentBlockStart event if needed + if "toolUse" in content: + yield { + "contentBlockStart": { + "start": { + "toolUse": { + "toolUseId": content["toolUse"]["toolUseId"], + "name": content["toolUse"]["name"], + } + }, + } + } + + # For tool use, we need to yield the input as a delta + input_value = json.dumps(content["toolUse"]["input"]) + + yield {"contentBlockDelta": {"delta": {"toolUse": {"input": input_value}}}} + elif "text" in content: + # Then yield the text as a delta + yield { + "contentBlockDelta": { + "delta": {"text": content["text"]}, + } + } + elif "reasoningContent" in content: + # Then yield the reasoning content as a delta + yield { + "contentBlockDelta": { + "delta": {"reasoningContent": {"text": content["reasoningContent"]["reasoningText"]["text"]}} + } + } + + if "signature" in content["reasoningContent"]["reasoningText"]: + yield { + "contentBlockDelta": { + "delta": { + "reasoningContent": { + "signature": content["reasoningContent"]["reasoningText"]["signature"] + } + } + } + } + + # Yield contentBlockStop event + yield {"contentBlockStop": {}} + + # Yield messageStop event + yield { + "messageStop": { + "stopReason": response["stopReason"], + "additionalModelResponseFields": response.get("additionalModelResponseFields"), + } + } + + # Yield metadata event + if "usage" in response or "metrics" in response or "trace" in response: + metadata: StreamEvent = {"metadata": {}} + if "usage" in response: + metadata["metadata"]["usage"] = response["usage"] + if "metrics" in response: + metadata["metadata"]["metrics"] = response["metrics"] + if "trace" in response: + metadata["metadata"]["trace"] = response["trace"] + yield metadata def _find_detected_and_blocked_policy(self, input: Any) -> bool: """Recursively checks if the assessment contains a detected and blocked guardrail. diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index a7563133..23d2c2ae 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -3,23 +3,19 @@ - Docs: https://docs.litellm.ai/ """ -import json import logging -import mimetypes -from typing import Any, Iterable, Optional, TypedDict +from typing import Any, Optional, TypedDict, cast import litellm from typing_extensions import Unpack, override -from ..types.content import ContentBlock, Messages -from ..types.models import Model -from ..types.streaming import StreamEvent -from ..types.tools import ToolResult, ToolSpec, ToolUse +from ..types.content import ContentBlock +from .openai import OpenAIModel logger = logging.getLogger(__name__) -class LiteLLMModel(Model): +class LiteLLMModel(OpenAIModel): """LiteLLM model provider implementation.""" class LiteLLMConfig(TypedDict, total=False): @@ -45,7 +41,7 @@ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: https://github.com/BerriAI/litellm/blob/main/litellm/main.py. **model_config: Configuration options for the LiteLLM model. """ - self.config = LiteLLMModel.LiteLLMConfig(**model_config) + self.config = dict(model_config) logger.debug("config=<%s> | initializing", self.config) @@ -68,9 +64,11 @@ def get_config(self) -> LiteLLMConfig: Returns: The LiteLLM model configuration. """ - return self.config + return cast(LiteLLMModel.LiteLLMConfig, self.config) - def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any]: + @override + @staticmethod + def format_request_message_content(content: ContentBlock) -> dict[str, Any]: """Format a LiteLLM content block. Args: @@ -79,18 +77,6 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An Returns: LiteLLM formatted content block. """ - if "image" in content: - mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream") - image_data = content["image"]["source"]["bytes"].decode("utf-8") - return { - "image_url": { - "detail": "auto", - "format": mime_type, - "url": f"data:{mime_type};base64,{image_data}", - }, - "type": "image_url", - } - if "reasoningContent" in content: return { "signature": content["reasoningContent"]["reasoningText"]["signature"], @@ -98,9 +84,6 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An "type": "thinking", } - if "text" in content: - return {"text": content["text"], "type": "text"} - if "video" in content: return { "type": "video_url", @@ -110,230 +93,4 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An }, } - return {"text": json.dumps(content), "type": "text"} - - def _format_request_message_tool_call(self, tool_use: ToolUse) -> dict[str, Any]: - """Format a LiteLLM tool call. - - Args: - tool_use: Tool use requested by the model. - - Returns: - LiteLLM formatted tool call. - """ - return { - "function": { - "arguments": json.dumps(tool_use["input"]), - "name": tool_use["name"], - }, - "id": tool_use["toolUseId"], - "type": "function", - } - - def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any]: - """Format a LiteLLM tool message. - - Args: - tool_result: Tool result collected from a tool execution. - - Returns: - LiteLLM formatted tool message. - """ - return { - "role": "tool", - "tool_call_id": tool_result["toolUseId"], - "content": json.dumps( - { - "content": tool_result["content"], - "status": tool_result["status"], - } - ), - } - - def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: - """Format a LiteLLM messages array. - - Args: - messages: List of message objects to be processed by the model. - system_prompt: System prompt to provide context to the model. - - Returns: - A LiteLLM messages array. - """ - formatted_messages: list[dict[str, Any]] - formatted_messages = [{"role": "system", "content": system_prompt}] if system_prompt else [] - - for message in messages: - contents = message["content"] - - formatted_contents = [ - self._format_request_message_content(content) - for content in contents - if not any(block_type in content for block_type in ["toolResult", "toolUse"]) - ] - formatted_tool_calls = [ - self._format_request_message_tool_call(content["toolUse"]) - for content in contents - if "toolUse" in content - ] - formatted_tool_messages = [ - self._format_request_tool_message(content["toolResult"]) - for content in contents - if "toolResult" in content - ] - - formatted_message = { - "role": message["role"], - "content": formatted_contents, - **({"tool_calls": formatted_tool_calls} if formatted_tool_calls else {}), - } - formatted_messages.append(formatted_message) - formatted_messages.extend(formatted_tool_messages) - - return [message for message in formatted_messages if message["content"] or "tool_calls" in message] - - @override - def format_request( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None - ) -> dict[str, Any]: - """Format a LiteLLM chat streaming request. - - Args: - messages: List of message objects to be processed by the model. - tool_specs: List of tool specifications to make available to the model. - system_prompt: System prompt to provide context to the model. - - Returns: - A LiteLLM chat streaming request. - """ - return { - "messages": self._format_request_messages(messages, system_prompt), - "model": self.config["model_id"], - "stream": True, - "stream_options": {"include_usage": True}, - "tools": [ - { - "type": "function", - "function": { - "name": tool_spec["name"], - "description": tool_spec["description"], - "parameters": tool_spec["inputSchema"]["json"], - }, - } - for tool_spec in tool_specs or [] - ], - **(self.config.get("params") or {}), - } - - @override - def format_chunk(self, event: dict[str, Any]) -> StreamEvent: - """Format the LiteLLM response events into standardized message chunks. - - Args: - event: A response event from the LiteLLM model. - - Returns: - The formatted chunk. - - Raises: - RuntimeError: If chunk_type is not recognized. - This error should never be encountered as we control chunk_type in the stream method. - """ - match event["chunk_type"]: - case "message_start": - return {"messageStart": {"role": "assistant"}} - - case "content_start": - if event["data_type"] == "tool": - return { - "contentBlockStart": { - "start": { - "toolUse": { - "name": event["data"].function.name, - "toolUseId": event["data"].id, - } - } - } - } - - return {"contentBlockStart": {"start": {}}} - - case "content_delta": - if event["data_type"] == "tool": - return {"contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments}}}} - - return {"contentBlockDelta": {"delta": {"text": event["data"]}}} - - case "content_stop": - return {"contentBlockStop": {}} - - case "message_stop": - match event["data"]: - case "tool_calls": - return {"messageStop": {"stopReason": "tool_use"}} - case "length": - return {"messageStop": {"stopReason": "max_tokens"}} - case _: - return {"messageStop": {"stopReason": "end_turn"}} - - case "metadata": - return { - "metadata": { - "usage": { - "inputTokens": event["data"].prompt_tokens, - "outputTokens": event["data"].completion_tokens, - "totalTokens": event["data"].total_tokens, - }, - "metrics": { - "latencyMs": 0, # TODO - }, - }, - } - - case _: - raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") - - @override - def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: - """Send the request to the LiteLLM model and get the streaming response. - - Args: - request: The formatted request to send to the LiteLLM model. - - Returns: - An iterable of response events from the LiteLLM model. - """ - response = self.client.chat.completions.create(**request) - - yield {"chunk_type": "message_start"} - yield {"chunk_type": "content_start", "data_type": "text"} - - tool_calls: dict[int, list[Any]] = {} - - for event in response: - choice = event.choices[0] - if choice.finish_reason: - break - - if choice.delta.content: - yield {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content} - - for tool_call in choice.delta.tool_calls or []: - tool_calls.setdefault(tool_call.index, []).append(tool_call) - - yield {"chunk_type": "content_stop", "data_type": "text"} - - for tool_deltas in tool_calls.values(): - tool_start, tool_deltas = tool_deltas[0], tool_deltas[1:] - yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_start} - - for tool_delta in tool_deltas: - yield {"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta} - - yield {"chunk_type": "content_stop", "data_type": "tool"} - - yield {"chunk_type": "message_stop", "data": choice.finish_reason} - - event = next(response) - if hasattr(event, "usage"): - yield {"chunk_type": "metadata", "data": event.usage} + return OpenAIModel.format_request_message_content(content) diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py new file mode 100644 index 00000000..764cb851 --- /dev/null +++ b/src/strands/models/openai.py @@ -0,0 +1,124 @@ +"""OpenAI model provider. + +- Docs: https://platform.openai.com/docs/overview +""" + +import logging +from typing import Any, Iterable, Optional, Protocol, TypedDict, cast + +import openai +from typing_extensions import Unpack, override + +from ..types.models import OpenAIModel as SAOpenAIModel + +logger = logging.getLogger(__name__) + + +class Client(Protocol): + """Protocol defining the OpenAI-compatible interface for the underlying provider client.""" + + @property + # pragma: no cover + def chat(self) -> Any: + """Chat completions interface.""" + ... + + +class OpenAIModel(SAOpenAIModel): + """OpenAI model provider implementation.""" + + client: Client + + class OpenAIConfig(TypedDict, total=False): + """Configuration options for OpenAI models. + + Attributes: + model_id: Model ID (e.g., "gpt-4o"). + For a complete list of supported models, see https://platform.openai.com/docs/models. + params: Model parameters (e.g., max_tokens). + For a complete list of supported parameters, see + https://platform.openai.com/docs/api-reference/chat/create. + """ + + model_id: str + params: Optional[dict[str, Any]] + + def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[OpenAIConfig]) -> None: + """Initialize provider instance. + + Args: + client_args: Arguments for the OpenAI client. + For a complete list of supported arguments, see https://pypi.org/project/openai/. + **model_config: Configuration options for the OpenAI model. + """ + self.config = dict(model_config) + + logger.debug("config=<%s> | initializing", self.config) + + client_args = client_args or {} + self.client = openai.OpenAI(**client_args) + + @override + def update_config(self, **model_config: Unpack[OpenAIConfig]) -> None: # type: ignore[override] + """Update the OpenAI model configuration with the provided arguments. + + Args: + **model_config: Configuration overrides. + """ + self.config.update(model_config) + + @override + def get_config(self) -> OpenAIConfig: + """Get the OpenAI model configuration. + + Returns: + The OpenAI model configuration. + """ + return cast(OpenAIModel.OpenAIConfig, self.config) + + @override + def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: + """Send the request to the OpenAI model and get the streaming response. + + Args: + request: The formatted request to send to the OpenAI model. + + Returns: + An iterable of response events from the OpenAI model. + """ + response = self.client.chat.completions.create(**request) + + yield {"chunk_type": "message_start"} + yield {"chunk_type": "content_start", "data_type": "text"} + + tool_calls: dict[int, list[Any]] = {} + + for event in response: + choice = event.choices[0] + + if choice.delta.content: + yield {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content} + + for tool_call in choice.delta.tool_calls or []: + tool_calls.setdefault(tool_call.index, []).append(tool_call) + + if choice.finish_reason: + break + + yield {"chunk_type": "content_stop", "data_type": "text"} + + for tool_deltas in tool_calls.values(): + yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]} + + for tool_delta in tool_deltas: + yield {"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta} + + yield {"chunk_type": "content_stop", "data_type": "tool"} + + yield {"chunk_type": "message_stop", "data": choice.finish_reason} + + # Skip remaining events as we don't have use for anything except the final usage payload + for event in response: + _ = event + + yield {"chunk_type": "metadata", "data": event.usage} diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index ad30a445..34eb7bed 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -7,7 +7,7 @@ import json import logging import os -from datetime import datetime, timezone +from datetime import date, datetime, timezone from importlib.metadata import version from typing import Any, Dict, Mapping, Optional @@ -30,21 +30,49 @@ class JSONEncoder(json.JSONEncoder): """Custom JSON encoder that handles non-serializable types.""" - def default(self, obj: Any) -> Any: - """Handle non-serializable types. + def encode(self, obj: Any) -> str: + """Recursively encode objects, preserving structure and only replacing unserializable values. Args: - obj: The object to serialize + obj: The object to encode Returns: - A JSON serializable version of the object + JSON string representation of the object """ - value = "" - try: - value = super().default(obj) - except TypeError: - value = "" - return value + # Process the object to handle non-serializable values + processed_obj = self._process_value(obj) + # Use the parent class to encode the processed object + return super().encode(processed_obj) + + def _process_value(self, value: Any) -> Any: + """Process any value, handling containers recursively. + + Args: + value: The value to process + + Returns: + Processed value with unserializable parts replaced + """ + # Handle datetime objects directly + if isinstance(value, (datetime, date)): + return value.isoformat() + + # Handle dictionaries + elif isinstance(value, dict): + return {k: self._process_value(v) for k, v in value.items()} + + # Handle lists + elif isinstance(value, list): + return [self._process_value(item) for item in value] + + # Handle all other values + else: + try: + # Test if the value is JSON serializable + json.dumps(value) + return value + except (TypeError, OverflowError, ValueError): + return "" class Tracer: @@ -65,7 +93,7 @@ def __init__( service_name: str = "strands-agents", otlp_endpoint: Optional[str] = None, otlp_headers: Optional[Dict[str, str]] = None, - enable_console_export: bool = False, + enable_console_export: Optional[bool] = None, ): """Initialize the tracer. @@ -77,13 +105,17 @@ def __init__( """ # Check environment variables first env_endpoint = os.environ.get("OTEL_EXPORTER_OTLP_ENDPOINT") - env_console_export = os.environ.get("STRANDS_OTEL_ENABLE_CONSOLE_EXPORT", "").lower() in ("true", "1", "yes") + env_console_export_str = os.environ.get("STRANDS_OTEL_ENABLE_CONSOLE_EXPORT") + + # Constructor parameters take precedence over environment variables + self.otlp_endpoint = otlp_endpoint or env_endpoint - # Environment variables take precedence over constructor parameters - if env_endpoint: - otlp_endpoint = env_endpoint - if env_console_export: - enable_console_export = True + if enable_console_export is not None: + self.enable_console_export = enable_console_export + elif env_console_export_str: + self.enable_console_export = env_console_export_str.lower() in ("true", "1", "yes") + else: + self.enable_console_export = False # Parse headers from environment if available env_headers = os.environ.get("OTEL_EXPORTER_OTLP_HEADERS") @@ -97,17 +129,14 @@ def __init__( headers_dict[key.strip()] = value.strip() otlp_headers = headers_dict except Exception as e: - logger.warning(f"error=<{e}> | failed to parse OTEL_EXPORTER_OTLP_HEADERS") + logger.warning("error=<%s> | failed to parse OTEL_EXPORTER_OTLP_HEADERS", e) self.service_name = service_name - self.otlp_endpoint = otlp_endpoint self.otlp_headers = otlp_headers or {} - self.enable_console_export = enable_console_export - self.tracer_provider: Optional[TracerProvider] = None self.tracer: Optional[trace.Tracer] = None - if otlp_endpoint or enable_console_export: + if self.otlp_endpoint or self.enable_console_export: self._initialize_tracer() def _initialize_tracer(self) -> None: @@ -156,9 +185,9 @@ def _initialize_tracer(self) -> None: batch_processor = BatchSpanProcessor(otlp_exporter) self.tracer_provider.add_span_processor(batch_processor) - logger.info(f"endpoint=<{endpoint}> | OTLP exporter configured with endpoint") + logger.info("endpoint=<%s> | OTLP exporter configured with endpoint", endpoint) except Exception as e: - logger.error(f"error=<{e}> | Failed to configure OTLP exporter", exc_info=True) + logger.exception("error=<%s> | Failed to configure OTLP exporter", e) # Set as global tracer provider trace.set_tracer_provider(self.tracer_provider) @@ -239,7 +268,7 @@ def _end_span( else: span.set_status(StatusCode.OK) except Exception as e: - logger.warning(f"error=<{e}> | error while ending span", exc_info=True) + logger.warning("error=<%s> | error while ending span", e, exc_info=True) finally: span.end() # Force flush to ensure spans are exported @@ -247,7 +276,7 @@ def _end_span( try: self.tracer_provider.force_flush() except Exception as e: - logger.warning(f"error=<{e}> | failed to force flush tracer provider") + logger.warning("error=<%s> | failed to force flush tracer provider", e) def end_span_with_error(self, span: trace.Span, error_message: str, exception: Optional[Exception] = None) -> None: """End a span with error status. @@ -287,7 +316,7 @@ def start_model_invoke_span( "gen_ai.system": "strands-agents", "agent.name": agent_name, "gen_ai.agent.name": agent_name, - "gen_ai.prompt": json.dumps(messages, cls=JSONEncoder), + "gen_ai.prompt": serialize(messages), } if model_id: @@ -310,7 +339,7 @@ def end_model_invoke_span( error: Optional exception if the model call failed. """ attributes: Dict[str, AttributeValue] = { - "gen_ai.completion": json.dumps(message["content"], cls=JSONEncoder), + "gen_ai.completion": serialize(message["content"]), "gen_ai.usage.prompt_tokens": usage["inputTokens"], "gen_ai.usage.completion_tokens": usage["outputTokens"], "gen_ai.usage.total_tokens": usage["totalTokens"], @@ -332,9 +361,10 @@ def start_tool_call_span( The created span, or None if tracing is not enabled. """ attributes: Dict[str, AttributeValue] = { + "gen_ai.prompt": serialize(tool), "tool.name": tool["name"], "tool.id": tool["toolUseId"], - "tool.parameters": json.dumps(tool["input"], cls=JSONEncoder), + "tool.parameters": serialize(tool["input"]), } # Add additional kwargs as attributes @@ -358,10 +388,11 @@ def end_tool_call_span( status = tool_result.get("status") status_str = str(status) if status is not None else "" + tool_result_content_json = serialize(tool_result.get("content")) attributes.update( { - "tool.result": json.dumps(tool_result.get("content"), cls=JSONEncoder), - "gen_ai.completion": json.dumps(tool_result.get("content"), cls=JSONEncoder), + "tool.result": tool_result_content_json, + "gen_ai.completion": tool_result_content_json, "tool.status": status_str, } ) @@ -390,7 +421,7 @@ def start_event_loop_cycle_span( parent_span = parent_span if parent_span else event_loop_kwargs.get("event_loop_parent_span") attributes: Dict[str, AttributeValue] = { - "gen_ai.prompt": json.dumps(messages, cls=JSONEncoder), + "gen_ai.prompt": serialize(messages), "event_loop.cycle_id": event_loop_cycle_id, } @@ -419,11 +450,11 @@ def end_event_loop_cycle_span( error: Optional exception if the cycle failed. """ attributes: Dict[str, AttributeValue] = { - "gen_ai.completion": json.dumps(message["content"], cls=JSONEncoder), + "gen_ai.completion": serialize(message["content"]), } if tool_result_message: - attributes["tool.result"] = json.dumps(tool_result_message["content"], cls=JSONEncoder) + attributes["tool.result"] = serialize(tool_result_message["content"]) self._end_span(span, attributes, error) @@ -460,7 +491,7 @@ def start_agent_span( attributes["gen_ai.request.model"] = model_id if tools: - tools_json = json.dumps(tools, cls=JSONEncoder) + tools_json = serialize(tools) attributes["agent.tools"] = tools_json attributes["gen_ai.agent.tools"] = tools_json @@ -492,7 +523,7 @@ def end_agent_span( if response: attributes.update( { - "gen_ai.completion": json.dumps(response, cls=JSONEncoder), + "gen_ai.completion": str(response), } ) @@ -517,7 +548,7 @@ def get_tracer( service_name: str = "strands-agents", otlp_endpoint: Optional[str] = None, otlp_headers: Optional[Dict[str, str]] = None, - enable_console_export: bool = False, + enable_console_export: Optional[bool] = None, ) -> Tracer: """Get or create the global tracer. @@ -541,3 +572,15 @@ def get_tracer( ) return _tracer_instance + + +def serialize(obj: Any) -> str: + """Serialize an object to JSON with consistent settings. + + Args: + obj: The object to serialize + + Returns: + JSON string representation of the object + """ + return json.dumps(obj, ensure_ascii=False, cls=JSONEncoder) diff --git a/src/strands/tools/tools.py b/src/strands/tools/tools.py index 40565a24..b595c3d6 100644 --- a/src/strands/tools/tools.py +++ b/src/strands/tools/tools.py @@ -40,14 +40,14 @@ def validate_tool_use_name(tool: ToolUse) -> None: Raises: InvalidToolUseNameException: If the tool name is invalid. """ - # We need to fix some typing here, because we dont actually expect a ToolUse, but dict[str, Any] + # We need to fix some typing here, because we don't actually expect a ToolUse, but dict[str, Any] if "name" not in tool: message = "tool name missing" # type: ignore[unreachable] logger.warning(message) raise InvalidToolUseNameException(message) tool_name = tool["name"] - tool_name_pattern = r"^[a-zA-Z][a-zA-Z0-9_]*$" + tool_name_pattern = r"^[a-zA-Z][a-zA-Z0-9_\-]*$" tool_name_max_length = 64 valid_name_pattern = bool(re.match(tool_name_pattern, tool_name)) tool_name_len = len(tool_name) diff --git a/src/strands/types/guardrails.py b/src/strands/types/guardrails.py index 6055b9ab..c15ba1be 100644 --- a/src/strands/types/guardrails.py +++ b/src/strands/types/guardrails.py @@ -16,7 +16,7 @@ class GuardrailConfig(TypedDict, total=False): Attributes: guardrailIdentifier: Unique identifier for the guardrail. guardrailVersion: Version of the guardrail to apply. - streamProcessingMode: Procesing mode. + streamProcessingMode: Processing mode. trace: The trace behavior for the guardrail. """ @@ -219,7 +219,7 @@ class GuardrailAssessment(TypedDict): contentPolicy: The content policy. contextualGroundingPolicy: The contextual grounding policy used for the guardrail assessment. sensitiveInformationPolicy: The sensitive information policy. - topicPolic: The topic policy. + topicPolicy: The topic policy. wordPolicy: The word policy. """ diff --git a/src/strands/types/media.py b/src/strands/types/media.py index 058a09ea..29b89e5c 100644 --- a/src/strands/types/media.py +++ b/src/strands/types/media.py @@ -68,7 +68,7 @@ class ImageContent(TypedDict): class VideoSource(TypedDict): - """Contains the content of a vidoe. + """Contains the content of a video. Attributes: bytes: The binary content of the video. diff --git a/src/strands/types/models/__init__.py b/src/strands/types/models/__init__.py new file mode 100644 index 00000000..5ce0a498 --- /dev/null +++ b/src/strands/types/models/__init__.py @@ -0,0 +1,6 @@ +"""Model-related type definitions for the SDK.""" + +from .model import Model +from .openai import OpenAIModel + +__all__ = ["Model", "OpenAIModel"] diff --git a/src/strands/types/models.py b/src/strands/types/models/model.py similarity index 97% rename from src/strands/types/models.py rename to src/strands/types/models/model.py index e3d96e29..23e74602 100644 --- a/src/strands/types/models.py +++ b/src/strands/types/models/model.py @@ -4,9 +4,9 @@ import logging from typing import Any, Iterable, Optional -from .content import Messages -from .streaming import StreamEvent -from .tools import ToolSpec +from ..content import Messages +from ..streaming import StreamEvent +from ..tools import ToolSpec logger = logging.getLogger(__name__) diff --git a/src/strands/types/models/openai.py b/src/strands/types/models/openai.py new file mode 100644 index 00000000..307c0be6 --- /dev/null +++ b/src/strands/types/models/openai.py @@ -0,0 +1,254 @@ +"""Base OpenAI model provider. + +This module provides the base OpenAI model provider class which implements shared logic for formatting requests and +responses to and from the OpenAI specification. + +- Docs: https://pypi.org/project/openai +""" + +import abc +import base64 +import json +import logging +import mimetypes +from typing import Any, Optional + +from typing_extensions import override + +from ..content import ContentBlock, Messages +from ..streaming import StreamEvent +from ..tools import ToolResult, ToolSpec, ToolUse +from .model import Model + +logger = logging.getLogger(__name__) + + +class OpenAIModel(Model, abc.ABC): + """Base OpenAI model provider implementation. + + Implements shared logic for formatting requests and responses to and from the OpenAI specification. + """ + + config: dict[str, Any] + + @staticmethod + def format_request_message_content(content: ContentBlock) -> dict[str, Any]: + """Format an OpenAI compatible content block. + + Args: + content: Message content. + + Returns: + OpenAI compatible content block. + """ + if "document" in content: + mime_type = mimetypes.types_map.get(f".{content['document']['format']}", "application/octet-stream") + file_data = base64.b64encode(content["document"]["source"]["bytes"]).decode("utf-8") + return { + "file": { + "file_data": f"data:{mime_type};base64,{file_data}", + "filename": content["document"]["name"], + }, + "type": "file", + } + + if "image" in content: + mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream") + image_data = content["image"]["source"]["bytes"].decode("utf-8") + return { + "image_url": { + "detail": "auto", + "format": mime_type, + "url": f"data:{mime_type};base64,{image_data}", + }, + "type": "image_url", + } + + if "text" in content: + return {"text": content["text"], "type": "text"} + + return {"text": json.dumps(content), "type": "text"} + + @staticmethod + def format_request_message_tool_call(tool_use: ToolUse) -> dict[str, Any]: + """Format an OpenAI compatible tool call. + + Args: + tool_use: Tool use requested by the model. + + Returns: + OpenAI compatible tool call. + """ + return { + "function": { + "arguments": json.dumps(tool_use["input"]), + "name": tool_use["name"], + }, + "id": tool_use["toolUseId"], + "type": "function", + } + + @staticmethod + def format_request_tool_message(tool_result: ToolResult) -> dict[str, Any]: + """Format an OpenAI compatible tool message. + + Args: + tool_result: Tool result collected from a tool execution. + + Returns: + OpenAI compatible tool message. + """ + return { + "role": "tool", + "tool_call_id": tool_result["toolUseId"], + "content": json.dumps( + { + "content": tool_result["content"], + "status": tool_result["status"], + } + ), + } + + @classmethod + def format_request_messages(cls, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + """Format an OpenAI compatible messages array. + + Args: + messages: List of message objects to be processed by the model. + system_prompt: System prompt to provide context to the model. + + Returns: + An OpenAI compatible messages array. + """ + formatted_messages: list[dict[str, Any]] + formatted_messages = [{"role": "system", "content": system_prompt}] if system_prompt else [] + + for message in messages: + contents = message["content"] + + formatted_contents = [ + cls.format_request_message_content(content) + for content in contents + if not any(block_type in content for block_type in ["toolResult", "toolUse"]) + ] + formatted_tool_calls = [ + cls.format_request_message_tool_call(content["toolUse"]) for content in contents if "toolUse" in content + ] + formatted_tool_messages = [ + cls.format_request_tool_message(content["toolResult"]) + for content in contents + if "toolResult" in content + ] + + formatted_message = { + "role": message["role"], + "content": formatted_contents, + **({"tool_calls": formatted_tool_calls} if formatted_tool_calls else {}), + } + formatted_messages.append(formatted_message) + formatted_messages.extend(formatted_tool_messages) + + return [message for message in formatted_messages if message["content"] or "tool_calls" in message] + + @override + def format_request( + self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + ) -> dict[str, Any]: + """Format an OpenAI compatible chat streaming request. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. + + Returns: + An OpenAI compatible chat streaming request. + """ + return { + "messages": self.format_request_messages(messages, system_prompt), + "model": self.config["model_id"], + "stream": True, + "stream_options": {"include_usage": True}, + "tools": [ + { + "type": "function", + "function": { + "name": tool_spec["name"], + "description": tool_spec["description"], + "parameters": tool_spec["inputSchema"]["json"], + }, + } + for tool_spec in tool_specs or [] + ], + **(self.config.get("params") or {}), + } + + @override + def format_chunk(self, event: dict[str, Any]) -> StreamEvent: + """Format an OpenAI response event into a standardized message chunk. + + Args: + event: A response event from the OpenAI compatible model. + + Returns: + The formatted chunk. + + Raises: + RuntimeError: If chunk_type is not recognized. + This error should never be encountered as chunk_type is controlled in the stream method. + """ + match event["chunk_type"]: + case "message_start": + return {"messageStart": {"role": "assistant"}} + + case "content_start": + if event["data_type"] == "tool": + return { + "contentBlockStart": { + "start": { + "toolUse": { + "name": event["data"].function.name, + "toolUseId": event["data"].id, + } + } + } + } + + return {"contentBlockStart": {"start": {}}} + + case "content_delta": + if event["data_type"] == "tool": + return { + "contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments or ""}}} + } + + return {"contentBlockDelta": {"delta": {"text": event["data"]}}} + + case "content_stop": + return {"contentBlockStop": {}} + + case "message_stop": + match event["data"]: + case "tool_calls": + return {"messageStop": {"stopReason": "tool_use"}} + case "length": + return {"messageStop": {"stopReason": "max_tokens"}} + case _: + return {"messageStop": {"stopReason": "end_turn"}} + + case "metadata": + return { + "metadata": { + "usage": { + "inputTokens": event["data"].prompt_tokens, + "outputTokens": event["data"].completion_tokens, + "totalTokens": event["data"].total_tokens, + }, + "metrics": { + "latencyMs": 0, # TODO + }, + }, + } + + case _: + raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") diff --git a/src/strands/types/streaming.py b/src/strands/types/streaming.py index db600f15..9c99b210 100644 --- a/src/strands/types/streaming.py +++ b/src/strands/types/streaming.py @@ -157,7 +157,7 @@ class ModelStreamErrorEvent(ExceptionEvent): originalStatusCode: int -class RedactContentEvent(TypedDict): +class RedactContentEvent(TypedDict, total=False): """Event for redacting content. Attributes: diff --git a/tests-integ/test_model_bedrock.py b/tests-integ/test_model_bedrock.py new file mode 100644 index 00000000..a6a29aa9 --- /dev/null +++ b/tests-integ/test_model_bedrock.py @@ -0,0 +1,120 @@ +import pytest + +import strands +from strands import Agent +from strands.models import BedrockModel + + +@pytest.fixture +def system_prompt(): + return "You are an AI assistant that uses & instead of ." + + +@pytest.fixture +def streaming_model(): + return BedrockModel( + model_id="us.anthropic.claude-3-7-sonnet-20250219-v1:0", + streaming=True, + ) + + +@pytest.fixture +def non_streaming_model(): + return BedrockModel( + model_id="us.meta.llama3-2-90b-instruct-v1:0", + streaming=False, + ) + + +@pytest.fixture +def streaming_agent(streaming_model, system_prompt): + return Agent(model=streaming_model, system_prompt=system_prompt, load_tools_from_directory=False) + + +@pytest.fixture +def non_streaming_agent(non_streaming_model, system_prompt): + return Agent(model=non_streaming_model, system_prompt=system_prompt, load_tools_from_directory=False) + + +def test_streaming_agent(streaming_agent): + """Test agent with streaming model.""" + result = streaming_agent("Hello!") + + assert len(str(result)) > 0 + + +def test_non_streaming_agent(non_streaming_agent): + """Test agent with non-streaming model.""" + result = non_streaming_agent("Hello!") + + assert len(str(result)) > 0 + + +def test_streaming_model_events(streaming_model): + """Test streaming model events.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + + # Call converse and collect events + events = list(streaming_model.converse(messages)) + + # Verify basic structure of events + assert any("messageStart" in event for event in events) + assert any("contentBlockDelta" in event for event in events) + assert any("messageStop" in event for event in events) + + +def test_non_streaming_model_events(non_streaming_model): + """Test non-streaming model events.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + + # Call converse and collect events + events = list(non_streaming_model.converse(messages)) + + # Verify basic structure of events + assert any("messageStart" in event for event in events) + assert any("contentBlockDelta" in event for event in events) + assert any("messageStop" in event for event in events) + + +def test_tool_use_streaming(streaming_model): + """Test tool use with streaming model.""" + + tool_was_called = False + + @strands.tool + def calculator(expression: str) -> float: + """Calculate the result of a mathematical expression.""" + + nonlocal tool_was_called + tool_was_called = True + return eval(expression) + + agent = Agent(model=streaming_model, tools=[calculator], load_tools_from_directory=False) + result = agent("What is 123 + 456?") + + # Print the full message content for debugging + print("\nFull message content:") + import json + + print(json.dumps(result.message["content"], indent=2)) + + assert tool_was_called + + +def test_tool_use_non_streaming(non_streaming_model): + """Test tool use with non-streaming model.""" + + tool_was_called = False + + @strands.tool + def calculator(expression: str) -> float: + """Calculate the result of a mathematical expression.""" + + nonlocal tool_was_called + tool_was_called = True + return eval(expression) + + agent = Agent(model=non_streaming_model, tools=[calculator], load_tools_from_directory=False) + agent("What is 123 + 456?") + + assert tool_was_called diff --git a/tests-integ/test_model_openai.py b/tests-integ/test_model_openai.py new file mode 100644 index 00000000..c9046ad5 --- /dev/null +++ b/tests-integ/test_model_openai.py @@ -0,0 +1,46 @@ +import os + +import pytest + +import strands +from strands import Agent +from strands.models.openai import OpenAIModel + + +@pytest.fixture +def model(): + return OpenAIModel( + model_id="gpt-4o", + client_args={ + "api_key": os.getenv("OPENAI_API_KEY"), + }, + ) + + +@pytest.fixture +def tools(): + @strands.tool + def tool_time() -> str: + return "12:00" + + @strands.tool + def tool_weather() -> str: + return "sunny" + + return [tool_time, tool_weather] + + +@pytest.fixture +def agent(model, tools): + return Agent(model=model, tools=tools) + + +@pytest.mark.skipif( + "OPENAI_API_KEY" not in os.environ, + reason="OPENAI_API_KEY environment variable missing", +) +def test_agent(agent): + result = agent("What is the time and weather in New York?") + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"]) diff --git a/tests-integ/test_stream_agent.py b/tests-integ/test_stream_agent.py index 4c97db6b..01f20339 100644 --- a/tests-integ/test_stream_agent.py +++ b/tests-integ/test_stream_agent.py @@ -1,5 +1,5 @@ """ -Test script for Strands's custom callback handler functionality. +Test script for Strands' custom callback handler functionality. Demonstrates different patterns of callback handling and processing. """ diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 5d2ffb23..ea06fb4e 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -9,7 +9,8 @@ import pytest import strands -from strands.agent.agent import Agent +from strands import Agent +from strands.agent import AgentResult from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.handlers.callback_handler import PrintingCallbackHandler, null_callback_handler @@ -337,17 +338,47 @@ def test_agent__call__passes_kwargs(mock_model, system_prompt, callback_handler, ], ] + override_system_prompt = "Override system prompt" + override_model = unittest.mock.Mock() + override_tool_execution_handler = unittest.mock.Mock() + override_event_loop_metrics = unittest.mock.Mock() + override_callback_handler = unittest.mock.Mock() + override_tool_handler = unittest.mock.Mock() + override_messages = [{"role": "user", "content": [{"text": "override msg"}]}] + override_tool_config = {"test": "config"} + def check_kwargs(some_value, **kwargs): assert some_value == "a_value" assert kwargs is not None + assert kwargs["system_prompt"] == override_system_prompt + assert kwargs["model"] == override_model + assert kwargs["tool_execution_handler"] == override_tool_execution_handler + assert kwargs["event_loop_metrics"] == override_event_loop_metrics + assert kwargs["callback_handler"] == override_callback_handler + assert kwargs["tool_handler"] == override_tool_handler + assert kwargs["messages"] == override_messages + assert kwargs["tool_config"] == override_tool_config + assert kwargs["agent"] == agent # Return expected values from event_loop_cycle return "stop", {"role": "assistant", "content": [{"text": "Response"}]}, {}, {} mock_event_loop_cycle.side_effect = check_kwargs - agent("test message", some_value="a_value") - assert mock_event_loop_cycle.call_count == 1 + agent( + "test message", + some_value="a_value", + system_prompt=override_system_prompt, + model=override_model, + tool_execution_handler=override_tool_execution_handler, + event_loop_metrics=override_event_loop_metrics, + callback_handler=override_callback_handler, + tool_handler=override_tool_handler, + messages=override_messages, + tool_config=override_tool_config, + ) + + mock_event_loop_cycle.assert_called_once() def test_agent__call__retry_with_reduced_context(mock_model, agent, tool): @@ -428,7 +459,7 @@ def test_agent__call__always_sliding_window_conversation_manager_doesnt_infinite with pytest.raises(ContextWindowOverflowException): agent("Test!") - assert conversation_manager_spy.reduce_context.call_count == 251 + assert conversation_manager_spy.reduce_context.call_count > 0 assert conversation_manager_spy.apply_management.call_count == 1 @@ -566,7 +597,7 @@ def test_agent_tool_user_message_override(agent): }, { "text": ( - "agent.tool_decorated direct tool call\n" + "agent.tool.tool_decorated direct tool call.\n" "Input parameters: " '{"random_string": "abcdEfghI123", "user_message_override": "test override"}\n' ), @@ -657,8 +688,6 @@ def test_agent_with_callback_handler_none_uses_null_handler(): @pytest.mark.asyncio async def test_stream_async_returns_all_events(mock_event_loop_cycle): - mock_event_loop_cycle.side_effect = ValueError("Test exception") - agent = Agent() # Define the side effect to simulate callback handler being called multiple times @@ -922,6 +951,52 @@ def test_agent_call_creates_and_ends_span_on_success(mock_get_tracer, mock_model mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, response=result) +@pytest.mark.asyncio +@unittest.mock.patch("strands.agent.agent.get_tracer") +async def test_agent_stream_async_creates_and_ends_span_on_success(mock_get_tracer, mock_event_loop_cycle): + """Test that stream_async creates and ends a span when the call succeeds.""" + # Setup mock tracer and span + mock_tracer = unittest.mock.MagicMock() + mock_span = unittest.mock.MagicMock() + mock_tracer.start_agent_span.return_value = mock_span + mock_get_tracer.return_value = mock_tracer + + # Define the side effect to simulate callback handler being called multiple times + def call_callback_handler(*args, **kwargs): + # Extract the callback handler from kwargs + callback_handler = kwargs.get("callback_handler") + # Call the callback handler with different data values + callback_handler(data="First chunk") + callback_handler(data="Second chunk") + callback_handler(data="Final chunk", complete=True) + # Return expected values from event_loop_cycle + return "stop", {"role": "assistant", "content": [{"text": "Agent Response"}]}, {}, {} + + mock_event_loop_cycle.side_effect = call_callback_handler + + # Create agent and make a call + agent = Agent(model=mock_model) + iterator = agent.stream_async("test prompt") + async for _event in iterator: + pass # NoOp + + # Verify span was created + mock_tracer.start_agent_span.assert_called_once_with( + prompt="test prompt", + model_id=unittest.mock.ANY, + tools=agent.tool_names, + system_prompt=agent.system_prompt, + custom_trace_attributes=agent.trace_attributes, + ) + + expected_response = AgentResult( + stop_reason="stop", message={"role": "assistant", "content": [{"text": "Agent Response"}]}, metrics={}, state={} + ) + + # Verify span was ended with the result + mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, response=expected_response) + + @unittest.mock.patch("strands.agent.agent.get_tracer") def test_agent_call_creates_and_ends_span_on_exception(mock_get_tracer, mock_model): """Test that __call__ creates and ends a span when an exception occurs.""" @@ -955,6 +1030,42 @@ def test_agent_call_creates_and_ends_span_on_exception(mock_get_tracer, mock_mod mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception) +@pytest.mark.asyncio +@unittest.mock.patch("strands.agent.agent.get_tracer") +async def test_agent_stream_async_creates_and_ends_span_on_exception(mock_get_tracer, mock_model): + """Test that stream_async creates and ends a span when the call succeeds.""" + # Setup mock tracer and span + mock_tracer = unittest.mock.MagicMock() + mock_span = unittest.mock.MagicMock() + mock_tracer.start_agent_span.return_value = mock_span + mock_get_tracer.return_value = mock_tracer + + # Define the side effect to simulate callback handler raising an Exception + test_exception = ValueError("Test exception") + mock_model.mock_converse.side_effect = test_exception + + # Create agent and make a call + agent = Agent(model=mock_model) + + # Call the agent and catch the exception + with pytest.raises(ValueError): + iterator = agent.stream_async("test prompt") + async for _event in iterator: + pass # NoOp + + # Verify span was created + mock_tracer.start_agent_span.assert_called_once_with( + prompt="test prompt", + model_id=unittest.mock.ANY, + tools=agent.tool_names, + system_prompt=agent.system_prompt, + custom_trace_attributes=agent.trace_attributes, + ) + + # Verify span was ended with the exception + mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception) + + @unittest.mock.patch("strands.agent.agent.get_tracer") def test_event_loop_cycle_includes_parent_span(mock_get_tracer, mock_event_loop_cycle, mock_model): """Test that event_loop_cycle is called with the parent span.""" diff --git a/tests/strands/agent/test_conversation_manager.py b/tests/strands/agent/test_conversation_manager.py index 2f6ee77d..b6132f1d 100644 --- a/tests/strands/agent/test_conversation_manager.py +++ b/tests/strands/agent/test_conversation_manager.py @@ -111,41 +111,7 @@ def conversation_manager(request): {"role": "assistant", "content": [{"text": "Second response"}]}, ], ), - # 7 - Message count above max window size - Remove dangling tool uses and tool results - ( - {"window_size": 1}, - [ - {"role": "user", "content": [{"text": "First message"}]}, - {"role": "assistant", "content": [{"toolUse": {"toolUseId": "321", "name": "tool1", "input": {}}}]}, - { - "role": "user", - "content": [ - {"toolResult": {"toolUseId": "123", "content": [{"text": "Hello!"}], "status": "success"}} - ], - }, - ], - [ - { - "role": "user", - "content": [{"text": "\nTool Result Text Content: Hello!\nTool Result Status: success"}], - }, - ], - ), - # 8 - Message count above max window size - Remove multiple tool use/tool result pairs - ( - {"window_size": 1}, - [ - {"role": "user", "content": [{"toolResult": {"toolUseId": "123", "content": [], "status": "success"}}]}, - {"role": "assistant", "content": [{"toolUse": {"toolUseId": "123", "name": "tool1", "input": {}}}]}, - {"role": "user", "content": [{"toolResult": {"toolUseId": "456", "content": [], "status": "success"}}]}, - {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool1", "input": {}}}]}, - {"role": "user", "content": [{"toolResult": {"toolUseId": "789", "content": [], "status": "success"}}]}, - ], - [ - {"role": "user", "content": [{"text": "Tool Result Status: success"}]}, - ], - ), - # 9 - Message count above max window size - Preserve tool use/tool result pairs + # 7 - Message count above max window size - Preserve tool use/tool result pairs ( {"window_size": 2}, [ @@ -158,7 +124,7 @@ def conversation_manager(request): {"role": "user", "content": [{"toolResult": {"toolUseId": "456", "content": [], "status": "success"}}]}, ], ), - # 10 - Test sliding window behavior - preserve tool use/result pairs across cut boundary + # 8 - Test sliding window behavior - preserve tool use/result pairs across cut boundary ( {"window_size": 3}, [ @@ -173,7 +139,7 @@ def conversation_manager(request): {"role": "assistant", "content": [{"text": "Response after tool use"}]}, ], ), - # 11 - Test sliding window with multiple tool pairs that need preservation + # 9 - Test sliding window with multiple tool pairs that need preservation ( {"window_size": 4}, [ @@ -185,7 +151,6 @@ def conversation_manager(request): {"role": "assistant", "content": [{"text": "Final response"}]}, ], [ - {"role": "user", "content": [{"text": "Tool Result Status: success"}]}, {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool2", "input": {}}}]}, {"role": "user", "content": [{"toolResult": {"toolUseId": "456", "content": [], "status": "success"}}]}, {"role": "assistant", "content": [{"text": "Final response"}]}, @@ -200,6 +165,20 @@ def test_apply_management(conversation_manager, messages, expected_messages): assert messages == expected_messages +def test_sliding_window_conversation_manager_with_untrimmable_history_raises_context_window_overflow_exception(): + manager = strands.agent.conversation_manager.SlidingWindowConversationManager(1) + messages = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "456", "name": "tool1", "input": {}}}]}, + {"role": "user", "content": [{"toolResult": {"toolUseId": "789", "content": [], "status": "success"}}]}, + ] + original_messages = messages.copy() + + with pytest.raises(ContextWindowOverflowException): + manager.apply_management(messages) + + assert messages == original_messages + + def test_null_conversation_manager_reduce_context_raises_context_window_overflow_exception(): """Test that NullConversationManager doesn't modify messages.""" manager = strands.agent.conversation_manager.NullConversationManager() diff --git a/tests/strands/event_loop/test_message_processor.py b/tests/strands/event_loop/test_message_processor.py new file mode 100644 index 00000000..395c71a1 --- /dev/null +++ b/tests/strands/event_loop/test_message_processor.py @@ -0,0 +1,128 @@ +import copy + +import pytest + +from strands.event_loop import message_processor + + +@pytest.mark.parametrize( + "messages,expected,expected_messages", + [ + # Orphaned toolUse with empty input, no toolResult + ( + [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "input": {}, "name": "foo"}}]}, + {"role": "user", "content": [{"toolResult": {"toolUseId": "2"}}]}, + ], + True, + [ + {"role": "assistant", "content": [{"text": "[Attempted to use foo, but operation was canceled]"}]}, + {"role": "user", "content": [{"toolResult": {"toolUseId": "2"}}]}, + ], + ), + # toolUse with input, has matching toolResult + ( + [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "input": {"a": 1}, "name": "foo"}}]}, + {"role": "user", "content": [{"toolResult": {"toolUseId": "1"}}]}, + ], + False, + [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "input": {"a": 1}, "name": "foo"}}]}, + {"role": "user", "content": [{"toolResult": {"toolUseId": "1"}}]}, + ], + ), + # No messages + ( + [], + False, + [], + ), + ], +) +def test_clean_orphaned_empty_tool_uses(messages, expected, expected_messages): + test_messages = copy.deepcopy(messages) + result = message_processor.clean_orphaned_empty_tool_uses(test_messages) + assert result == expected + assert test_messages == expected_messages + + +@pytest.mark.parametrize( + "messages,expected_idx", + [ + ( + [ + {"role": "user", "content": [{"text": "hi"}]}, + {"role": "user", "content": [{"toolResult": {"toolUseId": "1"}}]}, + {"role": "assistant", "content": [{"text": "ok"}]}, + ], + 1, + ), + ( + [ + {"role": "user", "content": [{"text": "hi"}]}, + {"role": "assistant", "content": [{"text": "ok"}]}, + ], + None, + ), + ( + [], + None, + ), + ], +) +def test_find_last_message_with_tool_results(messages, expected_idx): + idx = message_processor.find_last_message_with_tool_results(messages) + assert idx == expected_idx + + +@pytest.mark.parametrize( + "messages,msg_idx,expected_changed,expected_content", + [ + ( + [ + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "1", "status": "ok", "content": [{"text": "big"}]}}], + } + ], + 0, + True, + [ + { + "toolResult": { + "toolUseId": "1", + "status": "error", + "content": [{"text": "The tool result was too large!"}], + } + } + ], + ), + ( + [{"role": "user", "content": [{"text": "no tool result"}]}], + 0, + False, + [{"text": "no tool result"}], + ), + ( + [], + 0, + False, + [], + ), + ( + [{"role": "user", "content": [{"toolResult": {"toolUseId": "1"}}]}], + 2, + False, + [{"toolResult": {"toolUseId": "1"}}], + ), + ], +) +def test_truncate_tool_results(messages, msg_idx, expected_changed, expected_content): + test_messages = copy.deepcopy(messages) + changed = message_processor.truncate_tool_results(test_messages, msg_idx) + assert changed == expected_changed + if 0 <= msg_idx < len(test_messages): + assert test_messages[msg_idx]["content"] == expected_content + else: + assert test_messages == messages diff --git a/tests/strands/handlers/test_callback_handler.py b/tests/strands/handlers/test_callback_handler.py index 20e238cb..6fb2af07 100644 --- a/tests/strands/handlers/test_callback_handler.py +++ b/tests/strands/handlers/test_callback_handler.py @@ -30,6 +30,31 @@ def test_call_with_empty_args(handler, mock_print): mock_print.assert_not_called() +def test_call_handler_reasoningText(handler, mock_print): + """Test calling the handler with reasoningText.""" + handler(reasoningText="This is reasoning text") + # Should print reasoning text without newline + mock_print.assert_called_once_with("This is reasoning text", end="") + + +def test_call_without_reasoningText(handler, mock_print): + """Test calling the handler without reasoningText argument.""" + handler(data="Some output") + # Should only print data, not reasoningText + mock_print.assert_called_once_with("Some output", end="") + + +def test_call_with_reasoningText_and_data(handler, mock_print): + """Test calling the handler with both reasoningText and data.""" + handler(reasoningText="Reasoning", data="Output") + # Should print reasoningText and data, both without newline + calls = [ + unittest.mock.call("Reasoning", end=""), + unittest.mock.call("Output", end=""), + ] + mock_print.assert_has_calls(calls) + + def test_call_with_data_incomplete(handler, mock_print): """Test calling the handler with data but not complete.""" handler(data="Test output") diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index 48a1da37..2ee344cc 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -102,19 +102,46 @@ def test_format_request_with_system_prompt(model, messages, model_id, max_tokens assert tru_request == exp_request -def test_format_request_with_document(model, model_id, max_tokens): +@pytest.mark.parametrize( + ("content", "formatted_content"), + [ + # PDF + ( + { + "document": {"format": "pdf", "name": "test doc", "source": {"bytes": b"pdf"}}, + }, + { + "source": { + "data": "cGRm", + "media_type": "application/pdf", + "type": "base64", + }, + "title": "test doc", + "type": "document", + }, + ), + # Plain text + ( + { + "document": {"format": "txt", "name": "test doc", "source": {"bytes": b"txt"}}, + }, + { + "source": { + "data": "txt", + "media_type": "text/plain", + "type": "text", + }, + "title": "test doc", + "type": "document", + }, + ), + ], +) +def test_format_request_with_document(content, formatted_content, model, model_id, max_tokens): messages = [ { "role": "user", - "content": [ - { - "document": { - "format": "pdf", - "name": "test-doc", - "source": {"bytes": b"base64encodeddoc"}, - }, - }, - ], + "content": [content], }, ] @@ -124,17 +151,7 @@ def test_format_request_with_document(model, model_id, max_tokens): "messages": [ { "role": "user", - "content": [ - { - "source": { - "data": "YmFzZTY0ZW5jb2RlZGRvYw==", - "media_type": "application/pdf", - "type": "base64", - }, - "title": "test-doc", - "type": "document", - }, - ], + "content": [formatted_content], }, ], "model": model_id, diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 566671ce..b326eee7 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -1,7 +1,9 @@ +import os import unittest.mock import boto3 import pytest +from botocore.config import Config as BotocoreConfig from botocore.exceptions import ClientError, EventStreamError import strands @@ -99,12 +101,70 @@ def test__init__with_custom_region(bedrock_client): mock_session_cls.assert_called_once_with(region_name=custom_region) +def test__init__with_environment_variable_region(bedrock_client): + """Test that BedrockModel uses the provided region.""" + _ = bedrock_client + os.environ["AWS_REGION"] = "eu-west-1" + + with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls: + _ = BedrockModel() + mock_session_cls.assert_called_once_with(region_name="eu-west-1") + + def test__init__with_region_and_session_raises_value_error(): """Test that BedrockModel raises ValueError when both region and session are provided.""" with pytest.raises(ValueError): _ = BedrockModel(region_name="us-east-1", boto_session=boto3.Session(region_name="us-east-1")) +def test__init__default_user_agent(bedrock_client): + """Set user agent when no boto_client_config is provided.""" + with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls: + mock_session = mock_session_cls.return_value + _ = BedrockModel() + + # Verify the client was created with the correct config + mock_session.client.assert_called_once() + args, kwargs = mock_session.client.call_args + assert kwargs["service_name"] == "bedrock-runtime" + assert isinstance(kwargs["config"], BotocoreConfig) + assert kwargs["config"].user_agent_extra == "strands-agents" + + +def test__init__with_custom_boto_client_config_no_user_agent(bedrock_client): + """Set user agent when boto_client_config is provided without user_agent_extra.""" + custom_config = BotocoreConfig(read_timeout=900) + + with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls: + mock_session = mock_session_cls.return_value + _ = BedrockModel(boto_client_config=custom_config) + + # Verify the client was created with the correct config + mock_session.client.assert_called_once() + args, kwargs = mock_session.client.call_args + assert kwargs["service_name"] == "bedrock-runtime" + assert isinstance(kwargs["config"], BotocoreConfig) + assert kwargs["config"].user_agent_extra == "strands-agents" + assert kwargs["config"].read_timeout == 900 + + +def test__init__with_custom_boto_client_config_with_user_agent(bedrock_client): + """Append to existing user agent when boto_client_config is provided with user_agent_extra.""" + custom_config = BotocoreConfig(user_agent_extra="existing-agent", read_timeout=900) + + with unittest.mock.patch("strands.models.bedrock.boto3.Session") as mock_session_cls: + mock_session = mock_session_cls.return_value + _ = BedrockModel(boto_client_config=custom_config) + + # Verify the client was created with the correct config + mock_session.client.assert_called_once() + args, kwargs = mock_session.client.call_args + assert kwargs["service_name"] == "bedrock-runtime" + assert isinstance(kwargs["config"], BotocoreConfig) + assert kwargs["config"].user_agent_extra == "existing-agent strands-agents" + assert kwargs["config"].read_timeout == 900 + + def test__init__model_config(bedrock_client): _ = bedrock_client @@ -294,9 +354,9 @@ def test_stream(bedrock_client, model): def test_stream_throttling_exception_from_event_stream_error(bedrock_client, model): - error_message = "ThrottlingException - Rate exceeded" + error_message = "Rate exceeded" bedrock_client.converse_stream.side_effect = EventStreamError( - {"Error": {"Message": error_message}}, "ConverseStream" + {"Error": {"Message": error_message, "Code": "ThrottlingException"}}, "ConverseStream" ) request = {"a": 1} @@ -361,7 +421,9 @@ def test_converse(bedrock_client, model, messages, tool_spec, model_id, addition bedrock_client.converse_stream.assert_called_once_with(**request) -def test_converse_input_guardrails(bedrock_client, model, messages, tool_spec, model_id, additional_request_fields): +def test_converse_stream_input_guardrails( + bedrock_client, model, messages, tool_spec, model_id, additional_request_fields +): metadata_event = { "metadata": { "usage": {"inputTokens": 0, "outputTokens": 0, "totalTokens": 0}, @@ -370,7 +432,15 @@ def test_converse_input_guardrails(bedrock_client, model, messages, tool_spec, m "guardrail": { "inputAssessment": { "3e59qlue4hag": { - "wordPolicy": {"customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}]} + "wordPolicy": { + "customWords": [ + { + "match": "CACTUS", + "action": "BLOCKED", + "detected": True, + } + ] + } } } } @@ -395,13 +465,18 @@ def test_converse_input_guardrails(bedrock_client, model, messages, tool_spec, m chunks = model.converse(messages, [tool_spec]) tru_chunks = list(chunks) - exp_chunks = [{"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, metadata_event] + exp_chunks = [ + {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, + metadata_event, + ] assert tru_chunks == exp_chunks bedrock_client.converse_stream.assert_called_once_with(**request) -def test_converse_output_guardrails(bedrock_client, model, messages, tool_spec, model_id, additional_request_fields): +def test_converse_stream_output_guardrails( + bedrock_client, model, messages, tool_spec, model_id, additional_request_fields +): model.update_config(guardrail_redact_input=False, guardrail_redact_output=True) metadata_event = { "metadata": { @@ -413,7 +488,13 @@ def test_converse_output_guardrails(bedrock_client, model, messages, tool_spec, "3e59qlue4hag": [ { "wordPolicy": { - "customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}] + "customWords": [ + { + "match": "CACTUS", + "action": "BLOCKED", + "detected": True, + } + ] }, } ] @@ -440,7 +521,10 @@ def test_converse_output_guardrails(bedrock_client, model, messages, tool_spec, chunks = model.converse(messages, [tool_spec]) tru_chunks = list(chunks) - exp_chunks = [{"redactContent": {"redactAssistantContentMessage": "[Assistant output redacted.]"}}, metadata_event] + exp_chunks = [ + {"redactContent": {"redactAssistantContentMessage": "[Assistant output redacted.]"}}, + metadata_event, + ] assert tru_chunks == exp_chunks bedrock_client.converse_stream.assert_called_once_with(**request) @@ -460,7 +544,13 @@ def test_converse_output_guardrails_redacts_input_and_output( "3e59qlue4hag": [ { "wordPolicy": { - "customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}] + "customWords": [ + { + "match": "CACTUS", + "action": "BLOCKED", + "detected": True, + } + ] }, } ] @@ -510,7 +600,13 @@ def test_converse_output_no_blocked_guardrails_doesnt_redact( "3e59qlue4hag": [ { "wordPolicy": { - "customWords": [{"match": "CACTUS", "action": "NONE", "detected": True}] + "customWords": [ + { + "match": "CACTUS", + "action": "NONE", + "detected": True, + } + ] }, } ] @@ -556,7 +652,13 @@ def test_converse_output_no_guardrail_redact( "3e59qlue4hag": [ { "wordPolicy": { - "customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}] + "customWords": [ + { + "match": "CACTUS", + "action": "BLOCKED", + "detected": True, + } + ] }, } ] @@ -580,7 +682,9 @@ def test_converse_output_no_guardrail_redact( } model.update_config( - additional_request_fields=additional_request_fields, guardrail_redact_output=False, guardrail_redact_input=False + additional_request_fields=additional_request_fields, + guardrail_redact_output=False, + guardrail_redact_input=False, ) chunks = model.converse(messages, [tool_spec]) @@ -589,3 +693,322 @@ def test_converse_output_no_guardrail_redact( assert tru_chunks == exp_chunks bedrock_client.converse_stream.assert_called_once_with(**request) + + +def test_stream_with_streaming_false(bedrock_client): + """Test stream method with streaming=False.""" + bedrock_client.converse.return_value = { + "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, + "stopReason": "end_turn", + } + expected_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "test"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn", "additionalModelResponseFields": None}}, + ] + + # Create model and call stream + model = BedrockModel(model_id="test-model", streaming=False) + request = {"modelId": "test-model"} + events = list(model.stream(request)) + + assert expected_events == events + + bedrock_client.converse.assert_called_once() + bedrock_client.converse_stream.assert_not_called() + + +def test_stream_with_streaming_false_and_tool_use(bedrock_client): + """Test stream method with streaming=False.""" + bedrock_client.converse.return_value = { + "output": { + "message": { + "role": "assistant", + "content": [{"toolUse": {"toolUseId": "123", "name": "dummyTool", "input": {"hello": "world!"}}}], + } + }, + "stopReason": "tool_use", + } + + expected_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockStart": {"start": {"toolUse": {"toolUseId": "123", "name": "dummyTool"}}}}, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"hello": "world!"}'}}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use", "additionalModelResponseFields": None}}, + ] + + # Create model and call stream + model = BedrockModel(model_id="test-model", streaming=False) + request = {"modelId": "test-model"} + events = list(model.stream(request)) + + assert expected_events == events + + bedrock_client.converse.assert_called_once() + bedrock_client.converse_stream.assert_not_called() + + +def test_stream_with_streaming_false_and_reasoning(bedrock_client): + """Test stream method with streaming=False.""" + bedrock_client.converse.return_value = { + "output": { + "message": { + "role": "assistant", + "content": [ + { + "reasoningContent": { + "reasoningText": {"text": "Thinking really hard....", "signature": "123"}, + } + } + ], + } + }, + "stopReason": "tool_use", + } + + expected_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "Thinking really hard...."}}}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"signature": "123"}}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use", "additionalModelResponseFields": None}}, + ] + + # Create model and call stream + model = BedrockModel(model_id="test-model", streaming=False) + request = {"modelId": "test-model"} + events = list(model.stream(request)) + + assert expected_events == events + + # Verify converse was called + bedrock_client.converse.assert_called_once() + bedrock_client.converse_stream.assert_not_called() + + +def test_converse_and_reasoning_no_signature(bedrock_client): + """Test stream method with streaming=False.""" + bedrock_client.converse.return_value = { + "output": { + "message": { + "role": "assistant", + "content": [ + { + "reasoningContent": { + "reasoningText": {"text": "Thinking really hard...."}, + } + } + ], + } + }, + "stopReason": "tool_use", + } + + expected_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"reasoningContent": {"text": "Thinking really hard...."}}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use", "additionalModelResponseFields": None}}, + ] + + # Create model and call stream + model = BedrockModel(model_id="test-model", streaming=False) + request = {"modelId": "test-model"} + events = list(model.stream(request)) + + assert expected_events == events + + bedrock_client.converse.assert_called_once() + bedrock_client.converse_stream.assert_not_called() + + +def test_stream_with_streaming_false_with_metrics_and_usage(bedrock_client): + """Test stream method with streaming=False.""" + bedrock_client.converse.return_value = { + "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, + "usage": {"inputTokens": 1234, "outputTokens": 1234, "totalTokens": 2468}, + "metrics": {"latencyMs": 1234}, + "stopReason": "tool_use", + } + + expected_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "test"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "tool_use", "additionalModelResponseFields": None}}, + { + "metadata": { + "usage": {"inputTokens": 1234, "outputTokens": 1234, "totalTokens": 2468}, + "metrics": {"latencyMs": 1234}, + } + }, + ] + + # Create model and call stream + model = BedrockModel(model_id="test-model", streaming=False) + request = {"modelId": "test-model"} + events = list(model.stream(request)) + + assert expected_events == events + + # Verify converse was called + bedrock_client.converse.assert_called_once() + bedrock_client.converse_stream.assert_not_called() + + +def test_converse_input_guardrails(bedrock_client): + """Test stream method with streaming=False.""" + bedrock_client.converse.return_value = { + "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, + "trace": { + "guardrail": { + "inputAssessment": { + "3e59qlue4hag": { + "wordPolicy": {"customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}]} + } + } + } + }, + "stopReason": "end_turn", + } + + expected_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "test"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn", "additionalModelResponseFields": None}}, + { + "metadata": { + "trace": { + "guardrail": { + "inputAssessment": { + "3e59qlue4hag": { + "wordPolicy": { + "customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}] + } + } + } + } + } + } + }, + {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, + ] + + # Create model and call stream + model = BedrockModel(model_id="test-model", streaming=False) + request = {"modelId": "test-model"} + events = list(model.stream(request)) + + assert expected_events == events + + bedrock_client.converse.assert_called_once() + bedrock_client.converse_stream.assert_not_called() + + +def test_converse_output_guardrails(bedrock_client): + """Test stream method with streaming=False.""" + bedrock_client.converse.return_value = { + "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, + "trace": { + "guardrail": { + "outputAssessments": { + "3e59qlue4hag": [ + { + "wordPolicy": {"customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}]}, + } + ] + }, + } + }, + "stopReason": "end_turn", + } + + expected_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "test"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn", "additionalModelResponseFields": None}}, + { + "metadata": { + "trace": { + "guardrail": { + "outputAssessments": { + "3e59qlue4hag": [ + { + "wordPolicy": { + "customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}] + } + } + ] + } + } + } + } + }, + {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, + ] + + model = BedrockModel(model_id="test-model", streaming=False) + request = {"modelId": "test-model"} + events = list(model.stream(request)) + + assert expected_events == events + + bedrock_client.converse.assert_called_once() + bedrock_client.converse_stream.assert_not_called() + + +def test_converse_output_guardrails_redacts_output(bedrock_client): + """Test stream method with streaming=False.""" + bedrock_client.converse.return_value = { + "output": {"message": {"role": "assistant", "content": [{"text": "test"}]}}, + "trace": { + "guardrail": { + "outputAssessments": { + "3e59qlue4hag": [ + { + "wordPolicy": {"customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}]}, + } + ] + }, + } + }, + "stopReason": "end_turn", + } + + expected_events = [ + {"messageStart": {"role": "assistant"}}, + {"contentBlockDelta": {"delta": {"text": "test"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn", "additionalModelResponseFields": None}}, + { + "metadata": { + "trace": { + "guardrail": { + "outputAssessments": { + "3e59qlue4hag": [ + { + "wordPolicy": { + "customWords": [{"match": "CACTUS", "action": "BLOCKED", "detected": True}] + } + } + ] + } + } + } + } + }, + {"redactContent": {"redactUserContentMessage": "[User input redacted.]"}}, + ] + + model = BedrockModel(model_id="test-model", streaming=False) + request = {"modelId": "test-model"} + events = list(model.stream(request)) + + assert expected_events == events + + bedrock_client.converse.assert_called_once() + bedrock_client.converse_stream.assert_not_called() diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 5d4d9b40..528d1498 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -1,4 +1,3 @@ -import json import unittest.mock import pytest @@ -8,9 +7,14 @@ @pytest.fixture -def litellm_client(): +def litellm_client_cls(): with unittest.mock.patch.object(strands.models.litellm.litellm, "LiteLLM") as mock_client_cls: - yield mock_client_cls.return_value + yield mock_client_cls + + +@pytest.fixture +def litellm_client(litellm_client_cls): + return litellm_client_cls.return_value @pytest.fixture @@ -35,15 +39,15 @@ def system_prompt(): return "s1" -def test__init__model_configs(litellm_client, model_id): - _ = litellm_client +def test__init__(litellm_client_cls, model_id): + model = LiteLLMModel({"api_key": "k1"}, model_id=model_id, params={"max_tokens": 1}) - model = LiteLLMModel(model_id=model_id, params={"max_tokens": 1}) + tru_config = model.get_config() + exp_config = {"model_id": "m1", "params": {"max_tokens": 1}} - tru_max_tokens = model.get_config().get("params") - exp_max_tokens = {"max_tokens": 1} + assert tru_config == exp_config - assert tru_max_tokens == exp_max_tokens + litellm_client_cls.assert_called_once_with(api_key="k1") def test_update_config(model, model_id): @@ -55,513 +59,47 @@ def test_update_config(model, model_id): assert tru_model_id == exp_model_id -def test_format_request_default(model, messages, model_id): - tru_request = model.format_request(messages) - exp_request = { - "messages": [{"role": "user", "content": [{"type": "text", "text": "test"}]}], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_params(model, messages, model_id): - model.update_config(params={"max_tokens": 1}) - - tru_request = model.format_request(messages) - exp_request = { - "messages": [{"role": "user", "content": [{"type": "text", "text": "test"}]}], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [], - "max_tokens": 1, - } - - assert tru_request == exp_request - - -def test_format_request_with_system_prompt(model, messages, model_id, system_prompt): - tru_request = model.format_request(messages, system_prompt=system_prompt) - exp_request = { - "messages": [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": [{"type": "text", "text": "test"}]}, - ], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_image(model, model_id): - messages = [ - { - "role": "user", - "content": [ - { - "image": { - "format": "jpg", - "source": {"bytes": b"base64encodedimage"}, - }, - }, - ], - }, - ] - - tru_request = model.format_request(messages) - exp_request = { - "messages": [ - { - "role": "user", - "content": [ - { - "image_url": { - "detail": "auto", - "format": "image/jpeg", - "url": "", - }, - "type": "image_url", - }, - ], - }, - ], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_reasoning(model, model_id): - messages = [ - { - "role": "user", - "content": [ - { - "reasoningContent": { - "reasoningText": { - "signature": "reasoning_signature", - "text": "reasoning_text", - }, - }, - }, - ], - }, - ] - - tru_request = model.format_request(messages) - exp_request = { - "messages": [ +@pytest.mark.parametrize( + "content, exp_result", + [ + # Case 1: Thinking + ( { - "role": "user", - "content": [ - { + "reasoningContent": { + "reasoningText": { "signature": "reasoning_signature", - "thinking": "reasoning_text", - "type": "thinking", - }, - ], - }, - ], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_video(model, model_id): - messages = [ - { - "role": "user", - "content": [ - { - "video": { - "source": {"bytes": "base64encodedvideo"}, + "text": "reasoning_text", }, }, - ], - }, - ] - - tru_request = model.format_request(messages) - exp_request = { - "messages": [ - { - "role": "user", - "content": [ - { - "type": "video_url", - "video_url": { - "detail": "auto", - "url": "base64encodedvideo", - }, - }, - ], }, - ], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_other(model, model_id): - messages = [ - { - "role": "user", - "content": [{"other": {"a": 1}}], - }, - ] - - tru_request = model.format_request(messages) - exp_request = { - "messages": [ { - "role": "user", - "content": [ - { - "text": json.dumps({"other": {"a": 1}}), - "type": "text", - }, - ], + "signature": "reasoning_signature", + "thinking": "reasoning_text", + "type": "thinking", }, - ], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_tool_result(model, model_id): - messages = [ - { - "role": "user", - "content": [{"toolResult": {"toolUseId": "c1", "status": "success", "content": [{"value": 4}]}}], - } - ] - - tru_request = model.format_request(messages) - exp_request = { - "messages": [ + ), + # Case 2: Video + ( { - "content": json.dumps( - { - "content": [{"value": 4}], - "status": "success", - } - ), - "role": "tool", - "tool_call_id": "c1", - }, - ], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_tool_use(model, model_id): - messages = [ - { - "role": "assistant", - "content": [ - { - "toolUse": { - "toolUseId": "c1", - "name": "calculator", - "input": {"expression": "2+2"}, - }, + "video": { + "source": {"bytes": "base64encodedvideo"}, }, - ], - }, - ] - - tru_request = model.format_request(messages) - exp_request = { - "messages": [ - { - "content": [], - "role": "assistant", - "tool_calls": [ - { - "function": { - "name": "calculator", - "arguments": '{"expression": "2+2"}', - }, - "id": "c1", - "type": "function", - } - ], - } - ], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_tool_specs(model, messages, model_id): - tool_specs = [ - { - "name": "calculator", - "description": "Calculate mathematical expressions", - "inputSchema": { - "json": {"type": "object", "properties": {"expression": {"type": "string"}}, "required": ["expression"]} }, - } - ] - - tru_request = model.format_request(messages, tool_specs) - exp_request = { - "messages": [{"role": "user", "content": [{"type": "text", "text": "test"}]}], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [ { - "type": "function", - "function": { - "name": "calculator", - "description": "Calculate mathematical expressions", - "parameters": { - "type": "object", - "properties": {"expression": {"type": "string"}}, - "required": ["expression"], - }, + "type": "video_url", + "video_url": { + "detail": "auto", + "url": "base64encodedvideo", }, - } - ], - } - - assert tru_request == exp_request - - -def test_format_chunk_message_start(model): - event = {"chunk_type": "message_start"} - - tru_chunk = model.format_chunk(event) - exp_chunk = {"messageStart": {"role": "assistant"}} - - assert tru_chunk == exp_chunk - - -def test_format_chunk_content_start_text(model): - event = {"chunk_type": "content_start", "data_type": "text"} - - tru_chunk = model.format_chunk(event) - exp_chunk = {"contentBlockStart": {"start": {}}} - - assert tru_chunk == exp_chunk - - -def test_format_chunk_content_start_tool(model): - mock_tool_use = unittest.mock.Mock() - mock_tool_use.function.name = "calculator" - mock_tool_use.id = "c1" - - event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_use} - - tru_chunk = model.format_chunk(event) - exp_chunk = {"contentBlockStart": {"start": {"toolUse": {"name": "calculator", "toolUseId": "c1"}}}} - - assert tru_chunk == exp_chunk - - -def test_format_chunk_content_delta_text(model): - event = {"chunk_type": "content_delta", "data_type": "text", "data": "Hello"} - - tru_chunk = model.format_chunk(event) - exp_chunk = {"contentBlockDelta": {"delta": {"text": "Hello"}}} - - assert tru_chunk == exp_chunk - - -def test_format_chunk_content_delta_tool(model): - event = { - "chunk_type": "content_delta", - "data_type": "tool", - "data": unittest.mock.Mock(function=unittest.mock.Mock(arguments='{"expression": "2+2"}')), - } - - tru_chunk = model.format_chunk(event) - exp_chunk = {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "2+2"}'}}}} - - assert tru_chunk == exp_chunk - - -def test_format_chunk_content_stop(model): - event = {"chunk_type": "content_stop"} - - tru_chunk = model.format_chunk(event) - exp_chunk = {"contentBlockStop": {}} - - assert tru_chunk == exp_chunk - - -def test_format_chunk_message_stop_end_turn(model): - event = {"chunk_type": "message_stop", "data": "stop"} - - tru_chunk = model.format_chunk(event) - exp_chunk = {"messageStop": {"stopReason": "end_turn"}} - - assert tru_chunk == exp_chunk - - -def test_format_chunk_message_stop_tool_use(model): - event = {"chunk_type": "message_stop", "data": "tool_calls"} - - tru_chunk = model.format_chunk(event) - exp_chunk = {"messageStop": {"stopReason": "tool_use"}} - - assert tru_chunk == exp_chunk - - -def test_format_chunk_message_stop_max_tokens(model): - event = {"chunk_type": "message_stop", "data": "length"} - - tru_chunk = model.format_chunk(event) - exp_chunk = {"messageStop": {"stopReason": "max_tokens"}} - - assert tru_chunk == exp_chunk - - -def test_format_chunk_metadata(model): - event = { - "chunk_type": "metadata", - "data": unittest.mock.Mock(prompt_tokens=100, completion_tokens=50, total_tokens=150), - } - - tru_chunk = model.format_chunk(event) - exp_chunk = { - "metadata": { - "usage": { - "inputTokens": 100, - "outputTokens": 50, - "totalTokens": 150, - }, - "metrics": { - "latencyMs": 0, }, - }, - } - - assert tru_chunk == exp_chunk - - -def test_format_chunk_other(model): - event = {"chunk_type": "other"} - - with pytest.raises(RuntimeError, match="chunk_type= | unknown type"): - model.format_chunk(event) - - -def test_stream(litellm_client, model): - mock_tool_call_1_part_1 = unittest.mock.Mock(index=0) - mock_tool_call_2_part_1 = unittest.mock.Mock(index=1) - mock_delta_1 = unittest.mock.Mock( - content="I'll calculate", tool_calls=[mock_tool_call_1_part_1, mock_tool_call_2_part_1] - ) - - mock_tool_call_1_part_2 = unittest.mock.Mock(index=0) - mock_tool_call_2_part_2 = unittest.mock.Mock(index=1) - mock_delta_2 = unittest.mock.Mock( - content="that for you", tool_calls=[mock_tool_call_1_part_2, mock_tool_call_2_part_2] - ) - - mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_1)]) - mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_2)]) - mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls")]) - mock_event_4 = unittest.mock.Mock() - - litellm_client.chat.completions.create.return_value = iter([mock_event_1, mock_event_2, mock_event_3, mock_event_4]) - - request = {"model": "m1", "messages": [{"role": "user", "content": [{"type": "text", "text": "calculate 2+2"}]}]} - response = model.stream(request) - - tru_events = list(response) - exp_events = [ - {"chunk_type": "message_start"}, - {"chunk_type": "content_start", "data_type": "text"}, - {"chunk_type": "content_delta", "data_type": "text", "data": "I'll calculate"}, - {"chunk_type": "content_delta", "data_type": "text", "data": "that for you"}, - {"chunk_type": "content_stop", "data_type": "text"}, - {"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call_1_part_1}, - {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_1_part_2}, - {"chunk_type": "content_stop", "data_type": "tool"}, - {"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call_2_part_1}, - {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_2_part_2}, - {"chunk_type": "content_stop", "data_type": "tool"}, - {"chunk_type": "message_stop", "data": "tool_calls"}, - {"chunk_type": "metadata", "data": mock_event_4.usage}, - ] - - assert tru_events == exp_events - litellm_client.chat.completions.create.assert_called_once_with(**request) - - -def test_stream_empty(litellm_client, model): - mock_delta = unittest.mock.Mock(content=None, tool_calls=None) - - mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) - mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop")]) - mock_event_3 = unittest.mock.Mock(spec=[]) - - litellm_client.chat.completions.create.return_value = iter([mock_event_1, mock_event_2, mock_event_3]) - - request = {"model": "m1", "messages": [{"role": "user", "content": []}]} - response = model.stream(request) - - tru_events = list(response) - exp_events = [ - {"chunk_type": "message_start"}, - {"chunk_type": "content_start", "data_type": "text"}, - {"chunk_type": "content_stop", "data_type": "text"}, - {"chunk_type": "message_stop", "data": "stop"}, - ] - - assert tru_events == exp_events - litellm_client.chat.completions.create.assert_called_once_with(**request) + ), + # Case 3: Text + ( + {"text": "hello"}, + {"type": "text", "text": "hello"}, + ), + ], +) +def test_format_request_message_content(content, exp_result): + tru_result = LiteLLMModel.format_request_message_content(content) + assert tru_result == exp_result diff --git a/tests/test_llamaapi.py b/tests/strands/models/test_llamaapi.py similarity index 100% rename from tests/test_llamaapi.py rename to tests/strands/models/test_llamaapi.py diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py new file mode 100644 index 00000000..89aa591f --- /dev/null +++ b/tests/strands/models/test_openai.py @@ -0,0 +1,134 @@ +import unittest.mock + +import pytest + +import strands +from strands.models.openai import OpenAIModel + + +@pytest.fixture +def openai_client_cls(): + with unittest.mock.patch.object(strands.models.openai.openai, "OpenAI") as mock_client_cls: + yield mock_client_cls + + +@pytest.fixture +def openai_client(openai_client_cls): + return openai_client_cls.return_value + + +@pytest.fixture +def model_id(): + return "m1" + + +@pytest.fixture +def model(openai_client, model_id): + _ = openai_client + + return OpenAIModel(model_id=model_id) + + +@pytest.fixture +def messages(): + return [{"role": "user", "content": [{"text": "test"}]}] + + +@pytest.fixture +def system_prompt(): + return "s1" + + +def test__init__(openai_client_cls, model_id): + model = OpenAIModel({"api_key": "k1"}, model_id=model_id, params={"max_tokens": 1}) + + tru_config = model.get_config() + exp_config = {"model_id": "m1", "params": {"max_tokens": 1}} + + assert tru_config == exp_config + + openai_client_cls.assert_called_once_with(api_key="k1") + + +def test_update_config(model, model_id): + model.update_config(model_id=model_id) + + tru_model_id = model.get_config().get("model_id") + exp_model_id = model_id + + assert tru_model_id == exp_model_id + + +def test_stream(openai_client, model): + mock_tool_call_1_part_1 = unittest.mock.Mock(index=0) + mock_tool_call_2_part_1 = unittest.mock.Mock(index=1) + mock_delta_1 = unittest.mock.Mock( + content="I'll calculate", tool_calls=[mock_tool_call_1_part_1, mock_tool_call_2_part_1] + ) + + mock_tool_call_1_part_2 = unittest.mock.Mock(index=0) + mock_tool_call_2_part_2 = unittest.mock.Mock(index=1) + mock_delta_2 = unittest.mock.Mock( + content="that for you", tool_calls=[mock_tool_call_1_part_2, mock_tool_call_2_part_2] + ) + + mock_delta_3 = unittest.mock.Mock(content="", tool_calls=None) + + mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_1)]) + mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_2)]) + mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta_3)]) + mock_event_4 = unittest.mock.Mock() + + openai_client.chat.completions.create.return_value = iter([mock_event_1, mock_event_2, mock_event_3, mock_event_4]) + + request = {"model": "m1", "messages": [{"role": "user", "content": [{"type": "text", "text": "calculate 2+2"}]}]} + response = model.stream(request) + + tru_events = list(response) + exp_events = [ + {"chunk_type": "message_start"}, + {"chunk_type": "content_start", "data_type": "text"}, + {"chunk_type": "content_delta", "data_type": "text", "data": "I'll calculate"}, + {"chunk_type": "content_delta", "data_type": "text", "data": "that for you"}, + {"chunk_type": "content_stop", "data_type": "text"}, + {"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call_1_part_1}, + {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_1_part_1}, + {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_1_part_2}, + {"chunk_type": "content_stop", "data_type": "tool"}, + {"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call_2_part_1}, + {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_2_part_1}, + {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_2_part_2}, + {"chunk_type": "content_stop", "data_type": "tool"}, + {"chunk_type": "message_stop", "data": "tool_calls"}, + {"chunk_type": "metadata", "data": mock_event_4.usage}, + ] + + assert tru_events == exp_events + openai_client.chat.completions.create.assert_called_once_with(**request) + + +def test_stream_empty(openai_client, model): + mock_delta = unittest.mock.Mock(content=None, tool_calls=None) + mock_usage = unittest.mock.Mock(prompt_tokens=0, completion_tokens=0, total_tokens=0) + + mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) + mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) + mock_event_3 = unittest.mock.Mock() + mock_event_4 = unittest.mock.Mock(usage=mock_usage) + + openai_client.chat.completions.create.return_value = iter([mock_event_1, mock_event_2, mock_event_3, mock_event_4]) + + request = {"model": "m1", "messages": [{"role": "user", "content": []}]} + response = model.stream(request) + + tru_events = list(response) + exp_events = [ + {"chunk_type": "message_start"}, + {"chunk_type": "content_start", "data_type": "text"}, + {"chunk_type": "content_stop", "data_type": "text"}, + {"chunk_type": "message_stop", "data": "stop"}, + {"chunk_type": "metadata", "data": mock_usage}, + ] + + assert tru_events == exp_events + openai_client.chat.completions.create.assert_called_once_with(**request) diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 55018c5e..32a4ac0a 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -1,11 +1,12 @@ import json import os +from datetime import date, datetime, timezone from unittest import mock import pytest from opentelemetry.trace import StatusCode # type: ignore -from strands.telemetry.tracer import Tracer, get_tracer +from strands.telemetry.tracer import JSONEncoder, Tracer, get_tracer, serialize from strands.types.streaming import Usage @@ -59,6 +60,50 @@ def mock_resource(): yield mock_resource +@pytest.fixture +def clean_env(): + """Fixture to provide a clean environment for each test.""" + with mock.patch.dict(os.environ, {}, clear=True): + yield + + +@pytest.fixture +def env_with_otlp(): + """Fixture with OTLP environment variables.""" + with mock.patch.dict( + os.environ, + { + "OTEL_EXPORTER_OTLP_ENDPOINT": "http://env-endpoint", + }, + ): + yield + + +@pytest.fixture +def env_with_console(): + """Fixture with console export environment variables.""" + with mock.patch.dict( + os.environ, + { + "STRANDS_OTEL_ENABLE_CONSOLE_EXPORT": "true", + }, + ): + yield + + +@pytest.fixture +def env_with_both(): + """Fixture with both OTLP and console export environment variables.""" + with mock.patch.dict( + os.environ, + { + "OTEL_EXPORTER_OTLP_ENDPOINT": "http://env-endpoint", + "STRANDS_OTEL_ENABLE_CONSOLE_EXPORT": "true", + }, + ): + yield + + def test_init_default(): """Test initializing the Tracer with default parameters.""" tracer = Tracer() @@ -268,6 +313,9 @@ def test_start_tool_call_span(mock_tracer): mock_tracer.start_span.assert_called_once() assert mock_tracer.start_span.call_args[1]["name"] == "Tool: test-tool" + mock_span.set_attribute.assert_any_call( + "gen_ai.prompt", json.dumps({"name": "test-tool", "toolUseId": "123", "input": {"param": "value"}}) + ) mock_span.set_attribute.assert_any_call("tool.name", "test-tool") mock_span.set_attribute.assert_any_call("tool.id", "123") mock_span.set_attribute.assert_any_call("tool.parameters", json.dumps({"param": "value"})) @@ -369,7 +417,7 @@ def test_end_agent_span(mock_span): tracer.end_agent_span(mock_span, mock_response) - mock_span.set_attribute.assert_any_call("gen_ai.completion", '""') + mock_span.set_attribute.assert_any_call("gen_ai.completion", "Agent response") mock_span.set_attribute.assert_any_call("gen_ai.usage.prompt_tokens", 50) mock_span.set_attribute.assert_any_call("gen_ai.usage.completion_tokens", 100) mock_span.set_attribute.assert_any_call("gen_ai.usage.total_tokens", 150) @@ -497,3 +545,230 @@ def test_start_model_invoke_span_with_parent(mock_tracer): # Verify span was returned assert span is mock_span + + +@pytest.mark.parametrize( + "input_data, expected_result", + [ + ("test string", '"test string"'), + (1234, "1234"), + (13.37, "13.37"), + (False, "false"), + (None, "null"), + ], +) +def test_json_encoder_serializable(input_data, expected_result): + """Test encoding of serializable values.""" + encoder = JSONEncoder() + + result = encoder.encode(input_data) + assert result == expected_result + + +def test_json_encoder_datetime(): + """Test encoding datetime and date objects.""" + encoder = JSONEncoder() + + dt = datetime(2025, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + result = encoder.encode(dt) + assert result == f'"{dt.isoformat()}"' + + d = date(2025, 1, 1) + result = encoder.encode(d) + assert result == f'"{d.isoformat()}"' + + +def test_json_encoder_list(): + """Test encoding a list with mixed content.""" + encoder = JSONEncoder() + + non_serializable = lambda x: x # noqa: E731 + + data = ["value", 42, 13.37, non_serializable, None, {"key": True}, ["value here"]] + + result = json.loads(encoder.encode(data)) + assert result == ["value", 42, 13.37, "", None, {"key": True}, ["value here"]] + + +def test_json_encoder_dict(): + """Test encoding a dict with mixed content.""" + encoder = JSONEncoder() + + class UnserializableClass: + def __str__(self): + return "Unserializable Object" + + non_serializable = lambda x: x # noqa: E731 + + now = datetime.now(timezone.utc) + + data = { + "metadata": { + "timestamp": now, + "version": "1.0", + "debug_info": {"object": non_serializable, "callable": lambda x: x + 1}, # noqa: E731 + }, + "content": [ + {"type": "text", "value": "Hello world"}, + {"type": "binary", "value": non_serializable}, + {"type": "mixed", "values": [1, "text", non_serializable, {"nested": non_serializable}]}, + ], + "statistics": { + "processed": 100, + "failed": 5, + "details": [{"id": 1, "status": "ok"}, {"id": 2, "status": "error", "error_obj": non_serializable}], + }, + "list": [ + non_serializable, + 1234, + 13.37, + True, + None, + "string here", + ], + } + + expected = { + "metadata": { + "timestamp": now.isoformat(), + "version": "1.0", + "debug_info": {"object": "", "callable": ""}, + }, + "content": [ + {"type": "text", "value": "Hello world"}, + {"type": "binary", "value": ""}, + {"type": "mixed", "values": [1, "text", "", {"nested": ""}]}, + ], + "statistics": { + "processed": 100, + "failed": 5, + "details": [{"id": 1, "status": "ok"}, {"id": 2, "status": "error", "error_obj": ""}], + }, + "list": [ + "", + 1234, + 13.37, + True, + None, + "string here", + ], + } + + result = json.loads(encoder.encode(data)) + + assert result == expected + + +def test_json_encoder_value_error(): + """Test encoding values that cause ValueError.""" + encoder = JSONEncoder() + + # A very large integer that exceeds JSON limits and throws ValueError + huge_number = 2**100000 + + # Test in a dictionary + dict_data = {"normal": 42, "huge": huge_number} + result = json.loads(encoder.encode(dict_data)) + assert result == {"normal": 42, "huge": ""} + + # Test in a list + list_data = [42, huge_number] + result = json.loads(encoder.encode(list_data)) + assert result == [42, ""] + + # Test just the value + result = json.loads(encoder.encode(huge_number)) + assert result == "" + + +def test_serialize_non_ascii_characters(): + """Test that non-ASCII characters are preserved in JSON serialization.""" + + # Test with Japanese text + japanese_text = "γ“γ‚“γ«γ‘γ―δΈ–η•Œ" + result = serialize({"text": japanese_text}) + assert japanese_text in result + assert "\\u" not in result + + # Test with emoji + emoji_text = "Hello 🌍" + result = serialize({"text": emoji_text}) + assert emoji_text in result + assert "\\u" not in result + + # Test with Chinese characters + chinese_text = "δ½ ε₯½οΌŒδΈ–η•Œ" + result = serialize({"text": chinese_text}) + assert chinese_text in result + assert "\\u" not in result + + # Test with mixed content + mixed_text = {"ja": "こんにけは", "emoji": "😊", "zh": "δ½ ε₯½", "en": "hello"} + result = serialize(mixed_text) + assert "こんにけは" in result + assert "😊" in result + assert "δ½ ε₯½" in result + assert "\\u" not in result + + +def test_serialize_vs_json_dumps(): + """Test that serialize behaves differently from default json.dumps for non-ASCII characters.""" + + # Test with Japanese text + japanese_text = "γ“γ‚“γ«γ‘γ―δΈ–η•Œ" + + # Default json.dumps should escape non-ASCII characters + default_result = json.dumps({"text": japanese_text}) + assert "\\u" in default_result + + # Our serialize function should preserve non-ASCII characters + custom_result = serialize({"text": japanese_text}) + assert japanese_text in custom_result + assert "\\u" not in custom_result + + +def test_init_with_no_env_or_param(clean_env): + """Test initializing with neither environment variable nor constructor parameter.""" + tracer = Tracer() + assert tracer.otlp_endpoint is None + assert tracer.enable_console_export is False + + tracer = Tracer(otlp_endpoint="http://param-endpoint") + assert tracer.otlp_endpoint == "http://param-endpoint" + + tracer = Tracer(enable_console_export=True) + assert tracer.enable_console_export is True + + +def test_constructor_params_with_otlp_env(env_with_otlp): + """Test constructor parameters precedence over OTLP environment variable.""" + # Constructor parameter should take precedence + tracer = Tracer(otlp_endpoint="http://constructor-endpoint") + assert tracer.otlp_endpoint == "http://constructor-endpoint" + + # Without constructor parameter, should use env var + tracer = Tracer() + assert tracer.otlp_endpoint == "http://env-endpoint" + + +def test_constructor_params_with_console_env(env_with_console): + """Test constructor parameters precedence over console environment variable.""" + # Constructor parameter should take precedence + tracer = Tracer(enable_console_export=False) + assert tracer.enable_console_export is False + + # Without explicit constructor parameter, should use env var + tracer = Tracer() + assert tracer.enable_console_export is True + + +def test_fallback_to_env_vars(env_with_both): + """Test fallback to environment variables when no constructor parameters.""" + tracer = Tracer() + assert tracer.otlp_endpoint == "http://env-endpoint" + assert tracer.enable_console_export is True + + # Constructor parameters should still take precedence + tracer = Tracer(otlp_endpoint="http://constructor-endpoint", enable_console_export=False) + assert tracer.otlp_endpoint == "http://constructor-endpoint" + assert tracer.enable_console_export is False diff --git a/tests/strands/tools/test_tools.py b/tests/strands/tools/test_tools.py index 8a6b406f..f24cc22d 100644 --- a/tests/strands/tools/test_tools.py +++ b/tests/strands/tools/test_tools.py @@ -14,9 +14,13 @@ def test_validate_tool_use_name_valid(): - tool = {"name": "valid_tool_name", "toolUseId": "123"} + tool1 = {"name": "valid_tool_name", "toolUseId": "123"} + # Should not raise an exception + validate_tool_use_name(tool1) + + tool2 = {"name": "valid-name", "toolUseId": "123"} # Should not raise an exception - validate_tool_use_name(tool) + validate_tool_use_name(tool2) def test_validate_tool_use_name_missing(): diff --git a/tests/strands/types/models/__init__.py b/tests/strands/types/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/strands/types/models/test_model.py b/tests/strands/types/models/test_model.py new file mode 100644 index 00000000..f2797fe5 --- /dev/null +++ b/tests/strands/types/models/test_model.py @@ -0,0 +1,81 @@ +import pytest + +from strands.types.models import Model as SAModel + + +class TestModel(SAModel): + def update_config(self, **model_config): + return model_config + + def get_config(self): + return + + def format_request(self, messages, tool_specs, system_prompt): + return { + "messages": messages, + "tool_specs": tool_specs, + "system_prompt": system_prompt, + } + + def format_chunk(self, event): + return {"event": event} + + def stream(self, request): + yield {"request": request} + + +@pytest.fixture +def model(): + return TestModel() + + +@pytest.fixture +def messages(): + return [ + { + "role": "user", + "content": [{"text": "hello"}], + }, + ] + + +@pytest.fixture +def tool_specs(): + return [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "input": {"type": "string"}, + }, + "required": ["input"], + }, + }, + }, + ] + + +@pytest.fixture +def system_prompt(): + return "s1" + + +def test_converse(model, messages, tool_specs, system_prompt): + response = model.converse(messages, tool_specs, system_prompt) + + tru_events = list(response) + exp_events = [ + { + "event": { + "request": { + "messages": messages, + "tool_specs": tool_specs, + "system_prompt": system_prompt, + }, + }, + }, + ] + assert tru_events == exp_events diff --git a/tests/strands/types/models/test_openai.py b/tests/strands/types/models/test_openai.py new file mode 100644 index 00000000..c6a05291 --- /dev/null +++ b/tests/strands/types/models/test_openai.py @@ -0,0 +1,358 @@ +import json +import unittest.mock + +import pytest + +from strands.types.models import OpenAIModel as SAOpenAIModel + + +class TestOpenAIModel(SAOpenAIModel): + def __init__(self): + self.config = {"model_id": "m1", "params": {"max_tokens": 1}} + + def update_config(self, **model_config): + return model_config + + def get_config(self): + return + + def stream(self, request): + yield {"request": request} + + +@pytest.fixture +def model(): + return TestOpenAIModel() + + +@pytest.fixture +def messages(): + return [ + { + "role": "user", + "content": [{"text": "hello"}], + }, + ] + + +@pytest.fixture +def tool_specs(): + return [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "input": {"type": "string"}, + }, + "required": ["input"], + }, + }, + }, + ] + + +@pytest.fixture +def system_prompt(): + return "s1" + + +@pytest.mark.parametrize( + "content, exp_result", + [ + # Document + ( + { + "document": { + "format": "pdf", + "name": "test doc", + "source": {"bytes": b"document"}, + }, + }, + { + "file": { + "file_data": "data:application/pdf;base64,ZG9jdW1lbnQ=", + "filename": "test doc", + }, + "type": "file", + }, + ), + # Image + ( + { + "image": { + "format": "jpg", + "source": {"bytes": b"image"}, + }, + }, + { + "image_url": { + "detail": "auto", + "format": "image/jpeg", + "url": "", + }, + "type": "image_url", + }, + ), + # Text + ( + {"text": "hello"}, + {"type": "text", "text": "hello"}, + ), + # Other + ( + {"other": {"a": 1}}, + { + "text": json.dumps({"other": {"a": 1}}), + "type": "text", + }, + ), + ], +) +def test_format_request_message_content(content, exp_result): + tru_result = SAOpenAIModel.format_request_message_content(content) + assert tru_result == exp_result + + +def test_format_request_message_tool_call(): + tool_use = { + "input": {"expression": "2+2"}, + "name": "calculator", + "toolUseId": "c1", + } + + tru_result = SAOpenAIModel.format_request_message_tool_call(tool_use) + exp_result = { + "function": { + "arguments": '{"expression": "2+2"}', + "name": "calculator", + }, + "id": "c1", + "type": "function", + } + assert tru_result == exp_result + + +def test_format_request_tool_message(): + tool_result = { + "content": [{"value": 4}], + "status": "success", + "toolUseId": "c1", + } + + tru_result = SAOpenAIModel.format_request_tool_message(tool_result) + exp_result = { + "content": json.dumps( + { + "content": [{"value": 4}], + "status": "success", + } + ), + "role": "tool", + "tool_call_id": "c1", + } + assert tru_result == exp_result + + +def test_format_request_messages(system_prompt): + messages = [ + { + "content": [], + "role": "user", + }, + { + "content": [{"text": "hello"}], + "role": "user", + }, + { + "content": [ + {"text": "call tool"}, + { + "toolUse": { + "input": {"expression": "2+2"}, + "name": "calculator", + "toolUseId": "c1", + }, + }, + ], + "role": "assistant", + }, + { + "content": [{"toolResult": {"toolUseId": "c1", "status": "success", "content": [{"value": 4}]}}], + "role": "user", + }, + ] + + tru_result = SAOpenAIModel.format_request_messages(messages, system_prompt) + exp_result = [ + { + "content": system_prompt, + "role": "system", + }, + { + "content": [{"text": "hello", "type": "text"}], + "role": "user", + }, + { + "content": [{"text": "call tool", "type": "text"}], + "role": "assistant", + "tool_calls": [ + { + "function": { + "name": "calculator", + "arguments": '{"expression": "2+2"}', + }, + "id": "c1", + "type": "function", + } + ], + }, + { + "content": json.dumps( + { + "content": [{"value": 4}], + "status": "success", + } + ), + "role": "tool", + "tool_call_id": "c1", + }, + ] + assert tru_result == exp_result + + +def test_format_request(model, messages, tool_specs, system_prompt): + tru_request = model.format_request(messages, tool_specs, system_prompt) + exp_request = { + "messages": [ + { + "content": system_prompt, + "role": "system", + }, + { + "content": [{"text": "hello", "type": "text"}], + "role": "user", + }, + ], + "model": "m1", + "stream": True, + "stream_options": {"include_usage": True}, + "tools": [ + { + "function": { + "description": "A test tool", + "name": "test_tool", + "parameters": { + "properties": { + "input": {"type": "string"}, + }, + "required": ["input"], + "type": "object", + }, + }, + "type": "function", + }, + ], + "max_tokens": 1, + } + assert tru_request == exp_request + + +@pytest.mark.parametrize( + ("event", "exp_chunk"), + [ + # Message start + ( + {"chunk_type": "message_start"}, + {"messageStart": {"role": "assistant"}}, + ), + # Content Start - Tool Use + ( + { + "chunk_type": "content_start", + "data_type": "tool", + "data": unittest.mock.Mock(**{"function.name": "calculator", "id": "c1"}), + }, + {"contentBlockStart": {"start": {"toolUse": {"name": "calculator", "toolUseId": "c1"}}}}, + ), + # Content Start - Text + ( + {"chunk_type": "content_start", "data_type": "text"}, + {"contentBlockStart": {"start": {}}}, + ), + # Content Delta - Tool Use + ( + { + "chunk_type": "content_delta", + "data_type": "tool", + "data": unittest.mock.Mock(function=unittest.mock.Mock(arguments='{"expression": "2+2"}')), + }, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "2+2"}'}}}}, + ), + # Content Delta - Tool Use - None + ( + { + "chunk_type": "content_delta", + "data_type": "tool", + "data": unittest.mock.Mock(function=unittest.mock.Mock(arguments=None)), + }, + {"contentBlockDelta": {"delta": {"toolUse": {"input": ""}}}}, + ), + # Content Delta - Text + ( + {"chunk_type": "content_delta", "data_type": "text", "data": "hello"}, + {"contentBlockDelta": {"delta": {"text": "hello"}}}, + ), + # Content Stop + ( + {"chunk_type": "content_stop"}, + {"contentBlockStop": {}}, + ), + # Message Stop - Tool Use + ( + {"chunk_type": "message_stop", "data": "tool_calls"}, + {"messageStop": {"stopReason": "tool_use"}}, + ), + # Message Stop - Max Tokens + ( + {"chunk_type": "message_stop", "data": "length"}, + {"messageStop": {"stopReason": "max_tokens"}}, + ), + # Message Stop - End Turn + ( + {"chunk_type": "message_stop", "data": "stop"}, + {"messageStop": {"stopReason": "end_turn"}}, + ), + # Metadata + ( + { + "chunk_type": "metadata", + "data": unittest.mock.Mock(prompt_tokens=100, completion_tokens=50, total_tokens=150), + }, + { + "metadata": { + "usage": { + "inputTokens": 100, + "outputTokens": 50, + "totalTokens": 150, + }, + "metrics": { + "latencyMs": 0, + }, + }, + }, + ), + ], +) +def test_format_chunk(event, exp_chunk, model): + tru_chunk = model.format_chunk(event) + assert tru_chunk == exp_chunk + + +def test_format_chunk_unknown_type(model): + event = {"chunk_type": "unknown"} + + with pytest.raises(RuntimeError, match="chunk_type= | unknown type"): + model.format_chunk(event)