# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import contextlib from functools import cached_property import logging import sys from typing import AsyncGenerator from typing import cast from typing import TYPE_CHECKING from google.genai import Client from google.genai import types from typing_extensions import override from .. import version from .base_llm import BaseLlm from .base_llm_connection import BaseLlmConnection from .gemini_llm_connection import GeminiLlmConnection from .llm_response import LlmResponse if TYPE_CHECKING: from .llm_request import LlmRequest logger = logging.getLogger(__name__) _NEW_LINE = '\n' _EXCLUDED_PART_FIELD = {'inline_data': {'data'}} class Gemini(BaseLlm): """Integration for Gemini models. Attributes: model: The name of the Gemini model. """ model: str = 'gemini-1.5-flash' @staticmethod @override def supported_models() -> list[str]: """Provides the list of supported models. Returns: A list of supported models. """ return [ r'gemini-.*', # fine-tuned vertex endpoint pattern r'projects\/.+\/locations\/.+\/endpoints\/.+', # vertex gemini long name r'projects\/.+\/locations\/.+\/publishers\/google\/models\/gemini.+', ] async def generate_content_async( self, llm_request: LlmRequest, stream: bool = False ) -> AsyncGenerator[LlmResponse, None]: """Sends a request to the Gemini model. Args: llm_request: LlmRequest, the request to send to the Gemini model. stream: bool = False, whether to do streaming call. Yields: LlmResponse: The model response. """ self._maybe_append_user_content(llm_request) logger.info( 'Sending out request, model: %s, backend: %s, stream: %s', llm_request.model, self._api_backend, stream, ) logger.info(_build_request_log(llm_request)) if stream: responses = await self.api_client.aio.models.generate_content_stream( model=llm_request.model, contents=llm_request.contents, config=llm_request.config, ) response = None text = '' # for sse, similar as bidi (see receive method in gemini_llm_connecton.py), # we need to mark those text content as partial and after all partial # contents are sent, we send an accumulated event which contains all the # previous partial content. The only difference is bidi rely on # complete_turn flag to detect end while sse depends on finish_reason. async for response in responses: logger.info(_build_response_log(response)) llm_response = LlmResponse.create(response) if ( llm_response.content and llm_response.content.parts and llm_response.content.parts[0].text ): text += llm_response.content.parts[0].text llm_response.partial = True elif text and ( not llm_response.content or not llm_response.content.parts # don't yield the merged text event when receiving audio data or not llm_response.content.parts[0].inline_data ): yield LlmResponse( content=types.ModelContent( parts=[types.Part.from_text(text=text)], ), ) text = '' yield llm_response if ( text and response and response.candidates and response.candidates[0].finish_reason == types.FinishReason.STOP ): yield LlmResponse( content=types.ModelContent( parts=[types.Part.from_text(text=text)], ), ) else: response = await self.api_client.aio.models.generate_content( model=llm_request.model, contents=llm_request.contents, config=llm_request.config, ) logger.info(_build_response_log(response)) yield LlmResponse.create(response) @cached_property def api_client(self) -> Client: """Provides the api client. Returns: The api client. """ return Client( http_options=types.HttpOptions(headers=self._tracking_headers) ) @cached_property def _api_backend(self) -> str: return 'vertex' if self.api_client.vertexai else 'ml_dev' @cached_property def _tracking_headers(self) -> dict[str, str]: framework_label = f'google-adk/{version.__version__}' language_label = 'gl-python/' + sys.version.split()[0] version_header_value = f'{framework_label} {language_label}' tracking_headers = { 'x-goog-api-client': version_header_value, 'user-agent': version_header_value, } return tracking_headers @cached_property def _live_api_client(self) -> Client: if self._api_backend == 'vertex': #use beta version for vertex api api_version = 'v1beta1' # use default api version for vertex return Client( http_options=types.HttpOptions(headers=self._tracking_headers,api_version=api_version) ) else: # use v1alpha for ml_dev api_version = 'v1alpha' return Client( http_options=types.HttpOptions( headers=self._tracking_headers, api_version=api_version ) ) @contextlib.asynccontextmanager async def connect(self, llm_request: LlmRequest) -> BaseLlmConnection: """Connects to the Gemini model and returns an llm connection. Args: llm_request: LlmRequest, the request to send to the Gemini model. Yields: BaseLlmConnection, the connection to the Gemini model. """ llm_request.live_connect_config.system_instruction = types.Content( role='system', parts=[ types.Part.from_text(text=llm_request.config.system_instruction) ], ) llm_request.live_connect_config.tools = llm_request.config.tools async with self._live_api_client.aio.live.connect( model=llm_request.model, config=llm_request.live_connect_config ) as live_session: yield GeminiLlmConnection(live_session) def _build_function_declaration_log( func_decl: types.FunctionDeclaration, ) -> str: param_str = '{}' if func_decl.parameters and func_decl.parameters.properties: param_str = str({ k: v.model_dump(exclude_none=True) for k, v in func_decl.parameters.properties.items() }) return_str = 'None' if func_decl.response: return_str = str(func_decl.response.model_dump(exclude_none=True)) return f'{func_decl.name}: {param_str} -> {return_str}' def _build_request_log(req: LlmRequest) -> str: function_decls: list[types.FunctionDeclaration] = cast( list[types.FunctionDeclaration], req.config.tools[0].function_declarations if req.config.tools else [], ) function_logs = ( [ _build_function_declaration_log(func_decl) for func_decl in function_decls ] if function_decls else [] ) contents_logs = [ content.model_dump_json( exclude_none=True, exclude={ 'parts': { i: _EXCLUDED_PART_FIELD for i in range(len(content.parts)) } }, ) for content in req.contents ] return f""" LLM Request: ----------------------------------------------------------- System Instruction: {req.config.system_instruction} ----------------------------------------------------------- Contents: {_NEW_LINE.join(contents_logs)} ----------------------------------------------------------- Functions: {_NEW_LINE.join(function_logs)} ----------------------------------------------------------- """ def _build_response_log(resp: types.GenerateContentResponse) -> str: function_calls_text = [] if function_calls := resp.function_calls: for func_call in function_calls: function_calls_text.append( f'name: {func_call.name}, args: {func_call.args}' ) return f""" LLM Response: ----------------------------------------------------------- Text: {resp.text} ----------------------------------------------------------- Function calls: {_NEW_LINE.join(function_calls_text)} ----------------------------------------------------------- Raw response: {resp.model_dump_json(exclude_none=True)} ----------------------------------------------------------- """