adk-python/src/google/adk/models/anthropic_llm.py
hangfei 9827820143 Agent Development Kit(ADK)
An easy-to-use and powerful framework to build AI agents.
2025-04-08 17:25:47 +00:00

244 lines
7.1 KiB
Python

# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Anthropic integration for Claude models."""
from __future__ import annotations
from functools import cached_property
import logging
import os
from typing import AsyncGenerator
from typing import Generator
from typing import Iterable
from typing import Literal
from typing import Optional, Union
from typing import TYPE_CHECKING
from anthropic import AnthropicVertex
from anthropic import NOT_GIVEN
from anthropic import types as anthropic_types
from google.genai import types
from pydantic import BaseModel
from typing_extensions import override
from .base_llm import BaseLlm
from .llm_response import LlmResponse
if TYPE_CHECKING:
from .llm_request import LlmRequest
__all__ = ["Claude"]
logger = logging.getLogger(__name__)
MAX_TOKEN = 1024
class ClaudeRequest(BaseModel):
system_instruction: str
messages: Iterable[anthropic_types.MessageParam]
tools: list[anthropic_types.ToolParam]
def to_claude_role(role: Optional[str]) -> Literal["user", "assistant"]:
if role in ["model", "assistant"]:
return "assistant"
return "user"
def to_google_genai_finish_reason(
anthropic_stop_reason: Optional[str],
) -> types.FinishReason:
if anthropic_stop_reason in ["end_turn", "stop_sequence", "tool_use"]:
return "STOP"
if anthropic_stop_reason == "max_tokens":
return "MAX_TOKENS"
return "FINISH_REASON_UNSPECIFIED"
def part_to_message_block(
part: types.Part,
) -> Union[
anthropic_types.TextBlockParam,
anthropic_types.ImageBlockParam,
anthropic_types.ToolUseBlockParam,
anthropic_types.ToolResultBlockParam,
]:
if part.text:
return anthropic_types.TextBlockParam(text=part.text, type="text")
if part.function_call:
assert part.function_call.name
return anthropic_types.ToolUseBlockParam(
id=part.function_call.id or "",
name=part.function_call.name,
input=part.function_call.args,
type="tool_use",
)
if part.function_response:
content = ""
if (
"result" in part.function_response.response
and part.function_response.response["result"]
):
# Transformation is required because the content is a list of dict.
# ToolResultBlockParam content doesn't support list of dict. Converting
# to str to prevent anthropic.BadRequestError from being thrown.
content = str(part.function_response.response["result"])
return anthropic_types.ToolResultBlockParam(
tool_use_id=part.function_response.id or "",
type="tool_result",
content=content,
is_error=False,
)
raise NotImplementedError("Not supported yet.")
def content_to_message_param(
content: types.Content,
) -> anthropic_types.MessageParam:
return {
"role": to_claude_role(content.role),
"content": [part_to_message_block(part) for part in content.parts or []],
}
def content_block_to_part(
content_block: anthropic_types.ContentBlock,
) -> types.Part:
if isinstance(content_block, anthropic_types.TextBlock):
return types.Part.from_text(text=content_block.text)
if isinstance(content_block, anthropic_types.ToolUseBlock):
assert isinstance(content_block.input, dict)
part = types.Part.from_function_call(
name=content_block.name, args=content_block.input
)
part.function_call.id = content_block.id
return part
raise NotImplementedError("Not supported yet.")
def message_to_generate_content_response(
message: anthropic_types.Message,
) -> LlmResponse:
return LlmResponse(
content=types.Content(
role="model",
parts=[content_block_to_part(cb) for cb in message.content],
),
# TODO: Deal with these later.
# finish_reason=to_google_genai_finish_reason(message.stop_reason),
# usage_metadata=types.GenerateContentResponseUsageMetadata(
# prompt_token_count=message.usage.input_tokens,
# candidates_token_count=message.usage.output_tokens,
# total_token_count=(
# message.usage.input_tokens + message.usage.output_tokens
# ),
# ),
)
def function_declaration_to_tool_param(
function_declaration: types.FunctionDeclaration,
) -> anthropic_types.ToolParam:
assert function_declaration.name
properties = {}
if (
function_declaration.parameters
and function_declaration.parameters.properties
):
for key, value in function_declaration.parameters.properties.items():
value_dict = value.model_dump(exclude_none=True)
if "type" in value_dict:
value_dict["type"] = value_dict["type"].lower()
properties[key] = value_dict
return anthropic_types.ToolParam(
name=function_declaration.name,
description=function_declaration.description or "",
input_schema={
"type": "object",
"properties": properties,
},
)
class Claude(BaseLlm):
model: str = "claude-3-5-sonnet-v2@20241022"
@staticmethod
@override
def supported_models() -> list[str]:
return [r"claude-3-.*"]
@override
async def generate_content_async(
self, llm_request: LlmRequest, stream: bool = False
) -> AsyncGenerator[LlmResponse, None]:
messages = [
content_to_message_param(content)
for content in llm_request.contents or []
]
tools = NOT_GIVEN
if (
llm_request.config
and llm_request.config.tools
and llm_request.config.tools[0].function_declarations
):
tools = [
function_declaration_to_tool_param(tool)
for tool in llm_request.config.tools[0].function_declarations
]
tool_choice = (
anthropic_types.ToolChoiceAutoParam(
type="auto",
# TODO: allow parallel tool use.
disable_parallel_tool_use=True,
)
if llm_request.tools_dict
else NOT_GIVEN
)
message = self._anthropic_client.messages.create(
model=llm_request.model,
system=llm_request.config.system_instruction,
messages=messages,
tools=tools,
tool_choice=tool_choice,
max_tokens=MAX_TOKEN,
)
logger.info(
"Claude response: %s",
message.model_dump_json(indent=2, exclude_none=True),
)
yield message_to_generate_content_response(message)
@cached_property
def _anthropic_client(self) -> AnthropicVertex:
if (
"GOOGLE_CLOUD_PROJECT" not in os.environ
or "GOOGLE_CLOUD_LOCATION" not in os.environ
):
raise ValueError(
"GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION must be set for using"
" Anthropic on Vertex."
)
return AnthropicVertex(
project_id=os.environ["GOOGLE_CLOUD_PROJECT"],
region=os.environ["GOOGLE_CLOUD_LOCATION"],
)