mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-12-23 21:57:44 -06:00
Agent Development Kit(ADK)
An easy-to-use and powerful framework to build AI agents.
This commit is contained in:
31
src/google/adk/models/__init__.py
Normal file
31
src/google/adk/models/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Defines the interface to support a model."""
|
||||
|
||||
from .base_llm import BaseLlm
|
||||
from .google_llm import Gemini
|
||||
from .llm_request import LlmRequest
|
||||
from .llm_response import LlmResponse
|
||||
from .registry import LLMRegistry
|
||||
|
||||
__all__ = [
|
||||
'BaseLlm',
|
||||
'Gemini',
|
||||
'LLMRegistry',
|
||||
]
|
||||
|
||||
|
||||
for regex in Gemini.supported_models():
|
||||
LLMRegistry.register(Gemini)
|
||||
243
src/google/adk/models/anthropic_llm.py
Normal file
243
src/google/adk/models/anthropic_llm.py
Normal file
@@ -0,0 +1,243 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Anthropic integration for Claude models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import cached_property
|
||||
import logging
|
||||
import os
|
||||
from typing import AsyncGenerator
|
||||
from typing import Generator
|
||||
from typing import Iterable
|
||||
from typing import Literal
|
||||
from typing import Optional, Union
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from anthropic import AnthropicVertex
|
||||
from anthropic import NOT_GIVEN
|
||||
from anthropic import types as anthropic_types
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import override
|
||||
|
||||
from .base_llm import BaseLlm
|
||||
from .llm_response import LlmResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .llm_request import LlmRequest
|
||||
|
||||
__all__ = ["Claude"]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_TOKEN = 1024
|
||||
|
||||
|
||||
class ClaudeRequest(BaseModel):
|
||||
system_instruction: str
|
||||
messages: Iterable[anthropic_types.MessageParam]
|
||||
tools: list[anthropic_types.ToolParam]
|
||||
|
||||
|
||||
def to_claude_role(role: Optional[str]) -> Literal["user", "assistant"]:
|
||||
if role in ["model", "assistant"]:
|
||||
return "assistant"
|
||||
return "user"
|
||||
|
||||
|
||||
def to_google_genai_finish_reason(
|
||||
anthropic_stop_reason: Optional[str],
|
||||
) -> types.FinishReason:
|
||||
if anthropic_stop_reason in ["end_turn", "stop_sequence", "tool_use"]:
|
||||
return "STOP"
|
||||
if anthropic_stop_reason == "max_tokens":
|
||||
return "MAX_TOKENS"
|
||||
return "FINISH_REASON_UNSPECIFIED"
|
||||
|
||||
|
||||
def part_to_message_block(
|
||||
part: types.Part,
|
||||
) -> Union[
|
||||
anthropic_types.TextBlockParam,
|
||||
anthropic_types.ImageBlockParam,
|
||||
anthropic_types.ToolUseBlockParam,
|
||||
anthropic_types.ToolResultBlockParam,
|
||||
]:
|
||||
if part.text:
|
||||
return anthropic_types.TextBlockParam(text=part.text, type="text")
|
||||
if part.function_call:
|
||||
assert part.function_call.name
|
||||
|
||||
return anthropic_types.ToolUseBlockParam(
|
||||
id=part.function_call.id or "",
|
||||
name=part.function_call.name,
|
||||
input=part.function_call.args,
|
||||
type="tool_use",
|
||||
)
|
||||
if part.function_response:
|
||||
content = ""
|
||||
if (
|
||||
"result" in part.function_response.response
|
||||
and part.function_response.response["result"]
|
||||
):
|
||||
# Transformation is required because the content is a list of dict.
|
||||
# ToolResultBlockParam content doesn't support list of dict. Converting
|
||||
# to str to prevent anthropic.BadRequestError from being thrown.
|
||||
content = str(part.function_response.response["result"])
|
||||
return anthropic_types.ToolResultBlockParam(
|
||||
tool_use_id=part.function_response.id or "",
|
||||
type="tool_result",
|
||||
content=content,
|
||||
is_error=False,
|
||||
)
|
||||
raise NotImplementedError("Not supported yet.")
|
||||
|
||||
|
||||
def content_to_message_param(
|
||||
content: types.Content,
|
||||
) -> anthropic_types.MessageParam:
|
||||
return {
|
||||
"role": to_claude_role(content.role),
|
||||
"content": [part_to_message_block(part) for part in content.parts or []],
|
||||
}
|
||||
|
||||
|
||||
def content_block_to_part(
|
||||
content_block: anthropic_types.ContentBlock,
|
||||
) -> types.Part:
|
||||
if isinstance(content_block, anthropic_types.TextBlock):
|
||||
return types.Part.from_text(text=content_block.text)
|
||||
if isinstance(content_block, anthropic_types.ToolUseBlock):
|
||||
assert isinstance(content_block.input, dict)
|
||||
part = types.Part.from_function_call(
|
||||
name=content_block.name, args=content_block.input
|
||||
)
|
||||
part.function_call.id = content_block.id
|
||||
return part
|
||||
raise NotImplementedError("Not supported yet.")
|
||||
|
||||
|
||||
def message_to_generate_content_response(
|
||||
message: anthropic_types.Message,
|
||||
) -> LlmResponse:
|
||||
|
||||
return LlmResponse(
|
||||
content=types.Content(
|
||||
role="model",
|
||||
parts=[content_block_to_part(cb) for cb in message.content],
|
||||
),
|
||||
# TODO: Deal with these later.
|
||||
# finish_reason=to_google_genai_finish_reason(message.stop_reason),
|
||||
# usage_metadata=types.GenerateContentResponseUsageMetadata(
|
||||
# prompt_token_count=message.usage.input_tokens,
|
||||
# candidates_token_count=message.usage.output_tokens,
|
||||
# total_token_count=(
|
||||
# message.usage.input_tokens + message.usage.output_tokens
|
||||
# ),
|
||||
# ),
|
||||
)
|
||||
|
||||
|
||||
def function_declaration_to_tool_param(
|
||||
function_declaration: types.FunctionDeclaration,
|
||||
) -> anthropic_types.ToolParam:
|
||||
assert function_declaration.name
|
||||
|
||||
properties = {}
|
||||
if (
|
||||
function_declaration.parameters
|
||||
and function_declaration.parameters.properties
|
||||
):
|
||||
for key, value in function_declaration.parameters.properties.items():
|
||||
value_dict = value.model_dump(exclude_none=True)
|
||||
if "type" in value_dict:
|
||||
value_dict["type"] = value_dict["type"].lower()
|
||||
properties[key] = value_dict
|
||||
|
||||
return anthropic_types.ToolParam(
|
||||
name=function_declaration.name,
|
||||
description=function_declaration.description or "",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class Claude(BaseLlm):
|
||||
model: str = "claude-3-5-sonnet-v2@20241022"
|
||||
|
||||
@staticmethod
|
||||
@override
|
||||
def supported_models() -> list[str]:
|
||||
return [r"claude-3-.*"]
|
||||
|
||||
@override
|
||||
async def generate_content_async(
|
||||
self, llm_request: LlmRequest, stream: bool = False
|
||||
) -> AsyncGenerator[LlmResponse, None]:
|
||||
messages = [
|
||||
content_to_message_param(content)
|
||||
for content in llm_request.contents or []
|
||||
]
|
||||
tools = NOT_GIVEN
|
||||
if (
|
||||
llm_request.config
|
||||
and llm_request.config.tools
|
||||
and llm_request.config.tools[0].function_declarations
|
||||
):
|
||||
tools = [
|
||||
function_declaration_to_tool_param(tool)
|
||||
for tool in llm_request.config.tools[0].function_declarations
|
||||
]
|
||||
tool_choice = (
|
||||
anthropic_types.ToolChoiceAutoParam(
|
||||
type="auto",
|
||||
# TODO: allow parallel tool use.
|
||||
disable_parallel_tool_use=True,
|
||||
)
|
||||
if llm_request.tools_dict
|
||||
else NOT_GIVEN
|
||||
)
|
||||
message = self._anthropic_client.messages.create(
|
||||
model=llm_request.model,
|
||||
system=llm_request.config.system_instruction,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
max_tokens=MAX_TOKEN,
|
||||
)
|
||||
logger.info(
|
||||
"Claude response: %s",
|
||||
message.model_dump_json(indent=2, exclude_none=True),
|
||||
)
|
||||
yield message_to_generate_content_response(message)
|
||||
|
||||
@cached_property
|
||||
def _anthropic_client(self) -> AnthropicVertex:
|
||||
if (
|
||||
"GOOGLE_CLOUD_PROJECT" not in os.environ
|
||||
or "GOOGLE_CLOUD_LOCATION" not in os.environ
|
||||
):
|
||||
raise ValueError(
|
||||
"GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION must be set for using"
|
||||
" Anthropic on Vertex."
|
||||
)
|
||||
|
||||
return AnthropicVertex(
|
||||
project_id=os.environ["GOOGLE_CLOUD_PROJECT"],
|
||||
region=os.environ["GOOGLE_CLOUD_LOCATION"],
|
||||
)
|
||||
87
src/google/adk/models/base_llm.py
Normal file
87
src/google/adk/models/base_llm.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import AsyncGenerator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from .base_llm_connection import BaseLlmConnection
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .llm_request import LlmRequest
|
||||
from .llm_response import LlmResponse
|
||||
|
||||
|
||||
class BaseLlm(BaseModel):
|
||||
"""The BaseLLM class.
|
||||
|
||||
Attributes:
|
||||
model: The name of the LLM, e.g. gemini-1.5-flash or gemini-1.5-flash-001.
|
||||
model_config: The model config
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
# This allows us to use arbitrary types in the model. E.g. PIL.Image.
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
"""The model config."""
|
||||
|
||||
model: str
|
||||
"""The name of the LLM, e.g. gemini-1.5-flash or gemini-1.5-flash-001."""
|
||||
|
||||
@classmethod
|
||||
def supported_models(cls) -> list[str]:
|
||||
"""Returns a list of supported models in regex for LlmRegistry."""
|
||||
return []
|
||||
|
||||
@abstractmethod
|
||||
async def generate_content_async(
|
||||
self, llm_request: LlmRequest, stream: bool = False
|
||||
) -> AsyncGenerator[LlmResponse, None]:
|
||||
"""Generates one content from the given contents and tools.
|
||||
|
||||
Args:
|
||||
llm_request: LlmRequest, the request to send to the LLM.
|
||||
stream: bool = False, whether to do streaming call.
|
||||
|
||||
Yields:
|
||||
a generator of types.Content.
|
||||
|
||||
For non-streaming call, it will only yield one Content.
|
||||
|
||||
For streaming call, it may yield more than one content, but all yielded
|
||||
contents should be treated as one content by merging the
|
||||
parts list.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f'Async generation is not supported for {self.model}.'
|
||||
)
|
||||
yield # AsyncGenerator requires a yield statement in function body.
|
||||
|
||||
def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
|
||||
"""Creates a live connection to the LLM.
|
||||
|
||||
Args:
|
||||
llm_request: LlmRequest, the request to send to the LLM.
|
||||
|
||||
Returns:
|
||||
BaseLlmConnection, the connection to the LLM.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f'Live connection is not supported for {self.model}.'
|
||||
)
|
||||
76
src/google/adk/models/base_llm_connection.py
Normal file
76
src/google/adk/models/base_llm_connection.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import AsyncGenerator
|
||||
from google.genai import types
|
||||
from .llm_response import LlmResponse
|
||||
|
||||
|
||||
class BaseLlmConnection:
|
||||
"""The base class for a live model connection."""
|
||||
|
||||
@abstractmethod
|
||||
async def send_history(self, history: list[types.Content]):
|
||||
"""Sends the conversation history to the model.
|
||||
|
||||
You call this method right after setting up the model connection.
|
||||
The model will respond if the last content is from user, otherwise it will
|
||||
wait for new user input before responding.
|
||||
|
||||
Args:
|
||||
history: The conversation history to send to the model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def send_content(self, content: types.Content):
|
||||
"""Sends a user content to the model.
|
||||
|
||||
The model will respond immediately upon receiving the content.
|
||||
If you send function responses, all parts in the content should be function
|
||||
responses.
|
||||
|
||||
Args:
|
||||
content: The content to send to the model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def send_realtime(self, blob: types.Blob):
|
||||
"""Sends a chunk of audio or a frame of video to the model in realtime.
|
||||
|
||||
The model may not respond immediately upon receiving the blob. It will do
|
||||
voice activity detection and decide when to respond.
|
||||
|
||||
Args:
|
||||
blob: The blob to send to the model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def receive(self) -> AsyncGenerator[LlmResponse, None]:
|
||||
"""Receives the model response using the llm server connection.
|
||||
|
||||
Args: None.
|
||||
|
||||
Yields:
|
||||
LlmResponse: The model response.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def close(self):
|
||||
"""Closes the llm server connection."""
|
||||
pass
|
||||
200
src/google/adk/models/gemini_llm_connection.py
Normal file
200
src/google/adk/models/gemini_llm_connection.py
Normal file
@@ -0,0 +1,200 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from google.genai import live
|
||||
from google.genai import types
|
||||
|
||||
from .base_llm_connection import BaseLlmConnection
|
||||
from .llm_response import LlmResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GeminiLlmConnection(BaseLlmConnection):
|
||||
"""The Gemini model connection."""
|
||||
|
||||
def __init__(self, gemini_session: live.AsyncSession):
|
||||
self._gemini_session = gemini_session
|
||||
|
||||
async def send_history(self, history: list[types.Content]):
|
||||
"""Sends the conversation history to the gemini model.
|
||||
|
||||
You call this method right after setting up the model connection.
|
||||
The model will respond if the last content is from user, otherwise it will
|
||||
wait for new user input before responding.
|
||||
|
||||
Args:
|
||||
history: The conversation history to send to the model.
|
||||
"""
|
||||
|
||||
# TODO: Remove this filter and translate unary contents to streaming
|
||||
# contents properly.
|
||||
|
||||
# We ignore any audio from user during the agent transfer phase
|
||||
contents = [
|
||||
content
|
||||
for content in history
|
||||
if content.parts and content.parts[0].text
|
||||
]
|
||||
|
||||
if contents:
|
||||
await self._gemini_session.send(
|
||||
input=types.LiveClientContent(
|
||||
turns=contents,
|
||||
turn_complete=contents[-1].role == 'user',
|
||||
),
|
||||
)
|
||||
else:
|
||||
logger.info('no content is sent')
|
||||
|
||||
async def send_content(self, content: types.Content):
|
||||
"""Sends a user content to the gemini model.
|
||||
|
||||
The model will respond immediately upon receiving the content.
|
||||
If you send function responses, all parts in the content should be function
|
||||
responses.
|
||||
|
||||
Args:
|
||||
content: The content to send to the model.
|
||||
"""
|
||||
|
||||
assert content.parts
|
||||
if content.parts[0].function_response:
|
||||
# All parts have to be function responses.
|
||||
function_responses = [part.function_response for part in content.parts]
|
||||
logger.debug('Sending LLM function response: %s', function_responses)
|
||||
await self._gemini_session.send(
|
||||
input=types.LiveClientToolResponse(
|
||||
function_responses=function_responses
|
||||
),
|
||||
)
|
||||
else:
|
||||
logger.debug('Sending LLM new content %s', content)
|
||||
await self._gemini_session.send(
|
||||
input=types.LiveClientContent(
|
||||
turns=[content],
|
||||
turn_complete=True,
|
||||
)
|
||||
)
|
||||
|
||||
async def send_realtime(self, blob: types.Blob):
|
||||
"""Sends a chunk of audio or a frame of video to the model in realtime.
|
||||
|
||||
Args:
|
||||
blob: The blob to send to the model.
|
||||
"""
|
||||
|
||||
input_blob = blob.model_dump()
|
||||
logger.debug('Sending LLM Blob: %s', input_blob)
|
||||
await self._gemini_session.send(input=input_blob)
|
||||
|
||||
def __build_full_text_response(self, text: str):
|
||||
"""Builds a full text response.
|
||||
|
||||
The text should not partial and the returned LlmResponse is not be
|
||||
partial.
|
||||
|
||||
Args:
|
||||
text: The text to be included in the response.
|
||||
|
||||
Returns:
|
||||
An LlmResponse containing the full text.
|
||||
"""
|
||||
return LlmResponse(
|
||||
content=types.Content(
|
||||
role='model',
|
||||
parts=[types.Part.from_text(text=text)],
|
||||
),
|
||||
)
|
||||
|
||||
async def receive(self) -> AsyncGenerator[LlmResponse, None]:
|
||||
"""Receives the model response using the llm server connection.
|
||||
|
||||
Yields:
|
||||
LlmResponse: The model response.
|
||||
"""
|
||||
|
||||
text = ''
|
||||
async for message in self._gemini_session.receive():
|
||||
logger.debug('Got LLM Live message: %s', message)
|
||||
if message.server_content:
|
||||
content = message.server_content.model_turn
|
||||
if content and content.parts:
|
||||
llm_response = LlmResponse(
|
||||
content=content, interrupted=message.server_content.interrupted
|
||||
)
|
||||
if content.parts[0].text:
|
||||
text += content.parts[0].text
|
||||
llm_response.partial = True
|
||||
# don't yield the merged text event when receiving audio data
|
||||
elif text and not content.parts[0].inline_data:
|
||||
yield self.__build_full_text_response(text)
|
||||
text = ''
|
||||
yield llm_response
|
||||
|
||||
if (
|
||||
message.server_content.output_transcription
|
||||
and message.server_content.output_transcription.text
|
||||
):
|
||||
# TODO: Right now, we just support output_transcription without
|
||||
# changing interface and data protocol. Later, we can consider to
|
||||
# support output_transcription as a separete field in LlmResponse.
|
||||
|
||||
# Transcription is always considered as partial event
|
||||
# We rely on other control signals to determine when to yield the
|
||||
# full text response(turn_complete, interrupted, or tool_call).
|
||||
text += message.server_content.output_transcription.text
|
||||
parts = [
|
||||
types.Part.from_text(
|
||||
text=message.server_content.output_transcription.text
|
||||
)
|
||||
]
|
||||
llm_response = LlmResponse(
|
||||
content=types.Content(role='model', parts=parts), partial=True
|
||||
)
|
||||
yield llm_response
|
||||
|
||||
if message.server_content.turn_complete:
|
||||
if text:
|
||||
yield self.__build_full_text_response(text)
|
||||
text = ''
|
||||
yield LlmResponse(
|
||||
turn_complete=True, interrupted=message.server_content.interrupted
|
||||
)
|
||||
break
|
||||
# in case of empty content or parts, we sill surface it
|
||||
# in case it's an interrupted message, we merge the previous partial
|
||||
# text. Other we don't merge. because content can be none when model
|
||||
# safty threshold is triggered
|
||||
if message.server_content.interrupted and text:
|
||||
yield self.__build_full_text_response(text)
|
||||
text = ''
|
||||
yield LlmResponse(interrupted=message.server_content.interrupted)
|
||||
if message.tool_call:
|
||||
if text:
|
||||
yield self.__build_full_text_response(text)
|
||||
text = ''
|
||||
parts = [
|
||||
types.Part(function_call=function_call)
|
||||
for function_call in message.tool_call.function_calls
|
||||
]
|
||||
yield LlmResponse(content=types.Content(role='model', parts=parts))
|
||||
|
||||
async def close(self):
|
||||
"""Closes the llm server connection."""
|
||||
|
||||
await self._gemini_session.close()
|
||||
331
src/google/adk/models/google_llm.py
Normal file
331
src/google/adk/models/google_llm.py
Normal file
@@ -0,0 +1,331 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
from functools import cached_property
|
||||
import logging
|
||||
import sys
|
||||
from typing import AsyncGenerator
|
||||
from typing import cast
|
||||
from typing import Generator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from google.genai import Client
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
from .. import version
|
||||
from .base_llm import BaseLlm
|
||||
from .base_llm_connection import BaseLlmConnection
|
||||
from .gemini_llm_connection import GeminiLlmConnection
|
||||
from .llm_response import LlmResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .llm_request import LlmRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_NEW_LINE = '\n'
|
||||
_EXCLUDED_PART_FIELD = {'inline_data': {'data'}}
|
||||
|
||||
|
||||
class Gemini(BaseLlm):
|
||||
"""Integration for Gemini models.
|
||||
|
||||
Attributes:
|
||||
model: The name of the Gemini model.
|
||||
"""
|
||||
|
||||
model: str = 'gemini-1.5-flash'
|
||||
|
||||
@staticmethod
|
||||
@override
|
||||
def supported_models() -> list[str]:
|
||||
"""Provides the list of supported models.
|
||||
|
||||
Returns:
|
||||
A list of supported models.
|
||||
"""
|
||||
|
||||
return [
|
||||
r'gemini-.*',
|
||||
# fine-tuned vertex endpoint pattern
|
||||
r'projects\/.+\/locations\/.+\/endpoints\/.+',
|
||||
# vertex gemini long name
|
||||
r'projects\/.+\/locations\/.+\/publishers\/google\/models\/gemini.+',
|
||||
]
|
||||
|
||||
async def generate_content_async(
|
||||
self, llm_request: LlmRequest, stream: bool = False
|
||||
) -> AsyncGenerator[LlmResponse, None]:
|
||||
"""Sends a request to the Gemini model.
|
||||
|
||||
Args:
|
||||
llm_request: LlmRequest, the request to send to the Gemini model.
|
||||
stream: bool = False, whether to do streaming call.
|
||||
|
||||
Yields:
|
||||
LlmResponse: The model response.
|
||||
"""
|
||||
|
||||
self._maybe_append_user_content(llm_request)
|
||||
logger.info(
|
||||
'Sending out request, model: %s, backend: %s, stream: %s',
|
||||
llm_request.model,
|
||||
self._api_backend,
|
||||
stream,
|
||||
)
|
||||
logger.info(_build_request_log(llm_request))
|
||||
|
||||
if stream:
|
||||
responses = await self.api_client.aio.models.generate_content_stream(
|
||||
model=llm_request.model,
|
||||
contents=llm_request.contents,
|
||||
config=llm_request.config,
|
||||
)
|
||||
response = None
|
||||
text = ''
|
||||
# for sse, similar as bidi (see receive method in gemini_llm_connecton.py),
|
||||
# we need to mark those text content as partial and after all partial
|
||||
# contents are sent, we send an accumulated event which contains all the
|
||||
# previous partial content. The only difference is bidi rely on
|
||||
# complete_turn flag to detect end while sse depends on finish_reason.
|
||||
async for response in responses:
|
||||
logger.info(_build_response_log(response))
|
||||
llm_response = LlmResponse.create(response)
|
||||
if (
|
||||
llm_response.content
|
||||
and llm_response.content.parts
|
||||
and llm_response.content.parts[0].text
|
||||
):
|
||||
text += llm_response.content.parts[0].text
|
||||
llm_response.partial = True
|
||||
elif text and (
|
||||
not llm_response.content
|
||||
or not llm_response.content.parts
|
||||
# don't yield the merged text event when receiving audio data
|
||||
or not llm_response.content.parts[0].inline_data
|
||||
):
|
||||
yield LlmResponse(
|
||||
content=types.ModelContent(
|
||||
parts=[types.Part.from_text(text=text)],
|
||||
),
|
||||
)
|
||||
text = ''
|
||||
yield llm_response
|
||||
if (
|
||||
text
|
||||
and response
|
||||
and response.candidates
|
||||
and response.candidates[0].finish_reason == types.FinishReason.STOP
|
||||
):
|
||||
yield LlmResponse(
|
||||
content=types.ModelContent(
|
||||
parts=[types.Part.from_text(text=text)],
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
response = await self.api_client.aio.models.generate_content(
|
||||
model=llm_request.model,
|
||||
contents=llm_request.contents,
|
||||
config=llm_request.config,
|
||||
)
|
||||
logger.info(_build_response_log(response))
|
||||
yield LlmResponse.create(response)
|
||||
|
||||
@cached_property
|
||||
def api_client(self) -> Client:
|
||||
"""Provides the api client.
|
||||
|
||||
Returns:
|
||||
The api client.
|
||||
"""
|
||||
return Client(
|
||||
http_options=types.HttpOptions(headers=self._tracking_headers)
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def _api_backend(self) -> str:
|
||||
return 'vertex' if self.api_client.vertexai else 'ml_dev'
|
||||
|
||||
@cached_property
|
||||
def _tracking_headers(self) -> dict[str, str]:
|
||||
framework_label = f'google-adk/{version.__version__}'
|
||||
language_label = 'gl-python/' + sys.version.split()[0]
|
||||
version_header_value = f'{framework_label} {language_label}'
|
||||
tracking_headers = {
|
||||
'x-goog-api-client': version_header_value,
|
||||
'user-agent': version_header_value,
|
||||
}
|
||||
return tracking_headers
|
||||
|
||||
@cached_property
|
||||
def _live_api_client(self) -> Client:
|
||||
if self._api_backend == 'vertex':
|
||||
# use default api version for vertex
|
||||
return Client(
|
||||
http_options=types.HttpOptions(headers=self._tracking_headers)
|
||||
)
|
||||
else:
|
||||
# use v1alpha for ml_dev
|
||||
api_version = 'v1alpha'
|
||||
return Client(
|
||||
http_options=types.HttpOptions(
|
||||
headers=self._tracking_headers, api_version=api_version
|
||||
)
|
||||
)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
|
||||
"""Connects to the Gemini model and returns an llm connection.
|
||||
|
||||
Args:
|
||||
llm_request: LlmRequest, the request to send to the Gemini model.
|
||||
|
||||
Yields:
|
||||
BaseLlmConnection, the connection to the Gemini model.
|
||||
"""
|
||||
|
||||
llm_request.live_connect_config.system_instruction = types.Content(
|
||||
role='system',
|
||||
parts=[
|
||||
types.Part.from_text(text=llm_request.config.system_instruction)
|
||||
],
|
||||
)
|
||||
llm_request.live_connect_config.tools = llm_request.config.tools
|
||||
async with self._live_api_client.aio.live.connect(
|
||||
model=llm_request.model, config=llm_request.live_connect_config
|
||||
) as live_session:
|
||||
yield GeminiLlmConnection(live_session)
|
||||
|
||||
def _maybe_append_user_content(self, llm_request: LlmRequest):
|
||||
"""Appends a user content, so that model can continue to output.
|
||||
|
||||
Args:
|
||||
llm_request: LlmRequest, the request to send to the Gemini model.
|
||||
"""
|
||||
# If no content is provided, append a user content to hint model response
|
||||
# using system instruction.
|
||||
if not llm_request.contents:
|
||||
llm_request.contents.append(
|
||||
types.Content(
|
||||
role='user',
|
||||
parts=[
|
||||
types.Part(
|
||||
text=(
|
||||
'Handle the requests as specified in the System'
|
||||
' Instruction.'
|
||||
)
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Insert a user content to preserve user intent and to avoid empty
|
||||
# model response.
|
||||
if llm_request.contents[-1].role != 'user':
|
||||
llm_request.contents.append(
|
||||
types.Content(
|
||||
role='user',
|
||||
parts=[
|
||||
types.Part(
|
||||
text=(
|
||||
'Continue processing previous requests as instructed.'
|
||||
' Exit or provide a summary if no more outputs are'
|
||||
' needed.'
|
||||
)
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _build_function_declaration_log(
|
||||
func_decl: types.FunctionDeclaration,
|
||||
) -> str:
|
||||
param_str = '{}'
|
||||
if func_decl.parameters and func_decl.parameters.properties:
|
||||
param_str = str({
|
||||
k: v.model_dump(exclude_none=True)
|
||||
for k, v in func_decl.parameters.properties.items()
|
||||
})
|
||||
return_str = 'None'
|
||||
if func_decl.response:
|
||||
return_str = str(func_decl.response.model_dump(exclude_none=True))
|
||||
return f'{func_decl.name}: {param_str} -> {return_str}'
|
||||
|
||||
|
||||
def _build_request_log(req: LlmRequest) -> str:
|
||||
function_decls: list[types.FunctionDeclaration] = cast(
|
||||
list[types.FunctionDeclaration],
|
||||
req.config.tools[0].function_declarations if req.config.tools else [],
|
||||
)
|
||||
function_logs = (
|
||||
[
|
||||
_build_function_declaration_log(func_decl)
|
||||
for func_decl in function_decls
|
||||
]
|
||||
if function_decls
|
||||
else []
|
||||
)
|
||||
contents_logs = [
|
||||
content.model_dump_json(
|
||||
exclude_none=True,
|
||||
exclude={
|
||||
'parts': {
|
||||
i: _EXCLUDED_PART_FIELD for i in range(len(content.parts))
|
||||
}
|
||||
},
|
||||
)
|
||||
for content in req.contents
|
||||
]
|
||||
|
||||
return f"""
|
||||
LLM Request:
|
||||
-----------------------------------------------------------
|
||||
System Instruction:
|
||||
{req.config.system_instruction}
|
||||
-----------------------------------------------------------
|
||||
Contents:
|
||||
{_NEW_LINE.join(contents_logs)}
|
||||
-----------------------------------------------------------
|
||||
Functions:
|
||||
{_NEW_LINE.join(function_logs)}
|
||||
-----------------------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def _build_response_log(resp: types.GenerateContentResponse) -> str:
|
||||
function_calls_text = []
|
||||
if function_calls := resp.function_calls:
|
||||
for func_call in function_calls:
|
||||
function_calls_text.append(
|
||||
f'name: {func_call.name}, args: {func_call.args}'
|
||||
)
|
||||
return f"""
|
||||
LLM Response:
|
||||
-----------------------------------------------------------
|
||||
Text:
|
||||
{resp.text}
|
||||
-----------------------------------------------------------
|
||||
Function calls:
|
||||
{_NEW_LINE.join(function_calls_text)}
|
||||
-----------------------------------------------------------
|
||||
Raw response:
|
||||
{resp.model_dump_json(exclude_none=True)}
|
||||
-----------------------------------------------------------
|
||||
"""
|
||||
673
src/google/adk/models/lite_llm.py
Normal file
673
src/google/adk/models/lite_llm.py
Normal file
@@ -0,0 +1,673 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import AsyncGenerator
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import Generator
|
||||
from typing import Iterable
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
from google.genai import types
|
||||
from litellm import acompletion
|
||||
from litellm import ChatCompletionAssistantMessage
|
||||
from litellm import ChatCompletionDeveloperMessage
|
||||
from litellm import ChatCompletionImageUrlObject
|
||||
from litellm import ChatCompletionMessageToolCall
|
||||
from litellm import ChatCompletionTextObject
|
||||
from litellm import ChatCompletionToolMessage
|
||||
from litellm import ChatCompletionUserMessage
|
||||
from litellm import ChatCompletionVideoUrlObject
|
||||
from litellm import completion
|
||||
from litellm import CustomStreamWrapper
|
||||
from litellm import Function
|
||||
from litellm import Message
|
||||
from litellm import ModelResponse
|
||||
from litellm import OpenAIMessageContent
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from typing_extensions import override
|
||||
|
||||
from .base_llm import BaseLlm
|
||||
from .llm_request import LlmRequest
|
||||
from .llm_response import LlmResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_NEW_LINE = "\n"
|
||||
_EXCLUDED_PART_FIELD = {"inline_data": {"data"}}
|
||||
|
||||
|
||||
class FunctionChunk(BaseModel):
|
||||
id: Optional[str]
|
||||
name: Optional[str]
|
||||
args: Optional[str]
|
||||
|
||||
|
||||
class TextChunk(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class LiteLLMClient:
|
||||
"""Provides acompletion method (for better testability)."""
|
||||
|
||||
async def acompletion(
|
||||
self, model, messages, tools, **kwargs
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
"""Asynchronously calls acompletion.
|
||||
|
||||
Args:
|
||||
model: The model name.
|
||||
messages: The messages to send to the model.
|
||||
tools: The tools to use for the model.
|
||||
**kwargs: Additional arguments to pass to acompletion.
|
||||
|
||||
Returns:
|
||||
The model response as a message.
|
||||
"""
|
||||
|
||||
return await acompletion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def completion(
|
||||
self, model, messages, tools, stream=False, **kwargs
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
"""Synchronously calls completion. This is used for streaming only.
|
||||
|
||||
Args:
|
||||
model: The model to use.
|
||||
messages: The messages to send.
|
||||
tools: The tools to use for the model.
|
||||
stream: Whether to stream the response.
|
||||
**kwargs: Additional arguments to pass to completion.
|
||||
|
||||
Returns:
|
||||
The response from the model.
|
||||
"""
|
||||
|
||||
return completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
stream=stream,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _safe_json_serialize(obj) -> str:
|
||||
"""Convert any Python object to a JSON-serializable type or string.
|
||||
|
||||
Args:
|
||||
obj: The object to serialize.
|
||||
|
||||
Returns:
|
||||
The JSON-serialized object string or string.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Try direct JSON serialization first
|
||||
return json.dumps(obj)
|
||||
except (TypeError, OverflowError):
|
||||
return str(obj)
|
||||
|
||||
|
||||
def _content_to_message_param(
|
||||
content: types.Content,
|
||||
) -> Message:
|
||||
"""Converts a types.Content to a litellm Message.
|
||||
|
||||
Args:
|
||||
content: The content to convert.
|
||||
|
||||
Returns:
|
||||
The litellm Message.
|
||||
"""
|
||||
|
||||
if content.parts and content.parts[0].function_response:
|
||||
return ChatCompletionToolMessage(
|
||||
role="tool",
|
||||
tool_call_id=content.parts[0].function_response.id,
|
||||
content=_safe_json_serialize(
|
||||
content.parts[0].function_response.response
|
||||
),
|
||||
)
|
||||
|
||||
role = _to_litellm_role(content.role)
|
||||
|
||||
if role == "user":
|
||||
return ChatCompletionUserMessage(
|
||||
role="user", content=_get_content(content.parts)
|
||||
)
|
||||
else:
|
||||
|
||||
tool_calls = [
|
||||
ChatCompletionMessageToolCall(
|
||||
type="function",
|
||||
id=part.function_call.id,
|
||||
function=Function(
|
||||
name=part.function_call.name,
|
||||
arguments=part.function_call.args,
|
||||
),
|
||||
)
|
||||
for part in content.parts
|
||||
if part.function_call
|
||||
]
|
||||
|
||||
return ChatCompletionAssistantMessage(
|
||||
role=role,
|
||||
content=_get_content(content.parts),
|
||||
tool_calls=tool_calls or None,
|
||||
)
|
||||
|
||||
|
||||
def _get_content(parts: Iterable[types.Part]) -> OpenAIMessageContent | str:
|
||||
"""Converts a list of parts to litellm content.
|
||||
|
||||
Args:
|
||||
parts: The parts to convert.
|
||||
|
||||
Returns:
|
||||
The litellm content.
|
||||
"""
|
||||
|
||||
content_objects = []
|
||||
for part in parts:
|
||||
if part.text:
|
||||
if len(parts) == 1:
|
||||
return part.text
|
||||
content_objects.append(
|
||||
ChatCompletionTextObject(
|
||||
type="text",
|
||||
text=part.text,
|
||||
)
|
||||
)
|
||||
elif (
|
||||
part.inline_data
|
||||
and part.inline_data.data
|
||||
and part.inline_data.mime_type
|
||||
):
|
||||
base64_string = base64.b64encode(part.inline_data.data).decode("utf-8")
|
||||
data_uri = f"data:{part.inline_data.mime_type};base64,{base64_string}"
|
||||
|
||||
if part.inline_data.mime_type.startswith("image"):
|
||||
content_objects.append(
|
||||
ChatCompletionImageUrlObject(
|
||||
type="image_url",
|
||||
image_url=data_uri,
|
||||
)
|
||||
)
|
||||
elif part.inline_data.mime_type.startswith("video"):
|
||||
content_objects.append(
|
||||
ChatCompletionVideoUrlObject(
|
||||
type="video_url",
|
||||
video_url=data_uri,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError("LiteLlm(BaseLlm) does not support this content part.")
|
||||
|
||||
return content_objects
|
||||
|
||||
|
||||
def _to_litellm_role(role: Optional[str]) -> Literal["user", "assistant"]:
|
||||
"""Converts a types.Content role to a litellm role.
|
||||
|
||||
Args:
|
||||
role: The types.Content role.
|
||||
|
||||
Returns:
|
||||
The litellm role.
|
||||
"""
|
||||
|
||||
if role in ["model", "assistant"]:
|
||||
return "assistant"
|
||||
return "user"
|
||||
|
||||
|
||||
TYPE_LABELS = {
|
||||
"STRING": "string",
|
||||
"NUMBER": "number",
|
||||
"BOOLEAN": "boolean",
|
||||
"OBJECT": "object",
|
||||
"ARRAY": "array",
|
||||
"INTEGER": "integer",
|
||||
}
|
||||
|
||||
|
||||
def _schema_to_dict(schema: types.Schema) -> dict:
|
||||
"""Recursively converts a types.Schema to a dictionary.
|
||||
|
||||
Args:
|
||||
schema: The schema to convert.
|
||||
|
||||
Returns:
|
||||
The dictionary representation of the schema.
|
||||
"""
|
||||
|
||||
schema_dict = schema.model_dump(exclude_none=True)
|
||||
if "type" in schema_dict:
|
||||
schema_dict["type"] = schema_dict["type"].lower()
|
||||
if "items" in schema_dict:
|
||||
if isinstance(schema_dict["items"], dict):
|
||||
schema_dict["items"] = _schema_to_dict(
|
||||
types.Schema.model_validate(schema_dict["items"])
|
||||
)
|
||||
elif isinstance(schema_dict["items"]["type"], types.Type):
|
||||
schema_dict["items"]["type"] = TYPE_LABELS[
|
||||
schema_dict["items"]["type"].value
|
||||
]
|
||||
if "properties" in schema_dict:
|
||||
properties = {}
|
||||
for key, value in schema_dict["properties"].items():
|
||||
if isinstance(value, types.Schema):
|
||||
properties[key] = _schema_to_dict(value)
|
||||
else:
|
||||
properties[key] = value
|
||||
if "type" in properties[key]:
|
||||
properties[key]["type"] = properties[key]["type"].lower()
|
||||
schema_dict["properties"] = properties
|
||||
return schema_dict
|
||||
|
||||
|
||||
def _function_declaration_to_tool_param(
|
||||
function_declaration: types.FunctionDeclaration,
|
||||
) -> dict:
|
||||
"""Converts a types.FunctionDeclaration to a openapi spec dictionary.
|
||||
|
||||
Args:
|
||||
function_declaration: The function declaration to convert.
|
||||
|
||||
Returns:
|
||||
The openapi spec dictionary representation of the function declaration.
|
||||
"""
|
||||
|
||||
assert function_declaration.name
|
||||
|
||||
properties = {}
|
||||
if (
|
||||
function_declaration.parameters
|
||||
and function_declaration.parameters.properties
|
||||
):
|
||||
for key, value in function_declaration.parameters.properties.items():
|
||||
properties[key] = _schema_to_dict(value)
|
||||
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function_declaration.name,
|
||||
"description": function_declaration.description or "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _model_response_to_chunk(
|
||||
response: ModelResponse,
|
||||
) -> Generator[
|
||||
Tuple[Optional[Union[TextChunk, FunctionChunk]], Optional[str]], None, None
|
||||
]:
|
||||
"""Converts a litellm message to text or function chunk.
|
||||
|
||||
Args:
|
||||
response: The response from the model.
|
||||
|
||||
Yields:
|
||||
A tuple of text or function chunk and finish reason.
|
||||
"""
|
||||
|
||||
message = None
|
||||
if response.get("choices", None):
|
||||
message = response["choices"][0].get("message", None)
|
||||
finish_reason = response["choices"][0].get("finish_reason", None)
|
||||
# check streaming delta
|
||||
if message is None and response["choices"][0].get("delta", None):
|
||||
message = response["choices"][0]["delta"]
|
||||
|
||||
if message.get("content", None):
|
||||
yield TextChunk(text=message.get("content")), finish_reason
|
||||
|
||||
if message.get("tool_calls", None):
|
||||
for tool_call in message.get("tool_calls"):
|
||||
# aggregate tool_call
|
||||
if tool_call.type == "function":
|
||||
yield FunctionChunk(
|
||||
id=tool_call.id,
|
||||
name=tool_call.function.name,
|
||||
args=tool_call.function.arguments,
|
||||
), finish_reason
|
||||
|
||||
if finish_reason and not (
|
||||
message.get("content", None) or message.get("tool_calls", None)
|
||||
):
|
||||
yield None, finish_reason
|
||||
|
||||
if not message:
|
||||
yield None, None
|
||||
|
||||
|
||||
def _model_response_to_generate_content_response(
|
||||
response: ModelResponse,
|
||||
) -> LlmResponse:
|
||||
"""Converts a litellm response to LlmResponse.
|
||||
|
||||
Args:
|
||||
response: The model response.
|
||||
|
||||
Returns:
|
||||
The LlmResponse.
|
||||
"""
|
||||
|
||||
message = None
|
||||
if response.get("choices", None):
|
||||
message = response["choices"][0].get("message", None)
|
||||
|
||||
if not message:
|
||||
raise ValueError("No message in response")
|
||||
return _message_to_generate_content_response(message)
|
||||
|
||||
|
||||
def _message_to_generate_content_response(
|
||||
message: Message, is_partial: bool = False
|
||||
) -> LlmResponse:
|
||||
"""Converts a litellm message to LlmResponse.
|
||||
|
||||
Args:
|
||||
message: The message to convert.
|
||||
is_partial: Whether the message is partial.
|
||||
|
||||
Returns:
|
||||
The LlmResponse.
|
||||
"""
|
||||
|
||||
parts = []
|
||||
if message.get("content", None):
|
||||
parts.append(types.Part.from_text(text=message.get("content")))
|
||||
|
||||
if message.get("tool_calls", None):
|
||||
for tool_call in message.get("tool_calls"):
|
||||
if tool_call.type == "function":
|
||||
part = types.Part.from_function_call(
|
||||
name=tool_call.function.name,
|
||||
args=json.loads(tool_call.function.arguments or "{}"),
|
||||
)
|
||||
part.function_call.id = tool_call.id
|
||||
parts.append(part)
|
||||
|
||||
return LlmResponse(
|
||||
content=types.Content(role="model", parts=parts), partial=is_partial
|
||||
)
|
||||
|
||||
|
||||
def _get_completion_inputs(
|
||||
llm_request: LlmRequest,
|
||||
) -> tuple[Iterable[Message], Iterable[dict]]:
|
||||
"""Converts an LlmRequest to litellm inputs.
|
||||
|
||||
Args:
|
||||
llm_request: The LlmRequest to convert.
|
||||
|
||||
Returns:
|
||||
The litellm inputs (message list and tool dictionary).
|
||||
"""
|
||||
messages = [
|
||||
_content_to_message_param(content)
|
||||
for content in llm_request.contents or []
|
||||
]
|
||||
|
||||
if llm_request.config.system_instruction:
|
||||
messages.insert(
|
||||
0,
|
||||
ChatCompletionDeveloperMessage(
|
||||
role="developer",
|
||||
content=llm_request.config.system_instruction,
|
||||
),
|
||||
)
|
||||
|
||||
tools = None
|
||||
if (
|
||||
llm_request.config
|
||||
and llm_request.config.tools
|
||||
and llm_request.config.tools[0].function_declarations
|
||||
):
|
||||
tools = [
|
||||
_function_declaration_to_tool_param(tool)
|
||||
for tool in llm_request.config.tools[0].function_declarations
|
||||
]
|
||||
return messages, tools
|
||||
|
||||
|
||||
def _build_function_declaration_log(
|
||||
func_decl: types.FunctionDeclaration,
|
||||
) -> str:
|
||||
"""Builds a function declaration log.
|
||||
|
||||
Args:
|
||||
func_decl: The function declaration to convert.
|
||||
|
||||
Returns:
|
||||
The function declaration log.
|
||||
"""
|
||||
|
||||
param_str = "{}"
|
||||
if func_decl.parameters and func_decl.parameters.properties:
|
||||
param_str = str({
|
||||
k: v.model_dump(exclude_none=True)
|
||||
for k, v in func_decl.parameters.properties.items()
|
||||
})
|
||||
return_str = "None"
|
||||
if func_decl.response:
|
||||
return_str = str(func_decl.response.model_dump(exclude_none=True))
|
||||
return f"{func_decl.name}: {param_str} -> {return_str}"
|
||||
|
||||
|
||||
def _build_request_log(req: LlmRequest) -> str:
|
||||
"""Builds a request log.
|
||||
|
||||
Args:
|
||||
req: The request to convert.
|
||||
|
||||
Returns:
|
||||
The request log.
|
||||
"""
|
||||
|
||||
function_decls: list[types.FunctionDeclaration] = cast(
|
||||
list[types.FunctionDeclaration],
|
||||
req.config.tools[0].function_declarations if req.config.tools else [],
|
||||
)
|
||||
function_logs = (
|
||||
[
|
||||
_build_function_declaration_log(func_decl)
|
||||
for func_decl in function_decls
|
||||
]
|
||||
if function_decls
|
||||
else []
|
||||
)
|
||||
contents_logs = [
|
||||
content.model_dump_json(
|
||||
exclude_none=True,
|
||||
exclude={
|
||||
"parts": {
|
||||
i: _EXCLUDED_PART_FIELD for i in range(len(content.parts))
|
||||
}
|
||||
},
|
||||
)
|
||||
for content in req.contents
|
||||
]
|
||||
|
||||
return f"""
|
||||
LLM Request:
|
||||
-----------------------------------------------------------
|
||||
System Instruction:
|
||||
{req.config.system_instruction}
|
||||
-----------------------------------------------------------
|
||||
Contents:
|
||||
{_NEW_LINE.join(contents_logs)}
|
||||
-----------------------------------------------------------
|
||||
Functions:
|
||||
{_NEW_LINE.join(function_logs)}
|
||||
-----------------------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
class LiteLlm(BaseLlm):
|
||||
"""Wrapper around litellm.
|
||||
|
||||
This wrapper can be used with any of the models supported by litellm. The
|
||||
environment variable(s) needed for authenticating with the model endpoint must
|
||||
be set prior to instantiating this class.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
os.environ["VERTEXAI_PROJECT"] = "your-gcp-project-id"
|
||||
os.environ["VERTEXAI_LOCATION"] = "your-gcp-location"
|
||||
|
||||
agent = Agent(
|
||||
model=LiteLlm(model="vertex_ai/claude-3-7-sonnet@20250219"),
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
Attributes:
|
||||
model: The name of the LiteLlm model.
|
||||
llm_client: The LLM client to use for the model.
|
||||
model_config: The model config.
|
||||
"""
|
||||
|
||||
llm_client: LiteLLMClient = Field(default_factory=LiteLLMClient)
|
||||
"""The LLM client to use for the model."""
|
||||
|
||||
_additional_args: Dict[str, Any] = None
|
||||
|
||||
def __init__(self, model: str, **kwargs):
|
||||
"""Initializes the LiteLlm class.
|
||||
|
||||
Args:
|
||||
model: The name of the LiteLlm model.
|
||||
**kwargs: Additional arguments to pass to the litellm completion api.
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
self._additional_args = kwargs
|
||||
# preventing generation call with llm_client
|
||||
# and overriding messages, tools and stream which are managed internally
|
||||
self._additional_args.pop("llm_client", None)
|
||||
self._additional_args.pop("messages", None)
|
||||
self._additional_args.pop("tools", None)
|
||||
# public api called from runner determines to stream or not
|
||||
self._additional_args.pop("stream", None)
|
||||
|
||||
async def generate_content_async(
|
||||
self, llm_request: LlmRequest, stream: bool = False
|
||||
) -> AsyncGenerator[LlmResponse, None]:
|
||||
"""Generates content asynchronously.
|
||||
|
||||
Args:
|
||||
llm_request: LlmRequest, the request to send to the LiteLlm model.
|
||||
stream: bool = False, whether to do streaming call.
|
||||
|
||||
Yields:
|
||||
LlmResponse: The model response.
|
||||
"""
|
||||
|
||||
logger.info(_build_request_log(llm_request))
|
||||
|
||||
messages, tools = _get_completion_inputs(llm_request)
|
||||
|
||||
completion_args = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
}
|
||||
completion_args.update(self._additional_args)
|
||||
|
||||
if stream:
|
||||
text = ""
|
||||
function_name = ""
|
||||
function_args = ""
|
||||
function_id = None
|
||||
completion_args["stream"] = True
|
||||
for part in self.llm_client.completion(**completion_args):
|
||||
for chunk, finish_reason in _model_response_to_chunk(part):
|
||||
if isinstance(chunk, FunctionChunk):
|
||||
if chunk.name:
|
||||
function_name += chunk.name
|
||||
if chunk.args:
|
||||
function_args += chunk.args
|
||||
function_id = chunk.id or function_id
|
||||
elif isinstance(chunk, TextChunk):
|
||||
text += chunk.text
|
||||
yield _message_to_generate_content_response(
|
||||
ChatCompletionAssistantMessage(
|
||||
role="assistant",
|
||||
content=chunk.text,
|
||||
),
|
||||
is_partial=True,
|
||||
)
|
||||
if finish_reason == "tool_calls" and function_id:
|
||||
yield _message_to_generate_content_response(
|
||||
ChatCompletionAssistantMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
type="function",
|
||||
id=function_id,
|
||||
function=Function(
|
||||
name=function_name,
|
||||
arguments=function_args,
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
function_name = ""
|
||||
function_args = ""
|
||||
function_id = None
|
||||
elif finish_reason == "stop" and text:
|
||||
yield _message_to_generate_content_response(
|
||||
ChatCompletionAssistantMessage(role="assistant", content=text)
|
||||
)
|
||||
text = ""
|
||||
|
||||
else:
|
||||
response = await self.llm_client.acompletion(**completion_args)
|
||||
yield _model_response_to_generate_content_response(response)
|
||||
|
||||
@staticmethod
|
||||
@override
|
||||
def supported_models() -> list[str]:
|
||||
"""Provides the list of supported models.
|
||||
|
||||
LiteLlm supports all models supported by litellm. We do not keep track of
|
||||
these models here. So we return an empty list.
|
||||
|
||||
Returns:
|
||||
A list of supported models.
|
||||
"""
|
||||
|
||||
return []
|
||||
98
src/google/adk/models/llm_request.py
Normal file
98
src/google/adk/models/llm_request.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
|
||||
from ..tools.base_tool import BaseTool
|
||||
|
||||
|
||||
class LlmRequest(BaseModel):
|
||||
"""LLM request class that allows passing in tools, output schema and system
|
||||
|
||||
instructions to the model.
|
||||
|
||||
Attributes:
|
||||
model: The model name.
|
||||
contents: The contents to send to the model.
|
||||
config: Additional config for the generate content request.
|
||||
tools_dict: The tools dictionary.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
"""The model config."""
|
||||
|
||||
model: Optional[str] = None
|
||||
"""The model name."""
|
||||
|
||||
contents: list[types.Content] = Field(default_factory=list)
|
||||
"""The contents to send to the model."""
|
||||
|
||||
config: Optional[types.GenerateContentConfig] = None
|
||||
live_connect_config: types.LiveConnectConfig = types.LiveConnectConfig()
|
||||
"""Additional config for the generate content request.
|
||||
|
||||
tools in generate_content_config should not be set.
|
||||
"""
|
||||
tools_dict: dict[str, BaseTool] = Field(default_factory=dict, exclude=True)
|
||||
"""The tools dictionary."""
|
||||
|
||||
def append_instructions(self, instructions: list[str]) -> None:
|
||||
"""Appends instructions to the system instruction.
|
||||
|
||||
Args:
|
||||
instructions: The instructions to append.
|
||||
"""
|
||||
|
||||
if self.config.system_instruction:
|
||||
self.config.system_instruction += '\n\n' + '\n\n'.join(instructions)
|
||||
else:
|
||||
self.config.system_instruction = '\n\n'.join(instructions)
|
||||
|
||||
def append_tools(self, tools: list[BaseTool]) -> None:
|
||||
"""Appends tools to the request.
|
||||
|
||||
Args:
|
||||
tools: The tools to append.
|
||||
"""
|
||||
|
||||
if not tools:
|
||||
return
|
||||
declarations = []
|
||||
for tool in tools:
|
||||
if isinstance(tool, BaseTool):
|
||||
declaration = tool._get_declaration()
|
||||
else:
|
||||
declaration = tool.get_declaration()
|
||||
if declaration:
|
||||
declarations.append(declaration)
|
||||
self.tools_dict[tool.name] = tool
|
||||
if declarations:
|
||||
self.config.tools.append(types.Tool(function_declarations=declarations))
|
||||
|
||||
def set_output_schema(self, base_model: type[BaseModel]) -> None:
|
||||
"""Sets the output schema for the request.
|
||||
|
||||
Args:
|
||||
base_model: The pydantic base model to set the output schema to.
|
||||
"""
|
||||
|
||||
self.config.response_schema = base_model
|
||||
self.config.response_mime_type = 'application/json'
|
||||
111
src/google/adk/models/llm_response.py
Normal file
111
src/google/adk/models/llm_response.py
Normal file
@@ -0,0 +1,111 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
|
||||
|
||||
class LlmResponse(BaseModel):
|
||||
"""LLM response class that provides the first candidate response from the
|
||||
|
||||
model if available. Otherwise, returns error code and message.
|
||||
|
||||
Attributes:
|
||||
content: The content of the response.
|
||||
grounding_metadata: The grounding metadata of the response.
|
||||
partial: Indicates whether the text content is part of a unfinished text
|
||||
stream. Only used for streaming mode and when the content is plain text.
|
||||
turn_complete: Indicates whether the response from the model is complete.
|
||||
Only used for streaming mode.
|
||||
error_code: Error code if the response is an error. Code varies by model.
|
||||
error_message: Error message if the response is an error.
|
||||
interrupted: Flag indicating that LLM was interrupted when generating the
|
||||
content. Usually it's due to user interruption during a bidi streaming.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra='forbid')
|
||||
"""The model config."""
|
||||
|
||||
content: Optional[types.Content] = None
|
||||
"""The content of the response."""
|
||||
|
||||
grounding_metadata: Optional[types.GroundingMetadata] = None
|
||||
"""The grounding metadata of the response."""
|
||||
|
||||
partial: Optional[bool] = None
|
||||
"""Indicates whether the text content is part of a unfinished text stream.
|
||||
|
||||
Only used for streaming mode and when the content is plain text.
|
||||
"""
|
||||
|
||||
turn_complete: Optional[bool] = None
|
||||
"""Indicates whether the response from the model is complete.
|
||||
|
||||
Only used for streaming mode.
|
||||
"""
|
||||
|
||||
error_code: Optional[str] = None
|
||||
"""Error code if the response is an error. Code varies by model."""
|
||||
|
||||
error_message: Optional[str] = None
|
||||
"""Error message if the response is an error."""
|
||||
|
||||
interrupted: Optional[bool] = None
|
||||
"""Flag indicating that LLM was interrupted when generating the content.
|
||||
Usually it's due to user interruption during a bidi streaming.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
generate_content_response: types.GenerateContentResponse,
|
||||
) -> 'LlmResponse':
|
||||
"""Creates an LlmResponse from a GenerateContentResponse.
|
||||
|
||||
Args:
|
||||
generate_content_response: The GenerateContentResponse to create the
|
||||
LlmResponse from.
|
||||
|
||||
Returns:
|
||||
The LlmResponse.
|
||||
"""
|
||||
|
||||
if generate_content_response.candidates:
|
||||
candidate = generate_content_response.candidates[0]
|
||||
if candidate.content and candidate.content.parts:
|
||||
return LlmResponse(
|
||||
content=candidate.content,
|
||||
grounding_metadata=candidate.grounding_metadata,
|
||||
)
|
||||
else:
|
||||
return LlmResponse(
|
||||
error_code=candidate.finish_reason,
|
||||
error_message=candidate.finish_message,
|
||||
)
|
||||
else:
|
||||
if generate_content_response.prompt_feedback:
|
||||
prompt_feedback = generate_content_response.prompt_feedback
|
||||
return LlmResponse(
|
||||
error_code=prompt_feedback.block_reason,
|
||||
error_message=prompt_feedback.block_reason_message,
|
||||
)
|
||||
else:
|
||||
return LlmResponse(
|
||||
error_code='UNKNOWN_ERROR',
|
||||
error_message='Unknown error.',
|
||||
)
|
||||
102
src/google/adk/models/registry.py
Normal file
102
src/google/adk/models/registry.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The registry class for model."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .base_llm import BaseLlm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_llm_registry_dict: dict[str, type[BaseLlm]] = {}
|
||||
"""Registry for LLMs.
|
||||
|
||||
Key is the regex that matches the model name.
|
||||
Value is the class that implements the model.
|
||||
"""
|
||||
|
||||
|
||||
class LLMRegistry:
|
||||
"""Registry for LLMs."""
|
||||
|
||||
@staticmethod
|
||||
def new_llm(model: str) -> BaseLlm:
|
||||
"""Creates a new LLM instance.
|
||||
|
||||
Args:
|
||||
model: The model name.
|
||||
|
||||
Returns:
|
||||
The LLM instance.
|
||||
"""
|
||||
|
||||
return LLMRegistry.resolve(model)(model=model)
|
||||
|
||||
@staticmethod
|
||||
def _register(model_name_regex: str, llm_cls: type[BaseLlm]):
|
||||
"""Registers a new LLM class.
|
||||
|
||||
Args:
|
||||
model_name_regex: The regex that matches the model name.
|
||||
llm_cls: The class that implements the model.
|
||||
"""
|
||||
|
||||
if model_name_regex in _llm_registry_dict:
|
||||
logger.info(
|
||||
'Updating LLM class for %s from %s to %s',
|
||||
model_name_regex,
|
||||
_llm_registry_dict[model_name_regex],
|
||||
llm_cls,
|
||||
)
|
||||
|
||||
_llm_registry_dict[model_name_regex] = llm_cls
|
||||
|
||||
@staticmethod
|
||||
def register(llm_cls: type[BaseLlm]):
|
||||
"""Registers a new LLM class.
|
||||
|
||||
Args:
|
||||
llm_cls: The class that implements the model.
|
||||
"""
|
||||
|
||||
for regex in llm_cls.supported_models():
|
||||
LLMRegistry._register(regex, llm_cls)
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=32)
|
||||
def resolve(model: str) -> type[BaseLlm]:
|
||||
"""Resolves the model to a BaseLlm subclass.
|
||||
|
||||
Args:
|
||||
model: The model name.
|
||||
|
||||
Returns:
|
||||
The BaseLlm subclass.
|
||||
Raises:
|
||||
ValueError: If the model is not found.
|
||||
"""
|
||||
|
||||
for regex, llm_class in _llm_registry_dict.items():
|
||||
if re.compile(regex).fullmatch(model):
|
||||
return llm_class
|
||||
|
||||
raise ValueError(f'Model {model} not found.')
|
||||
Reference in New Issue
Block a user