adk-python/tests/unittests/agents/test_langgraph_agent.py
Jack Sun 05142a07cc
Moves unittests to root folder and adds github action to run unit tests. (#72)
* Move unit tests to root package.

* Adds deps to "test" extra, and mark two broken tests in tests/unittests/auth/test_auth_handler.py

* Adds github workflow

* minor fix in lite_llm.py for python 3.9.

* format pyproject.toml
2025-04-11 08:25:59 -07:00

192 lines
6.5 KiB
Python

# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import MagicMock
from google.adk.agents.invocation_context import InvocationContext
from google.adk.agents.langgraph_agent import LangGraphAgent
from google.adk.events import Event
from google.genai import types
from langchain_core.messages import AIMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from langgraph.graph.graph import CompiledGraph
import pytest
@pytest.mark.parametrize(
"checkpointer_value, events_list, expected_messages",
[
(
MagicMock(),
[
Event(
invocation_id="test_invocation_id",
author="user",
content=types.Content(
role="user",
parts=[types.Part.from_text(text="test prompt")],
),
),
Event(
invocation_id="test_invocation_id",
author="root_agent",
content=types.Content(
role="model",
parts=[types.Part.from_text(text="(some delegation)")],
),
),
],
[
SystemMessage(content="test system prompt"),
HumanMessage(content="test prompt"),
],
),
(
None,
[
Event(
invocation_id="test_invocation_id",
author="user",
content=types.Content(
role="user",
parts=[types.Part.from_text(text="user prompt 1")],
),
),
Event(
invocation_id="test_invocation_id",
author="root_agent",
content=types.Content(
role="model",
parts=[
types.Part.from_text(text="root agent response")
],
),
),
Event(
invocation_id="test_invocation_id",
author="weather_agent",
content=types.Content(
role="model",
parts=[
types.Part.from_text(text="weather agent response")
],
),
),
Event(
invocation_id="test_invocation_id",
author="user",
content=types.Content(
role="user",
parts=[types.Part.from_text(text="user prompt 2")],
),
),
],
[
SystemMessage(content="test system prompt"),
HumanMessage(content="user prompt 1"),
AIMessage(content="weather agent response"),
HumanMessage(content="user prompt 2"),
],
),
(
MagicMock(),
[
Event(
invocation_id="test_invocation_id",
author="user",
content=types.Content(
role="user",
parts=[types.Part.from_text(text="user prompt 1")],
),
),
Event(
invocation_id="test_invocation_id",
author="root_agent",
content=types.Content(
role="model",
parts=[
types.Part.from_text(text="root agent response")
],
),
),
Event(
invocation_id="test_invocation_id",
author="weather_agent",
content=types.Content(
role="model",
parts=[
types.Part.from_text(text="weather agent response")
],
),
),
Event(
invocation_id="test_invocation_id",
author="user",
content=types.Content(
role="user",
parts=[types.Part.from_text(text="user prompt 2")],
),
),
],
[
SystemMessage(content="test system prompt"),
HumanMessage(content="user prompt 2"),
],
),
],
)
@pytest.mark.asyncio
async def test_langgraph_agent(
checkpointer_value, events_list, expected_messages
):
mock_graph = MagicMock(spec=CompiledGraph)
mock_graph_state = MagicMock()
mock_graph_state.values = {}
mock_graph.get_state.return_value = mock_graph_state
mock_graph.checkpointer = checkpointer_value
mock_graph.invoke.return_value = {
"messages": [AIMessage(content="test response")]
}
mock_parent_context = MagicMock(spec=InvocationContext)
mock_session = MagicMock()
mock_parent_context.session = mock_session
mock_parent_context.branch = "parent_agent"
mock_parent_context.end_invocation = False
mock_session.events = events_list
mock_parent_context.invocation_id = "test_invocation_id"
mock_parent_context.model_copy.return_value = mock_parent_context
weather_agent = LangGraphAgent(
name="weather_agent",
description="A agent that answers weather questions",
instruction="test system prompt",
graph=mock_graph,
)
result_event = None
async for event in weather_agent.run_async(mock_parent_context):
result_event = event
assert result_event.author == "weather_agent"
assert result_event.content.parts[0].text == "test response"
mock_graph.invoke.assert_called_once()
mock_graph.invoke.assert_called_with(
{"messages": expected_messages},
{"configurable": {"thread_id": mock_session.id}},
)