adk-python/src/google/adk/models/anthropic_llm.py

262 lines
7.8 KiB
Python

# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Anthropic integration for Claude models."""
from __future__ import annotations
from functools import cached_property
import logging
import os
from typing import Any
from typing import AsyncGenerator
from typing import Generator
from typing import Iterable
from typing import Literal
from typing import Optional, Union
from typing import TYPE_CHECKING
from anthropic import AnthropicVertex
from anthropic import NOT_GIVEN
from anthropic import types as anthropic_types
from google.genai import types
from pydantic import BaseModel
from typing_extensions import override
from .base_llm import BaseLlm
from .llm_response import LlmResponse
if TYPE_CHECKING:
from .llm_request import LlmRequest
__all__ = ["Claude"]
logger = logging.getLogger(__name__)
MAX_TOKEN = 1024
class ClaudeRequest(BaseModel):
system_instruction: str
messages: Iterable[anthropic_types.MessageParam]
tools: list[anthropic_types.ToolParam]
def to_claude_role(role: Optional[str]) -> Literal["user", "assistant"]:
if role in ["model", "assistant"]:
return "assistant"
return "user"
def to_google_genai_finish_reason(
anthropic_stop_reason: Optional[str],
) -> types.FinishReason:
if anthropic_stop_reason in ["end_turn", "stop_sequence", "tool_use"]:
return "STOP"
if anthropic_stop_reason == "max_tokens":
return "MAX_TOKENS"
return "FINISH_REASON_UNSPECIFIED"
def part_to_message_block(
part: types.Part,
) -> Union[
anthropic_types.TextBlockParam,
anthropic_types.ImageBlockParam,
anthropic_types.ToolUseBlockParam,
anthropic_types.ToolResultBlockParam,
]:
if part.text:
return anthropic_types.TextBlockParam(text=part.text, type="text")
if part.function_call:
assert part.function_call.name
return anthropic_types.ToolUseBlockParam(
id=part.function_call.id or "",
name=part.function_call.name,
input=part.function_call.args,
type="tool_use",
)
if part.function_response:
content = ""
if (
"result" in part.function_response.response
and part.function_response.response["result"]
):
# Transformation is required because the content is a list of dict.
# ToolResultBlockParam content doesn't support list of dict. Converting
# to str to prevent anthropic.BadRequestError from being thrown.
content = str(part.function_response.response["result"])
return anthropic_types.ToolResultBlockParam(
tool_use_id=part.function_response.id or "",
type="tool_result",
content=content,
is_error=False,
)
raise NotImplementedError("Not supported yet.")
def content_to_message_param(
content: types.Content,
) -> anthropic_types.MessageParam:
return {
"role": to_claude_role(content.role),
"content": [part_to_message_block(part) for part in content.parts or []],
}
def content_block_to_part(
content_block: anthropic_types.ContentBlock,
) -> types.Part:
if isinstance(content_block, anthropic_types.TextBlock):
return types.Part.from_text(text=content_block.text)
if isinstance(content_block, anthropic_types.ToolUseBlock):
assert isinstance(content_block.input, dict)
part = types.Part.from_function_call(
name=content_block.name, args=content_block.input
)
part.function_call.id = content_block.id
return part
raise NotImplementedError("Not supported yet.")
def message_to_generate_content_response(
message: anthropic_types.Message,
) -> LlmResponse:
return LlmResponse(
content=types.Content(
role="model",
parts=[content_block_to_part(cb) for cb in message.content],
),
# TODO: Deal with these later.
# finish_reason=to_google_genai_finish_reason(message.stop_reason),
# usage_metadata=types.GenerateContentResponseUsageMetadata(
# prompt_token_count=message.usage.input_tokens,
# candidates_token_count=message.usage.output_tokens,
# total_token_count=(
# message.usage.input_tokens + message.usage.output_tokens
# ),
# ),
)
def _update_type_string(value_dict: dict[str, Any]):
"""Updates 'type' field to expected JSON schema format."""
if "type" in value_dict:
value_dict["type"] = value_dict["type"].lower()
if "items" in value_dict:
# 'type' field could exist for items as well, this would be the case if
# items represent primitive types.
_update_type_string(value_dict["items"])
if "properties" in value_dict["items"]:
# There could be properties as well on the items, especially if the items
# are complex object themselves. We recursively traverse each individual
# property as well and fix the "type" value.
for _, value in value_dict["items"]["properties"].items():
_update_type_string(value)
def function_declaration_to_tool_param(
function_declaration: types.FunctionDeclaration,
) -> anthropic_types.ToolParam:
assert function_declaration.name
properties = {}
if (
function_declaration.parameters
and function_declaration.parameters.properties
):
for key, value in function_declaration.parameters.properties.items():
value_dict = value.model_dump(exclude_none=True)
_update_type_string(value_dict)
properties[key] = value_dict
return anthropic_types.ToolParam(
name=function_declaration.name,
description=function_declaration.description or "",
input_schema={
"type": "object",
"properties": properties,
},
)
class Claude(BaseLlm):
model: str = "claude-3-5-sonnet-v2@20241022"
@staticmethod
@override
def supported_models() -> list[str]:
return [r"claude-3-.*"]
@override
async def generate_content_async(
self, llm_request: LlmRequest, stream: bool = False
) -> AsyncGenerator[LlmResponse, None]:
messages = [
content_to_message_param(content)
for content in llm_request.contents or []
]
tools = NOT_GIVEN
if (
llm_request.config
and llm_request.config.tools
and llm_request.config.tools[0].function_declarations
):
tools = [
function_declaration_to_tool_param(tool)
for tool in llm_request.config.tools[0].function_declarations
]
tool_choice = (
anthropic_types.ToolChoiceAutoParam(
type="auto",
# TODO: allow parallel tool use.
disable_parallel_tool_use=True,
)
if llm_request.tools_dict
else NOT_GIVEN
)
message = self._anthropic_client.messages.create(
model=llm_request.model,
system=llm_request.config.system_instruction,
messages=messages,
tools=tools,
tool_choice=tool_choice,
max_tokens=MAX_TOKEN,
)
logger.info(
"Claude response: %s",
message.model_dump_json(indent=2, exclude_none=True),
)
yield message_to_generate_content_response(message)
@cached_property
def _anthropic_client(self) -> AnthropicVertex:
if (
"GOOGLE_CLOUD_PROJECT" not in os.environ
or "GOOGLE_CLOUD_LOCATION" not in os.environ
):
raise ValueError(
"GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION must be set for using"
" Anthropic on Vertex."
)
return AnthropicVertex(
project_id=os.environ["GOOGLE_CLOUD_PROJECT"],
region=os.environ["GOOGLE_CLOUD_LOCATION"],
)