Files
evo-ai/.venv/lib/python3.10/site-packages/google/genai/live.py
2025-04-25 15:30:54 -03:00

985 lines
34 KiB
Python

# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""[Preview] Live API client."""
import asyncio
import base64
import contextlib
import json
import logging
from typing import Any, AsyncIterator, Dict, Optional, Sequence, Union, cast, get_args
import warnings
import google.auth
import pydantic
from websockets import ConnectionClosed
from . import _api_module
from . import _common
from . import _transformers as t
from . import client
from . import types
from ._api_client import BaseApiClient
from ._common import get_value_by_path as getv
from ._common import set_value_by_path as setv
from . import _live_converters as live_converters
from .models import _Content_to_mldev
from .models import _Content_to_vertex
try:
from websockets.asyncio.client import ClientConnection
from websockets.asyncio.client import connect
except ModuleNotFoundError:
# This try/except is for TAP, mypy complains about it which is why we have the type: ignore
from websockets.client import ClientConnection # type: ignore
from websockets.client import connect # type: ignore
logger = logging.getLogger('google_genai.live')
_FUNCTION_RESPONSE_REQUIRES_ID = (
'FunctionResponse request must have an `id` field from the'
' response of a ToolCall.FunctionalCalls in Google AI.'
)
class AsyncSession:
"""[Preview] AsyncSession."""
def __init__(
self, api_client: BaseApiClient, websocket: ClientConnection
):
self._api_client = api_client
self._ws = websocket
async def send(
self,
*,
input: Optional[
Union[
types.ContentListUnion,
types.ContentListUnionDict,
types.LiveClientContentOrDict,
types.LiveClientRealtimeInputOrDict,
types.LiveClientToolResponseOrDict,
types.FunctionResponseOrDict,
Sequence[types.FunctionResponseOrDict],
]
] = None,
end_of_turn: Optional[bool] = False,
) -> None:
"""[Deprecated] Send input to the model.
> **Warning**: This method is deprecated and will be removed in a future
version (not before Q3 2025). Please use one of the more specific methods:
`send_client_content`, `send_realtime_input`, or `send_tool_response`
instead.
The method will send the input request to the server.
Args:
input: The input request to the model.
end_of_turn: Whether the input is the last message in a turn.
Example usage:
.. code-block:: python
client = genai.Client(api_key=API_KEY)
async with client.aio.live.connect(model='...') as session:
await session.send(input='Hello world!', end_of_turn=True)
async for message in session.receive():
print(message)
"""
warnings.warn(
'The `session.send` method is deprecated and will be removed in a '
'future version (not before Q3 2025).\n'
'Please use one of the more specific methods: `send_client_content`, '
'`send_realtime_input`, or `send_tool_response` instead.',
DeprecationWarning,
stacklevel=2,
)
client_message = self._parse_client_message(input, end_of_turn)
await self._ws.send(json.dumps(client_message))
async def send_client_content(
self,
*,
turns: Optional[
Union[
types.Content,
types.ContentDict,
list[Union[types.Content, types.ContentDict]]
]
] = None,
turn_complete: bool = True,
) -> None:
"""Send non-realtime, turn based content to the model.
There are two ways to send messages to the live API:
`send_client_content` and `send_realtime_input`.
`send_client_content` messages are added to the model context **in order**.
Having a conversation using `send_client_content` messages is roughly
equivalent to using the `Chat.send_message_stream` method, except that the
state of the `chat` history is stored on the API server.
Because of `send_client_content`'s order guarantee, the model cannot
respond as quickly to `send_client_content` messages as to
`send_realtime_input` messages. This makes the biggest difference when
sending objects that have significant preprocessing time (typically images).
The `send_client_content` message sends a list of `Content` objects,
which has more options than the `media:Blob` sent by `send_realtime_input`.
The main use-cases for `send_client_content` over `send_realtime_input` are:
- Prefilling a conversation context (including sending anything that can't
be represented as a realtime message), before starting a realtime
conversation.
- Conducting a non-realtime conversation, similar to `client.chat`, using
the live api.
Caution: Interleaving `send_client_content` and `send_realtime_input`
in the same conversation is not recommended and can lead to unexpected
results.
Args:
turns: A `Content` object or list of `Content` objects (or equivalent
dicts).
turn_complete: if true (the default) the model will reply immediately. If
false, the model will wait for you to send additional client_content,
and will not return until you send `turn_complete=True`.
Example:
```
import google.genai
from google.genai import types
import os
if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI'):
MODEL_NAME = 'gemini-2.0-flash-live-preview-04-09'
else:
MODEL_NAME = 'gemini-2.0-flash-live-001';
client = genai.Client()
async with client.aio.live.connect(
model=MODEL_NAME,
config={"response_modalities": ["TEXT"]}
) as session:
await session.send_client_content(
turns=types.Content(
role='user',
parts=[types.Part(text="Hello world!")]))
async for msg in session.receive():
if msg.text:
print(msg.text)
```
"""
client_content = t.t_client_content(turns, turn_complete)
if self._api_client.vertexai:
client_content_dict = live_converters._LiveClientContent_to_vertex(
api_client=self._api_client, from_object=client_content
)
else:
client_content_dict = live_converters._LiveClientContent_to_mldev(
api_client=self._api_client, from_object=client_content
)
await self._ws.send(json.dumps({'client_content': client_content_dict}))
async def send_realtime_input(
self,
*,
media: Optional[types.BlobImageUnionDict] = None,
audio: Optional[types.BlobOrDict] = None,
audio_stream_end: Optional[bool] = None,
video: Optional[types.BlobImageUnionDict] = None,
text: Optional[str] = None,
activity_start: Optional[types.ActivityStartOrDict] = None,
activity_end: Optional[types.ActivityEndOrDict] = None,
) -> None:
"""Send realtime input to the model, only send one argument per call.
Use `send_realtime_input` for realtime audio chunks and video
frames(images).
With `send_realtime_input` the api will respond to audio automatically
based on voice activity detection (VAD).
`send_realtime_input` is optimized for responsivness at the expense of
deterministic ordering. Audio and video tokens are added to the
context when they become available.
Args:
media: A `Blob`-like object, the realtime media to send.
Example:
```
from pathlib import Path
from google import genai
from google.genai import types
import PIL.Image
import os
if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI'):
MODEL_NAME = 'gemini-2.0-flash-live-preview-04-09'
else:
MODEL_NAME = 'gemini-2.0-flash-live-001';
client = genai.Client()
async with client.aio.live.connect(
model=MODEL_NAME,
config={"response_modalities": ["TEXT"]},
) as session:
await session.send_realtime_input(
media=PIL.Image.open('image.jpg'))
audio_bytes = Path('audio.pcm').read_bytes()
await session.send_realtime_input(
media=types.Blob(data=audio_bytes, mime_type='audio/pcm;rate=16000'))
async for msg in session.receive():
if msg.text is not None:
print(f'{msg.text}')
```
"""
kwargs:dict[str, Any] = {}
if media is not None:
kwargs['media'] = media
if audio is not None:
kwargs['audio'] = audio
if audio_stream_end is not None:
kwargs['audio_stream_end'] = audio_stream_end
if video is not None:
kwargs['video'] = video
if text is not None:
kwargs['text'] = text
if activity_start is not None:
kwargs['activity_start'] = activity_start
if activity_end is not None:
kwargs['activity_end'] = activity_end
if len(kwargs) != 1:
raise ValueError(
f'Only one argument can be set, got {len(kwargs)}:'
f' {list(kwargs.keys())}'
)
realtime_input = types.LiveSendRealtimeInputParameters.model_validate(
kwargs
)
if self._api_client.vertexai:
realtime_input_dict = (
live_converters._LiveSendRealtimeInputParameters_to_vertex(
api_client=self._api_client, from_object=realtime_input
)
)
else:
realtime_input_dict = (
live_converters._LiveSendRealtimeInputParameters_to_mldev(
api_client=self._api_client, from_object=realtime_input
)
)
realtime_input_dict = _common.convert_to_dict(realtime_input_dict)
realtime_input_dict = _common.encode_unserializable_types(
realtime_input_dict
)
await self._ws.send(json.dumps({'realtime_input': realtime_input_dict}))
async def send_tool_response(
self,
*,
function_responses: Union[
types.FunctionResponseOrDict,
Sequence[types.FunctionResponseOrDict],
],
) -> None:
"""Send a tool response to the session.
Use `send_tool_response` to reply to `LiveServerToolCall` messages
from the server.
To set the available tools, use the `config.tools` argument
when you connect to the session (`client.live.connect`).
Args:
function_responses: A `FunctionResponse`-like object or list of
`FunctionResponse`-like objects.
Example:
```
from google import genai
from google.genai import types
import os
if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI'):
MODEL_NAME = 'gemini-2.0-flash-live-preview-04-09'
else:
MODEL_NAME = 'gemini-2.0-flash-live-001';
client = genai.Client()
tools = [{'function_declarations': [{'name': 'turn_on_the_lights'}]}]
config = {
"tools": tools,
"response_modalities": ['TEXT']
}
async with client.aio.live.connect(
model='models/gemini-2.0-flash-live-001',
config=config
) as session:
prompt = "Turn on the lights please"
await session.send_client_content(
turns={"parts": [{'text': prompt}]}
)
async for chunk in session.receive():
if chunk.server_content:
if chunk.text is not None:
print(chunk.text)
elif chunk.tool_call:
print(chunk.tool_call)
print('_'*80)
function_response=types.FunctionResponse(
name='turn_on_the_lights',
response={'result': 'ok'},
id=chunk.tool_call.function_calls[0].id,
)
print(function_response)
await session.send_tool_response(
function_responses=function_response
)
print('_'*80)
"""
tool_response = t.t_tool_response(function_responses)
if self._api_client.vertexai:
tool_response_dict = live_converters._LiveClientToolResponse_to_vertex(
api_client=self._api_client, from_object=tool_response
)
else:
tool_response_dict = live_converters._LiveClientToolResponse_to_mldev(
api_client=self._api_client, from_object=tool_response
)
for response in tool_response_dict.get('functionResponses', []):
if response.get('id') is None:
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
await self._ws.send(json.dumps({'tool_response': tool_response_dict}))
async def receive(self) -> AsyncIterator[types.LiveServerMessage]:
"""Receive model responses from the server.
The method will yield the model responses from the server. The returned
responses will represent a complete model turn. When the returned message
is function call, user must call `send` with the function response to
continue the turn.
Yields:
The model responses from the server.
Example usage:
.. code-block:: python
client = genai.Client(api_key=API_KEY)
async with client.aio.live.connect(model='...') as session:
await session.send(input='Hello world!', end_of_turn=True)
async for message in session.receive():
print(message)
"""
# TODO(b/365983264) Handle intermittent issues for the user.
while result := await self._receive():
if result.server_content and result.server_content.turn_complete:
yield result
break
yield result
async def start_stream(
self, *, stream: AsyncIterator[bytes], mime_type: str
) -> AsyncIterator[types.LiveServerMessage]:
"""[Deprecated] Start a live session from a data stream.
> **Warning**: This method is deprecated and will be removed in a future
version (not before Q2 2025). Please use one of the more specific methods:
`send_client_content`, `send_realtime_input`, or `send_tool_response`
instead.
The interaction terminates when the input stream is complete.
This method will start two async tasks. One task will be used to send the
input stream to the model and the other task will be used to receive the
responses from the model.
Args:
stream: An iterator that yields the model response.
mime_type: The MIME type of the data in the stream.
Yields:
The audio bytes received from the model and server response messages.
Example usage:
.. code-block:: python
client = genai.Client(api_key=API_KEY)
config = {'response_modalities': ['AUDIO']}
async def audio_stream():
stream = read_audio()
for data in stream:
yield data
async with client.aio.live.connect(model='...', config=config) as session:
for audio in session.start_stream(stream = audio_stream(),
mime_type = 'audio/pcm'):
play_audio_chunk(audio.data)
"""
warnings.warn(
'Setting `AsyncSession.start_stream` is deprecated, '
'and will be removed in a future release (not before Q3 2025). '
'Please use the `receive`, and `send_realtime_input`, methods instead.',
DeprecationWarning,
stacklevel=4,
)
stop_event = asyncio.Event()
# Start the send loop. When stream is complete stop_event is set.
asyncio.create_task(self._send_loop(stream, mime_type, stop_event))
recv_task = None
while not stop_event.is_set():
try:
recv_task = asyncio.create_task(self._receive())
await asyncio.wait(
[
recv_task,
asyncio.create_task(stop_event.wait()),
],
return_when=asyncio.FIRST_COMPLETED,
)
if recv_task.done():
yield recv_task.result()
# Give a chance for the send loop to process requests.
await asyncio.sleep(10**-12)
except ConnectionClosed:
break
if recv_task is not None and not recv_task.done():
recv_task.cancel()
# Wait for the task to finish (cancelled or not)
try:
await recv_task
except asyncio.CancelledError:
pass
async def _receive(self) -> types.LiveServerMessage:
parameter_model = types.LiveServerMessage()
try:
raw_response = await self._ws.recv(decode=False)
except TypeError:
raw_response = await self._ws.recv() # type: ignore[assignment]
if raw_response:
try:
response = json.loads(raw_response)
except json.decoder.JSONDecodeError:
raise ValueError(f'Failed to parse response: {raw_response!r}')
else:
response = {}
if self._api_client.vertexai:
response_dict = live_converters._LiveServerMessage_from_vertex(self._api_client, response)
else:
response_dict = live_converters._LiveServerMessage_from_mldev(self._api_client, response)
return types.LiveServerMessage._from_response(
response=response_dict, kwargs=parameter_model.model_dump()
)
async def _send_loop(
self,
data_stream: AsyncIterator[bytes],
mime_type: str,
stop_event: asyncio.Event,
) -> None:
async for data in data_stream:
model_input = types.LiveClientRealtimeInput(
media_chunks=[types.Blob(data=data, mime_type=mime_type)]
)
await self.send(input=model_input)
# Give a chance for the receive loop to process responses.
await asyncio.sleep(10**-12)
# Give a chance for the receiver to process the last response.
stop_event.set()
def _parse_client_message(
self,
input: Optional[
Union[
types.ContentListUnion,
types.ContentListUnionDict,
types.LiveClientContentOrDict,
types.LiveClientRealtimeInputOrDict,
types.LiveClientToolResponseOrDict,
types.FunctionResponseOrDict,
Sequence[types.FunctionResponseOrDict],
]
] = None,
end_of_turn: Optional[bool] = False,
) -> types.LiveClientMessageDict:
formatted_input: Any = input
if not input:
logging.info('No input provided. Assume it is the end of turn.')
return {'client_content': {'turn_complete': True}}
if isinstance(input, str):
formatted_input = [input]
elif isinstance(input, dict) and 'data' in input:
try:
blob_input = types.Blob(**input)
except pydantic.ValidationError:
raise ValueError(
f'Unsupported input type "{type(input)}" or input content "{input}"'
)
if (
isinstance(blob_input, types.Blob)
and isinstance(blob_input.data, bytes)
):
formatted_input = [
blob_input.model_dump(mode='json', exclude_none=True)
]
elif isinstance(input, types.Blob):
formatted_input = [input]
elif isinstance(input, dict) and 'name' in input and 'response' in input:
# ToolResponse.FunctionResponse
if not (self._api_client.vertexai) and 'id' not in input:
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
formatted_input = [input]
if isinstance(formatted_input, Sequence) and any(
isinstance(c, dict) and 'name' in c and 'response' in c
for c in formatted_input
):
# ToolResponse.FunctionResponse
function_responses_input = []
for item in formatted_input:
if isinstance(item, dict):
try:
function_response_input = types.FunctionResponse(**item)
except pydantic.ValidationError:
raise ValueError(
f'Unsupported input type "{type(input)}" or input content'
f' "{input}"'
)
if (
function_response_input.id is None
and not self._api_client.vertexai
):
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
else:
function_response_dict = function_response_input.model_dump(
exclude_none=True, mode='json'
)
function_response_typeddict = types.FunctionResponseDict(
name=function_response_dict.get('name'),
response=function_response_dict.get('response'),
)
if function_response_dict.get('id'):
function_response_typeddict['id'] = function_response_dict.get(
'id'
)
function_responses_input.append(function_response_typeddict)
client_message = types.LiveClientMessageDict(
tool_response=types.LiveClientToolResponseDict(
function_responses=function_responses_input
)
)
elif isinstance(formatted_input, Sequence) and any(
isinstance(c, str) for c in formatted_input
):
to_object: dict[str, Any] = {}
content_input_parts: list[types.PartUnion] = []
for item in formatted_input:
if isinstance(item, get_args(types.PartUnion)):
content_input_parts.append(item)
if self._api_client.vertexai:
contents = [
_Content_to_vertex(self._api_client, item, to_object)
for item in t.t_contents(self._api_client, content_input_parts)
]
else:
contents = [
_Content_to_mldev(self._api_client, item, to_object)
for item in t.t_contents(self._api_client, content_input_parts)
]
content_dict_list: list[types.ContentDict] = []
for item in contents:
try:
content_input = types.Content(**item)
except pydantic.ValidationError:
raise ValueError(
f'Unsupported input type "{type(input)}" or input content'
f' "{input}"'
)
content_dict_list.append(
types.ContentDict(
parts=content_input.model_dump(exclude_none=True, mode='json')[
'parts'
],
role=content_input.role,
)
)
client_message = types.LiveClientMessageDict(
client_content=types.LiveClientContentDict(
turns=content_dict_list, turn_complete=end_of_turn
)
)
elif isinstance(formatted_input, Sequence):
if any((isinstance(b, dict) and 'data' in b) for b in formatted_input):
pass
elif any(isinstance(b, types.Blob) for b in formatted_input):
formatted_input = [
b.model_dump(exclude_none=True, mode='json')
for b in formatted_input
]
else:
raise ValueError(
f'Unsupported input type "{type(input)}" or input content "{input}"'
)
client_message = types.LiveClientMessageDict(
realtime_input=types.LiveClientRealtimeInputDict(
media_chunks=formatted_input
)
)
elif isinstance(formatted_input, dict):
if 'content' in formatted_input or 'turns' in formatted_input:
# TODO(b/365983264) Add validation checks for content_update input_dict.
if 'turns' in formatted_input:
content_turns = formatted_input['turns']
else:
content_turns = formatted_input['content']
client_message = types.LiveClientMessageDict(
client_content=types.LiveClientContentDict(
turns=content_turns,
turn_complete=formatted_input.get('turn_complete'),
)
)
elif 'media_chunks' in formatted_input:
try:
realtime_input = types.LiveClientRealtimeInput(**formatted_input)
except pydantic.ValidationError:
raise ValueError(
f'Unsupported input type "{type(input)}" or input content'
f' "{input}"'
)
client_message = types.LiveClientMessageDict(
realtime_input=types.LiveClientRealtimeInputDict(
media_chunks=realtime_input.model_dump(
exclude_none=True, mode='json'
)['media_chunks']
)
)
elif 'function_responses' in formatted_input:
try:
tool_response_input = types.LiveClientToolResponse(**formatted_input)
except pydantic.ValidationError:
raise ValueError(
f'Unsupported input type "{type(input)}" or input content'
f' "{input}"'
)
client_message = types.LiveClientMessageDict(
tool_response=types.LiveClientToolResponseDict(
function_responses=tool_response_input.model_dump(
exclude_none=True, mode='json'
)['function_responses']
)
)
else:
raise ValueError(
f'Unsupported input type "{type(input)}" or input content "{input}"'
)
elif isinstance(formatted_input, types.LiveClientRealtimeInput):
realtime_input_dict = formatted_input.model_dump(
exclude_none=True, mode='json'
)
client_message = types.LiveClientMessageDict(
realtime_input=types.LiveClientRealtimeInputDict(
media_chunks=realtime_input_dict.get('media_chunks')
)
)
if (
client_message['realtime_input'] is not None
and client_message['realtime_input']['media_chunks'] is not None
and isinstance(
client_message['realtime_input']['media_chunks'][0]['data'], bytes
)
):
formatted_media_chunks: list[types.BlobDict] = []
for item in client_message['realtime_input']['media_chunks']:
if isinstance(item, dict):
try:
blob_input = types.Blob(**item)
except pydantic.ValidationError:
raise ValueError(
f'Unsupported input type "{type(input)}" or input content'
f' "{input}"'
)
if (
isinstance(blob_input, types.Blob)
and isinstance(blob_input.data, bytes)
and blob_input.data is not None
):
formatted_media_chunks.append(
types.BlobDict(
data=base64.b64decode(blob_input.data),
mime_type=blob_input.mime_type,
)
)
client_message['realtime_input'][
'media_chunks'
] = formatted_media_chunks
elif isinstance(formatted_input, types.LiveClientContent):
client_content_dict = formatted_input.model_dump(
exclude_none=True, mode='json'
)
client_message = types.LiveClientMessageDict(
client_content=types.LiveClientContentDict(
turns=client_content_dict.get('turns'),
turn_complete=client_content_dict.get('turn_complete'),
)
)
elif isinstance(formatted_input, types.LiveClientToolResponse):
# ToolResponse.FunctionResponse
if (
not (self._api_client.vertexai)
and formatted_input.function_responses is not None
and not (formatted_input.function_responses[0].id)
):
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
client_message = types.LiveClientMessageDict(
tool_response=types.LiveClientToolResponseDict(
function_responses=formatted_input.model_dump(
exclude_none=True, mode='json'
).get('function_responses')
)
)
elif isinstance(formatted_input, types.FunctionResponse):
if not (self._api_client.vertexai) and not (formatted_input.id):
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
function_response_dict = formatted_input.model_dump(
exclude_none=True, mode='json'
)
function_response_typeddict = types.FunctionResponseDict(
name=function_response_dict.get('name'),
response=function_response_dict.get('response'),
)
if function_response_dict.get('id'):
function_response_typeddict['id'] = function_response_dict.get('id')
client_message = types.LiveClientMessageDict(
tool_response=types.LiveClientToolResponseDict(
function_responses=[function_response_typeddict]
)
)
elif isinstance(formatted_input, Sequence) and isinstance(
formatted_input[0], types.FunctionResponse
):
if not (self._api_client.vertexai) and not (formatted_input[0].id):
raise ValueError(_FUNCTION_RESPONSE_REQUIRES_ID)
function_response_list: list[types.FunctionResponseDict] = []
for item in formatted_input:
function_response_dict = item.model_dump(exclude_none=True, mode='json')
function_response_typeddict = types.FunctionResponseDict(
name=function_response_dict.get('name'),
response=function_response_dict.get('response'),
)
if function_response_dict.get('id'):
function_response_typeddict['id'] = function_response_dict.get('id')
function_response_list.append(function_response_typeddict)
client_message = types.LiveClientMessageDict(
tool_response=types.LiveClientToolResponseDict(
function_responses=function_response_list
)
)
else:
raise ValueError(
f'Unsupported input type "{type(input)}" or input content "{input}"'
)
return client_message
async def close(self) -> None:
# Close the websocket connection.
await self._ws.close()
class AsyncLive(_api_module.BaseModule):
"""[Preview] AsyncLive."""
@contextlib.asynccontextmanager
async def connect(
self,
*,
model: str,
config: Optional[types.LiveConnectConfigOrDict] = None,
) -> AsyncIterator[AsyncSession]:
"""[Preview] Connect to the live server.
Note: the live API is currently in preview.
Usage:
.. code-block:: python
client = genai.Client(api_key=API_KEY)
config = {}
async with client.aio.live.connect(model='...', config=config) as session:
await session.send(input='Hello world!', end_of_turn=True)
async for message in session.receive():
print(message)
"""
base_url = self._api_client._websocket_base_url()
if isinstance(base_url, bytes):
base_url = base_url.decode('utf-8')
transformed_model = t.t_model(self._api_client, model)
parameter_model = _t_live_connect_config(self._api_client, config)
if self._api_client.api_key:
api_key = self._api_client.api_key
version = self._api_client._http_options.api_version
uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.BidiGenerateContent?key={api_key}'
headers = self._api_client._http_options.headers
request_dict = _common.convert_to_dict(
live_converters._LiveConnectParameters_to_mldev(
api_client=self._api_client,
from_object=types.LiveConnectParameters(
model=transformed_model,
config=parameter_model,
).model_dump(exclude_none=True)
)
)
del request_dict['config']
setv(request_dict, ['setup', 'model'], transformed_model)
request = json.dumps(request_dict)
else:
# Get bearer token through Application Default Credentials.
creds, _ = google.auth.default( # type: ignore[no-untyped-call]
scopes=['https://www.googleapis.com/auth/cloud-platform']
)
# creds.valid is False, and creds.token is None
# Need to refresh credentials to populate those
auth_req = google.auth.transport.requests.Request() # type: ignore[no-untyped-call]
creds.refresh(auth_req)
bearer_token = creds.token
headers = self._api_client._http_options.headers
if headers is not None:
headers.update({
'Authorization': 'Bearer {}'.format(bearer_token),
})
version = self._api_client._http_options.api_version
uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
location = self._api_client.location
project = self._api_client.project
if transformed_model.startswith('publishers/'):
transformed_model = (
f'projects/{project}/locations/{location}/' + transformed_model
)
request_dict = _common.convert_to_dict(
live_converters._LiveConnectParameters_to_vertex(
api_client=self._api_client,
from_object=types.LiveConnectParameters(
model=transformed_model,
config=parameter_model,
).model_dump(exclude_none=True)
)
)
del request_dict['config']
if getv(request_dict, ['setup', 'generationConfig', 'responseModalities']) is None:
setv(request_dict, ['setup', 'generationConfig', 'responseModalities'], ['AUDIO'])
request = json.dumps(request_dict)
try:
async with connect(uri, additional_headers=headers) as ws:
await ws.send(request)
logger.info(await ws.recv(decode=False))
yield AsyncSession(api_client=self._api_client, websocket=ws)
except TypeError:
# Try with the older websockets API
async with connect(uri, extra_headers=headers) as ws:
await ws.send(request)
logger.info(await ws.recv())
yield AsyncSession(api_client=self._api_client, websocket=ws)
def _t_live_connect_config(
api_client: BaseApiClient,
config: Optional[types.LiveConnectConfigOrDict],
) -> types.LiveConnectConfig:
# Ensure the config is a LiveConnectConfig.
if config is None:
parameter_model = types.LiveConnectConfig()
elif isinstance(config, dict):
if getv(config, ['system_instruction']) is not None:
converted_system_instruction = t.t_content(
api_client, getv(config, ['system_instruction'])
)
else:
converted_system_instruction = None
parameter_model = types.LiveConnectConfig(**config)
parameter_model.system_instruction = converted_system_instruction
else:
if config.system_instruction is None:
system_instruction = None
else:
system_instruction = t.t_content(
api_client, getv(config, ['system_instruction'])
)
parameter_model = config
parameter_model.system_instruction = system_instruction
if parameter_model.generation_config is not None:
warnings.warn(
'Setting `LiveConnectConfig.generation_config` is deprecated, '
'please set the fields on `LiveConnectConfig` directly. This will '
'become an error in a future version (not before Q3 2025)',
DeprecationWarning,
stacklevel=4,
)
return parameter_model