Add two adk run options:

--replay : a json file that contains the initial session state and user queries, adk run will create a new session based on the state and run the user queries against the session. Users cannot continue to interact with agent.

--resume : a json file that contains the previously saved session (by --save_session option), adk run will replay this session and then user can continue to interact with the agent.
PiperOrigin-RevId: 752923403
This commit is contained in:
Xiang (Sean) Zhou 2025-04-29 16:22:43 -07:00 committed by Copybara-Service
parent 2a9ddec7e3
commit dbbeb190b0
2 changed files with 105 additions and 49 deletions

View File

@ -39,12 +39,12 @@ class InputFile(BaseModel):
async def run_input_file( async def run_input_file(
app_name: str, app_name: str,
user_id: str,
root_agent: LlmAgent, root_agent: LlmAgent,
artifact_service: BaseArtifactService, artifact_service: BaseArtifactService,
session: Session,
session_service: BaseSessionService, session_service: BaseSessionService,
input_path: str, input_path: str,
) -> None: ) -> Session:
runner = Runner( runner = Runner(
app_name=app_name, app_name=app_name,
agent=root_agent, agent=root_agent,
@ -55,9 +55,11 @@ async def run_input_file(
input_file = InputFile.model_validate_json(f.read()) input_file = InputFile.model_validate_json(f.read())
input_file.state['_time'] = datetime.now() input_file.state['_time'] = datetime.now()
session.state = input_file.state session = session_service.create_session(
app_name=app_name, user_id=user_id, state=input_file.state
)
for query in input_file.queries: for query in input_file.queries:
click.echo(f'user: {query}') click.echo(f'[user]: {query}')
content = types.Content(role='user', parts=[types.Part(text=query)]) content = types.Content(role='user', parts=[types.Part(text=query)])
async for event in runner.run_async( async for event in runner.run_async(
user_id=session.user_id, session_id=session.id, new_message=content user_id=session.user_id, session_id=session.id, new_message=content
@ -65,23 +67,23 @@ async def run_input_file(
if event.content and event.content.parts: if event.content and event.content.parts:
if text := ''.join(part.text or '' for part in event.content.parts): if text := ''.join(part.text or '' for part in event.content.parts):
click.echo(f'[{event.author}]: {text}') click.echo(f'[{event.author}]: {text}')
return session
async def run_interactively( async def run_interactively(
app_name: str,
root_agent: LlmAgent, root_agent: LlmAgent,
artifact_service: BaseArtifactService, artifact_service: BaseArtifactService,
session: Session, session: Session,
session_service: BaseSessionService, session_service: BaseSessionService,
) -> None: ) -> None:
runner = Runner( runner = Runner(
app_name=app_name, app_name=session.app_name,
agent=root_agent, agent=root_agent,
artifact_service=artifact_service, artifact_service=artifact_service,
session_service=session_service, session_service=session_service,
) )
while True: while True:
query = input('user: ') query = input('[user]: ')
if not query or not query.strip(): if not query or not query.strip():
continue continue
if query == 'exit': if query == 'exit':
@ -100,7 +102,8 @@ async def run_cli(
*, *,
agent_parent_dir: str, agent_parent_dir: str,
agent_folder_name: str, agent_folder_name: str,
json_file_path: Optional[str] = None, input_file: Optional[str] = None,
saved_session_file: Optional[str] = None,
save_session: bool, save_session: bool,
) -> None: ) -> None:
"""Runs an interactive CLI for a certain agent. """Runs an interactive CLI for a certain agent.
@ -109,8 +112,11 @@ async def run_cli(
agent_parent_dir: str, the absolute path of the parent folder of the agent agent_parent_dir: str, the absolute path of the parent folder of the agent
folder. folder.
agent_folder_name: str, the name of the agent folder. agent_folder_name: str, the name of the agent folder.
json_file_path: Optional[str], the absolute path to the json file, either input_file: Optional[str], the absolute path to the json file that contains
*.input.json or *.session.json. the initial session state and user queries, exclusive with
saved_session_file.
saved_session_file: Optional[str], the absolute path to the json file that
contains a previously saved session, exclusive with input_file.
save_session: bool, whether to save the session on exit. save_session: bool, whether to save the session on exit.
""" """
if agent_parent_dir not in sys.path: if agent_parent_dir not in sys.path:
@ -118,46 +124,50 @@ async def run_cli(
artifact_service = InMemoryArtifactService() artifact_service = InMemoryArtifactService()
session_service = InMemorySessionService() session_service = InMemorySessionService()
session = session_service.create_session(
app_name=agent_folder_name, user_id='test_user'
)
agent_module_path = os.path.join(agent_parent_dir, agent_folder_name) agent_module_path = os.path.join(agent_parent_dir, agent_folder_name)
agent_module = importlib.import_module(agent_folder_name) agent_module = importlib.import_module(agent_folder_name)
user_id = 'test_user'
session = session_service.create_session(
app_name=agent_folder_name, user_id=user_id
)
root_agent = agent_module.agent.root_agent root_agent = agent_module.agent.root_agent
envs.load_dotenv_for_agent(agent_folder_name, agent_parent_dir) envs.load_dotenv_for_agent(agent_folder_name, agent_parent_dir)
if json_file_path: if input_file:
if json_file_path.endswith('.input.json'): session = await run_input_file(
await run_input_file(
app_name=agent_folder_name, app_name=agent_folder_name,
user_id=user_id,
root_agent=root_agent, root_agent=root_agent,
artifact_service=artifact_service, artifact_service=artifact_service,
session=session,
session_service=session_service, session_service=session_service,
input_path=json_file_path, input_path=input_file,
) )
elif json_file_path.endswith('.session.json'): elif saved_session_file:
with open(json_file_path, 'r') as f:
session = Session.model_validate_json(f.read()) loaded_session = None
for content in session.get_contents(): with open(saved_session_file, 'r') as f:
if content.role == 'user': loaded_session = Session.model_validate_json(f.read())
print('user: ', content.parts[0].text)
if loaded_session:
for event in loaded_session.events:
session_service.append_event(session, event)
content = event.content
if not content or not content.parts or not content.parts[0].text:
continue
if event.author == 'user':
click.echo(f'[user]: {content.parts[0].text}')
else: else:
print(content.parts[0].text) click.echo(f'[{event.author}]: {content.parts[0].text}')
await run_interactively( await run_interactively(
agent_folder_name,
root_agent, root_agent,
artifact_service, artifact_service,
session, session,
session_service, session_service,
) )
else: else:
print(f'Unsupported file type: {json_file_path}') click.echo(f'Running agent {root_agent.name}, type exit to exit.')
exit(1)
else:
print(f'Running agent {root_agent.name}, type exit to exit.')
await run_interactively( await run_interactively(
agent_folder_name,
root_agent, root_agent,
artifact_service, artifact_service,
session, session,
@ -165,9 +175,6 @@ async def run_cli(
) )
if save_session: if save_session:
if json_file_path:
session_path = json_file_path.replace('.input.json', '.session.json')
else:
session_id = input('Session ID to save: ') session_id = input('Session ID to save: ')
session_path = f'{agent_module_path}/{session_id}.session.json' session_path = f'{agent_module_path}/{session_id}.session.json'

View File

@ -96,6 +96,23 @@ def cli_create_cmd(
) )
def validate_exclusive(ctx, param, value):
# Store the validated parameters in the context
if not hasattr(ctx, "exclusive_opts"):
ctx.exclusive_opts = {}
# If this option has a value and we've already seen another exclusive option
if value is not None and any(ctx.exclusive_opts.values()):
exclusive_opt = next(key for key, val in ctx.exclusive_opts.items() if val)
raise click.UsageError(
f"Options '{param.name}' and '{exclusive_opt}' cannot be set together."
)
# Record this option's value
ctx.exclusive_opts[param.name] = value is not None
return value
@main.command("run") @main.command("run")
@click.option( @click.option(
"--save_session", "--save_session",
@ -105,13 +122,43 @@ def cli_create_cmd(
default=False, default=False,
help="Optional. Whether to save the session to a json file on exit.", help="Optional. Whether to save the session to a json file on exit.",
) )
@click.option(
"--replay",
type=click.Path(
exists=True, dir_okay=False, file_okay=True, resolve_path=True
),
help=(
"The json file that contains the initial state of the session and user"
" queries. A new session will be created using this state. And user"
" queries are run againt the newly created session. Users cannot"
" continue to interact with the agent."
),
callback=validate_exclusive,
)
@click.option(
"--resume",
type=click.Path(
exists=True, dir_okay=False, file_okay=True, resolve_path=True
),
help=(
"The json file that contains a previously saved session (by"
"--save_session option). The previous session will be re-displayed. And"
" user can continue to interact with the agent."
),
callback=validate_exclusive,
)
@click.argument( @click.argument(
"agent", "agent",
type=click.Path( type=click.Path(
exists=True, dir_okay=True, file_okay=False, resolve_path=True exists=True, dir_okay=True, file_okay=False, resolve_path=True
), ),
) )
def cli_run(agent: str, save_session: bool): def cli_run(
agent: str,
save_session: bool,
replay: Optional[str],
resume: Optional[str],
):
"""Runs an interactive CLI for a certain agent. """Runs an interactive CLI for a certain agent.
AGENT: The path to the agent source code folder. AGENT: The path to the agent source code folder.
@ -129,6 +176,8 @@ def cli_run(agent: str, save_session: bool):
run_cli( run_cli(
agent_parent_dir=agent_parent_folder, agent_parent_dir=agent_parent_folder,
agent_folder_name=agent_folder_name, agent_folder_name=agent_folder_name,
input_file=replay,
saved_session_file=resume,
save_session=save_session, save_session=save_session,
) )
) )