From 0ec2df5c8e27da1245d5cc0b7ac1ec10911676f4 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Fri, 23 May 2025 15:20:01 -0400 Subject: [PATCH] Release - v0.1.4 (#91) --- .github/workflows/test-lint-pr.yml | 74 +-- CONTRIBUTING.md | 14 +- STYLE_GUIDE.md | 2 +- pyproject.toml | 26 +- src/strands/event_loop/error_handler.py | 2 +- src/strands/handlers/tool_handler.py | 1 + src/strands/models/bedrock.py | 18 +- src/strands/models/litellm.py | 263 +-------- src/strands/models/openai.py | 123 ++++ src/strands/telemetry/tracer.py | 30 +- src/strands/tools/tools.py | 2 +- src/strands/types/guardrails.py | 4 +- src/strands/types/media.py | 2 +- src/strands/types/models/__init__.py | 6 + .../types/{models.py => models/model.py} | 6 +- src/strands/types/models/openai.py | 240 ++++++++ tests-integ/test_model_openai.py | 46 ++ tests-integ/test_stream_agent.py | 2 +- tests/strands/models/test_litellm.py | 550 ++---------------- tests/{ => strands/models}/test_llamaapi.py | 0 tests/strands/models/test_openai.py | 134 +++++ tests/strands/telemetry/test_tracer.py | 48 +- tests/strands/types/models/__init__.py | 0 tests/strands/types/models/test_model.py | 81 +++ tests/strands/types/models/test_openai.py | 332 +++++++++++ 25 files changed, 1154 insertions(+), 852 deletions(-) create mode 100644 src/strands/models/openai.py create mode 100644 src/strands/types/models/__init__.py rename src/strands/types/{models.py => models/model.py} (97%) create mode 100644 src/strands/types/models/openai.py create mode 100644 tests-integ/test_model_openai.py rename tests/{ => strands/models}/test_llamaapi.py (100%) create mode 100644 tests/strands/models/test_openai.py create mode 100644 tests/strands/types/models/__init__.py create mode 100644 tests/strands/types/models/test_model.py create mode 100644 tests/strands/types/models/test_openai.py diff --git a/.github/workflows/test-lint-pr.yml b/.github/workflows/test-lint-pr.yml index 15fbebcb..5ba62427 100644 --- a/.github/workflows/test-lint-pr.yml +++ b/.github/workflows/test-lint-pr.yml @@ -6,88 +6,49 @@ on: types: [opened, synchronize, reopened, ready_for_review, review_requested, review_request_removed] push: branches: [ main ] # Also run on direct pushes to main +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true jobs: - check-approval: - name: Check if PR has contributor approval - runs-on: ubuntu-latest - permissions: - pull-requests: read - # Skip this check for direct pushes to main - if: github.event_name == 'pull_request' - outputs: - approved: ${{ steps.check-approval.outputs.approved }} - steps: - - name: Check if PR has been approved by a contributor - id: check-approval - uses: actions/github-script@v7 - with: - script: | - const APPROVED_ASSOCIATION = ['COLLABORATOR', 'CONTRIBUTOR', 'MEMBER', 'OWNER'] - const PR_AUTHOR_ASSOCIATION = context.payload.pull_request.author_association; - const { data: reviews } = await github.rest.pulls.listReviews({ - owner: context.repo.owner, - repo: context.repo.repo, - pull_number: context.issue.number, - }); - - const isApprovedContributor = APPROVED_ASSOCIATION.includes(PR_AUTHOR_ASSOCIATION); - - // Check if any contributor has approved - const isApproved = reviews.some(review => - review.state === 'APPROVED' && APPROVED_ASSOCIATION.includes(review.author_association) - ) || isApprovedContributor; - - core.setOutput('approved', isApproved); - - if (!isApproved) { - core.notice('This PR does not have approval from a Contributor. Workflow will not run test jobs.'); - return false; - } - - return true; - unit-test: name: Unit Tests - Python ${{ matrix.python-version }} - ${{ matrix.os-name }} - needs: check-approval permissions: contents: read - # Only run if PR is approved or this is a direct push to main - if: github.event_name == 'push' || needs.check-approval.outputs.approved == 'true' strategy: matrix: include: # Linux - os: ubuntu-latest - os-name: linux + os-name: 'linux' python-version: "3.10" - os: ubuntu-latest - os-name: linux + os-name: 'linux' python-version: "3.11" - os: ubuntu-latest - os-name: linux + os-name: 'linux' python-version: "3.12" - os: ubuntu-latest - os-name: linux + os-name: 'linux' python-version: "3.13" # Windows - os: windows-latest - os-name: windows + os-name: 'windows' python-version: "3.10" - os: windows-latest - os-name: windows + os-name: 'windows' python-version: "3.11" - os: windows-latest - os-name: windows + os-name: 'windows' python-version: "3.12" - os: windows-latest - os-name: windows + os-name: 'windows' python-version: "3.13" - # MacOS - latest only; not enough runners for MacOS + # MacOS - latest only; not enough runners for macOS - os: macos-latest - os-name: macos - python-version: "3.13" - fail-fast: false + os-name: 'macOS' + python-version: "3.13" + fail-fast: true runs-on: ${{ matrix.os }} env: LOG_LEVEL: DEBUG @@ -95,7 +56,7 @@ jobs: - name: Checkout code uses: actions/checkout@v4 with: - ref: ${{ github.event.pull_request.head.sha }} # Explicitly define which commit to checkout + ref: ${{ github.event.pull_request.head.sha }} # Explicitly define which commit to check out persist-credentials: false # Don't persist credentials for subsequent actions - name: Set up Python uses: actions/setup-python@v5 @@ -108,14 +69,11 @@ jobs: id: tests run: hatch test tests --cover continue-on-error: false - lint: name: Lint runs-on: ubuntu-latest - needs: check-approval permissions: contents: read - if: github.event_name == 'push' || needs.check-approval.outputs.approved == 'true' steps: - name: Checkout code uses: actions/checkout@v4 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 18087852..fa724cdd 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -31,16 +31,22 @@ This project uses [hatchling](https://hatch.pypa.io/latest/build/#hatchling) as ### Setting Up Your Development Environment -1. Install development dependencies: +1. Entering virtual environment using `hatch` (recommended), then launch your IDE in the new shell. ```bash - pip install -e ".[dev]" && pip install -e ".[litellm] + hatch shell dev ``` + Alternatively, install development dependencies in a manually created virtual environment: + ```bash + pip install -e ".[dev]" && pip install -e ".[litellm]" + ``` + + 2. Set up pre-commit hooks: ```bash pre-commit install -t pre-commit -t commit-msg ``` - This will automatically run formatters and convention commit checks on your code before each commit. + This will automatically run formatters and conventional commit checks on your code before each commit. 3. Run code formatters manually: ```bash @@ -117,7 +123,7 @@ To send us a pull request, please: ## Finding contributions to work on -Looking at the existing issues is a great way to find something to contribute on. +Looking at the existing issues is a great way to find something to contribute to. You can check: - Our known bugs list in [Bug Reports](../../issues?q=is%3Aissue%20state%3Aopen%20label%3Abug) for issues that need fixing diff --git a/STYLE_GUIDE.md b/STYLE_GUIDE.md index a50c571b..51dc0a73 100644 --- a/STYLE_GUIDE.md +++ b/STYLE_GUIDE.md @@ -26,7 +26,7 @@ logger.debug("field1=<%s>, field2=<%s>, ... | human readable message", field1, f - This is an optimization to skip string interpolation when the log level is not enabled 1. **Messages**: - - Add human readable messages at the end of the log + - Add human-readable messages at the end of the log - Use lowercase for consistency - Avoid punctuation (periods, exclamation points, etc.) to reduce clutter - Keep messages concise and focused on a single statement diff --git a/pyproject.toml b/pyproject.toml index 43130ac9..52394cea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "strands-agents" -version = "0.1.3" +version = "0.1.4" description = "A model-driven approach to building AI agents in just a few lines of code" readme = "README.md" requires-python = ">=3.10" @@ -33,9 +33,9 @@ dependencies = [ "pydantic>=2.0.0,<3.0.0", "typing-extensions>=4.13.2,<5.0.0", "watchdog>=6.0.0,<7.0.0", - "opentelemetry-api>=1.33.0,<2.0.0", - "opentelemetry-sdk>=1.33.0,<2.0.0", - "opentelemetry-exporter-otlp-proto-http>=1.33.0,<2.0.0", + "opentelemetry-api>=1.30.0,<2.0.0", + "opentelemetry-sdk>=1.30.0,<2.0.0", + "opentelemetry-exporter-otlp-proto-http>=1.30.0,<2.0.0", ] [project.urls] @@ -54,7 +54,7 @@ dev = [ "commitizen>=4.4.0,<5.0.0", "hatch>=1.0.0,<2.0.0", "moto>=5.1.0,<6.0.0", - "mypy>=0.981,<1.0.0", + "mypy>=1.15.0,<2.0.0", "pre-commit>=3.2.0,<4.2.0", "pytest>=8.0.0,<9.0.0", "pytest-asyncio>=0.26.0,<0.27.0", @@ -69,15 +69,18 @@ docs = [ litellm = [ "litellm>=1.69.0,<2.0.0", ] +llamaapi = [ + "llama-api-client>=0.1.0,<1.0.0", +] ollama = [ "ollama>=0.4.8,<1.0.0", ] -llamaapi = [ - "llama-api-client>=0.1.0,<1.0.0", +openai = [ + "openai>=1.68.0,<2.0.0", ] [tool.hatch.envs.hatch-static-analysis] -features = ["anthropic", "litellm", "llamaapi", "ollama"] +features = ["anthropic", "litellm", "llamaapi", "ollama", "openai"] dependencies = [ "mypy>=1.15.0,<2.0.0", "ruff>=0.11.6,<0.12.0", @@ -100,7 +103,7 @@ lint-fix = [ ] [tool.hatch.envs.hatch-test] -features = ["anthropic", "litellm", "llamaapi", "ollama"] +features = ["anthropic", "litellm", "llamaapi", "ollama", "openai"] extra-dependencies = [ "moto>=5.1.0,<6.0.0", "pytest>=8.0.0,<9.0.0", @@ -114,6 +117,11 @@ extra-args = [ "-vv", ] +[tool.hatch.envs.dev] +dev-mode = true +features = ["dev", "docs", "anthropic", "litellm", "llamaapi", "ollama"] + + [[tool.hatch.envs.hatch-test.matrix]] python = ["3.13", "3.12", "3.11", "3.10"] diff --git a/src/strands/event_loop/error_handler.py b/src/strands/event_loop/error_handler.py index 5a78bb3f..a5c85668 100644 --- a/src/strands/event_loop/error_handler.py +++ b/src/strands/event_loop/error_handler.py @@ -74,7 +74,7 @@ def handle_input_too_long_error( """Handle 'Input is too long' errors by truncating tool results. When a context window overflow exception occurs (input too long for the model), this function attempts to recover - by finding and truncating the most recent tool results in the conversation history. If trunction is successful, the + by finding and truncating the most recent tool results in the conversation history. If truncation is successful, the function will make a call to the event loop. Args: diff --git a/src/strands/handlers/tool_handler.py b/src/strands/handlers/tool_handler.py index 0803eca5..bc4ec1ce 100644 --- a/src/strands/handlers/tool_handler.py +++ b/src/strands/handlers/tool_handler.py @@ -46,6 +46,7 @@ def preprocess( def process( self, tool: Any, + *, model: Model, system_prompt: Optional[str], messages: List[Any], diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index 4c02156d..05d89923 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -112,7 +112,21 @@ def __init__( session = boto_session or boto3.Session( region_name=region_name or os.getenv("AWS_REGION") or "us-west-2", ) - client_config = boto_client_config or BotocoreConfig(user_agent_extra="strands-agents") + + # Add strands-agents to the request user agent + if boto_client_config: + existing_user_agent = getattr(boto_client_config, "user_agent_extra", None) + + # Append 'strands-agents' to existing user_agent_extra or set it if not present + if existing_user_agent: + new_user_agent = f"{existing_user_agent} strands-agents" + else: + new_user_agent = "strands-agents" + + client_config = boto_client_config.merge(BotocoreConfig(user_agent_extra=new_user_agent)) + else: + client_config = BotocoreConfig(user_agent_extra="strands-agents") + self.client = session.client( service_name="bedrock-runtime", config=client_config, @@ -132,7 +146,7 @@ def get_config(self) -> BedrockConfig: """Get the current Bedrock Model configuration. Returns: - The Bedrok model configuration. + The Bedrock model configuration. """ return self.config diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index a7563133..23d2c2ae 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -3,23 +3,19 @@ - Docs: https://docs.litellm.ai/ """ -import json import logging -import mimetypes -from typing import Any, Iterable, Optional, TypedDict +from typing import Any, Optional, TypedDict, cast import litellm from typing_extensions import Unpack, override -from ..types.content import ContentBlock, Messages -from ..types.models import Model -from ..types.streaming import StreamEvent -from ..types.tools import ToolResult, ToolSpec, ToolUse +from ..types.content import ContentBlock +from .openai import OpenAIModel logger = logging.getLogger(__name__) -class LiteLLMModel(Model): +class LiteLLMModel(OpenAIModel): """LiteLLM model provider implementation.""" class LiteLLMConfig(TypedDict, total=False): @@ -45,7 +41,7 @@ def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: https://github.com/BerriAI/litellm/blob/main/litellm/main.py. **model_config: Configuration options for the LiteLLM model. """ - self.config = LiteLLMModel.LiteLLMConfig(**model_config) + self.config = dict(model_config) logger.debug("config=<%s> | initializing", self.config) @@ -68,9 +64,11 @@ def get_config(self) -> LiteLLMConfig: Returns: The LiteLLM model configuration. """ - return self.config + return cast(LiteLLMModel.LiteLLMConfig, self.config) - def _format_request_message_content(self, content: ContentBlock) -> dict[str, Any]: + @override + @staticmethod + def format_request_message_content(content: ContentBlock) -> dict[str, Any]: """Format a LiteLLM content block. Args: @@ -79,18 +77,6 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An Returns: LiteLLM formatted content block. """ - if "image" in content: - mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream") - image_data = content["image"]["source"]["bytes"].decode("utf-8") - return { - "image_url": { - "detail": "auto", - "format": mime_type, - "url": f"data:{mime_type};base64,{image_data}", - }, - "type": "image_url", - } - if "reasoningContent" in content: return { "signature": content["reasoningContent"]["reasoningText"]["signature"], @@ -98,9 +84,6 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An "type": "thinking", } - if "text" in content: - return {"text": content["text"], "type": "text"} - if "video" in content: return { "type": "video_url", @@ -110,230 +93,4 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An }, } - return {"text": json.dumps(content), "type": "text"} - - def _format_request_message_tool_call(self, tool_use: ToolUse) -> dict[str, Any]: - """Format a LiteLLM tool call. - - Args: - tool_use: Tool use requested by the model. - - Returns: - LiteLLM formatted tool call. - """ - return { - "function": { - "arguments": json.dumps(tool_use["input"]), - "name": tool_use["name"], - }, - "id": tool_use["toolUseId"], - "type": "function", - } - - def _format_request_tool_message(self, tool_result: ToolResult) -> dict[str, Any]: - """Format a LiteLLM tool message. - - Args: - tool_result: Tool result collected from a tool execution. - - Returns: - LiteLLM formatted tool message. - """ - return { - "role": "tool", - "tool_call_id": tool_result["toolUseId"], - "content": json.dumps( - { - "content": tool_result["content"], - "status": tool_result["status"], - } - ), - } - - def _format_request_messages(self, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: - """Format a LiteLLM messages array. - - Args: - messages: List of message objects to be processed by the model. - system_prompt: System prompt to provide context to the model. - - Returns: - A LiteLLM messages array. - """ - formatted_messages: list[dict[str, Any]] - formatted_messages = [{"role": "system", "content": system_prompt}] if system_prompt else [] - - for message in messages: - contents = message["content"] - - formatted_contents = [ - self._format_request_message_content(content) - for content in contents - if not any(block_type in content for block_type in ["toolResult", "toolUse"]) - ] - formatted_tool_calls = [ - self._format_request_message_tool_call(content["toolUse"]) - for content in contents - if "toolUse" in content - ] - formatted_tool_messages = [ - self._format_request_tool_message(content["toolResult"]) - for content in contents - if "toolResult" in content - ] - - formatted_message = { - "role": message["role"], - "content": formatted_contents, - **({"tool_calls": formatted_tool_calls} if formatted_tool_calls else {}), - } - formatted_messages.append(formatted_message) - formatted_messages.extend(formatted_tool_messages) - - return [message for message in formatted_messages if message["content"] or "tool_calls" in message] - - @override - def format_request( - self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None - ) -> dict[str, Any]: - """Format a LiteLLM chat streaming request. - - Args: - messages: List of message objects to be processed by the model. - tool_specs: List of tool specifications to make available to the model. - system_prompt: System prompt to provide context to the model. - - Returns: - A LiteLLM chat streaming request. - """ - return { - "messages": self._format_request_messages(messages, system_prompt), - "model": self.config["model_id"], - "stream": True, - "stream_options": {"include_usage": True}, - "tools": [ - { - "type": "function", - "function": { - "name": tool_spec["name"], - "description": tool_spec["description"], - "parameters": tool_spec["inputSchema"]["json"], - }, - } - for tool_spec in tool_specs or [] - ], - **(self.config.get("params") or {}), - } - - @override - def format_chunk(self, event: dict[str, Any]) -> StreamEvent: - """Format the LiteLLM response events into standardized message chunks. - - Args: - event: A response event from the LiteLLM model. - - Returns: - The formatted chunk. - - Raises: - RuntimeError: If chunk_type is not recognized. - This error should never be encountered as we control chunk_type in the stream method. - """ - match event["chunk_type"]: - case "message_start": - return {"messageStart": {"role": "assistant"}} - - case "content_start": - if event["data_type"] == "tool": - return { - "contentBlockStart": { - "start": { - "toolUse": { - "name": event["data"].function.name, - "toolUseId": event["data"].id, - } - } - } - } - - return {"contentBlockStart": {"start": {}}} - - case "content_delta": - if event["data_type"] == "tool": - return {"contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments}}}} - - return {"contentBlockDelta": {"delta": {"text": event["data"]}}} - - case "content_stop": - return {"contentBlockStop": {}} - - case "message_stop": - match event["data"]: - case "tool_calls": - return {"messageStop": {"stopReason": "tool_use"}} - case "length": - return {"messageStop": {"stopReason": "max_tokens"}} - case _: - return {"messageStop": {"stopReason": "end_turn"}} - - case "metadata": - return { - "metadata": { - "usage": { - "inputTokens": event["data"].prompt_tokens, - "outputTokens": event["data"].completion_tokens, - "totalTokens": event["data"].total_tokens, - }, - "metrics": { - "latencyMs": 0, # TODO - }, - }, - } - - case _: - raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") - - @override - def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: - """Send the request to the LiteLLM model and get the streaming response. - - Args: - request: The formatted request to send to the LiteLLM model. - - Returns: - An iterable of response events from the LiteLLM model. - """ - response = self.client.chat.completions.create(**request) - - yield {"chunk_type": "message_start"} - yield {"chunk_type": "content_start", "data_type": "text"} - - tool_calls: dict[int, list[Any]] = {} - - for event in response: - choice = event.choices[0] - if choice.finish_reason: - break - - if choice.delta.content: - yield {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content} - - for tool_call in choice.delta.tool_calls or []: - tool_calls.setdefault(tool_call.index, []).append(tool_call) - - yield {"chunk_type": "content_stop", "data_type": "text"} - - for tool_deltas in tool_calls.values(): - tool_start, tool_deltas = tool_deltas[0], tool_deltas[1:] - yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_start} - - for tool_delta in tool_deltas: - yield {"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta} - - yield {"chunk_type": "content_stop", "data_type": "tool"} - - yield {"chunk_type": "message_stop", "data": choice.finish_reason} - - event = next(response) - if hasattr(event, "usage"): - yield {"chunk_type": "metadata", "data": event.usage} + return OpenAIModel.format_request_message_content(content) diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py new file mode 100644 index 00000000..6e32c5bd --- /dev/null +++ b/src/strands/models/openai.py @@ -0,0 +1,123 @@ +"""OpenAI model provider. + +- Docs: https://platform.openai.com/docs/overview +""" + +import logging +from typing import Any, Iterable, Optional, Protocol, TypedDict, cast + +import openai +from typing_extensions import Unpack, override + +from ..types.models import OpenAIModel as SAOpenAIModel + +logger = logging.getLogger(__name__) + + +class Client(Protocol): + """Protocol defining the OpenAI-compatible interface for the underlying provider client.""" + + @property + def chat(self) -> Any: + """Chat completions interface.""" + ... + + +class OpenAIModel(SAOpenAIModel): + """OpenAI model provider implementation.""" + + client: Client + + class OpenAIConfig(TypedDict, total=False): + """Configuration options for OpenAI models. + + Attributes: + model_id: Model ID (e.g., "gpt-4o"). + For a complete list of supported models, see https://platform.openai.com/docs/models. + params: Model parameters (e.g., max_tokens). + For a complete list of supported parameters, see + https://platform.openai.com/docs/api-reference/chat/create. + """ + + model_id: str + params: Optional[dict[str, Any]] + + def __init__(self, client_args: Optional[dict[str, Any]] = None, **model_config: Unpack[OpenAIConfig]) -> None: + """Initialize provider instance. + + Args: + client_args: Arguments for the OpenAI client. + For a complete list of supported arguments, see https://pypi.org/project/openai/. + **model_config: Configuration options for the OpenAI model. + """ + self.config = dict(model_config) + + logger.debug("config=<%s> | initializing", self.config) + + client_args = client_args or {} + self.client = openai.OpenAI(**client_args) + + @override + def update_config(self, **model_config: Unpack[OpenAIConfig]) -> None: # type: ignore[override] + """Update the OpenAI model configuration with the provided arguments. + + Args: + **model_config: Configuration overrides. + """ + self.config.update(model_config) + + @override + def get_config(self) -> OpenAIConfig: + """Get the OpenAI model configuration. + + Returns: + The OpenAI model configuration. + """ + return cast(OpenAIModel.OpenAIConfig, self.config) + + @override + def stream(self, request: dict[str, Any]) -> Iterable[dict[str, Any]]: + """Send the request to the OpenAI model and get the streaming response. + + Args: + request: The formatted request to send to the OpenAI model. + + Returns: + An iterable of response events from the OpenAI model. + """ + response = self.client.chat.completions.create(**request) + + yield {"chunk_type": "message_start"} + yield {"chunk_type": "content_start", "data_type": "text"} + + tool_calls: dict[int, list[Any]] = {} + + for event in response: + choice = event.choices[0] + + if choice.delta.content: + yield {"chunk_type": "content_delta", "data_type": "text", "data": choice.delta.content} + + for tool_call in choice.delta.tool_calls or []: + tool_calls.setdefault(tool_call.index, []).append(tool_call) + + if choice.finish_reason: + break + + yield {"chunk_type": "content_stop", "data_type": "text"} + + for tool_deltas in tool_calls.values(): + yield {"chunk_type": "content_start", "data_type": "tool", "data": tool_deltas[0]} + + for tool_delta in tool_deltas: + yield {"chunk_type": "content_delta", "data_type": "tool", "data": tool_delta} + + yield {"chunk_type": "content_stop", "data_type": "tool"} + + yield {"chunk_type": "message_stop", "data": choice.finish_reason} + + # Skip remaining events as we don't have use for anything except the final usage payload + for event in response: + _ = event + + yield {"chunk_type": "metadata", "data": event.usage} diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index b3709a1f..3ec663ce 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -315,7 +315,7 @@ def start_model_invoke_span( "gen_ai.system": "strands-agents", "agent.name": agent_name, "gen_ai.agent.name": agent_name, - "gen_ai.prompt": json.dumps(messages, cls=JSONEncoder), + "gen_ai.prompt": serialize(messages), } if model_id: @@ -338,7 +338,7 @@ def end_model_invoke_span( error: Optional exception if the model call failed. """ attributes: Dict[str, AttributeValue] = { - "gen_ai.completion": json.dumps(message["content"], cls=JSONEncoder), + "gen_ai.completion": serialize(message["content"]), "gen_ai.usage.prompt_tokens": usage["inputTokens"], "gen_ai.usage.completion_tokens": usage["outputTokens"], "gen_ai.usage.total_tokens": usage["totalTokens"], @@ -360,10 +360,10 @@ def start_tool_call_span( The created span, or None if tracing is not enabled. """ attributes: Dict[str, AttributeValue] = { - "gen_ai.prompt": json.dumps(tool, cls=JSONEncoder), + "gen_ai.prompt": serialize(tool), "tool.name": tool["name"], "tool.id": tool["toolUseId"], - "tool.parameters": json.dumps(tool["input"], cls=JSONEncoder), + "tool.parameters": serialize(tool["input"]), } # Add additional kwargs as attributes @@ -387,7 +387,7 @@ def end_tool_call_span( status = tool_result.get("status") status_str = str(status) if status is not None else "" - tool_result_content_json = json.dumps(tool_result.get("content"), cls=JSONEncoder) + tool_result_content_json = serialize(tool_result.get("content")) attributes.update( { "tool.result": tool_result_content_json, @@ -420,7 +420,7 @@ def start_event_loop_cycle_span( parent_span = parent_span if parent_span else event_loop_kwargs.get("event_loop_parent_span") attributes: Dict[str, AttributeValue] = { - "gen_ai.prompt": json.dumps(messages, cls=JSONEncoder), + "gen_ai.prompt": serialize(messages), "event_loop.cycle_id": event_loop_cycle_id, } @@ -449,11 +449,11 @@ def end_event_loop_cycle_span( error: Optional exception if the cycle failed. """ attributes: Dict[str, AttributeValue] = { - "gen_ai.completion": json.dumps(message["content"], cls=JSONEncoder), + "gen_ai.completion": serialize(message["content"]), } if tool_result_message: - attributes["tool.result"] = json.dumps(tool_result_message["content"], cls=JSONEncoder) + attributes["tool.result"] = serialize(tool_result_message["content"]) self._end_span(span, attributes, error) @@ -490,7 +490,7 @@ def start_agent_span( attributes["gen_ai.request.model"] = model_id if tools: - tools_json = json.dumps(tools, cls=JSONEncoder) + tools_json = serialize(tools) attributes["agent.tools"] = tools_json attributes["gen_ai.agent.tools"] = tools_json @@ -571,3 +571,15 @@ def get_tracer( ) return _tracer_instance + + +def serialize(obj: Any) -> str: + """Serialize an object to JSON with consistent settings. + + Args: + obj: The object to serialize + + Returns: + JSON string representation of the object + """ + return json.dumps(obj, ensure_ascii=False, cls=JSONEncoder) diff --git a/src/strands/tools/tools.py b/src/strands/tools/tools.py index 40565a24..7d43125b 100644 --- a/src/strands/tools/tools.py +++ b/src/strands/tools/tools.py @@ -40,7 +40,7 @@ def validate_tool_use_name(tool: ToolUse) -> None: Raises: InvalidToolUseNameException: If the tool name is invalid. """ - # We need to fix some typing here, because we dont actually expect a ToolUse, but dict[str, Any] + # We need to fix some typing here, because we don't actually expect a ToolUse, but dict[str, Any] if "name" not in tool: message = "tool name missing" # type: ignore[unreachable] logger.warning(message) diff --git a/src/strands/types/guardrails.py b/src/strands/types/guardrails.py index 6055b9ab..c15ba1be 100644 --- a/src/strands/types/guardrails.py +++ b/src/strands/types/guardrails.py @@ -16,7 +16,7 @@ class GuardrailConfig(TypedDict, total=False): Attributes: guardrailIdentifier: Unique identifier for the guardrail. guardrailVersion: Version of the guardrail to apply. - streamProcessingMode: Procesing mode. + streamProcessingMode: Processing mode. trace: The trace behavior for the guardrail. """ @@ -219,7 +219,7 @@ class GuardrailAssessment(TypedDict): contentPolicy: The content policy. contextualGroundingPolicy: The contextual grounding policy used for the guardrail assessment. sensitiveInformationPolicy: The sensitive information policy. - topicPolic: The topic policy. + topicPolicy: The topic policy. wordPolicy: The word policy. """ diff --git a/src/strands/types/media.py b/src/strands/types/media.py index 058a09ea..29b89e5c 100644 --- a/src/strands/types/media.py +++ b/src/strands/types/media.py @@ -68,7 +68,7 @@ class ImageContent(TypedDict): class VideoSource(TypedDict): - """Contains the content of a vidoe. + """Contains the content of a video. Attributes: bytes: The binary content of the video. diff --git a/src/strands/types/models/__init__.py b/src/strands/types/models/__init__.py new file mode 100644 index 00000000..5ce0a498 --- /dev/null +++ b/src/strands/types/models/__init__.py @@ -0,0 +1,6 @@ +"""Model-related type definitions for the SDK.""" + +from .model import Model +from .openai import OpenAIModel + +__all__ = ["Model", "OpenAIModel"] diff --git a/src/strands/types/models.py b/src/strands/types/models/model.py similarity index 97% rename from src/strands/types/models.py rename to src/strands/types/models/model.py index e3d96e29..23e74602 100644 --- a/src/strands/types/models.py +++ b/src/strands/types/models/model.py @@ -4,9 +4,9 @@ import logging from typing import Any, Iterable, Optional -from .content import Messages -from .streaming import StreamEvent -from .tools import ToolSpec +from ..content import Messages +from ..streaming import StreamEvent +from ..tools import ToolSpec logger = logging.getLogger(__name__) diff --git a/src/strands/types/models/openai.py b/src/strands/types/models/openai.py new file mode 100644 index 00000000..8f5ffab3 --- /dev/null +++ b/src/strands/types/models/openai.py @@ -0,0 +1,240 @@ +"""Base OpenAI model provider. + +This module provides the base OpenAI model provider class which implements shared logic for formatting requests and +responses to and from the OpenAI specification. + +- Docs: https://pypi.org/project/openai +""" + +import abc +import json +import logging +import mimetypes +from typing import Any, Optional + +from typing_extensions import override + +from ..content import ContentBlock, Messages +from ..streaming import StreamEvent +from ..tools import ToolResult, ToolSpec, ToolUse +from .model import Model + +logger = logging.getLogger(__name__) + + +class OpenAIModel(Model, abc.ABC): + """Base OpenAI model provider implementation. + + Implements shared logic for formatting requests and responses to and from the OpenAI specification. + """ + + config: dict[str, Any] + + @staticmethod + def format_request_message_content(content: ContentBlock) -> dict[str, Any]: + """Format an OpenAI compatible content block. + + Args: + content: Message content. + + Returns: + OpenAI compatible content block. + """ + if "image" in content: + mime_type = mimetypes.types_map.get(f".{content['image']['format']}", "application/octet-stream") + image_data = content["image"]["source"]["bytes"].decode("utf-8") + return { + "image_url": { + "detail": "auto", + "format": mime_type, + "url": f"data:{mime_type};base64,{image_data}", + }, + "type": "image_url", + } + + if "text" in content: + return {"text": content["text"], "type": "text"} + + return {"text": json.dumps(content), "type": "text"} + + @staticmethod + def format_request_message_tool_call(tool_use: ToolUse) -> dict[str, Any]: + """Format an OpenAI compatible tool call. + + Args: + tool_use: Tool use requested by the model. + + Returns: + OpenAI compatible tool call. + """ + return { + "function": { + "arguments": json.dumps(tool_use["input"]), + "name": tool_use["name"], + }, + "id": tool_use["toolUseId"], + "type": "function", + } + + @staticmethod + def format_request_tool_message(tool_result: ToolResult) -> dict[str, Any]: + """Format an OpenAI compatible tool message. + + Args: + tool_result: Tool result collected from a tool execution. + + Returns: + OpenAI compatible tool message. + """ + return { + "role": "tool", + "tool_call_id": tool_result["toolUseId"], + "content": json.dumps( + { + "content": tool_result["content"], + "status": tool_result["status"], + } + ), + } + + @classmethod + def format_request_messages(cls, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: + """Format an OpenAI compatible messages array. + + Args: + messages: List of message objects to be processed by the model. + system_prompt: System prompt to provide context to the model. + + Returns: + An OpenAI compatible messages array. + """ + formatted_messages: list[dict[str, Any]] + formatted_messages = [{"role": "system", "content": system_prompt}] if system_prompt else [] + + for message in messages: + contents = message["content"] + + formatted_contents = [ + cls.format_request_message_content(content) + for content in contents + if not any(block_type in content for block_type in ["toolResult", "toolUse"]) + ] + formatted_tool_calls = [ + cls.format_request_message_tool_call(content["toolUse"]) for content in contents if "toolUse" in content + ] + formatted_tool_messages = [ + cls.format_request_tool_message(content["toolResult"]) + for content in contents + if "toolResult" in content + ] + + formatted_message = { + "role": message["role"], + "content": formatted_contents, + **({"tool_calls": formatted_tool_calls} if formatted_tool_calls else {}), + } + formatted_messages.append(formatted_message) + formatted_messages.extend(formatted_tool_messages) + + return [message for message in formatted_messages if message["content"] or "tool_calls" in message] + + @override + def format_request( + self, messages: Messages, tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None + ) -> dict[str, Any]: + """Format an OpenAI compatible chat streaming request. + + Args: + messages: List of message objects to be processed by the model. + tool_specs: List of tool specifications to make available to the model. + system_prompt: System prompt to provide context to the model. + + Returns: + An OpenAI compatible chat streaming request. + """ + return { + "messages": self.format_request_messages(messages, system_prompt), + "model": self.config["model_id"], + "stream": True, + "stream_options": {"include_usage": True}, + "tools": [ + { + "type": "function", + "function": { + "name": tool_spec["name"], + "description": tool_spec["description"], + "parameters": tool_spec["inputSchema"]["json"], + }, + } + for tool_spec in tool_specs or [] + ], + **(self.config.get("params") or {}), + } + + @override + def format_chunk(self, event: dict[str, Any]) -> StreamEvent: + """Format an OpenAI response event into a standardized message chunk. + + Args: + event: A response event from the OpenAI compatible model. + + Returns: + The formatted chunk. + + Raises: + RuntimeError: If chunk_type is not recognized. + This error should never be encountered as chunk_type is controlled in the stream method. + """ + match event["chunk_type"]: + case "message_start": + return {"messageStart": {"role": "assistant"}} + + case "content_start": + if event["data_type"] == "tool": + return { + "contentBlockStart": { + "start": { + "toolUse": { + "name": event["data"].function.name, + "toolUseId": event["data"].id, + } + } + } + } + + return {"contentBlockStart": {"start": {}}} + + case "content_delta": + if event["data_type"] == "tool": + return {"contentBlockDelta": {"delta": {"toolUse": {"input": event["data"].function.arguments}}}} + + return {"contentBlockDelta": {"delta": {"text": event["data"]}}} + + case "content_stop": + return {"contentBlockStop": {}} + + case "message_stop": + match event["data"]: + case "tool_calls": + return {"messageStop": {"stopReason": "tool_use"}} + case "length": + return {"messageStop": {"stopReason": "max_tokens"}} + case _: + return {"messageStop": {"stopReason": "end_turn"}} + + case "metadata": + return { + "metadata": { + "usage": { + "inputTokens": event["data"].prompt_tokens, + "outputTokens": event["data"].completion_tokens, + "totalTokens": event["data"].total_tokens, + }, + "metrics": { + "latencyMs": 0, # TODO + }, + }, + } + + case _: + raise RuntimeError(f"chunk_type=<{event['chunk_type']} | unknown type") diff --git a/tests-integ/test_model_openai.py b/tests-integ/test_model_openai.py new file mode 100644 index 00000000..c9046ad5 --- /dev/null +++ b/tests-integ/test_model_openai.py @@ -0,0 +1,46 @@ +import os + +import pytest + +import strands +from strands import Agent +from strands.models.openai import OpenAIModel + + +@pytest.fixture +def model(): + return OpenAIModel( + model_id="gpt-4o", + client_args={ + "api_key": os.getenv("OPENAI_API_KEY"), + }, + ) + + +@pytest.fixture +def tools(): + @strands.tool + def tool_time() -> str: + return "12:00" + + @strands.tool + def tool_weather() -> str: + return "sunny" + + return [tool_time, tool_weather] + + +@pytest.fixture +def agent(model, tools): + return Agent(model=model, tools=tools) + + +@pytest.mark.skipif( + "OPENAI_API_KEY" not in os.environ, + reason="OPENAI_API_KEY environment variable missing", +) +def test_agent(agent): + result = agent("What is the time and weather in New York?") + text = result.message["content"][0]["text"].lower() + + assert all(string in text for string in ["12:00", "sunny"]) diff --git a/tests-integ/test_stream_agent.py b/tests-integ/test_stream_agent.py index 4c97db6b..01f20339 100644 --- a/tests-integ/test_stream_agent.py +++ b/tests-integ/test_stream_agent.py @@ -1,5 +1,5 @@ """ -Test script for Strands's custom callback handler functionality. +Test script for Strands' custom callback handler functionality. Demonstrates different patterns of callback handling and processing. """ diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 5d4d9b40..528d1498 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -1,4 +1,3 @@ -import json import unittest.mock import pytest @@ -8,9 +7,14 @@ @pytest.fixture -def litellm_client(): +def litellm_client_cls(): with unittest.mock.patch.object(strands.models.litellm.litellm, "LiteLLM") as mock_client_cls: - yield mock_client_cls.return_value + yield mock_client_cls + + +@pytest.fixture +def litellm_client(litellm_client_cls): + return litellm_client_cls.return_value @pytest.fixture @@ -35,15 +39,15 @@ def system_prompt(): return "s1" -def test__init__model_configs(litellm_client, model_id): - _ = litellm_client +def test__init__(litellm_client_cls, model_id): + model = LiteLLMModel({"api_key": "k1"}, model_id=model_id, params={"max_tokens": 1}) - model = LiteLLMModel(model_id=model_id, params={"max_tokens": 1}) + tru_config = model.get_config() + exp_config = {"model_id": "m1", "params": {"max_tokens": 1}} - tru_max_tokens = model.get_config().get("params") - exp_max_tokens = {"max_tokens": 1} + assert tru_config == exp_config - assert tru_max_tokens == exp_max_tokens + litellm_client_cls.assert_called_once_with(api_key="k1") def test_update_config(model, model_id): @@ -55,513 +59,47 @@ def test_update_config(model, model_id): assert tru_model_id == exp_model_id -def test_format_request_default(model, messages, model_id): - tru_request = model.format_request(messages) - exp_request = { - "messages": [{"role": "user", "content": [{"type": "text", "text": "test"}]}], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_params(model, messages, model_id): - model.update_config(params={"max_tokens": 1}) - - tru_request = model.format_request(messages) - exp_request = { - "messages": [{"role": "user", "content": [{"type": "text", "text": "test"}]}], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [], - "max_tokens": 1, - } - - assert tru_request == exp_request - - -def test_format_request_with_system_prompt(model, messages, model_id, system_prompt): - tru_request = model.format_request(messages, system_prompt=system_prompt) - exp_request = { - "messages": [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": [{"type": "text", "text": "test"}]}, - ], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_image(model, model_id): - messages = [ - { - "role": "user", - "content": [ - { - "image": { - "format": "jpg", - "source": {"bytes": b"base64encodedimage"}, - }, - }, - ], - }, - ] - - tru_request = model.format_request(messages) - exp_request = { - "messages": [ - { - "role": "user", - "content": [ - { - "image_url": { - "detail": "auto", - "format": "image/jpeg", - "url": "data:image/jpeg;base64,base64encodedimage", - }, - "type": "image_url", - }, - ], - }, - ], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_reasoning(model, model_id): - messages = [ - { - "role": "user", - "content": [ - { - "reasoningContent": { - "reasoningText": { - "signature": "reasoning_signature", - "text": "reasoning_text", - }, - }, - }, - ], - }, - ] - - tru_request = model.format_request(messages) - exp_request = { - "messages": [ +@pytest.mark.parametrize( + "content, exp_result", + [ + # Case 1: Thinking + ( { - "role": "user", - "content": [ - { + "reasoningContent": { + "reasoningText": { "signature": "reasoning_signature", - "thinking": "reasoning_text", - "type": "thinking", - }, - ], - }, - ], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_video(model, model_id): - messages = [ - { - "role": "user", - "content": [ - { - "video": { - "source": {"bytes": "base64encodedvideo"}, + "text": "reasoning_text", }, }, - ], - }, - ] - - tru_request = model.format_request(messages) - exp_request = { - "messages": [ - { - "role": "user", - "content": [ - { - "type": "video_url", - "video_url": { - "detail": "auto", - "url": "base64encodedvideo", - }, - }, - ], }, - ], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_other(model, model_id): - messages = [ - { - "role": "user", - "content": [{"other": {"a": 1}}], - }, - ] - - tru_request = model.format_request(messages) - exp_request = { - "messages": [ { - "role": "user", - "content": [ - { - "text": json.dumps({"other": {"a": 1}}), - "type": "text", - }, - ], + "signature": "reasoning_signature", + "thinking": "reasoning_text", + "type": "thinking", }, - ], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_tool_result(model, model_id): - messages = [ - { - "role": "user", - "content": [{"toolResult": {"toolUseId": "c1", "status": "success", "content": [{"value": 4}]}}], - } - ] - - tru_request = model.format_request(messages) - exp_request = { - "messages": [ + ), + # Case 2: Video + ( { - "content": json.dumps( - { - "content": [{"value": 4}], - "status": "success", - } - ), - "role": "tool", - "tool_call_id": "c1", - }, - ], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_tool_use(model, model_id): - messages = [ - { - "role": "assistant", - "content": [ - { - "toolUse": { - "toolUseId": "c1", - "name": "calculator", - "input": {"expression": "2+2"}, - }, + "video": { + "source": {"bytes": "base64encodedvideo"}, }, - ], - }, - ] - - tru_request = model.format_request(messages) - exp_request = { - "messages": [ - { - "content": [], - "role": "assistant", - "tool_calls": [ - { - "function": { - "name": "calculator", - "arguments": '{"expression": "2+2"}', - }, - "id": "c1", - "type": "function", - } - ], - } - ], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [], - } - - assert tru_request == exp_request - - -def test_format_request_with_tool_specs(model, messages, model_id): - tool_specs = [ - { - "name": "calculator", - "description": "Calculate mathematical expressions", - "inputSchema": { - "json": {"type": "object", "properties": {"expression": {"type": "string"}}, "required": ["expression"]} }, - } - ] - - tru_request = model.format_request(messages, tool_specs) - exp_request = { - "messages": [{"role": "user", "content": [{"type": "text", "text": "test"}]}], - "model": model_id, - "stream": True, - "stream_options": { - "include_usage": True, - }, - "tools": [ { - "type": "function", - "function": { - "name": "calculator", - "description": "Calculate mathematical expressions", - "parameters": { - "type": "object", - "properties": {"expression": {"type": "string"}}, - "required": ["expression"], - }, + "type": "video_url", + "video_url": { + "detail": "auto", + "url": "base64encodedvideo", }, - } - ], - } - - assert tru_request == exp_request - - -def test_format_chunk_message_start(model): - event = {"chunk_type": "message_start"} - - tru_chunk = model.format_chunk(event) - exp_chunk = {"messageStart": {"role": "assistant"}} - - assert tru_chunk == exp_chunk - - -def test_format_chunk_content_start_text(model): - event = {"chunk_type": "content_start", "data_type": "text"} - - tru_chunk = model.format_chunk(event) - exp_chunk = {"contentBlockStart": {"start": {}}} - - assert tru_chunk == exp_chunk - - -def test_format_chunk_content_start_tool(model): - mock_tool_use = unittest.mock.Mock() - mock_tool_use.function.name = "calculator" - mock_tool_use.id = "c1" - - event = {"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_use} - - tru_chunk = model.format_chunk(event) - exp_chunk = {"contentBlockStart": {"start": {"toolUse": {"name": "calculator", "toolUseId": "c1"}}}} - - assert tru_chunk == exp_chunk - - -def test_format_chunk_content_delta_text(model): - event = {"chunk_type": "content_delta", "data_type": "text", "data": "Hello"} - - tru_chunk = model.format_chunk(event) - exp_chunk = {"contentBlockDelta": {"delta": {"text": "Hello"}}} - - assert tru_chunk == exp_chunk - - -def test_format_chunk_content_delta_tool(model): - event = { - "chunk_type": "content_delta", - "data_type": "tool", - "data": unittest.mock.Mock(function=unittest.mock.Mock(arguments='{"expression": "2+2"}')), - } - - tru_chunk = model.format_chunk(event) - exp_chunk = {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "2+2"}'}}}} - - assert tru_chunk == exp_chunk - - -def test_format_chunk_content_stop(model): - event = {"chunk_type": "content_stop"} - - tru_chunk = model.format_chunk(event) - exp_chunk = {"contentBlockStop": {}} - - assert tru_chunk == exp_chunk - - -def test_format_chunk_message_stop_end_turn(model): - event = {"chunk_type": "message_stop", "data": "stop"} - - tru_chunk = model.format_chunk(event) - exp_chunk = {"messageStop": {"stopReason": "end_turn"}} - - assert tru_chunk == exp_chunk - - -def test_format_chunk_message_stop_tool_use(model): - event = {"chunk_type": "message_stop", "data": "tool_calls"} - - tru_chunk = model.format_chunk(event) - exp_chunk = {"messageStop": {"stopReason": "tool_use"}} - - assert tru_chunk == exp_chunk - - -def test_format_chunk_message_stop_max_tokens(model): - event = {"chunk_type": "message_stop", "data": "length"} - - tru_chunk = model.format_chunk(event) - exp_chunk = {"messageStop": {"stopReason": "max_tokens"}} - - assert tru_chunk == exp_chunk - - -def test_format_chunk_metadata(model): - event = { - "chunk_type": "metadata", - "data": unittest.mock.Mock(prompt_tokens=100, completion_tokens=50, total_tokens=150), - } - - tru_chunk = model.format_chunk(event) - exp_chunk = { - "metadata": { - "usage": { - "inputTokens": 100, - "outputTokens": 50, - "totalTokens": 150, - }, - "metrics": { - "latencyMs": 0, }, - }, - } - - assert tru_chunk == exp_chunk - - -def test_format_chunk_other(model): - event = {"chunk_type": "other"} - - with pytest.raises(RuntimeError, match="chunk_type= | unknown type"): - model.format_chunk(event) - - -def test_stream(litellm_client, model): - mock_tool_call_1_part_1 = unittest.mock.Mock(index=0) - mock_tool_call_2_part_1 = unittest.mock.Mock(index=1) - mock_delta_1 = unittest.mock.Mock( - content="I'll calculate", tool_calls=[mock_tool_call_1_part_1, mock_tool_call_2_part_1] - ) - - mock_tool_call_1_part_2 = unittest.mock.Mock(index=0) - mock_tool_call_2_part_2 = unittest.mock.Mock(index=1) - mock_delta_2 = unittest.mock.Mock( - content="that for you", tool_calls=[mock_tool_call_1_part_2, mock_tool_call_2_part_2] - ) - - mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_1)]) - mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_2)]) - mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls")]) - mock_event_4 = unittest.mock.Mock() - - litellm_client.chat.completions.create.return_value = iter([mock_event_1, mock_event_2, mock_event_3, mock_event_4]) - - request = {"model": "m1", "messages": [{"role": "user", "content": [{"type": "text", "text": "calculate 2+2"}]}]} - response = model.stream(request) - - tru_events = list(response) - exp_events = [ - {"chunk_type": "message_start"}, - {"chunk_type": "content_start", "data_type": "text"}, - {"chunk_type": "content_delta", "data_type": "text", "data": "I'll calculate"}, - {"chunk_type": "content_delta", "data_type": "text", "data": "that for you"}, - {"chunk_type": "content_stop", "data_type": "text"}, - {"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call_1_part_1}, - {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_1_part_2}, - {"chunk_type": "content_stop", "data_type": "tool"}, - {"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call_2_part_1}, - {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_2_part_2}, - {"chunk_type": "content_stop", "data_type": "tool"}, - {"chunk_type": "message_stop", "data": "tool_calls"}, - {"chunk_type": "metadata", "data": mock_event_4.usage}, - ] - - assert tru_events == exp_events - litellm_client.chat.completions.create.assert_called_once_with(**request) - - -def test_stream_empty(litellm_client, model): - mock_delta = unittest.mock.Mock(content=None, tool_calls=None) - - mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) - mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop")]) - mock_event_3 = unittest.mock.Mock(spec=[]) - - litellm_client.chat.completions.create.return_value = iter([mock_event_1, mock_event_2, mock_event_3]) - - request = {"model": "m1", "messages": [{"role": "user", "content": []}]} - response = model.stream(request) - - tru_events = list(response) - exp_events = [ - {"chunk_type": "message_start"}, - {"chunk_type": "content_start", "data_type": "text"}, - {"chunk_type": "content_stop", "data_type": "text"}, - {"chunk_type": "message_stop", "data": "stop"}, - ] - - assert tru_events == exp_events - litellm_client.chat.completions.create.assert_called_once_with(**request) + ), + # Case 3: Text + ( + {"text": "hello"}, + {"type": "text", "text": "hello"}, + ), + ], +) +def test_format_request_message_content(content, exp_result): + tru_result = LiteLLMModel.format_request_message_content(content) + assert tru_result == exp_result diff --git a/tests/test_llamaapi.py b/tests/strands/models/test_llamaapi.py similarity index 100% rename from tests/test_llamaapi.py rename to tests/strands/models/test_llamaapi.py diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py new file mode 100644 index 00000000..89aa591f --- /dev/null +++ b/tests/strands/models/test_openai.py @@ -0,0 +1,134 @@ +import unittest.mock + +import pytest + +import strands +from strands.models.openai import OpenAIModel + + +@pytest.fixture +def openai_client_cls(): + with unittest.mock.patch.object(strands.models.openai.openai, "OpenAI") as mock_client_cls: + yield mock_client_cls + + +@pytest.fixture +def openai_client(openai_client_cls): + return openai_client_cls.return_value + + +@pytest.fixture +def model_id(): + return "m1" + + +@pytest.fixture +def model(openai_client, model_id): + _ = openai_client + + return OpenAIModel(model_id=model_id) + + +@pytest.fixture +def messages(): + return [{"role": "user", "content": [{"text": "test"}]}] + + +@pytest.fixture +def system_prompt(): + return "s1" + + +def test__init__(openai_client_cls, model_id): + model = OpenAIModel({"api_key": "k1"}, model_id=model_id, params={"max_tokens": 1}) + + tru_config = model.get_config() + exp_config = {"model_id": "m1", "params": {"max_tokens": 1}} + + assert tru_config == exp_config + + openai_client_cls.assert_called_once_with(api_key="k1") + + +def test_update_config(model, model_id): + model.update_config(model_id=model_id) + + tru_model_id = model.get_config().get("model_id") + exp_model_id = model_id + + assert tru_model_id == exp_model_id + + +def test_stream(openai_client, model): + mock_tool_call_1_part_1 = unittest.mock.Mock(index=0) + mock_tool_call_2_part_1 = unittest.mock.Mock(index=1) + mock_delta_1 = unittest.mock.Mock( + content="I'll calculate", tool_calls=[mock_tool_call_1_part_1, mock_tool_call_2_part_1] + ) + + mock_tool_call_1_part_2 = unittest.mock.Mock(index=0) + mock_tool_call_2_part_2 = unittest.mock.Mock(index=1) + mock_delta_2 = unittest.mock.Mock( + content="that for you", tool_calls=[mock_tool_call_1_part_2, mock_tool_call_2_part_2] + ) + + mock_delta_3 = unittest.mock.Mock(content="", tool_calls=None) + + mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_1)]) + mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_2)]) + mock_event_3 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta_3)]) + mock_event_4 = unittest.mock.Mock() + + openai_client.chat.completions.create.return_value = iter([mock_event_1, mock_event_2, mock_event_3, mock_event_4]) + + request = {"model": "m1", "messages": [{"role": "user", "content": [{"type": "text", "text": "calculate 2+2"}]}]} + response = model.stream(request) + + tru_events = list(response) + exp_events = [ + {"chunk_type": "message_start"}, + {"chunk_type": "content_start", "data_type": "text"}, + {"chunk_type": "content_delta", "data_type": "text", "data": "I'll calculate"}, + {"chunk_type": "content_delta", "data_type": "text", "data": "that for you"}, + {"chunk_type": "content_stop", "data_type": "text"}, + {"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call_1_part_1}, + {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_1_part_1}, + {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_1_part_2}, + {"chunk_type": "content_stop", "data_type": "tool"}, + {"chunk_type": "content_start", "data_type": "tool", "data": mock_tool_call_2_part_1}, + {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_2_part_1}, + {"chunk_type": "content_delta", "data_type": "tool", "data": mock_tool_call_2_part_2}, + {"chunk_type": "content_stop", "data_type": "tool"}, + {"chunk_type": "message_stop", "data": "tool_calls"}, + {"chunk_type": "metadata", "data": mock_event_4.usage}, + ] + + assert tru_events == exp_events + openai_client.chat.completions.create.assert_called_once_with(**request) + + +def test_stream_empty(openai_client, model): + mock_delta = unittest.mock.Mock(content=None, tool_calls=None) + mock_usage = unittest.mock.Mock(prompt_tokens=0, completion_tokens=0, total_tokens=0) + + mock_event_1 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta)]) + mock_event_2 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="stop", delta=mock_delta)]) + mock_event_3 = unittest.mock.Mock() + mock_event_4 = unittest.mock.Mock(usage=mock_usage) + + openai_client.chat.completions.create.return_value = iter([mock_event_1, mock_event_2, mock_event_3, mock_event_4]) + + request = {"model": "m1", "messages": [{"role": "user", "content": []}]} + response = model.stream(request) + + tru_events = list(response) + exp_events = [ + {"chunk_type": "message_start"}, + {"chunk_type": "content_start", "data_type": "text"}, + {"chunk_type": "content_stop", "data_type": "text"}, + {"chunk_type": "message_stop", "data": "stop"}, + {"chunk_type": "metadata", "data": mock_usage}, + ] + + assert tru_events == exp_events + openai_client.chat.completions.create.assert_called_once_with(**request) diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index fbbfa89b..128b4f94 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -6,7 +6,7 @@ import pytest from opentelemetry.trace import StatusCode # type: ignore -from strands.telemetry.tracer import JSONEncoder, Tracer, get_tracer +from strands.telemetry.tracer import JSONEncoder, Tracer, get_tracer, serialize from strands.types.streaming import Usage @@ -635,3 +635,49 @@ def test_json_encoder_value_error(): # Test just the value result = json.loads(encoder.encode(huge_number)) assert result == "" + + +def test_serialize_non_ascii_characters(): + """Test that non-ASCII characters are preserved in JSON serialization.""" + + # Test with Japanese text + japanese_text = "こんにちは世界" + result = serialize({"text": japanese_text}) + assert japanese_text in result + assert "\\u" not in result + + # Test with emoji + emoji_text = "Hello 🌍" + result = serialize({"text": emoji_text}) + assert emoji_text in result + assert "\\u" not in result + + # Test with Chinese characters + chinese_text = "你好,世界" + result = serialize({"text": chinese_text}) + assert chinese_text in result + assert "\\u" not in result + + # Test with mixed content + mixed_text = {"ja": "こんにちは", "emoji": "😊", "zh": "你好", "en": "hello"} + result = serialize(mixed_text) + assert "こんにちは" in result + assert "😊" in result + assert "你好" in result + assert "\\u" not in result + + +def test_serialize_vs_json_dumps(): + """Test that serialize behaves differently from default json.dumps for non-ASCII characters.""" + + # Test with Japanese text + japanese_text = "こんにちは世界" + + # Default json.dumps should escape non-ASCII characters + default_result = json.dumps({"text": japanese_text}) + assert "\\u" in default_result + + # Our serialize function should preserve non-ASCII characters + custom_result = serialize({"text": japanese_text}) + assert japanese_text in custom_result + assert "\\u" not in custom_result diff --git a/tests/strands/types/models/__init__.py b/tests/strands/types/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/strands/types/models/test_model.py b/tests/strands/types/models/test_model.py new file mode 100644 index 00000000..f2797fe5 --- /dev/null +++ b/tests/strands/types/models/test_model.py @@ -0,0 +1,81 @@ +import pytest + +from strands.types.models import Model as SAModel + + +class TestModel(SAModel): + def update_config(self, **model_config): + return model_config + + def get_config(self): + return + + def format_request(self, messages, tool_specs, system_prompt): + return { + "messages": messages, + "tool_specs": tool_specs, + "system_prompt": system_prompt, + } + + def format_chunk(self, event): + return {"event": event} + + def stream(self, request): + yield {"request": request} + + +@pytest.fixture +def model(): + return TestModel() + + +@pytest.fixture +def messages(): + return [ + { + "role": "user", + "content": [{"text": "hello"}], + }, + ] + + +@pytest.fixture +def tool_specs(): + return [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "input": {"type": "string"}, + }, + "required": ["input"], + }, + }, + }, + ] + + +@pytest.fixture +def system_prompt(): + return "s1" + + +def test_converse(model, messages, tool_specs, system_prompt): + response = model.converse(messages, tool_specs, system_prompt) + + tru_events = list(response) + exp_events = [ + { + "event": { + "request": { + "messages": messages, + "tool_specs": tool_specs, + "system_prompt": system_prompt, + }, + }, + }, + ] + assert tru_events == exp_events diff --git a/tests/strands/types/models/test_openai.py b/tests/strands/types/models/test_openai.py new file mode 100644 index 00000000..97a0882a --- /dev/null +++ b/tests/strands/types/models/test_openai.py @@ -0,0 +1,332 @@ +import json +import unittest.mock + +import pytest + +from strands.types.models import OpenAIModel as SAOpenAIModel + + +class TestOpenAIModel(SAOpenAIModel): + def __init__(self): + self.config = {"model_id": "m1", "params": {"max_tokens": 1}} + + def update_config(self, **model_config): + return model_config + + def get_config(self): + return + + def stream(self, request): + yield {"request": request} + + +@pytest.fixture +def model(): + return TestOpenAIModel() + + +@pytest.fixture +def messages(): + return [ + { + "role": "user", + "content": [{"text": "hello"}], + }, + ] + + +@pytest.fixture +def tool_specs(): + return [ + { + "name": "test_tool", + "description": "A test tool", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "input": {"type": "string"}, + }, + "required": ["input"], + }, + }, + }, + ] + + +@pytest.fixture +def system_prompt(): + return "s1" + + +@pytest.mark.parametrize( + "content, exp_result", + [ + # Case 1: Image + ( + { + "image": { + "format": "jpg", + "source": {"bytes": b"image"}, + }, + }, + { + "image_url": { + "detail": "auto", + "format": "image/jpeg", + "url": "data:image/jpeg;base64,image", + }, + "type": "image_url", + }, + ), + # Case 2: Text + ( + {"text": "hello"}, + {"type": "text", "text": "hello"}, + ), + # Case 3: Other + ( + {"other": {"a": 1}}, + { + "text": json.dumps({"other": {"a": 1}}), + "type": "text", + }, + ), + ], +) +def test_format_request_message_content(content, exp_result): + tru_result = SAOpenAIModel.format_request_message_content(content) + assert tru_result == exp_result + + +def test_format_request_message_tool_call(): + tool_use = { + "input": {"expression": "2+2"}, + "name": "calculator", + "toolUseId": "c1", + } + + tru_result = SAOpenAIModel.format_request_message_tool_call(tool_use) + exp_result = { + "function": { + "arguments": '{"expression": "2+2"}', + "name": "calculator", + }, + "id": "c1", + "type": "function", + } + assert tru_result == exp_result + + +def test_format_request_tool_message(): + tool_result = { + "content": [{"value": 4}], + "status": "success", + "toolUseId": "c1", + } + + tru_result = SAOpenAIModel.format_request_tool_message(tool_result) + exp_result = { + "content": json.dumps( + { + "content": [{"value": 4}], + "status": "success", + } + ), + "role": "tool", + "tool_call_id": "c1", + } + assert tru_result == exp_result + + +def test_format_request_messages(system_prompt): + messages = [ + { + "content": [], + "role": "user", + }, + { + "content": [{"text": "hello"}], + "role": "user", + }, + { + "content": [ + {"text": "call tool"}, + { + "toolUse": { + "input": {"expression": "2+2"}, + "name": "calculator", + "toolUseId": "c1", + }, + }, + ], + "role": "assistant", + }, + { + "content": [{"toolResult": {"toolUseId": "c1", "status": "success", "content": [{"value": 4}]}}], + "role": "user", + }, + ] + + tru_result = SAOpenAIModel.format_request_messages(messages, system_prompt) + exp_result = [ + { + "content": system_prompt, + "role": "system", + }, + { + "content": [{"text": "hello", "type": "text"}], + "role": "user", + }, + { + "content": [{"text": "call tool", "type": "text"}], + "role": "assistant", + "tool_calls": [ + { + "function": { + "name": "calculator", + "arguments": '{"expression": "2+2"}', + }, + "id": "c1", + "type": "function", + } + ], + }, + { + "content": json.dumps( + { + "content": [{"value": 4}], + "status": "success", + } + ), + "role": "tool", + "tool_call_id": "c1", + }, + ] + assert tru_result == exp_result + + +def test_format_request(model, messages, tool_specs, system_prompt): + tru_request = model.format_request(messages, tool_specs, system_prompt) + exp_request = { + "messages": [ + { + "content": system_prompt, + "role": "system", + }, + { + "content": [{"text": "hello", "type": "text"}], + "role": "user", + }, + ], + "model": "m1", + "stream": True, + "stream_options": {"include_usage": True}, + "tools": [ + { + "function": { + "description": "A test tool", + "name": "test_tool", + "parameters": { + "properties": { + "input": {"type": "string"}, + }, + "required": ["input"], + "type": "object", + }, + }, + "type": "function", + }, + ], + "max_tokens": 1, + } + assert tru_request == exp_request + + +@pytest.mark.parametrize( + ("event", "exp_chunk"), + [ + # Case 1: Message start + ( + {"chunk_type": "message_start"}, + {"messageStart": {"role": "assistant"}}, + ), + # Case 2: Content Start - Tool Use + ( + { + "chunk_type": "content_start", + "data_type": "tool", + "data": unittest.mock.Mock(**{"function.name": "calculator", "id": "c1"}), + }, + {"contentBlockStart": {"start": {"toolUse": {"name": "calculator", "toolUseId": "c1"}}}}, + ), + # Case 3: Content Start - Text + ( + {"chunk_type": "content_start", "data_type": "text"}, + {"contentBlockStart": {"start": {}}}, + ), + # Case 4: Content Delta - Tool Use + ( + { + "chunk_type": "content_delta", + "data_type": "tool", + "data": unittest.mock.Mock(function=unittest.mock.Mock(arguments='{"expression": "2+2"}')), + }, + {"contentBlockDelta": {"delta": {"toolUse": {"input": '{"expression": "2+2"}'}}}}, + ), + # Case 5: Content Delta - Text + ( + {"chunk_type": "content_delta", "data_type": "text", "data": "hello"}, + {"contentBlockDelta": {"delta": {"text": "hello"}}}, + ), + # Case 6: Content Stop + ( + {"chunk_type": "content_stop"}, + {"contentBlockStop": {}}, + ), + # Case 7: Message Stop - Tool Use + ( + {"chunk_type": "message_stop", "data": "tool_calls"}, + {"messageStop": {"stopReason": "tool_use"}}, + ), + # Case 8: Message Stop - Max Tokens + ( + {"chunk_type": "message_stop", "data": "length"}, + {"messageStop": {"stopReason": "max_tokens"}}, + ), + # Case 9: Message Stop - End Turn + ( + {"chunk_type": "message_stop", "data": "stop"}, + {"messageStop": {"stopReason": "end_turn"}}, + ), + # Case 10: Metadata + ( + { + "chunk_type": "metadata", + "data": unittest.mock.Mock(prompt_tokens=100, completion_tokens=50, total_tokens=150), + }, + { + "metadata": { + "usage": { + "inputTokens": 100, + "outputTokens": 50, + "totalTokens": 150, + }, + "metrics": { + "latencyMs": 0, + }, + }, + }, + ), + ], +) +def test_format_chunk(event, exp_chunk, model): + tru_chunk = model.format_chunk(event) + assert tru_chunk == exp_chunk + + +def test_format_chunk_unknown_type(model): + event = {"chunk_type": "unknown"} + + with pytest.raises(RuntimeError, match="chunk_type= | unknown type"): + model.format_chunk(event)