# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Anthropic integration for Claude models.""" from __future__ import annotations from functools import cached_property import logging import os from typing import AsyncGenerator from typing import Generator from typing import Iterable from typing import Literal from typing import Optional, Union from typing import TYPE_CHECKING from anthropic import AnthropicVertex from anthropic import NOT_GIVEN from anthropic import types as anthropic_types from google.genai import types from pydantic import BaseModel from typing_extensions import override from .base_llm import BaseLlm from .llm_response import LlmResponse if TYPE_CHECKING: from .llm_request import LlmRequest __all__ = ["Claude"] logger = logging.getLogger(__name__) MAX_TOKEN = 1024 class ClaudeRequest(BaseModel): system_instruction: str messages: Iterable[anthropic_types.MessageParam] tools: list[anthropic_types.ToolParam] def to_claude_role(role: Optional[str]) -> Literal["user", "assistant"]: if role in ["model", "assistant"]: return "assistant" return "user" def to_google_genai_finish_reason( anthropic_stop_reason: Optional[str], ) -> types.FinishReason: if anthropic_stop_reason in ["end_turn", "stop_sequence", "tool_use"]: return "STOP" if anthropic_stop_reason == "max_tokens": return "MAX_TOKENS" return "FINISH_REASON_UNSPECIFIED" def part_to_message_block( part: types.Part, ) -> Union[ anthropic_types.TextBlockParam, anthropic_types.ImageBlockParam, anthropic_types.ToolUseBlockParam, anthropic_types.ToolResultBlockParam, ]: if part.text: return anthropic_types.TextBlockParam(text=part.text, type="text") if part.function_call: assert part.function_call.name return anthropic_types.ToolUseBlockParam( id=part.function_call.id or "", name=part.function_call.name, input=part.function_call.args, type="tool_use", ) if part.function_response: content = "" if ( "result" in part.function_response.response and part.function_response.response["result"] ): # Transformation is required because the content is a list of dict. # ToolResultBlockParam content doesn't support list of dict. Converting # to str to prevent anthropic.BadRequestError from being thrown. content = str(part.function_response.response["result"]) return anthropic_types.ToolResultBlockParam( tool_use_id=part.function_response.id or "", type="tool_result", content=content, is_error=False, ) raise NotImplementedError("Not supported yet.") def content_to_message_param( content: types.Content, ) -> anthropic_types.MessageParam: return { "role": to_claude_role(content.role), "content": [part_to_message_block(part) for part in content.parts or []], } def content_block_to_part( content_block: anthropic_types.ContentBlock, ) -> types.Part: if isinstance(content_block, anthropic_types.TextBlock): return types.Part.from_text(text=content_block.text) if isinstance(content_block, anthropic_types.ToolUseBlock): assert isinstance(content_block.input, dict) part = types.Part.from_function_call( name=content_block.name, args=content_block.input ) part.function_call.id = content_block.id return part raise NotImplementedError("Not supported yet.") def message_to_generate_content_response( message: anthropic_types.Message, ) -> LlmResponse: return LlmResponse( content=types.Content( role="model", parts=[content_block_to_part(cb) for cb in message.content], ), # TODO: Deal with these later. # finish_reason=to_google_genai_finish_reason(message.stop_reason), # usage_metadata=types.GenerateContentResponseUsageMetadata( # prompt_token_count=message.usage.input_tokens, # candidates_token_count=message.usage.output_tokens, # total_token_count=( # message.usage.input_tokens + message.usage.output_tokens # ), # ), ) def function_declaration_to_tool_param( function_declaration: types.FunctionDeclaration, ) -> anthropic_types.ToolParam: assert function_declaration.name properties = {} if ( function_declaration.parameters and function_declaration.parameters.properties ): for key, value in function_declaration.parameters.properties.items(): value_dict = value.model_dump(exclude_none=True) if "type" in value_dict: value_dict["type"] = value_dict["type"].lower() properties[key] = value_dict return anthropic_types.ToolParam( name=function_declaration.name, description=function_declaration.description or "", input_schema={ "type": "object", "properties": properties, }, ) class Claude(BaseLlm): model: str = "claude-3-5-sonnet-v2@20241022" @staticmethod @override def supported_models() -> list[str]: return [r"claude-3-.*"] @override async def generate_content_async( self, llm_request: LlmRequest, stream: bool = False ) -> AsyncGenerator[LlmResponse, None]: messages = [ content_to_message_param(content) for content in llm_request.contents or [] ] tools = NOT_GIVEN if ( llm_request.config and llm_request.config.tools and llm_request.config.tools[0].function_declarations ): tools = [ function_declaration_to_tool_param(tool) for tool in llm_request.config.tools[0].function_declarations ] tool_choice = ( anthropic_types.ToolChoiceAutoParam( type="auto", # TODO: allow parallel tool use. disable_parallel_tool_use=True, ) if llm_request.tools_dict else NOT_GIVEN ) message = self._anthropic_client.messages.create( model=llm_request.model, system=llm_request.config.system_instruction, messages=messages, tools=tools, tool_choice=tool_choice, max_tokens=MAX_TOKEN, ) logger.info( "Claude response: %s", message.model_dump_json(indent=2, exclude_none=True), ) yield message_to_generate_content_response(message) @cached_property def _anthropic_client(self) -> AnthropicVertex: if ( "GOOGLE_CLOUD_PROJECT" not in os.environ or "GOOGLE_CLOUD_LOCATION" not in os.environ ): raise ValueError( "GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION must be set for using" " Anthropic on Vertex." ) return AnthropicVertex( project_id=os.environ["GOOGLE_CLOUD_PROJECT"], region=os.environ["GOOGLE_CLOUD_LOCATION"], )