mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-13 07:04:51 -06:00
244 lines
7.1 KiB
Python
244 lines
7.1 KiB
Python
# Copyright 2025 Google LLC
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Anthropic integration for Claude models."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from functools import cached_property
|
|
import logging
|
|
import os
|
|
from typing import AsyncGenerator
|
|
from typing import Generator
|
|
from typing import Iterable
|
|
from typing import Literal
|
|
from typing import Optional, Union
|
|
from typing import TYPE_CHECKING
|
|
|
|
from anthropic import AnthropicVertex
|
|
from anthropic import NOT_GIVEN
|
|
from anthropic import types as anthropic_types
|
|
from google.genai import types
|
|
from pydantic import BaseModel
|
|
from typing_extensions import override
|
|
|
|
from .base_llm import BaseLlm
|
|
from .llm_response import LlmResponse
|
|
|
|
if TYPE_CHECKING:
|
|
from .llm_request import LlmRequest
|
|
|
|
__all__ = ["Claude"]
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
MAX_TOKEN = 1024
|
|
|
|
|
|
class ClaudeRequest(BaseModel):
|
|
system_instruction: str
|
|
messages: Iterable[anthropic_types.MessageParam]
|
|
tools: list[anthropic_types.ToolParam]
|
|
|
|
|
|
def to_claude_role(role: Optional[str]) -> Literal["user", "assistant"]:
|
|
if role in ["model", "assistant"]:
|
|
return "assistant"
|
|
return "user"
|
|
|
|
|
|
def to_google_genai_finish_reason(
|
|
anthropic_stop_reason: Optional[str],
|
|
) -> types.FinishReason:
|
|
if anthropic_stop_reason in ["end_turn", "stop_sequence", "tool_use"]:
|
|
return "STOP"
|
|
if anthropic_stop_reason == "max_tokens":
|
|
return "MAX_TOKENS"
|
|
return "FINISH_REASON_UNSPECIFIED"
|
|
|
|
|
|
def part_to_message_block(
|
|
part: types.Part,
|
|
) -> Union[
|
|
anthropic_types.TextBlockParam,
|
|
anthropic_types.ImageBlockParam,
|
|
anthropic_types.ToolUseBlockParam,
|
|
anthropic_types.ToolResultBlockParam,
|
|
]:
|
|
if part.text:
|
|
return anthropic_types.TextBlockParam(text=part.text, type="text")
|
|
if part.function_call:
|
|
assert part.function_call.name
|
|
|
|
return anthropic_types.ToolUseBlockParam(
|
|
id=part.function_call.id or "",
|
|
name=part.function_call.name,
|
|
input=part.function_call.args,
|
|
type="tool_use",
|
|
)
|
|
if part.function_response:
|
|
content = ""
|
|
if (
|
|
"result" in part.function_response.response
|
|
and part.function_response.response["result"]
|
|
):
|
|
# Transformation is required because the content is a list of dict.
|
|
# ToolResultBlockParam content doesn't support list of dict. Converting
|
|
# to str to prevent anthropic.BadRequestError from being thrown.
|
|
content = str(part.function_response.response["result"])
|
|
return anthropic_types.ToolResultBlockParam(
|
|
tool_use_id=part.function_response.id or "",
|
|
type="tool_result",
|
|
content=content,
|
|
is_error=False,
|
|
)
|
|
raise NotImplementedError("Not supported yet.")
|
|
|
|
|
|
def content_to_message_param(
|
|
content: types.Content,
|
|
) -> anthropic_types.MessageParam:
|
|
return {
|
|
"role": to_claude_role(content.role),
|
|
"content": [part_to_message_block(part) for part in content.parts or []],
|
|
}
|
|
|
|
|
|
def content_block_to_part(
|
|
content_block: anthropic_types.ContentBlock,
|
|
) -> types.Part:
|
|
if isinstance(content_block, anthropic_types.TextBlock):
|
|
return types.Part.from_text(text=content_block.text)
|
|
if isinstance(content_block, anthropic_types.ToolUseBlock):
|
|
assert isinstance(content_block.input, dict)
|
|
part = types.Part.from_function_call(
|
|
name=content_block.name, args=content_block.input
|
|
)
|
|
part.function_call.id = content_block.id
|
|
return part
|
|
raise NotImplementedError("Not supported yet.")
|
|
|
|
|
|
def message_to_generate_content_response(
|
|
message: anthropic_types.Message,
|
|
) -> LlmResponse:
|
|
|
|
return LlmResponse(
|
|
content=types.Content(
|
|
role="model",
|
|
parts=[content_block_to_part(cb) for cb in message.content],
|
|
),
|
|
# TODO: Deal with these later.
|
|
# finish_reason=to_google_genai_finish_reason(message.stop_reason),
|
|
# usage_metadata=types.GenerateContentResponseUsageMetadata(
|
|
# prompt_token_count=message.usage.input_tokens,
|
|
# candidates_token_count=message.usage.output_tokens,
|
|
# total_token_count=(
|
|
# message.usage.input_tokens + message.usage.output_tokens
|
|
# ),
|
|
# ),
|
|
)
|
|
|
|
|
|
def function_declaration_to_tool_param(
|
|
function_declaration: types.FunctionDeclaration,
|
|
) -> anthropic_types.ToolParam:
|
|
assert function_declaration.name
|
|
|
|
properties = {}
|
|
if (
|
|
function_declaration.parameters
|
|
and function_declaration.parameters.properties
|
|
):
|
|
for key, value in function_declaration.parameters.properties.items():
|
|
value_dict = value.model_dump(exclude_none=True)
|
|
if "type" in value_dict:
|
|
value_dict["type"] = value_dict["type"].lower()
|
|
properties[key] = value_dict
|
|
|
|
return anthropic_types.ToolParam(
|
|
name=function_declaration.name,
|
|
description=function_declaration.description or "",
|
|
input_schema={
|
|
"type": "object",
|
|
"properties": properties,
|
|
},
|
|
)
|
|
|
|
|
|
class Claude(BaseLlm):
|
|
model: str = "claude-3-5-sonnet-v2@20241022"
|
|
|
|
@staticmethod
|
|
@override
|
|
def supported_models() -> list[str]:
|
|
return [r"claude-3-.*"]
|
|
|
|
@override
|
|
async def generate_content_async(
|
|
self, llm_request: LlmRequest, stream: bool = False
|
|
) -> AsyncGenerator[LlmResponse, None]:
|
|
messages = [
|
|
content_to_message_param(content)
|
|
for content in llm_request.contents or []
|
|
]
|
|
tools = NOT_GIVEN
|
|
if (
|
|
llm_request.config
|
|
and llm_request.config.tools
|
|
and llm_request.config.tools[0].function_declarations
|
|
):
|
|
tools = [
|
|
function_declaration_to_tool_param(tool)
|
|
for tool in llm_request.config.tools[0].function_declarations
|
|
]
|
|
tool_choice = (
|
|
anthropic_types.ToolChoiceAutoParam(
|
|
type="auto",
|
|
# TODO: allow parallel tool use.
|
|
disable_parallel_tool_use=True,
|
|
)
|
|
if llm_request.tools_dict
|
|
else NOT_GIVEN
|
|
)
|
|
message = self._anthropic_client.messages.create(
|
|
model=llm_request.model,
|
|
system=llm_request.config.system_instruction,
|
|
messages=messages,
|
|
tools=tools,
|
|
tool_choice=tool_choice,
|
|
max_tokens=MAX_TOKEN,
|
|
)
|
|
logger.info(
|
|
"Claude response: %s",
|
|
message.model_dump_json(indent=2, exclude_none=True),
|
|
)
|
|
yield message_to_generate_content_response(message)
|
|
|
|
@cached_property
|
|
def _anthropic_client(self) -> AnthropicVertex:
|
|
if (
|
|
"GOOGLE_CLOUD_PROJECT" not in os.environ
|
|
or "GOOGLE_CLOUD_LOCATION" not in os.environ
|
|
):
|
|
raise ValueError(
|
|
"GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION must be set for using"
|
|
" Anthropic on Vertex."
|
|
)
|
|
|
|
return AnthropicVertex(
|
|
project_id=os.environ["GOOGLE_CLOUD_PROJECT"],
|
|
region=os.environ["GOOGLE_CLOUD_LOCATION"],
|
|
)
|