Agent Development Kit(ADK)

An easy-to-use and powerful framework to build AI agents.
This commit is contained in:
hangfei
2025-04-08 17:22:09 +00:00
parent f92478bd5c
commit 9827820143
299 changed files with 44398 additions and 2 deletions

View File

@@ -0,0 +1,31 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines the interface to support a model."""
from .base_llm import BaseLlm
from .google_llm import Gemini
from .llm_request import LlmRequest
from .llm_response import LlmResponse
from .registry import LLMRegistry
__all__ = [
'BaseLlm',
'Gemini',
'LLMRegistry',
]
for regex in Gemini.supported_models():
LLMRegistry.register(Gemini)

View File

@@ -0,0 +1,243 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Anthropic integration for Claude models."""
from __future__ import annotations
from functools import cached_property
import logging
import os
from typing import AsyncGenerator
from typing import Generator
from typing import Iterable
from typing import Literal
from typing import Optional, Union
from typing import TYPE_CHECKING
from anthropic import AnthropicVertex
from anthropic import NOT_GIVEN
from anthropic import types as anthropic_types
from google.genai import types
from pydantic import BaseModel
from typing_extensions import override
from .base_llm import BaseLlm
from .llm_response import LlmResponse
if TYPE_CHECKING:
from .llm_request import LlmRequest
__all__ = ["Claude"]
logger = logging.getLogger(__name__)
MAX_TOKEN = 1024
class ClaudeRequest(BaseModel):
system_instruction: str
messages: Iterable[anthropic_types.MessageParam]
tools: list[anthropic_types.ToolParam]
def to_claude_role(role: Optional[str]) -> Literal["user", "assistant"]:
if role in ["model", "assistant"]:
return "assistant"
return "user"
def to_google_genai_finish_reason(
anthropic_stop_reason: Optional[str],
) -> types.FinishReason:
if anthropic_stop_reason in ["end_turn", "stop_sequence", "tool_use"]:
return "STOP"
if anthropic_stop_reason == "max_tokens":
return "MAX_TOKENS"
return "FINISH_REASON_UNSPECIFIED"
def part_to_message_block(
part: types.Part,
) -> Union[
anthropic_types.TextBlockParam,
anthropic_types.ImageBlockParam,
anthropic_types.ToolUseBlockParam,
anthropic_types.ToolResultBlockParam,
]:
if part.text:
return anthropic_types.TextBlockParam(text=part.text, type="text")
if part.function_call:
assert part.function_call.name
return anthropic_types.ToolUseBlockParam(
id=part.function_call.id or "",
name=part.function_call.name,
input=part.function_call.args,
type="tool_use",
)
if part.function_response:
content = ""
if (
"result" in part.function_response.response
and part.function_response.response["result"]
):
# Transformation is required because the content is a list of dict.
# ToolResultBlockParam content doesn't support list of dict. Converting
# to str to prevent anthropic.BadRequestError from being thrown.
content = str(part.function_response.response["result"])
return anthropic_types.ToolResultBlockParam(
tool_use_id=part.function_response.id or "",
type="tool_result",
content=content,
is_error=False,
)
raise NotImplementedError("Not supported yet.")
def content_to_message_param(
content: types.Content,
) -> anthropic_types.MessageParam:
return {
"role": to_claude_role(content.role),
"content": [part_to_message_block(part) for part in content.parts or []],
}
def content_block_to_part(
content_block: anthropic_types.ContentBlock,
) -> types.Part:
if isinstance(content_block, anthropic_types.TextBlock):
return types.Part.from_text(text=content_block.text)
if isinstance(content_block, anthropic_types.ToolUseBlock):
assert isinstance(content_block.input, dict)
part = types.Part.from_function_call(
name=content_block.name, args=content_block.input
)
part.function_call.id = content_block.id
return part
raise NotImplementedError("Not supported yet.")
def message_to_generate_content_response(
message: anthropic_types.Message,
) -> LlmResponse:
return LlmResponse(
content=types.Content(
role="model",
parts=[content_block_to_part(cb) for cb in message.content],
),
# TODO: Deal with these later.
# finish_reason=to_google_genai_finish_reason(message.stop_reason),
# usage_metadata=types.GenerateContentResponseUsageMetadata(
# prompt_token_count=message.usage.input_tokens,
# candidates_token_count=message.usage.output_tokens,
# total_token_count=(
# message.usage.input_tokens + message.usage.output_tokens
# ),
# ),
)
def function_declaration_to_tool_param(
function_declaration: types.FunctionDeclaration,
) -> anthropic_types.ToolParam:
assert function_declaration.name
properties = {}
if (
function_declaration.parameters
and function_declaration.parameters.properties
):
for key, value in function_declaration.parameters.properties.items():
value_dict = value.model_dump(exclude_none=True)
if "type" in value_dict:
value_dict["type"] = value_dict["type"].lower()
properties[key] = value_dict
return anthropic_types.ToolParam(
name=function_declaration.name,
description=function_declaration.description or "",
input_schema={
"type": "object",
"properties": properties,
},
)
class Claude(BaseLlm):
model: str = "claude-3-5-sonnet-v2@20241022"
@staticmethod
@override
def supported_models() -> list[str]:
return [r"claude-3-.*"]
@override
async def generate_content_async(
self, llm_request: LlmRequest, stream: bool = False
) -> AsyncGenerator[LlmResponse, None]:
messages = [
content_to_message_param(content)
for content in llm_request.contents or []
]
tools = NOT_GIVEN
if (
llm_request.config
and llm_request.config.tools
and llm_request.config.tools[0].function_declarations
):
tools = [
function_declaration_to_tool_param(tool)
for tool in llm_request.config.tools[0].function_declarations
]
tool_choice = (
anthropic_types.ToolChoiceAutoParam(
type="auto",
# TODO: allow parallel tool use.
disable_parallel_tool_use=True,
)
if llm_request.tools_dict
else NOT_GIVEN
)
message = self._anthropic_client.messages.create(
model=llm_request.model,
system=llm_request.config.system_instruction,
messages=messages,
tools=tools,
tool_choice=tool_choice,
max_tokens=MAX_TOKEN,
)
logger.info(
"Claude response: %s",
message.model_dump_json(indent=2, exclude_none=True),
)
yield message_to_generate_content_response(message)
@cached_property
def _anthropic_client(self) -> AnthropicVertex:
if (
"GOOGLE_CLOUD_PROJECT" not in os.environ
or "GOOGLE_CLOUD_LOCATION" not in os.environ
):
raise ValueError(
"GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION must be set for using"
" Anthropic on Vertex."
)
return AnthropicVertex(
project_id=os.environ["GOOGLE_CLOUD_PROJECT"],
region=os.environ["GOOGLE_CLOUD_LOCATION"],
)

View File

@@ -0,0 +1,87 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from abc import abstractmethod
from typing import AsyncGenerator
from typing import TYPE_CHECKING
from pydantic import BaseModel
from pydantic import ConfigDict
from .base_llm_connection import BaseLlmConnection
if TYPE_CHECKING:
from .llm_request import LlmRequest
from .llm_response import LlmResponse
class BaseLlm(BaseModel):
"""The BaseLLM class.
Attributes:
model: The name of the LLM, e.g. gemini-1.5-flash or gemini-1.5-flash-001.
model_config: The model config
"""
model_config = ConfigDict(
# This allows us to use arbitrary types in the model. E.g. PIL.Image.
arbitrary_types_allowed=True,
)
"""The model config."""
model: str
"""The name of the LLM, e.g. gemini-1.5-flash or gemini-1.5-flash-001."""
@classmethod
def supported_models(cls) -> list[str]:
"""Returns a list of supported models in regex for LlmRegistry."""
return []
@abstractmethod
async def generate_content_async(
self, llm_request: LlmRequest, stream: bool = False
) -> AsyncGenerator[LlmResponse, None]:
"""Generates one content from the given contents and tools.
Args:
llm_request: LlmRequest, the request to send to the LLM.
stream: bool = False, whether to do streaming call.
Yields:
a generator of types.Content.
For non-streaming call, it will only yield one Content.
For streaming call, it may yield more than one content, but all yielded
contents should be treated as one content by merging the
parts list.
"""
raise NotImplementedError(
f'Async generation is not supported for {self.model}.'
)
yield # AsyncGenerator requires a yield statement in function body.
def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
"""Creates a live connection to the LLM.
Args:
llm_request: LlmRequest, the request to send to the LLM.
Returns:
BaseLlmConnection, the connection to the LLM.
"""
raise NotImplementedError(
f'Live connection is not supported for {self.model}.'
)

View File

@@ -0,0 +1,76 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import abstractmethod
from typing import AsyncGenerator
from google.genai import types
from .llm_response import LlmResponse
class BaseLlmConnection:
"""The base class for a live model connection."""
@abstractmethod
async def send_history(self, history: list[types.Content]):
"""Sends the conversation history to the model.
You call this method right after setting up the model connection.
The model will respond if the last content is from user, otherwise it will
wait for new user input before responding.
Args:
history: The conversation history to send to the model.
"""
pass
@abstractmethod
async def send_content(self, content: types.Content):
"""Sends a user content to the model.
The model will respond immediately upon receiving the content.
If you send function responses, all parts in the content should be function
responses.
Args:
content: The content to send to the model.
"""
pass
@abstractmethod
async def send_realtime(self, blob: types.Blob):
"""Sends a chunk of audio or a frame of video to the model in realtime.
The model may not respond immediately upon receiving the blob. It will do
voice activity detection and decide when to respond.
Args:
blob: The blob to send to the model.
"""
pass
@abstractmethod
async def receive(self) -> AsyncGenerator[LlmResponse, None]:
"""Receives the model response using the llm server connection.
Args: None.
Yields:
LlmResponse: The model response.
"""
pass
@abstractmethod
async def close(self):
"""Closes the llm server connection."""
pass

View File

@@ -0,0 +1,200 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import AsyncGenerator
from google.genai import live
from google.genai import types
from .base_llm_connection import BaseLlmConnection
from .llm_response import LlmResponse
logger = logging.getLogger(__name__)
class GeminiLlmConnection(BaseLlmConnection):
"""The Gemini model connection."""
def __init__(self, gemini_session: live.AsyncSession):
self._gemini_session = gemini_session
async def send_history(self, history: list[types.Content]):
"""Sends the conversation history to the gemini model.
You call this method right after setting up the model connection.
The model will respond if the last content is from user, otherwise it will
wait for new user input before responding.
Args:
history: The conversation history to send to the model.
"""
# TODO: Remove this filter and translate unary contents to streaming
# contents properly.
# We ignore any audio from user during the agent transfer phase
contents = [
content
for content in history
if content.parts and content.parts[0].text
]
if contents:
await self._gemini_session.send(
input=types.LiveClientContent(
turns=contents,
turn_complete=contents[-1].role == 'user',
),
)
else:
logger.info('no content is sent')
async def send_content(self, content: types.Content):
"""Sends a user content to the gemini model.
The model will respond immediately upon receiving the content.
If you send function responses, all parts in the content should be function
responses.
Args:
content: The content to send to the model.
"""
assert content.parts
if content.parts[0].function_response:
# All parts have to be function responses.
function_responses = [part.function_response for part in content.parts]
logger.debug('Sending LLM function response: %s', function_responses)
await self._gemini_session.send(
input=types.LiveClientToolResponse(
function_responses=function_responses
),
)
else:
logger.debug('Sending LLM new content %s', content)
await self._gemini_session.send(
input=types.LiveClientContent(
turns=[content],
turn_complete=True,
)
)
async def send_realtime(self, blob: types.Blob):
"""Sends a chunk of audio or a frame of video to the model in realtime.
Args:
blob: The blob to send to the model.
"""
input_blob = blob.model_dump()
logger.debug('Sending LLM Blob: %s', input_blob)
await self._gemini_session.send(input=input_blob)
def __build_full_text_response(self, text: str):
"""Builds a full text response.
The text should not partial and the returned LlmResponse is not be
partial.
Args:
text: The text to be included in the response.
Returns:
An LlmResponse containing the full text.
"""
return LlmResponse(
content=types.Content(
role='model',
parts=[types.Part.from_text(text=text)],
),
)
async def receive(self) -> AsyncGenerator[LlmResponse, None]:
"""Receives the model response using the llm server connection.
Yields:
LlmResponse: The model response.
"""
text = ''
async for message in self._gemini_session.receive():
logger.debug('Got LLM Live message: %s', message)
if message.server_content:
content = message.server_content.model_turn
if content and content.parts:
llm_response = LlmResponse(
content=content, interrupted=message.server_content.interrupted
)
if content.parts[0].text:
text += content.parts[0].text
llm_response.partial = True
# don't yield the merged text event when receiving audio data
elif text and not content.parts[0].inline_data:
yield self.__build_full_text_response(text)
text = ''
yield llm_response
if (
message.server_content.output_transcription
and message.server_content.output_transcription.text
):
# TODO: Right now, we just support output_transcription without
# changing interface and data protocol. Later, we can consider to
# support output_transcription as a separete field in LlmResponse.
# Transcription is always considered as partial event
# We rely on other control signals to determine when to yield the
# full text response(turn_complete, interrupted, or tool_call).
text += message.server_content.output_transcription.text
parts = [
types.Part.from_text(
text=message.server_content.output_transcription.text
)
]
llm_response = LlmResponse(
content=types.Content(role='model', parts=parts), partial=True
)
yield llm_response
if message.server_content.turn_complete:
if text:
yield self.__build_full_text_response(text)
text = ''
yield LlmResponse(
turn_complete=True, interrupted=message.server_content.interrupted
)
break
# in case of empty content or parts, we sill surface it
# in case it's an interrupted message, we merge the previous partial
# text. Other we don't merge. because content can be none when model
# safty threshold is triggered
if message.server_content.interrupted and text:
yield self.__build_full_text_response(text)
text = ''
yield LlmResponse(interrupted=message.server_content.interrupted)
if message.tool_call:
if text:
yield self.__build_full_text_response(text)
text = ''
parts = [
types.Part(function_call=function_call)
for function_call in message.tool_call.function_calls
]
yield LlmResponse(content=types.Content(role='model', parts=parts))
async def close(self):
"""Closes the llm server connection."""
await self._gemini_session.close()

View File

@@ -0,0 +1,331 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import contextlib
from functools import cached_property
import logging
import sys
from typing import AsyncGenerator
from typing import cast
from typing import Generator
from typing import TYPE_CHECKING
from google.genai import Client
from google.genai import types
from typing_extensions import override
from .. import version
from .base_llm import BaseLlm
from .base_llm_connection import BaseLlmConnection
from .gemini_llm_connection import GeminiLlmConnection
from .llm_response import LlmResponse
if TYPE_CHECKING:
from .llm_request import LlmRequest
logger = logging.getLogger(__name__)
_NEW_LINE = '\n'
_EXCLUDED_PART_FIELD = {'inline_data': {'data'}}
class Gemini(BaseLlm):
"""Integration for Gemini models.
Attributes:
model: The name of the Gemini model.
"""
model: str = 'gemini-1.5-flash'
@staticmethod
@override
def supported_models() -> list[str]:
"""Provides the list of supported models.
Returns:
A list of supported models.
"""
return [
r'gemini-.*',
# fine-tuned vertex endpoint pattern
r'projects\/.+\/locations\/.+\/endpoints\/.+',
# vertex gemini long name
r'projects\/.+\/locations\/.+\/publishers\/google\/models\/gemini.+',
]
async def generate_content_async(
self, llm_request: LlmRequest, stream: bool = False
) -> AsyncGenerator[LlmResponse, None]:
"""Sends a request to the Gemini model.
Args:
llm_request: LlmRequest, the request to send to the Gemini model.
stream: bool = False, whether to do streaming call.
Yields:
LlmResponse: The model response.
"""
self._maybe_append_user_content(llm_request)
logger.info(
'Sending out request, model: %s, backend: %s, stream: %s',
llm_request.model,
self._api_backend,
stream,
)
logger.info(_build_request_log(llm_request))
if stream:
responses = await self.api_client.aio.models.generate_content_stream(
model=llm_request.model,
contents=llm_request.contents,
config=llm_request.config,
)
response = None
text = ''
# for sse, similar as bidi (see receive method in gemini_llm_connecton.py),
# we need to mark those text content as partial and after all partial
# contents are sent, we send an accumulated event which contains all the
# previous partial content. The only difference is bidi rely on
# complete_turn flag to detect end while sse depends on finish_reason.
async for response in responses:
logger.info(_build_response_log(response))
llm_response = LlmResponse.create(response)
if (
llm_response.content
and llm_response.content.parts
and llm_response.content.parts[0].text
):
text += llm_response.content.parts[0].text
llm_response.partial = True
elif text and (
not llm_response.content
or not llm_response.content.parts
# don't yield the merged text event when receiving audio data
or not llm_response.content.parts[0].inline_data
):
yield LlmResponse(
content=types.ModelContent(
parts=[types.Part.from_text(text=text)],
),
)
text = ''
yield llm_response
if (
text
and response
and response.candidates
and response.candidates[0].finish_reason == types.FinishReason.STOP
):
yield LlmResponse(
content=types.ModelContent(
parts=[types.Part.from_text(text=text)],
),
)
else:
response = await self.api_client.aio.models.generate_content(
model=llm_request.model,
contents=llm_request.contents,
config=llm_request.config,
)
logger.info(_build_response_log(response))
yield LlmResponse.create(response)
@cached_property
def api_client(self) -> Client:
"""Provides the api client.
Returns:
The api client.
"""
return Client(
http_options=types.HttpOptions(headers=self._tracking_headers)
)
@cached_property
def _api_backend(self) -> str:
return 'vertex' if self.api_client.vertexai else 'ml_dev'
@cached_property
def _tracking_headers(self) -> dict[str, str]:
framework_label = f'google-adk/{version.__version__}'
language_label = 'gl-python/' + sys.version.split()[0]
version_header_value = f'{framework_label} {language_label}'
tracking_headers = {
'x-goog-api-client': version_header_value,
'user-agent': version_header_value,
}
return tracking_headers
@cached_property
def _live_api_client(self) -> Client:
if self._api_backend == 'vertex':
# use default api version for vertex
return Client(
http_options=types.HttpOptions(headers=self._tracking_headers)
)
else:
# use v1alpha for ml_dev
api_version = 'v1alpha'
return Client(
http_options=types.HttpOptions(
headers=self._tracking_headers, api_version=api_version
)
)
@contextlib.asynccontextmanager
async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
"""Connects to the Gemini model and returns an llm connection.
Args:
llm_request: LlmRequest, the request to send to the Gemini model.
Yields:
BaseLlmConnection, the connection to the Gemini model.
"""
llm_request.live_connect_config.system_instruction = types.Content(
role='system',
parts=[
types.Part.from_text(text=llm_request.config.system_instruction)
],
)
llm_request.live_connect_config.tools = llm_request.config.tools
async with self._live_api_client.aio.live.connect(
model=llm_request.model, config=llm_request.live_connect_config
) as live_session:
yield GeminiLlmConnection(live_session)
def _maybe_append_user_content(self, llm_request: LlmRequest):
"""Appends a user content, so that model can continue to output.
Args:
llm_request: LlmRequest, the request to send to the Gemini model.
"""
# If no content is provided, append a user content to hint model response
# using system instruction.
if not llm_request.contents:
llm_request.contents.append(
types.Content(
role='user',
parts=[
types.Part(
text=(
'Handle the requests as specified in the System'
' Instruction.'
)
)
],
)
)
return
# Insert a user content to preserve user intent and to avoid empty
# model response.
if llm_request.contents[-1].role != 'user':
llm_request.contents.append(
types.Content(
role='user',
parts=[
types.Part(
text=(
'Continue processing previous requests as instructed.'
' Exit or provide a summary if no more outputs are'
' needed.'
)
)
],
)
)
def _build_function_declaration_log(
func_decl: types.FunctionDeclaration,
) -> str:
param_str = '{}'
if func_decl.parameters and func_decl.parameters.properties:
param_str = str({
k: v.model_dump(exclude_none=True)
for k, v in func_decl.parameters.properties.items()
})
return_str = 'None'
if func_decl.response:
return_str = str(func_decl.response.model_dump(exclude_none=True))
return f'{func_decl.name}: {param_str} -> {return_str}'
def _build_request_log(req: LlmRequest) -> str:
function_decls: list[types.FunctionDeclaration] = cast(
list[types.FunctionDeclaration],
req.config.tools[0].function_declarations if req.config.tools else [],
)
function_logs = (
[
_build_function_declaration_log(func_decl)
for func_decl in function_decls
]
if function_decls
else []
)
contents_logs = [
content.model_dump_json(
exclude_none=True,
exclude={
'parts': {
i: _EXCLUDED_PART_FIELD for i in range(len(content.parts))
}
},
)
for content in req.contents
]
return f"""
LLM Request:
-----------------------------------------------------------
System Instruction:
{req.config.system_instruction}
-----------------------------------------------------------
Contents:
{_NEW_LINE.join(contents_logs)}
-----------------------------------------------------------
Functions:
{_NEW_LINE.join(function_logs)}
-----------------------------------------------------------
"""
def _build_response_log(resp: types.GenerateContentResponse) -> str:
function_calls_text = []
if function_calls := resp.function_calls:
for func_call in function_calls:
function_calls_text.append(
f'name: {func_call.name}, args: {func_call.args}'
)
return f"""
LLM Response:
-----------------------------------------------------------
Text:
{resp.text}
-----------------------------------------------------------
Function calls:
{_NEW_LINE.join(function_calls_text)}
-----------------------------------------------------------
Raw response:
{resp.model_dump_json(exclude_none=True)}
-----------------------------------------------------------
"""

View File

@@ -0,0 +1,673 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import json
import logging
from typing import Any
from typing import AsyncGenerator
from typing import cast
from typing import Dict
from typing import Generator
from typing import Iterable
from typing import Literal
from typing import Optional
from typing import Tuple
from typing import Union
from google.genai import types
from litellm import acompletion
from litellm import ChatCompletionAssistantMessage
from litellm import ChatCompletionDeveloperMessage
from litellm import ChatCompletionImageUrlObject
from litellm import ChatCompletionMessageToolCall
from litellm import ChatCompletionTextObject
from litellm import ChatCompletionToolMessage
from litellm import ChatCompletionUserMessage
from litellm import ChatCompletionVideoUrlObject
from litellm import completion
from litellm import CustomStreamWrapper
from litellm import Function
from litellm import Message
from litellm import ModelResponse
from litellm import OpenAIMessageContent
from pydantic import BaseModel
from pydantic import Field
from typing_extensions import override
from .base_llm import BaseLlm
from .llm_request import LlmRequest
from .llm_response import LlmResponse
logger = logging.getLogger(__name__)
_NEW_LINE = "\n"
_EXCLUDED_PART_FIELD = {"inline_data": {"data"}}
class FunctionChunk(BaseModel):
id: Optional[str]
name: Optional[str]
args: Optional[str]
class TextChunk(BaseModel):
text: str
class LiteLLMClient:
"""Provides acompletion method (for better testability)."""
async def acompletion(
self, model, messages, tools, **kwargs
) -> Union[ModelResponse, CustomStreamWrapper]:
"""Asynchronously calls acompletion.
Args:
model: The model name.
messages: The messages to send to the model.
tools: The tools to use for the model.
**kwargs: Additional arguments to pass to acompletion.
Returns:
The model response as a message.
"""
return await acompletion(
model=model,
messages=messages,
tools=tools,
**kwargs,
)
def completion(
self, model, messages, tools, stream=False, **kwargs
) -> Union[ModelResponse, CustomStreamWrapper]:
"""Synchronously calls completion. This is used for streaming only.
Args:
model: The model to use.
messages: The messages to send.
tools: The tools to use for the model.
stream: Whether to stream the response.
**kwargs: Additional arguments to pass to completion.
Returns:
The response from the model.
"""
return completion(
model=model,
messages=messages,
tools=tools,
stream=stream,
**kwargs,
)
def _safe_json_serialize(obj) -> str:
"""Convert any Python object to a JSON-serializable type or string.
Args:
obj: The object to serialize.
Returns:
The JSON-serialized object string or string.
"""
try:
# Try direct JSON serialization first
return json.dumps(obj)
except (TypeError, OverflowError):
return str(obj)
def _content_to_message_param(
content: types.Content,
) -> Message:
"""Converts a types.Content to a litellm Message.
Args:
content: The content to convert.
Returns:
The litellm Message.
"""
if content.parts and content.parts[0].function_response:
return ChatCompletionToolMessage(
role="tool",
tool_call_id=content.parts[0].function_response.id,
content=_safe_json_serialize(
content.parts[0].function_response.response
),
)
role = _to_litellm_role(content.role)
if role == "user":
return ChatCompletionUserMessage(
role="user", content=_get_content(content.parts)
)
else:
tool_calls = [
ChatCompletionMessageToolCall(
type="function",
id=part.function_call.id,
function=Function(
name=part.function_call.name,
arguments=part.function_call.args,
),
)
for part in content.parts
if part.function_call
]
return ChatCompletionAssistantMessage(
role=role,
content=_get_content(content.parts),
tool_calls=tool_calls or None,
)
def _get_content(parts: Iterable[types.Part]) -> OpenAIMessageContent | str:
"""Converts a list of parts to litellm content.
Args:
parts: The parts to convert.
Returns:
The litellm content.
"""
content_objects = []
for part in parts:
if part.text:
if len(parts) == 1:
return part.text
content_objects.append(
ChatCompletionTextObject(
type="text",
text=part.text,
)
)
elif (
part.inline_data
and part.inline_data.data
and part.inline_data.mime_type
):
base64_string = base64.b64encode(part.inline_data.data).decode("utf-8")
data_uri = f"data:{part.inline_data.mime_type};base64,{base64_string}"
if part.inline_data.mime_type.startswith("image"):
content_objects.append(
ChatCompletionImageUrlObject(
type="image_url",
image_url=data_uri,
)
)
elif part.inline_data.mime_type.startswith("video"):
content_objects.append(
ChatCompletionVideoUrlObject(
type="video_url",
video_url=data_uri,
)
)
else:
raise ValueError("LiteLlm(BaseLlm) does not support this content part.")
return content_objects
def _to_litellm_role(role: Optional[str]) -> Literal["user", "assistant"]:
"""Converts a types.Content role to a litellm role.
Args:
role: The types.Content role.
Returns:
The litellm role.
"""
if role in ["model", "assistant"]:
return "assistant"
return "user"
TYPE_LABELS = {
"STRING": "string",
"NUMBER": "number",
"BOOLEAN": "boolean",
"OBJECT": "object",
"ARRAY": "array",
"INTEGER": "integer",
}
def _schema_to_dict(schema: types.Schema) -> dict:
"""Recursively converts a types.Schema to a dictionary.
Args:
schema: The schema to convert.
Returns:
The dictionary representation of the schema.
"""
schema_dict = schema.model_dump(exclude_none=True)
if "type" in schema_dict:
schema_dict["type"] = schema_dict["type"].lower()
if "items" in schema_dict:
if isinstance(schema_dict["items"], dict):
schema_dict["items"] = _schema_to_dict(
types.Schema.model_validate(schema_dict["items"])
)
elif isinstance(schema_dict["items"]["type"], types.Type):
schema_dict["items"]["type"] = TYPE_LABELS[
schema_dict["items"]["type"].value
]
if "properties" in schema_dict:
properties = {}
for key, value in schema_dict["properties"].items():
if isinstance(value, types.Schema):
properties[key] = _schema_to_dict(value)
else:
properties[key] = value
if "type" in properties[key]:
properties[key]["type"] = properties[key]["type"].lower()
schema_dict["properties"] = properties
return schema_dict
def _function_declaration_to_tool_param(
function_declaration: types.FunctionDeclaration,
) -> dict:
"""Converts a types.FunctionDeclaration to a openapi spec dictionary.
Args:
function_declaration: The function declaration to convert.
Returns:
The openapi spec dictionary representation of the function declaration.
"""
assert function_declaration.name
properties = {}
if (
function_declaration.parameters
and function_declaration.parameters.properties
):
for key, value in function_declaration.parameters.properties.items():
properties[key] = _schema_to_dict(value)
return {
"type": "function",
"function": {
"name": function_declaration.name,
"description": function_declaration.description or "",
"parameters": {
"type": "object",
"properties": properties,
},
},
}
def _model_response_to_chunk(
response: ModelResponse,
) -> Generator[
Tuple[Optional[Union[TextChunk, FunctionChunk]], Optional[str]], None, None
]:
"""Converts a litellm message to text or function chunk.
Args:
response: The response from the model.
Yields:
A tuple of text or function chunk and finish reason.
"""
message = None
if response.get("choices", None):
message = response["choices"][0].get("message", None)
finish_reason = response["choices"][0].get("finish_reason", None)
# check streaming delta
if message is None and response["choices"][0].get("delta", None):
message = response["choices"][0]["delta"]
if message.get("content", None):
yield TextChunk(text=message.get("content")), finish_reason
if message.get("tool_calls", None):
for tool_call in message.get("tool_calls"):
# aggregate tool_call
if tool_call.type == "function":
yield FunctionChunk(
id=tool_call.id,
name=tool_call.function.name,
args=tool_call.function.arguments,
), finish_reason
if finish_reason and not (
message.get("content", None) or message.get("tool_calls", None)
):
yield None, finish_reason
if not message:
yield None, None
def _model_response_to_generate_content_response(
response: ModelResponse,
) -> LlmResponse:
"""Converts a litellm response to LlmResponse.
Args:
response: The model response.
Returns:
The LlmResponse.
"""
message = None
if response.get("choices", None):
message = response["choices"][0].get("message", None)
if not message:
raise ValueError("No message in response")
return _message_to_generate_content_response(message)
def _message_to_generate_content_response(
message: Message, is_partial: bool = False
) -> LlmResponse:
"""Converts a litellm message to LlmResponse.
Args:
message: The message to convert.
is_partial: Whether the message is partial.
Returns:
The LlmResponse.
"""
parts = []
if message.get("content", None):
parts.append(types.Part.from_text(text=message.get("content")))
if message.get("tool_calls", None):
for tool_call in message.get("tool_calls"):
if tool_call.type == "function":
part = types.Part.from_function_call(
name=tool_call.function.name,
args=json.loads(tool_call.function.arguments or "{}"),
)
part.function_call.id = tool_call.id
parts.append(part)
return LlmResponse(
content=types.Content(role="model", parts=parts), partial=is_partial
)
def _get_completion_inputs(
llm_request: LlmRequest,
) -> tuple[Iterable[Message], Iterable[dict]]:
"""Converts an LlmRequest to litellm inputs.
Args:
llm_request: The LlmRequest to convert.
Returns:
The litellm inputs (message list and tool dictionary).
"""
messages = [
_content_to_message_param(content)
for content in llm_request.contents or []
]
if llm_request.config.system_instruction:
messages.insert(
0,
ChatCompletionDeveloperMessage(
role="developer",
content=llm_request.config.system_instruction,
),
)
tools = None
if (
llm_request.config
and llm_request.config.tools
and llm_request.config.tools[0].function_declarations
):
tools = [
_function_declaration_to_tool_param(tool)
for tool in llm_request.config.tools[0].function_declarations
]
return messages, tools
def _build_function_declaration_log(
func_decl: types.FunctionDeclaration,
) -> str:
"""Builds a function declaration log.
Args:
func_decl: The function declaration to convert.
Returns:
The function declaration log.
"""
param_str = "{}"
if func_decl.parameters and func_decl.parameters.properties:
param_str = str({
k: v.model_dump(exclude_none=True)
for k, v in func_decl.parameters.properties.items()
})
return_str = "None"
if func_decl.response:
return_str = str(func_decl.response.model_dump(exclude_none=True))
return f"{func_decl.name}: {param_str} -> {return_str}"
def _build_request_log(req: LlmRequest) -> str:
"""Builds a request log.
Args:
req: The request to convert.
Returns:
The request log.
"""
function_decls: list[types.FunctionDeclaration] = cast(
list[types.FunctionDeclaration],
req.config.tools[0].function_declarations if req.config.tools else [],
)
function_logs = (
[
_build_function_declaration_log(func_decl)
for func_decl in function_decls
]
if function_decls
else []
)
contents_logs = [
content.model_dump_json(
exclude_none=True,
exclude={
"parts": {
i: _EXCLUDED_PART_FIELD for i in range(len(content.parts))
}
},
)
for content in req.contents
]
return f"""
LLM Request:
-----------------------------------------------------------
System Instruction:
{req.config.system_instruction}
-----------------------------------------------------------
Contents:
{_NEW_LINE.join(contents_logs)}
-----------------------------------------------------------
Functions:
{_NEW_LINE.join(function_logs)}
-----------------------------------------------------------
"""
class LiteLlm(BaseLlm):
"""Wrapper around litellm.
This wrapper can be used with any of the models supported by litellm. The
environment variable(s) needed for authenticating with the model endpoint must
be set prior to instantiating this class.
Example usage:
```
os.environ["VERTEXAI_PROJECT"] = "your-gcp-project-id"
os.environ["VERTEXAI_LOCATION"] = "your-gcp-location"
agent = Agent(
model=LiteLlm(model="vertex_ai/claude-3-7-sonnet@20250219"),
...
)
```
Attributes:
model: The name of the LiteLlm model.
llm_client: The LLM client to use for the model.
model_config: The model config.
"""
llm_client: LiteLLMClient = Field(default_factory=LiteLLMClient)
"""The LLM client to use for the model."""
_additional_args: Dict[str, Any] = None
def __init__(self, model: str, **kwargs):
"""Initializes the LiteLlm class.
Args:
model: The name of the LiteLlm model.
**kwargs: Additional arguments to pass to the litellm completion api.
"""
super().__init__(model=model, **kwargs)
self._additional_args = kwargs
# preventing generation call with llm_client
# and overriding messages, tools and stream which are managed internally
self._additional_args.pop("llm_client", None)
self._additional_args.pop("messages", None)
self._additional_args.pop("tools", None)
# public api called from runner determines to stream or not
self._additional_args.pop("stream", None)
async def generate_content_async(
self, llm_request: LlmRequest, stream: bool = False
) -> AsyncGenerator[LlmResponse, None]:
"""Generates content asynchronously.
Args:
llm_request: LlmRequest, the request to send to the LiteLlm model.
stream: bool = False, whether to do streaming call.
Yields:
LlmResponse: The model response.
"""
logger.info(_build_request_log(llm_request))
messages, tools = _get_completion_inputs(llm_request)
completion_args = {
"model": self.model,
"messages": messages,
"tools": tools,
}
completion_args.update(self._additional_args)
if stream:
text = ""
function_name = ""
function_args = ""
function_id = None
completion_args["stream"] = True
for part in self.llm_client.completion(**completion_args):
for chunk, finish_reason in _model_response_to_chunk(part):
if isinstance(chunk, FunctionChunk):
if chunk.name:
function_name += chunk.name
if chunk.args:
function_args += chunk.args
function_id = chunk.id or function_id
elif isinstance(chunk, TextChunk):
text += chunk.text
yield _message_to_generate_content_response(
ChatCompletionAssistantMessage(
role="assistant",
content=chunk.text,
),
is_partial=True,
)
if finish_reason == "tool_calls" and function_id:
yield _message_to_generate_content_response(
ChatCompletionAssistantMessage(
role="assistant",
content="",
tool_calls=[
ChatCompletionMessageToolCall(
type="function",
id=function_id,
function=Function(
name=function_name,
arguments=function_args,
),
)
],
)
)
function_name = ""
function_args = ""
function_id = None
elif finish_reason == "stop" and text:
yield _message_to_generate_content_response(
ChatCompletionAssistantMessage(role="assistant", content=text)
)
text = ""
else:
response = await self.llm_client.acompletion(**completion_args)
yield _model_response_to_generate_content_response(response)
@staticmethod
@override
def supported_models() -> list[str]:
"""Provides the list of supported models.
LiteLlm supports all models supported by litellm. We do not keep track of
these models here. So we return an empty list.
Returns:
A list of supported models.
"""
return []

View File

@@ -0,0 +1,98 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Optional
from google.genai import types
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import Field
from ..tools.base_tool import BaseTool
class LlmRequest(BaseModel):
"""LLM request class that allows passing in tools, output schema and system
instructions to the model.
Attributes:
model: The model name.
contents: The contents to send to the model.
config: Additional config for the generate content request.
tools_dict: The tools dictionary.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
"""The model config."""
model: Optional[str] = None
"""The model name."""
contents: list[types.Content] = Field(default_factory=list)
"""The contents to send to the model."""
config: Optional[types.GenerateContentConfig] = None
live_connect_config: types.LiveConnectConfig = types.LiveConnectConfig()
"""Additional config for the generate content request.
tools in generate_content_config should not be set.
"""
tools_dict: dict[str, BaseTool] = Field(default_factory=dict, exclude=True)
"""The tools dictionary."""
def append_instructions(self, instructions: list[str]) -> None:
"""Appends instructions to the system instruction.
Args:
instructions: The instructions to append.
"""
if self.config.system_instruction:
self.config.system_instruction += '\n\n' + '\n\n'.join(instructions)
else:
self.config.system_instruction = '\n\n'.join(instructions)
def append_tools(self, tools: list[BaseTool]) -> None:
"""Appends tools to the request.
Args:
tools: The tools to append.
"""
if not tools:
return
declarations = []
for tool in tools:
if isinstance(tool, BaseTool):
declaration = tool._get_declaration()
else:
declaration = tool.get_declaration()
if declaration:
declarations.append(declaration)
self.tools_dict[tool.name] = tool
if declarations:
self.config.tools.append(types.Tool(function_declarations=declarations))
def set_output_schema(self, base_model: type[BaseModel]) -> None:
"""Sets the output schema for the request.
Args:
base_model: The pydantic base model to set the output schema to.
"""
self.config.response_schema = base_model
self.config.response_mime_type = 'application/json'

View File

@@ -0,0 +1,111 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Optional
from google.genai import types
from pydantic import BaseModel
from pydantic import ConfigDict
class LlmResponse(BaseModel):
"""LLM response class that provides the first candidate response from the
model if available. Otherwise, returns error code and message.
Attributes:
content: The content of the response.
grounding_metadata: The grounding metadata of the response.
partial: Indicates whether the text content is part of a unfinished text
stream. Only used for streaming mode and when the content is plain text.
turn_complete: Indicates whether the response from the model is complete.
Only used for streaming mode.
error_code: Error code if the response is an error. Code varies by model.
error_message: Error message if the response is an error.
interrupted: Flag indicating that LLM was interrupted when generating the
content. Usually it's due to user interruption during a bidi streaming.
"""
model_config = ConfigDict(extra='forbid')
"""The model config."""
content: Optional[types.Content] = None
"""The content of the response."""
grounding_metadata: Optional[types.GroundingMetadata] = None
"""The grounding metadata of the response."""
partial: Optional[bool] = None
"""Indicates whether the text content is part of a unfinished text stream.
Only used for streaming mode and when the content is plain text.
"""
turn_complete: Optional[bool] = None
"""Indicates whether the response from the model is complete.
Only used for streaming mode.
"""
error_code: Optional[str] = None
"""Error code if the response is an error. Code varies by model."""
error_message: Optional[str] = None
"""Error message if the response is an error."""
interrupted: Optional[bool] = None
"""Flag indicating that LLM was interrupted when generating the content.
Usually it's due to user interruption during a bidi streaming.
"""
@staticmethod
def create(
generate_content_response: types.GenerateContentResponse,
) -> 'LlmResponse':
"""Creates an LlmResponse from a GenerateContentResponse.
Args:
generate_content_response: The GenerateContentResponse to create the
LlmResponse from.
Returns:
The LlmResponse.
"""
if generate_content_response.candidates:
candidate = generate_content_response.candidates[0]
if candidate.content and candidate.content.parts:
return LlmResponse(
content=candidate.content,
grounding_metadata=candidate.grounding_metadata,
)
else:
return LlmResponse(
error_code=candidate.finish_reason,
error_message=candidate.finish_message,
)
else:
if generate_content_response.prompt_feedback:
prompt_feedback = generate_content_response.prompt_feedback
return LlmResponse(
error_code=prompt_feedback.block_reason,
error_message=prompt_feedback.block_reason_message,
)
else:
return LlmResponse(
error_code='UNKNOWN_ERROR',
error_message='Unknown error.',
)

View File

@@ -0,0 +1,102 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The registry class for model."""
from __future__ import annotations
from functools import lru_cache
import logging
import re
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .base_llm import BaseLlm
logger = logging.getLogger(__name__)
_llm_registry_dict: dict[str, type[BaseLlm]] = {}
"""Registry for LLMs.
Key is the regex that matches the model name.
Value is the class that implements the model.
"""
class LLMRegistry:
"""Registry for LLMs."""
@staticmethod
def new_llm(model: str) -> BaseLlm:
"""Creates a new LLM instance.
Args:
model: The model name.
Returns:
The LLM instance.
"""
return LLMRegistry.resolve(model)(model=model)
@staticmethod
def _register(model_name_regex: str, llm_cls: type[BaseLlm]):
"""Registers a new LLM class.
Args:
model_name_regex: The regex that matches the model name.
llm_cls: The class that implements the model.
"""
if model_name_regex in _llm_registry_dict:
logger.info(
'Updating LLM class for %s from %s to %s',
model_name_regex,
_llm_registry_dict[model_name_regex],
llm_cls,
)
_llm_registry_dict[model_name_regex] = llm_cls
@staticmethod
def register(llm_cls: type[BaseLlm]):
"""Registers a new LLM class.
Args:
llm_cls: The class that implements the model.
"""
for regex in llm_cls.supported_models():
LLMRegistry._register(regex, llm_cls)
@staticmethod
@lru_cache(maxsize=32)
def resolve(model: str) -> type[BaseLlm]:
"""Resolves the model to a BaseLlm subclass.
Args:
model: The model name.
Returns:
The BaseLlm subclass.
Raises:
ValueError: If the model is not found.
"""
for regex, llm_class in _llm_registry_dict.items():
if re.compile(regex).fullmatch(model):
return llm_class
raise ValueError(f'Model {model} not found.')