structure saas with tools
This commit is contained in:
@@ -0,0 +1,31 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Defines the interface to support a model."""
|
||||
|
||||
from .base_llm import BaseLlm
|
||||
from .google_llm import Gemini
|
||||
from .llm_request import LlmRequest
|
||||
from .llm_response import LlmResponse
|
||||
from .registry import LLMRegistry
|
||||
|
||||
__all__ = [
|
||||
'BaseLlm',
|
||||
'Gemini',
|
||||
'LLMRegistry',
|
||||
]
|
||||
|
||||
|
||||
for regex in Gemini.supported_models():
|
||||
LLMRegistry.register(Gemini)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,243 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Anthropic integration for Claude models."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import cached_property
|
||||
import logging
|
||||
import os
|
||||
from typing import AsyncGenerator
|
||||
from typing import Generator
|
||||
from typing import Iterable
|
||||
from typing import Literal
|
||||
from typing import Optional, Union
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from anthropic import AnthropicVertex
|
||||
from anthropic import NOT_GIVEN
|
||||
from anthropic import types as anthropic_types
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import override
|
||||
|
||||
from .base_llm import BaseLlm
|
||||
from .llm_response import LlmResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .llm_request import LlmRequest
|
||||
|
||||
__all__ = ["Claude"]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_TOKEN = 1024
|
||||
|
||||
|
||||
class ClaudeRequest(BaseModel):
|
||||
system_instruction: str
|
||||
messages: Iterable[anthropic_types.MessageParam]
|
||||
tools: list[anthropic_types.ToolParam]
|
||||
|
||||
|
||||
def to_claude_role(role: Optional[str]) -> Literal["user", "assistant"]:
|
||||
if role in ["model", "assistant"]:
|
||||
return "assistant"
|
||||
return "user"
|
||||
|
||||
|
||||
def to_google_genai_finish_reason(
|
||||
anthropic_stop_reason: Optional[str],
|
||||
) -> types.FinishReason:
|
||||
if anthropic_stop_reason in ["end_turn", "stop_sequence", "tool_use"]:
|
||||
return "STOP"
|
||||
if anthropic_stop_reason == "max_tokens":
|
||||
return "MAX_TOKENS"
|
||||
return "FINISH_REASON_UNSPECIFIED"
|
||||
|
||||
|
||||
def part_to_message_block(
|
||||
part: types.Part,
|
||||
) -> Union[
|
||||
anthropic_types.TextBlockParam,
|
||||
anthropic_types.ImageBlockParam,
|
||||
anthropic_types.ToolUseBlockParam,
|
||||
anthropic_types.ToolResultBlockParam,
|
||||
]:
|
||||
if part.text:
|
||||
return anthropic_types.TextBlockParam(text=part.text, type="text")
|
||||
if part.function_call:
|
||||
assert part.function_call.name
|
||||
|
||||
return anthropic_types.ToolUseBlockParam(
|
||||
id=part.function_call.id or "",
|
||||
name=part.function_call.name,
|
||||
input=part.function_call.args,
|
||||
type="tool_use",
|
||||
)
|
||||
if part.function_response:
|
||||
content = ""
|
||||
if (
|
||||
"result" in part.function_response.response
|
||||
and part.function_response.response["result"]
|
||||
):
|
||||
# Transformation is required because the content is a list of dict.
|
||||
# ToolResultBlockParam content doesn't support list of dict. Converting
|
||||
# to str to prevent anthropic.BadRequestError from being thrown.
|
||||
content = str(part.function_response.response["result"])
|
||||
return anthropic_types.ToolResultBlockParam(
|
||||
tool_use_id=part.function_response.id or "",
|
||||
type="tool_result",
|
||||
content=content,
|
||||
is_error=False,
|
||||
)
|
||||
raise NotImplementedError("Not supported yet.")
|
||||
|
||||
|
||||
def content_to_message_param(
|
||||
content: types.Content,
|
||||
) -> anthropic_types.MessageParam:
|
||||
return {
|
||||
"role": to_claude_role(content.role),
|
||||
"content": [part_to_message_block(part) for part in content.parts or []],
|
||||
}
|
||||
|
||||
|
||||
def content_block_to_part(
|
||||
content_block: anthropic_types.ContentBlock,
|
||||
) -> types.Part:
|
||||
if isinstance(content_block, anthropic_types.TextBlock):
|
||||
return types.Part.from_text(text=content_block.text)
|
||||
if isinstance(content_block, anthropic_types.ToolUseBlock):
|
||||
assert isinstance(content_block.input, dict)
|
||||
part = types.Part.from_function_call(
|
||||
name=content_block.name, args=content_block.input
|
||||
)
|
||||
part.function_call.id = content_block.id
|
||||
return part
|
||||
raise NotImplementedError("Not supported yet.")
|
||||
|
||||
|
||||
def message_to_generate_content_response(
|
||||
message: anthropic_types.Message,
|
||||
) -> LlmResponse:
|
||||
|
||||
return LlmResponse(
|
||||
content=types.Content(
|
||||
role="model",
|
||||
parts=[content_block_to_part(cb) for cb in message.content],
|
||||
),
|
||||
# TODO: Deal with these later.
|
||||
# finish_reason=to_google_genai_finish_reason(message.stop_reason),
|
||||
# usage_metadata=types.GenerateContentResponseUsageMetadata(
|
||||
# prompt_token_count=message.usage.input_tokens,
|
||||
# candidates_token_count=message.usage.output_tokens,
|
||||
# total_token_count=(
|
||||
# message.usage.input_tokens + message.usage.output_tokens
|
||||
# ),
|
||||
# ),
|
||||
)
|
||||
|
||||
|
||||
def function_declaration_to_tool_param(
|
||||
function_declaration: types.FunctionDeclaration,
|
||||
) -> anthropic_types.ToolParam:
|
||||
assert function_declaration.name
|
||||
|
||||
properties = {}
|
||||
if (
|
||||
function_declaration.parameters
|
||||
and function_declaration.parameters.properties
|
||||
):
|
||||
for key, value in function_declaration.parameters.properties.items():
|
||||
value_dict = value.model_dump(exclude_none=True)
|
||||
if "type" in value_dict:
|
||||
value_dict["type"] = value_dict["type"].lower()
|
||||
properties[key] = value_dict
|
||||
|
||||
return anthropic_types.ToolParam(
|
||||
name=function_declaration.name,
|
||||
description=function_declaration.description or "",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class Claude(BaseLlm):
|
||||
model: str = "claude-3-5-sonnet-v2@20241022"
|
||||
|
||||
@staticmethod
|
||||
@override
|
||||
def supported_models() -> list[str]:
|
||||
return [r"claude-3-.*"]
|
||||
|
||||
@override
|
||||
async def generate_content_async(
|
||||
self, llm_request: LlmRequest, stream: bool = False
|
||||
) -> AsyncGenerator[LlmResponse, None]:
|
||||
messages = [
|
||||
content_to_message_param(content)
|
||||
for content in llm_request.contents or []
|
||||
]
|
||||
tools = NOT_GIVEN
|
||||
if (
|
||||
llm_request.config
|
||||
and llm_request.config.tools
|
||||
and llm_request.config.tools[0].function_declarations
|
||||
):
|
||||
tools = [
|
||||
function_declaration_to_tool_param(tool)
|
||||
for tool in llm_request.config.tools[0].function_declarations
|
||||
]
|
||||
tool_choice = (
|
||||
anthropic_types.ToolChoiceAutoParam(
|
||||
type="auto",
|
||||
# TODO: allow parallel tool use.
|
||||
disable_parallel_tool_use=True,
|
||||
)
|
||||
if llm_request.tools_dict
|
||||
else NOT_GIVEN
|
||||
)
|
||||
message = self._anthropic_client.messages.create(
|
||||
model=llm_request.model,
|
||||
system=llm_request.config.system_instruction,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
max_tokens=MAX_TOKEN,
|
||||
)
|
||||
logger.info(
|
||||
"Claude response: %s",
|
||||
message.model_dump_json(indent=2, exclude_none=True),
|
||||
)
|
||||
yield message_to_generate_content_response(message)
|
||||
|
||||
@cached_property
|
||||
def _anthropic_client(self) -> AnthropicVertex:
|
||||
if (
|
||||
"GOOGLE_CLOUD_PROJECT" not in os.environ
|
||||
or "GOOGLE_CLOUD_LOCATION" not in os.environ
|
||||
):
|
||||
raise ValueError(
|
||||
"GOOGLE_CLOUD_PROJECT and GOOGLE_CLOUD_LOCATION must be set for using"
|
||||
" Anthropic on Vertex."
|
||||
)
|
||||
|
||||
return AnthropicVertex(
|
||||
project_id=os.environ["GOOGLE_CLOUD_PROJECT"],
|
||||
region=os.environ["GOOGLE_CLOUD_LOCATION"],
|
||||
)
|
||||
@@ -0,0 +1,87 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import AsyncGenerator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from .base_llm_connection import BaseLlmConnection
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .llm_request import LlmRequest
|
||||
from .llm_response import LlmResponse
|
||||
|
||||
|
||||
class BaseLlm(BaseModel):
|
||||
"""The BaseLLM class.
|
||||
|
||||
Attributes:
|
||||
model: The name of the LLM, e.g. gemini-1.5-flash or gemini-1.5-flash-001.
|
||||
model_config: The model config
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
# This allows us to use arbitrary types in the model. E.g. PIL.Image.
|
||||
arbitrary_types_allowed=True,
|
||||
)
|
||||
"""The model config."""
|
||||
|
||||
model: str
|
||||
"""The name of the LLM, e.g. gemini-1.5-flash or gemini-1.5-flash-001."""
|
||||
|
||||
@classmethod
|
||||
def supported_models(cls) -> list[str]:
|
||||
"""Returns a list of supported models in regex for LlmRegistry."""
|
||||
return []
|
||||
|
||||
@abstractmethod
|
||||
async def generate_content_async(
|
||||
self, llm_request: LlmRequest, stream: bool = False
|
||||
) -> AsyncGenerator[LlmResponse, None]:
|
||||
"""Generates one content from the given contents and tools.
|
||||
|
||||
Args:
|
||||
llm_request: LlmRequest, the request to send to the LLM.
|
||||
stream: bool = False, whether to do streaming call.
|
||||
|
||||
Yields:
|
||||
a generator of types.Content.
|
||||
|
||||
For non-streaming call, it will only yield one Content.
|
||||
|
||||
For streaming call, it may yield more than one content, but all yielded
|
||||
contents should be treated as one content by merging the
|
||||
parts list.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f'Async generation is not supported for {self.model}.'
|
||||
)
|
||||
yield # AsyncGenerator requires a yield statement in function body.
|
||||
|
||||
def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
|
||||
"""Creates a live connection to the LLM.
|
||||
|
||||
Args:
|
||||
llm_request: LlmRequest, the request to send to the LLM.
|
||||
|
||||
Returns:
|
||||
BaseLlmConnection, the connection to the LLM.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f'Live connection is not supported for {self.model}.'
|
||||
)
|
||||
@@ -0,0 +1,76 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import AsyncGenerator
|
||||
from google.genai import types
|
||||
from .llm_response import LlmResponse
|
||||
|
||||
|
||||
class BaseLlmConnection:
|
||||
"""The base class for a live model connection."""
|
||||
|
||||
@abstractmethod
|
||||
async def send_history(self, history: list[types.Content]):
|
||||
"""Sends the conversation history to the model.
|
||||
|
||||
You call this method right after setting up the model connection.
|
||||
The model will respond if the last content is from user, otherwise it will
|
||||
wait for new user input before responding.
|
||||
|
||||
Args:
|
||||
history: The conversation history to send to the model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def send_content(self, content: types.Content):
|
||||
"""Sends a user content to the model.
|
||||
|
||||
The model will respond immediately upon receiving the content.
|
||||
If you send function responses, all parts in the content should be function
|
||||
responses.
|
||||
|
||||
Args:
|
||||
content: The content to send to the model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def send_realtime(self, blob: types.Blob):
|
||||
"""Sends a chunk of audio or a frame of video to the model in realtime.
|
||||
|
||||
The model may not respond immediately upon receiving the blob. It will do
|
||||
voice activity detection and decide when to respond.
|
||||
|
||||
Args:
|
||||
blob: The blob to send to the model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def receive(self) -> AsyncGenerator[LlmResponse, None]:
|
||||
"""Receives the model response using the llm server connection.
|
||||
|
||||
Args: None.
|
||||
|
||||
Yields:
|
||||
LlmResponse: The model response.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def close(self):
|
||||
"""Closes the llm server connection."""
|
||||
pass
|
||||
@@ -0,0 +1,200 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from google.genai import live
|
||||
from google.genai import types
|
||||
|
||||
from .base_llm_connection import BaseLlmConnection
|
||||
from .llm_response import LlmResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GeminiLlmConnection(BaseLlmConnection):
|
||||
"""The Gemini model connection."""
|
||||
|
||||
def __init__(self, gemini_session: live.AsyncSession):
|
||||
self._gemini_session = gemini_session
|
||||
|
||||
async def send_history(self, history: list[types.Content]):
|
||||
"""Sends the conversation history to the gemini model.
|
||||
|
||||
You call this method right after setting up the model connection.
|
||||
The model will respond if the last content is from user, otherwise it will
|
||||
wait for new user input before responding.
|
||||
|
||||
Args:
|
||||
history: The conversation history to send to the model.
|
||||
"""
|
||||
|
||||
# TODO: Remove this filter and translate unary contents to streaming
|
||||
# contents properly.
|
||||
|
||||
# We ignore any audio from user during the agent transfer phase
|
||||
contents = [
|
||||
content
|
||||
for content in history
|
||||
if content.parts and content.parts[0].text
|
||||
]
|
||||
|
||||
if contents:
|
||||
await self._gemini_session.send(
|
||||
input=types.LiveClientContent(
|
||||
turns=contents,
|
||||
turn_complete=contents[-1].role == 'user',
|
||||
),
|
||||
)
|
||||
else:
|
||||
logger.info('no content is sent')
|
||||
|
||||
async def send_content(self, content: types.Content):
|
||||
"""Sends a user content to the gemini model.
|
||||
|
||||
The model will respond immediately upon receiving the content.
|
||||
If you send function responses, all parts in the content should be function
|
||||
responses.
|
||||
|
||||
Args:
|
||||
content: The content to send to the model.
|
||||
"""
|
||||
|
||||
assert content.parts
|
||||
if content.parts[0].function_response:
|
||||
# All parts have to be function responses.
|
||||
function_responses = [part.function_response for part in content.parts]
|
||||
logger.debug('Sending LLM function response: %s', function_responses)
|
||||
await self._gemini_session.send(
|
||||
input=types.LiveClientToolResponse(
|
||||
function_responses=function_responses
|
||||
),
|
||||
)
|
||||
else:
|
||||
logger.debug('Sending LLM new content %s', content)
|
||||
await self._gemini_session.send(
|
||||
input=types.LiveClientContent(
|
||||
turns=[content],
|
||||
turn_complete=True,
|
||||
)
|
||||
)
|
||||
|
||||
async def send_realtime(self, blob: types.Blob):
|
||||
"""Sends a chunk of audio or a frame of video to the model in realtime.
|
||||
|
||||
Args:
|
||||
blob: The blob to send to the model.
|
||||
"""
|
||||
|
||||
input_blob = blob.model_dump()
|
||||
logger.debug('Sending LLM Blob: %s', input_blob)
|
||||
await self._gemini_session.send(input=input_blob)
|
||||
|
||||
def __build_full_text_response(self, text: str):
|
||||
"""Builds a full text response.
|
||||
|
||||
The text should not partial and the returned LlmResponse is not be
|
||||
partial.
|
||||
|
||||
Args:
|
||||
text: The text to be included in the response.
|
||||
|
||||
Returns:
|
||||
An LlmResponse containing the full text.
|
||||
"""
|
||||
return LlmResponse(
|
||||
content=types.Content(
|
||||
role='model',
|
||||
parts=[types.Part.from_text(text=text)],
|
||||
),
|
||||
)
|
||||
|
||||
async def receive(self) -> AsyncGenerator[LlmResponse, None]:
|
||||
"""Receives the model response using the llm server connection.
|
||||
|
||||
Yields:
|
||||
LlmResponse: The model response.
|
||||
"""
|
||||
|
||||
text = ''
|
||||
async for message in self._gemini_session.receive():
|
||||
logger.debug('Got LLM Live message: %s', message)
|
||||
if message.server_content:
|
||||
content = message.server_content.model_turn
|
||||
if content and content.parts:
|
||||
llm_response = LlmResponse(
|
||||
content=content, interrupted=message.server_content.interrupted
|
||||
)
|
||||
if content.parts[0].text:
|
||||
text += content.parts[0].text
|
||||
llm_response.partial = True
|
||||
# don't yield the merged text event when receiving audio data
|
||||
elif text and not content.parts[0].inline_data:
|
||||
yield self.__build_full_text_response(text)
|
||||
text = ''
|
||||
yield llm_response
|
||||
|
||||
if (
|
||||
message.server_content.output_transcription
|
||||
and message.server_content.output_transcription.text
|
||||
):
|
||||
# TODO: Right now, we just support output_transcription without
|
||||
# changing interface and data protocol. Later, we can consider to
|
||||
# support output_transcription as a separate field in LlmResponse.
|
||||
|
||||
# Transcription is always considered as partial event
|
||||
# We rely on other control signals to determine when to yield the
|
||||
# full text response(turn_complete, interrupted, or tool_call).
|
||||
text += message.server_content.output_transcription.text
|
||||
parts = [
|
||||
types.Part.from_text(
|
||||
text=message.server_content.output_transcription.text
|
||||
)
|
||||
]
|
||||
llm_response = LlmResponse(
|
||||
content=types.Content(role='model', parts=parts), partial=True
|
||||
)
|
||||
yield llm_response
|
||||
|
||||
if message.server_content.turn_complete:
|
||||
if text:
|
||||
yield self.__build_full_text_response(text)
|
||||
text = ''
|
||||
yield LlmResponse(
|
||||
turn_complete=True, interrupted=message.server_content.interrupted
|
||||
)
|
||||
break
|
||||
# in case of empty content or parts, we sill surface it
|
||||
# in case it's an interrupted message, we merge the previous partial
|
||||
# text. Other we don't merge. because content can be none when model
|
||||
# safety threshold is triggered
|
||||
if message.server_content.interrupted and text:
|
||||
yield self.__build_full_text_response(text)
|
||||
text = ''
|
||||
yield LlmResponse(interrupted=message.server_content.interrupted)
|
||||
if message.tool_call:
|
||||
if text:
|
||||
yield self.__build_full_text_response(text)
|
||||
text = ''
|
||||
parts = [
|
||||
types.Part(function_call=function_call)
|
||||
for function_call in message.tool_call.function_calls
|
||||
]
|
||||
yield LlmResponse(content=types.Content(role='model', parts=parts))
|
||||
|
||||
async def close(self):
|
||||
"""Closes the llm server connection."""
|
||||
|
||||
await self._gemini_session.close()
|
||||
@@ -0,0 +1,330 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
from functools import cached_property
|
||||
import logging
|
||||
import sys
|
||||
from typing import AsyncGenerator
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from google.genai import Client
|
||||
from google.genai import types
|
||||
from typing_extensions import override
|
||||
|
||||
from .. import version
|
||||
from .base_llm import BaseLlm
|
||||
from .base_llm_connection import BaseLlmConnection
|
||||
from .gemini_llm_connection import GeminiLlmConnection
|
||||
from .llm_response import LlmResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .llm_request import LlmRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_NEW_LINE = '\n'
|
||||
_EXCLUDED_PART_FIELD = {'inline_data': {'data'}}
|
||||
|
||||
|
||||
class Gemini(BaseLlm):
|
||||
"""Integration for Gemini models.
|
||||
|
||||
Attributes:
|
||||
model: The name of the Gemini model.
|
||||
"""
|
||||
|
||||
model: str = 'gemini-1.5-flash'
|
||||
|
||||
@staticmethod
|
||||
@override
|
||||
def supported_models() -> list[str]:
|
||||
"""Provides the list of supported models.
|
||||
|
||||
Returns:
|
||||
A list of supported models.
|
||||
"""
|
||||
|
||||
return [
|
||||
r'gemini-.*',
|
||||
# fine-tuned vertex endpoint pattern
|
||||
r'projects\/.+\/locations\/.+\/endpoints\/.+',
|
||||
# vertex gemini long name
|
||||
r'projects\/.+\/locations\/.+\/publishers\/google\/models\/gemini.+',
|
||||
]
|
||||
|
||||
async def generate_content_async(
|
||||
self, llm_request: LlmRequest, stream: bool = False
|
||||
) -> AsyncGenerator[LlmResponse, None]:
|
||||
"""Sends a request to the Gemini model.
|
||||
|
||||
Args:
|
||||
llm_request: LlmRequest, the request to send to the Gemini model.
|
||||
stream: bool = False, whether to do streaming call.
|
||||
|
||||
Yields:
|
||||
LlmResponse: The model response.
|
||||
"""
|
||||
|
||||
self._maybe_append_user_content(llm_request)
|
||||
logger.info(
|
||||
'Sending out request, model: %s, backend: %s, stream: %s',
|
||||
llm_request.model,
|
||||
self._api_backend,
|
||||
stream,
|
||||
)
|
||||
logger.info(_build_request_log(llm_request))
|
||||
|
||||
if stream:
|
||||
responses = await self.api_client.aio.models.generate_content_stream(
|
||||
model=llm_request.model,
|
||||
contents=llm_request.contents,
|
||||
config=llm_request.config,
|
||||
)
|
||||
response = None
|
||||
text = ''
|
||||
# for sse, similar as bidi (see receive method in gemini_llm_connecton.py),
|
||||
# we need to mark those text content as partial and after all partial
|
||||
# contents are sent, we send an accumulated event which contains all the
|
||||
# previous partial content. The only difference is bidi rely on
|
||||
# complete_turn flag to detect end while sse depends on finish_reason.
|
||||
async for response in responses:
|
||||
logger.info(_build_response_log(response))
|
||||
llm_response = LlmResponse.create(response)
|
||||
if (
|
||||
llm_response.content
|
||||
and llm_response.content.parts
|
||||
and llm_response.content.parts[0].text
|
||||
):
|
||||
text += llm_response.content.parts[0].text
|
||||
llm_response.partial = True
|
||||
elif text and (
|
||||
not llm_response.content
|
||||
or not llm_response.content.parts
|
||||
# don't yield the merged text event when receiving audio data
|
||||
or not llm_response.content.parts[0].inline_data
|
||||
):
|
||||
yield LlmResponse(
|
||||
content=types.ModelContent(
|
||||
parts=[types.Part.from_text(text=text)],
|
||||
),
|
||||
)
|
||||
text = ''
|
||||
yield llm_response
|
||||
if (
|
||||
text
|
||||
and response
|
||||
and response.candidates
|
||||
and response.candidates[0].finish_reason == types.FinishReason.STOP
|
||||
):
|
||||
yield LlmResponse(
|
||||
content=types.ModelContent(
|
||||
parts=[types.Part.from_text(text=text)],
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
response = await self.api_client.aio.models.generate_content(
|
||||
model=llm_request.model,
|
||||
contents=llm_request.contents,
|
||||
config=llm_request.config,
|
||||
)
|
||||
logger.info(_build_response_log(response))
|
||||
yield LlmResponse.create(response)
|
||||
|
||||
@cached_property
|
||||
def api_client(self) -> Client:
|
||||
"""Provides the api client.
|
||||
|
||||
Returns:
|
||||
The api client.
|
||||
"""
|
||||
return Client(
|
||||
http_options=types.HttpOptions(headers=self._tracking_headers)
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def _api_backend(self) -> str:
|
||||
return 'vertex' if self.api_client.vertexai else 'ml_dev'
|
||||
|
||||
@cached_property
|
||||
def _tracking_headers(self) -> dict[str, str]:
|
||||
framework_label = f'google-adk/{version.__version__}'
|
||||
language_label = 'gl-python/' + sys.version.split()[0]
|
||||
version_header_value = f'{framework_label} {language_label}'
|
||||
tracking_headers = {
|
||||
'x-goog-api-client': version_header_value,
|
||||
'user-agent': version_header_value,
|
||||
}
|
||||
return tracking_headers
|
||||
|
||||
@cached_property
|
||||
def _live_api_client(self) -> Client:
|
||||
if self._api_backend == 'vertex':
|
||||
# use default api version for vertex
|
||||
return Client(
|
||||
http_options=types.HttpOptions(headers=self._tracking_headers)
|
||||
)
|
||||
else:
|
||||
# use v1alpha for ml_dev
|
||||
api_version = 'v1alpha'
|
||||
return Client(
|
||||
http_options=types.HttpOptions(
|
||||
headers=self._tracking_headers, api_version=api_version
|
||||
)
|
||||
)
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection:
|
||||
"""Connects to the Gemini model and returns an llm connection.
|
||||
|
||||
Args:
|
||||
llm_request: LlmRequest, the request to send to the Gemini model.
|
||||
|
||||
Yields:
|
||||
BaseLlmConnection, the connection to the Gemini model.
|
||||
"""
|
||||
|
||||
llm_request.live_connect_config.system_instruction = types.Content(
|
||||
role='system',
|
||||
parts=[
|
||||
types.Part.from_text(text=llm_request.config.system_instruction)
|
||||
],
|
||||
)
|
||||
llm_request.live_connect_config.tools = llm_request.config.tools
|
||||
async with self._live_api_client.aio.live.connect(
|
||||
model=llm_request.model, config=llm_request.live_connect_config
|
||||
) as live_session:
|
||||
yield GeminiLlmConnection(live_session)
|
||||
|
||||
def _maybe_append_user_content(self, llm_request: LlmRequest):
|
||||
"""Appends a user content, so that model can continue to output.
|
||||
|
||||
Args:
|
||||
llm_request: LlmRequest, the request to send to the Gemini model.
|
||||
"""
|
||||
# If no content is provided, append a user content to hint model response
|
||||
# using system instruction.
|
||||
if not llm_request.contents:
|
||||
llm_request.contents.append(
|
||||
types.Content(
|
||||
role='user',
|
||||
parts=[
|
||||
types.Part(
|
||||
text=(
|
||||
'Handle the requests as specified in the System'
|
||||
' Instruction.'
|
||||
)
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Insert a user content to preserve user intent and to avoid empty
|
||||
# model response.
|
||||
if llm_request.contents[-1].role != 'user':
|
||||
llm_request.contents.append(
|
||||
types.Content(
|
||||
role='user',
|
||||
parts=[
|
||||
types.Part(
|
||||
text=(
|
||||
'Continue processing previous requests as instructed.'
|
||||
' Exit or provide a summary if no more outputs are'
|
||||
' needed.'
|
||||
)
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _build_function_declaration_log(
|
||||
func_decl: types.FunctionDeclaration,
|
||||
) -> str:
|
||||
param_str = '{}'
|
||||
if func_decl.parameters and func_decl.parameters.properties:
|
||||
param_str = str({
|
||||
k: v.model_dump(exclude_none=True)
|
||||
for k, v in func_decl.parameters.properties.items()
|
||||
})
|
||||
return_str = 'None'
|
||||
if func_decl.response:
|
||||
return_str = str(func_decl.response.model_dump(exclude_none=True))
|
||||
return f'{func_decl.name}: {param_str} -> {return_str}'
|
||||
|
||||
|
||||
def _build_request_log(req: LlmRequest) -> str:
|
||||
function_decls: list[types.FunctionDeclaration] = cast(
|
||||
list[types.FunctionDeclaration],
|
||||
req.config.tools[0].function_declarations if req.config.tools else [],
|
||||
)
|
||||
function_logs = (
|
||||
[
|
||||
_build_function_declaration_log(func_decl)
|
||||
for func_decl in function_decls
|
||||
]
|
||||
if function_decls
|
||||
else []
|
||||
)
|
||||
contents_logs = [
|
||||
content.model_dump_json(
|
||||
exclude_none=True,
|
||||
exclude={
|
||||
'parts': {
|
||||
i: _EXCLUDED_PART_FIELD for i in range(len(content.parts))
|
||||
}
|
||||
},
|
||||
)
|
||||
for content in req.contents
|
||||
]
|
||||
|
||||
return f"""
|
||||
LLM Request:
|
||||
-----------------------------------------------------------
|
||||
System Instruction:
|
||||
{req.config.system_instruction}
|
||||
-----------------------------------------------------------
|
||||
Contents:
|
||||
{_NEW_LINE.join(contents_logs)}
|
||||
-----------------------------------------------------------
|
||||
Functions:
|
||||
{_NEW_LINE.join(function_logs)}
|
||||
-----------------------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
def _build_response_log(resp: types.GenerateContentResponse) -> str:
|
||||
function_calls_text = []
|
||||
if function_calls := resp.function_calls:
|
||||
for func_call in function_calls:
|
||||
function_calls_text.append(
|
||||
f'name: {func_call.name}, args: {func_call.args}'
|
||||
)
|
||||
return f"""
|
||||
LLM Response:
|
||||
-----------------------------------------------------------
|
||||
Text:
|
||||
{resp.text}
|
||||
-----------------------------------------------------------
|
||||
Function calls:
|
||||
{_NEW_LINE.join(function_calls_text)}
|
||||
-----------------------------------------------------------
|
||||
Raw response:
|
||||
{resp.model_dump_json(exclude_none=True)}
|
||||
-----------------------------------------------------------
|
||||
"""
|
||||
690
.venv/lib/python3.10/site-packages/google/adk/models/lite_llm.py
Normal file
690
.venv/lib/python3.10/site-packages/google/adk/models/lite_llm.py
Normal file
@@ -0,0 +1,690 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import AsyncGenerator
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import Generator
|
||||
from typing import Iterable
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
from google.genai import types
|
||||
from litellm import acompletion
|
||||
from litellm import ChatCompletionAssistantMessage
|
||||
from litellm import ChatCompletionDeveloperMessage
|
||||
from litellm import ChatCompletionImageUrlObject
|
||||
from litellm import ChatCompletionMessageToolCall
|
||||
from litellm import ChatCompletionTextObject
|
||||
from litellm import ChatCompletionToolMessage
|
||||
from litellm import ChatCompletionUserMessage
|
||||
from litellm import ChatCompletionVideoUrlObject
|
||||
from litellm import completion
|
||||
from litellm import CustomStreamWrapper
|
||||
from litellm import Function
|
||||
from litellm import Message
|
||||
from litellm import ModelResponse
|
||||
from litellm import OpenAIMessageContent
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from typing_extensions import override
|
||||
|
||||
from .base_llm import BaseLlm
|
||||
from .llm_request import LlmRequest
|
||||
from .llm_response import LlmResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_NEW_LINE = "\n"
|
||||
_EXCLUDED_PART_FIELD = {"inline_data": {"data"}}
|
||||
|
||||
|
||||
class FunctionChunk(BaseModel):
|
||||
id: Optional[str]
|
||||
name: Optional[str]
|
||||
args: Optional[str]
|
||||
|
||||
|
||||
class TextChunk(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class LiteLLMClient:
|
||||
"""Provides acompletion method (for better testability)."""
|
||||
|
||||
async def acompletion(
|
||||
self, model, messages, tools, **kwargs
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
"""Asynchronously calls acompletion.
|
||||
|
||||
Args:
|
||||
model: The model name.
|
||||
messages: The messages to send to the model.
|
||||
tools: The tools to use for the model.
|
||||
**kwargs: Additional arguments to pass to acompletion.
|
||||
|
||||
Returns:
|
||||
The model response as a message.
|
||||
"""
|
||||
|
||||
return await acompletion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def completion(
|
||||
self, model, messages, tools, stream=False, **kwargs
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
"""Synchronously calls completion. This is used for streaming only.
|
||||
|
||||
Args:
|
||||
model: The model to use.
|
||||
messages: The messages to send.
|
||||
tools: The tools to use for the model.
|
||||
stream: Whether to stream the response.
|
||||
**kwargs: Additional arguments to pass to completion.
|
||||
|
||||
Returns:
|
||||
The response from the model.
|
||||
"""
|
||||
|
||||
return completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
stream=stream,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _safe_json_serialize(obj) -> str:
|
||||
"""Convert any Python object to a JSON-serializable type or string.
|
||||
|
||||
Args:
|
||||
obj: The object to serialize.
|
||||
|
||||
Returns:
|
||||
The JSON-serialized object string or string.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Try direct JSON serialization first
|
||||
return json.dumps(obj)
|
||||
except (TypeError, OverflowError):
|
||||
return str(obj)
|
||||
|
||||
|
||||
def _content_to_message_param(
|
||||
content: types.Content,
|
||||
) -> Union[Message, list[Message]]:
|
||||
"""Converts a types.Content to a litellm Message or list of Messages.
|
||||
|
||||
Handles multipart function responses by returning a list of
|
||||
ChatCompletionToolMessage objects if multiple function_response parts exist.
|
||||
|
||||
Args:
|
||||
content: The content to convert.
|
||||
|
||||
Returns:
|
||||
A litellm Message, a list of litellm Messages.
|
||||
"""
|
||||
|
||||
tool_messages = []
|
||||
for part in content.parts:
|
||||
if part.function_response:
|
||||
tool_messages.append(
|
||||
ChatCompletionToolMessage(
|
||||
role="tool",
|
||||
tool_call_id=part.function_response.id,
|
||||
content=_safe_json_serialize(part.function_response.response),
|
||||
)
|
||||
)
|
||||
if tool_messages:
|
||||
return tool_messages if len(tool_messages) > 1 else tool_messages[0]
|
||||
|
||||
# Handle user or assistant messages
|
||||
role = _to_litellm_role(content.role)
|
||||
message_content = _get_content(content.parts) or None
|
||||
|
||||
if role == "user":
|
||||
return ChatCompletionUserMessage(role="user", content=message_content)
|
||||
else: # assistant/model
|
||||
tool_calls = []
|
||||
content_present = False
|
||||
for part in content.parts:
|
||||
if part.function_call:
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
type="function",
|
||||
id=part.function_call.id,
|
||||
function=Function(
|
||||
name=part.function_call.name,
|
||||
arguments=part.function_call.args,
|
||||
),
|
||||
)
|
||||
)
|
||||
elif part.text or part.inline_data:
|
||||
content_present = True
|
||||
|
||||
final_content = message_content if content_present else None
|
||||
|
||||
return ChatCompletionAssistantMessage(
|
||||
role=role,
|
||||
content=final_content,
|
||||
tool_calls=tool_calls or None,
|
||||
)
|
||||
|
||||
|
||||
def _get_content(
|
||||
parts: Iterable[types.Part],
|
||||
) -> Union[OpenAIMessageContent, str]:
|
||||
"""Converts a list of parts to litellm content.
|
||||
|
||||
Args:
|
||||
parts: The parts to convert.
|
||||
|
||||
Returns:
|
||||
The litellm content.
|
||||
"""
|
||||
|
||||
content_objects = []
|
||||
for part in parts:
|
||||
if part.text:
|
||||
if len(parts) == 1:
|
||||
return part.text
|
||||
content_objects.append(
|
||||
ChatCompletionTextObject(
|
||||
type="text",
|
||||
text=part.text,
|
||||
)
|
||||
)
|
||||
elif (
|
||||
part.inline_data
|
||||
and part.inline_data.data
|
||||
and part.inline_data.mime_type
|
||||
):
|
||||
base64_string = base64.b64encode(part.inline_data.data).decode("utf-8")
|
||||
data_uri = f"data:{part.inline_data.mime_type};base64,{base64_string}"
|
||||
|
||||
if part.inline_data.mime_type.startswith("image"):
|
||||
content_objects.append(
|
||||
ChatCompletionImageUrlObject(
|
||||
type="image_url",
|
||||
image_url=data_uri,
|
||||
)
|
||||
)
|
||||
elif part.inline_data.mime_type.startswith("video"):
|
||||
content_objects.append(
|
||||
ChatCompletionVideoUrlObject(
|
||||
type="video_url",
|
||||
video_url=data_uri,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError("LiteLlm(BaseLlm) does not support this content part.")
|
||||
|
||||
return content_objects
|
||||
|
||||
|
||||
def _to_litellm_role(role: Optional[str]) -> Literal["user", "assistant"]:
|
||||
"""Converts a types.Content role to a litellm role.
|
||||
|
||||
Args:
|
||||
role: The types.Content role.
|
||||
|
||||
Returns:
|
||||
The litellm role.
|
||||
"""
|
||||
|
||||
if role in ["model", "assistant"]:
|
||||
return "assistant"
|
||||
return "user"
|
||||
|
||||
|
||||
TYPE_LABELS = {
|
||||
"STRING": "string",
|
||||
"NUMBER": "number",
|
||||
"BOOLEAN": "boolean",
|
||||
"OBJECT": "object",
|
||||
"ARRAY": "array",
|
||||
"INTEGER": "integer",
|
||||
}
|
||||
|
||||
|
||||
def _schema_to_dict(schema: types.Schema) -> dict:
|
||||
"""Recursively converts a types.Schema to a dictionary.
|
||||
|
||||
Args:
|
||||
schema: The schema to convert.
|
||||
|
||||
Returns:
|
||||
The dictionary representation of the schema.
|
||||
"""
|
||||
|
||||
schema_dict = schema.model_dump(exclude_none=True)
|
||||
if "type" in schema_dict:
|
||||
schema_dict["type"] = schema_dict["type"].lower()
|
||||
if "items" in schema_dict:
|
||||
if isinstance(schema_dict["items"], dict):
|
||||
schema_dict["items"] = _schema_to_dict(
|
||||
types.Schema.model_validate(schema_dict["items"])
|
||||
)
|
||||
elif isinstance(schema_dict["items"]["type"], types.Type):
|
||||
schema_dict["items"]["type"] = TYPE_LABELS[
|
||||
schema_dict["items"]["type"].value
|
||||
]
|
||||
if "properties" in schema_dict:
|
||||
properties = {}
|
||||
for key, value in schema_dict["properties"].items():
|
||||
if isinstance(value, types.Schema):
|
||||
properties[key] = _schema_to_dict(value)
|
||||
else:
|
||||
properties[key] = value
|
||||
if "type" in properties[key]:
|
||||
properties[key]["type"] = properties[key]["type"].lower()
|
||||
schema_dict["properties"] = properties
|
||||
return schema_dict
|
||||
|
||||
|
||||
def _function_declaration_to_tool_param(
|
||||
function_declaration: types.FunctionDeclaration,
|
||||
) -> dict:
|
||||
"""Converts a types.FunctionDeclaration to a openapi spec dictionary.
|
||||
|
||||
Args:
|
||||
function_declaration: The function declaration to convert.
|
||||
|
||||
Returns:
|
||||
The openapi spec dictionary representation of the function declaration.
|
||||
"""
|
||||
|
||||
assert function_declaration.name
|
||||
|
||||
properties = {}
|
||||
if (
|
||||
function_declaration.parameters
|
||||
and function_declaration.parameters.properties
|
||||
):
|
||||
for key, value in function_declaration.parameters.properties.items():
|
||||
properties[key] = _schema_to_dict(value)
|
||||
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function_declaration.name,
|
||||
"description": function_declaration.description or "",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _model_response_to_chunk(
|
||||
response: ModelResponse,
|
||||
) -> Generator[
|
||||
Tuple[Optional[Union[TextChunk, FunctionChunk]], Optional[str]], None, None
|
||||
]:
|
||||
"""Converts a litellm message to text or function chunk.
|
||||
|
||||
Args:
|
||||
response: The response from the model.
|
||||
|
||||
Yields:
|
||||
A tuple of text or function chunk and finish reason.
|
||||
"""
|
||||
|
||||
message = None
|
||||
if response.get("choices", None):
|
||||
message = response["choices"][0].get("message", None)
|
||||
finish_reason = response["choices"][0].get("finish_reason", None)
|
||||
# check streaming delta
|
||||
if message is None and response["choices"][0].get("delta", None):
|
||||
message = response["choices"][0]["delta"]
|
||||
|
||||
if message.get("content", None):
|
||||
yield TextChunk(text=message.get("content")), finish_reason
|
||||
|
||||
if message.get("tool_calls", None):
|
||||
for tool_call in message.get("tool_calls"):
|
||||
# aggregate tool_call
|
||||
if tool_call.type == "function":
|
||||
yield FunctionChunk(
|
||||
id=tool_call.id,
|
||||
name=tool_call.function.name,
|
||||
args=tool_call.function.arguments,
|
||||
), finish_reason
|
||||
|
||||
if finish_reason and not (
|
||||
message.get("content", None) or message.get("tool_calls", None)
|
||||
):
|
||||
yield None, finish_reason
|
||||
|
||||
if not message:
|
||||
yield None, None
|
||||
|
||||
|
||||
def _model_response_to_generate_content_response(
|
||||
response: ModelResponse,
|
||||
) -> LlmResponse:
|
||||
"""Converts a litellm response to LlmResponse.
|
||||
|
||||
Args:
|
||||
response: The model response.
|
||||
|
||||
Returns:
|
||||
The LlmResponse.
|
||||
"""
|
||||
|
||||
message = None
|
||||
if response.get("choices", None):
|
||||
message = response["choices"][0].get("message", None)
|
||||
|
||||
if not message:
|
||||
raise ValueError("No message in response")
|
||||
return _message_to_generate_content_response(message)
|
||||
|
||||
|
||||
def _message_to_generate_content_response(
|
||||
message: Message, is_partial: bool = False
|
||||
) -> LlmResponse:
|
||||
"""Converts a litellm message to LlmResponse.
|
||||
|
||||
Args:
|
||||
message: The message to convert.
|
||||
is_partial: Whether the message is partial.
|
||||
|
||||
Returns:
|
||||
The LlmResponse.
|
||||
"""
|
||||
|
||||
parts = []
|
||||
if message.get("content", None):
|
||||
parts.append(types.Part.from_text(text=message.get("content")))
|
||||
|
||||
if message.get("tool_calls", None):
|
||||
for tool_call in message.get("tool_calls"):
|
||||
if tool_call.type == "function":
|
||||
part = types.Part.from_function_call(
|
||||
name=tool_call.function.name,
|
||||
args=json.loads(tool_call.function.arguments or "{}"),
|
||||
)
|
||||
part.function_call.id = tool_call.id
|
||||
parts.append(part)
|
||||
|
||||
return LlmResponse(
|
||||
content=types.Content(role="model", parts=parts), partial=is_partial
|
||||
)
|
||||
|
||||
|
||||
def _get_completion_inputs(
|
||||
llm_request: LlmRequest,
|
||||
) -> tuple[Iterable[Message], Iterable[dict]]:
|
||||
"""Converts an LlmRequest to litellm inputs.
|
||||
|
||||
Args:
|
||||
llm_request: The LlmRequest to convert.
|
||||
|
||||
Returns:
|
||||
The litellm inputs (message list and tool dictionary).
|
||||
"""
|
||||
messages = []
|
||||
for content in llm_request.contents or []:
|
||||
message_param_or_list = _content_to_message_param(content)
|
||||
if isinstance(message_param_or_list, list):
|
||||
messages.extend(message_param_or_list)
|
||||
elif message_param_or_list: # Ensure it's not None before appending
|
||||
messages.append(message_param_or_list)
|
||||
|
||||
if llm_request.config.system_instruction:
|
||||
messages.insert(
|
||||
0,
|
||||
ChatCompletionDeveloperMessage(
|
||||
role="developer",
|
||||
content=llm_request.config.system_instruction,
|
||||
),
|
||||
)
|
||||
|
||||
tools = None
|
||||
if (
|
||||
llm_request.config
|
||||
and llm_request.config.tools
|
||||
and llm_request.config.tools[0].function_declarations
|
||||
):
|
||||
tools = [
|
||||
_function_declaration_to_tool_param(tool)
|
||||
for tool in llm_request.config.tools[0].function_declarations
|
||||
]
|
||||
return messages, tools
|
||||
|
||||
|
||||
def _build_function_declaration_log(
|
||||
func_decl: types.FunctionDeclaration,
|
||||
) -> str:
|
||||
"""Builds a function declaration log.
|
||||
|
||||
Args:
|
||||
func_decl: The function declaration to convert.
|
||||
|
||||
Returns:
|
||||
The function declaration log.
|
||||
"""
|
||||
|
||||
param_str = "{}"
|
||||
if func_decl.parameters and func_decl.parameters.properties:
|
||||
param_str = str({
|
||||
k: v.model_dump(exclude_none=True)
|
||||
for k, v in func_decl.parameters.properties.items()
|
||||
})
|
||||
return_str = "None"
|
||||
if func_decl.response:
|
||||
return_str = str(func_decl.response.model_dump(exclude_none=True))
|
||||
return f"{func_decl.name}: {param_str} -> {return_str}"
|
||||
|
||||
|
||||
def _build_request_log(req: LlmRequest) -> str:
|
||||
"""Builds a request log.
|
||||
|
||||
Args:
|
||||
req: The request to convert.
|
||||
|
||||
Returns:
|
||||
The request log.
|
||||
"""
|
||||
|
||||
function_decls: list[types.FunctionDeclaration] = cast(
|
||||
list[types.FunctionDeclaration],
|
||||
req.config.tools[0].function_declarations if req.config.tools else [],
|
||||
)
|
||||
function_logs = (
|
||||
[
|
||||
_build_function_declaration_log(func_decl)
|
||||
for func_decl in function_decls
|
||||
]
|
||||
if function_decls
|
||||
else []
|
||||
)
|
||||
contents_logs = [
|
||||
content.model_dump_json(
|
||||
exclude_none=True,
|
||||
exclude={
|
||||
"parts": {
|
||||
i: _EXCLUDED_PART_FIELD for i in range(len(content.parts))
|
||||
}
|
||||
},
|
||||
)
|
||||
for content in req.contents
|
||||
]
|
||||
|
||||
return f"""
|
||||
LLM Request:
|
||||
-----------------------------------------------------------
|
||||
System Instruction:
|
||||
{req.config.system_instruction}
|
||||
-----------------------------------------------------------
|
||||
Contents:
|
||||
{_NEW_LINE.join(contents_logs)}
|
||||
-----------------------------------------------------------
|
||||
Functions:
|
||||
{_NEW_LINE.join(function_logs)}
|
||||
-----------------------------------------------------------
|
||||
"""
|
||||
|
||||
|
||||
class LiteLlm(BaseLlm):
|
||||
"""Wrapper around litellm.
|
||||
|
||||
This wrapper can be used with any of the models supported by litellm. The
|
||||
environment variable(s) needed for authenticating with the model endpoint must
|
||||
be set prior to instantiating this class.
|
||||
|
||||
Example usage:
|
||||
```
|
||||
os.environ["VERTEXAI_PROJECT"] = "your-gcp-project-id"
|
||||
os.environ["VERTEXAI_LOCATION"] = "your-gcp-location"
|
||||
|
||||
agent = Agent(
|
||||
model=LiteLlm(model="vertex_ai/claude-3-7-sonnet@20250219"),
|
||||
...
|
||||
)
|
||||
```
|
||||
|
||||
Attributes:
|
||||
model: The name of the LiteLlm model.
|
||||
llm_client: The LLM client to use for the model.
|
||||
model_config: The model config.
|
||||
"""
|
||||
|
||||
llm_client: LiteLLMClient = Field(default_factory=LiteLLMClient)
|
||||
"""The LLM client to use for the model."""
|
||||
|
||||
_additional_args: Dict[str, Any] = None
|
||||
|
||||
def __init__(self, model: str, **kwargs):
|
||||
"""Initializes the LiteLlm class.
|
||||
|
||||
Args:
|
||||
model: The name of the LiteLlm model.
|
||||
**kwargs: Additional arguments to pass to the litellm completion api.
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
self._additional_args = kwargs
|
||||
# preventing generation call with llm_client
|
||||
# and overriding messages, tools and stream which are managed internally
|
||||
self._additional_args.pop("llm_client", None)
|
||||
self._additional_args.pop("messages", None)
|
||||
self._additional_args.pop("tools", None)
|
||||
# public api called from runner determines to stream or not
|
||||
self._additional_args.pop("stream", None)
|
||||
|
||||
async def generate_content_async(
|
||||
self, llm_request: LlmRequest, stream: bool = False
|
||||
) -> AsyncGenerator[LlmResponse, None]:
|
||||
"""Generates content asynchronously.
|
||||
|
||||
Args:
|
||||
llm_request: LlmRequest, the request to send to the LiteLlm model.
|
||||
stream: bool = False, whether to do streaming call.
|
||||
|
||||
Yields:
|
||||
LlmResponse: The model response.
|
||||
"""
|
||||
|
||||
logger.info(_build_request_log(llm_request))
|
||||
|
||||
messages, tools = _get_completion_inputs(llm_request)
|
||||
|
||||
completion_args = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
}
|
||||
completion_args.update(self._additional_args)
|
||||
|
||||
if stream:
|
||||
text = ""
|
||||
function_name = ""
|
||||
function_args = ""
|
||||
function_id = None
|
||||
completion_args["stream"] = True
|
||||
for part in self.llm_client.completion(**completion_args):
|
||||
for chunk, finish_reason in _model_response_to_chunk(part):
|
||||
if isinstance(chunk, FunctionChunk):
|
||||
if chunk.name:
|
||||
function_name += chunk.name
|
||||
if chunk.args:
|
||||
function_args += chunk.args
|
||||
function_id = chunk.id or function_id
|
||||
elif isinstance(chunk, TextChunk):
|
||||
text += chunk.text
|
||||
yield _message_to_generate_content_response(
|
||||
ChatCompletionAssistantMessage(
|
||||
role="assistant",
|
||||
content=chunk.text,
|
||||
),
|
||||
is_partial=True,
|
||||
)
|
||||
if finish_reason == "tool_calls" and function_id:
|
||||
yield _message_to_generate_content_response(
|
||||
ChatCompletionAssistantMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
type="function",
|
||||
id=function_id,
|
||||
function=Function(
|
||||
name=function_name,
|
||||
arguments=function_args,
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
function_name = ""
|
||||
function_args = ""
|
||||
function_id = None
|
||||
elif finish_reason == "stop" and text:
|
||||
yield _message_to_generate_content_response(
|
||||
ChatCompletionAssistantMessage(role="assistant", content=text)
|
||||
)
|
||||
text = ""
|
||||
|
||||
else:
|
||||
response = await self.llm_client.acompletion(**completion_args)
|
||||
yield _model_response_to_generate_content_response(response)
|
||||
|
||||
@staticmethod
|
||||
@override
|
||||
def supported_models() -> list[str]:
|
||||
"""Provides the list of supported models.
|
||||
|
||||
LiteLlm supports all models supported by litellm. We do not keep track of
|
||||
these models here. So we return an empty list.
|
||||
|
||||
Returns:
|
||||
A list of supported models.
|
||||
"""
|
||||
|
||||
return []
|
||||
@@ -0,0 +1,98 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
|
||||
from ..tools.base_tool import BaseTool
|
||||
|
||||
|
||||
class LlmRequest(BaseModel):
|
||||
"""LLM request class that allows passing in tools, output schema and system
|
||||
|
||||
instructions to the model.
|
||||
|
||||
Attributes:
|
||||
model: The model name.
|
||||
contents: The contents to send to the model.
|
||||
config: Additional config for the generate content request.
|
||||
tools_dict: The tools dictionary.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
"""The model config."""
|
||||
|
||||
model: Optional[str] = None
|
||||
"""The model name."""
|
||||
|
||||
contents: list[types.Content] = Field(default_factory=list)
|
||||
"""The contents to send to the model."""
|
||||
|
||||
config: Optional[types.GenerateContentConfig] = None
|
||||
live_connect_config: types.LiveConnectConfig = types.LiveConnectConfig()
|
||||
"""Additional config for the generate content request.
|
||||
|
||||
tools in generate_content_config should not be set.
|
||||
"""
|
||||
tools_dict: dict[str, BaseTool] = Field(default_factory=dict, exclude=True)
|
||||
"""The tools dictionary."""
|
||||
|
||||
def append_instructions(self, instructions: list[str]) -> None:
|
||||
"""Appends instructions to the system instruction.
|
||||
|
||||
Args:
|
||||
instructions: The instructions to append.
|
||||
"""
|
||||
|
||||
if self.config.system_instruction:
|
||||
self.config.system_instruction += '\n\n' + '\n\n'.join(instructions)
|
||||
else:
|
||||
self.config.system_instruction = '\n\n'.join(instructions)
|
||||
|
||||
def append_tools(self, tools: list[BaseTool]) -> None:
|
||||
"""Appends tools to the request.
|
||||
|
||||
Args:
|
||||
tools: The tools to append.
|
||||
"""
|
||||
|
||||
if not tools:
|
||||
return
|
||||
declarations = []
|
||||
for tool in tools:
|
||||
if isinstance(tool, BaseTool):
|
||||
declaration = tool._get_declaration()
|
||||
else:
|
||||
declaration = tool.get_declaration()
|
||||
if declaration:
|
||||
declarations.append(declaration)
|
||||
self.tools_dict[tool.name] = tool
|
||||
if declarations:
|
||||
self.config.tools.append(types.Tool(function_declarations=declarations))
|
||||
|
||||
def set_output_schema(self, base_model: type[BaseModel]) -> None:
|
||||
"""Sets the output schema for the request.
|
||||
|
||||
Args:
|
||||
base_model: The pydantic base model to set the output schema to.
|
||||
"""
|
||||
|
||||
self.config.response_schema = base_model
|
||||
self.config.response_mime_type = 'application/json'
|
||||
@@ -0,0 +1,120 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from google.genai import types
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
|
||||
|
||||
class LlmResponse(BaseModel):
|
||||
"""LLM response class that provides the first candidate response from the
|
||||
|
||||
model if available. Otherwise, returns error code and message.
|
||||
|
||||
Attributes:
|
||||
content: The content of the response.
|
||||
grounding_metadata: The grounding metadata of the response.
|
||||
partial: Indicates whether the text content is part of a unfinished text
|
||||
stream. Only used for streaming mode and when the content is plain text.
|
||||
turn_complete: Indicates whether the response from the model is complete.
|
||||
Only used for streaming mode.
|
||||
error_code: Error code if the response is an error. Code varies by model.
|
||||
error_message: Error message if the response is an error.
|
||||
interrupted: Flag indicating that LLM was interrupted when generating the
|
||||
content. Usually it's due to user interruption during a bidi streaming.
|
||||
custom_metadata: The custom metadata of the LlmResponse.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra='forbid')
|
||||
"""The model config."""
|
||||
|
||||
content: Optional[types.Content] = None
|
||||
"""The content of the response."""
|
||||
|
||||
grounding_metadata: Optional[types.GroundingMetadata] = None
|
||||
"""The grounding metadata of the response."""
|
||||
|
||||
partial: Optional[bool] = None
|
||||
"""Indicates whether the text content is part of a unfinished text stream.
|
||||
|
||||
Only used for streaming mode and when the content is plain text.
|
||||
"""
|
||||
|
||||
turn_complete: Optional[bool] = None
|
||||
"""Indicates whether the response from the model is complete.
|
||||
|
||||
Only used for streaming mode.
|
||||
"""
|
||||
|
||||
error_code: Optional[str] = None
|
||||
"""Error code if the response is an error. Code varies by model."""
|
||||
|
||||
error_message: Optional[str] = None
|
||||
"""Error message if the response is an error."""
|
||||
|
||||
interrupted: Optional[bool] = None
|
||||
"""Flag indicating that LLM was interrupted when generating the content.
|
||||
Usually it's due to user interruption during a bidi streaming.
|
||||
"""
|
||||
|
||||
custom_metadata: Optional[dict[str, Any]] = None
|
||||
"""The custom metadata of the LlmResponse.
|
||||
|
||||
An optional key-value pair to label an LlmResponse.
|
||||
|
||||
NOTE: the entire dict must be JSON serializable.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
generate_content_response: types.GenerateContentResponse,
|
||||
) -> 'LlmResponse':
|
||||
"""Creates an LlmResponse from a GenerateContentResponse.
|
||||
|
||||
Args:
|
||||
generate_content_response: The GenerateContentResponse to create the
|
||||
LlmResponse from.
|
||||
|
||||
Returns:
|
||||
The LlmResponse.
|
||||
"""
|
||||
|
||||
if generate_content_response.candidates:
|
||||
candidate = generate_content_response.candidates[0]
|
||||
if candidate.content and candidate.content.parts:
|
||||
return LlmResponse(
|
||||
content=candidate.content,
|
||||
grounding_metadata=candidate.grounding_metadata,
|
||||
)
|
||||
else:
|
||||
return LlmResponse(
|
||||
error_code=candidate.finish_reason,
|
||||
error_message=candidate.finish_message,
|
||||
)
|
||||
else:
|
||||
if generate_content_response.prompt_feedback:
|
||||
prompt_feedback = generate_content_response.prompt_feedback
|
||||
return LlmResponse(
|
||||
error_code=prompt_feedback.block_reason,
|
||||
error_message=prompt_feedback.block_reason_message,
|
||||
)
|
||||
else:
|
||||
return LlmResponse(
|
||||
error_code='UNKNOWN_ERROR',
|
||||
error_message='Unknown error.',
|
||||
)
|
||||
102
.venv/lib/python3.10/site-packages/google/adk/models/registry.py
Normal file
102
.venv/lib/python3.10/site-packages/google/adk/models/registry.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""The registry class for model."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .base_llm import BaseLlm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_llm_registry_dict: dict[str, type[BaseLlm]] = {}
|
||||
"""Registry for LLMs.
|
||||
|
||||
Key is the regex that matches the model name.
|
||||
Value is the class that implements the model.
|
||||
"""
|
||||
|
||||
|
||||
class LLMRegistry:
|
||||
"""Registry for LLMs."""
|
||||
|
||||
@staticmethod
|
||||
def new_llm(model: str) -> BaseLlm:
|
||||
"""Creates a new LLM instance.
|
||||
|
||||
Args:
|
||||
model: The model name.
|
||||
|
||||
Returns:
|
||||
The LLM instance.
|
||||
"""
|
||||
|
||||
return LLMRegistry.resolve(model)(model=model)
|
||||
|
||||
@staticmethod
|
||||
def _register(model_name_regex: str, llm_cls: type[BaseLlm]):
|
||||
"""Registers a new LLM class.
|
||||
|
||||
Args:
|
||||
model_name_regex: The regex that matches the model name.
|
||||
llm_cls: The class that implements the model.
|
||||
"""
|
||||
|
||||
if model_name_regex in _llm_registry_dict:
|
||||
logger.info(
|
||||
'Updating LLM class for %s from %s to %s',
|
||||
model_name_regex,
|
||||
_llm_registry_dict[model_name_regex],
|
||||
llm_cls,
|
||||
)
|
||||
|
||||
_llm_registry_dict[model_name_regex] = llm_cls
|
||||
|
||||
@staticmethod
|
||||
def register(llm_cls: type[BaseLlm]):
|
||||
"""Registers a new LLM class.
|
||||
|
||||
Args:
|
||||
llm_cls: The class that implements the model.
|
||||
"""
|
||||
|
||||
for regex in llm_cls.supported_models():
|
||||
LLMRegistry._register(regex, llm_cls)
|
||||
|
||||
@staticmethod
|
||||
@lru_cache(maxsize=32)
|
||||
def resolve(model: str) -> type[BaseLlm]:
|
||||
"""Resolves the model to a BaseLlm subclass.
|
||||
|
||||
Args:
|
||||
model: The model name.
|
||||
|
||||
Returns:
|
||||
The BaseLlm subclass.
|
||||
Raises:
|
||||
ValueError: If the model is not found.
|
||||
"""
|
||||
|
||||
for regex, llm_class in _llm_registry_dict.items():
|
||||
if re.compile(regex).fullmatch(model):
|
||||
return llm_class
|
||||
|
||||
raise ValueError(f'Model {model} not found.')
|
||||
Reference in New Issue
Block a user