chore: Allows BaseAgent in the constructor of InMemoryRunner.

PiperOrigin-RevId: 759818175
This commit is contained in:
Wei Sun (Jack) 2025-05-16 17:42:22 -07:00 committed by Copybara-Service
parent 74b8841e62
commit 021aaddf32
2 changed files with 6 additions and 14 deletions

View File

@ -17,11 +17,9 @@ import time
import agent import agent
from dotenv import load_dotenv from dotenv import load_dotenv
from google.adk import Runner
from google.adk.agents.run_config import RunConfig from google.adk.agents.run_config import RunConfig
from google.adk.artifacts import InMemoryArtifactService
from google.adk.cli.utils import logs from google.adk.cli.utils import logs
from google.adk.sessions import InMemorySessionService from google.adk.runners import InMemoryRunner
from google.adk.sessions import Session from google.adk.sessions import Session
from google.genai import types from google.genai import types
@ -32,15 +30,11 @@ logs.log_to_tmp_folder()
async def main(): async def main():
app_name = 'my_app' app_name = 'my_app'
user_id_1 = 'user1' user_id_1 = 'user1'
session_service = InMemorySessionService() runner = InMemoryRunner(
artifact_service = InMemoryArtifactService()
runner = Runner(
app_name=app_name,
agent=agent.root_agent, agent=agent.root_agent,
artifact_service=artifact_service, app_name=app_name,
session_service=session_service,
) )
session_11 = await session_service.create_session( session_11 = await runner.session_service.create_session(
app_name=app_name, user_id=user_id_1 app_name=app_name, user_id=user_id_1
) )
@ -85,7 +79,7 @@ async def main():
await run_prompt(session_11, 'What numbers did I got?') await run_prompt(session_11, 'What numbers did I got?')
await run_prompt_bytes(session_11, 'Hi bytes') await run_prompt_bytes(session_11, 'Hi bytes')
print( print(
await artifact_service.list_artifact_keys( await runner.artifact_service.list_artifact_keys(
app_name=app_name, user_id=user_id_1, session_id=session_11.id app_name=app_name, user_id=user_id_1, session_id=session_11.id
) )
) )

View File

@ -23,7 +23,6 @@ from typing import Generator
from typing import Optional from typing import Optional
import warnings import warnings
from deprecated import deprecated
from google.genai import types from google.genai import types
from .agents.active_streaming_tool import ActiveStreamingTool from .agents.active_streaming_tool import ActiveStreamingTool
@ -33,7 +32,6 @@ from .agents.invocation_context import new_invocation_context_id
from .agents.live_request_queue import LiveRequestQueue from .agents.live_request_queue import LiveRequestQueue
from .agents.llm_agent import LlmAgent from .agents.llm_agent import LlmAgent
from .agents.run_config import RunConfig from .agents.run_config import RunConfig
from .agents.run_config import StreamingMode
from .artifacts.base_artifact_service import BaseArtifactService from .artifacts.base_artifact_service import BaseArtifactService
from .artifacts.in_memory_artifact_service import InMemoryArtifactService from .artifacts.in_memory_artifact_service import InMemoryArtifactService
from .events.event import Event from .events.event import Event
@ -475,7 +473,7 @@ class InMemoryRunner(Runner):
session service for the runner. session service for the runner.
""" """
def __init__(self, agent: LlmAgent, *, app_name: str = 'InMemoryRunner'): def __init__(self, agent: BaseAgent, *, app_name: str = 'InMemoryRunner'):
"""Initializes the InMemoryRunner. """Initializes the InMemoryRunner.
Args: Args: