Files
evo-ai/.venv/lib/python3.10/site-packages/google/adk/agents/langgraph_agent.py
2025-04-25 15:30:54 -03:00

141 lines
4.2 KiB
Python

# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import AsyncGenerator
from typing import Union
from google.genai import types
from langchain_core.messages import AIMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from langchain_core.runnables.config import RunnableConfig
from langgraph.graph.graph import CompiledGraph
from pydantic import ConfigDict
from typing_extensions import override
from ..events.event import Event
from .base_agent import BaseAgent
from .invocation_context import InvocationContext
def _get_last_human_messages(events: list[Event]) -> list[HumanMessage]:
"""Extracts last human messages from given list of events.
Args:
events: the list of events
Returns:
list of last human messages
"""
messages = []
for event in reversed(events):
if messages and event.author != 'user':
break
if event.author == 'user' and event.content and event.content.parts:
messages.append(HumanMessage(content=event.content.parts[0].text))
return list(reversed(messages))
class LangGraphAgent(BaseAgent):
"""Currently a concept implementation, supports single and multi-turn."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
graph: CompiledGraph
instruction: str = ''
@override
async def _run_async_impl(
self,
ctx: InvocationContext,
) -> AsyncGenerator[Event, None]:
# Needed for langgraph checkpointer (for subsequent invocations; multi-turn)
config: RunnableConfig = {'configurable': {'thread_id': ctx.session.id}}
# Add instruction as SystemMessage if graph state is empty
current_graph_state = self.graph.get_state(config)
graph_messages = (
current_graph_state.values.get('messages', [])
if current_graph_state.values
else []
)
messages = (
[SystemMessage(content=self.instruction)]
if self.instruction and not graph_messages
else []
)
# Add events to messages (evaluating the memory used; parent agent vs checkpointer)
messages += self._get_messages(ctx.session.events)
# Use the Runnable
final_state = self.graph.invoke({'messages': messages}, config)
result = final_state['messages'][-1].content
result_event = Event(
invocation_id=ctx.invocation_id,
author=self.name,
branch=ctx.branch,
content=types.Content(
role='model',
parts=[types.Part.from_text(text=result)],
),
)
yield result_event
def _get_messages(
self, events: list[Event]
) -> list[Union[HumanMessage, AIMessage]]:
"""Extracts messages from given list of events.
If the developer provides their own memory within langgraph, we return the
last user messages only. Otherwise, we return all messages between the user
and the agent.
Args:
events: the list of events
Returns:
list of messages
"""
if self.graph.checkpointer:
return _get_last_human_messages(events)
else:
return self._get_conversation_with_agent(events)
def _get_conversation_with_agent(
self, events: list[Event]
) -> list[Union[HumanMessage, AIMessage]]:
"""Extracts messages from given list of events.
Args:
events: the list of events
Returns:
list of messages
"""
messages = []
for event in events:
if not event.content or not event.content.parts:
continue
if event.author == 'user':
messages.append(HumanMessage(content=event.content.parts[0].text))
elif event.author == self.name:
messages.append(AIMessage(content=event.content.parts[0].text))
return messages