mirror of
https://github.com/EvolutionAPI/adk-python.git
synced 2025-07-14 01:41:25 -06:00
Add token usage to gemini (streaming), litellm and anthropic
Also included a token_usage sample that showcases the token usage of subagents with different models under a parent agent. PiperOrigin-RevId: 759347015
This commit is contained in:
parent
4d5760917d
commit
509db3f9fb
15
contributing/samples/token_usage/__init__.py
Executable file
15
contributing/samples/token_usage/__init__.py
Executable file
@ -0,0 +1,15 @@
|
|||||||
|
# Copyright 2025 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from . import agent
|
97
contributing/samples/token_usage/agent.py
Executable file
97
contributing/samples/token_usage/agent.py
Executable file
@ -0,0 +1,97 @@
|
|||||||
|
# Copyright 2025 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
from google.adk import Agent
|
||||||
|
from google.adk.agents.llm_agent import LlmAgent
|
||||||
|
from google.adk.agents.sequential_agent import SequentialAgent
|
||||||
|
from google.adk.models.anthropic_llm import Claude
|
||||||
|
from google.adk.models.lite_llm import LiteLlm
|
||||||
|
from google.adk.planners import BuiltInPlanner
|
||||||
|
from google.adk.planners import PlanReActPlanner
|
||||||
|
from google.adk.tools.tool_context import ToolContext
|
||||||
|
from google.genai import types
|
||||||
|
|
||||||
|
|
||||||
|
def roll_die(sides: int, tool_context: ToolContext) -> int:
|
||||||
|
"""Roll a die and return the rolled result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sides: The integer number of sides the die has.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An integer of the result of rolling the die.
|
||||||
|
"""
|
||||||
|
result = random.randint(1, sides)
|
||||||
|
if 'rolls' not in tool_context.state:
|
||||||
|
tool_context.state['rolls'] = []
|
||||||
|
|
||||||
|
tool_context.state['rolls'] = tool_context.state['rolls'] + [result]
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
roll_agent_with_openai = LlmAgent(
|
||||||
|
model=LiteLlm(model='openai/gpt-4o'),
|
||||||
|
description='Handles rolling dice of different sizes.',
|
||||||
|
name='roll_agent_with_openai',
|
||||||
|
instruction="""
|
||||||
|
You are responsible for rolling dice based on the user's request.
|
||||||
|
When asked to roll a die, you must call the roll_die tool with the number of sides as an integer.
|
||||||
|
""",
|
||||||
|
tools=[roll_die],
|
||||||
|
)
|
||||||
|
|
||||||
|
roll_agent_with_claude = LlmAgent(
|
||||||
|
model=Claude(model='claude-3-7-sonnet@20250219'),
|
||||||
|
description='Handles rolling dice of different sizes.',
|
||||||
|
name='roll_agent_with_claude',
|
||||||
|
instruction="""
|
||||||
|
You are responsible for rolling dice based on the user's request.
|
||||||
|
When asked to roll a die, you must call the roll_die tool with the number of sides as an integer.
|
||||||
|
""",
|
||||||
|
tools=[roll_die],
|
||||||
|
)
|
||||||
|
|
||||||
|
roll_agent_with_litellm_claude = LlmAgent(
|
||||||
|
model=LiteLlm(model='vertex_ai/claude-3-7-sonnet'),
|
||||||
|
description='Handles rolling dice of different sizes.',
|
||||||
|
name='roll_agent_with_litellm_claude',
|
||||||
|
instruction="""
|
||||||
|
You are responsible for rolling dice based on the user's request.
|
||||||
|
When asked to roll a die, you must call the roll_die tool with the number of sides as an integer.
|
||||||
|
""",
|
||||||
|
tools=[roll_die],
|
||||||
|
)
|
||||||
|
|
||||||
|
roll_agent_with_gemini = LlmAgent(
|
||||||
|
model='gemini-2.0-flash',
|
||||||
|
description='Handles rolling dice of different sizes.',
|
||||||
|
name='roll_agent_with_gemini',
|
||||||
|
instruction="""
|
||||||
|
You are responsible for rolling dice based on the user's request.
|
||||||
|
When asked to roll a die, you must call the roll_die tool with the number of sides as an integer.
|
||||||
|
""",
|
||||||
|
tools=[roll_die],
|
||||||
|
)
|
||||||
|
|
||||||
|
root_agent = SequentialAgent(
|
||||||
|
name='code_pipeline_agent',
|
||||||
|
sub_agents=[
|
||||||
|
roll_agent_with_openai,
|
||||||
|
roll_agent_with_claude,
|
||||||
|
roll_agent_with_litellm_claude,
|
||||||
|
roll_agent_with_gemini,
|
||||||
|
],
|
||||||
|
)
|
102
contributing/samples/token_usage/asyncio_run.py
Executable file
102
contributing/samples/token_usage/asyncio_run.py
Executable file
@ -0,0 +1,102 @@
|
|||||||
|
# Copyright 2025 Google LLC
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import agent
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from google.adk import Runner
|
||||||
|
from google.adk.agents.run_config import RunConfig
|
||||||
|
from google.adk.artifacts import InMemoryArtifactService
|
||||||
|
from google.adk.cli.utils import logs
|
||||||
|
from google.adk.sessions import InMemorySessionService
|
||||||
|
from google.adk.sessions import Session
|
||||||
|
from google.genai import types
|
||||||
|
|
||||||
|
load_dotenv(override=True)
|
||||||
|
warnings.filterwarnings('ignore', category=UserWarning)
|
||||||
|
logs.log_to_tmp_folder()
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
app_name = 'my_app'
|
||||||
|
user_id_1 = 'user1'
|
||||||
|
session_service = InMemorySessionService()
|
||||||
|
artifact_service = InMemoryArtifactService()
|
||||||
|
runner = Runner(
|
||||||
|
app_name=app_name,
|
||||||
|
agent=agent.root_agent,
|
||||||
|
artifact_service=artifact_service,
|
||||||
|
session_service=session_service,
|
||||||
|
)
|
||||||
|
session_11 = session_service.create_session(
|
||||||
|
app_name=app_name, user_id=user_id_1
|
||||||
|
)
|
||||||
|
|
||||||
|
total_prompt_tokens = 0
|
||||||
|
total_candidate_tokens = 0
|
||||||
|
total_tokens = 0
|
||||||
|
|
||||||
|
async def run_prompt(session: Session, new_message: str):
|
||||||
|
nonlocal total_prompt_tokens
|
||||||
|
nonlocal total_candidate_tokens
|
||||||
|
nonlocal total_tokens
|
||||||
|
content = types.Content(
|
||||||
|
role='user', parts=[types.Part.from_text(text=new_message)]
|
||||||
|
)
|
||||||
|
print('** User says:', content.model_dump(exclude_none=True))
|
||||||
|
async for event in runner.run_async(
|
||||||
|
user_id=user_id_1,
|
||||||
|
session_id=session.id,
|
||||||
|
new_message=content,
|
||||||
|
):
|
||||||
|
if event.content.parts and event.content.parts[0].text:
|
||||||
|
print(f'** {event.author}: {event.content.parts[0].text}')
|
||||||
|
if event.usage_metadata:
|
||||||
|
total_prompt_tokens += event.usage_metadata.prompt_token_count or 0
|
||||||
|
total_candidate_tokens += (
|
||||||
|
event.usage_metadata.candidates_token_count or 0
|
||||||
|
)
|
||||||
|
total_tokens += event.usage_metadata.total_token_count or 0
|
||||||
|
print(
|
||||||
|
'Turn tokens:'
|
||||||
|
f' {event.usage_metadata.total_token_count} (prompt={event.usage_metadata.prompt_token_count},'
|
||||||
|
f' candidates={event.usage_metadata.candidates_token_count})'
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f'Session tokens: {total_tokens} (prompt={total_prompt_tokens},'
|
||||||
|
f' candidates={total_candidate_tokens})'
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
print('Start time:', start_time)
|
||||||
|
print('------------------------------------')
|
||||||
|
await run_prompt(session_11, 'Hi')
|
||||||
|
await run_prompt(session_11, 'Roll a die with 100 sides')
|
||||||
|
print(
|
||||||
|
await artifact_service.list_artifact_keys(
|
||||||
|
app_name=app_name, user_id=user_id_1, session_id=session_11.id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
end_time = time.time()
|
||||||
|
print('------------------------------------')
|
||||||
|
print('End time:', end_time)
|
||||||
|
print('Total time:', end_time - start_time)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
asyncio.run(main())
|
@ -140,15 +140,15 @@ def message_to_generate_content_response(
|
|||||||
role="model",
|
role="model",
|
||||||
parts=[content_block_to_part(cb) for cb in message.content],
|
parts=[content_block_to_part(cb) for cb in message.content],
|
||||||
),
|
),
|
||||||
|
usage_metadata=types.GenerateContentResponseUsageMetadata(
|
||||||
|
prompt_token_count=message.usage.input_tokens,
|
||||||
|
candidates_token_count=message.usage.output_tokens,
|
||||||
|
total_token_count=(
|
||||||
|
message.usage.input_tokens + message.usage.output_tokens
|
||||||
|
),
|
||||||
|
),
|
||||||
# TODO: Deal with these later.
|
# TODO: Deal with these later.
|
||||||
# finish_reason=to_google_genai_finish_reason(message.stop_reason),
|
# finish_reason=to_google_genai_finish_reason(message.stop_reason),
|
||||||
# usage_metadata=types.GenerateContentResponseUsageMetadata(
|
|
||||||
# prompt_token_count=message.usage.input_tokens,
|
|
||||||
# candidates_token_count=message.usage.output_tokens,
|
|
||||||
# total_token_count=(
|
|
||||||
# message.usage.input_tokens + message.usage.output_tokens
|
|
||||||
# ),
|
|
||||||
# ),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -196,6 +196,12 @@ def function_declaration_to_tool_param(
|
|||||||
|
|
||||||
|
|
||||||
class Claude(BaseLlm):
|
class Claude(BaseLlm):
|
||||||
|
""" "Integration with Claude models served from Vertex AI.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
model: The name of the Claude model.
|
||||||
|
"""
|
||||||
|
|
||||||
model: str = "claude-3-5-sonnet-v2@20241022"
|
model: str = "claude-3-5-sonnet-v2@20241022"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -121,6 +121,7 @@ class Gemini(BaseLlm):
|
|||||||
content=types.ModelContent(
|
content=types.ModelContent(
|
||||||
parts=[types.Part.from_text(text=text)],
|
parts=[types.Part.from_text(text=text)],
|
||||||
),
|
),
|
||||||
|
usage_metadata=llm_response.usage_metadata,
|
||||||
)
|
)
|
||||||
text = ''
|
text = ''
|
||||||
yield llm_response
|
yield llm_response
|
||||||
|
@ -67,6 +67,12 @@ class TextChunk(BaseModel):
|
|||||||
text: str
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class UsageMetadataChunk(BaseModel):
|
||||||
|
prompt_tokens: int
|
||||||
|
completion_tokens: int
|
||||||
|
total_tokens: int
|
||||||
|
|
||||||
|
|
||||||
class LiteLLMClient:
|
class LiteLLMClient:
|
||||||
"""Provides acompletion method (for better testability)."""
|
"""Provides acompletion method (for better testability)."""
|
||||||
|
|
||||||
@ -344,15 +350,20 @@ def _function_declaration_to_tool_param(
|
|||||||
def _model_response_to_chunk(
|
def _model_response_to_chunk(
|
||||||
response: ModelResponse,
|
response: ModelResponse,
|
||||||
) -> Generator[
|
) -> Generator[
|
||||||
Tuple[Optional[Union[TextChunk, FunctionChunk]], Optional[str]], None, None
|
Tuple[
|
||||||
|
Optional[Union[TextChunk, FunctionChunk, UsageMetadataChunk]],
|
||||||
|
Optional[str],
|
||||||
|
],
|
||||||
|
None,
|
||||||
|
None,
|
||||||
]:
|
]:
|
||||||
"""Converts a litellm message to text or function chunk.
|
"""Converts a litellm message to text, function or usage metadata chunk.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
response: The response from the model.
|
response: The response from the model.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
A tuple of text or function chunk and finish reason.
|
A tuple of text or function or usage metadata chunk and finish reason.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
message = None
|
message = None
|
||||||
@ -384,11 +395,21 @@ def _model_response_to_chunk(
|
|||||||
if not message:
|
if not message:
|
||||||
yield None, None
|
yield None, None
|
||||||
|
|
||||||
|
# Ideally usage would be expected with the last ModelResponseStream with a
|
||||||
|
# finish_reason set. But this is not the case we are observing from litellm.
|
||||||
|
# So we are sending it as a separate chunk to be set on the llm_response.
|
||||||
|
if response.get("usage", None):
|
||||||
|
yield UsageMetadataChunk(
|
||||||
|
prompt_tokens=response["usage"].get("prompt_tokens", 0),
|
||||||
|
completion_tokens=response["usage"].get("completion_tokens", 0),
|
||||||
|
total_tokens=response["usage"].get("total_tokens", 0),
|
||||||
|
), None
|
||||||
|
|
||||||
|
|
||||||
def _model_response_to_generate_content_response(
|
def _model_response_to_generate_content_response(
|
||||||
response: ModelResponse,
|
response: ModelResponse,
|
||||||
) -> LlmResponse:
|
) -> LlmResponse:
|
||||||
"""Converts a litellm response to LlmResponse.
|
"""Converts a litellm response to LlmResponse. Also adds usage metadata.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
response: The model response.
|
response: The model response.
|
||||||
@ -403,7 +424,15 @@ def _model_response_to_generate_content_response(
|
|||||||
|
|
||||||
if not message:
|
if not message:
|
||||||
raise ValueError("No message in response")
|
raise ValueError("No message in response")
|
||||||
return _message_to_generate_content_response(message)
|
|
||||||
|
llm_response = _message_to_generate_content_response(message)
|
||||||
|
if response.get("usage", None):
|
||||||
|
llm_response.usage_metadata = types.GenerateContentResponseUsageMetadata(
|
||||||
|
prompt_token_count=response["usage"].get("prompt_tokens", 0),
|
||||||
|
candidates_token_count=response["usage"].get("completion_tokens", 0),
|
||||||
|
total_token_count=response["usage"].get("total_tokens", 0),
|
||||||
|
)
|
||||||
|
return llm_response
|
||||||
|
|
||||||
|
|
||||||
def _message_to_generate_content_response(
|
def _message_to_generate_content_response(
|
||||||
@ -628,6 +657,10 @@ class LiteLlm(BaseLlm):
|
|||||||
function_args = ""
|
function_args = ""
|
||||||
function_id = None
|
function_id = None
|
||||||
completion_args["stream"] = True
|
completion_args["stream"] = True
|
||||||
|
aggregated_llm_response = None
|
||||||
|
aggregated_llm_response_with_tool_call = None
|
||||||
|
usage_metadata = None
|
||||||
|
|
||||||
for part in self.llm_client.completion(**completion_args):
|
for part in self.llm_client.completion(**completion_args):
|
||||||
for chunk, finish_reason in _model_response_to_chunk(part):
|
for chunk, finish_reason in _model_response_to_chunk(part):
|
||||||
if isinstance(chunk, FunctionChunk):
|
if isinstance(chunk, FunctionChunk):
|
||||||
@ -645,8 +678,16 @@ class LiteLlm(BaseLlm):
|
|||||||
),
|
),
|
||||||
is_partial=True,
|
is_partial=True,
|
||||||
)
|
)
|
||||||
|
elif isinstance(chunk, UsageMetadataChunk):
|
||||||
|
usage_metadata = types.GenerateContentResponseUsageMetadata(
|
||||||
|
prompt_token_count=chunk.prompt_tokens,
|
||||||
|
candidates_token_count=chunk.completion_tokens,
|
||||||
|
total_token_count=chunk.total_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
if finish_reason == "tool_calls" and function_id:
|
if finish_reason == "tool_calls" and function_id:
|
||||||
yield _message_to_generate_content_response(
|
aggregated_llm_response_with_tool_call = (
|
||||||
|
_message_to_generate_content_response(
|
||||||
ChatCompletionAssistantMessage(
|
ChatCompletionAssistantMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content="",
|
content="",
|
||||||
@ -662,15 +703,30 @@ class LiteLlm(BaseLlm):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
)
|
||||||
function_name = ""
|
function_name = ""
|
||||||
function_args = ""
|
function_args = ""
|
||||||
function_id = None
|
function_id = None
|
||||||
elif finish_reason == "stop" and text:
|
elif finish_reason == "stop" and text:
|
||||||
yield _message_to_generate_content_response(
|
aggregated_llm_response = _message_to_generate_content_response(
|
||||||
ChatCompletionAssistantMessage(role="assistant", content=text)
|
ChatCompletionAssistantMessage(role="assistant", content=text)
|
||||||
)
|
)
|
||||||
text = ""
|
text = ""
|
||||||
|
|
||||||
|
# waiting until streaming ends to yield the llm_response as litellm tends
|
||||||
|
# to send chunk that contains usage_metadata after the chunk with
|
||||||
|
# finish_reason set to tool_calls or stop.
|
||||||
|
if aggregated_llm_response:
|
||||||
|
if usage_metadata:
|
||||||
|
aggregated_llm_response.usage_metadata = usage_metadata
|
||||||
|
usage_metadata = None
|
||||||
|
yield aggregated_llm_response
|
||||||
|
|
||||||
|
if aggregated_llm_response_with_tool_call:
|
||||||
|
if usage_metadata:
|
||||||
|
aggregated_llm_response_with_tool_call.usage_metadata = usage_metadata
|
||||||
|
yield aggregated_llm_response_with_tool_call
|
||||||
|
|
||||||
else:
|
else:
|
||||||
response = await self.llm_client.acompletion(**completion_args)
|
response = await self.llm_client.acompletion(**completion_args)
|
||||||
yield _model_response_to_generate_content_response(response)
|
yield _model_response_to_generate_content_response(response)
|
||||||
|
@ -25,6 +25,7 @@ from google.adk.models.lite_llm import FunctionChunk
|
|||||||
from google.adk.models.lite_llm import LiteLlm
|
from google.adk.models.lite_llm import LiteLlm
|
||||||
from google.adk.models.lite_llm import LiteLLMClient
|
from google.adk.models.lite_llm import LiteLLMClient
|
||||||
from google.adk.models.lite_llm import TextChunk
|
from google.adk.models.lite_llm import TextChunk
|
||||||
|
from google.adk.models.lite_llm import UsageMetadataChunk
|
||||||
from google.adk.models.llm_request import LlmRequest
|
from google.adk.models.llm_request import LlmRequest
|
||||||
from google.genai import types
|
from google.genai import types
|
||||||
from litellm import ChatCompletionAssistantMessage
|
from litellm import ChatCompletionAssistantMessage
|
||||||
@ -315,14 +316,11 @@ litellm_append_user_content_test_cases = [
|
|||||||
)
|
)
|
||||||
def test_maybe_append_user_content(lite_llm_instance, llm_request, expected_output):
|
def test_maybe_append_user_content(lite_llm_instance, llm_request, expected_output):
|
||||||
|
|
||||||
lite_llm_instance._maybe_append_user_content(
|
lite_llm_instance._maybe_append_user_content(llm_request)
|
||||||
llm_request
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(llm_request.contents) == expected_output
|
assert len(llm_request.contents) == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
function_declaration_test_cases = [
|
function_declaration_test_cases = [
|
||||||
(
|
(
|
||||||
"simple_function",
|
"simple_function",
|
||||||
@ -567,6 +565,80 @@ async def test_generate_content_async_with_tool_response(
|
|||||||
assert kwargs["messages"][2]["content"] == '{"result": "test_result"}'
|
assert kwargs["messages"][2]["content"] == '{"result": "test_result"}'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_content_async(mock_acompletion, lite_llm_instance):
|
||||||
|
|
||||||
|
async for response in lite_llm_instance.generate_content_async(
|
||||||
|
LLM_REQUEST_WITH_FUNCTION_DECLARATION
|
||||||
|
):
|
||||||
|
assert response.content.role == "model"
|
||||||
|
assert response.content.parts[0].text == "Test response"
|
||||||
|
assert response.content.parts[1].function_call.name == "test_function"
|
||||||
|
assert response.content.parts[1].function_call.args == {
|
||||||
|
"test_arg": "test_value"
|
||||||
|
}
|
||||||
|
assert response.content.parts[1].function_call.id == "test_tool_call_id"
|
||||||
|
|
||||||
|
mock_acompletion.assert_called_once()
|
||||||
|
|
||||||
|
_, kwargs = mock_acompletion.call_args
|
||||||
|
assert kwargs["model"] == "test_model"
|
||||||
|
assert kwargs["messages"][0]["role"] == "user"
|
||||||
|
assert kwargs["messages"][0]["content"] == "Test prompt"
|
||||||
|
assert kwargs["tools"][0]["function"]["name"] == "test_function"
|
||||||
|
assert (
|
||||||
|
kwargs["tools"][0]["function"]["description"]
|
||||||
|
== "Test function description"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
kwargs["tools"][0]["function"]["parameters"]["properties"]["test_arg"][
|
||||||
|
"type"
|
||||||
|
]
|
||||||
|
== "string"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_content_async_with_usage_metadata(
|
||||||
|
lite_llm_instance, mock_acompletion
|
||||||
|
):
|
||||||
|
mock_response_with_usage_metadata = ModelResponse(
|
||||||
|
choices=[
|
||||||
|
Choices(
|
||||||
|
message=ChatCompletionAssistantMessage(
|
||||||
|
role="assistant",
|
||||||
|
content="Test response",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
],
|
||||||
|
usage={
|
||||||
|
"prompt_tokens": 10,
|
||||||
|
"completion_tokens": 5,
|
||||||
|
"total_tokens": 15,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
mock_acompletion.return_value = mock_response_with_usage_metadata
|
||||||
|
|
||||||
|
llm_request = LlmRequest(
|
||||||
|
contents=[
|
||||||
|
types.Content(
|
||||||
|
role="user", parts=[types.Part.from_text(text="Test prompt")]
|
||||||
|
),
|
||||||
|
],
|
||||||
|
config=types.GenerateContentConfig(
|
||||||
|
system_instruction="test instruction",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
async for response in lite_llm_instance.generate_content_async(llm_request):
|
||||||
|
assert response.content.role == "model"
|
||||||
|
assert response.content.parts[0].text == "Test response"
|
||||||
|
assert response.usage_metadata.prompt_token_count == 10
|
||||||
|
assert response.usage_metadata.candidates_token_count == 5
|
||||||
|
assert response.usage_metadata.total_token_count == 15
|
||||||
|
|
||||||
|
mock_acompletion.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
def test_content_to_message_param_user_message():
|
def test_content_to_message_param_user_message():
|
||||||
content = types.Content(
|
content = types.Content(
|
||||||
role="user", parts=[types.Part.from_text(text="Test prompt")]
|
role="user", parts=[types.Part.from_text(text="Test prompt")]
|
||||||
@ -704,7 +776,7 @@ def test_to_litellm_role():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"response, expected_chunk, expected_finished",
|
"response, expected_chunks, expected_finished",
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
ModelResponse(
|
ModelResponse(
|
||||||
@ -716,7 +788,35 @@ def test_to_litellm_role():
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
|
[
|
||||||
TextChunk(text="this is a test"),
|
TextChunk(text="this is a test"),
|
||||||
|
UsageMetadataChunk(
|
||||||
|
prompt_tokens=0, completion_tokens=0, total_tokens=0
|
||||||
|
),
|
||||||
|
],
|
||||||
|
"stop",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
ModelResponse(
|
||||||
|
choices=[
|
||||||
|
{
|
||||||
|
"message": {
|
||||||
|
"content": "this is a test",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
usage={
|
||||||
|
"prompt_tokens": 3,
|
||||||
|
"completion_tokens": 5,
|
||||||
|
"total_tokens": 8,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
[
|
||||||
|
TextChunk(text="this is a test"),
|
||||||
|
UsageMetadataChunk(
|
||||||
|
prompt_tokens=3, completion_tokens=5, total_tokens=8
|
||||||
|
),
|
||||||
|
],
|
||||||
"stop",
|
"stop",
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
@ -741,28 +841,53 @@ def test_to_litellm_role():
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
|
[
|
||||||
FunctionChunk(id="1", name="test_function", args='{"key": "va'),
|
FunctionChunk(id="1", name="test_function", args='{"key": "va'),
|
||||||
|
UsageMetadataChunk(
|
||||||
|
prompt_tokens=0, completion_tokens=0, total_tokens=0
|
||||||
|
),
|
||||||
|
],
|
||||||
None,
|
None,
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
ModelResponse(choices=[{"finish_reason": "tool_calls"}]),
|
ModelResponse(choices=[{"finish_reason": "tool_calls"}]),
|
||||||
|
[
|
||||||
None,
|
None,
|
||||||
|
UsageMetadataChunk(
|
||||||
|
prompt_tokens=0, completion_tokens=0, total_tokens=0
|
||||||
|
),
|
||||||
|
],
|
||||||
"tool_calls",
|
"tool_calls",
|
||||||
),
|
),
|
||||||
(ModelResponse(choices=[{}]), None, "stop"),
|
(
|
||||||
|
ModelResponse(choices=[{}]),
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
UsageMetadataChunk(
|
||||||
|
prompt_tokens=0, completion_tokens=0, total_tokens=0
|
||||||
|
),
|
||||||
|
],
|
||||||
|
"stop",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_model_response_to_chunk(response, expected_chunk, expected_finished):
|
def test_model_response_to_chunk(response, expected_chunks, expected_finished):
|
||||||
result = list(_model_response_to_chunk(response))
|
result = list(_model_response_to_chunk(response))
|
||||||
assert len(result) == 1
|
assert len(result) == 2
|
||||||
chunk, finished = result[0]
|
chunk, finished = result[0]
|
||||||
if expected_chunk:
|
if expected_chunks:
|
||||||
assert isinstance(chunk, type(expected_chunk))
|
assert isinstance(chunk, type(expected_chunks[0]))
|
||||||
assert chunk == expected_chunk
|
assert chunk == expected_chunks[0]
|
||||||
else:
|
else:
|
||||||
assert chunk is None
|
assert chunk is None
|
||||||
assert finished == expected_finished
|
assert finished == expected_finished
|
||||||
|
|
||||||
|
usage_chunk, _ = result[1]
|
||||||
|
assert usage_chunk is not None
|
||||||
|
assert usage_chunk.prompt_tokens == expected_chunks[1].prompt_tokens
|
||||||
|
assert usage_chunk.completion_tokens == expected_chunks[1].completion_tokens
|
||||||
|
assert usage_chunk.total_tokens == expected_chunks[1].total_tokens
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_acompletion_additional_args(mock_acompletion, mock_client):
|
async def test_acompletion_additional_args(mock_acompletion, mock_client):
|
||||||
@ -893,3 +1018,71 @@ async def test_generate_content_async_stream(
|
|||||||
]
|
]
|
||||||
== "string"
|
== "string"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_content_async_stream_with_usage_metadata(
|
||||||
|
mock_completion, lite_llm_instance
|
||||||
|
):
|
||||||
|
|
||||||
|
streaming_model_response_with_usage_metadata = [
|
||||||
|
*STREAMING_MODEL_RESPONSE,
|
||||||
|
ModelResponse(
|
||||||
|
usage={
|
||||||
|
"prompt_tokens": 10,
|
||||||
|
"completion_tokens": 5,
|
||||||
|
"total_tokens": 15,
|
||||||
|
},
|
||||||
|
choices=[
|
||||||
|
StreamingChoices(
|
||||||
|
finish_reason=None,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_completion.return_value = iter(
|
||||||
|
streaming_model_response_with_usage_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
responses = [
|
||||||
|
response
|
||||||
|
async for response in lite_llm_instance.generate_content_async(
|
||||||
|
LLM_REQUEST_WITH_FUNCTION_DECLARATION, stream=True
|
||||||
|
)
|
||||||
|
]
|
||||||
|
assert len(responses) == 4
|
||||||
|
assert responses[0].content.role == "model"
|
||||||
|
assert responses[0].content.parts[0].text == "zero, "
|
||||||
|
assert responses[1].content.role == "model"
|
||||||
|
assert responses[1].content.parts[0].text == "one, "
|
||||||
|
assert responses[2].content.role == "model"
|
||||||
|
assert responses[2].content.parts[0].text == "two:"
|
||||||
|
assert responses[3].content.role == "model"
|
||||||
|
assert responses[3].content.parts[0].function_call.name == "test_function"
|
||||||
|
assert responses[3].content.parts[0].function_call.args == {
|
||||||
|
"test_arg": "test_value"
|
||||||
|
}
|
||||||
|
assert responses[3].content.parts[0].function_call.id == "test_tool_call_id"
|
||||||
|
|
||||||
|
assert responses[3].usage_metadata.prompt_token_count == 10
|
||||||
|
assert responses[3].usage_metadata.candidates_token_count == 5
|
||||||
|
assert responses[3].usage_metadata.total_token_count == 15
|
||||||
|
|
||||||
|
mock_completion.assert_called_once()
|
||||||
|
|
||||||
|
_, kwargs = mock_completion.call_args
|
||||||
|
assert kwargs["model"] == "test_model"
|
||||||
|
assert kwargs["messages"][0]["role"] == "user"
|
||||||
|
assert kwargs["messages"][0]["content"] == "Test prompt"
|
||||||
|
assert kwargs["tools"][0]["function"]["name"] == "test_function"
|
||||||
|
assert (
|
||||||
|
kwargs["tools"][0]["function"]["description"]
|
||||||
|
== "Test function description"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
kwargs["tools"][0]["function"]["parameters"]["properties"]["test_arg"][
|
||||||
|
"type"
|
||||||
|
]
|
||||||
|
== "string"
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user