Moves unittests to root folder and adds github action to run unit tests. (#72)

* Move unit tests to root package.

* Adds deps to "test" extra, and mark two broken tests in tests/unittests/auth/test_auth_handler.py

* Adds github workflow

* minor fix in lite_llm.py for python 3.9.

* format pyproject.toml
This commit is contained in:
Jack Sun
2025-04-11 08:25:59 -07:00
committed by GitHub
parent 59117b9b96
commit 05142a07cc
66 changed files with 50 additions and 2 deletions

View File

@@ -0,0 +1,14 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,224 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from unittest import mock
from google.adk import version
from google.adk.models.gemini_llm_connection import GeminiLlmConnection
from google.adk.models.google_llm import Gemini
from google.adk.models.llm_request import LlmRequest
from google.adk.models.llm_response import LlmResponse
from google.genai import types
from google.genai.types import Content
from google.genai.types import Part
import pytest
@pytest.fixture
def generate_content_response():
return types.GenerateContentResponse(
candidates=[
types.Candidate(
content=Content(
role="model",
parts=[Part.from_text(text="Hello, how can I help you?")],
),
finish_reason=types.FinishReason.STOP,
)
]
)
@pytest.fixture
def gemini_llm():
return Gemini(model="gemini-1.5-flash")
@pytest.fixture
def llm_request():
return LlmRequest(
model="gemini-1.5-flash",
contents=[Content(role="user", parts=[Part.from_text(text="Hello")])],
config=types.GenerateContentConfig(
temperature=0.1,
response_modalities=[types.Modality.TEXT],
system_instruction="You are a helpful assistant",
),
)
def test_supported_models():
models = Gemini.supported_models()
assert len(models) == 3
assert models[0] == r"gemini-.*"
assert models[1] == r"projects\/.+\/locations\/.+\/endpoints\/.+"
assert (
models[2]
== r"projects\/.+\/locations\/.+\/publishers\/google\/models\/gemini.+"
)
def test_client_version_header():
model = Gemini(model="gemini-1.5-flash")
client = model.api_client
expected_header = (
f"google-adk/{version.__version__}"
f" gl-python/{sys.version.split()[0]} google-genai-sdk/"
)
assert (
expected_header
in client._api_client._http_options.headers["x-goog-api-client"]
)
assert (
expected_header in client._api_client._http_options.headers["user-agent"]
)
def test_maybe_append_user_content(gemini_llm, llm_request):
# Test with user content already present
gemini_llm._maybe_append_user_content(llm_request)
assert len(llm_request.contents) == 1
# Test with model content as the last message
llm_request.contents.append(
Content(role="model", parts=[Part.from_text(text="Response")])
)
gemini_llm._maybe_append_user_content(llm_request)
assert len(llm_request.contents) == 3
assert llm_request.contents[-1].role == "user"
assert "Continue processing" in llm_request.contents[-1].parts[0].text
@pytest.mark.asyncio
async def test_generate_content_async(
gemini_llm, llm_request, generate_content_response
):
with mock.patch.object(gemini_llm, "api_client") as mock_client:
# Create a mock coroutine that returns the generate_content_response
async def mock_coro():
return generate_content_response
# Assign the coroutine to the mocked method
mock_client.aio.models.generate_content.return_value = mock_coro()
responses = [
resp
async for resp in gemini_llm.generate_content_async(
llm_request, stream=False
)
]
assert len(responses) == 1
assert isinstance(responses[0], LlmResponse)
assert responses[0].content.parts[0].text == "Hello, how can I help you?"
mock_client.aio.models.generate_content.assert_called_once()
@pytest.mark.asyncio
async def test_generate_content_async_stream(gemini_llm, llm_request):
with mock.patch.object(gemini_llm, "api_client") as mock_client:
# Create mock stream responses
class MockAsyncIterator:
def __init__(self, seq):
self.iter = iter(seq)
def __aiter__(self):
return self
async def __anext__(self):
try:
return next(self.iter)
except StopIteration:
raise StopAsyncIteration
mock_responses = [
types.GenerateContentResponse(
candidates=[
types.Candidate(
content=Content(
role="model", parts=[Part.from_text(text="Hello")]
),
finish_reason=None,
)
]
),
types.GenerateContentResponse(
candidates=[
types.Candidate(
content=Content(
role="model", parts=[Part.from_text(text=", how")]
),
finish_reason=None,
)
]
),
types.GenerateContentResponse(
candidates=[
types.Candidate(
content=Content(
role="model",
parts=[Part.from_text(text=" can I help you?")],
),
finish_reason=types.FinishReason.STOP,
)
]
),
]
# Create a mock coroutine that returns the MockAsyncIterator
async def mock_coro():
return MockAsyncIterator(mock_responses)
# Set the mock to return the coroutine
mock_client.aio.models.generate_content_stream.return_value = mock_coro()
responses = [
resp
async for resp in gemini_llm.generate_content_async(
llm_request, stream=True
)
]
# Assertions remain the same
assert len(responses) == 4
assert responses[0].partial is True
assert responses[1].partial is True
assert responses[2].partial is True
assert responses[3].content.parts[0].text == "Hello, how can I help you?"
mock_client.aio.models.generate_content_stream.assert_called_once()
@pytest.mark.asyncio
async def test_connect(gemini_llm, llm_request):
# Create a mock connection
mock_connection = mock.MagicMock(spec=GeminiLlmConnection)
# Create a mock context manager
class MockContextManager:
async def __aenter__(self):
return mock_connection
async def __aexit__(self, *args):
pass
# Mock the connect method at the class level
with mock.patch(
"google.adk.models.google_llm.Gemini.connect",
return_value=MockContextManager(),
):
async with gemini_llm.connect(llm_request) as connection:
assert connection is mock_connection

View File

@@ -0,0 +1,804 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import AsyncMock
from unittest.mock import Mock
from google.adk.models.lite_llm import _content_to_message_param
from google.adk.models.lite_llm import _function_declaration_to_tool_param
from google.adk.models.lite_llm import _get_content
from google.adk.models.lite_llm import _message_to_generate_content_response
from google.adk.models.lite_llm import _model_response_to_chunk
from google.adk.models.lite_llm import _to_litellm_role
from google.adk.models.lite_llm import FunctionChunk
from google.adk.models.lite_llm import LiteLlm
from google.adk.models.lite_llm import LiteLLMClient
from google.adk.models.lite_llm import TextChunk
from google.adk.models.llm_request import LlmRequest
from google.genai import types
from litellm import ChatCompletionAssistantMessage
from litellm import ChatCompletionMessageToolCall
from litellm import Function
from litellm.types.utils import ChatCompletionDeltaToolCall
from litellm.types.utils import Choices
from litellm.types.utils import Delta
from litellm.types.utils import ModelResponse
from litellm.types.utils import StreamingChoices
import pytest
LLM_REQUEST_WITH_FUNCTION_DECLARATION = LlmRequest(
contents=[
types.Content(
role="user", parts=[types.Part.from_text(text="Test prompt")]
)
],
config=types.GenerateContentConfig(
tools=[
types.Tool(
function_declarations=[
types.FunctionDeclaration(
name="test_function",
description="Test function description",
parameters=types.Schema(
type=types.Type.OBJECT,
properties={
"test_arg": types.Schema(
type=types.Type.STRING
),
"array_arg": types.Schema(
type=types.Type.ARRAY,
items={
"type": types.Type.STRING,
},
),
"nested_arg": types.Schema(
type=types.Type.OBJECT,
properties={
"nested_key1": types.Schema(
type=types.Type.STRING
),
"nested_key2": types.Schema(
type=types.Type.STRING
),
},
),
},
),
)
]
)
],
),
)
STREAMING_MODEL_RESPONSE = [
ModelResponse(
choices=[
StreamingChoices(
finish_reason=None,
delta=Delta(
role="assistant",
content="zero, ",
),
)
]
),
ModelResponse(
choices=[
StreamingChoices(
finish_reason=None,
delta=Delta(
role="assistant",
content="one, ",
),
)
]
),
ModelResponse(
choices=[
StreamingChoices(
finish_reason=None,
delta=Delta(
role="assistant",
content="two:",
),
)
]
),
ModelResponse(
choices=[
StreamingChoices(
finish_reason=None,
delta=Delta(
role="assistant",
tool_calls=[
ChatCompletionDeltaToolCall(
type="function",
id="test_tool_call_id",
function=Function(
name="test_function",
arguments='{"test_arg": "test_',
),
index=0,
)
],
),
)
]
),
ModelResponse(
choices=[
StreamingChoices(
finish_reason=None,
delta=Delta(
role="assistant",
tool_calls=[
ChatCompletionDeltaToolCall(
type="function",
id=None,
function=Function(
name=None,
arguments='value"}',
),
index=0,
)
],
),
)
]
),
ModelResponse(
choices=[
StreamingChoices(
finish_reason="tool_use",
)
]
),
]
@pytest.fixture
def mock_response():
return ModelResponse(
choices=[
Choices(
message=ChatCompletionAssistantMessage(
role="assistant",
content="Test response",
tool_calls=[
ChatCompletionMessageToolCall(
type="function",
id="test_tool_call_id",
function=Function(
name="test_function",
arguments='{"test_arg": "test_value"}',
),
)
],
)
)
]
)
@pytest.fixture
def mock_acompletion(mock_response):
return AsyncMock(return_value=mock_response)
@pytest.fixture
def mock_completion(mock_response):
return Mock(return_value=mock_response)
@pytest.fixture
def mock_client(mock_acompletion, mock_completion):
return MockLLMClient(mock_acompletion, mock_completion)
@pytest.fixture
def lite_llm_instance(mock_client):
return LiteLlm(model="test_model", llm_client=mock_client)
class MockLLMClient(LiteLLMClient):
def __init__(self, acompletion_mock, completion_mock):
self.acompletion_mock = acompletion_mock
self.completion_mock = completion_mock
async def acompletion(self, model, messages, tools, **kwargs):
return await self.acompletion_mock(
model=model, messages=messages, tools=tools, **kwargs
)
def completion(self, model, messages, tools, stream, **kwargs):
return self.completion_mock(
model=model, messages=messages, tools=tools, stream=stream, **kwargs
)
@pytest.mark.asyncio
async def test_generate_content_async(mock_acompletion, lite_llm_instance):
async for response in lite_llm_instance.generate_content_async(
LLM_REQUEST_WITH_FUNCTION_DECLARATION
):
assert response.content.role == "model"
assert response.content.parts[0].text == "Test response"
assert response.content.parts[1].function_call.name == "test_function"
assert response.content.parts[1].function_call.args == {
"test_arg": "test_value"
}
assert response.content.parts[1].function_call.id == "test_tool_call_id"
mock_acompletion.assert_called_once()
_, kwargs = mock_acompletion.call_args
assert kwargs["model"] == "test_model"
assert kwargs["messages"][0]["role"] == "user"
assert kwargs["messages"][0]["content"] == "Test prompt"
assert kwargs["tools"][0]["function"]["name"] == "test_function"
assert (
kwargs["tools"][0]["function"]["description"]
== "Test function description"
)
assert (
kwargs["tools"][0]["function"]["parameters"]["properties"]["test_arg"][
"type"
]
== "string"
)
function_declaration_test_cases = [
(
"simple_function",
types.FunctionDeclaration(
name="test_function",
description="Test function description",
parameters=types.Schema(
type=types.Type.OBJECT,
properties={
"test_arg": types.Schema(type=types.Type.STRING),
"array_arg": types.Schema(
type=types.Type.ARRAY,
items=types.Schema(
type=types.Type.STRING,
),
),
"nested_arg": types.Schema(
type=types.Type.OBJECT,
properties={
"nested_key1": types.Schema(type=types.Type.STRING),
"nested_key2": types.Schema(type=types.Type.STRING),
},
),
},
),
),
{
"type": "function",
"function": {
"name": "test_function",
"description": "Test function description",
"parameters": {
"type": "object",
"properties": {
"test_arg": {"type": "string"},
"array_arg": {
"items": {"type": "string"},
"type": "array",
},
"nested_arg": {
"properties": {
"nested_key1": {"type": "string"},
"nested_key2": {"type": "string"},
},
"type": "object",
},
},
},
},
},
),
(
"no_description",
types.FunctionDeclaration(
name="test_function_no_description",
parameters=types.Schema(
type=types.Type.OBJECT,
properties={
"test_arg": types.Schema(type=types.Type.STRING),
},
),
),
{
"type": "function",
"function": {
"name": "test_function_no_description",
"description": "",
"parameters": {
"type": "object",
"properties": {
"test_arg": {"type": "string"},
},
},
},
},
),
(
"empty_parameters",
types.FunctionDeclaration(
name="test_function_empty_params",
parameters=types.Schema(type=types.Type.OBJECT, properties={}),
),
{
"type": "function",
"function": {
"name": "test_function_empty_params",
"description": "",
"parameters": {
"type": "object",
"properties": {},
},
},
},
),
(
"nested_array",
types.FunctionDeclaration(
name="test_function_nested_array",
parameters=types.Schema(
type=types.Type.OBJECT,
properties={
"array_arg": types.Schema(
type=types.Type.ARRAY,
items=types.Schema(
type=types.Type.OBJECT,
properties={
"nested_key": types.Schema(
type=types.Type.STRING
)
},
),
),
},
),
),
{
"type": "function",
"function": {
"name": "test_function_nested_array",
"description": "",
"parameters": {
"type": "object",
"properties": {
"array_arg": {
"items": {
"properties": {
"nested_key": {"type": "string"}
},
"type": "object",
},
"type": "array",
},
},
},
},
},
),
]
@pytest.mark.parametrize(
"_, function_declaration, expected_output",
function_declaration_test_cases,
ids=[case[0] for case in function_declaration_test_cases],
)
def test_function_declaration_to_tool_param(
_, function_declaration, expected_output
):
assert (
_function_declaration_to_tool_param(function_declaration)
== expected_output
)
@pytest.mark.asyncio
async def test_generate_content_async_with_system_instruction(
lite_llm_instance, mock_acompletion
):
mock_response_with_system_instruction = ModelResponse(
choices=[
Choices(
message=ChatCompletionAssistantMessage(
role="assistant",
content="Test response",
)
)
]
)
mock_acompletion.return_value = mock_response_with_system_instruction
llm_request = LlmRequest(
contents=[
types.Content(
role="user", parts=[types.Part.from_text(text="Test prompt")]
)
],
config=types.GenerateContentConfig(
system_instruction="Test system instruction"
),
)
async for response in lite_llm_instance.generate_content_async(llm_request):
assert response.content.role == "model"
assert response.content.parts[0].text == "Test response"
mock_acompletion.assert_called_once()
_, kwargs = mock_acompletion.call_args
assert kwargs["model"] == "test_model"
assert kwargs["messages"][0]["role"] == "developer"
assert kwargs["messages"][0]["content"] == "Test system instruction"
assert kwargs["messages"][1]["role"] == "user"
assert kwargs["messages"][1]["content"] == "Test prompt"
@pytest.mark.asyncio
async def test_generate_content_async_with_tool_response(
lite_llm_instance, mock_acompletion
):
mock_response_with_tool_response = ModelResponse(
choices=[
Choices(
message=ChatCompletionAssistantMessage(
role="tool",
content='{"result": "test_result"}',
tool_call_id="test_tool_call_id",
)
)
]
)
mock_acompletion.return_value = mock_response_with_tool_response
llm_request = LlmRequest(
contents=[
types.Content(
role="user", parts=[types.Part.from_text(text="Test prompt")]
),
types.Content(
role="tool",
parts=[
types.Part.from_function_response(
name="test_function",
response={"result": "test_result"},
)
],
),
],
config=types.GenerateContentConfig(
system_instruction="test instruction",
),
)
async for response in lite_llm_instance.generate_content_async(llm_request):
assert response.content.role == "model"
assert response.content.parts[0].text == '{"result": "test_result"}'
mock_acompletion.assert_called_once()
_, kwargs = mock_acompletion.call_args
assert kwargs["model"] == "test_model"
assert kwargs["messages"][2]["role"] == "tool"
assert kwargs["messages"][2]["content"] == '{"result": "test_result"}'
def test_content_to_message_param_user_message():
content = types.Content(
role="user", parts=[types.Part.from_text(text="Test prompt")]
)
message = _content_to_message_param(content)
assert message["role"] == "user"
assert message["content"] == "Test prompt"
def test_content_to_message_param_assistant_message():
content = types.Content(
role="assistant", parts=[types.Part.from_text(text="Test response")]
)
message = _content_to_message_param(content)
assert message["role"] == "assistant"
assert message["content"] == "Test response"
def test_content_to_message_param_function_call():
content = types.Content(
role="assistant",
parts=[
types.Part.from_function_call(
name="test_function", args={"test_arg": "test_value"}
)
],
)
content.parts[0].function_call.id = "test_tool_call_id"
message = _content_to_message_param(content)
assert message["role"] == "assistant"
assert message["content"] == []
assert message["tool_calls"][0].type == "function"
assert message["tool_calls"][0].id == "test_tool_call_id"
assert message["tool_calls"][0].function.name == "test_function"
assert (
message["tool_calls"][0].function.arguments
== '{"test_arg": "test_value"}'
)
def test_message_to_generate_content_response_text():
message = ChatCompletionAssistantMessage(
role="assistant",
content="Test response",
)
response = _message_to_generate_content_response(message)
assert response.content.role == "model"
assert response.content.parts[0].text == "Test response"
def test_message_to_generate_content_response_tool_call():
message = ChatCompletionAssistantMessage(
role="assistant",
content=None,
tool_calls=[
ChatCompletionMessageToolCall(
type="function",
id="test_tool_call_id",
function=Function(
name="test_function",
arguments='{"test_arg": "test_value"}',
),
)
],
)
response = _message_to_generate_content_response(message)
assert response.content.role == "model"
assert response.content.parts[0].function_call.name == "test_function"
assert response.content.parts[0].function_call.args == {
"test_arg": "test_value"
}
assert response.content.parts[0].function_call.id == "test_tool_call_id"
def test_get_content_text():
parts = [types.Part.from_text(text="Test text")]
content = _get_content(parts)
assert content == "Test text"
def test_get_content_image():
parts = [
types.Part.from_bytes(data=b"test_image_data", mime_type="image/png")
]
content = _get_content(parts)
assert content[0]["type"] == "image_url"
assert content[0]["image_url"] == ""
def test_get_content_video():
parts = [
types.Part.from_bytes(data=b"test_video_data", mime_type="video/mp4")
]
content = _get_content(parts)
assert content[0]["type"] == "video_url"
assert content[0]["video_url"] == "data:video/mp4;base64,dGVzdF92aWRlb19kYXRh"
def test_to_litellm_role():
assert _to_litellm_role("model") == "assistant"
assert _to_litellm_role("assistant") == "assistant"
assert _to_litellm_role("user") == "user"
assert _to_litellm_role(None) == "user"
@pytest.mark.parametrize(
"response, expected_chunk, expected_finished",
[
(
ModelResponse(
choices=[
{
"message": {
"content": "this is a test",
}
}
]
),
TextChunk(text="this is a test"),
"stop",
),
(
ModelResponse(
choices=[
StreamingChoices(
finish_reason=None,
delta=Delta(
role="assistant",
tool_calls=[
ChatCompletionDeltaToolCall(
type="function",
id="1",
function=Function(
name="test_function",
arguments='{"key": "va',
),
index=0,
)
],
),
)
]
),
FunctionChunk(id="1", name="test_function", args='{"key": "va'),
None,
),
(
ModelResponse(choices=[{"finish_reason": "tool_calls"}]),
None,
"tool_calls",
),
(ModelResponse(choices=[{}]), None, "stop"),
],
)
def test_model_response_to_chunk(response, expected_chunk, expected_finished):
result = list(_model_response_to_chunk(response))
assert len(result) == 1
chunk, finished = result[0]
if expected_chunk:
assert isinstance(chunk, type(expected_chunk))
assert chunk == expected_chunk
else:
assert chunk is None
assert finished == expected_finished
@pytest.mark.asyncio
async def test_acompletion_additional_args(mock_acompletion, mock_client):
lite_llm_instance = LiteLlm(
# valid args
model="test_model",
llm_client=mock_client,
api_key="test_key",
api_base="some://url",
api_version="2024-09-12",
# invalid args (ignored)
stream=True,
messages=[{"role": "invalid", "content": "invalid"}],
tools=[{
"type": "function",
"function": {
"name": "invalid",
},
}],
)
async for response in lite_llm_instance.generate_content_async(
LLM_REQUEST_WITH_FUNCTION_DECLARATION
):
assert response.content.role == "model"
assert response.content.parts[0].text == "Test response"
assert response.content.parts[1].function_call.name == "test_function"
assert response.content.parts[1].function_call.args == {
"test_arg": "test_value"
}
assert response.content.parts[1].function_call.id == "test_tool_call_id"
mock_acompletion.assert_called_once()
_, kwargs = mock_acompletion.call_args
assert kwargs["model"] == "test_model"
assert kwargs["messages"][0]["role"] == "user"
assert kwargs["messages"][0]["content"] == "Test prompt"
assert kwargs["tools"][0]["function"]["name"] == "test_function"
assert "stream" not in kwargs
assert "llm_client" not in kwargs
assert kwargs["api_base"] == "some://url"
@pytest.mark.asyncio
async def test_completion_additional_args(mock_completion, mock_client):
lite_llm_instance = LiteLlm(
# valid args
model="test_model",
llm_client=mock_client,
api_key="test_key",
api_base="some://url",
api_version="2024-09-12",
# invalid args (ignored)
stream=False,
messages=[{"role": "invalid", "content": "invalid"}],
tools=[{
"type": "function",
"function": {
"name": "invalid",
},
}],
)
mock_completion.return_value = iter(STREAMING_MODEL_RESPONSE)
responses = [
response
async for response in lite_llm_instance.generate_content_async(
LLM_REQUEST_WITH_FUNCTION_DECLARATION, stream=True
)
]
assert len(responses) == 4
mock_completion.assert_called_once()
_, kwargs = mock_completion.call_args
assert kwargs["model"] == "test_model"
assert kwargs["messages"][0]["role"] == "user"
assert kwargs["messages"][0]["content"] == "Test prompt"
assert kwargs["tools"][0]["function"]["name"] == "test_function"
assert kwargs["stream"]
assert "llm_client" not in kwargs
assert kwargs["api_base"] == "some://url"
@pytest.mark.asyncio
async def test_generate_content_async_stream(
mock_completion, lite_llm_instance
):
mock_completion.return_value = iter(STREAMING_MODEL_RESPONSE)
responses = [
response
async for response in lite_llm_instance.generate_content_async(
LLM_REQUEST_WITH_FUNCTION_DECLARATION, stream=True
)
]
assert len(responses) == 4
assert responses[0].content.role == "model"
assert responses[0].content.parts[0].text == "zero, "
assert responses[1].content.role == "model"
assert responses[1].content.parts[0].text == "one, "
assert responses[2].content.role == "model"
assert responses[2].content.parts[0].text == "two:"
assert responses[3].content.role == "model"
assert responses[3].content.parts[0].function_call.name == "test_function"
assert responses[3].content.parts[0].function_call.args == {
"test_arg": "test_value"
}
assert responses[3].content.parts[0].function_call.id == "test_tool_call_id"
mock_completion.assert_called_once()
_, kwargs = mock_completion.call_args
assert kwargs["model"] == "test_model"
assert kwargs["messages"][0]["role"] == "user"
assert kwargs["messages"][0]["content"] == "Test prompt"
assert kwargs["tools"][0]["function"]["name"] == "test_function"
assert (
kwargs["tools"][0]["function"]["description"]
== "Test function description"
)
assert (
kwargs["tools"][0]["function"]["parameters"]["properties"]["test_arg"][
"type"
]
== "string"
)

View File

@@ -0,0 +1,60 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.adk import models
from google.adk.models.anthropic_llm import Claude
from google.adk.models.google_llm import Gemini
from google.adk.models.registry import LLMRegistry
import pytest
@pytest.mark.parametrize(
'model_name',
[
'gemini-1.5-flash',
'gemini-1.5-flash-001',
'gemini-1.5-flash-002',
'gemini-1.5-pro',
'gemini-1.5-pro-001',
'gemini-1.5-pro-002',
'gemini-2.0-flash-exp',
'projects/123456/locations/us-central1/endpoints/123456', # finetuned vertex gemini endpoint
'projects/123456/locations/us-central1/publishers/google/models/gemini-2.0-flash-exp', # vertex gemini long name
],
)
def test_match_gemini_family(model_name):
assert models.LLMRegistry.resolve(model_name) is Gemini
@pytest.mark.parametrize(
'model_name',
[
'claude-3-5-haiku@20241022',
'claude-3-5-sonnet-v2@20241022',
'claude-3-5-sonnet@20240620',
'claude-3-haiku@20240307',
'claude-3-opus@20240229',
'claude-3-sonnet@20240229',
],
)
def test_match_claude_family(model_name):
LLMRegistry.register(Claude)
assert models.LLMRegistry.resolve(model_name) is Claude
def test_non_exist_model():
with pytest.raises(ValueError) as e_info:
models.LLMRegistry.resolve('non-exist-model')
assert 'Model non-exist-model not found.' in str(e_info.value)