mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-14 01:41:25 -06:00
Add token usage to gemini (streaming), litellm and anthropic
Also included a token_usage sample that showcases the token usage of subagents with different models under a parent agent. PiperOrigin-RevId: 759347015
This commit is contained in:
parent
4d5760917d
commit
509db3f9fb
15
contributing/samples/token_usage/__init__.py
Executable file
15
contributing/samples/token_usage/__init__.py
Executable file
@ -0,0 +1,15 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import agent
|
97
contributing/samples/token_usage/agent.py
Executable file
97
contributing/samples/token_usage/agent.py
Executable file
@ -0,0 +1,97 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
|
||||
from google.adk import Agent
|
||||
from google.adk.agents.llm_agent import LlmAgent
|
||||
from google.adk.agents.sequential_agent import SequentialAgent
|
||||
from google.adk.models.anthropic_llm import Claude
|
||||
from google.adk.models.lite_llm import LiteLlm
|
||||
from google.adk.planners import BuiltInPlanner
|
||||
from google.adk.planners import PlanReActPlanner
|
||||
from google.adk.tools.tool_context import ToolContext
|
||||
from google.genai import types
|
||||
|
||||
|
||||
def roll_die(sides: int, tool_context: ToolContext) -> int:
|
||||
"""Roll a die and return the rolled result.
|
||||
|
||||
Args:
|
||||
sides: The integer number of sides the die has.
|
||||
|
||||
Returns:
|
||||
An integer of the result of rolling the die.
|
||||
"""
|
||||
result = random.randint(1, sides)
|
||||
if 'rolls' not in tool_context.state:
|
||||
tool_context.state['rolls'] = []
|
||||
|
||||
tool_context.state['rolls'] = tool_context.state['rolls'] + [result]
|
||||
return result
|
||||
|
||||
|
||||
roll_agent_with_openai = LlmAgent(
|
||||
model=LiteLlm(model='openai/gpt-4o'),
|
||||
description='Handles rolling dice of different sizes.',
|
||||
name='roll_agent_with_openai',
|
||||
instruction="""
|
||||
You are responsible for rolling dice based on the user's request.
|
||||
When asked to roll a die, you must call the roll_die tool with the number of sides as an integer.
|
||||
""",
|
||||
tools=[roll_die],
|
||||
)
|
||||
|
||||
roll_agent_with_claude = LlmAgent(
|
||||
model=Claude(model='claude-3-7-sonnet@20250219'),
|
||||
description='Handles rolling dice of different sizes.',
|
||||
name='roll_agent_with_claude',
|
||||
instruction="""
|
||||
You are responsible for rolling dice based on the user's request.
|
||||
When asked to roll a die, you must call the roll_die tool with the number of sides as an integer.
|
||||
""",
|
||||
tools=[roll_die],
|
||||
)
|
||||
|
||||
roll_agent_with_litellm_claude = LlmAgent(
|
||||
model=LiteLlm(model='vertex_ai/claude-3-7-sonnet'),
|
||||
description='Handles rolling dice of different sizes.',
|
||||
name='roll_agent_with_litellm_claude',
|
||||
instruction="""
|
||||
You are responsible for rolling dice based on the user's request.
|
||||
When asked to roll a die, you must call the roll_die tool with the number of sides as an integer.
|
||||
""",
|
||||
tools=[roll_die],
|
||||
)
|
||||
|
||||
roll_agent_with_gemini = LlmAgent(
|
||||
model='gemini-2.0-flash',
|
||||
description='Handles rolling dice of different sizes.',
|
||||
name='roll_agent_with_gemini',
|
||||
instruction="""
|
||||
You are responsible for rolling dice based on the user's request.
|
||||
When asked to roll a die, you must call the roll_die tool with the number of sides as an integer.
|
||||
""",
|
||||
tools=[roll_die],
|
||||
)
|
||||
|
||||
root_agent = SequentialAgent(
|
||||
name='code_pipeline_agent',
|
||||
sub_agents=[
|
||||
roll_agent_with_openai,
|
||||
roll_agent_with_claude,
|
||||
roll_agent_with_litellm_claude,
|
||||
roll_agent_with_gemini,
|
||||
],
|
||||
)
|
102
contributing/samples/token_usage/asyncio_run.py
Executable file
102
contributing/samples/token_usage/asyncio_run.py
Executable file
@ -0,0 +1,102 @@
|
||||
# Copyright 2025 Google LLC
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import warnings
|
||||
|
||||
import agent
|
||||
from dotenv import load_dotenv
|
||||
from google.adk import Runner
|
||||
from google.adk.agents.run_config import RunConfig
|
||||
from google.adk.artifacts import InMemoryArtifactService
|
||||
from google.adk.cli.utils import logs
|
||||
from google.adk.sessions import InMemorySessionService
|
||||
from google.adk.sessions import Session
|
||||
from google.genai import types
|
||||
|
||||
load_dotenv(override=True)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
logs.log_to_tmp_folder()
|
||||
|
||||
|
||||
async def main():
|
||||
app_name = 'my_app'
|
||||
user_id_1 = 'user1'
|
||||
session_service = InMemorySessionService()
|
||||
artifact_service = InMemoryArtifactService()
|
||||
runner = Runner(
|
||||
app_name=app_name,
|
||||
agent=agent.root_agent,
|
||||
artifact_service=artifact_service,
|
||||
session_service=session_service,
|
||||
)
|
||||
session_11 = session_service.create_session(
|
||||
app_name=app_name, user_id=user_id_1
|
||||
)
|
||||
|
||||
total_prompt_tokens = 0
|
||||
total_candidate_tokens = 0
|
||||
total_tokens = 0
|
||||
|
||||
async def run_prompt(session: Session, new_message: str):
|
||||
nonlocal total_prompt_tokens
|
||||
nonlocal total_candidate_tokens
|
||||
nonlocal total_tokens
|
||||
content = types.Content(
|
||||
role='user', parts=[types.Part.from_text(text=new_message)]
|
||||
)
|
||||
print('** User says:', content.model_dump(exclude_none=True))
|
||||
async for event in runner.run_async(
|
||||
user_id=user_id_1,
|
||||
session_id=session.id,
|
||||
new_message=content,
|
||||
):
|
||||
if event.content.parts and event.content.parts[0].text:
|
||||
print(f'** {event.author}: {event.content.parts[0].text}')
|
||||
if event.usage_metadata:
|
||||
total_prompt_tokens += event.usage_metadata.prompt_token_count or 0
|
||||
total_candidate_tokens += (
|
||||
event.usage_metadata.candidates_token_count or 0
|
||||
)
|
||||
total_tokens += event.usage_metadata.total_token_count or 0
|
||||
print(
|
||||
'Turn tokens:'
|
||||
f' {event.usage_metadata.total_token_count} (prompt={event.usage_metadata.prompt_token_count},'
|
||||
f' candidates={event.usage_metadata.candidates_token_count})'
|
||||
)
|
||||
|
||||
print(
|
||||
f'Session tokens: {total_tokens} (prompt={total_prompt_tokens},'
|
||||
f' candidates={total_candidate_tokens})'
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
print('Start time:', start_time)
|
||||
print('------------------------------------')
|
||||
await run_prompt(session_11, 'Hi')
|
||||
await run_prompt(session_11, 'Roll a die with 100 sides')
|
||||
print(
|
||||
await artifact_service.list_artifact_keys(
|
||||
app_name=app_name, user_id=user_id_1, session_id=session_11.id
|
||||
)
|
||||
)
|
||||
end_time = time.time()
|
||||
print('------------------------------------')
|
||||
print('End time:', end_time)
|
||||
print('Total time:', end_time - start_time)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
@ -218,7 +218,7 @@ class BaseLlmFlow(ABC):
|
||||
|
||||
When the model returns transcription, the author is "user". Otherwise, the
|
||||
author is the agent name(not 'model').
|
||||
|
||||
|
||||
Args:
|
||||
llm_response: The LLM response from the LLM call.
|
||||
"""
|
||||
|
@ -140,15 +140,15 @@ def message_to_generate_content_response(
|
||||
role="model",
|
||||
parts=[content_block_to_part(cb) for cb in message.content],
|
||||
),
|
||||
usage_metadata=types.GenerateContentResponseUsageMetadata(
|
||||
prompt_token_count=message.usage.input_tokens,
|
||||
candidates_token_count=message.usage.output_tokens,
|
||||
total_token_count=(
|
||||
message.usage.input_tokens + message.usage.output_tokens
|
||||
),
|
||||
),
|
||||
# TODO: Deal with these later.
|
||||
# finish_reason=to_google_genai_finish_reason(message.stop_reason),
|
||||
# usage_metadata=types.GenerateContentResponseUsageMetadata(
|
||||
# prompt_token_count=message.usage.input_tokens,
|
||||
# candidates_token_count=message.usage.output_tokens,
|
||||
# total_token_count=(
|
||||
# message.usage.input_tokens + message.usage.output_tokens
|
||||
# ),
|
||||
# ),
|
||||
)
|
||||
|
||||
|
||||
@ -196,6 +196,12 @@ def function_declaration_to_tool_param(
|
||||
|
||||
|
||||
class Claude(BaseLlm):
|
||||
""" "Integration with Claude models served from Vertex AI.
|
||||
|
||||
Attributes:
|
||||
model: The name of the Claude model.
|
||||
"""
|
||||
|
||||
model: str = "claude-3-5-sonnet-v2@20241022"
|
||||
|
||||
@staticmethod
|
||||
|
@ -121,6 +121,7 @@ class Gemini(BaseLlm):
|
||||
content=types.ModelContent(
|
||||
parts=[types.Part.from_text(text=text)],
|
||||
),
|
||||
usage_metadata=llm_response.usage_metadata,
|
||||
)
|
||||
text = ''
|
||||
yield llm_response
|
||||
@ -174,7 +175,7 @@ class Gemini(BaseLlm):
|
||||
@cached_property
|
||||
def _live_api_client(self) -> Client:
|
||||
if self._api_backend == 'vertex':
|
||||
#use beta version for vertex api
|
||||
# use beta version for vertex api
|
||||
api_version = 'v1beta1'
|
||||
# use default api version for vertex
|
||||
return Client(
|
||||
|
@ -67,6 +67,12 @@ class TextChunk(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class UsageMetadataChunk(BaseModel):
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class LiteLLMClient:
|
||||
"""Provides acompletion method (for better testability)."""
|
||||
|
||||
@ -344,15 +350,20 @@ def _function_declaration_to_tool_param(
|
||||
def _model_response_to_chunk(
|
||||
response: ModelResponse,
|
||||
) -> Generator[
|
||||
Tuple[Optional[Union[TextChunk, FunctionChunk]], Optional[str]], None, None
|
||||
Tuple[
|
||||
Optional[Union[TextChunk, FunctionChunk, UsageMetadataChunk]],
|
||||
Optional[str],
|
||||
],
|
||||
None,
|
||||
None,
|
||||
]:
|
||||
"""Converts a litellm message to text or function chunk.
|
||||
"""Converts a litellm message to text, function or usage metadata chunk.
|
||||
|
||||
Args:
|
||||
response: The response from the model.
|
||||
|
||||
Yields:
|
||||
A tuple of text or function chunk and finish reason.
|
||||
A tuple of text or function or usage metadata chunk and finish reason.
|
||||
"""
|
||||
|
||||
message = None
|
||||
@ -384,11 +395,21 @@ def _model_response_to_chunk(
|
||||
if not message:
|
||||
yield None, None
|
||||
|
||||
# Ideally usage would be expected with the last ModelResponseStream with a
|
||||
# finish_reason set. But this is not the case we are observing from litellm.
|
||||
# So we are sending it as a separate chunk to be set on the llm_response.
|
||||
if response.get("usage", None):
|
||||
yield UsageMetadataChunk(
|
||||
prompt_tokens=response["usage"].get("prompt_tokens", 0),
|
||||
completion_tokens=response["usage"].get("completion_tokens", 0),
|
||||
total_tokens=response["usage"].get("total_tokens", 0),
|
||||
), None
|
||||
|
||||
|
||||
def _model_response_to_generate_content_response(
|
||||
response: ModelResponse,
|
||||
) -> LlmResponse:
|
||||
"""Converts a litellm response to LlmResponse.
|
||||
"""Converts a litellm response to LlmResponse. Also adds usage metadata.
|
||||
|
||||
Args:
|
||||
response: The model response.
|
||||
@ -403,7 +424,15 @@ def _model_response_to_generate_content_response(
|
||||
|
||||
if not message:
|
||||
raise ValueError("No message in response")
|
||||
return _message_to_generate_content_response(message)
|
||||
|
||||
llm_response = _message_to_generate_content_response(message)
|
||||
if response.get("usage", None):
|
||||
llm_response.usage_metadata = types.GenerateContentResponseUsageMetadata(
|
||||
prompt_token_count=response["usage"].get("prompt_tokens", 0),
|
||||
candidates_token_count=response["usage"].get("completion_tokens", 0),
|
||||
total_token_count=response["usage"].get("total_tokens", 0),
|
||||
)
|
||||
return llm_response
|
||||
|
||||
|
||||
def _message_to_generate_content_response(
|
||||
@ -628,6 +657,10 @@ class LiteLlm(BaseLlm):
|
||||
function_args = ""
|
||||
function_id = None
|
||||
completion_args["stream"] = True
|
||||
aggregated_llm_response = None
|
||||
aggregated_llm_response_with_tool_call = None
|
||||
usage_metadata = None
|
||||
|
||||
for part in self.llm_client.completion(**completion_args):
|
||||
for chunk, finish_reason in _model_response_to_chunk(part):
|
||||
if isinstance(chunk, FunctionChunk):
|
||||
@ -645,32 +678,55 @@ class LiteLlm(BaseLlm):
|
||||
),
|
||||
is_partial=True,
|
||||
)
|
||||
elif isinstance(chunk, UsageMetadataChunk):
|
||||
usage_metadata = types.GenerateContentResponseUsageMetadata(
|
||||
prompt_token_count=chunk.prompt_tokens,
|
||||
candidates_token_count=chunk.completion_tokens,
|
||||
total_token_count=chunk.total_tokens,
|
||||
)
|
||||
|
||||
if finish_reason == "tool_calls" and function_id:
|
||||
yield _message_to_generate_content_response(
|
||||
ChatCompletionAssistantMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
type="function",
|
||||
id=function_id,
|
||||
function=Function(
|
||||
name=function_name,
|
||||
arguments=function_args,
|
||||
),
|
||||
)
|
||||
],
|
||||
aggregated_llm_response_with_tool_call = (
|
||||
_message_to_generate_content_response(
|
||||
ChatCompletionAssistantMessage(
|
||||
role="assistant",
|
||||
content="",
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
type="function",
|
||||
id=function_id,
|
||||
function=Function(
|
||||
name=function_name,
|
||||
arguments=function_args,
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
)
|
||||
function_name = ""
|
||||
function_args = ""
|
||||
function_id = None
|
||||
elif finish_reason == "stop" and text:
|
||||
yield _message_to_generate_content_response(
|
||||
aggregated_llm_response = _message_to_generate_content_response(
|
||||
ChatCompletionAssistantMessage(role="assistant", content=text)
|
||||
)
|
||||
text = ""
|
||||
|
||||
# waiting until streaming ends to yield the llm_response as litellm tends
|
||||
# to send chunk that contains usage_metadata after the chunk with
|
||||
# finish_reason set to tool_calls or stop.
|
||||
if aggregated_llm_response:
|
||||
if usage_metadata:
|
||||
aggregated_llm_response.usage_metadata = usage_metadata
|
||||
usage_metadata = None
|
||||
yield aggregated_llm_response
|
||||
|
||||
if aggregated_llm_response_with_tool_call:
|
||||
if usage_metadata:
|
||||
aggregated_llm_response_with_tool_call.usage_metadata = usage_metadata
|
||||
yield aggregated_llm_response_with_tool_call
|
||||
|
||||
else:
|
||||
response = await self.llm_client.acompletion(**completion_args)
|
||||
yield _model_response_to_generate_content_response(response)
|
||||
|
@ -25,6 +25,7 @@ from google.adk.models.lite_llm import FunctionChunk
|
||||
from google.adk.models.lite_llm import LiteLlm
|
||||
from google.adk.models.lite_llm import LiteLLMClient
|
||||
from google.adk.models.lite_llm import TextChunk
|
||||
from google.adk.models.lite_llm import UsageMetadataChunk
|
||||
from google.adk.models.llm_request import LlmRequest
|
||||
from google.genai import types
|
||||
from litellm import ChatCompletionAssistantMessage
|
||||
@ -314,13 +315,10 @@ litellm_append_user_content_test_cases = [
|
||||
litellm_append_user_content_test_cases
|
||||
)
|
||||
def test_maybe_append_user_content(lite_llm_instance, llm_request, expected_output):
|
||||
|
||||
lite_llm_instance._maybe_append_user_content(
|
||||
llm_request
|
||||
)
|
||||
|
||||
assert len(llm_request.contents) == expected_output
|
||||
|
||||
lite_llm_instance._maybe_append_user_content(llm_request)
|
||||
|
||||
assert len(llm_request.contents) == expected_output
|
||||
|
||||
|
||||
function_declaration_test_cases = [
|
||||
@ -567,6 +565,80 @@ async def test_generate_content_async_with_tool_response(
|
||||
assert kwargs["messages"][2]["content"] == '{"result": "test_result"}'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_content_async(mock_acompletion, lite_llm_instance):
|
||||
|
||||
async for response in lite_llm_instance.generate_content_async(
|
||||
LLM_REQUEST_WITH_FUNCTION_DECLARATION
|
||||
):
|
||||
assert response.content.role == "model"
|
||||
assert response.content.parts[0].text == "Test response"
|
||||
assert response.content.parts[1].function_call.name == "test_function"
|
||||
assert response.content.parts[1].function_call.args == {
|
||||
"test_arg": "test_value"
|
||||
}
|
||||
assert response.content.parts[1].function_call.id == "test_tool_call_id"
|
||||
|
||||
mock_acompletion.assert_called_once()
|
||||
|
||||
_, kwargs = mock_acompletion.call_args
|
||||
assert kwargs["model"] == "test_model"
|
||||
assert kwargs["messages"][0]["role"] == "user"
|
||||
assert kwargs["messages"][0]["content"] == "Test prompt"
|
||||
assert kwargs["tools"][0]["function"]["name"] == "test_function"
|
||||
assert (
|
||||
kwargs["tools"][0]["function"]["description"]
|
||||
== "Test function description"
|
||||
)
|
||||
assert (
|
||||
kwargs["tools"][0]["function"]["parameters"]["properties"]["test_arg"][
|
||||
"type"
|
||||
]
|
||||
== "string"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_content_async_with_usage_metadata(
|
||||
lite_llm_instance, mock_acompletion
|
||||
):
|
||||
mock_response_with_usage_metadata = ModelResponse(
|
||||
choices=[
|
||||
Choices(
|
||||
message=ChatCompletionAssistantMessage(
|
||||
role="assistant",
|
||||
content="Test response",
|
||||
)
|
||||
)
|
||||
],
|
||||
usage={
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 5,
|
||||
"total_tokens": 15,
|
||||
},
|
||||
)
|
||||
mock_acompletion.return_value = mock_response_with_usage_metadata
|
||||
|
||||
llm_request = LlmRequest(
|
||||
contents=[
|
||||
types.Content(
|
||||
role="user", parts=[types.Part.from_text(text="Test prompt")]
|
||||
),
|
||||
],
|
||||
config=types.GenerateContentConfig(
|
||||
system_instruction="test instruction",
|
||||
),
|
||||
)
|
||||
async for response in lite_llm_instance.generate_content_async(llm_request):
|
||||
assert response.content.role == "model"
|
||||
assert response.content.parts[0].text == "Test response"
|
||||
assert response.usage_metadata.prompt_token_count == 10
|
||||
assert response.usage_metadata.candidates_token_count == 5
|
||||
assert response.usage_metadata.total_token_count == 15
|
||||
|
||||
mock_acompletion.assert_called_once()
|
||||
|
||||
|
||||
def test_content_to_message_param_user_message():
|
||||
content = types.Content(
|
||||
role="user", parts=[types.Part.from_text(text="Test prompt")]
|
||||
@ -704,7 +776,7 @@ def test_to_litellm_role():
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"response, expected_chunk, expected_finished",
|
||||
"response, expected_chunks, expected_finished",
|
||||
[
|
||||
(
|
||||
ModelResponse(
|
||||
@ -716,7 +788,35 @@ def test_to_litellm_role():
|
||||
}
|
||||
]
|
||||
),
|
||||
TextChunk(text="this is a test"),
|
||||
[
|
||||
TextChunk(text="this is a test"),
|
||||
UsageMetadataChunk(
|
||||
prompt_tokens=0, completion_tokens=0, total_tokens=0
|
||||
),
|
||||
],
|
||||
"stop",
|
||||
),
|
||||
(
|
||||
ModelResponse(
|
||||
choices=[
|
||||
{
|
||||
"message": {
|
||||
"content": "this is a test",
|
||||
}
|
||||
}
|
||||
],
|
||||
usage={
|
||||
"prompt_tokens": 3,
|
||||
"completion_tokens": 5,
|
||||
"total_tokens": 8,
|
||||
},
|
||||
),
|
||||
[
|
||||
TextChunk(text="this is a test"),
|
||||
UsageMetadataChunk(
|
||||
prompt_tokens=3, completion_tokens=5, total_tokens=8
|
||||
),
|
||||
],
|
||||
"stop",
|
||||
),
|
||||
(
|
||||
@ -741,28 +841,53 @@ def test_to_litellm_role():
|
||||
)
|
||||
]
|
||||
),
|
||||
FunctionChunk(id="1", name="test_function", args='{"key": "va'),
|
||||
[
|
||||
FunctionChunk(id="1", name="test_function", args='{"key": "va'),
|
||||
UsageMetadataChunk(
|
||||
prompt_tokens=0, completion_tokens=0, total_tokens=0
|
||||
),
|
||||
],
|
||||
None,
|
||||
),
|
||||
(
|
||||
ModelResponse(choices=[{"finish_reason": "tool_calls"}]),
|
||||
None,
|
||||
[
|
||||
None,
|
||||
UsageMetadataChunk(
|
||||
prompt_tokens=0, completion_tokens=0, total_tokens=0
|
||||
),
|
||||
],
|
||||
"tool_calls",
|
||||
),
|
||||
(ModelResponse(choices=[{}]), None, "stop"),
|
||||
(
|
||||
ModelResponse(choices=[{}]),
|
||||
[
|
||||
None,
|
||||
UsageMetadataChunk(
|
||||
prompt_tokens=0, completion_tokens=0, total_tokens=0
|
||||
),
|
||||
],
|
||||
"stop",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_model_response_to_chunk(response, expected_chunk, expected_finished):
|
||||
def test_model_response_to_chunk(response, expected_chunks, expected_finished):
|
||||
result = list(_model_response_to_chunk(response))
|
||||
assert len(result) == 1
|
||||
assert len(result) == 2
|
||||
chunk, finished = result[0]
|
||||
if expected_chunk:
|
||||
assert isinstance(chunk, type(expected_chunk))
|
||||
assert chunk == expected_chunk
|
||||
if expected_chunks:
|
||||
assert isinstance(chunk, type(expected_chunks[0]))
|
||||
assert chunk == expected_chunks[0]
|
||||
else:
|
||||
assert chunk is None
|
||||
assert finished == expected_finished
|
||||
|
||||
usage_chunk, _ = result[1]
|
||||
assert usage_chunk is not None
|
||||
assert usage_chunk.prompt_tokens == expected_chunks[1].prompt_tokens
|
||||
assert usage_chunk.completion_tokens == expected_chunks[1].completion_tokens
|
||||
assert usage_chunk.total_tokens == expected_chunks[1].total_tokens
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acompletion_additional_args(mock_acompletion, mock_client):
|
||||
@ -893,3 +1018,71 @@ async def test_generate_content_async_stream(
|
||||
]
|
||||
== "string"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_content_async_stream_with_usage_metadata(
|
||||
mock_completion, lite_llm_instance
|
||||
):
|
||||
|
||||
streaming_model_response_with_usage_metadata = [
|
||||
*STREAMING_MODEL_RESPONSE,
|
||||
ModelResponse(
|
||||
usage={
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 5,
|
||||
"total_tokens": 15,
|
||||
},
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
finish_reason=None,
|
||||
)
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
mock_completion.return_value = iter(
|
||||
streaming_model_response_with_usage_metadata
|
||||
)
|
||||
|
||||
responses = [
|
||||
response
|
||||
async for response in lite_llm_instance.generate_content_async(
|
||||
LLM_REQUEST_WITH_FUNCTION_DECLARATION, stream=True
|
||||
)
|
||||
]
|
||||
assert len(responses) == 4
|
||||
assert responses[0].content.role == "model"
|
||||
assert responses[0].content.parts[0].text == "zero, "
|
||||
assert responses[1].content.role == "model"
|
||||
assert responses[1].content.parts[0].text == "one, "
|
||||
assert responses[2].content.role == "model"
|
||||
assert responses[2].content.parts[0].text == "two:"
|
||||
assert responses[3].content.role == "model"
|
||||
assert responses[3].content.parts[0].function_call.name == "test_function"
|
||||
assert responses[3].content.parts[0].function_call.args == {
|
||||
"test_arg": "test_value"
|
||||
}
|
||||
assert responses[3].content.parts[0].function_call.id == "test_tool_call_id"
|
||||
|
||||
assert responses[3].usage_metadata.prompt_token_count == 10
|
||||
assert responses[3].usage_metadata.candidates_token_count == 5
|
||||
assert responses[3].usage_metadata.total_token_count == 15
|
||||
|
||||
mock_completion.assert_called_once()
|
||||
|
||||
_, kwargs = mock_completion.call_args
|
||||
assert kwargs["model"] == "test_model"
|
||||
assert kwargs["messages"][0]["role"] == "user"
|
||||
assert kwargs["messages"][0]["content"] == "Test prompt"
|
||||
assert kwargs["tools"][0]["function"]["name"] == "test_function"
|
||||
assert (
|
||||
kwargs["tools"][0]["function"]["description"]
|
||||
== "Test function description"
|
||||
)
|
||||
assert (
|
||||
kwargs["tools"][0]["function"]["parameters"]["properties"]["test_arg"][
|
||||
"type"
|
||||
]
|
||||
== "string"
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user