# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from collections.abc import Iterator import sys from typing import AsyncIterator, Awaitable, Optional, Union, get_args from . import _transformers as t from . import types from .models import AsyncModels, Models from .types import Content, ContentOrDict, GenerateContentConfigOrDict, GenerateContentResponse, Part, PartUnionDict if sys.version_info >= (3, 10): from typing import TypeGuard else: from typing_extensions import TypeGuard def _validate_content(content: Content) -> bool: if not content.parts: return False for part in content.parts: if part == Part(): return False if part.text is not None and part.text == "": return False return True def _validate_contents(contents: list[Content]) -> bool: if not contents: return False for content in contents: if not _validate_content(content): return False return True def _validate_response(response: GenerateContentResponse) -> bool: if not response.candidates: return False if not response.candidates[0].content: return False return _validate_content(response.candidates[0].content) def _extract_curated_history( comprehensive_history: list[Content], ) -> list[Content]: """Extracts the curated (valid) history from a comprehensive history. The comprehensive history contains all turns (user input and model responses), including any invalid or rejected model outputs. This function filters that history to return only the valid turns. A "turn" starts with one user input (a single content) and then follows by corresponding model response (which may consist of multiple contents). Turns are assumed to alternate: user input, model output, user input, model output, etc. Args: comprehensive_history: A list representing the complete chat history. Including invalid turns. Returns: curated history, which is a list of valid turns. """ if not comprehensive_history: return [] curated_history = [] length = len(comprehensive_history) i = 0 current_input = comprehensive_history[i] if current_input.role != "user": raise ValueError("History must start with a user turn.") while i < length: if comprehensive_history[i].role not in ["user", "model"]: raise ValueError( f"Role must be user or model, but got {comprehensive_history[i].role}" ) if comprehensive_history[i].role == "user": current_input = comprehensive_history[i] i += 1 else: current_output = [] is_valid = True while i < length and comprehensive_history[i].role == "model": current_output.append(comprehensive_history[i]) if is_valid and not _validate_content(comprehensive_history[i]): is_valid = False i += 1 if is_valid: curated_history.append(current_input) curated_history.extend(current_output) return curated_history class _BaseChat: """Base chat session.""" def __init__( self, *, model: str, config: Optional[GenerateContentConfigOrDict] = None, history: list[ContentOrDict], ): self._model = model self._config = config content_models = [] for content in history: if not isinstance(content, Content): content_model = Content.model_validate(content) else: content_model = content content_models.append(content_model) self._comprehensive_history = content_models """Comprehensive history is the full history of the chat, including turns of the invalid contents from the model and their associated inputs. """ self._curated_history = _extract_curated_history(content_models) """Curated history is the set of valid turns that will be used in the subsequent send requests. """ def record_history( self, user_input: Content, model_output: list[Content], automatic_function_calling_history: list[Content], is_valid: bool, ) -> None: """Records the chat history. Maintaining both comprehensive and curated histories. Args: user_input: The user's input content. model_output: A list of `Content` from the model's response. This can be an empty list if the model produced no output. automatic_function_calling_history: A list of `Content` representing the history of automatic function calls, including the user input as the first entry. is_valid: A boolean flag indicating whether the current model output is considered valid. """ input_contents = ( # Because the AFC input contains the entire curated chat history in # addition to the new user input, we need to truncate the AFC history # to deduplicate the existing chat history. automatic_function_calling_history[len(self._curated_history):] if automatic_function_calling_history else [user_input] ) # Appends an empty content when model returns empty response, so that the # history is always alternating between user and model. output_contents = ( model_output if model_output else [Content(role="model", parts=[])] ) self._comprehensive_history.extend(input_contents) self._comprehensive_history.extend(output_contents) if is_valid: self._curated_history.extend(input_contents) self._curated_history.extend(output_contents) def get_history(self, curated: bool = False) -> list[Content]: """Returns the chat history. Args: curated: A boolean flag indicating whether to return the curated (valid) history or the comprehensive (all turns) history. Defaults to False (returns the comprehensive history). Returns: A list of `Content` objects representing the chat history. """ if curated: return self._curated_history else: return self._comprehensive_history def _is_part_type( contents: Union[list[PartUnionDict], PartUnionDict], ) -> TypeGuard[t.ContentType]: if isinstance(contents, list): return all(_is_part_type(part) for part in contents) else: allowed_part_types = get_args(types.PartUnion) if type(contents) in allowed_part_types: return True else: # Some images don't pass isinstance(item, PIL.Image.Image) # For example if types.PIL_Image is not None and isinstance(contents, types.PIL_Image): return True return False class Chat(_BaseChat): """Chat session.""" def __init__( self, *, modules: Models, model: str, config: Optional[GenerateContentConfigOrDict] = None, history: list[ContentOrDict], ): self._modules = modules super().__init__( model=model, config=config, history=history, ) def send_message( self, message: Union[list[PartUnionDict], PartUnionDict], config: Optional[GenerateContentConfigOrDict] = None, ) -> GenerateContentResponse: """Sends the conversation history with the additional message and returns the model's response. Args: message: The message to send to the model. config: Optional config to override the default Chat config for this request. Returns: The model's response. Usage: .. code-block:: python chat = client.chats.create(model='gemini-1.5-flash') response = chat.send_message('tell me a story') """ if not _is_part_type(message): raise ValueError( f"Message must be a valid part type: {types.PartUnion} or" f" {types.PartUnionDict}, got {type(message)}" ) input_content = t.t_content(self._modules._api_client, message) response = self._modules.generate_content( model=self._model, contents=self._curated_history + [input_content], # type: ignore[arg-type] config=config if config else self._config, ) model_output = ( [response.candidates[0].content] if response.candidates and response.candidates[0].content else [] ) automatic_function_calling_history = ( response.automatic_function_calling_history if response.automatic_function_calling_history else [] ) self.record_history( user_input=input_content, model_output=model_output, automatic_function_calling_history=automatic_function_calling_history, is_valid=_validate_response(response), ) return response def send_message_stream( self, message: Union[list[PartUnionDict], PartUnionDict], config: Optional[GenerateContentConfigOrDict] = None, ) -> Iterator[GenerateContentResponse]: """Sends the conversation history with the additional message and yields the model's response in chunks. Args: message: The message to send to the model. config: Optional config to override the default Chat config for this request. Yields: The model's response in chunks. Usage: .. code-block:: python chat = client.chats.create(model='gemini-1.5-flash') for chunk in chat.send_message_stream('tell me a story'): print(chunk.text) """ if not _is_part_type(message): raise ValueError( f"Message must be a valid part type: {types.PartUnion} or" f" {types.PartUnionDict}, got {type(message)}" ) input_content = t.t_content(self._modules._api_client, message) output_contents = [] finish_reason = None is_valid = True chunk = None if isinstance(self._modules, Models): for chunk in self._modules.generate_content_stream( model=self._model, contents=self._curated_history + [input_content], # type: ignore[arg-type] config=config if config else self._config, ): if not _validate_response(chunk): is_valid = False if chunk.candidates and chunk.candidates[0].content: output_contents.append(chunk.candidates[0].content) if chunk.candidates and chunk.candidates[0].finish_reason: finish_reason = chunk.candidates[0].finish_reason yield chunk automatic_function_calling_history = ( chunk.automatic_function_calling_history if chunk.automatic_function_calling_history else [] ) self.record_history( user_input=input_content, model_output=output_contents, automatic_function_calling_history=automatic_function_calling_history, is_valid=is_valid and output_contents is not None and finish_reason is not None, ) class Chats: """A util class to create chat sessions.""" def __init__(self, modules: Models): self._modules = modules def create( self, *, model: str, config: Optional[GenerateContentConfigOrDict] = None, history: Optional[list[ContentOrDict]] = None, ) -> Chat: """Creates a new chat session. Args: model: The model to use for the chat. config: The configuration to use for the generate content request. history: The history to use for the chat. Returns: A new chat session. """ return Chat( modules=self._modules, model=model, config=config, history=history if history else [], ) class AsyncChat(_BaseChat): """Async chat session.""" def __init__( self, *, modules: AsyncModels, model: str, config: Optional[GenerateContentConfigOrDict] = None, history: list[ContentOrDict], ): self._modules = modules super().__init__( model=model, config=config, history=history, ) async def send_message( self, message: Union[list[PartUnionDict], PartUnionDict], config: Optional[GenerateContentConfigOrDict] = None, ) -> GenerateContentResponse: """Sends the conversation history with the additional message and returns model's response. Args: message: The message to send to the model. config: Optional config to override the default Chat config for this request. Returns: The model's response. Usage: .. code-block:: python chat = client.aio.chats.create(model='gemini-1.5-flash') response = await chat.send_message('tell me a story') """ if not _is_part_type(message): raise ValueError( f"Message must be a valid part type: {types.PartUnion} or" f" {types.PartUnionDict}, got {type(message)}" ) input_content = t.t_content(self._modules._api_client, message) response = await self._modules.generate_content( model=self._model, contents=self._curated_history + [input_content], # type: ignore[arg-type] config=config if config else self._config, ) model_output = ( [response.candidates[0].content] if response.candidates and response.candidates[0].content else [] ) automatic_function_calling_history = ( response.automatic_function_calling_history if response.automatic_function_calling_history else [] ) self.record_history( user_input=input_content, model_output=model_output, automatic_function_calling_history=automatic_function_calling_history, is_valid=_validate_response(response), ) return response async def send_message_stream( self, message: Union[list[PartUnionDict], PartUnionDict], config: Optional[GenerateContentConfigOrDict] = None, ) -> AsyncIterator[GenerateContentResponse]: """Sends the conversation history with the additional message and yields the model's response in chunks. Args: message: The message to send to the model. config: Optional config to override the default Chat config for this request. Yields: The model's response in chunks. Usage: .. code-block:: python chat = client.aio.chats.create(model='gemini-1.5-flash') async for chunk in await chat.send_message_stream('tell me a story'): print(chunk.text) """ if not _is_part_type(message): raise ValueError( f"Message must be a valid part type: {types.PartUnion} or" f" {types.PartUnionDict}, got {type(message)}" ) input_content = t.t_content(self._modules._api_client, message) async def async_generator(): # type: ignore[no-untyped-def] output_contents = [] finish_reason = None is_valid = True chunk = None async for chunk in await self._modules.generate_content_stream( # type: ignore[attr-defined] model=self._model, contents=self._curated_history + [input_content], # type: ignore[arg-type] config=config if config else self._config, ): if not _validate_response(chunk): is_valid = False if chunk.candidates and chunk.candidates[0].content: output_contents.append(chunk.candidates[0].content) if chunk.candidates and chunk.candidates[0].finish_reason: finish_reason = chunk.candidates[0].finish_reason yield chunk if not output_contents or finish_reason is None: is_valid = False self.record_history( user_input=input_content, model_output=output_contents, automatic_function_calling_history=chunk.automatic_function_calling_history if chunk.automatic_function_calling_history else [], is_valid=is_valid, ) return async_generator() # type: ignore[no-untyped-call, no-any-return] class AsyncChats: """A util class to create async chat sessions.""" def __init__(self, modules: AsyncModels): self._modules = modules def create( self, *, model: str, config: Optional[GenerateContentConfigOrDict] = None, history: Optional[list[ContentOrDict]] = None, ) -> AsyncChat: """Creates a new chat session. Args: model: The model to use for the chat. config: The configuration to use for the generate content request. history: The history to use for the chat. Returns: A new chat session. """ return AsyncChat( modules=self._modules, model=model, config=config, history=history if history else [], )