structure saas with tools
This commit is contained in:
984
.venv/lib/python3.10/site-packages/google/genai/live.py
Normal file
984
.venv/lib/python3.10/site-packages/google/genai/live.py
Normal file
@@ -0,0 +1,984 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
"""[Preview] Live API client."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, AsyncIterator, Dict, Optional, Sequence, Union, cast, get_args
|
||||
import warnings
|
||||
|
||||
import google.auth
|
||||
import pydantic
|
||||
from websockets import ConnectionClosed
|
||||
|
||||
from . import _api_module
|
||||
from . import _common
|
||||
from . import _transformers as t
|
||||
from . import client
|
||||
from . import types
|
||||
from ._api_client import BaseApiClient
|
||||
from ._common import get_value_by_path as getv
|
||||
from ._common import set_value_by_path as setv
|
||||
from . import _live_converters as live_converters
|
||||
from .models import _Content_to_mldev
|
||||
from .models import _Content_to_vertex
|
||||
|
||||
|
||||
try:
|
||||
from websockets.asyncio.client import ClientConnection
|
||||
from websockets.asyncio.client import connect
|
||||
except ModuleNotFoundError:
|
||||
# This try/except is for TAP, mypy complains about it which is why we have the type: ignore
|
||||
from websockets.client import ClientConnection # type: ignore
|
||||
from websockets.client import connect # type: ignore
|
||||
|
||||
logger = logging.getLogger('google_genai.live')
|
||||
|
||||
_FUNCTION_RESPONSE_REQUIRES_ID = (
|
||||
'FunctionResponse request must have an `id` field from the'
|
||||
' response of a ToolCall.FunctionalCalls in Google AI.'
|
||||
)
|
||||
|
||||
|
||||
class AsyncSession:
|
||||
"""[Preview] AsyncSession."""
|
||||
|
||||
def __init__(
|
||||
self, api_client: BaseApiClient, websocket: ClientConnection
|
||||
):
|
||||
self._api_client = api_client
|
||||
self._ws = websocket
|
||||
|
||||
async def send(
|
||||
self,
|
||||
*,
|
||||
input: Optional[
|
||||
Union[
|
||||
types.ContentListUnion,
|
||||
types.ContentListUnionDict,
|
||||
types.LiveClientContentOrDict,
|
||||
types.LiveClientRealtimeInputOrDict,
|
||||
types.LiveClientToolResponseOrDict,
|
||||
types.FunctionResponseOrDict,
|
||||
Sequence[types.FunctionResponseOrDict],
|
||||
]
|
||||
] = None,
|
||||
end_of_turn: Optional[bool] = False,
|
||||
) -> None:
|
||||
"""[Deprecated] Send input to the model.
|
||||
|
||||
> **Warning**: This method is deprecated and will be removed in a future
|
||||
version (not before Q3 2025). Please use one of the more specific methods:
|
||||
`send_client_content`, `send_realtime_input`, or `send_tool_response`
|
||||
instead.
|
||||
|
||||
The method will send the input request to the server.
|
||||
|
||||
Args:
|
||||
input: The input request to the model.
|
||||
end_of_turn: Whether the input is the last message in a turn.
|
||||
|
||||
Example usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
client = genai.Client(api_key=API_KEY)
|
||||
|
||||
async with client.aio.live.connect(model='...') as session:
|
||||
await session.send(input='Hello world!', end_of_turn=True)
|
||||
async for message in session.receive():
|
||||
print(message)
|
||||
"""
|
||||
warnings.warn(
|
||||
'The `session.send` method is deprecated and will be removed in a '
|
||||
'future version (not before Q3 2025).\n'
|
||||
'Please use one of the more specific methods: `send_client_content`, '
|
||||
'`send_realtime_input`, or `send_tool_response` instead.',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
client_message = self._parse_client_message(input, end_of_turn)
|
||||
await self._ws.send(json.dumps(client_message))
|
||||
|
||||
async def send_client_content(
|
||||
self,
|
||||
*,
|
||||
turns: Optional[
|
||||
Union[
|
||||
types.Content,
|
||||
types.ContentDict,
|
||||
list[Union[types.Content, types.ContentDict]]
|
||||
]
|
||||
] = None,
|
||||
turn_complete: bool = True,
|
||||
) -> None:
|
||||
"""Send non-realtime, turn based content to the model.
|
||||
|
||||
There are two ways to send messages to the live API:
|
||||
`send_client_content` and `send_realtime_input`.
|
||||
|
||||
`send_client_content` messages are added to the model context **in order**.
|
||||
Having a conversation using `send_client_content` messages is roughly
|
||||
equivalent to using the `Chat.send_message_stream` method, except that the
|
||||
state of the `chat` history is stored on the API server.
|
||||
|
||||
Because of `send_client_content`'s order guarantee, the model cannot
|
||||
respond as quickly to `send_client_content` messages as to
|
||||
`send_realtime_input` messages. This makes the biggest difference when
|
||||
sending objects that have significant preprocessing time (typically images).
|
||||
|
||||
The `send_client_content` message sends a list of `Content` objects,
|
||||
which has more options than the `media:Blob` sent by `send_realtime_input`.
|
||||
|
||||
The main use-cases for `send_client_content` over `send_realtime_input` are:
|
||||
|
||||
- Prefilling a conversation context (including sending anything that can't
|
||||
be represented as a realtime message), before starting a realtime
|
||||
conversation.
|
||||
- Conducting a non-realtime conversation, similar to `client.chat`, using
|
||||
the live api.
|
||||
|
||||
Caution: Interleaving `send_client_content` and `send_realtime_input`
|
||||
in the same conversation is not recommended and can lead to unexpected
|
||||
results.
|
||||
|
||||
Args:
|
||||
turns: A `Content` object or list of `Content` objects (or equivalent
|
||||
dicts).
|
||||
turn_complete: if true (the default) the model will reply immediately. If
|
||||
false, the model will wait for you to send additional client_content,
|
||||
and will not return until you send `turn_complete=True`.
|
||||
|
||||
Example:
|
||||
```
|
||||
import google.genai
|
||||
from google.genai import types
|
||||
import os
|
||||
|
||||
if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI'):
|
||||
MODEL_NAME = 'gemini-2.0-flash-live-preview-04-09'
|
||||
else:
|
||||
MODEL_NAME = 'gemini-2.0-flash-live-001';
|
||||
|
||||
client = genai.Client()
|
||||
async with client.aio.live.connect(
|
||||
model=MODEL_NAME,
|
||||
config={"response_modalities": ["TEXT"]}
|
||||
) as session:
|
||||
await session.send_client_content(
|
||||
turns=types.Content(
|
||||
role='user',
|
||||
parts=[types.Part(text="Hello world!")]))
|
||||
async for msg in session.receive():
|
||||
if msg.text:
|
||||
print(msg.text)
|
||||
```
|
||||
"""
|
||||
client_content = t.t_client_content(turns, turn_complete)
|
||||
|
||||
if self._api_client.vertexai:
|
||||
client_content_dict = live_converters._LiveClientContent_to_vertex(
|
||||
api_client=self._api_client, from_object=client_content
|
||||
)
|
||||
else:
|
||||
client_content_dict = live_converters._LiveClientContent_to_mldev(
|
||||
api_client=self._api_client, from_object=client_content
|
||||
)
|
||||
|
||||
await self._ws.send(json.dumps({'client_content': client_content_dict}))
|
||||
|
||||
async def send_realtime_input(
|
||||
self,
|
||||
*,
|
||||
media: Optional[types.BlobImageUnionDict] = None,
|
||||
audio: Optional[types.BlobOrDict] = None,
|
||||
audio_stream_end: Optional[bool] = None,
|
||||
video: Optional[types.BlobImageUnionDict] = None,
|
||||
text: Optional[str] = None,
|
||||
activity_start: Optional[types.ActivityStartOrDict] = None,
|
||||
activity_end: Optional[types.ActivityEndOrDict] = None,
|
||||
) -> None:
|
||||
"""Send realtime input to the model, only send one argument per call.
|
||||
|
||||
Use `send_realtime_input` for realtime audio chunks and video
|
||||
frames(images).
|
||||
|
||||
With `send_realtime_input` the api will respond to audio automatically
|
||||
based on voice activity detection (VAD).
|
||||
|
||||
`send_realtime_input` is optimized for responsivness at the expense of
|
||||
deterministic ordering. Audio and video tokens are added to the
|
||||
context when they become available.
|
||||
|
||||
Args:
|
||||
media: A `Blob`-like object, the realtime media to send.
|
||||
|
||||
Example:
|
||||
```
|
||||
from pathlib import Path
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
|
||||
import PIL.Image
|
||||
|
||||
import os
|
||||
|
||||
if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI'):
|
||||
MODEL_NAME = 'gemini-2.0-flash-live-preview-04-09'
|
||||
else:
|
||||
MODEL_NAME = 'gemini-2.0-flash-live-001';
|
||||
|
||||
|
||||
client = genai.Client()
|
||||
|
||||
async with client.aio.live.connect(
|
||||
model=MODEL_NAME,
|
||||
config={"response_modalities": ["TEXT"]},
|
||||
) as session:
|
||||
await session.send_realtime_input(
|
||||
media=PIL.Image.open('image.jpg'))
|
||||
|
||||
audio_bytes = Path('audio.pcm').read_bytes()
|
||||
await session.send_realtime_input(
|
||||
media=types.Blob(data=audio_bytes, mime_type='audio/pcm;rate=16000'))
|
||||
|
||||
async for msg in session.receive():
|
||||
if msg.text is not None:
|
||||
print(f'{msg.text}')
|
||||
```
|
||||
"""
|
||||
kwargs:dict[str, Any] = {}
|
||||
if media is not None:
|
||||
kwargs['media'] = media
|
||||
if audio is not None:
|
||||
kwargs['audio'] = audio
|
||||
if audio_stream_end is not None:
|
||||
kwargs['audio_stream_end'] = audio_stream_end
|
||||
if video is not None:
|
||||
kwargs['video'] = video
|
||||
if text is not None:
|
||||
kwargs['text'] = text
|
||||
if activity_start is not None:
|
||||
kwargs['activity_start'] = activity_start
|
||||
if activity_end is not None:
|
||||
kwargs['activity_end'] = activity_end
|
||||
|
||||
if len(kwargs) != 1:
|
||||
raise ValueError(
|
||||
f'Only one argument can be set, got {len(kwargs)}:'
|
||||
f' {list(kwargs.keys())}'
|
||||
)
|
||||
realtime_input = types.LiveSendRealtimeInputParameters.model_validate(
|
||||
kwargs
|
||||
)
|
||||
|
||||
if self._api_client.vertexai:
|
||||
realtime_input_dict = (
|
||||
live_converters._LiveSendRealtimeInputParameters_to_vertex(
|
||||
api_client=self._api_client, from_object=realtime_input
|
||||
)
|
||||
)
|
||||
else:
|
||||
realtime_input_dict = (
|
||||
live_converters._LiveSendRealtimeInputParameters_to_mldev(
|
||||
api_client=self._api_client, from_object=realtime_input
|
||||
)
|
||||
)
|
||||
realtime_input_dict = _common.convert_to_dict(realtime_input_dict)
|
||||
realtime_input_dict = _common.encode_unserializable_types(
|
||||
realtime_input_dict
|
||||
)
|
||||
await self._ws.send(json.dumps({'realtime_input': realtime_input_dict}))
|
||||
|
||||
async def send_tool_response(
|
||||
self,
|
||||
*,
|
||||
function_responses: Union[
|
||||
types.FunctionResponseOrDict,
|
||||
Sequence[types.FunctionResponseOrDict],
|
||||
],
|
||||
) -> None:
|
||||
"""Send a tool response to the session.
|
||||
|
||||
Use `send_tool_response` to reply to `LiveServerToolCall` messages
|
||||
from the server.
|
||||
|
||||
To set the available tools, use the `config.tools` argument
|
||||
when you connect to the session (`client.live.connect`).
|
||||
|
||||
Args:
|
||||
function_responses: A `FunctionResponse`-like object or list of
|
||||
`FunctionResponse`-like objects.
|
||||
|
||||
Example:
|
||||
```
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
|
||||
import os
|
||||
|
||||
if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI'):
|
||||
MODEL_NAME = 'gemini-2.0-flash-live-preview-04-09'
|
||||
else:
|
||||
MODEL_NAME = 'gemini-2.0-flash-live-001';
|
||||
|
||||
client = genai.Client()
|
||||
|
||||
tools = [{'function_declarations': [{'name': 'turn_on_the_lights'}]}]
|
||||
config = {
|
||||
"tools": tools,
|
||||
"response_modalities": ['TEXT']
|
||||
}
|
||||
|
||||
async with client.aio.live.connect(
|
||||
model='models/gemini-2.0-flash-live-001',
|
||||
config=config
|
||||
) as session:
|
||||
prompt = "Turn on the lights please"
|
||||
await session.send_client_content(
|
||||
turns={"parts": [{'text': prompt}]}
|
||||
)
|
||||
|
||||
async for chunk in session.receive():
|
||||
if chunk.server_content:
|
||||
if chunk.text is not None:
|
||||
print(chunk.text)
|
||||
elif chunk.tool_call:
|
||||
print(chunk.tool_call)
|
||||
print('_'*80)
|
||||
function_response=types.FunctionResponse(
|
||||
name='turn_on_the_lights',
|
||||
response={'result': 'ok'},
|
||||
id=chunk.tool_call.function_calls[0].id,
|
||||
)
|
||||
print(function_response)
|
||||
await session.send_tool_response(
|
||||
function_responses=function_response
|
||||
)
|
||||
|
||||
print('_'*80)
|
||||
"""
|
||||
tool_response = t.t_tool_response(function_responses)
|
||||
if self._api_client.vertexai:
|
||||
tool_response_dict = live_converters._LiveClientToolResponse_to_vertex(
|
||||
api_client=self._api_client, from_object=tool_response
|
||||
)
|
||||
else:
|
||||
tool_response_dict = live_converters._LiveClientToolResponse_to_mldev(
|
||||
api_client=self._api_client, from_object=tool_response
|
||||
)
|
||||
for response in tool_response_dict.get('functionResponses', []):
|
||||
if response.get('id') is None:
|
||||
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
|
||||
|
||||
await self._ws.send(json.dumps({'tool_response': tool_response_dict}))
|
||||
|
||||
async def receive(self) -> AsyncIterator[types.LiveServerMessage]:
|
||||
"""Receive model responses from the server.
|
||||
|
||||
The method will yield the model responses from the server. The returned
|
||||
responses will represent a complete model turn. When the returned message
|
||||
is function call, user must call `send` with the function response to
|
||||
continue the turn.
|
||||
|
||||
Yields:
|
||||
The model responses from the server.
|
||||
|
||||
Example usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
client = genai.Client(api_key=API_KEY)
|
||||
|
||||
async with client.aio.live.connect(model='...') as session:
|
||||
await session.send(input='Hello world!', end_of_turn=True)
|
||||
async for message in session.receive():
|
||||
print(message)
|
||||
"""
|
||||
# TODO(b/365983264) Handle intermittent issues for the user.
|
||||
while result := await self._receive():
|
||||
if result.server_content and result.server_content.turn_complete:
|
||||
yield result
|
||||
break
|
||||
yield result
|
||||
|
||||
async def start_stream(
|
||||
self, *, stream: AsyncIterator[bytes], mime_type: str
|
||||
) -> AsyncIterator[types.LiveServerMessage]:
|
||||
"""[Deprecated] Start a live session from a data stream.
|
||||
|
||||
> **Warning**: This method is deprecated and will be removed in a future
|
||||
version (not before Q2 2025). Please use one of the more specific methods:
|
||||
`send_client_content`, `send_realtime_input`, or `send_tool_response`
|
||||
instead.
|
||||
|
||||
The interaction terminates when the input stream is complete.
|
||||
This method will start two async tasks. One task will be used to send the
|
||||
input stream to the model and the other task will be used to receive the
|
||||
responses from the model.
|
||||
|
||||
Args:
|
||||
stream: An iterator that yields the model response.
|
||||
mime_type: The MIME type of the data in the stream.
|
||||
|
||||
Yields:
|
||||
The audio bytes received from the model and server response messages.
|
||||
|
||||
Example usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
client = genai.Client(api_key=API_KEY)
|
||||
config = {'response_modalities': ['AUDIO']}
|
||||
async def audio_stream():
|
||||
stream = read_audio()
|
||||
for data in stream:
|
||||
yield data
|
||||
async with client.aio.live.connect(model='...', config=config) as session:
|
||||
for audio in session.start_stream(stream = audio_stream(),
|
||||
mime_type = 'audio/pcm'):
|
||||
play_audio_chunk(audio.data)
|
||||
"""
|
||||
warnings.warn(
|
||||
'Setting `AsyncSession.start_stream` is deprecated, '
|
||||
'and will be removed in a future release (not before Q3 2025). '
|
||||
'Please use the `receive`, and `send_realtime_input`, methods instead.',
|
||||
DeprecationWarning,
|
||||
stacklevel=4,
|
||||
)
|
||||
stop_event = asyncio.Event()
|
||||
# Start the send loop. When stream is complete stop_event is set.
|
||||
asyncio.create_task(self._send_loop(stream, mime_type, stop_event))
|
||||
recv_task = None
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
recv_task = asyncio.create_task(self._receive())
|
||||
await asyncio.wait(
|
||||
[
|
||||
recv_task,
|
||||
asyncio.create_task(stop_event.wait()),
|
||||
],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
if recv_task.done():
|
||||
yield recv_task.result()
|
||||
# Give a chance for the send loop to process requests.
|
||||
await asyncio.sleep(10**-12)
|
||||
except ConnectionClosed:
|
||||
break
|
||||
if recv_task is not None and not recv_task.done():
|
||||
recv_task.cancel()
|
||||
# Wait for the task to finish (cancelled or not)
|
||||
try:
|
||||
await recv_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _receive(self) -> types.LiveServerMessage:
|
||||
parameter_model = types.LiveServerMessage()
|
||||
try:
|
||||
raw_response = await self._ws.recv(decode=False)
|
||||
except TypeError:
|
||||
raw_response = await self._ws.recv() # type: ignore[assignment]
|
||||
if raw_response:
|
||||
try:
|
||||
response = json.loads(raw_response)
|
||||
except json.decoder.JSONDecodeError:
|
||||
raise ValueError(f'Failed to parse response: {raw_response!r}')
|
||||
else:
|
||||
response = {}
|
||||
|
||||
if self._api_client.vertexai:
|
||||
response_dict = live_converters._LiveServerMessage_from_vertex(self._api_client, response)
|
||||
else:
|
||||
response_dict = live_converters._LiveServerMessage_from_mldev(self._api_client, response)
|
||||
|
||||
return types.LiveServerMessage._from_response(
|
||||
response=response_dict, kwargs=parameter_model.model_dump()
|
||||
)
|
||||
|
||||
async def _send_loop(
|
||||
self,
|
||||
data_stream: AsyncIterator[bytes],
|
||||
mime_type: str,
|
||||
stop_event: asyncio.Event,
|
||||
) -> None:
|
||||
async for data in data_stream:
|
||||
model_input = types.LiveClientRealtimeInput(
|
||||
media_chunks=[types.Blob(data=data, mime_type=mime_type)]
|
||||
)
|
||||
await self.send(input=model_input)
|
||||
# Give a chance for the receive loop to process responses.
|
||||
await asyncio.sleep(10**-12)
|
||||
# Give a chance for the receiver to process the last response.
|
||||
stop_event.set()
|
||||
|
||||
def _parse_client_message(
|
||||
self,
|
||||
input: Optional[
|
||||
Union[
|
||||
types.ContentListUnion,
|
||||
types.ContentListUnionDict,
|
||||
types.LiveClientContentOrDict,
|
||||
types.LiveClientRealtimeInputOrDict,
|
||||
types.LiveClientToolResponseOrDict,
|
||||
types.FunctionResponseOrDict,
|
||||
Sequence[types.FunctionResponseOrDict],
|
||||
]
|
||||
] = None,
|
||||
end_of_turn: Optional[bool] = False,
|
||||
) -> types.LiveClientMessageDict:
|
||||
|
||||
formatted_input: Any = input
|
||||
|
||||
if not input:
|
||||
logging.info('No input provided. Assume it is the end of turn.')
|
||||
return {'client_content': {'turn_complete': True}}
|
||||
if isinstance(input, str):
|
||||
formatted_input = [input]
|
||||
elif isinstance(input, dict) and 'data' in input:
|
||||
try:
|
||||
blob_input = types.Blob(**input)
|
||||
except pydantic.ValidationError:
|
||||
raise ValueError(
|
||||
f'Unsupported input type "{type(input)}" or input content "{input}"'
|
||||
)
|
||||
if (
|
||||
isinstance(blob_input, types.Blob)
|
||||
and isinstance(blob_input.data, bytes)
|
||||
):
|
||||
formatted_input = [
|
||||
blob_input.model_dump(mode='json', exclude_none=True)
|
||||
]
|
||||
elif isinstance(input, types.Blob):
|
||||
formatted_input = [input]
|
||||
elif isinstance(input, dict) and 'name' in input and 'response' in input:
|
||||
# ToolResponse.FunctionResponse
|
||||
if not (self._api_client.vertexai) and 'id' not in input:
|
||||
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
|
||||
formatted_input = [input]
|
||||
|
||||
if isinstance(formatted_input, Sequence) and any(
|
||||
isinstance(c, dict) and 'name' in c and 'response' in c
|
||||
for c in formatted_input
|
||||
):
|
||||
# ToolResponse.FunctionResponse
|
||||
function_responses_input = []
|
||||
for item in formatted_input:
|
||||
if isinstance(item, dict):
|
||||
try:
|
||||
function_response_input = types.FunctionResponse(**item)
|
||||
except pydantic.ValidationError:
|
||||
raise ValueError(
|
||||
f'Unsupported input type "{type(input)}" or input content'
|
||||
f' "{input}"'
|
||||
)
|
||||
if (
|
||||
function_response_input.id is None
|
||||
and not self._api_client.vertexai
|
||||
):
|
||||
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
|
||||
else:
|
||||
function_response_dict = function_response_input.model_dump(
|
||||
exclude_none=True, mode='json'
|
||||
)
|
||||
function_response_typeddict = types.FunctionResponseDict(
|
||||
name=function_response_dict.get('name'),
|
||||
response=function_response_dict.get('response'),
|
||||
)
|
||||
if function_response_dict.get('id'):
|
||||
function_response_typeddict['id'] = function_response_dict.get(
|
||||
'id'
|
||||
)
|
||||
function_responses_input.append(function_response_typeddict)
|
||||
client_message = types.LiveClientMessageDict(
|
||||
tool_response=types.LiveClientToolResponseDict(
|
||||
function_responses=function_responses_input
|
||||
)
|
||||
)
|
||||
elif isinstance(formatted_input, Sequence) and any(
|
||||
isinstance(c, str) for c in formatted_input
|
||||
):
|
||||
to_object: dict[str, Any] = {}
|
||||
content_input_parts: list[types.PartUnion] = []
|
||||
for item in formatted_input:
|
||||
if isinstance(item, get_args(types.PartUnion)):
|
||||
content_input_parts.append(item)
|
||||
if self._api_client.vertexai:
|
||||
contents = [
|
||||
_Content_to_vertex(self._api_client, item, to_object)
|
||||
for item in t.t_contents(self._api_client, content_input_parts)
|
||||
]
|
||||
else:
|
||||
contents = [
|
||||
_Content_to_mldev(self._api_client, item, to_object)
|
||||
for item in t.t_contents(self._api_client, content_input_parts)
|
||||
]
|
||||
|
||||
content_dict_list: list[types.ContentDict] = []
|
||||
for item in contents:
|
||||
try:
|
||||
content_input = types.Content(**item)
|
||||
except pydantic.ValidationError:
|
||||
raise ValueError(
|
||||
f'Unsupported input type "{type(input)}" or input content'
|
||||
f' "{input}"'
|
||||
)
|
||||
content_dict_list.append(
|
||||
types.ContentDict(
|
||||
parts=content_input.model_dump(exclude_none=True, mode='json')[
|
||||
'parts'
|
||||
],
|
||||
role=content_input.role,
|
||||
)
|
||||
)
|
||||
|
||||
client_message = types.LiveClientMessageDict(
|
||||
client_content=types.LiveClientContentDict(
|
||||
turns=content_dict_list, turn_complete=end_of_turn
|
||||
)
|
||||
)
|
||||
elif isinstance(formatted_input, Sequence):
|
||||
if any((isinstance(b, dict) and 'data' in b) for b in formatted_input):
|
||||
pass
|
||||
elif any(isinstance(b, types.Blob) for b in formatted_input):
|
||||
formatted_input = [
|
||||
b.model_dump(exclude_none=True, mode='json')
|
||||
for b in formatted_input
|
||||
]
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Unsupported input type "{type(input)}" or input content "{input}"'
|
||||
)
|
||||
|
||||
client_message = types.LiveClientMessageDict(
|
||||
realtime_input=types.LiveClientRealtimeInputDict(
|
||||
media_chunks=formatted_input
|
||||
)
|
||||
)
|
||||
|
||||
elif isinstance(formatted_input, dict):
|
||||
if 'content' in formatted_input or 'turns' in formatted_input:
|
||||
# TODO(b/365983264) Add validation checks for content_update input_dict.
|
||||
if 'turns' in formatted_input:
|
||||
content_turns = formatted_input['turns']
|
||||
else:
|
||||
content_turns = formatted_input['content']
|
||||
client_message = types.LiveClientMessageDict(
|
||||
client_content=types.LiveClientContentDict(
|
||||
turns=content_turns,
|
||||
turn_complete=formatted_input.get('turn_complete'),
|
||||
)
|
||||
)
|
||||
elif 'media_chunks' in formatted_input:
|
||||
try:
|
||||
realtime_input = types.LiveClientRealtimeInput(**formatted_input)
|
||||
except pydantic.ValidationError:
|
||||
raise ValueError(
|
||||
f'Unsupported input type "{type(input)}" or input content'
|
||||
f' "{input}"'
|
||||
)
|
||||
client_message = types.LiveClientMessageDict(
|
||||
realtime_input=types.LiveClientRealtimeInputDict(
|
||||
media_chunks=realtime_input.model_dump(
|
||||
exclude_none=True, mode='json'
|
||||
)['media_chunks']
|
||||
)
|
||||
)
|
||||
elif 'function_responses' in formatted_input:
|
||||
try:
|
||||
tool_response_input = types.LiveClientToolResponse(**formatted_input)
|
||||
except pydantic.ValidationError:
|
||||
raise ValueError(
|
||||
f'Unsupported input type "{type(input)}" or input content'
|
||||
f' "{input}"'
|
||||
)
|
||||
client_message = types.LiveClientMessageDict(
|
||||
tool_response=types.LiveClientToolResponseDict(
|
||||
function_responses=tool_response_input.model_dump(
|
||||
exclude_none=True, mode='json'
|
||||
)['function_responses']
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Unsupported input type "{type(input)}" or input content "{input}"'
|
||||
)
|
||||
elif isinstance(formatted_input, types.LiveClientRealtimeInput):
|
||||
realtime_input_dict = formatted_input.model_dump(
|
||||
exclude_none=True, mode='json'
|
||||
)
|
||||
client_message = types.LiveClientMessageDict(
|
||||
realtime_input=types.LiveClientRealtimeInputDict(
|
||||
media_chunks=realtime_input_dict.get('media_chunks')
|
||||
)
|
||||
)
|
||||
if (
|
||||
client_message['realtime_input'] is not None
|
||||
and client_message['realtime_input']['media_chunks'] is not None
|
||||
and isinstance(
|
||||
client_message['realtime_input']['media_chunks'][0]['data'], bytes
|
||||
)
|
||||
):
|
||||
formatted_media_chunks: list[types.BlobDict] = []
|
||||
for item in client_message['realtime_input']['media_chunks']:
|
||||
if isinstance(item, dict):
|
||||
try:
|
||||
blob_input = types.Blob(**item)
|
||||
except pydantic.ValidationError:
|
||||
raise ValueError(
|
||||
f'Unsupported input type "{type(input)}" or input content'
|
||||
f' "{input}"'
|
||||
)
|
||||
if (
|
||||
isinstance(blob_input, types.Blob)
|
||||
and isinstance(blob_input.data, bytes)
|
||||
and blob_input.data is not None
|
||||
):
|
||||
formatted_media_chunks.append(
|
||||
types.BlobDict(
|
||||
data=base64.b64decode(blob_input.data),
|
||||
mime_type=blob_input.mime_type,
|
||||
)
|
||||
)
|
||||
|
||||
client_message['realtime_input'][
|
||||
'media_chunks'
|
||||
] = formatted_media_chunks
|
||||
|
||||
elif isinstance(formatted_input, types.LiveClientContent):
|
||||
client_content_dict = formatted_input.model_dump(
|
||||
exclude_none=True, mode='json'
|
||||
)
|
||||
client_message = types.LiveClientMessageDict(
|
||||
client_content=types.LiveClientContentDict(
|
||||
turns=client_content_dict.get('turns'),
|
||||
turn_complete=client_content_dict.get('turn_complete'),
|
||||
)
|
||||
)
|
||||
elif isinstance(formatted_input, types.LiveClientToolResponse):
|
||||
# ToolResponse.FunctionResponse
|
||||
if (
|
||||
not (self._api_client.vertexai)
|
||||
and formatted_input.function_responses is not None
|
||||
and not (formatted_input.function_responses[0].id)
|
||||
):
|
||||
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
|
||||
client_message = types.LiveClientMessageDict(
|
||||
tool_response=types.LiveClientToolResponseDict(
|
||||
function_responses=formatted_input.model_dump(
|
||||
exclude_none=True, mode='json'
|
||||
).get('function_responses')
|
||||
)
|
||||
)
|
||||
elif isinstance(formatted_input, types.FunctionResponse):
|
||||
if not (self._api_client.vertexai) and not (formatted_input.id):
|
||||
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
|
||||
function_response_dict = formatted_input.model_dump(
|
||||
exclude_none=True, mode='json'
|
||||
)
|
||||
function_response_typeddict = types.FunctionResponseDict(
|
||||
name=function_response_dict.get('name'),
|
||||
response=function_response_dict.get('response'),
|
||||
)
|
||||
if function_response_dict.get('id'):
|
||||
function_response_typeddict['id'] = function_response_dict.get('id')
|
||||
client_message = types.LiveClientMessageDict(
|
||||
tool_response=types.LiveClientToolResponseDict(
|
||||
function_responses=[function_response_typeddict]
|
||||
)
|
||||
)
|
||||
elif isinstance(formatted_input, Sequence) and isinstance(
|
||||
formatted_input[0], types.FunctionResponse
|
||||
):
|
||||
if not (self._api_client.vertexai) and not (formatted_input[0].id):
|
||||
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
|
||||
function_response_list: list[types.FunctionResponseDict] = []
|
||||
for item in formatted_input:
|
||||
function_response_dict = item.model_dump(exclude_none=True, mode='json')
|
||||
function_response_typeddict = types.FunctionResponseDict(
|
||||
name=function_response_dict.get('name'),
|
||||
response=function_response_dict.get('response'),
|
||||
)
|
||||
if function_response_dict.get('id'):
|
||||
function_response_typeddict['id'] = function_response_dict.get('id')
|
||||
function_response_list.append(function_response_typeddict)
|
||||
client_message = types.LiveClientMessageDict(
|
||||
tool_response=types.LiveClientToolResponseDict(
|
||||
function_responses=function_response_list
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Unsupported input type "{type(input)}" or input content "{input}"'
|
||||
)
|
||||
|
||||
return client_message
|
||||
|
||||
async def close(self) -> None:
|
||||
# Close the websocket connection.
|
||||
await self._ws.close()
|
||||
|
||||
|
||||
class AsyncLive(_api_module.BaseModule):
|
||||
"""[Preview] AsyncLive."""
|
||||
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def connect(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
config: Optional[types.LiveConnectConfigOrDict] = None,
|
||||
) -> AsyncIterator[AsyncSession]:
|
||||
"""[Preview] Connect to the live server.
|
||||
|
||||
Note: the live API is currently in preview.
|
||||
|
||||
Usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
client = genai.Client(api_key=API_KEY)
|
||||
config = {}
|
||||
async with client.aio.live.connect(model='...', config=config) as session:
|
||||
await session.send(input='Hello world!', end_of_turn=True)
|
||||
async for message in session.receive():
|
||||
print(message)
|
||||
"""
|
||||
base_url = self._api_client._websocket_base_url()
|
||||
if isinstance(base_url, bytes):
|
||||
base_url = base_url.decode('utf-8')
|
||||
transformed_model = t.t_model(self._api_client, model)
|
||||
|
||||
parameter_model = _t_live_connect_config(self._api_client, config)
|
||||
|
||||
if self._api_client.api_key:
|
||||
api_key = self._api_client.api_key
|
||||
version = self._api_client._http_options.api_version
|
||||
uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.BidiGenerateContent?key={api_key}'
|
||||
headers = self._api_client._http_options.headers
|
||||
|
||||
request_dict = _common.convert_to_dict(
|
||||
live_converters._LiveConnectParameters_to_mldev(
|
||||
api_client=self._api_client,
|
||||
from_object=types.LiveConnectParameters(
|
||||
model=transformed_model,
|
||||
config=parameter_model,
|
||||
).model_dump(exclude_none=True)
|
||||
)
|
||||
)
|
||||
del request_dict['config']
|
||||
|
||||
setv(request_dict, ['setup', 'model'], transformed_model)
|
||||
|
||||
request = json.dumps(request_dict)
|
||||
else:
|
||||
# Get bearer token through Application Default Credentials.
|
||||
creds, _ = google.auth.default( # type: ignore[no-untyped-call]
|
||||
scopes=['https://www.googleapis.com/auth/cloud-platform']
|
||||
)
|
||||
|
||||
# creds.valid is False, and creds.token is None
|
||||
# Need to refresh credentials to populate those
|
||||
auth_req = google.auth.transport.requests.Request() # type: ignore[no-untyped-call]
|
||||
creds.refresh(auth_req)
|
||||
bearer_token = creds.token
|
||||
headers = self._api_client._http_options.headers
|
||||
if headers is not None:
|
||||
headers.update({
|
||||
'Authorization': 'Bearer {}'.format(bearer_token),
|
||||
})
|
||||
version = self._api_client._http_options.api_version
|
||||
uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
|
||||
location = self._api_client.location
|
||||
project = self._api_client.project
|
||||
if transformed_model.startswith('publishers/'):
|
||||
transformed_model = (
|
||||
f'projects/{project}/locations/{location}/' + transformed_model
|
||||
)
|
||||
request_dict = _common.convert_to_dict(
|
||||
live_converters._LiveConnectParameters_to_vertex(
|
||||
api_client=self._api_client,
|
||||
from_object=types.LiveConnectParameters(
|
||||
model=transformed_model,
|
||||
config=parameter_model,
|
||||
).model_dump(exclude_none=True)
|
||||
)
|
||||
)
|
||||
del request_dict['config']
|
||||
|
||||
if getv(request_dict, ['setup', 'generationConfig', 'responseModalities']) is None:
|
||||
setv(request_dict, ['setup', 'generationConfig', 'responseModalities'], ['AUDIO'])
|
||||
|
||||
request = json.dumps(request_dict)
|
||||
|
||||
try:
|
||||
async with connect(uri, additional_headers=headers) as ws:
|
||||
await ws.send(request)
|
||||
logger.info(await ws.recv(decode=False))
|
||||
|
||||
yield AsyncSession(api_client=self._api_client, websocket=ws)
|
||||
except TypeError:
|
||||
# Try with the older websockets API
|
||||
async with connect(uri, extra_headers=headers) as ws:
|
||||
await ws.send(request)
|
||||
logger.info(await ws.recv())
|
||||
|
||||
yield AsyncSession(api_client=self._api_client, websocket=ws)
|
||||
|
||||
|
||||
def _t_live_connect_config(
|
||||
api_client: BaseApiClient,
|
||||
config: Optional[types.LiveConnectConfigOrDict],
|
||||
) -> types.LiveConnectConfig:
|
||||
# Ensure the config is a LiveConnectConfig.
|
||||
if config is None:
|
||||
parameter_model = types.LiveConnectConfig()
|
||||
elif isinstance(config, dict):
|
||||
if getv(config, ['system_instruction']) is not None:
|
||||
converted_system_instruction = t.t_content(
|
||||
api_client, getv(config, ['system_instruction'])
|
||||
)
|
||||
else:
|
||||
converted_system_instruction = None
|
||||
parameter_model = types.LiveConnectConfig(**config)
|
||||
parameter_model.system_instruction = converted_system_instruction
|
||||
else:
|
||||
if config.system_instruction is None:
|
||||
system_instruction = None
|
||||
else:
|
||||
system_instruction = t.t_content(
|
||||
api_client, getv(config, ['system_instruction'])
|
||||
)
|
||||
parameter_model = config
|
||||
parameter_model.system_instruction = system_instruction
|
||||
|
||||
if parameter_model.generation_config is not None:
|
||||
warnings.warn(
|
||||
'Setting `LiveConnectConfig.generation_config` is deprecated, '
|
||||
'please set the fields on `LiveConnectConfig` directly. This will '
|
||||
'become an error in a future version (not before Q3 2025)',
|
||||
DeprecationWarning,
|
||||
stacklevel=4,
|
||||
)
|
||||
|
||||
return parameter_model
|
||||
Reference in New Issue
Block a user